提交 c33f9482 编写于 作者: J Jaesung Chung 提交者: TensorFlower Gardener

Do not legalize tf.Conv2DOp when the filter type is not determined.

Currently, the converter generates a wrong constant when the condition is hold.

PiperOrigin-RevId: 340170033
Change-Id: Ia4af732f456254738c76840a32673ecff5f0fb6a
上级 7d4e7b86
......@@ -639,4 +639,11 @@ func @DontMatchFusedBatchNormV3(%arg0 :tensor<?x576x1x1xf32>, %arg1 : tensor<576
// CHECK: "tf.FusedBatchNormV3"
}
// CHECK-LABEL: DoNotConvertConv2DWhenFilterTypeDimIsNotDecided
func @DoNotConvertConv2DWhenFilterTypeDimIsNotDecided(%arg0 : tensor<?x?x?x96xf32>, %arg1 : tensor<3x3x96x?xf32>) -> tensor<?x?x?x?xf32> {
%0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<?x?x?x96xf32>, tensor<3x3x96x?xf32>) -> tensor<?x?x?x?xf32>
return %0 : tensor<?x?x?x?xf32>
// CHECK: tf.Conv2D
}
}
......@@ -328,7 +328,9 @@ struct ConvertTFConvOp : public RewritePattern {
// tensor, for setting depth_multiplier attribute, etc.).
auto filter = tf_op.filter();
auto filter_type = filter.getType().template dyn_cast<RankedTensorType>();
if (!filter_type || filter_type.getRank() != 4) return failure();
if (!filter_type || filter_type.getRank() != 4 ||
!filter_type.hasStaticShape())
return failure();
// TensorFlow convolution op only has two inputs, while the TFLite one has
// three, with the bias vector marked as optional. However, TOCO has a
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册