diff --git a/chapter07/src/bert_for_pre_training.py b/chapter07/src/bert_for_pre_training.py index 802391ee861cdc4780ba2e52b7646f7b08709e62..4f3dbc80683e11f0363fe87f6d7bcd5ee79d87dc 100644 --- a/chapter07/src/bert_for_pre_training.py +++ b/chapter07/src/bert_for_pre_training.py @@ -272,7 +272,7 @@ class BertTrainOneStepCell(nn.Cell): self.network = network self.weights = ParameterTuple(network.trainable_params()) self.optimizer = optimizer - self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) + self.grad = C.GradOperation(get_by_list=True, sens_param=True) self.sens = sens self.reducer_flag = False self.parallel_mode = context.get_auto_parallel_context("parallel_mode") @@ -351,8 +351,7 @@ class BertTrainOneStepWithLossScaleCell(nn.Cell): self.network = network self.weights = ParameterTuple(network.trainable_params()) self.optimizer = optimizer - self.grad = C.GradOperation('grad', - get_by_list=True, + self.grad = C.GradOperation(get_by_list=True, sens_param=True) self.reducer_flag = False self.allreduce = P.AllReduce()