未验证 提交 1bf73072 编写于 作者: B Bubbliiiing 提交者: GitHub

Update dataloader.py

上级 bef8bde8
......@@ -14,13 +14,13 @@ from torch.utils.data.dataset import Dataset
MEANS = (104, 117, 123)
class SSDDataset(Dataset):
def __init__(self, train_lines, image_size, is_val):
def __init__(self, train_lines, image_size, is_train):
super(SSDDataset, self).__init__()
self.train_lines = train_lines
self.train_batches = len(train_lines)
self.image_size = image_size
self.is_val = is_val
self.is_train = is_train
def __len__(self):
return self.train_batches
......@@ -127,10 +127,10 @@ class SSDDataset(Dataset):
def __getitem__(self, index):
lines = self.train_lines
if self.is_val:
img, y = self.get_random_data(lines[index], self.image_size[0:2], random=False)
else:
if self.is_train:
img, y = self.get_random_data(lines[index], self.image_size[0:2])
else:
img, y = self.get_random_data(lines[index], self.image_size[0:2], random=False)
boxes = np.array(y[:,:4],dtype=np.float32)
boxes[:,0] = boxes[:,0]/self.image_size[1]
......@@ -138,7 +138,7 @@ class SSDDataset(Dataset):
boxes[:,2] = boxes[:,2]/self.image_size[1]
boxes[:,3] = boxes[:,3]/self.image_size[0]
boxes = np.maximum(np.minimum(boxes,1),0)
y = np.concatenate([boxes, y[:,-1:]],axis=-1)
img = np.array(img, dtype = np.float32)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册