提交 2cbf686a 编写于 作者: M malin10

bug fix

上级 2949a20f
...@@ -23,11 +23,13 @@ class Metric(object): ...@@ -23,11 +23,13 @@ class Metric(object):
__metaclass__ = abc.ABCMeta __metaclass__ = abc.ABCMeta
def __init__(self, config): def __init__(self, config):
""" """ """R
"""
pass pass
def clear(self, scope=None): def clear(self, scope=None):
""" """ """R
"""
if scope is None: if scope is None:
scope = fluid.global_scope() scope = fluid.global_scope()
...@@ -41,9 +43,13 @@ class Metric(object): ...@@ -41,9 +43,13 @@ class Metric(object):
data_array = np.zeros(var._get_dims()).astype(dtype) data_array = np.zeros(var._get_dims()).astype(dtype)
var.set(data_array, place) var.set(data_array, place)
def get_global_metric_state(self, fleet, scope, metric_name, mode="sum"): def _get_global_metric_state(self, fleet, scope, metric_name, mode="sum"):
""" """ """R
input = np.array(scope.find_var(metric_name).get_tensor()) """
var = scope.find_var(metric_name)
if not var:
return None
input = np.array(var.get_tensor())
if fleet is None: if fleet is None:
return input return input
fleet._role_maker._barrier_worker() fleet._role_maker._barrier_worker()
...@@ -54,8 +60,9 @@ class Metric(object): ...@@ -54,8 +60,9 @@ class Metric(object):
output = output.reshape(old_shape) output = output.reshape(old_shape)
return output return output
def cal_global_metrics(self, fleet, scope=None): def calc_global_metrics(self, fleet, scope=None):
""" """ """R
"""
if scope is None: if scope is None:
scope = fluid.global_scope() scope = fluid.global_scope()
...@@ -65,9 +72,9 @@ class Metric(object): ...@@ -65,9 +72,9 @@ class Metric(object):
global_metrics[key] = self.get_global_metric_state(fleet, scope, global_metrics[key] = self.get_global_metric_state(fleet, scope,
varname) varname)
return self.calculate(global_metrics) return self._calculate(global_metrics)
def calculate(self, global_metrics): def _calculate(self, global_metrics):
pass pass
@abc.abstractmethod @abc.abstractmethod
......
...@@ -74,7 +74,7 @@ class AUC(Metric): ...@@ -74,7 +74,7 @@ class AUC(Metric):
self.metrics["AUC"] = auc_out self.metrics["AUC"] = auc_out
self.metrics["BATCH_AUC"] = batch_auc_out self.metrics["BATCH_AUC"] = batch_auc_out
def calculate_bucket_error(self, global_pos, global_neg): def _calculate_bucket_error(self, global_pos, global_neg):
"""R """R
""" """
num_bucket = len(global_pos) num_bucket = len(global_pos)
...@@ -122,7 +122,7 @@ class AUC(Metric): ...@@ -122,7 +122,7 @@ class AUC(Metric):
bucket_error = error_sum / error_count if error_count > 0 else 0.0 bucket_error = error_sum / error_count if error_count > 0 else 0.0
return bucket_error return bucket_error
def calculate_auc(self, global_pos, global_neg): def _calculate_auc(self, global_pos, global_neg):
"""R """R
""" """
num_bucket = len(global_pos) num_bucket = len(global_pos)
...@@ -148,7 +148,7 @@ class AUC(Metric): ...@@ -148,7 +148,7 @@ class AUC(Metric):
auc_value = area / (pos * neg) auc_value = area / (pos * neg)
return auc_value return auc_value
def calculate(self, global_metrics): def _calculate(self, global_metrics):
result = dict() result = dict()
for key in self._global_metric_state_vars: for key in self._global_metric_state_vars:
if key not in global_metrics: if key not in global_metrics:
...@@ -165,10 +165,10 @@ class AUC(Metric): ...@@ -165,10 +165,10 @@ class AUC(Metric):
result['copc'] = 0 result['copc'] = 0
result['mean_q'] = 0 result['mean_q'] = 0
else: else:
result['auc'] = self.calculate_auc(result['stat_pos'], result['auc'] = self._calculate_auc(result['stat_pos'],
result['stat_neg']) result['stat_neg'])
result['bucket_error'] = self.calculate_auc(result['stat_pos'], result['bucket_error'] = self._calculate_bucket_error(
result['stat_neg']) result['stat_pos'], result['stat_neg'])
result['actual_ctr'] = result['pos_ins_num'] / result[ result['actual_ctr'] = result['pos_ins_num'] / result[
'total_ins_num'] 'total_ins_num']
result['mae'] = result['abserr'] / result['total_ins_num'] result['mae'] = result['abserr'] / result['total_ins_num']
......
...@@ -29,7 +29,8 @@ class PrecisionRecall(Metric): ...@@ -29,7 +29,8 @@ class PrecisionRecall(Metric):
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
""" """ """R
"""
if "input" not in kwargs or "label" not in kwargs or "class_num" not in kwargs: if "input" not in kwargs or "label" not in kwargs or "class_num" not in kwargs:
raise ValueError( raise ValueError(
"PrecisionRecall expect input, label and class_num as inputs.") "PrecisionRecall expect input, label and class_num as inputs.")
...@@ -107,9 +108,7 @@ class PrecisionRecall(Metric): ...@@ -107,9 +108,7 @@ class PrecisionRecall(Metric):
self.metrics["precision_recall_f1"] = accum_metrics self.metrics["precision_recall_f1"] = accum_metrics
self.metrics["[TP FP TN FN]"] = states_info self.metrics["[TP FP TN FN]"] = states_info
# self.metrics["batch_metrics"] = batch_metrics def _calculate(self, global_metrics):
def calculate(self, global_metrics):
for key in self._global_metric_state_vars: for key in self._global_metric_state_vars:
if key not in global_metrics: if key not in global_metrics:
raise ValueError("%s not existed" % key) raise ValueError("%s not existed" % key)
......
...@@ -84,7 +84,7 @@ class PosNegRatio(Metric): ...@@ -84,7 +84,7 @@ class PosNegRatio(Metric):
self.metrics['RightCnt'] = global_right_cnt self.metrics['RightCnt'] = global_right_cnt
self.metrics['PN'] = self.pn self.metrics['PN'] = self.pn
def calculate(self, global_metrics): def _calculate(self, global_metrics):
for key in self._global_communicate_var: for key in self._global_communicate_var:
if key not in global_metrics: if key not in global_metrics:
raise ValueError("%s not existed" % key) raise ValueError("%s not existed" % key)
......
...@@ -88,7 +88,7 @@ class RecallK(Metric): ...@@ -88,7 +88,7 @@ class RecallK(Metric):
self.metrics[metric_name] = self.acc self.metrics[metric_name] = self.acc
# self.metrics["batch_metrics"] = batch_metrics # self.metrics["batch_metrics"] = batch_metrics
def calculate(self, global_metrics): def _calculate(self, global_metrics):
for key in self._global_metric_state_vars: for key in self._global_metric_state_vars:
if key not in global_metrics: if key not in global_metrics:
raise ValueError("%s not existed" % key) raise ValueError("%s not existed" % key)
......
...@@ -356,7 +356,7 @@ class SingleRunner(RunnerBase): ...@@ -356,7 +356,7 @@ class SingleRunner(RunnerBase):
metrics_result = [] metrics_result = []
for key in metrics: for key in metrics:
if isinstance(metrics[key], Metric): if isinstance(metrics[key], Metric):
_str = metrics[key].cal_global_metrics( _str = metrics[key].calc_global_metrics(
None, None,
context["model"][model_dict["name"]]["scope"]) context["model"][model_dict["name"]]["scope"])
metrics_result.append(_str) metrics_result.append(_str)
...@@ -404,7 +404,7 @@ class PSRunner(RunnerBase): ...@@ -404,7 +404,7 @@ class PSRunner(RunnerBase):
metrics_result = [] metrics_result = []
for key in metrics: for key in metrics:
if isinstance(metrics[key], Metric): if isinstance(metrics[key], Metric):
_str = metrics[key].cal_global_metrics( _str = metrics[key].calc_global_metrics(
context["fleet"], context["fleet"],
context["model"][model_dict["name"]]["scope"]) context["model"][model_dict["name"]]["scope"])
metrics_result.append(_str) metrics_result.append(_str)
...@@ -536,7 +536,7 @@ class SingleInferRunner(RunnerBase): ...@@ -536,7 +536,7 @@ class SingleInferRunner(RunnerBase):
metrics_result = [] metrics_result = []
for key in metrics: for key in metrics:
if isinstance(metrics[key], Metric): if isinstance(metrics[key], Metric):
_str = metrics[key].cal_global_metrics( _str = metrics[key].calc_global_metrics(
None, None,
context["model"][model_dict["name"]]["scope"]) context["model"][model_dict["name"]]["scope"])
metrics_result.append(_str) metrics_result.append(_str)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册