提交 d37180a2 编写于 作者: A AshihsKrShrivastava

modification for upsample node fused from unfused Resize subgraph

上级 245b2fec
......@@ -1397,8 +1397,7 @@ void ONNXImporter::populateNet(Net dstNet)
CV_Assert(layer_id.find(node_proto.input(i)) == layer_id.end());
String interp_mode = layerParams.get<String>("coordinate_transformation_mode");
CV_Assert_N(interp_mode != "tf_crop_and_resize", interp_mode != "asymmetric",
interp_mode != "tf_half_pixel_for_nn");
CV_Assert_N(interp_mode != "tf_crop_and_resize", interp_mode != "tf_half_pixel_for_nn");
layerParams.set("align_corners", interp_mode == "align_corners");
Mat shapes = getBlob(node_proto, constBlobs, node_proto.input_size() - 1);
......@@ -1426,6 +1425,22 @@ void ONNXImporter::populateNet(Net dstNet)
}
else if (layer_type == "Upsample")
{
//fused from Resize Subgraph
if (layerParams.has("coordinate_transformation_mode"))
{
String interp_mode = layerParams.get<String>("coordinate_transformation_mode");
CV_Assert_N(interp_mode != "tf_crop_and_resize", interp_mode != "tf_half_pixel_for_nn");
layerParams.set("align_corners", interp_mode == "align_corners");
if (layerParams.get<String>("mode") == "linear")
{
layerParams.set("mode", interp_mode == "pytorch_half_pixel" ?
"opencv_linear" : "bilinear");
}
}
if (layerParams.get<String>("mode") == "linear" && framework_name == "pytorch")
layerParams.set("mode", "opencv_linear");
layerParams.type = "Resize";
if (layerParams.has("scales"))
{
......@@ -1435,22 +1450,21 @@ void ONNXImporter::populateNet(Net dstNet)
layerParams.set("zoom_factor_y", scales.getIntValue(2));
layerParams.set("zoom_factor_x", scales.getIntValue(3));
}
else
else if (layerParams.has("height_scale") && layerParams.has("width_scale"))
{
// Caffe2 layer
replaceLayerParam(layerParams, "height_scale", "zoom_factor_y");
replaceLayerParam(layerParams, "width_scale", "zoom_factor_x");
}
replaceLayerParam(layerParams, "mode", "interpolation");
if (layerParams.get<String>("interpolation") == "linear" && framework_name == "pytorch") {
layerParams.type = "Resize";
else
{
// scales as input
Mat scales = getBlob(node_proto, constBlobs, 1);
CV_Assert(scales.total() == 4);
layerParams.set("interpolation", "opencv_linear");
layerParams.set("zoom_factor_y", scales.at<float>(2));
layerParams.set("zoom_factor_x", scales.at<float>(3));
}
replaceLayerParam(layerParams, "mode", "interpolation");
}
else if (layer_type == "SoftMax" || layer_type == "LogSoftmax")
{
......
......@@ -369,6 +369,7 @@ TEST_P(Test_ONNX_layers, ResizeUnfused)
testONNXModels("upsample_unfused_opset9_torch1.4");
testONNXModels("resize_nearest_unfused_opset11_torch1.4");
testONNXModels("resize_nearest_unfused_opset11_torch1.3");
testONNXModels("resize_bilinear_unfused_opset11_torch1.4");
}
TEST_P(Test_ONNX_layers, MultyInputs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册