未验证 提交 cf652101 编写于 作者: Y Yichen Zhang 提交者: GitHub

add softmax backward rule (#56502)

上级 8e9de875
......@@ -20,10 +20,10 @@ namespace auto_parallel {
using phi::distributed::auto_parallel::str_join;
// step0: verify input args based on softmax logic
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
SoftmaxSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs) {
// step0: verify input args based on softmax logic
auto input_specs_size = input_specs.size();
PADDLE_ENFORCE_EQ(
input_specs_size,
......@@ -33,7 +33,7 @@ SoftmaxSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs,
input_specs_size));
auto x_shape = input_specs[0].shape();
int x_ndim = static_cast<int>(x_shape.size());
int x_ndim = x_shape.size();
auto x_dist_attr_src = input_specs[0].dist_attr();
std::vector<int64_t> x_dims_mapping = x_dist_attr_src.dims_mapping();
......@@ -94,7 +94,7 @@ SoftmaxSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs,
TensorDistAttr x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src);
x_dist_attr_dst.set_dims_mapping(x_dims_mapping);
VLOG(4) << "EmbeddingSPMDRule InferForward: "
VLOG(4) << "SoftmaxSPMDRule InferForward: "
<< "Einsum notation: [" << x_axes << " --> " << out_axes << "]. "
<< std::endl
<< "X shape: [" << str_join(x_shape) << "], src_dims_mapping: ["
......@@ -107,9 +107,71 @@ SoftmaxSPMDRule::InferForward(const std::vector<DistTensorSpec>& input_specs,
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
SoftmaxSPMDRule::InferBackward(const std::vector<DistTensorSpec>& input_specs,
const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) {
PADDLE_THROW(phi::errors::Unimplemented(
"InferBackward of SoftmaxSPMDRule is NOT implemented yet."));
// step0: verify input args based on softmax logic
int64_t input_specs_size = input_specs.size();
int64_t output_specs_size = output_specs.size();
PADDLE_ENFORCE_EQ(
input_specs_size,
1,
phi::errors::InvalidArgument(
"The size of InputSpec of softmax should be 1, but got [%d].",
input_specs_size));
PADDLE_ENFORCE_EQ(
output_specs_size,
1,
phi::errors::InvalidArgument(
"The size of InputSpec of softmax should be 1, but got [%d].",
output_specs_size));
VerifySpecs(output_specs, "softmax_backward");
// step1: build Einsum Notation
std::vector<int64_t> x_shape = input_specs[0].shape();
int64_t x_ndim = input_specs[0].shape().size();
std::string alphabet = "abcdefghijklmnopqrstuvwxyz";
std::string x_axes = GetBroadcastAxes(x_ndim, x_ndim, alphabet);
std::string out_axes = x_axes;
int axis = ExtractAttr<int>("axis", attrs);
// normalize axis
if (axis < 0) {
axis = x_ndim + axis;
}
// sharding on softmax_axis is not supported now, so set
// the notation on softmax_axis to '1' so that we can set
// its dim mapping to -1
x_axes[axis] = '1';
// step2: Sharding Propogation
std::vector<int64_t> out_dims_mapping = output_specs[0].dims_mapping();
std::unordered_map<std::string, int64_t> axis_to_dim_map =
ShardingMergeForTensors({{out_axes, out_dims_mapping}});
// infer input's dims mapping.
std::vector<int64_t> x_dims_mapping =
GetDimsMappingForAxes(x_axes, axis_to_dim_map);
TensorDistAttr input_dist_attr(input_specs[0].dist_attr());
input_dist_attr.set_dims_mapping(x_dims_mapping);
// update output's dims mapping.
out_dims_mapping[axis] = -1;
TensorDistAttr output_dist_attr(output_specs[0].dist_attr());
output_dist_attr.set_dims_mapping(out_dims_mapping);
VLOG(4) << "SoftmaxSPMDRule InferBackward: "
<< "softmax_axis: " << axis << std::endl
<< "Einsum notation: [" << x_axes << " --> " << out_axes << "]. "
<< std::endl
<< "Output shape: [" << str_join(output_specs[0].shape())
<< "], src_dims_mapping: ["
<< str_join(output_specs[0].dims_mapping())
<< "], dst_dims_mapping: [" << str_join(out_dims_mapping)
<< "]; Input dims_mapping: [" << str_join(x_dims_mapping) << "]";
return {{input_dist_attr}, {output_dist_attr}};
}
} // namespace auto_parallel
......
......@@ -32,7 +32,8 @@ class SoftmaxSPMDRule : public SPMDRuleBase {
const paddle::framework::AttributeMap& attrs) override;
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
InferBackward(const std::vector<DistTensorSpec>& output_specs,
InferBackward(const std::vector<DistTensorSpec>& input_specs,
const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) override;
};
} // namespace auto_parallel
......
......@@ -33,6 +33,8 @@ class TestEmbeddingSPMDRule(unittest.TestCase):
x_tensor_dist_attr.process_mesh = process_mesh
self.x_dist_tensor_spec = DistTensorSpec(x_shape, x_tensor_dist_attr)
self.out_dist_tensor_spec = DistTensorSpec(self.x_dist_tensor_spec)
self.attrs = {
'axis': -1,
}
......@@ -99,6 +101,68 @@ class TestEmbeddingSPMDRule(unittest.TestCase):
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, -1, 0])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, -1, 0])
def test_softmax_infer_backward(self):
# sharding on batch axis I
self.out_dist_tensor_spec.set_dims_mapping([1, -1, -1])
result_dist_attrs = self.rule1.infer_backward(
[self.x_dist_tensor_spec], [self.out_dist_tensor_spec], self.attrs
)
self.assertEqual(len(result_dist_attrs), 2)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(len(infered_input_dist_attrs), 1)
self.assertEqual(len(infered_output_dist_attrs), 1)
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, -1, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, -1, -1])
# sharding on batch axis II
self.out_dist_tensor_spec.set_dims_mapping([-1, 1, -1])
result_dist_attrs = self.rule1.infer_backward(
[self.x_dist_tensor_spec], [self.out_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, 1, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, 1, -1])
# sharding on softmax_axis
self.out_dist_tensor_spec.set_dims_mapping([1, -1, 0])
result_dist_attrs = self.rule1.infer_backward(
[self.x_dist_tensor_spec], [self.out_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [1, -1, -1])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [1, -1, -1])
# sharding on softmax_axis + axis = 1
self.attrs = {
'axis': 1,
}
self.out_dist_tensor_spec.set_dims_mapping([-1, 1, 0])
result_dist_attrs = self.rule1.infer_backward(
[self.x_dist_tensor_spec], [self.out_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, -1, 0])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, -1, 0])
# sharding on softmax_axis + axis = -2
self.attrs = {
'axis': -2,
}
self.out_dist_tensor_spec.set_dims_mapping([-1, 1, 0])
result_dist_attrs = self.rule1.infer_backward(
[self.x_dist_tensor_spec], [self.out_dist_tensor_spec], self.attrs
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(infered_input_dist_attrs[0].dims_mapping, [-1, -1, 0])
self.assertEqual(infered_output_dist_attrs[0].dims_mapping, [-1, -1, 0])
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册