提交 739ff847 编写于 作者: S Smirnov Egor

add Max layer to TFImporter

上级 f7c82bae
......@@ -647,7 +647,7 @@ const TFImporter::DispatchMap TFImporter::buildDispatchMap()
dispatch["PriorBox"] = &TFImporter::parsePriorBox;
dispatch["Softmax"] = &TFImporter::parseSoftmax;
dispatch["CropAndResize"] = &TFImporter::parseCropAndResize;
dispatch["Mean"] = dispatch["Sum"] = &TFImporter::parseMean;
dispatch["Mean"] = dispatch["Sum"] = dispatch["Max"] = &TFImporter::parseMean;
dispatch["Pack"] = &TFImporter::parsePack;
dispatch["ClipByValue"] = &TFImporter::parseClipByValue;
dispatch["LeakyRelu"] = &TFImporter::parseLeakyRelu;
......@@ -657,6 +657,7 @@ const TFImporter::DispatchMap TFImporter::buildDispatchMap()
return dispatch;
}
// "Conv2D" "SpaceToBatchND" "DepthwiseConv2dNative" "Pad" "MirrorPad" "Conv3D"
void TFImporter::parseConvolution(tensorflow::GraphDef& net, const tensorflow::NodeDef& layer_, LayerParams& layerParams)
{
tensorflow::NodeDef layer = layer_;
......@@ -876,6 +877,7 @@ void TFImporter::parseConvolution(tensorflow::GraphDef& net, const tensorflow::N
data_layouts[name] = DATA_LAYOUT_NHWC;
}
// "BiasAdd" "Add" "AddV2" "Sub" "AddN"
void TFImporter::parseBias(tensorflow::GraphDef& net, const tensorflow::NodeDef& layer, LayerParams& layerParams)
{
const std::string& name = layer.name();
......@@ -1087,6 +1089,7 @@ void TFImporter::parseReshape(tensorflow::GraphDef& net, const tensorflow::NodeD
}
}
// "Flatten" "Squeeze"
void TFImporter::parseFlatten(tensorflow::GraphDef& net, const tensorflow::NodeDef& layer, LayerParams& layerParams)
{
const std::string& name = layer.name();
......@@ -1245,6 +1248,7 @@ void TFImporter::parseLrn(tensorflow::GraphDef& net, const tensorflow::NodeDef&
connectToAllBlobs(layer_id, dstNet, parsePin(layer.input(0)), id, num_inputs);
}
// "Concat" "ConcatV2"
void TFImporter::parseConcat(tensorflow::GraphDef& net, const tensorflow::NodeDef& layer, LayerParams& layerParams)
{
const std::string& name = layer.name();
......@@ -1295,6 +1299,7 @@ void TFImporter::parseConcat(tensorflow::GraphDef& net, const tensorflow::NodeDe
}
}
// "MaxPool" "MaxPool3D"
void TFImporter::parseMaxPool(tensorflow::GraphDef& net, const tensorflow::NodeDef& layer, LayerParams& layerParams)
{
const std::string& name = layer.name();
......@@ -1316,6 +1321,7 @@ void TFImporter::parseMaxPool(tensorflow::GraphDef& net, const tensorflow::NodeD
connectToAllBlobs(layer_id, dstNet, parsePin(inputName), id, num_inputs);
}
// "AvgPool" "AvgPool3D"
void TFImporter::parseAvgPool(tensorflow::GraphDef& net, const tensorflow::NodeDef& layer, LayerParams& layerParams)
{
const std::string& name = layer.name();
......@@ -1502,6 +1508,7 @@ void TFImporter::parseStridedSlice(tensorflow::GraphDef& net, const tensorflow::
connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
}
// "Mul" "RealDiv"
void TFImporter::parseMul(tensorflow::GraphDef& net, const tensorflow::NodeDef& layer, LayerParams& layerParams)
{
const std::string& name = layer.name();
......@@ -1659,6 +1666,7 @@ void TFImporter::parseMul(tensorflow::GraphDef& net, const tensorflow::NodeDef&
}
}
// "FusedBatchNorm" "FusedBatchNormV3"
void TFImporter::parseFusedBatchNorm(tensorflow::GraphDef& net, const tensorflow::NodeDef& layer, LayerParams& layerParams)
{
// op: "FusedBatchNorm"
......@@ -1918,6 +1926,7 @@ void TFImporter::parseBlockLSTM(tensorflow::GraphDef& net, const tensorflow::Nod
data_layouts[name] = DATA_LAYOUT_UNKNOWN;
}
// "ResizeNearestNeighbor" "ResizeBilinear" "FusedResizeAndPadConv2D"
void TFImporter::parseResize(tensorflow::GraphDef& net, const tensorflow::NodeDef& layer_, LayerParams& layerParams)
{
tensorflow::NodeDef layer = layer_;
......@@ -2106,6 +2115,7 @@ void TFImporter::parseCropAndResize(tensorflow::GraphDef& net, const tensorflow:
connect(layer_id, dstNet, parsePin(layer.input(1)), id, 1);
}
// "Mean" "Sum" "Max"
void TFImporter::parseMean(tensorflow::GraphDef& net, const tensorflow::NodeDef& layer, LayerParams& layerParams)
{
// Computes the mean of elements across dimensions of a tensor.
......@@ -2124,7 +2134,12 @@ void TFImporter::parseMean(tensorflow::GraphDef& net, const tensorflow::NodeDef&
const std::string& name = layer.name();
const std::string& type = layer.op();
const int num_inputs = layer.input_size();
std::string pool_type = cv::toLowerCase(type);
if (pool_type == "mean")
{
pool_type = "ave";
}
CV_CheckGT(num_inputs, 0, "");
Mat indices = getTensorContent(getConstBlob(layer, value_id, 1));
......@@ -2161,7 +2176,7 @@ void TFImporter::parseMean(tensorflow::GraphDef& net, const tensorflow::NodeDef&
LayerParams avgLp;
std::string avgName = name + "/avg";
CV_Assert(layer_id.find(avgName) == layer_id.end());
avgLp.set("pool", type == "Mean" ? "ave" : "sum");
avgLp.set("pool", pool_type);
// pooling kernel H x 1
avgLp.set("global_pooling_h", true);
avgLp.set("kernel_w", 1);
......@@ -2202,7 +2217,7 @@ void TFImporter::parseMean(tensorflow::GraphDef& net, const tensorflow::NodeDef&
int axis = toNCHW(indices.at<int>(0));
if (axis == 2 || axis == 3)
{
layerParams.set("pool", type == "Mean" ? "ave" : "sum");
layerParams.set("pool", pool_type);
layerParams.set(axis == 2 ? "kernel_w" : "kernel_h", 1);
layerParams.set(axis == 2 ? "global_pooling_h" : "global_pooling_w", true);
int id = dstNet.addLayer(name, "Pooling", layerParams);
......@@ -2234,7 +2249,7 @@ void TFImporter::parseMean(tensorflow::GraphDef& net, const tensorflow::NodeDef&
Pin inpId = parsePin(layer.input(0));
addPermuteLayer(order, name + "/nhwc", inpId);
layerParams.set("pool", type == "Mean" ? "ave" : "sum");
layerParams.set("pool", pool_type);
layerParams.set("kernel_h", 1);
layerParams.set("global_pooling_w", true);
int id = dstNet.addLayer(name, "Pooling", layerParams);
......@@ -2264,7 +2279,7 @@ void TFImporter::parseMean(tensorflow::GraphDef& net, const tensorflow::NodeDef&
if (indices.total() != 2 || indices.at<int>(0) != 1 || indices.at<int>(1) != 2)
CV_Error(Error::StsNotImplemented, "Unsupported mode of reduce_mean or reduce_sum operation.");
layerParams.set("pool", type == "Mean" ? "ave" : "sum");
layerParams.set("pool", pool_type);
layerParams.set("global_pooling", true);
int id = dstNet.addLayer(name, "Pooling", layerParams);
layer_id[name] = id;
......@@ -2368,6 +2383,7 @@ void TFImporter::parseLeakyRelu(tensorflow::GraphDef& net, const tensorflow::Nod
connectToAllBlobs(layer_id, dstNet, parsePin(layer.input(0)), id, num_inputs);
}
// "Abs" "Tanh" "Sigmoid" "Relu" "Elu" "Exp" "Identity" "Relu6"
void TFImporter::parseActivation(tensorflow::GraphDef& net, const tensorflow::NodeDef& layer, LayerParams& layerParams)
{
const std::string& name = layer.name();
......
......@@ -128,6 +128,13 @@ TEST_P(Test_TensorFlow_layers, reduce_mean)
runTensorFlowNet("global_pool_by_axis");
}
TEST_P(Test_TensorFlow_layers, reduce_max)
{
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019)
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER);
runTensorFlowNet("max_pool_by_axis");
}
TEST_P(Test_TensorFlow_layers, reduce_sum)
{
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019)
......@@ -135,11 +142,21 @@ TEST_P(Test_TensorFlow_layers, reduce_sum)
runTensorFlowNet("sum_pool_by_axis");
}
TEST_P(Test_TensorFlow_layers, reduce_max_channel)
{
runTensorFlowNet("reduce_max_channel");
}
TEST_P(Test_TensorFlow_layers, reduce_sum_channel)
{
runTensorFlowNet("reduce_sum_channel");
}
TEST_P(Test_TensorFlow_layers, reduce_max_channel_keep_dims)
{
runTensorFlowNet("reduce_max_channel", false, 0.0, 0.0, false, "_keep_dims");
}
TEST_P(Test_TensorFlow_layers, reduce_sum_channel_keep_dims)
{
runTensorFlowNet("reduce_sum_channel", false, 0.0, 0.0, false, "_keep_dims");
......@@ -386,6 +403,11 @@ TEST_P(Test_TensorFlow_layers, pooling_reduce_mean)
runTensorFlowNet("reduce_mean"); // an average pooling over all spatial dimensions.
}
TEST_P(Test_TensorFlow_layers, pooling_reduce_max)
{
runTensorFlowNet("reduce_max"); // a MAX pooling over all spatial dimensions.
}
TEST_P(Test_TensorFlow_layers, pooling_reduce_sum)
{
runTensorFlowNet("reduce_sum"); // a SUM pooling over all spatial dimensions.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册