From 40f715989dfb9bb6ad940f28edeba8f0b62f1528 Mon Sep 17 00:00:00 2001 From: JiaQi Xu <47347516+bubbliiiing@users.noreply.github.com> Date: Tue, 26 May 2020 19:48:19 +0800 Subject: [PATCH] Update yolo_training.py --- nets/yolo_training.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/nets/yolo_training.py b/nets/yolo_training.py index 04678d7..eb83296 100644 --- a/nets/yolo_training.py +++ b/nets/yolo_training.py @@ -47,7 +47,7 @@ def box_ciou(b1, b2): b1_area = b1_wh[..., 0] * b1_wh[..., 1] b2_area = b2_wh[..., 0] * b2_wh[..., 1] union_area = b1_area + b2_area - intersect_area - iou = intersect_area / (union_area + 1e-6) + iou = intersect_area / torch.clamp(union_area,min = 1e-6) # 计算中心的差距 center_distance = torch.sum(torch.pow((b1_xy - b2_xy), 2), axis=-1) @@ -58,13 +58,13 @@ def box_ciou(b1, b2): enclose_wh = torch.max(enclose_maxes - enclose_mins, torch.zeros_like(intersect_maxes)) # 计算对角线距离 enclose_diagonal = torch.sum(torch.pow(enclose_wh,2), axis=-1) - ciou = iou - 1.0 * (center_distance) / (enclose_diagonal + 1e-7) + ciou = iou - 1.0 * (center_distance) / torch.clamp(enclose_diagonal,min = 1e-6) - v = (4 / (math.pi ** 2)) * torch.pow((torch.atan(b1_wh[..., 0]/b1_wh[..., 1]) - torch.atan(b2_wh[..., 0]/b2_wh[..., 1])), 2) - alpha = v / (1.0 - iou + v) + v = (4 / (math.pi ** 2)) * torch.pow((torch.atan(b1_wh[..., 0]/torch.clamp(b1_wh[..., 1],min = 1e-6)) - torch.atan(b2_wh[..., 0]/torch.clamp(b2_wh[..., 1],min = 1e-6))), 2) + alpha = v / torch.clamp((1.0 - iou + v),min=1e-6) ciou = ciou - alpha * v return ciou - + def clip_by_tensor(t,t_min,t_max): t=t.float() result = (t >= t_min).float() * t + (t < t_min).float() * t_min @@ -504,4 +504,4 @@ class Generator(object): tmp_targets = np.array(targets) inputs = [] targets = [] - yield tmp_inp, tmp_targets \ No newline at end of file + yield tmp_inp, tmp_targets -- GitLab