未验证 提交 88687b57 编写于 作者: B Bubbliiiing 提交者: GitHub

Update get_map.py

上级 48878af0
import argparse
import glob
import json
import math
import operator
import os
import shutil
import operator
import sys
import argparse
import math
import numpy as np
#----------------------------------------------------#
# 用于计算mAP
# 代码克隆自https://github.com/Cartucho/mAP
#----------------------------------------------------#
MINOVERLAP = 0.5 # default value (defined in the PASCAL VOC2012 challenge)
'''
用于计算mAP
代码克隆自https://github.com/Cartucho/mAP
如果想要设定mAP0.x,比如计算mAP0.75,可以设定MINOVERLAP = 0.75。
'''
MINOVERLAP = 0.5
parser = argparse.ArgumentParser()
parser.add_argument('-na', '--no-animation', help="no animation is shown.", action="store_true")
parser.add_argument('-np', '--no-plot', help="no plot is shown.", action="store_true")
parser.add_argument('-q', '--quiet', help="minimalistic console output.", action="store_true")
# argparse receiving list of classes to be ignored
parser.add_argument('-i', '--ignore', nargs='+', type=str, help="ignore a list of classes.")
# argparse receiving list of classes with specific IoU (e.g., python main.py --set-class-iou person 0.7)
parser.add_argument('--set-class-iou', nargs='+', type=str, help="set IoU for a specific class.")
args = parser.parse_args()
......@@ -37,7 +36,6 @@ args = parser.parse_args()
(Right,Bottom)
'''
# if there are no classes to ignore then replace None by empty list
if args.ignore is None:
args.ignore = []
......@@ -45,22 +43,18 @@ specific_iou_flagged = False
if args.set_class_iou is not None:
specific_iou_flagged = True
# make sure that the cwd() is the location of the python script (so that every path makes sense)
os.chdir(os.path.dirname(os.path.abspath(__file__)))
GT_PATH = os.path.join(os.getcwd(), 'input', 'ground-truth')
DR_PATH = os.path.join(os.getcwd(), 'input', 'detection-results')
# if there are no images then no animation can be shown
IMG_PATH = os.path.join(os.getcwd(), 'input', 'images-optional')
if os.path.exists(IMG_PATH):
for dirpath, dirnames, files in os.walk(IMG_PATH):
if not files:
# no image files found
args.no_animation = True
else:
args.no_animation = True
# try to import OpenCV if the user didn't choose the option --no-animation
show_animation = False
if not args.no_animation:
try:
......@@ -70,7 +64,6 @@ if not args.no_animation:
print("\"opencv-python\" not found, please install to visualize the results.")
args.no_animation = True
# try to import Matplotlib if the user didn't choose the option --no-plot
draw_plot = False
if not args.no_plot:
try:
......@@ -98,7 +91,6 @@ def log_average_miss_rate(precision, fp_cumsum, num_images):
Transactions on 34.4 (2012): 743 - 761.
"""
# if there were no detections of that class
if precision.size == 0:
lamr = 0
mr = 1
......@@ -111,14 +103,11 @@ def log_average_miss_rate(precision, fp_cumsum, num_images):
fppi_tmp = np.insert(fppi, 0, -1.0)
mr_tmp = np.insert(mr, 0, 1.0)
# Use 9 evenly spaced reference points in log-space
ref = np.logspace(-2.0, 0.0, num = 9)
for i, ref_i in enumerate(ref):
# np.where() will always find at least 1 index, since min(ref) = 0.01 and min(fppi_tmp) = -1.0
j = np.where(fppi_tmp <= ref_i)[-1][-1]
ref[i] = mr_tmp[j]
# log(0) is undefined, so we use the np.maximum(1e-10, ref)
lamr = math.exp(np.mean(np.log(np.maximum(1e-10, ref))))
return lamr, mr, fppi
......@@ -172,10 +161,6 @@ def voc_ap(rec, prec):
matlab: for i=numel(mpre)-1:-1:1
mpre(i)=max(mpre(i),mpre(i+1));
"""
# matlab indexes start in 1 but python in 0, so I have to do:
# range(start=(len(mpre) - 2), end=0, step=-1)
# also the python function range excludes the end, resulting in:
# range(start=(len(mpre) - 2), end=-1, step=-1)
for i in range(len(mpre)-2, -1, -1):
mpre[i] = max(mpre[i], mpre[i+1])
"""
......@@ -386,10 +371,10 @@ for txt_file in ground_truth_files_list:
for line in lines_list:
try:
if "difficult" in line:
class_name, left, top, right, bottom, _difficult = line.split()
is_difficult = True
class_name, left, top, right, bottom, _difficult = line.split()
is_difficult = True
else:
class_name, left, top, right, bottom = line.split()
class_name, left, top, right, bottom = line.split()
except:
if "difficult" in line:
......@@ -401,7 +386,8 @@ for txt_file in ground_truth_files_list:
left = line_split[-5]
class_name = ""
for name in line_split[:-5]:
class_name += name
class_name += name + " "
class_name = class_name[:-1]
is_difficult = True
else:
line_split = line.split()
......@@ -411,8 +397,8 @@ for txt_file in ground_truth_files_list:
left = line_split[-4]
class_name = ""
for name in line_split[:-4]:
class_name += name
# check if class is in the ignore list, if yes skip
class_name += name + " "
class_name = class_name[:-1]
if class_name in args.ignore:
continue
bbox = left + " " + top + " " + right + " " +bottom
......@@ -421,32 +407,25 @@ for txt_file in ground_truth_files_list:
is_difficult = False
else:
bounding_boxes.append({"class_name":class_name, "bbox":bbox, "used":False})
# count that object
if class_name in gt_counter_per_class:
gt_counter_per_class[class_name] += 1
else:
# if class didn't exist yet
gt_counter_per_class[class_name] = 1
if class_name not in already_seen_classes:
if class_name in counter_images_per_class:
counter_images_per_class[class_name] += 1
else:
# if class didn't exist yet
counter_images_per_class[class_name] = 1
already_seen_classes.append(class_name)
# dump bounding_boxes into a ".json" file
with open(TEMP_FILES_PATH + "/" + file_id + "_ground_truth.json", 'w') as outfile:
json.dump(bounding_boxes, outfile)
gt_classes = list(gt_counter_per_class.keys())
# let's sort the classes alphabetically
gt_classes = sorted(gt_classes)
n_classes = len(gt_classes)
#print(gt_classes)
#print(gt_counter_per_class)
"""
Check format of the flag --set-class-iou (if used)
......@@ -476,15 +455,12 @@ if specific_iou_flagged:
detection-results
Load each of the detection-results files into a temporary ".json" file.
"""
# get a list with the detection-results files
dr_files_list = glob.glob(DR_PATH + '/*.txt')
dr_files_list.sort()
for class_index, class_name in enumerate(gt_classes):
bounding_boxes = []
for txt_file in dr_files_list:
#print(txt_file)
# the first time it checks if all the corresponding ground-truth files exist
file_id = txt_file.split(".txt",1)[0]
file_id = os.path.basename(os.path.normpath(file_id))
temp_path = os.path.join(GT_PATH, (file_id + ".txt"))
......@@ -506,14 +482,13 @@ for class_index, class_name in enumerate(gt_classes):
confidence = line_split[-5]
tmp_class_name = ""
for name in line_split[:-5]:
tmp_class_name += name
tmp_class_name += name + " "
tmp_class_name = tmp_class_name[:-1]
if tmp_class_name == class_name:
#print("match")
bbox = left + " " + top + " " + right + " " +bottom
bounding_boxes.append({"confidence":confidence, "file_id":file_id, "bbox":bbox})
#print(bounding_boxes)
# sort detection-results by decreasing confidence
bounding_boxes.sort(key=lambda x:float(x['confidence']), reverse=True)
with open(TEMP_FILES_PATH + "/" + class_name + "_dr.json", 'w') as outfile:
json.dump(bounding_boxes, outfile)
......@@ -524,7 +499,6 @@ for class_index, class_name in enumerate(gt_classes):
sum_AP = 0.0
ap_dictionary = {}
lamr_dictionary = {}
# open file to store the results
with open(results_files_path + "/results.txt", 'w') as results_file:
results_file.write("# AP and precision/recall per class\n")
count_true_positives = {}
......@@ -536,12 +510,11 @@ with open(results_files_path + "/results.txt", 'w') as results_file:
"""
dr_file = TEMP_FILES_PATH + "/" + class_name + "_dr.json"
dr_data = json.load(open(dr_file))
"""
Assign detection-results to ground-truth objects
"""
nd = len(dr_data)
tp = [0] * nd # creates an array of zeros of size nd
tp = [0] * nd
fp = [0] * nd
score = [0] * nd
score05_idx = 0
......@@ -552,37 +525,28 @@ with open(results_files_path + "/results.txt", 'w') as results_file:
score05_idx = idx
if show_animation:
# find ground truth image
ground_truth_img = glob.glob1(IMG_PATH, file_id + ".*")
#tifCounter = len(glob.glob1(myPath,"*.tif"))
if len(ground_truth_img) == 0:
error("Error. Image not found with id: " + file_id)
elif len(ground_truth_img) > 1:
error("Error. Multiple image with id: " + file_id)
else: # found image
#print(IMG_PATH + "/" + ground_truth_img[0])
# Load image
else:
img = cv2.imread(IMG_PATH + "/" + ground_truth_img[0])
# load image with draws of multiple detections
img_cumulative_path = results_files_path + "/images/" + ground_truth_img[0]
if os.path.isfile(img_cumulative_path):
img_cumulative = cv2.imread(img_cumulative_path)
else:
img_cumulative = img.copy()
# Add bottom border to image
bottom_border = 60
BLACK = [0, 0, 0]
img = cv2.copyMakeBorder(img, 0, bottom_border, 0, 0, cv2.BORDER_CONSTANT, value=BLACK)
# assign detection-results to ground truth object if any
# open ground-truth with that file_id
gt_file = TEMP_FILES_PATH + "/" + file_id + "_ground_truth.json"
ground_truth_data = json.load(open(gt_file))
ovmax = -1
gt_match = -1
# load detected object bounding-box
bb = [ float(x) for x in detection["bbox"].split() ]
for obj in ground_truth_data:
# look for a class_name match
if obj["class_name"] == class_name:
bbgt = [ float(x) for x in obj["bbox"].split() ]
bi = [max(bb[0],bbgt[0]), max(bb[1],bbgt[1]), min(bb[2],bbgt[2]), min(bb[3],bbgt[3])]
......@@ -597,10 +561,8 @@ with open(results_files_path + "/results.txt", 'w') as results_file:
ovmax = ov
gt_match = obj
# assign detection as true positive/don't care/false positive
if show_animation:
status = "NO MATCH FOUND!" # status is only used in the animation
# set minimum overlap
status = "NO MATCH FOUND!"
min_overlap = MINOVERLAP
if specific_iou_flagged:
if class_name in specific_iou_classes:
......@@ -608,23 +570,19 @@ with open(results_files_path + "/results.txt", 'w') as results_file:
min_overlap = float(iou_list[index])
if ovmax >= min_overlap:
if "difficult" not in gt_match:
if not bool(gt_match["used"]):
# true positive
tp[idx] = 1
gt_match["used"] = True
count_true_positives[class_name] += 1
# update the ".json" file
with open(gt_file, 'w') as f:
f.write(json.dumps(ground_truth_data))
if show_animation:
status = "MATCH!"
else:
# false positive (multiple detection)
fp[idx] = 1
if show_animation:
status = "REPEATED MATCH!"
if not bool(gt_match["used"]):
tp[idx] = 1
gt_match["used"] = True
count_true_positives[class_name] += 1
with open(gt_file, 'w') as f:
f.write(json.dumps(ground_truth_data))
if show_animation:
status = "MATCH!"
else:
fp[idx] = 1
if show_animation:
status = "REPEATED MATCH!"
else:
# false positive
fp[idx] = 1
if ovmax > 0:
status = "INSUFFICIENT OVERLAP"
......@@ -684,27 +642,26 @@ with open(results_files_path + "/results.txt", 'w') as results_file:
# save the image with all the objects drawn to it
cv2.imwrite(img_cumulative_path, img_cumulative)
# compute precision/recall
cumsum = 0
for idx, val in enumerate(fp):
fp[idx] += cumsum
cumsum += val
cumsum = 0
for idx, val in enumerate(tp):
tp[idx] += cumsum
cumsum += val
#print(tp)
rec = tp[:]
rec = tp[:]
for idx, val in enumerate(tp):
rec[idx] = float(tp[idx]) / gt_counter_per_class[class_name]
#print(rec)
rec[idx] = float(tp[idx]) / np.maximum(gt_counter_per_class[class_name], 1)
prec = tp[:]
for idx, val in enumerate(tp):
prec[idx] = float(tp[idx]) / (fp[idx] + tp[idx])
#print(prec)
prec[idx] = float(tp[idx]) / np.maximum((fp[idx] + tp[idx]), 1)
ap, mrec, mprec = voc_ap(rec[:], prec[:])
F1 = np.array(rec)*np.array(prec)/(np.array(prec)+np.array(rec))*2
F1 = np.array(rec)*np.array(prec)*2 / np.where((np.array(prec)+np.array(rec))==0, 1, (np.array(prec)+np.array(rec)))
sum_AP += ap
text = "{0:.2f}%".format(ap*100) + " = " + class_name + " AP " #class_name + " AP = {0:.2f}%".format(ap*100)
......@@ -717,16 +674,16 @@ with open(results_files_path + "/results.txt", 'w') as results_file:
F1_text = "0.00" + " = " + class_name + " F1 "
Recall_text = "0.00%" + " = " + class_name + " Recall "
Precision_text = "0.00%" + " = " + class_name + " Precision "
"""
Write to results.txt
"""
rounded_prec = [ '%.2f' % elem for elem in prec ]
rounded_rec = [ '%.2f' % elem for elem in rec ]
results_file.write(text + "\n Precision: " + str(rounded_prec) + "\n Recall :" + str(rounded_rec) + "\n\n")
if not args.quiet:
if(len(rec)!=0):
if len(prec)>0:
print(text + "\t||\tscore_threhold=0.5 : " + "F1=" + "{0:.2f}".format(F1[score05_idx])\
+ " ; Recall=" + "{0:.2f}%".format(rec[score05_idx]*100) + " ; Precision=" + "{0:.2f}%".format(prec[score05_idx]*100))
else:
print(text + "\t||\tscore_threhold=0.5 : F1=0.00% ; Recall=0.00% ; Precision=0.00%")
ap_dictionary[class_name] = ap
n_images = counter_images_per_class[class_name]
......@@ -738,63 +695,52 @@ with open(results_files_path + "/results.txt", 'w') as results_file:
"""
if draw_plot:
plt.plot(rec, prec, '-o')
# add a new penultimate point to the list (mrec[-2], 0.0)
# since the last line segment (and respective area) do not affect the AP value
area_under_curve_x = mrec[:-1] + [mrec[-2]] + [mrec[-1]]
area_under_curve_y = mprec[:-1] + [0.0] + [mprec[-1]]
plt.fill_between(area_under_curve_x, 0, area_under_curve_y, alpha=0.2, edgecolor='r')
# set window title
fig = plt.gcf() # gcf - get current figure
fig = plt.gcf()
fig.canvas.set_window_title('AP ' + class_name)
# set plot title
plt.title('class: ' + text)
#plt.suptitle('This is a somewhat long figure title', fontsize=16)
# set axis titles
plt.xlabel('Recall')
plt.ylabel('Precision')
# optional - set axes
axes = plt.gca() # gca - get current axes
axes = plt.gca()
axes.set_xlim([0.0,1.0])
axes.set_ylim([0.0,1.05]) # .05 to give some extra space
# Alternative option -> wait for button to be pressed
# while not plt.waitforbuttonpress(): pass # wait for key display
# Alternative option -> normal display
# plt.show()
# save the plot
axes.set_ylim([0.0,1.05])
fig.savefig(results_files_path + "/AP/" + class_name + ".png")
plt.cla() # clear axes for next plot
plt.cla()
plt.plot(score, F1, "-", color='orangered')
plt.title('class: ' + F1_text + "\nscore_threhold=0.5")
plt.xlabel('Score_Threhold')
plt.ylabel('F1')
axes = plt.gca() # gca - get current axes
axes = plt.gca()
axes.set_xlim([0.0,1.0])
axes.set_ylim([0.0,1.05]) # .05 to give some extra space
axes.set_ylim([0.0,1.05])
fig.savefig(results_files_path + "/F1/" + class_name + ".png")
plt.cla() # clear axes for next plot
plt.cla()
plt.plot(score, rec, "-H", color='gold')
plt.title('class: ' + Recall_text + "\nscore_threhold=0.5")
plt.xlabel('Score_Threhold')
plt.ylabel('Recall')
axes = plt.gca() # gca - get current axes
axes = plt.gca()
axes.set_xlim([0.0,1.0])
axes.set_ylim([0.0,1.05]) # .05 to give some extra space
axes.set_ylim([0.0,1.05])
fig.savefig(results_files_path + "/Recall/" + class_name + ".png")
plt.cla() # clear axes for next plot
plt.cla()
plt.plot(score, prec, "-s", color='palevioletred')
plt.title('class: ' + Precision_text + "\nscore_threhold=0.5")
plt.xlabel('Score_Threhold')
plt.ylabel('Precision')
axes = plt.gca() # gca - get current axes
axes = plt.gca()
axes.set_xlim([0.0,1.0])
axes.set_ylim([0.0,1.05]) # .05 to give some extra space
axes.set_ylim([0.0,1.05])
fig.savefig(results_files_path + "/Precision/" + class_name + ".png")
plt.cla() # clear axes for next plot
plt.cla()
if show_animation:
cv2.destroyAllWindows()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册