提交 260d308a 编写于 作者: Y yao_yf

adapt parallel interface change

上级 bacd6196
......@@ -31,7 +31,8 @@ import mindspore.dataset.transforms.c_transforms as C2
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.nn.optim.momentum import Momentum
from mindspore.train.model import Model, ParallelMode
from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore import context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
......@@ -113,7 +114,7 @@ def create_dataset(repeat_num=1, training=True):
if __name__ == '__main__':
if args_opt.mode == 'train' and args_opt.run_distribute:
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True)
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True)
auto_parallel_context().set_all_reduce_fusion_split_indices([140])
init()
......
......@@ -24,7 +24,7 @@ import mindspore.communication.management as D
import mindspore.common.dtype as mstype
from mindspore import context
from mindspore.train.model import Model
from mindspore.train.parallel_utils import ParallelMode
from mindspore.context import ParallelMode
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
......@@ -77,7 +77,7 @@ def run_pretrain():
ckpt_save_dir = args_opt.save_checkpoint_path + 'ckpt_' + str(rank) + '/'
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True,
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
device_num=device_num)
from mindspore.parallel._auto_parallel_context import auto_parallel_context
if bert_net_cfg.num_hidden_layers == 12:
......
......@@ -24,7 +24,7 @@ from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter, ParameterTuple
from mindspore.common import dtype as mstype
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
from mindspore.train.parallel_utils import ParallelMode
from mindspore.context import ParallelMode
from mindspore.communication.management import get_group_size
from mindspore import context
from mindspore.ops import _selected_ops
......@@ -280,7 +280,7 @@ class BertTrainOneStepCell(nn.Cell):
self.reducer_flag = True
self.grad_reducer = None
if self.reducer_flag:
mean = context.get_auto_parallel_context("mirror_mean")
mean = context.get_auto_parallel_context("gradients_mean")
degree = get_group_size()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册