diff --git a/core/metric.py b/core/metric.py index d621a06ead5fa85b0bb9d8f3e13a8f15aa6dafa5..ae91cd6de35e11ccbd2c7bd5f3d4b745c8723a4f 100755 --- a/core/metric.py +++ b/core/metric.py @@ -23,11 +23,13 @@ class Metric(object): __metaclass__ = abc.ABCMeta def __init__(self, config): - """ """ + """R + """ pass def clear(self, scope=None): - """ """ + """R + """ if scope is None: scope = fluid.global_scope() @@ -41,9 +43,13 @@ class Metric(object): data_array = np.zeros(var._get_dims()).astype(dtype) var.set(data_array, place) - def get_global_metric_state(self, fleet, scope, metric_name, mode="sum"): - """ """ - input = np.array(scope.find_var(metric_name).get_tensor()) + def _get_global_metric_state(self, fleet, scope, metric_name, mode="sum"): + """R + """ + var = scope.find_var(metric_name) + if not var: + return None + input = np.array(var.get_tensor()) if fleet is None: return input fleet._role_maker._barrier_worker() @@ -54,8 +60,9 @@ class Metric(object): output = output.reshape(old_shape) return output - def cal_global_metrics(self, fleet, scope=None): - """ """ + def calc_global_metrics(self, fleet, scope=None): + """R + """ if scope is None: scope = fluid.global_scope() @@ -65,9 +72,9 @@ class Metric(object): global_metrics[key] = self.get_global_metric_state(fleet, scope, varname) - return self.calculate(global_metrics) + return self._calculate(global_metrics) - def calculate(self, global_metrics): + def _calculate(self, global_metrics): pass @abc.abstractmethod diff --git a/core/metrics/binary_class/auc.py b/core/metrics/binary_class/auc.py index b847314636df53be8bdabcdd58961470b61148f5..129b8bc7eb0854f3a19b2ae0c2a101ccaf7d1d74 100755 --- a/core/metrics/binary_class/auc.py +++ b/core/metrics/binary_class/auc.py @@ -74,7 +74,7 @@ class AUC(Metric): self.metrics["AUC"] = auc_out self.metrics["BATCH_AUC"] = batch_auc_out - def calculate_bucket_error(self, global_pos, global_neg): + def _calculate_bucket_error(self, global_pos, global_neg): """R """ num_bucket = len(global_pos) @@ -122,7 +122,7 @@ class AUC(Metric): bucket_error = error_sum / error_count if error_count > 0 else 0.0 return bucket_error - def calculate_auc(self, global_pos, global_neg): + def _calculate_auc(self, global_pos, global_neg): """R """ num_bucket = len(global_pos) @@ -148,7 +148,7 @@ class AUC(Metric): auc_value = area / (pos * neg) return auc_value - def calculate(self, global_metrics): + def _calculate(self, global_metrics): result = dict() for key in self._global_metric_state_vars: if key not in global_metrics: @@ -165,10 +165,10 @@ class AUC(Metric): result['copc'] = 0 result['mean_q'] = 0 else: - result['auc'] = self.calculate_auc(result['stat_pos'], - result['stat_neg']) - result['bucket_error'] = self.calculate_auc(result['stat_pos'], - result['stat_neg']) + result['auc'] = self._calculate_auc(result['stat_pos'], + result['stat_neg']) + result['bucket_error'] = self._calculate_bucket_error( + result['stat_pos'], result['stat_neg']) result['actual_ctr'] = result['pos_ins_num'] / result[ 'total_ins_num'] result['mae'] = result['abserr'] / result['total_ins_num'] diff --git a/core/metrics/binary_class/precision_recall.py b/core/metrics/binary_class/precision_recall.py index 0eb80765232f7c318c49b926f61c74434c163d5d..a40b1e191b1cee7df9b4f457a0087ff3f58cce69 100755 --- a/core/metrics/binary_class/precision_recall.py +++ b/core/metrics/binary_class/precision_recall.py @@ -29,7 +29,8 @@ class PrecisionRecall(Metric): """ def __init__(self, **kwargs): - """ """ + """R + """ if "input" not in kwargs or "label" not in kwargs or "class_num" not in kwargs: raise ValueError( "PrecisionRecall expect input, label and class_num as inputs.") @@ -107,9 +108,7 @@ class PrecisionRecall(Metric): self.metrics["precision_recall_f1"] = accum_metrics self.metrics["[TP FP TN FN]"] = states_info - # self.metrics["batch_metrics"] = batch_metrics - - def calculate(self, global_metrics): + def _calculate(self, global_metrics): for key in self._global_metric_state_vars: if key not in global_metrics: raise ValueError("%s not existed" % key) diff --git a/core/metrics/pairwise_pn.py b/core/metrics/pairwise_pn.py index 673ce79b9e8f5b7ec73fc440cdc4d959747cae26..156a86063efbe8380fa1314fa7613aa378f35302 100755 --- a/core/metrics/pairwise_pn.py +++ b/core/metrics/pairwise_pn.py @@ -84,7 +84,7 @@ class PosNegRatio(Metric): self.metrics['RightCnt'] = global_right_cnt self.metrics['PN'] = self.pn - def calculate(self, global_metrics): + def _calculate(self, global_metrics): for key in self._global_communicate_var: if key not in global_metrics: raise ValueError("%s not existed" % key) diff --git a/core/metrics/recall_k.py b/core/metrics/recall_k.py index f570ef222f2cdb68cd5fc283539755ce97123905..27ade14503fe6d558c7f2345517bed831f57dccf 100755 --- a/core/metrics/recall_k.py +++ b/core/metrics/recall_k.py @@ -88,7 +88,7 @@ class RecallK(Metric): self.metrics[metric_name] = self.acc # self.metrics["batch_metrics"] = batch_metrics - def calculate(self, global_metrics): + def _calculate(self, global_metrics): for key in self._global_metric_state_vars: if key not in global_metrics: raise ValueError("%s not existed" % key) diff --git a/core/trainers/framework/runner.py b/core/trainers/framework/runner.py index 91164a65a51db847a53df8503e497fa14655508c..5da1f5df723d43dcefd327193e0b7e7ab5368366 100644 --- a/core/trainers/framework/runner.py +++ b/core/trainers/framework/runner.py @@ -356,7 +356,7 @@ class SingleRunner(RunnerBase): metrics_result = [] for key in metrics: if isinstance(metrics[key], Metric): - _str = metrics[key].cal_global_metrics( + _str = metrics[key].calc_global_metrics( None, context["model"][model_dict["name"]]["scope"]) metrics_result.append(_str) @@ -404,7 +404,7 @@ class PSRunner(RunnerBase): metrics_result = [] for key in metrics: if isinstance(metrics[key], Metric): - _str = metrics[key].cal_global_metrics( + _str = metrics[key].calc_global_metrics( context["fleet"], context["model"][model_dict["name"]]["scope"]) metrics_result.append(_str) @@ -536,7 +536,7 @@ class SingleInferRunner(RunnerBase): metrics_result = [] for key in metrics: if isinstance(metrics[key], Metric): - _str = metrics[key].cal_global_metrics( + _str = metrics[key].calc_global_metrics( None, context["model"][model_dict["name"]]["scope"]) metrics_result.append(_str)