提交 8574a757 编写于 作者: D Dmitry Kurtaev

Case sensitive dnn layers types

上级 7b82ad29
......@@ -4626,16 +4626,15 @@ void LayerFactory::registerLayer(const String &type, Constructor constructor)
CV_TRACE_ARG_VALUE(type, "type", type.c_str());
cv::AutoLock lock(getLayerFactoryMutex());
String type_ = type.toLowerCase();
LayerFactory_Impl::iterator it = getLayerFactoryImpl().find(type_);
LayerFactory_Impl::iterator it = getLayerFactoryImpl().find(type);
if (it != getLayerFactoryImpl().end())
{
if (it->second.back() == constructor)
CV_Error(cv::Error::StsBadArg, "Layer \"" + type_ + "\" already was registered");
CV_Error(cv::Error::StsBadArg, "Layer \"" + type + "\" already was registered");
it->second.push_back(constructor);
}
getLayerFactoryImpl().insert(std::make_pair(type_, std::vector<Constructor>(1, constructor)));
getLayerFactoryImpl().insert(std::make_pair(type, std::vector<Constructor>(1, constructor)));
}
void LayerFactory::unregisterLayer(const String &type)
......@@ -4644,9 +4643,8 @@ void LayerFactory::unregisterLayer(const String &type)
CV_TRACE_ARG_VALUE(type, "type", type.c_str());
cv::AutoLock lock(getLayerFactoryMutex());
String type_ = type.toLowerCase();
LayerFactory_Impl::iterator it = getLayerFactoryImpl().find(type_);
LayerFactory_Impl::iterator it = getLayerFactoryImpl().find(type);
if (it != getLayerFactoryImpl().end())
{
if (it->second.size() > 1)
......@@ -4662,8 +4660,7 @@ Ptr<Layer> LayerFactory::createLayerInstance(const String &type, LayerParams& pa
CV_TRACE_ARG_VALUE(type, "type", type.c_str());
cv::AutoLock lock(getLayerFactoryMutex());
String type_ = type.toLowerCase();
LayerFactory_Impl::const_iterator it = getLayerFactoryImpl().find(type_);
LayerFactory_Impl::const_iterator it = getLayerFactoryImpl().find(type);
if (it != getLayerFactoryImpl().end())
{
......
......@@ -95,6 +95,7 @@ void initializeLayerFactory()
CV_DNN_REGISTER_LAYER_CLASS(LRN, LRNLayer);
CV_DNN_REGISTER_LAYER_CLASS(InnerProduct, InnerProductLayer);
CV_DNN_REGISTER_LAYER_CLASS(Softmax, SoftmaxLayer);
CV_DNN_REGISTER_LAYER_CLASS(SoftMax, SoftmaxLayer); // For compatibility. See https://github.com/opencv/opencv/issues/16877
CV_DNN_REGISTER_LAYER_CLASS(MVN, MVNLayer);
CV_DNN_REGISTER_LAYER_CLASS(ReLU, ReLULayer);
......
......@@ -615,6 +615,15 @@ void ONNXImporter::populateNet(Net dstNet)
layerParams.type = "ReLU";
replaceLayerParam(layerParams, "alpha", "negative_slope");
}
else if (layer_type == "Relu")
{
layerParams.type = "ReLU";
}
else if (layer_type == "PRelu")
{
layerParams.type = "PReLU";
layerParams.blobs.push_back(getBlob(node_proto, constBlobs, 1));
}
else if (layer_type == "LRN")
{
replaceLayerParam(layerParams, "size", "local_size");
......@@ -1133,10 +1142,10 @@ void ONNXImporter::populateNet(Net dstNet)
layerParams.set("zoom_factor_x", scales.at<float>(3));
}
}
else if (layer_type == "LogSoftmax")
else if (layer_type == "SoftMax" || layer_type == "LogSoftmax")
{
layerParams.type = "Softmax";
layerParams.set("log_softmax", true);
layerParams.set("log_softmax", layer_type == "LogSoftmax");
}
else
{
......
......@@ -865,15 +865,10 @@ struct TorchImporter
layerParams.set("indices_blob_id", tensorParams["indices"].first);
curModule->modules.push_back(newModule);
}
else if (nnName == "SoftMax")
else if (nnName == "LogSoftMax" || nnName == "SoftMax")
{
newModule->apiType = "SoftMax";
curModule->modules.push_back(newModule);
}
else if (nnName == "LogSoftMax")
{
newModule->apiType = "SoftMax";
layerParams.set("log_softmax", true);
newModule->apiType = "Softmax";
layerParams.set("log_softmax", nnName == "LogSoftMax");
curModule->modules.push_back(newModule);
}
else if (nnName == "SpatialCrossMapLRN")
......
......@@ -431,7 +431,7 @@ TEST_P(SoftMax, Accuracy)
Backend backendId = get<0>(get<1>(GetParam()));
Target targetId = get<1>(get<1>(GetParam()));
LayerParams lp;
lp.type = "SoftMax";
lp.type = "Softmax";
lp.name = "testLayer";
int sz[] = {1, inChannels, 1, 1};
......
......@@ -70,7 +70,7 @@ public:
{
LayerParams lp;
Net netSoftmax;
netSoftmax.addLayerToPrev("softmaxLayer", "SoftMax", lp);
netSoftmax.addLayerToPrev("softmaxLayer", "Softmax", lp);
netSoftmax.setPreferableBackend(DNN_BACKEND_OPENCV);
netSoftmax.setInput(out);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册