提交 b14bb89c 编写于 作者: W wangxiao1021

add evaluate-slot.py

上级 ad7d3792
# -*- coding: utf-8 -*-
import json
def load_label_map(map_dir="./data/atis/atis_slot/label_map.json"):
"""
:param map_dir: dict indictuing chunk type
:return:
"""
return json.load(open(map_dir, "r"))
def cal_chunk(total_res, total_label):
assert len(total_label) == len(total_res), "prediction result doesn't match to labels, {}, {}".format(len(total_res),len(total_label))
num_labels = 0
num_corr = 0
num_infers = 0
for res, label in zip(total_res, total_label):
assert len(res) == len(label), "prediction result doesn't match to labels, {}, {}".format(len(res),len(label))
num_labels += sum([0 if i == 6 else 1 for i in label])
num_corr += sum([1 if label[i] == res[i] and label[i] != 6 else 0 for i in range(len(label))])
num_infers += sum([0 if i == 6 else 1 for i in res])
precision = num_corr * 1.0 / num_infers if num_infers > 0 else 0.0
recall = num_corr * 1.0 / num_labels if num_labels > 0 else 0.0
f1 = 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0.0
return precision, recall, f1
def res_evaluate(res_dir="./outputs/predict-slot/predictions.json", data_dir="./data/atis/atis_slot/test.tsv"):
label_map = load_label_map()
total_label = []
with open(data_dir, "r") as file:
first_flag = True
for line in file:
if first_flag:
first_flag = False
continue
line = line.strip("\n")
if len(line) == 0:
continue
line = line.split("\t")
if len(line) < 2:
continue
labels = line[1][:-1].split("\x02")
total_label.append(labels)
total_label = [[label_map[j] for j in i] for i in total_label]
total_res = []
with open(res_dir, "r") as file:
cnt = 0
for line in file:
line = line.strip("\n")
if len(line) == 0:
continue
try:
res_arr = json.loads(line)
if len(total_label[cnt]) < len(res_arr):
total_res.append(res_arr[1: 1 + len(total_label[cnt])])
elif len(total_label[cnt]) == len(res_arr):
total_res.append(res_arr)
else:
total_res.append(res_arr)
total_label[cnt] = total_label[cnt][: len(res_arr)]
except:
print("json format error: {}".format(cnt))
print(line)
cnt += 1
precision, recall, f1 = cal_chunk(total_res, total_label)
print("precision: {}, recall: {}, f1: {}".format(precision, recall, f1))
res_evaluate()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册