diff --git a/python/paddle/distributed/fleet/base/fleet_base.py b/python/paddle/distributed/fleet/base/fleet_base.py index 2a9b15c732541a22ff73b18b8f9aff0b6b3facc2..d0020a2776be7bc490c0a4cba401a5e066323023 100644 --- a/python/paddle/distributed/fleet/base/fleet_base.py +++ b/python/paddle/distributed/fleet/base/fleet_base.py @@ -269,11 +269,27 @@ class Fleet(object): cg.set_comm_group('global', global_rank, global_world_size, global_ring_id, global_ranks) + use_tensor_parallel = self._user_defined_strategy.tensor_parallel + use_mp = use_sharding or use_tensor_parallel + # hybrid group - if use_sharding is False: return + if use_mp is False: return + + mp_degree_sharding = 1 + mp_degree_tensor_parallel = 1 + if use_sharding: + sharding_configs = self._user_defined_strategy.sharding_configs + mp_degree_sharding = int(sharding_configs['mp_degree']) + + if use_tensor_parallel: + tensor_parallel_configs = self._user_defined_strategy.tensor_parallel_configs + mp_degree_tensor_parallel = int(tensor_parallel_configs[ + 'tensor_parallel_degree']) + + if use_sharding and use_tensor_parallel: + assert mp_degree_sharding == mp_degree_tensor_parallel - sharding_configs = self._user_defined_strategy.sharding_configs - mp_degree = int(sharding_configs['mp_degree']) + mp_degree = mp_degree_sharding if use_sharding else mp_degree_tensor_parallel if mp_degree > 1: assert global_world_size % mp_degree == 0 diff --git a/python/paddle/fluid/tests/unittests/test_fleet_static_mp_layers.py b/python/paddle/fluid/tests/unittests/test_fleet_static_mp_layers.py index 6c7fab25a3096db565091284c67f53bb45a66a39..c9de3814f0acf983a3c298b1a6f9952468f7f0ed 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_static_mp_layers.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_static_mp_layers.py @@ -84,6 +84,8 @@ class TestDistTraning(unittest.TestCase): "mp_degree": self.model_parallel_size, "sharding_degree": 2, } + strategy.tensor_parallel = True + strategy.tensor_parallel_configs = {"tensor_parallel_degree": 2} fleet.init(is_collective=True, strategy=strategy) def get_program(self):