recall_k.py 3.9 KB
Newer Older
M
malin10 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math

import numpy as np
import paddle.fluid as fluid

from paddlerec.core.metric import Metric
M
bug fix  
malin10 已提交
21
from paddle.fluid.layers import accuracy
M
malin10 已提交
22 23
from paddle.fluid.initializer import Constant
from paddle.fluid.layer_helper import LayerHelper
M
malin10 已提交
24
from paddle.fluid.layers.tensor import Variable
M
malin10 已提交
25 26 27 28 29 30 31 32 33


class RecallK(Metric):
    """
    Metric For Fluid Model
    """

    def __init__(self, **kwargs):
        """ """
M
malin10 已提交
34 35 36 37
        if "input" not in kwargs or "label" not in kwargs:
            raise ValueError("RecallK expect input and label as inputs.")
        predict = kwargs.get('input')
        label = kwargs.get('label')
M
update  
malin10 已提交
38
        self.k = kwargs.get("k", 20)
M
malin10 已提交
39 40 41 42 43 44 45 46

        if not isinstance(predict, Variable):
            raise ValueError("input must be Variable, but received %s" %
                             type(predict))
        if not isinstance(label, Variable):
            raise ValueError("label must be Variable, but received %s" %
                             type(label))

M
malin10 已提交
47
        helper = LayerHelper("PaddleRec_RecallK", **kwargs)
M
update  
malin10 已提交
48
        batch_accuracy = accuracy(predict, label, self.k)
M
malin10 已提交
49 50 51 52 53 54 55 56 57 58 59 60 61 62
        global_ins_cnt, _ = helper.create_or_get_global_variable(
            name="ins_cnt", persistable=True, dtype='float32', shape=[1])
        global_pos_cnt, _ = helper.create_or_get_global_variable(
            name="pos_cnt", persistable=True, dtype='float32', shape=[1])

        for var in [global_ins_cnt, global_pos_cnt]:
            helper.set_variable_initializer(
                var, Constant(
                    value=0.0, force_cpu=True))

        tmp_ones = fluid.layers.fill_constant(
            shape=fluid.layers.shape(label), dtype="float32", value=1.0)
        batch_ins = fluid.layers.reduce_sum(tmp_ones)
        batch_pos = batch_ins * batch_accuracy
M
malin10 已提交
63 64

        helper.append_op(
M
malin10 已提交
65 66 67 68 69
            type="elementwise_add",
            inputs={"X": [global_ins_cnt],
                    "Y": [batch_ins]},
            outputs={"Out": [global_ins_cnt]})

M
malin10 已提交
70
        helper.append_op(
M
malin10 已提交
71 72 73 74
            type="elementwise_add",
            inputs={"X": [global_pos_cnt],
                    "Y": [batch_pos]},
            outputs={"Out": [global_pos_cnt]})
M
malin10 已提交
75

M
malin10 已提交
76
        self.acc = global_pos_cnt / global_ins_cnt
M
malin10 已提交
77

M
malin10 已提交
78 79 80 81 82
        self._global_metric_state_vars = dict()
        self._global_metric_state_vars['ins_cnt'] = (global_ins_cnt.name,
                                                     "float32")
        self._global_metric_state_vars['pos_cnt'] = (global_pos_cnt.name,
                                                     "float32")
M
malin10 已提交
83

M
update  
malin10 已提交
84
        metric_name = "Acc(Recall@%d)" % self.k
M
malin10 已提交
85
        self.metrics = dict()
M
update  
malin10 已提交
86 87
        self.metrics["InsCnt"] = global_ins_cnt
        self.metrics["RecallCnt"] = global_pos_cnt
M
malin10 已提交
88
        self.metrics[metric_name] = self.acc
M
malin10 已提交
89

M
update  
malin10 已提交
90
    # self.metrics["batch_metrics"] = batch_metrics
M
bug fix  
malin10 已提交
91
    def _calculate(self, global_metrics):
M
malin10 已提交
92
        for key in self._global_metric_state_vars:
M
update  
malin10 已提交
93 94 95 96 97 98 99 100 101 102 103
            if key not in global_metrics:
                raise ValueError("%s not existed" % key)
        ins_cnt = global_metrics['ins_cnt'][0]
        pos_cnt = global_metrics['pos_cnt'][0]
        if ins_cnt == 0:
            acc = 0
        else:
            acc = float(pos_cnt) / ins_cnt
        return "InsCnt=%s RecallCnt=%s Acc(Recall@%d)=%s" % (
            str(ins_cnt), str(pos_cnt), self.k, str(acc))

M
malin10 已提交
104 105
    def get_result(self):
        return self.metrics