diff --git a/python/paddle/distributed/collective.py b/python/paddle/distributed/collective.py index cfba9f656b333a743fbf5890d5928a3178faede9..df6df856222299f7ff0751086aedd344c020f8bf 100644 --- a/python/paddle/distributed/collective.py +++ b/python/paddle/distributed/collective.py @@ -1341,20 +1341,20 @@ def split(x, Examples: .. code-block:: python - + # required: distributed import paddle - from paddle.distributed import init_parallel_env - - # required: gpu + import paddle.distributed.fleet as fleet + paddle.enable_static() paddle.set_device('gpu:%d'%paddle.distributed.ParallelEnv().dev_id) - init_parallel_env() + fleet.init(is_collective=True) data = paddle.randint(0, 8, shape=[10,4]) emb_out = paddle.distributed.split( data, (8, 8), operation="embedding", num_partitions=2) + """ assert isinstance(size, (list, tuple)), ( "The type of size for "