提交 598039c0 编写于 作者: D Dmitry Kurtaev

Fix embedded Torch's nn.ConcatTable

上级 dbcb4549
......@@ -101,6 +101,8 @@ struct TorchImporter
std::set<int> readedIndexes;
std::map<int, Mat> storages;
std::map<int, Mat> tensors;
// Stack with numbers of unconnected layers per scope (Sequential, ConcatTable etc.)
std::vector<int> numUnconnectedLayers;
struct Module
{
......@@ -489,15 +491,7 @@ struct TorchImporter
layerParams.set("inputDimension", scalarParams.get<int>("inputDimension"));
layerParams.set("outputDimension", scalarParams.get<int>("outputDimension"));
}
if (nnName == "Concat")
{
layerParams.set("dimension", scalarParams.get<int>("dimension"));
}
if (nnName == "JoinTable")
{
layerParams.set("dimension", scalarParams.get<int>("dimension"));
}
if (nnName == "DepthConcat")
else if (nnName == "Concat" || nnName == "JoinTable" || nnName == "DepthConcat")
{
layerParams.set("dimension", scalarParams.get<int>("dimension"));
}
......@@ -1096,6 +1090,7 @@ struct TorchImporter
{
newId = fill(module->modules[i], addedModules, prevLayerId, prevOutNum);
}
numUnconnectedLayers.push_back(module->modules.size());
return newId;
}
else if (module->thName == "JoinTable") {
......@@ -1108,9 +1103,14 @@ struct TorchImporter
mergeId = net.addLayer(generateLayerName("torchMerge"), "Concat", mergeParams);
addedModules.push_back(std::make_pair(mergeId, module));
for (int i = 0; i < ids.size(); i++)
// Connect to the last number of unconnected layers.
CV_Assert(!numUnconnectedLayers.empty());
const int numInputs = numUnconnectedLayers.back();
numUnconnectedLayers.pop_back();
CV_Assert(numInputs <= ids.size());
for (int i = 0; i < numInputs; i++)
{
net.connect(ids[i], 0, mergeId, i);
net.connect(ids[ids.size() - numInputs + i], 0, mergeId, i);
}
return mergeId;
......@@ -1124,9 +1124,14 @@ struct TorchImporter
int id = net.addLayer(name, "Eltwise", params);
for (int i = 0; i < ids.size(); i++)
// Connect to the last number of unconnected layers.
CV_Assert(!numUnconnectedLayers.empty());
const int numInputs = numUnconnectedLayers.back();
numUnconnectedLayers.pop_back();
CV_Assert(numInputs <= ids.size());
for (int i = 0; i < numInputs; i++)
{
net.connect(ids[i], 0, id, i);
net.connect(ids[ids.size() - numInputs + i], 0, id, i);
}
addedModules.push_back(std::make_pair(id, module));
......
......@@ -320,4 +320,9 @@ TEST(Torch_Importer, DISABLED_run_paralel)
runTorchNet("net_parallel", DNN_TARGET_OPENCL, "l5_torchMerge");
}
TEST(Torch_Importer, net_residual)
{
runTorchNet("net_residual", DNN_TARGET_CPU, "", false, true);
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册