diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index 4f56666a64ba387fc979c93d51a1454a2a599165..597cc40ae0647afbfe1a8a769804f1399f7d0cfe 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -1664,6 +1664,16 @@ class Executor(object): print_period, fetch_handler, use_program_cache) + from paddle.optimizer.lr import LRScheduler + if hasattr(program, 'lr_sheduler'): + lr_sheduler = program.lr_sheduler + assert isinstance(lr_sheduler, LRScheduler), "must be LRScheduler" + lr_value = lr_sheduler() + lr_var = program.global_block().vars[lr_sheduler._var_name] + data = np.array([lr_value]).astype(convert_dtype(lr_var.dtype)) + tensor = core.get_variable_tensor(scope, lr_sheduler._var_name) + tensor.set(data, self.place) + self._default_executor.run_from_dataset(trainer_instance) if not use_program_cache: diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index d60e07674edad0fa40d2fddebc45b0ae68c5df24..58ce5eb66a3eed0912f4955875a7c717b10428e4 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -4634,6 +4634,9 @@ class PipelineOptimizer(object): op.type == 'elementwise_div'): device = f"{self._device}:all" op._set_attr(self._op_device_key, device) + elif self._is_weight_decay_op(op) and op.type == 'scale': + # set AdamW decay_coeff to device:all + op._set_attr(self._op_device_key, f"{self._device}:all") elif op.type == "alloc_float_status": op._set_attr(self._op_device_key, f"{self._device}:all") else: @@ -5267,6 +5270,11 @@ class PipelineOptimizer(object): return op.desc.has_attr("op_namescope") \ and op.desc.attr("op_namescope").startswith("/regularization") + def _is_weight_decay_op(self, op): + # in AdamW namescope is /optimizer_*/weight decay/ + return op.desc.has_attr("op_namescope") \ + and 'weight decay' in op.desc.attr("op_namescope") + def _get_input_output_info(self, block): ''' Get info of op input and output. diff --git a/python/paddle/fluid/tests/unittests/pipeline_mnist.py b/python/paddle/fluid/tests/unittests/pipeline_mnist.py index 8c3a66f933f59ddb01a624c57c3b1573e71c953e..37e992c4d1365b88b45826fa186171bf5a3514ce 100644 --- a/python/paddle/fluid/tests/unittests/pipeline_mnist.py +++ b/python/paddle/fluid/tests/unittests/pipeline_mnist.py @@ -116,10 +116,10 @@ class TestDistMnist2x2(TestDistRunnerBase): steps_per_pass = 10 bd = [steps_per_pass * p for p in passes] lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)] - lr_val = fluid.layers.piecewise_decay(boundaries=bd, values=lr) - opt = fluid.optimizer.Momentum( + lr_val = paddle.optimizer.lr.PiecewiseDecay(boundaries=bd, values=lr) + + opt = paddle.optimizer.AdamW( learning_rate=lr_val, - momentum=0.9, grad_clip=fluid.clip.GradientClipByGlobalNorm(clip_norm=1.0)) acc_steps = 2 # accumulated steps for pipeline diff --git a/python/paddle/fluid/tests/unittests/test_dist_base.py b/python/paddle/fluid/tests/unittests/test_dist_base.py index 446b9a1e697e9053ae0fb84e7f6254030252b902..f5e67f2ddfaccdbd56873e0b1ebf92856d6e5369 100755 --- a/python/paddle/fluid/tests/unittests/test_dist_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_base.py @@ -96,6 +96,15 @@ class TestDistRunnerBase(object): current_endpoint=current_endpoint) return t + @staticmethod + def get_lr_scheduler(program): + lr_sheduler = None + if hasattr(program, 'lr_sheduler'): + from paddle.optimizer.lr import LRScheduler + lr_sheduler = program.lr_sheduler + assert isinstance(lr_sheduler, LRScheduler), "must be LRScheduler" + return lr_sheduler + def run_pserver(self, args): self.lr = args.lr self.get_model(batch_size=args.batch_size) @@ -139,11 +148,17 @@ class TestDistRunnerBase(object): data_loader.start() print_to_err(type(self).__name__, "begin to train on trainer") out_losses = [] + + main_program = fluid.default_main_program() + lr_sheduler = self.get_lr_scheduler(main_program) for i in six.moves.xrange(RUN_STEP): - loss = exe.run(fluid.default_main_program(), fetch_list=[avg_cost]) + loss = exe.run(main_program, fetch_list=[avg_cost]) loss = loss[0] if loss else None out_losses.append(loss) print_to_err(type(self).__name__, "run step %d finished" % i) + if lr_sheduler is not None: + lr_sheduler.step() + data_loader.reset() print_to_err(type(self).__name__, "trainer run finished") @@ -494,6 +509,7 @@ class TestDistRunnerBase(object): else: return origin_batch + lr_scheduler = self.get_lr_scheduler(trainer_prog) print_to_err(type(self).__name__, "begin to train on trainer") out_losses = [] for i in six.moves.xrange(RUN_STEP): @@ -502,6 +518,9 @@ class TestDistRunnerBase(object): feed=feeder.feed(get_data())) out_losses.append(loss[0]) print_to_err(type(self).__name__, "run step %d finished" % i) + if lr_scheduler is not None: + lr_scheduler.step() + print_to_err(type(self).__name__, "trainer run finished") print_to_out(out_losses) diff --git a/python/paddle/optimizer/adamw.py b/python/paddle/optimizer/adamw.py index c3cffa2998f6cc0956412be7709251720f8a51db..d39de0ae7683f5c798cb54b0f8a68f586a33fc64 100644 --- a/python/paddle/optimizer/adamw.py +++ b/python/paddle/optimizer/adamw.py @@ -160,6 +160,7 @@ class AdamW(Adam): self._apply_decay_param_fun = apply_decay_param_fun self._coeff = coeff self._lr_to_coeff = dict() + super(AdamW, self).__init__( learning_rate=learning_rate, parameters=parameters, @@ -211,7 +212,9 @@ class AdamW(Adam): # we do this in _create_optimization_pass decay_coeff = self._lr_to_coeff.get(learning_rate, None) if decay_coeff is None: - decay_coeff = 1.0 - learning_rate * self._coeff + # NOTE(wangxi): for pipeline to set device:all + with paddle.static.device_guard(None): + decay_coeff = 1.0 - learning_rate * self._coeff self._lr_to_coeff[learning_rate] = decay_coeff find_master = (self._multi_precision and