未验证 提交 828f87ae 编写于 作者: B Baibaifan 提交者: GitHub

sharding_stage2_pfp16 (#37836)

上级 3e33ef5a
...@@ -83,8 +83,14 @@ class ShardingOptimizerStage2(Optimizer): ...@@ -83,8 +83,14 @@ class ShardingOptimizerStage2(Optimizer):
# Default information # Default information
self._optim_defaults = kw self._optim_defaults = kw
self._optim = optim self._optim = optim
assert hasattr(self._optim, "_master_weights"
), "Must use optimizer with _master_weights attribute"
self._local_params = params self._local_params = params
self._default_device = device self._default_device = device
self._pfp16 = len(
list(
filter(lambda x: x.trainable and x.dtype == Type.fp16.value,
self._local_params))) > 0
assert group is not None, "Distributed communication group is must be gived" assert group is not None, "Distributed communication group is must be gived"
self.group = group self.group = group
...@@ -98,6 +104,12 @@ class ShardingOptimizerStage2(Optimizer): ...@@ -98,6 +104,12 @@ class ShardingOptimizerStage2(Optimizer):
# Update optimizer parameters and adjust parameter storage and use according to rank. # Update optimizer parameters and adjust parameter storage and use according to rank.
self.update_opt_status() self.update_opt_status()
def _generate_master_params(self, trainable_params):
for param in trainable_params:
if param.dtype == Type.fp16.value:
self._optim._master_weights[param.name] = paddle.cast(
param, Type.fp32.value)
def update_opt_status(self): def update_opt_status(self):
"""Update optimizer status and parameter storage information, and special functions to be developed. """Update optimizer status and parameter storage information, and special functions to be developed.
""" """
...@@ -207,6 +219,8 @@ class ShardingOptimizerStage2(Optimizer): ...@@ -207,6 +219,8 @@ class ShardingOptimizerStage2(Optimizer):
# Merge all the trainable params in a single InternalStorage # Merge all the trainable params in a single InternalStorage
trainable_params = list( trainable_params = list(
filter(lambda x: x.trainable, params)) filter(lambda x: x.trainable, params))
if self._pfp16 and dst_rank == self.rank:
self._generate_master_params(trainable_params)
if trainable_params: if trainable_params:
param_storage = ParamStorage( param_storage = ParamStorage(
size=self.rank_buffer_size[dtype][dst_rank], size=self.rank_buffer_size[dtype][dst_rank],
......
...@@ -30,6 +30,7 @@ from paddle import nn ...@@ -30,6 +30,7 @@ from paddle import nn
import paddle.distributed as dist import paddle.distributed as dist
from ...utils.internal_storage import GradStorage from ...utils.internal_storage import GradStorage
from ...meta_optimizers.dygraph_optimizer.sharding_optimizer_stage2 import ShardingOptimizerStage2
from .sharding_utils import Taskflow, Type from .sharding_utils import Taskflow, Type
...@@ -70,6 +71,11 @@ class ShardingStage2(nn.Layer): ...@@ -70,6 +71,11 @@ class ShardingStage2(nn.Layer):
self._layer = layer self._layer = layer
self._sharding_optimizers = [sharding_optimizer] if not isinstance( self._sharding_optimizers = [sharding_optimizer] if not isinstance(
sharding_optimizer, list) else sharding_optimizer sharding_optimizer, list) else sharding_optimizer
assert all(
list(
map(lambda opt: isinstance(opt, ShardingOptimizerStage2),
self._sharding_optimizers))
), "Please use ShardingOptimizerStage2 optimizer"
self._sync_buffers = sync_buffers self._sync_buffers = sync_buffers
self._auto_refresh_trainable = auto_refresh_trainable self._auto_refresh_trainable = auto_refresh_trainable
...@@ -88,8 +94,7 @@ class ShardingStage2(nn.Layer): ...@@ -88,8 +94,7 @@ class ShardingStage2(nn.Layer):
# Global statistical parameters # Global statistical parameters
self._all_params = list( self._all_params = list(
chain( chain(*[optim.local_params for optim in self._sharding_optimizers]))
* [optim.local_params for optim in self._sharding_optimizers]))
self._trainable_params = [] self._trainable_params = []
self._grad_reduced = [] self._grad_reduced = []
self._trainable_param2rank = {} self._trainable_param2rank = {}
...@@ -436,7 +441,7 @@ class ShardingStage2(nn.Layer): ...@@ -436,7 +441,7 @@ class ShardingStage2(nn.Layer):
._fill)) ._fill))
self._grad_storage_list = list( self._grad_storage_list = list(
chain(* [ chain(*[
self._grad_storages[dtype].values() self._grad_storages[dtype].values()
for dtype in self._grad_storages.keys() for dtype in self._grad_storages.keys()
])) ]))
......
...@@ -24,7 +24,6 @@ from paddle.fluid.dygraph.nn import Linear ...@@ -24,7 +24,6 @@ from paddle.fluid.dygraph.nn import Linear
from paddle.distributed import fleet from paddle.distributed import fleet
from paddle.fluid.dygraph import nn from paddle.fluid.dygraph import nn
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import DygraphShardingOptimizer
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.sharding_optimizer_stage2 import ShardingOptimizerStage2 from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.sharding_optimizer_stage2 import ShardingOptimizerStage2
from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage2 import ShardingStage2 from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage2 import ShardingStage2
...@@ -70,7 +69,7 @@ def reader_decorator(): ...@@ -70,7 +69,7 @@ def reader_decorator():
return __reader__ return __reader__
def optimizer_setting(model, use_pure_fp16, stage=1): def optimizer_setting(model, use_pure_fp16):
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0) clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
optimizer = paddle.optimizer.AdamW( optimizer = paddle.optimizer.AdamW(
parameters=model.parameters(), parameters=model.parameters(),
...@@ -87,20 +86,16 @@ def train_mlp(model, ...@@ -87,20 +86,16 @@ def train_mlp(model,
use_pure_fp16=False, use_pure_fp16=False,
all_test=False, all_test=False,
accumulate_grad=False): accumulate_grad=False):
if sharding_stage == 1: if sharding_stage == "dp":
hcg = fleet.get_hybrid_communicate_group() hcg = fleet.get_hybrid_communicate_group()
group = hcg.get_check_parallel_group() group = hcg.get_check_parallel_group()
else: else:
group = paddle.distributed.new_group([0, 1]) group = paddle.distributed.new_group([0, 1])
optimizer = optimizer_setting( optimizer = optimizer_setting(model=model, use_pure_fp16=use_pure_fp16)
model=model, use_pure_fp16=use_pure_fp16, stage=sharding_stage)
if use_pure_fp16: if use_pure_fp16:
model, optimizer = paddle.amp.decorate( model = paddle.amp.decorate(
models=model, models=model, level='O2', save_dtype='float32')
optimizers=optimizer,
level='O2',
save_dtype='float32')
if sharding_stage == 2: if sharding_stage == 2:
optimizer = ShardingOptimizerStage2( optimizer = ShardingOptimizerStage2(
...@@ -164,7 +159,7 @@ def train_mlp(model, ...@@ -164,7 +159,7 @@ def train_mlp(model,
return model.parameters() return model.parameters()
def test_stage1_stage2(): def test_dp_stage2():
mlp = MLP() mlp = MLP()
state_dict = mlp.state_dict() state_dict = mlp.state_dict()
mlp1 = MLP() mlp1 = MLP()
...@@ -175,11 +170,13 @@ def test_stage1_stage2(): ...@@ -175,11 +170,13 @@ def test_stage1_stage2():
mlp2.set_state_dict(state_dict) mlp2.set_state_dict(state_dict)
mlp3.set_state_dict(state_dict) mlp3.set_state_dict(state_dict)
mlp4.set_state_dict(state_dict) mlp4.set_state_dict(state_dict)
stage1_params = train_mlp(mlp, sharding_stage=1, use_pure_fp16=False) dp_params = train_mlp(mlp1, sharding_stage="dp", use_pure_fp16=False)
stage2_params = train_mlp(mlp, sharding_stage=2, use_pure_fp16=False) stage2_params = train_mlp(mlp2, sharding_stage=2, use_pure_fp16=False)
for i in range(len(stage1_params)): for i in range(len(dp_params)):
for j in range(len(stage2_params)):
if dp_params[i].name == stage2_params[j].name:
np.testing.assert_allclose( np.testing.assert_allclose(
stage1_params[i].numpy(), stage2_params[i].numpy(), rtol=1e-6) dp_params[i].numpy(), stage2_params[j].numpy(), rtol=1e-6)
stage2_params = train_mlp( stage2_params = train_mlp(
mlp3, sharding_stage=2, use_pure_fp16=True, all_test=True) mlp3, sharding_stage=2, use_pure_fp16=True, all_test=True)
...@@ -201,4 +198,4 @@ def test_stage1_stage2(): ...@@ -201,4 +198,4 @@ def test_stage1_stage2():
if __name__ == '__main__': if __name__ == '__main__':
test_stage1_stage2() test_dp_stage2()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册