未验证 提交 828f87ae 编写于 作者: B Baibaifan 提交者: GitHub

sharding_stage2_pfp16 (#37836)

上级 3e33ef5a
......@@ -83,8 +83,14 @@ class ShardingOptimizerStage2(Optimizer):
# Default information
self._optim_defaults = kw
self._optim = optim
assert hasattr(self._optim, "_master_weights"
), "Must use optimizer with _master_weights attribute"
self._local_params = params
self._default_device = device
self._pfp16 = len(
list(
filter(lambda x: x.trainable and x.dtype == Type.fp16.value,
self._local_params))) > 0
assert group is not None, "Distributed communication group is must be gived"
self.group = group
......@@ -98,6 +104,12 @@ class ShardingOptimizerStage2(Optimizer):
# Update optimizer parameters and adjust parameter storage and use according to rank.
self.update_opt_status()
def _generate_master_params(self, trainable_params):
for param in trainable_params:
if param.dtype == Type.fp16.value:
self._optim._master_weights[param.name] = paddle.cast(
param, Type.fp32.value)
def update_opt_status(self):
"""Update optimizer status and parameter storage information, and special functions to be developed.
"""
......@@ -207,6 +219,8 @@ class ShardingOptimizerStage2(Optimizer):
# Merge all the trainable params in a single InternalStorage
trainable_params = list(
filter(lambda x: x.trainable, params))
if self._pfp16 and dst_rank == self.rank:
self._generate_master_params(trainable_params)
if trainable_params:
param_storage = ParamStorage(
size=self.rank_buffer_size[dtype][dst_rank],
......
......@@ -30,6 +30,7 @@ from paddle import nn
import paddle.distributed as dist
from ...utils.internal_storage import GradStorage
from ...meta_optimizers.dygraph_optimizer.sharding_optimizer_stage2 import ShardingOptimizerStage2
from .sharding_utils import Taskflow, Type
......@@ -70,6 +71,11 @@ class ShardingStage2(nn.Layer):
self._layer = layer
self._sharding_optimizers = [sharding_optimizer] if not isinstance(
sharding_optimizer, list) else sharding_optimizer
assert all(
list(
map(lambda opt: isinstance(opt, ShardingOptimizerStage2),
self._sharding_optimizers))
), "Please use ShardingOptimizerStage2 optimizer"
self._sync_buffers = sync_buffers
self._auto_refresh_trainable = auto_refresh_trainable
......@@ -88,8 +94,7 @@ class ShardingStage2(nn.Layer):
# Global statistical parameters
self._all_params = list(
chain(
* [optim.local_params for optim in self._sharding_optimizers]))
chain(*[optim.local_params for optim in self._sharding_optimizers]))
self._trainable_params = []
self._grad_reduced = []
self._trainable_param2rank = {}
......@@ -436,7 +441,7 @@ class ShardingStage2(nn.Layer):
._fill))
self._grad_storage_list = list(
chain(* [
chain(*[
self._grad_storages[dtype].values()
for dtype in self._grad_storages.keys()
]))
......
......@@ -24,7 +24,6 @@ from paddle.fluid.dygraph.nn import Linear
from paddle.distributed import fleet
from paddle.fluid.dygraph import nn
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import DygraphShardingOptimizer
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.sharding_optimizer_stage2 import ShardingOptimizerStage2
from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage2 import ShardingStage2
......@@ -70,7 +69,7 @@ def reader_decorator():
return __reader__
def optimizer_setting(model, use_pure_fp16, stage=1):
def optimizer_setting(model, use_pure_fp16):
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
optimizer = paddle.optimizer.AdamW(
parameters=model.parameters(),
......@@ -87,20 +86,16 @@ def train_mlp(model,
use_pure_fp16=False,
all_test=False,
accumulate_grad=False):
if sharding_stage == 1:
if sharding_stage == "dp":
hcg = fleet.get_hybrid_communicate_group()
group = hcg.get_check_parallel_group()
else:
group = paddle.distributed.new_group([0, 1])
optimizer = optimizer_setting(
model=model, use_pure_fp16=use_pure_fp16, stage=sharding_stage)
optimizer = optimizer_setting(model=model, use_pure_fp16=use_pure_fp16)
if use_pure_fp16:
model, optimizer = paddle.amp.decorate(
models=model,
optimizers=optimizer,
level='O2',
save_dtype='float32')
model = paddle.amp.decorate(
models=model, level='O2', save_dtype='float32')
if sharding_stage == 2:
optimizer = ShardingOptimizerStage2(
......@@ -164,7 +159,7 @@ def train_mlp(model,
return model.parameters()
def test_stage1_stage2():
def test_dp_stage2():
mlp = MLP()
state_dict = mlp.state_dict()
mlp1 = MLP()
......@@ -175,11 +170,13 @@ def test_stage1_stage2():
mlp2.set_state_dict(state_dict)
mlp3.set_state_dict(state_dict)
mlp4.set_state_dict(state_dict)
stage1_params = train_mlp(mlp, sharding_stage=1, use_pure_fp16=False)
stage2_params = train_mlp(mlp, sharding_stage=2, use_pure_fp16=False)
for i in range(len(stage1_params)):
np.testing.assert_allclose(
stage1_params[i].numpy(), stage2_params[i].numpy(), rtol=1e-6)
dp_params = train_mlp(mlp1, sharding_stage="dp", use_pure_fp16=False)
stage2_params = train_mlp(mlp2, sharding_stage=2, use_pure_fp16=False)
for i in range(len(dp_params)):
for j in range(len(stage2_params)):
if dp_params[i].name == stage2_params[j].name:
np.testing.assert_allclose(
dp_params[i].numpy(), stage2_params[j].numpy(), rtol=1e-6)
stage2_params = train_mlp(
mlp3, sharding_stage=2, use_pure_fp16=True, all_test=True)
......@@ -201,4 +198,4 @@ def test_stage1_stage2():
if __name__ == '__main__':
test_stage1_stage2()
test_dp_stage2()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册