From b623f4b863cd624120b63a06801d6b7c5c5fae9d Mon Sep 17 00:00:00 2001 From: Guoxia Wang Date: Tue, 17 Jan 2023 15:28:23 +0800 Subject: [PATCH] fix json output (#5708) --- modelcenter/PLSC-ViT/APP/app.py | 8 ++++++-- modelcenter/PLSC-ViT/APP/predictor.py | 12 ++++++------ 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/modelcenter/PLSC-ViT/APP/app.py b/modelcenter/PLSC-ViT/APP/app.py index 89e38d59..7dc1791b 100644 --- a/modelcenter/PLSC-ViT/APP/app.py +++ b/modelcenter/PLSC-ViT/APP/app.py @@ -15,8 +15,12 @@ def model_inference(image): model_path=model_path, params_path=params_path, label_path=label_path) - scores, labels = predictor.predict(image) - json_out = {"scores": scores.tolist(), "labels": labels.tolist()} + class_ids, scores, labels = predictor.predict(image) + json_out = { + "class_ids": class_ids.tolist(), + "scores": scores.tolist(), + "labels": labels.tolist() + } return image, json_out diff --git a/modelcenter/PLSC-ViT/APP/predictor.py b/modelcenter/PLSC-ViT/APP/predictor.py index 46fd81a7..6d3782f8 100644 --- a/modelcenter/PLSC-ViT/APP/predictor.py +++ b/modelcenter/PLSC-ViT/APP/predictor.py @@ -4,6 +4,7 @@ import numpy as np import paddle from download import get_model_path, get_data_path + class Predictor(object): def __init__(self, model_type="paddle", @@ -42,8 +43,8 @@ class Predictor(object): self.predictor.run() outputs = [] for output_idx in range(len(self.output_names)): - output_tensor = self.predictor.get_output_handle( - self.output_names[output_idx]) + output_tensor = self.predictor.get_output_handle(self.output_names[ + output_idx]) outputs.append(output_tensor.copy_to_cpu()) if self.postprocess is not None: output_data = self.postprocess(outputs) @@ -69,7 +70,7 @@ class Predictor(object): for line in f: if len(line) < 2: continue - label = line.strip().split(',')[0].split(' ')[2] + label = line.strip().split(',')[1].strip() labels.append(label) return labels @@ -83,6 +84,5 @@ class Predictor(object): pred = np.array(logits).squeeze() pred = self.softmax(pred) class_idx = pred.argsort()[::-1] - return pred[class_idx[:5]], np.array(self.labels)[class_idx[:5]] - - + return class_idx[:5], pred[class_idx[:5]], np.array(self.labels)[ + class_idx[:5]] -- GitLab