Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
OpenCV
opencv
提交
84336202
O
opencv
项目概览
OpenCV
/
opencv
上一次同步 9 个月
通知
992
Star
71100
Fork
55581
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
O
opencv
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
84336202
编写于
3月 22, 2020
作者:
D
Dmitry Kurtaev
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Bidirectional LSTM
上级
11d565ca
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
116 addition
and
94 deletion
+116
-94
modules/dnn/src/layers/recurrent_layers.cpp
modules/dnn/src/layers/recurrent_layers.cpp
+86
-76
modules/dnn/src/onnx/onnx_importer.cpp
modules/dnn/src/onnx/onnx_importer.cpp
+25
-18
modules/dnn/test/test_onnx_importer.cpp
modules/dnn/test/test_onnx_importer.cpp
+5
-0
未找到文件。
modules/dnn/src/layers/recurrent_layers.cpp
浏览文件 @
84336202
...
...
@@ -93,6 +93,7 @@ class LSTMLayerImpl CV_FINAL : public LSTMLayer
float
forgetBias
,
cellClip
;
bool
useCellClip
,
usePeephole
;
bool
reverse
;
// If true, go in negative direction along the time axis
bool
bidirectional
;
// If true, produces both forward and reversed directions along time axis
public:
...
...
@@ -101,6 +102,7 @@ public:
{
setParamsFrom
(
params
);
bidirectional
=
params
.
get
<
bool
>
(
"bidirectional"
,
false
);
if
(
!
blobs
.
empty
())
{
CV_Assert
(
blobs
.
size
()
>=
3
);
...
...
@@ -113,7 +115,7 @@ public:
CV_CheckEQ
(
Wh
.
dims
,
2
,
""
);
CV_CheckEQ
(
Wx
.
dims
,
2
,
""
);
CV_CheckEQ
(
Wh
.
rows
,
Wx
.
rows
,
""
);
CV_CheckEQ
(
Wh
.
rows
,
4
*
Wh
.
cols
,
""
);
CV_CheckEQ
(
Wh
.
rows
,
(
1
+
static_cast
<
int
>
(
bidirectional
))
*
4
*
Wh
.
cols
,
""
);
CV_CheckEQ
(
Wh
.
rows
,
(
int
)
bias
.
total
(),
""
);
CV_Assert
(
Wh
.
type
()
==
Wx
.
type
()
&&
Wx
.
type
()
==
bias
.
type
());
...
...
@@ -136,6 +138,7 @@ public:
useCellClip
=
params
.
get
<
bool
>
(
"use_cell_clip"
,
false
);
usePeephole
=
params
.
get
<
bool
>
(
"use_peephole"
,
false
);
reverse
=
params
.
get
<
bool
>
(
"reverse"
,
false
);
CV_Assert
(
!
reverse
||
!
bidirectional
);
allocated
=
false
;
outTailShape
.
clear
();
...
...
@@ -207,6 +210,7 @@ public:
outResShape
.
push_back
(
_numSamples
);
outResShape
.
insert
(
outResShape
.
end
(),
outTailShape_
.
begin
(),
outTailShape_
.
end
());
outResShape
.
back
()
*=
(
1
+
static_cast
<
int
>
(
bidirectional
));
size_t
noutputs
=
produceCellOutput
?
2
:
1
;
outputs
.
assign
(
noutputs
,
outResShape
);
...
...
@@ -253,6 +257,7 @@ public:
outTsShape
.
clear
();
outTsShape
.
push_back
(
numSamples
);
outTsShape
.
insert
(
outTsShape
.
end
(),
outTailShape
.
begin
(),
outTailShape
.
end
());
outTsShape
.
back
()
*=
(
1
+
static_cast
<
int
>
(
bidirectional
));
allocated
=
true
;
}
...
...
@@ -273,91 +278,96 @@ public:
outputs_arr
.
getMatVector
(
output
);
internals_arr
.
getMatVector
(
internals
);
const
Mat
&
Wh
=
blobs
[
0
];
const
Mat
&
Wx
=
blobs
[
1
];
const
Mat
&
bias
=
blobs
[
2
];
int
numOut
=
Wh
.
size
[
1
];
Mat
hInternal
=
internals
[
0
],
cInternal
=
internals
[
1
],
dummyOnes
=
internals
[
2
],
gates
=
internals
[
3
];
hInternal
.
setTo
(
0.
);
cInternal
.
setTo
(
0.
);
dummyOnes
.
setTo
(
1.
);
int
numSamplesTotal
=
numTimeStamps
*
numSamples
;
Mat
xTs
=
input
[
0
].
reshape
(
1
,
numSamplesTotal
);
Mat
hOutTs
=
output
[
0
].
reshape
(
1
,
numSamplesTotal
);
Mat
cOutTs
=
produceCellOutput
?
output
[
1
].
reshape
(
1
,
numSamplesTotal
)
:
Mat
();
int
tsStart
,
tsEnd
,
tsInc
;
if
(
reverse
)
{
tsStart
=
numTimeStamps
-
1
;
tsEnd
=
-
1
;
tsInc
=
-
1
;
}
else
{
tsStart
=
0
;
tsEnd
=
numTimeStamps
;
tsInc
=
1
;
}
for
(
int
ts
=
tsStart
;
ts
!=
tsEnd
;
ts
+=
tsInc
)
const
int
numDirs
=
1
+
static_cast
<
int
>
(
bidirectional
);
for
(
int
i
=
0
;
i
<
numDirs
;
++
i
)
{
Range
curRowRange
(
ts
*
numSamples
,
(
ts
+
1
)
*
numSamples
);
Mat
xCurr
=
xTs
.
rowRange
(
curRowRange
);
const
Mat
&
Wh
=
blobs
[
0
].
rowRange
(
i
*
blobs
[
0
].
rows
/
numDirs
,
(
i
+
1
)
*
blobs
[
0
].
rows
/
numDirs
);
const
Mat
&
Wx
=
blobs
[
1
].
rowRange
(
i
*
blobs
[
1
].
rows
/
numDirs
,
(
i
+
1
)
*
blobs
[
1
].
rows
/
numDirs
);
const
Mat
&
bias
=
blobs
[
2
].
colRange
(
i
*
blobs
[
2
].
cols
/
numDirs
,
(
i
+
1
)
*
blobs
[
2
].
cols
/
numDirs
);
int
numOut
=
Wh
.
size
[
1
];
Mat
hInternal
=
internals
[
0
],
cInternal
=
internals
[
1
],
dummyOnes
=
internals
[
2
],
gates
=
internals
[
3
];
hInternal
.
setTo
(
0.
);
cInternal
.
setTo
(
0.
);
dummyOnes
.
setTo
(
1.
);
int
numSamplesTotal
=
numTimeStamps
*
numSamples
;
Mat
xTs
=
input
[
0
].
reshape
(
1
,
numSamplesTotal
);
Mat
hOutTs
=
output
[
0
].
reshape
(
1
,
numSamplesTotal
);
hOutTs
=
hOutTs
.
colRange
(
i
*
hOutTs
.
cols
/
numDirs
,
(
i
+
1
)
*
hOutTs
.
cols
/
numDirs
);
Mat
cOutTs
=
produceCellOutput
?
output
[
1
].
reshape
(
1
,
numSamplesTotal
)
:
Mat
();
int
tsStart
,
tsEnd
,
tsInc
;
if
(
reverse
||
i
==
1
)
{
tsStart
=
numTimeStamps
-
1
;
tsEnd
=
-
1
;
tsInc
=
-
1
;
}
else
{
tsStart
=
0
;
tsEnd
=
numTimeStamps
;
tsInc
=
1
;
}
for
(
int
ts
=
tsStart
;
ts
!=
tsEnd
;
ts
+=
tsInc
)
{
Range
curRowRange
(
ts
*
numSamples
,
(
ts
+
1
)
*
numSamples
);
Mat
xCurr
=
xTs
.
rowRange
(
curRowRange
);
gemm
(
xCurr
,
Wx
,
1
,
gates
,
0
,
gates
,
GEMM_2_T
);
// Wx * x_t
gemm
(
hInternal
,
Wh
,
1
,
gates
,
1
,
gates
,
GEMM_2_T
);
//+Wh * h_{t-1}
gemm
(
dummyOnes
,
bias
,
1
,
gates
,
1
,
gates
);
//+b
gemm
(
xCurr
,
Wx
,
1
,
gates
,
0
,
gates
,
GEMM_2_T
);
// Wx * x_t
gemm
(
hInternal
,
Wh
,
1
,
gates
,
1
,
gates
,
GEMM_2_T
);
//+Wh * h_{t-1}
gemm
(
dummyOnes
,
bias
,
1
,
gates
,
1
,
gates
);
//+b
Mat
gateI
=
gates
.
colRange
(
0
*
numOut
,
1
*
numOut
);
Mat
gateF
=
gates
.
colRange
(
1
*
numOut
,
2
*
numOut
);
Mat
gateO
=
gates
.
colRange
(
2
*
numOut
,
3
*
numOut
);
Mat
gateG
=
gates
.
colRange
(
3
*
numOut
,
4
*
numOut
);
Mat
gateI
=
gates
.
colRange
(
0
*
numOut
,
1
*
numOut
);
Mat
gateF
=
gates
.
colRange
(
1
*
numOut
,
2
*
numOut
);
Mat
gateO
=
gates
.
colRange
(
2
*
numOut
,
3
*
numOut
);
Mat
gateG
=
gates
.
colRange
(
3
*
numOut
,
4
*
numOut
);
if
(
forgetBias
)
add
(
gateF
,
forgetBias
,
gateF
);
if
(
forgetBias
)
add
(
gateF
,
forgetBias
,
gateF
);
if
(
usePeephole
)
{
Mat
gatesIF
=
gates
.
colRange
(
0
,
2
*
numOut
);
gemm
(
cInternal
,
blobs
[
3
],
1
,
gateI
,
1
,
gateI
);
gemm
(
cInternal
,
blobs
[
4
],
1
,
gateF
,
1
,
gateF
);
sigmoid
(
gatesIF
,
gatesIF
);
}
else
{
Mat
gatesIFO
=
gates
.
colRange
(
0
,
3
*
numOut
);
sigmoid
(
gatesIFO
,
gatesIFO
);
}
if
(
usePeephole
)
{
Mat
gatesIF
=
gates
.
colRange
(
0
,
2
*
numOut
);
gemm
(
cInternal
,
blobs
[
3
],
1
,
gateI
,
1
,
gateI
);
gemm
(
cInternal
,
blobs
[
4
],
1
,
gateF
,
1
,
gateF
);
sigmoid
(
gatesIF
,
gatesIF
);
}
else
{
Mat
gatesIFO
=
gates
.
colRange
(
0
,
3
*
numOut
);
sigmoid
(
gatesIFO
,
gatesIFO
);
}
tanh
(
gateG
,
gateG
);
tanh
(
gateG
,
gateG
);
//compute c_t
multiply
(
gateF
,
cInternal
,
gateF
);
// f_t (*) c_{t-1}
multiply
(
gateI
,
gateG
,
gateI
);
// i_t (*) g_t
add
(
gateF
,
gateI
,
cInternal
);
// c_t = f_t (*) c_{t-1} + i_t (*) g_t
//compute c_t
multiply
(
gateF
,
cInternal
,
gateF
);
// f_t (*) c_{t-1}
multiply
(
gateI
,
gateG
,
gateI
);
// i_t (*) g_t
add
(
gateF
,
gateI
,
cInternal
);
// c_t = f_t (*) c_{t-1} + i_t (*) g_t
if
(
useCellClip
)
{
min
(
cInternal
,
cellClip
,
cInternal
);
max
(
cInternal
,
-
cellClip
,
cInternal
);
}
if
(
usePeephole
)
{
gemm
(
cInternal
,
blobs
[
5
],
1
,
gateO
,
1
,
gateO
);
sigmoid
(
gateO
,
gateO
);
}
if
(
useCellClip
)
{
min
(
cInternal
,
cellClip
,
cInternal
);
max
(
cInternal
,
-
cellClip
,
cInternal
);
}
if
(
usePeephole
)
{
gemm
(
cInternal
,
blobs
[
5
],
1
,
gateO
,
1
,
gateO
);
sigmoid
(
gateO
,
gateO
);
}
//compute h_t
tanh
(
cInternal
,
hInternal
);
multiply
(
gateO
,
hInternal
,
hInternal
);
//compute h_t
tanh
(
cInternal
,
hInternal
);
multiply
(
gateO
,
hInternal
,
hInternal
);
//save results in output blobs
hInternal
.
copyTo
(
hOutTs
.
rowRange
(
curRowRange
));
if
(
produceCellOutput
)
cInternal
.
copyTo
(
cOutTs
.
rowRange
(
curRowRange
));
//save results in output blobs
hInternal
.
copyTo
(
hOutTs
.
rowRange
(
curRowRange
));
if
(
produceCellOutput
)
cInternal
.
copyTo
(
cOutTs
.
rowRange
(
curRowRange
));
}
}
}
};
...
...
modules/dnn/src/onnx/onnx_importer.cpp
浏览文件 @
84336202
...
...
@@ -630,37 +630,44 @@ void ONNXImporter::populateNet(Net dstNet)
Mat
Wx
=
getBlob
(
node_proto
,
constBlobs
,
1
);
Mat
Wh
=
getBlob
(
node_proto
,
constBlobs
,
2
);
Mat
b
=
getBlob
(
node_proto
,
constBlobs
,
3
);
b
=
b
.
reshape
(
1
,
b
.
size
[
0
]);
const
int
numHidden
=
lstmParams
.
get
<
int
>
(
"hidden_size"
);
Wx
=
Wx
.
reshape
(
1
,
Wx
.
size
[
1
])
;
Wh
=
Wh
.
reshape
(
1
,
Wh
.
size
[
1
]
);
b
=
b
.
reshape
(
1
,
2
);
reduce
(
b
,
b
,
0
,
REDUCE_SUM
)
;
const
int
numDirs
=
Wx
.
size
[
0
];
// Is 1 for forward only and 2 for bidirectional LSTM.
const
int
numFeatures
=
Wx
.
size
[
2
]
;
Mat
bx
=
b
.
colRange
(
0
,
b
.
cols
/
2
);
Mat
bh
=
b
.
colRange
(
b
.
cols
/
2
,
b
.
cols
);
b
=
bx
+
bh
;
// IFGO->IGFO
float
*
WxData
=
(
float
*
)
Wx
.
data
;
float
*
WhData
=
(
float
*
)
Wh
.
data
;
float
*
biasData
=
(
float
*
)
b
.
data
;
for
(
int
j
=
0
;
j
<
numHidden
;
++
j
)
for
(
int
k
=
0
;
k
<
numDirs
;
++
k
)
{
for
(
int
i
=
0
;
i
<
Wx
.
cols
;
++
i
)
{
std
::
swap
(
WxData
[(
numHidden
+
j
)
*
Wx
.
cols
+
i
],
WxData
[(
numHidden
*
2
+
j
)
*
Wx
.
cols
+
i
]);
}
for
(
int
i
=
0
;
i
<
Wh
.
cols
;
++
i
)
float
*
WxData
=
Wx
.
ptr
<
float
>
(
k
);
float
*
WhData
=
Wh
.
ptr
<
float
>
(
k
);
float
*
biasData
=
b
.
ptr
<
float
>
(
k
);
for
(
int
j
=
0
;
j
<
numHidden
;
++
j
)
{
std
::
swap
(
WhData
[(
numHidden
+
j
)
*
Wh
.
cols
+
i
],
WhData
[(
numHidden
*
2
+
j
)
*
Wh
.
cols
+
i
]);
for
(
int
i
=
0
;
i
<
numFeatures
;
++
i
)
{
std
::
swap
(
WxData
[(
numHidden
+
j
)
*
numFeatures
+
i
],
WxData
[(
numHidden
*
2
+
j
)
*
numFeatures
+
i
]);
}
for
(
int
i
=
0
;
i
<
numHidden
;
++
i
)
{
std
::
swap
(
WhData
[(
numHidden
+
j
)
*
numHidden
+
i
],
WhData
[(
numHidden
*
2
+
j
)
*
numHidden
+
i
]);
}
std
::
swap
(
biasData
[
numHidden
+
j
],
biasData
[
numHidden
*
2
+
j
]);
}
std
::
swap
(
biasData
[
numHidden
+
j
],
biasData
[
numHidden
*
2
+
j
]);
}
Wx
=
Wx
.
reshape
(
1
,
Wx
.
size
[
0
]
*
Wx
.
size
[
1
]);
Wh
=
Wh
.
reshape
(
1
,
Wh
.
size
[
0
]
*
Wh
.
size
[
1
]);
lstmParams
.
blobs
.
resize
(
3
);
lstmParams
.
blobs
[
0
]
=
Wh
;
lstmParams
.
blobs
[
1
]
=
Wx
;
lstmParams
.
blobs
[
2
]
=
b
;
lstmParams
.
set
(
"bidirectional"
,
lstmParams
.
get
<
String
>
(
"direction"
,
""
)
==
"bidirectional"
);
node_proto
.
set_output
(
0
,
lstmParams
.
name
);
// set different name so output shapes will be registered on that name
addLayer
(
dstNet
,
lstmParams
,
node_proto
,
layer_id
,
outShapes
);
...
...
modules/dnn/test/test_onnx_importer.cpp
浏览文件 @
84336202
...
...
@@ -456,6 +456,11 @@ TEST_P(Test_ONNX_layers, LSTM)
testONNXModels
(
"lstm"
);
}
TEST_P
(
Test_ONNX_layers
,
LSTM_bidirectional
)
{
testONNXModels
(
"lstm_bidirectional"
);
}
INSTANTIATE_TEST_CASE_P
(
/*nothing*/
,
Test_ONNX_layers
,
dnnBackendsAndTargets
());
class
Test_ONNX_nets
:
public
Test_ONNX_layers
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录