diff --git a/paddlepalm/head/cls.py b/paddlepalm/head/cls.py index 499c0f77d82f36872b745f94b94bd7ef89bf1727..66117ac8810b9844f9ee2f73972b1090aa470122 100644 --- a/paddlepalm/head/cls.py +++ b/paddlepalm/head/cls.py @@ -98,7 +98,7 @@ class Classify(Head): raise ValueError('argument output_dir not found in config. Please add it into config dict/file.') with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer: for i in range(len(self._preds)): - label = 0 if self._preds[i][0] > self._preds[i][1] else 1 + label = np.argmax(np.array(self._preds[i])) result = {'index': i, 'label': label, 'logits': self._preds[i], 'probs': self._probs[i]} result = json.dumps(result) writer.write(result+'\n') diff --git a/paddlepalm/head/match.py b/paddlepalm/head/match.py index ec6bf40d9fa878d6b304d345026371a656dd935b..9df4a1a1d8c532db94eaf3484ac88581184e07b8 100644 --- a/paddlepalm/head/match.py +++ b/paddlepalm/head/match.py @@ -179,11 +179,10 @@ class Match(Head): with open(os.path.join(output_dir, 'predictions.json'), 'w') as writer: for i in range(len(self._preds)): if self._learning_strategy == 'pointwise': - label = 0 if self._preds[i][0] > self._preds[i][1] else 1 + label = np.argmax(np.array(self._preds[i])) result = {'index': i, 'label': label, 'logits': self._preds_logits[i], 'probs': self._preds[i]} elif self._learning_strategy == 'pairwise': - label = 0 if self._preds[i][0] < 0.5 else 1 - result = {'index': i, 'label': label, 'probs': self._preds[i][0]} + result = {'index': i, 'probs': self._preds[i][0]} result = json.dumps(result, ensure_ascii=False) writer.write(result+'\n') print('Predictions saved at '+os.path.join(output_dir, 'predictions.json'))