未验证 提交 4d87afd6 编写于 作者: L liu zhengxi 提交者: GitHub

Fix hung (#5121)

* fix hung

* add shuffle batch

* update

* reader_seed to shuffle_seed

* seed for shuffle batch
上级 047b8b69
...@@ -27,6 +27,12 @@ pool_size: 200000 ...@@ -27,6 +27,12 @@ pool_size: 200000
sort_type: "global" sort_type: "global"
batch_size: 4096 batch_size: 4096
infer_batch_size: 16 infer_batch_size: 16
shuffle_batch: True
# Data shuffle only works when sort_type is pool or none
shuffle: True
# shuffle_seed must be set when shuffle is True and using multi-cards to train.
# Otherwise, the number of batches cannot be guaranteed.
shuffle_seed: 128
# Hyparams for training: # Hyparams for training:
# The number of epoches for training # The number of epoches for training
......
...@@ -43,6 +43,12 @@ def create_data_loader(args): ...@@ -43,6 +43,12 @@ def create_data_loader(args):
mode=m, transform_func=transform_func) for m in ["train", "dev"] mode=m, transform_func=transform_func) for m in ["train", "dev"]
] ]
if args.shuffle or args.shuffle_batch:
if args.shuffle_seed == "None" or args.shuffle_seed is None:
shuffle_seed = 0
else:
shuffle_seed = args.shuffle_seed
def _max_token_fn(current_idx, current_batch_size, tokens_sofar, def _max_token_fn(current_idx, current_batch_size, tokens_sofar,
data_source): data_source):
return max(tokens_sofar, return max(tokens_sofar,
...@@ -69,7 +75,8 @@ def create_data_loader(args): ...@@ -69,7 +75,8 @@ def create_data_loader(args):
key=trg_key, buffer_size=buffer_size).sort( key=trg_key, buffer_size=buffer_size).sort(
key=src_key, buffer_size=buffer_size) key=src_key, buffer_size=buffer_size)
else: else:
sampler = sampler.shuffle() if args.shuffle:
sampler = sampler.shuffle(seed=shuffle_seed)
if args.sort_type == SortType.POOL: if args.sort_type == SortType.POOL:
buffer_size = args.pool_size buffer_size = args.pool_size
sampler = sampler.sort(key=src_key, buffer_size=buffer_size) sampler = sampler.sort(key=src_key, buffer_size=buffer_size)
...@@ -83,6 +90,9 @@ def create_data_loader(args): ...@@ -83,6 +90,9 @@ def create_data_loader(args):
if m == "train": if m == "train":
batch_sampler = batch_sampler.shard() batch_sampler = batch_sampler.shard()
if args.shuffle_batch:
batch_sampler.shuffle(seed=shuffle_seed)
data_loader = DataLoader( data_loader = DataLoader(
dataset=dataset, dataset=dataset,
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
......
...@@ -27,6 +27,12 @@ pool_size: 200000 ...@@ -27,6 +27,12 @@ pool_size: 200000
sort_type: "global" sort_type: "global"
batch_size: 4096 batch_size: 4096
infer_batch_size: 8 infer_batch_size: 8
shuffle_batch: True
# Data shuffle only works when sort_type is pool or none
shuffle: True
# shuffle_seed must be set when shuffle is True and using multi-cards to train.
# Otherwise, the number of batches cannot be guaranteed.
shuffle_seed: 128
# Hyparams for training: # Hyparams for training:
# The number of epoches for training # The number of epoches for training
......
...@@ -27,6 +27,12 @@ pool_size: 200000 ...@@ -27,6 +27,12 @@ pool_size: 200000
sort_type: "global" sort_type: "global"
batch_size: 4096 batch_size: 4096
infer_batch_size: 8 infer_batch_size: 8
shuffle_batch: True
# Data shuffle only works when sort_type is pool or none
shuffle: True
# shuffle_seed must be set when shuffle is True and using multi-cards to train.
# Otherwise, the number of batches cannot be guaranteed.
shuffle_seed: 128
# Hyparams for training: # Hyparams for training:
# The number of epoches for training # The number of epoches for training
......
...@@ -43,6 +43,12 @@ def create_data_loader(args): ...@@ -43,6 +43,12 @@ def create_data_loader(args):
mode=m, transform_func=transform_func) for m in ["train", "dev"] mode=m, transform_func=transform_func) for m in ["train", "dev"]
] ]
if args.shuffle or args.shuffle_batch:
if args.shuffle_seed == "None" or args.shuffle_seed is None:
shuffle_seed = 0
else:
shuffle_seed = args.shuffle_seed
def _max_token_fn(current_idx, current_batch_size, tokens_sofar, def _max_token_fn(current_idx, current_batch_size, tokens_sofar,
data_source): data_source):
return max(tokens_sofar, return max(tokens_sofar,
...@@ -69,7 +75,8 @@ def create_data_loader(args): ...@@ -69,7 +75,8 @@ def create_data_loader(args):
key=trg_key, buffer_size=buffer_size).sort( key=trg_key, buffer_size=buffer_size).sort(
key=src_key, buffer_size=buffer_size) key=src_key, buffer_size=buffer_size)
else: else:
sampler = sampler.shuffle() if args.shuffle:
sampler = sampler.shuffle(seed=shuffle_seed)
if args.sort_type == SortType.POOL: if args.sort_type == SortType.POOL:
buffer_size = args.pool_size buffer_size = args.pool_size
sampler = sampler.sort(key=src_key, buffer_size=buffer_size) sampler = sampler.sort(key=src_key, buffer_size=buffer_size)
...@@ -83,6 +90,9 @@ def create_data_loader(args): ...@@ -83,6 +90,9 @@ def create_data_loader(args):
if m == "train": if m == "train":
batch_sampler = batch_sampler.shard() batch_sampler = batch_sampler.shard()
if args.shuffle_batch:
batch_sampler.shuffle(seed=shuffle_seed)
data_loader = DataLoader( data_loader = DataLoader(
dataset=dataset, dataset=dataset,
batch_sampler=batch_sampler, batch_sampler=batch_sampler,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册