未验证 提交 1e7e9023 编写于 作者: Y Yizhuang Zhou 提交者: GitHub

feat(cls/shufflenet) use native infinite sampler (#9)

上级 b5da0a1a
......@@ -112,16 +112,6 @@ def get_parameters(model):
return groups
def infinite_iter(loader):
iterator = iter(loader)
while True:
try:
yield next(iterator)
except StopIteration:
iterator = iter(loader)
yield next(iterator)
def worker(rank, world_size, args):
if world_size > 1:
# Initialize distributed process group
......@@ -174,9 +164,9 @@ def worker(rank, world_size, args):
# Build train and valid datasets
logger.info("preparing dataset..")
train_dataset = data.dataset.ImageNet(args.data, train=True)
train_sampler = data.RandomSampler(
train_sampler = data.Infinite(data.RandomSampler(
train_dataset, batch_size=args.batch_size, drop_last=True
)
))
train_queue = data.DataLoader(
train_dataset,
sampler=train_sampler,
......@@ -193,7 +183,6 @@ def worker(rank, world_size, args):
),
num_workers=args.workers,
)
train_queue = infinite_iter(train_queue)
valid_dataset = data.dataset.ImageNet(args.data, train=False)
valid_sampler = data.SequentialSampler(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册