diff --git a/modules/dnn/src/onnx/onnx_graph_simplifier.cpp b/modules/dnn/src/onnx/onnx_graph_simplifier.cpp index 76937e08f3bd92c6c6d9fc406932e9517113807a..e4cf73fd07d3b3d45027ae66d0caf17be971cde9 100644 --- a/modules/dnn/src/onnx/onnx_graph_simplifier.cpp +++ b/modules/dnn/src/onnx/onnx_graph_simplifier.cpp @@ -107,17 +107,10 @@ private: opencv_onnx::GraphProto& net; }; -class SoftMaxSubgraph : public Subgraph +class SoftMaxSubgraphBase : public Subgraph { public: - SoftMaxSubgraph() : axis(1) - { - int input = addNodeToMatch(""); - int inpExp = addNodeToMatch("Exp", input); - int sum = addNodeToMatch("ReduceSum", inpExp); - addNodeToMatch("Div", inpExp, sum); - setFusedNode("Softmax", input); - } + SoftMaxSubgraphBase() : axis(1), id(-1) {} virtual bool match(const Ptr& net, int nodeId, std::vector& matchedNodesIds, @@ -125,7 +118,8 @@ public: { if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds)) { - Ptr sum = net->getNode(matchedNodesIds[1]); + CV_Assert(id >= 0 && id < matchedNodesIds.size()); + Ptr sum = net->getNode(matchedNodesIds[id]); opencv_onnx::NodeProto* node = sum.dynamicCast()->node; for (int i = 0; i < node->attribute_size(); i++) @@ -153,8 +147,60 @@ public: attr->set_i(axis); } -private: +protected: int axis; + int id; +}; + +class SoftMaxSubgraph : public SoftMaxSubgraphBase +{ +public: + SoftMaxSubgraph() + { + int input = addNodeToMatch(""); + int inpExp = addNodeToMatch("Exp", input); + + int sum = addNodeToMatch("ReduceSum", inpExp); + id = 1; + + addNodeToMatch("Div", inpExp, sum); + setFusedNode("Softmax", input); + } +}; + +class SoftMaxSubgraph2 : public SoftMaxSubgraphBase { +public: + SoftMaxSubgraph2() { + int input = addNodeToMatch(""); + + int reducemax = addNodeToMatch("ReduceMax", input); + id = 0; + + int sub = addNodeToMatch("Sub", input, reducemax); + int exp = addNodeToMatch("Exp", sub); + int reducesum = addNodeToMatch("ReduceSum", exp, addNodeToMatch("")); + addNodeToMatch("Div", exp, reducesum); + setFusedNode("Softmax", input); + } +}; + +class LogSoftMaxSubgraph : public SoftMaxSubgraphBase +{ +public: + LogSoftMaxSubgraph() + { + int input = addNodeToMatch(""); + + int reducemax = addNodeToMatch("ReduceMax", input); + id = 0; + + int sub_1 = addNodeToMatch("Sub", input, reducemax); + int exp = addNodeToMatch("Exp", sub_1); + int reducesum = addNodeToMatch("ReduceSum", exp, addNodeToMatch("")); + int log = addNodeToMatch("Log", reducesum); + addNodeToMatch("Sub", sub_1, log); + setFusedNode("LogSoftmax", input); + } }; class NormalizeSubgraphBase : public Subgraph @@ -574,6 +620,8 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net) subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); + subgraphs.push_back(makePtr()); + subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr()); subgraphs.push_back(makePtr());