Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
cf652101
P
Paddle
项目概览
PaddlePaddle
/
Paddle
9 个月 前同步成功
通知
2282
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
未验证
提交
cf652101
编写于
9月 08, 2023
作者:
Y
Yichen Zhang
提交者:
GitHub
9月 08, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add softmax backward rule (#56502)
上级
8e9de875
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
133 addition
and
6 deletion
+133
-6
paddle/fluid/distributed/auto_parallel/spmd_rules/softmax_spmd_rule.cc
...distributed/auto_parallel/spmd_rules/softmax_spmd_rule.cc
+67
-5
paddle/fluid/distributed/auto_parallel/spmd_rules/softmax_spmd_rule.h
.../distributed/auto_parallel/spmd_rules/softmax_spmd_rule.h
+2
-1
test/auto_parallel/spmd_rules/test_softmax_rule.py
test/auto_parallel/spmd_rules/test_softmax_rule.py
+64
-0
未找到文件。
paddle/fluid/distributed/auto_parallel/spmd_rules/softmax_spmd_rule.cc
浏览文件 @
cf652101
...
...
@@ -20,10 +20,10 @@ namespace auto_parallel {
using
phi
::
distributed
::
auto_parallel
::
str_join
;
// step0: verify input args based on softmax logic
std
::
pair
<
std
::
vector
<
TensorDistAttr
>
,
std
::
vector
<
TensorDistAttr
>>
SoftmaxSPMDRule
::
InferForward
(
const
std
::
vector
<
DistTensorSpec
>&
input_specs
,
const
paddle
::
framework
::
AttributeMap
&
attrs
)
{
// step0: verify input args based on softmax logic
auto
input_specs_size
=
input_specs
.
size
();
PADDLE_ENFORCE_EQ
(
input_specs_size
,
...
...
@@ -33,7 +33,7 @@ SoftmaxSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs,
input_specs_size
));
auto
x_shape
=
input_specs
[
0
].
shape
();
int
x_ndim
=
static_cast
<
int
>
(
x_shape
.
size
()
);
int
x_ndim
=
x_shape
.
size
(
);
auto
x_dist_attr_src
=
input_specs
[
0
].
dist_attr
();
std
::
vector
<
int64_t
>
x_dims_mapping
=
x_dist_attr_src
.
dims_mapping
();
...
...
@@ -94,7 +94,7 @@ SoftmaxSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs,
TensorDistAttr
x_dist_attr_dst
=
CopyTensorDistAttrForOutput
(
x_dist_attr_src
);
x_dist_attr_dst
.
set_dims_mapping
(
x_dims_mapping
);
VLOG
(
4
)
<<
"
Embedding
SPMDRule InferForward: "
VLOG
(
4
)
<<
"
Softmax
SPMDRule InferForward: "
<<
"Einsum notation: ["
<<
x_axes
<<
" --> "
<<
out_axes
<<
"]. "
<<
std
::
endl
<<
"X shape: ["
<<
str_join
(
x_shape
)
<<
"], src_dims_mapping: ["
...
...
@@ -107,9 +107,71 @@ SoftmaxSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs,
std
::
pair
<
std
::
vector
<
TensorDistAttr
>
,
std
::
vector
<
TensorDistAttr
>>
SoftmaxSPMDRule
::
InferBackward
(
const
std
::
vector
<
DistTensorSpec
>&
input_specs
,
const
std
::
vector
<
DistTensorSpec
>&
output_specs
,
const
paddle
::
framework
::
AttributeMap
&
attrs
)
{
PADDLE_THROW
(
phi
::
errors
::
Unimplemented
(
"InferBackward of SoftmaxSPMDRule is NOT implemented yet."
));
// step0: verify input args based on softmax logic
int64_t
input_specs_size
=
input_specs
.
size
();
int64_t
output_specs_size
=
output_specs
.
size
();
PADDLE_ENFORCE_EQ
(
input_specs_size
,
1
,
phi
::
errors
::
InvalidArgument
(
"The size of InputSpec of softmax should be 1, but got [%d]."
,
input_specs_size
));
PADDLE_ENFORCE_EQ
(
output_specs_size
,
1
,
phi
::
errors
::
InvalidArgument
(
"The size of InputSpec of softmax should be 1, but got [%d]."
,
output_specs_size
));
VerifySpecs
(
output_specs
,
"softmax_backward"
);
// step1: build Einsum Notation
std
::
vector
<
int64_t
>
x_shape
=
input_specs
[
0
].
shape
();
int64_t
x_ndim
=
input_specs
[
0
].
shape
().
size
();
std
::
string
alphabet
=
"abcdefghijklmnopqrstuvwxyz"
;
std
::
string
x_axes
=
GetBroadcastAxes
(
x_ndim
,
x_ndim
,
alphabet
);
std
::
string
out_axes
=
x_axes
;
int
axis
=
ExtractAttr
<
int
>
(
"axis"
,
attrs
);
// normalize axis
if
(
axis
<
0
)
{
axis
=
x_ndim
+
axis
;
}
// sharding on softmax_axis is not supported now, so set
// the notation on softmax_axis to '1' so that we can set
// its dim mapping to -1
x_axes
[
axis
]
=
'1'
;
// step2: Sharding Propogation
std
::
vector
<
int64_t
>
out_dims_mapping
=
output_specs
[
0
].
dims_mapping
();
std
::
unordered_map
<
std
::
string
,
int64_t
>
axis_to_dim_map
=
ShardingMergeForTensors
({{
out_axes
,
out_dims_mapping
}});
// infer input's dims mapping.
std
::
vector
<
int64_t
>
x_dims_mapping
=
GetDimsMappingForAxes
(
x_axes
,
axis_to_dim_map
);
TensorDistAttr
input_dist_attr
(
input_specs
[
0
].
dist_attr
());
input_dist_attr
.
set_dims_mapping
(
x_dims_mapping
);
// update output's dims mapping.
out_dims_mapping
[
axis
]
=
-
1
;
TensorDistAttr
output_dist_attr
(
output_specs
[
0
].
dist_attr
());
output_dist_attr
.
set_dims_mapping
(
out_dims_mapping
);
VLOG
(
4
)
<<
"SoftmaxSPMDRule InferBackward: "
<<
"softmax_axis: "
<<
axis
<<
std
::
endl
<<
"Einsum notation: ["
<<
x_axes
<<
" --> "
<<
out_axes
<<
"]. "
<<
std
::
endl
<<
"Output shape: ["
<<
str_join
(
output_specs
[
0
].
shape
())
<<
"], src_dims_mapping: ["
<<
str_join
(
output_specs
[
0
].
dims_mapping
())
<<
"], dst_dims_mapping: ["
<<
str_join
(
out_dims_mapping
)
<<
"]; Input dims_mapping: ["
<<
str_join
(
x_dims_mapping
)
<<
"]"
;
return
{{
input_dist_attr
},
{
output_dist_attr
}};
}
}
// namespace auto_parallel
...
...
paddle/fluid/distributed/auto_parallel/spmd_rules/softmax_spmd_rule.h
浏览文件 @
cf652101
...
...
@@ -32,7 +32,8 @@ class SoftmaxSPMDRule : public SPMDRuleBase {
const
paddle
::
framework
::
AttributeMap
&
attrs
)
override
;
std
::
pair
<
std
::
vector
<
TensorDistAttr
>
,
std
::
vector
<
TensorDistAttr
>>
InferBackward
(
const
std
::
vector
<
DistTensorSpec
>&
output_specs
,
InferBackward
(
const
std
::
vector
<
DistTensorSpec
>&
input_specs
,
const
std
::
vector
<
DistTensorSpec
>&
output_specs
,
const
paddle
::
framework
::
AttributeMap
&
attrs
)
override
;
};
}
// namespace auto_parallel
...
...
test/auto_parallel/spmd_rules/test_softmax_rule.py
浏览文件 @
cf652101
...
...
@@ -33,6 +33,8 @@ class TestEmbeddingSPMDRule(unittest.TestCase):
x_tensor_dist_attr
.
process_mesh
=
process_mesh
self
.
x_dist_tensor_spec
=
DistTensorSpec
(
x_shape
,
x_tensor_dist_attr
)
self
.
out_dist_tensor_spec
=
DistTensorSpec
(
self
.
x_dist_tensor_spec
)
self
.
attrs
=
{
'axis'
:
-
1
,
}
...
...
@@ -99,6 +101,68 @@ class TestEmbeddingSPMDRule(unittest.TestCase):
self
.
assertEqual
(
infered_input_dist_attrs
[
0
].
dims_mapping
,
[
-
1
,
-
1
,
0
])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
-
1
,
-
1
,
0
])
def
test_softmax_infer_backward
(
self
):
# sharding on batch axis I
self
.
out_dist_tensor_spec
.
set_dims_mapping
([
1
,
-
1
,
-
1
])
result_dist_attrs
=
self
.
rule1
.
infer_backward
(
[
self
.
x_dist_tensor_spec
],
[
self
.
out_dist_tensor_spec
],
self
.
attrs
)
self
.
assertEqual
(
len
(
result_dist_attrs
),
2
)
infered_input_dist_attrs
=
result_dist_attrs
[
0
]
infered_output_dist_attrs
=
result_dist_attrs
[
1
]
self
.
assertEqual
(
len
(
infered_input_dist_attrs
),
1
)
self
.
assertEqual
(
len
(
infered_output_dist_attrs
),
1
)
self
.
assertEqual
(
infered_input_dist_attrs
[
0
].
dims_mapping
,
[
1
,
-
1
,
-
1
])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
1
,
-
1
,
-
1
])
# sharding on batch axis II
self
.
out_dist_tensor_spec
.
set_dims_mapping
([
-
1
,
1
,
-
1
])
result_dist_attrs
=
self
.
rule1
.
infer_backward
(
[
self
.
x_dist_tensor_spec
],
[
self
.
out_dist_tensor_spec
],
self
.
attrs
)
infered_input_dist_attrs
=
result_dist_attrs
[
0
]
infered_output_dist_attrs
=
result_dist_attrs
[
1
]
self
.
assertEqual
(
infered_input_dist_attrs
[
0
].
dims_mapping
,
[
-
1
,
1
,
-
1
])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
-
1
,
1
,
-
1
])
# sharding on softmax_axis
self
.
out_dist_tensor_spec
.
set_dims_mapping
([
1
,
-
1
,
0
])
result_dist_attrs
=
self
.
rule1
.
infer_backward
(
[
self
.
x_dist_tensor_spec
],
[
self
.
out_dist_tensor_spec
],
self
.
attrs
)
infered_input_dist_attrs
=
result_dist_attrs
[
0
]
infered_output_dist_attrs
=
result_dist_attrs
[
1
]
self
.
assertEqual
(
infered_input_dist_attrs
[
0
].
dims_mapping
,
[
1
,
-
1
,
-
1
])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
1
,
-
1
,
-
1
])
# sharding on softmax_axis + axis = 1
self
.
attrs
=
{
'axis'
:
1
,
}
self
.
out_dist_tensor_spec
.
set_dims_mapping
([
-
1
,
1
,
0
])
result_dist_attrs
=
self
.
rule1
.
infer_backward
(
[
self
.
x_dist_tensor_spec
],
[
self
.
out_dist_tensor_spec
],
self
.
attrs
)
infered_input_dist_attrs
=
result_dist_attrs
[
0
]
infered_output_dist_attrs
=
result_dist_attrs
[
1
]
self
.
assertEqual
(
infered_input_dist_attrs
[
0
].
dims_mapping
,
[
-
1
,
-
1
,
0
])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
-
1
,
-
1
,
0
])
# sharding on softmax_axis + axis = -2
self
.
attrs
=
{
'axis'
:
-
2
,
}
self
.
out_dist_tensor_spec
.
set_dims_mapping
([
-
1
,
1
,
0
])
result_dist_attrs
=
self
.
rule1
.
infer_backward
(
[
self
.
x_dist_tensor_spec
],
[
self
.
out_dist_tensor_spec
],
self
.
attrs
)
infered_input_dist_attrs
=
result_dist_attrs
[
0
]
infered_output_dist_attrs
=
result_dist_attrs
[
1
]
self
.
assertEqual
(
infered_input_dist_attrs
[
0
].
dims_mapping
,
[
-
1
,
-
1
,
0
])
self
.
assertEqual
(
infered_output_dist_attrs
[
0
].
dims_mapping
,
[
-
1
,
-
1
,
0
])
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录