未验证 提交 8e9de875 编写于 作者: Y Yichen Zhang 提交者: GitHub

add reshape backward rule (#56443)

上级 f2968742
......@@ -135,6 +135,7 @@ std::vector<DimTrans*> MakeReshapeDimTrans(
return ret;
}
//
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
paddle::distributed::auto_parallel::ReshapeSPMDRule::InferForward(
const std::vector<DistTensorSpec>& input_specs,
......@@ -195,12 +196,64 @@ paddle::distributed::auto_parallel::ReshapeSPMDRule::InferForward(
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
paddle::distributed::auto_parallel::ReshapeSPMDRule::InferBackward(
const std::vector<DistTensorSpec>& input_specs,
const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) {
PADDLE_THROW(phi::errors::Unimplemented(
"InferBackward of ReductionSPMDRule is NOT implemented yet."));
// step0: Verify Input Args Based on Reshape Logic
int64_t ninputs = input_specs.size();
int64_t noutputs = output_specs.size();
PADDLE_ENFORCE_EQ(
ninputs,
1,
phi::errors::InvalidArgument("The size of InputSpec in reshape must "
"be equal to 1, but got [%d].",
ninputs));
PADDLE_ENFORCE_EQ(
noutputs,
1,
phi::errors::InvalidArgument("The size of OutputSpec in reshape must "
"be equal to 1, but got [%d].",
noutputs));
VerifySpecs(output_specs, "reshape");
// step1: build the transformation from the output shape
// to original shape. Inferbackward infers the dims mapping
// from output to input, we first get the transformation
// from output to input so that we can infer the dims mapping
// with the map from output axes to input axes.
// Shapes in Inferbackward don't contain -1 or 0, so they will
// not be modified and we can use ref here.
const std::vector<int64_t>& output_shape = output_specs[0].shape();
const std::vector<int64_t>& input_shape = input_specs[0].shape();
std::vector<DimTrans*> trans = MakeReshapeDimTrans(output_shape, input_shape);
// step2: infer the dims mapping of input with
// output's dims_mapping and the transformation.
std::vector<std::vector<int64_t>> dims_mapping_vec =
InferFromDimTrans(output_specs[0], trans);
// step3: update the dist attributes of input
// and output with the inferred dims mapping
TensorDistAttr new_output_dist_attr(output_specs[0].dist_attr());
new_output_dist_attr.set_dims_mapping(dims_mapping_vec[0]);
TensorDistAttr input_dist_attr(input_specs[0].dist_attr());
input_dist_attr.set_dims_mapping(dims_mapping_vec[1]);
VLOG(4) << "Reshape Inferbackward: output_shape: [" << str_join(output_shape)
<< "] input_shape: [" << str_join(input_shape) << "]";
VLOG(4) << "Transformation from output to input:";
for (int64_t i = 0, n = trans.size(); i < n; i++) {
DimTrans* t = trans[i];
VLOG(4) << "\tInput axis " << i << ": " << t->to_string();
}
VLOG(4) << "input_dims_mapping: [" << str_join(dims_mapping_vec[1])
<< "] output_dims_mapping: [" << str_join(dims_mapping_vec[0])
<< "]\n\n";
CleanUp();
return {};
return {{input_dist_attr}, {new_output_dist_attr}};
}
} // namespace auto_parallel
......
......@@ -32,7 +32,8 @@ class ReshapeSPMDRule : public SPMDRuleBase {
const paddle::framework::AttributeMap& attrs) override;
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
InferBackward(const std::vector<DistTensorSpec>& output_specs,
InferBackward(const std::vector<DistTensorSpec>& input_specs,
const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) override;
};
} // namespace auto_parallel
......
......@@ -30,7 +30,7 @@ class TestReshapeSPMDRule(unittest.TestCase):
process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2], [3, 4, 5]])
x_tensor_dist_attr = TensorDistAttr()
x_tensor_dist_attr.dims_mapping = [-1, -1]
x_tensor_dist_attr.dims_mapping = [-1, -1, -1, -1]
x_tensor_dist_attr.process_mesh = process_mesh
self.x_dist_tensor_spec = DistTensorSpec(x_shape, x_tensor_dist_attr)
......@@ -248,6 +248,171 @@ class TestReshapeSPMDRule(unittest.TestCase):
with self.assertRaises(BaseException):
self.rule.infer_forward([self.x_dist_tensor_spec], self.attrs)
def test_reshape_infer_backward(self):
process_mesh = auto.ProcessMesh(mesh=[[0, 1, 2], [3, 4, 5]])
output_tensor_dist_attr = TensorDistAttr()
output_tensor_dist_attr.dims_mapping = [-1, -1, -1, -1]
output_tensor_dist_attr.process_mesh = process_mesh
# shape: [6, 12, 48, 24] --> [1, 72, 48, 4, 6] (input --> output)
# dims_mapping: [-1, 0, 1, -1, -1] --> [0, -1, 1, -1], [-1, 0, 1, -1, -1] (output --> input, output)
self.output_dist_tensor_spec = DistTensorSpec(
[1, 72, 48, 4, 6], output_tensor_dist_attr
)
self.output_dist_tensor_spec.set_dims_mapping([-1, 0, 1, -1, -1])
result_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec],
[self.output_dist_tensor_spec],
self.attrs,
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(len(infered_input_dist_attrs), 1)
self.assertEqual(len(infered_output_dist_attrs), 1)
self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [0, -1, 1, -1]
)
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [-1, 0, 1, -1, -1]
)
# shape: [6, 12, 48, 24] --> [1, 72, 48, 4, 6] (input --> output)
# dims_mapping: [-1, -1, -1, -1, -1] --> [-1, -1, -1, -1], [-1, -1, -1, -1, -1] (output --> input, output)
self.output_dist_tensor_spec.shape = [1, 72, 48, 4, 6]
self.output_dist_tensor_spec.set_dims_mapping([-1, -1, -1, -1, -1])
result_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec],
[self.output_dist_tensor_spec],
self.attrs,
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [-1, -1, -1, -1]
)
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [-1, -1, -1, -1, -1]
)
# shape: [6, 12, 48, 24] --> [1, 72, 48, 4, 6] (input --> output)
# dims_mapping: [-1, 1, -1, 0, -1] --> [1, -1, -1, 0] [-1, 1, -1, 0, -1] (output --> input, output)
self.output_dist_tensor_spec.shape = [1, 72, 48, 4, 6]
self.output_dist_tensor_spec.set_dims_mapping([-1, 1, -1, 0, -1])
result_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec],
[self.output_dist_tensor_spec],
self.attrs,
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [1, -1, -1, 0]
)
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [-1, 1, -1, 0, -1]
)
# shape: [6, 12, 48, 24] --> [3, 24, 6, 8, 24] (input --> output)
# dims_mapping: [1, -1, -1, -1, 0] --> [1, -1, -1, 0], [1, -1, -1, -1, 0] (output --> input, output)
self.output_dist_tensor_spec.shape = [3, 24, 6, 8, 24]
self.output_dist_tensor_spec.set_dims_mapping([1, -1, -1, -1, 0])
result_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec],
[self.output_dist_tensor_spec],
self.attrs,
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [1, -1, -1, 0]
)
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [1, -1, -1, -1, 0]
)
# shape: [6, 12, 48, 24] --> [3, 24, 6, 8, 24] (input --> output)
# dims_mapping: [-1, -1, 0, -1, 1] --> [-1, -1, 0, 1], [-1, -1, 0, -1, 1] (output --> input, output)
self.output_dist_tensor_spec.shape = [3, 24, 6, 8, 24]
self.output_dist_tensor_spec.set_dims_mapping([-1, -1, 0, -1, 1])
result_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec],
[self.output_dist_tensor_spec],
self.attrs,
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [-1, -1, 0, 1]
)
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [-1, -1, 0, -1, 1]
)
# shape: [6, 12, 48, 24] --> [6, 12, 48, 24] (intput --> output)
# dims_mapping: [-1, -1, 0, 1] --> [-1, -1, 0, 1], [-1, -1, 0, 1] (output --> input, output)
self.output_dist_tensor_spec.shape = [6, 12, 48, 24]
self.output_dist_tensor_spec.set_dims_mapping([-1, -1, 0, 1])
result_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec],
[self.output_dist_tensor_spec],
self.attrs,
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [-1, -1, 0, 1]
)
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [-1, -1, 0, 1]
)
# shape: [6, 12, 48, 24] --> [72, 3, 16, 24] (intput --> output)
# dims_mapping: [0, 1, -1, -1] --> [0, -1, 1, -1], [0, 1, -1, -1] (output --> input, output)
self.output_dist_tensor_spec.shape = [72, 3, 16, 24]
self.output_dist_tensor_spec.set_dims_mapping([0, 1, -1, -1])
result_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec],
[self.output_dist_tensor_spec],
self.attrs,
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [0, -1, 1, -1]
)
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [0, 1, -1, -1]
)
# shape: [6, 12, 48, 24] --> [72, 3, 16, 24] (intput --> output)
# dims_mapping: [1, -1, -1, -1] --> [1, -1, -1, -1], [1, -1, -1, -1] (output --> input, output)
self.output_dist_tensor_spec.shape = [72, 3, 16, 24]
self.output_dist_tensor_spec.set_dims_mapping([1, -1, -1, -1])
result_dist_attrs = self.rule.infer_backward(
[self.x_dist_tensor_spec],
[self.output_dist_tensor_spec],
self.attrs,
)
infered_input_dist_attrs = result_dist_attrs[0]
infered_output_dist_attrs = result_dist_attrs[1]
self.assertEqual(
infered_input_dist_attrs[0].dims_mapping, [1, -1, -1, -1]
)
self.assertEqual(
infered_output_dist_attrs[0].dims_mapping, [1, -1, -1, -1]
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册