From d4b3142e642e62480696974dfa907dad55abb615 Mon Sep 17 00:00:00 2001 From: Bubbliiiing <47347516+bubbliiiing@users.noreply.github.com> Date: Wed, 30 Dec 2020 16:44:06 +0800 Subject: [PATCH] Add files via upload --- get_gt_txt.py | 25 +++++++++++++++++++++++++ get_map.py | 50 ++++++++++++++++++++++++++++++++++++-------------- 2 files changed, 61 insertions(+), 14 deletions(-) diff --git a/get_gt_txt.py b/get_gt_txt.py index 4e62ce3..06046ba 100644 --- a/get_gt_txt.py +++ b/get_gt_txt.py @@ -8,6 +8,20 @@ import os import glob import xml.etree.ElementTree as ET +''' +!!!!!!!!!!!!!注意事项!!!!!!!!!!!!! +# 这一部分是当xml有无关的类的时候,下方有代码可以进行筛选! +''' +#---------------------------------------------------# +# 获得类 +#---------------------------------------------------# +def get_classes(classes_path): + '''loads the classes''' + with open(classes_path) as f: + class_names = f.readlines() + class_names = [c.strip() for c in class_names] + return class_names + image_ids = open('VOCdevkit/VOC2007/ImageSets/Main/test.txt').read().strip().split() if not os.path.exists("./input"): @@ -25,11 +39,22 @@ for image_id in image_ids: if int(difficult)==1: difficult_flag = True obj_name = obj.find('name').text + ''' + !!!!!!!!!!!!注意事项!!!!!!!!!!!! + # 这一部分是当xml有无关的类的时候,可以取消下面代码的注释 + # 利用对应的classes.txt来进行筛选!!!!!!!!!!!! + ''' + # classes_path = 'model_data/voc_classes.txt' + # class_names = get_classes(classes_path) + # if obj_name not in class_names: + # continue + bndbox = obj.find('bndbox') left = bndbox.find('xmin').text top = bndbox.find('ymin').text right = bndbox.find('xmax').text bottom = bndbox.find('ymax').text + if difficult_flag: new_f.write("%s %s %s %s %s difficult\n" % (obj_name, left, top, right, bottom)) else: diff --git a/get_map.py b/get_map.py index 632e918..786bb29 100644 --- a/get_map.py +++ b/get_map.py @@ -389,13 +389,28 @@ for txt_file in ground_truth_files_list: is_difficult = True else: class_name, left, top, right, bottom = line.split() - except ValueError: - error_msg = "Error: File " + txt_file + " in the wrong format.\n" - error_msg += " Expected: ['difficult']\n" - error_msg += " Received: " + line - error_msg += "\n\nIf you have a with spaces between words you should remove them\n" - error_msg += "by running the script \"remove_space.py\" or \"rename_class.py\" in the \"extra/\" folder." - error(error_msg) + + except: + if "difficult" in line: + line_split = line.split() + _difficult = line_split[-1] + bottom = line_split[-2] + right = line_split[-3] + top = line_split[-4] + left = line_split[-5] + class_name = "" + for name in line_split[:-5]: + class_name += name + is_difficult = True + else: + line_split = line.split() + bottom = line_split[-1] + right = line_split[-2] + top = line_split[-3] + left = line_split[-4] + class_name = "" + for name in line_split[:-4]: + class_name += name # check if class is in the ignore list, if yes skip if class_name in args.ignore: continue @@ -481,11 +496,17 @@ for class_index, class_name in enumerate(gt_classes): for line in lines: try: tmp_class_name, confidence, left, top, right, bottom = line.split() - except ValueError: - error_msg = "Error: File " + txt_file + " in the wrong format.\n" - error_msg += " Expected: \n" - error_msg += " Received: " + line - error(error_msg) + except: + line_split = line.split() + bottom = line_split[-1] + right = line_split[-2] + top = line_split[-3] + left = line_split[-4] + confidence = line_split[-5] + tmp_class_name = "" + for name in line_split[:-5]: + tmp_class_name += name + if tmp_class_name == class_name: #print("match") bbox = left + " " + top + " " + right + " " +bottom @@ -702,8 +723,9 @@ with open(results_files_path + "/results.txt", 'w') as results_file: rounded_rec = [ '%.2f' % elem for elem in rec ] results_file.write(text + "\n Precision: " + str(rounded_prec) + "\n Recall :" + str(rounded_rec) + "\n\n") if not args.quiet: - print(text + "\t||\tscore_threhold=0.5 : " + "F1=" + "{0:.2f}".format(F1[score05_idx])\ - + " ; Recall=" + "{0:.2f}%".format(rec[score05_idx]*100) + " ; Precision=" + "{0:.2f}%".format(prec[score05_idx]*100)) + if(len(rec)!=0): + print(text + "\t||\tscore_threhold=0.5 : " + "F1=" + "{0:.2f}".format(F1[score05_idx])\ + + " ; Recall=" + "{0:.2f}%".format(rec[score05_idx]*100) + " ; Precision=" + "{0:.2f}%".format(prec[score05_idx]*100)) ap_dictionary[class_name] = ap n_images = counter_images_per_class[class_name] -- GitLab