提交 682bede3 编写于 作者: W wangxiao1021

fix bugs

上级 632de7af
......@@ -337,7 +337,7 @@ def _write_predictions(all_examples, all_features, all_results, n_best_size,
nbest_json = []
for (i, entry) in enumerate(nbest):
output = collections.OrderedDict()
output["text"] = entry.text
output["text"] = entry.text.encode('utf-8').decode('utf-8')
output["probability"] = probs[i]
output["start_logit"] = entry.start_logit
output["end_logit"] = entry.end_logit
......@@ -358,8 +358,11 @@ def _write_predictions(all_examples, all_features, all_results, n_best_size,
all_predictions[example.qas_id] = best_non_null_entry.text
all_nbest_json[example.qas_id] = nbest_json
with open(output_prediction_file, "w") as writer:
writer.write(json.dumps(all_predictions, indent=4, ensure_ascii=False) + "\n")
with open(output_nbest_file, "w") as writer:
......
......@@ -22,6 +22,7 @@ import paddle
from paddle import fluid
from paddle.fluid import layers
from paddlepalm.distribute import gpu_dev_count, cpu_dev_count
import six
dev_count = 1 if gpu_dev_count <= 1 else gpu_dev_count
......@@ -35,7 +36,8 @@ def create_feed_batch_process_fn(net_inputs):
inputs= net_inputs
for q, var in inputs.items():
if isinstance(var, str) or isinstance(var, unicode):
if isinstance(var, str) or (six.PY3 and isinstance(var, bytes)) or (six.PY2 and isinstance(var, unicode)):
temp[var] = data[q]
else:
temp[var.name] = data[q]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册