提交 5b76abd4 编写于 作者: Y Yuanzhong Xu 提交者: TensorFlower Gardener

[XLA:SPMD] Fix contracting dim loop einsum threshold

PiperOrigin-RevId: 339989040
Change-Id: I46eaffb4e4b28c168b5f0182c4c0bb946d6eefe8
上级 05483cd4
......@@ -536,15 +536,13 @@ StatusOr<HloInstruction*> PartitionBaseCase(
}
if (lhs_contracting_partitions == rhs_contracting_partitions &&
lhs_contracting_partitions == num_partitions &&
output_sharding_dim > -1) {
if (output_lhs_non_contracting_partitions == num_partitions &&
ShapeSizeInBytes(rhs.base_shape()) >=
options.threshold_for_windowed_einsum_mib * 1024 * 1024) {
output_sharding_dim > -1 &&
ShapeSizeInBytes(output_base_shape) >=
options.threshold_for_windowed_einsum_mib * 1024 * 1024) {
if (output_lhs_non_contracting_partitions == num_partitions) {
return emit_windowed_dot_general(0, 1, false, false, true);
}
if (output_rhs_non_contracting_partitions == num_partitions &&
ShapeSizeInBytes(lhs.base_shape()) >=
options.threshold_for_windowed_einsum_mib * 1024 * 1024) {
if (output_rhs_non_contracting_partitions == num_partitions) {
return emit_windowed_dot_general(1, 0, false, false, true);
}
}
......
......@@ -3681,12 +3681,12 @@ TEST_F(SpmdPartitioningTest,
HloModule module
ENTRY entry {
%lhs = f32[32,25,64,128] parameter(0)
%lhs.copy = f32[32,25,64,128] copy(%lhs), sharding={devices=[1,1,4,1]0,1,2,3}
%rhs = f32[32,39296,64,128] parameter(1)
%rhs.copy = f32[32,39296,64,128] copy(%rhs),
%lhs = f32[320,25,64,128] parameter(0)
%lhs.copy = f32[320,25,64,128] copy(%lhs), sharding={devices=[1,1,4,1]0,1,2,3}
%rhs = f32[320,39296,64,128] parameter(1)
%rhs.copy = f32[320,39296,64,128] copy(%rhs),
sharding={devices=[1,1,4,1]0,1,2,3}
ROOT %dot = f32[32,25,39296] dot(%lhs.copy, %rhs.copy),
ROOT %dot = f32[320,25,39296] dot(%lhs.copy, %rhs.copy),
lhs_batch_dims={0}, rhs_batch_dims={0},
lhs_contracting_dims={2,3}, rhs_contracting_dims={2,3},
sharding={devices=[1,4,1]0,1,2,3}
......@@ -3700,14 +3700,14 @@ ENTRY entry {
auto lhs = AllOf(
op::Copy(op::DynamicSlice(op::Parameter(0), op::Constant(),
op::Constant(), op::Reshape(), op::Constant())),
op::Shape("f32[32,25,16,128]"));
op::Shape("f32[320,25,16,128]"));
auto rhs = AllOf(
op::Copy(op::DynamicSlice(op::Parameter(1), op::Constant(),
op::Constant(), op::Reshape(), op::Constant())),
op::Shape("f32[32,39296,16,128]"));
op::Shape("f32[320,39296,16,128]"));
EXPECT_THAT(root, AllOf(op::GetTupleElement(op::While(op::Tuple(
lhs, rhs, op::Broadcast(), op::Constant()))),
op::Shape("f32[32,7,39296]")));
op::Shape("f32[320,7,39296]")));
auto while_loop = root->operand(0);
// Check loop condition.
......@@ -3721,11 +3721,11 @@ ENTRY entry {
AllOf(op::DynamicSlice(
op::Pad(op::GetTupleElement(op::Parameter(0)), op::Constant()),
op::Constant(), op::Multiply(), op::Constant(), op::Constant()),
op::Shape("f32[32,7,16,128]"));
op::Shape("f32[320,7,16,128]"));
auto partial_output =
AllOf(op::Add(op::GetTupleElement(op::Parameter(0)),
op::Dot(ds, op::GetTupleElement(op::Parameter(0)))),
op::Shape("f32[32,7,39296]"));
op::Shape("f32[320,7,39296]"));
auto window = op::Conditional(op::Compare(next_i, op::Constant()),
partial_output, partial_output);
EXPECT_THAT(while_loop->while_body()->root_instruction(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册