提交 2cbf686a 编写于 作者: M malin10

bug fix

上级 2949a20f
......@@ -23,11 +23,13 @@ class Metric(object):
__metaclass__ = abc.ABCMeta
def __init__(self, config):
""" """
"""R
"""
pass
def clear(self, scope=None):
""" """
"""R
"""
if scope is None:
scope = fluid.global_scope()
......@@ -41,9 +43,13 @@ class Metric(object):
data_array = np.zeros(var._get_dims()).astype(dtype)
var.set(data_array, place)
def get_global_metric_state(self, fleet, scope, metric_name, mode="sum"):
""" """
input = np.array(scope.find_var(metric_name).get_tensor())
def _get_global_metric_state(self, fleet, scope, metric_name, mode="sum"):
"""R
"""
var = scope.find_var(metric_name)
if not var:
return None
input = np.array(var.get_tensor())
if fleet is None:
return input
fleet._role_maker._barrier_worker()
......@@ -54,8 +60,9 @@ class Metric(object):
output = output.reshape(old_shape)
return output
def cal_global_metrics(self, fleet, scope=None):
""" """
def calc_global_metrics(self, fleet, scope=None):
"""R
"""
if scope is None:
scope = fluid.global_scope()
......@@ -65,9 +72,9 @@ class Metric(object):
global_metrics[key] = self.get_global_metric_state(fleet, scope,
varname)
return self.calculate(global_metrics)
return self._calculate(global_metrics)
def calculate(self, global_metrics):
def _calculate(self, global_metrics):
pass
@abc.abstractmethod
......
......@@ -74,7 +74,7 @@ class AUC(Metric):
self.metrics["AUC"] = auc_out
self.metrics["BATCH_AUC"] = batch_auc_out
def calculate_bucket_error(self, global_pos, global_neg):
def _calculate_bucket_error(self, global_pos, global_neg):
"""R
"""
num_bucket = len(global_pos)
......@@ -122,7 +122,7 @@ class AUC(Metric):
bucket_error = error_sum / error_count if error_count > 0 else 0.0
return bucket_error
def calculate_auc(self, global_pos, global_neg):
def _calculate_auc(self, global_pos, global_neg):
"""R
"""
num_bucket = len(global_pos)
......@@ -148,7 +148,7 @@ class AUC(Metric):
auc_value = area / (pos * neg)
return auc_value
def calculate(self, global_metrics):
def _calculate(self, global_metrics):
result = dict()
for key in self._global_metric_state_vars:
if key not in global_metrics:
......@@ -165,10 +165,10 @@ class AUC(Metric):
result['copc'] = 0
result['mean_q'] = 0
else:
result['auc'] = self.calculate_auc(result['stat_pos'],
result['stat_neg'])
result['bucket_error'] = self.calculate_auc(result['stat_pos'],
result['stat_neg'])
result['auc'] = self._calculate_auc(result['stat_pos'],
result['stat_neg'])
result['bucket_error'] = self._calculate_bucket_error(
result['stat_pos'], result['stat_neg'])
result['actual_ctr'] = result['pos_ins_num'] / result[
'total_ins_num']
result['mae'] = result['abserr'] / result['total_ins_num']
......
......@@ -29,7 +29,8 @@ class PrecisionRecall(Metric):
"""
def __init__(self, **kwargs):
""" """
"""R
"""
if "input" not in kwargs or "label" not in kwargs or "class_num" not in kwargs:
raise ValueError(
"PrecisionRecall expect input, label and class_num as inputs.")
......@@ -107,9 +108,7 @@ class PrecisionRecall(Metric):
self.metrics["precision_recall_f1"] = accum_metrics
self.metrics["[TP FP TN FN]"] = states_info
# self.metrics["batch_metrics"] = batch_metrics
def calculate(self, global_metrics):
def _calculate(self, global_metrics):
for key in self._global_metric_state_vars:
if key not in global_metrics:
raise ValueError("%s not existed" % key)
......
......@@ -84,7 +84,7 @@ class PosNegRatio(Metric):
self.metrics['RightCnt'] = global_right_cnt
self.metrics['PN'] = self.pn
def calculate(self, global_metrics):
def _calculate(self, global_metrics):
for key in self._global_communicate_var:
if key not in global_metrics:
raise ValueError("%s not existed" % key)
......
......@@ -88,7 +88,7 @@ class RecallK(Metric):
self.metrics[metric_name] = self.acc
# self.metrics["batch_metrics"] = batch_metrics
def calculate(self, global_metrics):
def _calculate(self, global_metrics):
for key in self._global_metric_state_vars:
if key not in global_metrics:
raise ValueError("%s not existed" % key)
......
......@@ -356,7 +356,7 @@ class SingleRunner(RunnerBase):
metrics_result = []
for key in metrics:
if isinstance(metrics[key], Metric):
_str = metrics[key].cal_global_metrics(
_str = metrics[key].calc_global_metrics(
None,
context["model"][model_dict["name"]]["scope"])
metrics_result.append(_str)
......@@ -404,7 +404,7 @@ class PSRunner(RunnerBase):
metrics_result = []
for key in metrics:
if isinstance(metrics[key], Metric):
_str = metrics[key].cal_global_metrics(
_str = metrics[key].calc_global_metrics(
context["fleet"],
context["model"][model_dict["name"]]["scope"])
metrics_result.append(_str)
......@@ -536,7 +536,7 @@ class SingleInferRunner(RunnerBase):
metrics_result = []
for key in metrics:
if isinstance(metrics[key], Metric):
_str = metrics[key].cal_global_metrics(
_str = metrics[key].calc_global_metrics(
None,
context["model"][model_dict["name"]]["scope"])
metrics_result.append(_str)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册