diff --git a/mindspore/context.py b/mindspore/context.py index 8b5023f957dcdf902449f87243ec561a86f1816b..ea08960182f103c2db9cae7d36bebfb6f7b998cb 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -325,7 +325,8 @@ def _context(): @args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool, parallel_mode=str, auto_parallel_search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str, - strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool) + strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool, + all_reduce_fusion_config=list) def set_auto_parallel_context(**kwargs): """ Set auto parallel context. @@ -371,8 +372,9 @@ def set_auto_parallel_context(**kwargs): strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: '' strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: '' full_batch (bool): Whether to load the whole batch on each device. Default: False. - enable_parallel_optimizer(bool): This is a developing feature, which shards the weight update computation in + enable_parallel_optimizer (bool): This is a developing feature, which shards the weight update computation in data parallel training in the benefit of time and memory saving. + all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices. Raises: ValueError: If input key is not attribute in auto parallel context. diff --git a/mindspore/parallel/_auto_parallel_context.py b/mindspore/parallel/_auto_parallel_context.py index 6756912d64173feea20f7aa57c0a03f55d3c7147..0e543eb54a7ed369756d4f64690c96db64406da8 100644 --- a/mindspore/parallel/_auto_parallel_context.py +++ b/mindspore/parallel/_auto_parallel_context.py @@ -462,7 +462,8 @@ _set_auto_parallel_context_func_map = { "strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file, "strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file, "full_batch": auto_parallel_context().set_full_batch, - "enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer} + "enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer, + "all_reduce_fusion_config": auto_parallel_context().set_all_reduce_fusion_split_indices} _get_auto_parallel_context_func_map = { @@ -477,13 +478,15 @@ _get_auto_parallel_context_func_map = { "strategy_ckpt_load_file": auto_parallel_context().get_strategy_ckpt_load_file, "strategy_ckpt_save_file": auto_parallel_context().get_strategy_ckpt_save_file, "full_batch": auto_parallel_context().get_full_batch, - "enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer} + "enable_parallel_optimizer": auto_parallel_context().get_enable_parallel_optimizer, + "all_reduce_fusion_config": auto_parallel_context().get_all_reduce_fusion_split_indices} @args_type_check(device_num=int, global_rank=int, gradients_mean=bool, gradient_fp32_sync=bool, loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str, - strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool) + strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool, + all_reduce_fusion_config=list) def _set_auto_parallel_context(**kwargs): """ @@ -526,6 +529,7 @@ def _set_auto_parallel_context(**kwargs): strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: '' full_batch (bool): Whether to load the whole batch on each device. Default: False. enable_parallel_optimizer (bool): Enable using optimizer segmentation or not. Default: False. + all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices. Raises: ValueError: If input key is not attribute in auto parallel context. diff --git a/model_zoo/official/cv/mobilenetv2/src/utils.py b/model_zoo/official/cv/mobilenetv2/src/utils.py index 5a05f397a4afb4464bdb4464b0934e4f36313a83..b56b592b78f1894b2b8553190fbe841c750b5a25 100644 --- a/model_zoo/official/cv/mobilenetv2/src/utils.py +++ b/model_zoo/official/cv/mobilenetv2/src/utils.py @@ -47,8 +47,8 @@ def context_device_init(config): if config.run_distribute: context.set_auto_parallel_context(device_num=config.rank_size, parallel_mode=ParallelMode.DATA_PARALLEL, - parameter_broadcast=True, gradients_mean=True) - auto_parallel_context().set_all_reduce_fusion_split_indices([140]) + parameter_broadcast=True, gradients_mean=True, + all_reduce_fusion_config=[140]) init() else: raise ValueError("Only support CPU, GPU and Ascend.") diff --git a/model_zoo/official/cv/resnet/train.py b/model_zoo/official/cv/resnet/train.py index a45e26bd834813cffb23686527f603fcc4473d00..0f22a261929d705b88a37f573d5fde3940716459 100755 --- a/model_zoo/official/cv/resnet/train.py +++ b/model_zoo/official/cv/resnet/train.py @@ -18,7 +18,6 @@ import argparse import ast from mindspore import context from mindspore import Tensor -from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.nn.optim.momentum import Momentum from mindspore.train.model import Model from mindspore.context import ParallelMode @@ -78,9 +77,9 @@ if __name__ == '__main__': context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True) if args_opt.net == "resnet50" or args_opt.net == "se-resnet50": - auto_parallel_context().set_all_reduce_fusion_split_indices([85, 160]) + context.set_auto_parallel_context(all_reduce_fusion_config=[85, 150]) else: - auto_parallel_context().set_all_reduce_fusion_split_indices([180, 313]) + context.set_auto_parallel_context(all_reduce_fusion_config=[180, 313]) init() # GPU target else: @@ -88,7 +87,7 @@ if __name__ == '__main__': context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True) if args_opt.net == "resnet50": - auto_parallel_context().set_all_reduce_fusion_split_indices([85, 160]) + context.set_auto_parallel_context(all_reduce_fusion_config=[85, 160]) ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/" # create dataset diff --git a/model_zoo/official/cv/resnet50_quant/train.py b/model_zoo/official/cv/resnet50_quant/train.py index f4a3965c838794b6e30ba5a21447bbd8bf2cc34b..fa968c7d9a751cda50f3dbd2981b26281413533f 100755 --- a/model_zoo/official/cv/resnet50_quant/train.py +++ b/model_zoo/official/cv/resnet50_quant/train.py @@ -19,7 +19,6 @@ import argparse from mindspore import context from mindspore import Tensor -from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.nn.optim.momentum import Momentum from mindspore.train.model import Model from mindspore.context import ParallelMode @@ -80,8 +79,7 @@ if __name__ == '__main__': init() context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL, - gradients_mean=True) - auto_parallel_context().set_all_reduce_fusion_split_indices([107, 160]) + gradients_mean=True, all_reduce_fusion_config=[107, 160]) # define network net = resnet50_quant(class_num=config.class_num) diff --git a/model_zoo/official/cv/resnet_thor/train.py b/model_zoo/official/cv/resnet_thor/train.py index 29d4e58e32169fb3ae5953340621751be2a13392..7de1034628b821a8ff6f1605a7c7244d6fd38b8d 100644 --- a/model_zoo/official/cv/resnet_thor/train.py +++ b/model_zoo/official/cv/resnet_thor/train.py @@ -20,7 +20,6 @@ import numpy as np from mindspore import context from mindspore import Tensor from mindspore.common import set_seed -from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.context import ParallelMode from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor from mindspore.train.loss_scale_manager import FixedLossScaleManager @@ -94,15 +93,13 @@ if __name__ == '__main__': device_id = int(os.getenv('DEVICE_ID')) context.set_context(device_id=device_id, enable_auto_mixed_precision=True) context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL, - gradients_mean=True) - auto_parallel_context().set_all_reduce_fusion_split_indices([107]) + gradients_mean=True, all_reduce_fusion_config=[107]) init() # GPU target else: init() context.set_auto_parallel_context(device_num=get_group_size(), parallel_mode=ParallelMode.DATA_PARALLEL, - gradients_mean=True) - auto_parallel_context().set_all_reduce_fusion_split_indices([107]) + gradients_mean=True, all_reduce_fusion_config=[104]) ckpt_save_dir = config.save_checkpoint_path + "ckpt_" + str(get_rank()) + "/" # create dataset diff --git a/model_zoo/official/nlp/bert/run_pretrain.py b/model_zoo/official/nlp/bert/run_pretrain.py index 6836da1d73509123a8d04c106dd3d34bad647036..48802efa9ab3050461b10ffd6c33feb804e7b541 100644 --- a/model_zoo/official/nlp/bert/run_pretrain.py +++ b/model_zoo/official/nlp/bert/run_pretrain.py @@ -87,17 +87,16 @@ def run_pretrain(): context.reset_auto_parallel_context() context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True, device_num=device_num) - from mindspore.parallel._auto_parallel_context import auto_parallel_context if bert_net_cfg.num_hidden_layers == 12: if bert_net_cfg.use_relative_positions: - auto_parallel_context().set_all_reduce_fusion_split_indices([29, 58, 87, 116, 145, 174, 203, 217]) + context.set_auto_parallel_context(all_reduce_fusion_config=[29, 58, 87, 116, 145, 174, 203, 217]) else: - auto_parallel_context().set_all_reduce_fusion_split_indices([28, 55, 82, 109, 136, 163, 190, 205]) + context.set_auto_parallel_context(all_reduce_fusion_config=[28, 55, 82, 109, 136, 163, 190, 205]) elif bert_net_cfg.num_hidden_layers == 24: if bert_net_cfg.use_relative_positions: - auto_parallel_context().set_all_reduce_fusion_split_indices([30, 90, 150, 210, 270, 330, 390, 421]) + context.set_auto_parallel_context(all_reduce_fusion_config=[30, 90, 150, 210, 270, 330, 390, 421]) else: - auto_parallel_context().set_all_reduce_fusion_split_indices([38, 93, 148, 203, 258, 313, 368, 397]) + context.set_auto_parallel_context(all_reduce_fusion_config=[38, 93, 148, 203, 258, 313, 368, 397]) else: rank = 0 device_num = 1 diff --git a/tests/st/networks/models/resnet50/test_resnet50_imagenet.py b/tests/st/networks/models/resnet50/test_resnet50_imagenet.py index 3efdf78310ca62d8946ec71d63374f301c73b154..78ee7fdaf2430c98f7d2df591732e19755f2e8b6 100644 --- a/tests/st/networks/models/resnet50/test_resnet50_imagenet.py +++ b/tests/st/networks/models/resnet50/test_resnet50_imagenet.py @@ -23,7 +23,6 @@ import numpy as np from mindspore import context, Tensor from mindspore.communication.management import init -from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.train.model import Model from mindspore.context import ParallelMode from mindspore.train.callback import Callback @@ -137,8 +136,8 @@ def train_process(q, device_id, epoch_size, device_num, enable_hccl): os.environ['RANK_SIZE'] = str(device_num) if enable_hccl: context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, - gradients_mean=True, parameter_broadcast=True) - auto_parallel_context().set_all_reduce_fusion_split_indices([107, 160]) + gradients_mean=True, parameter_broadcast=True, + all_reduce_fusion_config=[107, 160]) init() # network @@ -240,8 +239,8 @@ def train_process_thor(q, device_id, epoch_size, device_num, enable_hccl): os.environ['RANK_SIZE'] = str(device_num) if enable_hccl: context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, - gradients_mean=True, parameter_broadcast=True) - auto_parallel_context().set_all_reduce_fusion_split_indices([107]) + gradients_mean=True, parameter_broadcast=True, + all_reduce_fusion_config=[107]) init() # network diff --git a/tests/st/tbe_networks/resnet_cifar.py b/tests/st/tbe_networks/resnet_cifar.py index c40b4809be70b9500b2f6f78d0b12c8697718c47..909cf1829863ead843aacaef7331bd2e029fa4ef 100644 --- a/tests/st/tbe_networks/resnet_cifar.py +++ b/tests/st/tbe_networks/resnet_cifar.py @@ -31,7 +31,6 @@ from mindspore import context from mindspore.communication.management import init from mindspore.nn.optim.momentum import Momentum from mindspore.ops import operations as P -from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.train.model import Model from mindspore.context import ParallelMode @@ -124,8 +123,8 @@ class CrossEntropyLoss(nn.Cell): if __name__ == '__main__': if not args_opt.do_eval and args_opt.run_distribute: - context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL) - auto_parallel_context().set_all_reduce_fusion_split_indices([140]) + context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL, + all_reduce_fusion_config=[140]) init() context.set_context(mode=context.GRAPH_MODE) diff --git a/tests/st/tbe_networks/test_resnet_cifar_8p.py b/tests/st/tbe_networks/test_resnet_cifar_8p.py index 7eefdcc7a915d4b8f5a9671f5fbe74b99683263c..5b610cceba0f43578541f6e7f45aa06a31286f3f 100644 --- a/tests/st/tbe_networks/test_resnet_cifar_8p.py +++ b/tests/st/tbe_networks/test_resnet_cifar_8p.py @@ -30,7 +30,6 @@ from mindspore import context from mindspore.communication.management import init from mindspore.nn.optim.momentum import Momentum from mindspore.ops import operations as P -from mindspore.parallel._auto_parallel_context import auto_parallel_context from mindspore.train.callback import Callback from mindspore.train.model import Model from mindspore.context import ParallelMode @@ -154,8 +153,7 @@ def train_process(q, device_id, epoch_size, num_classes, device_num, batch_size, os.environ['RANK_SIZE'] = str(device_num) if enable_hccl: context.set_auto_parallel_context( - device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL) - auto_parallel_context().set_all_reduce_fusion_split_indices([140]) + device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL, all_reduce_fusion_config=[140]) init() context.set_context(mode=context.GRAPH_MODE) net = resnet50(batch_size, num_classes) diff --git a/tests/ut/python/parallel/test_parallel_optimizer.py b/tests/ut/python/parallel/test_parallel_optimizer.py index ca5fe0ac3e71c732ef0cad3d1a7062279e403f00..f164716bb5a32cf0780cde64fb337f9718c0a246 100644 --- a/tests/ut/python/parallel/test_parallel_optimizer.py +++ b/tests/ut/python/parallel/test_parallel_optimizer.py @@ -23,7 +23,6 @@ from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn.optim import Adam, AdamWeightDecay, Lamb from mindspore.ops import operations as P from mindspore import context -from mindspore.parallel._auto_parallel_context import auto_parallel_context class Net(nn.Cell): """Net definition""" @@ -85,8 +84,8 @@ def test_lamb_compile(): def test_lamb_split_fusion(): """ test_Lamb_split_fusion """ - context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True) - auto_parallel_context().set_all_reduce_fusion_split_indices([2, 4, 6, 8]) + context.set_auto_parallel_context(parallel_mode="data_parallel", device_num=2, enable_parallel_optimizer=True, + all_reduce_fusion_config=[2, 4, 6, 8]) inputs = Tensor(np.ones([32, 128]).astype(np.float32)) label = Tensor(np.zeros([32, 768]).astype(np.float32)) net = Net()