diff --git a/nets/ssd_training.py b/nets/ssd_training.py index e56d81449c8866cc2846730bb61ad2ea92478024..b019d0814aab62b7ccc1d88efb3052e612ca0d80 100644 --- a/nets/ssd_training.py +++ b/nets/ssd_training.py @@ -283,7 +283,8 @@ class Generator(object): inputs = [] targets = [] yield tmp_inp, tmp_targets - + + class LossHistory(): def __init__(self, log_dir): import datetime @@ -315,9 +316,10 @@ class LossHistory(): plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss') plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss') try: - num = len(self.losses) /3 - num = num if num % 2 else num + 1 - num = max(num, 5) + if len(self.losses) < 25: + num = 5 + else: + num = 15 plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss') plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss')