From c33f9482309289ba457d7f2a1b8f7ea76b8b587b Mon Sep 17 00:00:00 2001 From: Jaesung Chung Date: Sun, 1 Nov 2020 22:13:09 -0800 Subject: [PATCH] Do not legalize tf.Conv2DOp when the filter type is not determined. Currently, the converter generates a wrong constant when the condition is hold. PiperOrigin-RevId: 340170033 Change-Id: Ia4af732f456254738c76840a32673ecff5f0fb6a --- tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir | 7 +++++++ tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc | 4 +++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir index 186c8631e56..88de48cf1f9 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-tf.mlir @@ -639,4 +639,11 @@ func @DontMatchFusedBatchNormV3(%arg0 :tensor, %arg1 : tensor<576 // CHECK: "tf.FusedBatchNormV3" } +// CHECK-LABEL: DoNotConvertConv2DWhenFilterTypeDimIsNotDecided +func @DoNotConvertConv2DWhenFilterTypeDimIsNotDecided(%arg0 : tensor, %arg1 : tensor<3x3x96x?xf32>) -> tensor { + %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor, tensor<3x3x96x?xf32>) -> tensor + return %0 : tensor +// CHECK: tf.Conv2D +} + } diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc index ef127659504..3fb5c2cc6f7 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc @@ -328,7 +328,9 @@ struct ConvertTFConvOp : public RewritePattern { // tensor, for setting depth_multiplier attribute, etc.). auto filter = tf_op.filter(); auto filter_type = filter.getType().template dyn_cast(); - if (!filter_type || filter_type.getRank() != 4) return failure(); + if (!filter_type || filter_type.getRank() != 4 || + !filter_type.hasStaticShape()) + return failure(); // TensorFlow convolution op only has two inputs, while the TFLite one has // three, with the bias vector marked as optional. However, TOCO has a -- GitLab