Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
stoneliu1981
pytorch-image-models
提交
9cc7dda6
P
pytorch-image-models
项目概览
stoneliu1981
/
pytorch-image-models
与 Fork 源项目一致
从无法访问的项目Fork
通知
7
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
pytorch-image-models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
9cc7dda6
编写于
4月 29, 2021
作者:
R
Ross Wightman
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fixup byoanet configs to pass unit tests. Add swin_attn and swinnet26t model for testing.
上级
e15c3886
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
216 addition
and
7 deletion
+216
-7
timm/models/byoanet.py
timm/models/byoanet.py
+33
-7
timm/models/layers/create_self_attn.py
timm/models/layers/create_self_attn.py
+5
-0
timm/models/layers/swin_attn.py
timm/models/layers/swin_attn.py
+178
-0
未找到文件。
timm/models/byoanet.py
浏览文件 @
9cc7dda6
...
...
@@ -35,7 +35,7 @@ __all__ = ['ByoaNet']
def
_cfg
(
url
=
''
,
**
kwargs
):
return
{
'url'
:
url
,
'num_classes'
:
1000
,
'input_size'
:
(
3
,
224
,
224
),
'pool_size'
:
(
7
,
7
),
'crop_pct'
:
0.875
,
'interpolation'
:
'bi
linear
'
,
'crop_pct'
:
0.875
,
'interpolation'
:
'bi
cubic
'
,
'mean'
:
IMAGENET_DEFAULT_MEAN
,
'std'
:
IMAGENET_DEFAULT_STD
,
'first_conv'
:
'stem.conv1.conv'
,
'classifier'
:
'head.fc'
,
'fixed_input_size'
:
False
,
'min_input_size'
:
(
3
,
224
,
224
),
...
...
@@ -45,17 +45,19 @@ def _cfg(url='', **kwargs):
default_cfgs
=
{
# GPU-Efficient (ResNet) weights
'botnet26t_256'
:
_cfg
(
url
=
''
,
fixed_input_size
=
True
,
input_size
=
(
3
,
256
,
256
)),
'botnet26t_256'
:
_cfg
(
url
=
''
,
fixed_input_size
=
True
,
input_size
=
(
3
,
256
,
256
)
,
pool_size
=
(
8
,
8
)
),
'botnet50t_224'
:
_cfg
(
url
=
''
,
fixed_input_size
=
True
),
'botnet50t_c4c5_224'
:
_cfg
(
url
=
''
,
fixed_input_size
=
True
),
'halonet_h1'
:
_cfg
(
url
=
''
,
input_size
=
(
3
,
256
,
256
),
pool_size
=
(
8
,
8
),
min_input_size
=
(
3
,
256
,
256
)),
'halonet_h1_c4c5'
:
_cfg
(
url
=
''
,
input_size
=
(
3
,
256
,
256
),
pool_size
=
(
8
,
8
),
min_input_size
=
(
3
,
256
,
256
)),
'halonet26t'
:
_cfg
(
url
=
''
,
input_size
=
(
3
,
256
,
256
)),
'halonet50t'
:
_cfg
(
url
=
''
),
'halonet26t'
:
_cfg
(
url
=
''
,
input_size
=
(
3
,
256
,
256
)
,
pool_size
=
(
8
,
8
),
min_input_size
=
(
3
,
256
,
256
)
),
'halonet50t'
:
_cfg
(
url
=
''
,
min_input_size
=
(
3
,
224
,
224
)
),
'lambda_resnet26t'
:
_cfg
(
url
=
''
,
min_input_size
=
(
3
,
128
,
128
),
input_size
=
(
3
,
256
,
256
)),
'lambda_resnet26t'
:
_cfg
(
url
=
''
,
min_input_size
=
(
3
,
128
,
128
),
input_size
=
(
3
,
256
,
256
)
,
pool_size
=
(
8
,
8
)
),
'lambda_resnet50t'
:
_cfg
(
url
=
''
,
min_input_size
=
(
3
,
128
,
128
)),
'swinnet26t_256'
:
_cfg
(
url
=
''
,
fixed_input_size
=
True
,
input_size
=
(
3
,
256
,
256
),
pool_size
=
(
8
,
8
)),
}
...
...
@@ -95,10 +97,10 @@ model_cfgs = dict(
botnet26t
=
ByoaCfg
(
blocks
=
(
ByoaBlocksCfg
(
type
=
'bottle'
,
d
=
3
,
c
=
256
,
s
=
2
,
gs
=
0
,
br
=
0.25
),
ByoaBlocksCfg
(
type
=
'bottle'
,
d
=
3
,
c
=
256
,
s
=
1
,
gs
=
0
,
br
=
0.25
),
ByoaBlocksCfg
(
type
=
'bottle'
,
d
=
4
,
c
=
512
,
s
=
2
,
gs
=
0
,
br
=
0.25
),
interleave_attn
(
types
=
(
'bottle'
,
'self_attn'
),
every
=
1
,
d
=
2
,
c
=
1024
,
s
=
2
,
gs
=
0
,
br
=
0.25
),
ByoaBlocksCfg
(
type
=
'self_attn'
,
d
=
3
,
c
=
2048
,
s
=
1
,
gs
=
0
,
br
=
0.25
),
ByoaBlocksCfg
(
type
=
'self_attn'
,
d
=
3
,
c
=
2048
,
s
=
2
,
gs
=
0
,
br
=
0.25
),
),
stem_chs
=
64
,
stem_type
=
'tiered'
,
...
...
@@ -230,6 +232,22 @@ model_cfgs = dict(
self_attn_layer
=
'lambda'
,
self_attn_kwargs
=
dict
()
),
swinnet26t
=
ByoaCfg
(
blocks
=
(
ByoaBlocksCfg
(
type
=
'bottle'
,
d
=
3
,
c
=
256
,
s
=
1
,
gs
=
0
,
br
=
0.25
),
ByoaBlocksCfg
(
type
=
'bottle'
,
d
=
4
,
c
=
512
,
s
=
2
,
gs
=
0
,
br
=
0.25
),
interleave_attn
(
types
=
(
'bottle'
,
'self_attn'
),
every
=
1
,
d
=
2
,
c
=
1024
,
s
=
2
,
gs
=
0
,
br
=
0.25
),
ByoaBlocksCfg
(
type
=
'self_attn'
,
d
=
3
,
c
=
2048
,
s
=
2
,
gs
=
0
,
br
=
0.25
),
),
stem_chs
=
64
,
stem_type
=
'tiered'
,
stem_pool
=
'maxpool'
,
num_features
=
0
,
self_attn_layer
=
'swin'
,
self_attn_fixed_size
=
True
,
self_attn_kwargs
=
dict
(
win_size
=
8
)
),
)
...
...
@@ -452,3 +470,11 @@ def lambda_resnet50t(pretrained=False, **kwargs):
""" Lambda-ResNet-50T. Lambda layers in one C4 stage and all C5.
"""
return
_create_byoanet
(
'lambda_resnet50t'
,
pretrained
=
pretrained
,
**
kwargs
)
@
register_model
def
swinnet26t_256
(
pretrained
=
False
,
**
kwargs
):
"""
"""
kwargs
.
setdefault
(
'img_size'
,
256
)
return
_create_byoanet
(
'swinnet26t_256'
,
'swinnet26t'
,
pretrained
=
pretrained
,
**
kwargs
)
timm/models/layers/create_self_attn.py
浏览文件 @
9cc7dda6
from
.bottleneck_attn
import
BottleneckAttn
from
.halo_attn
import
HaloAttn
from
.lambda_layer
import
LambdaLayer
from
.swin_attn
import
WindowAttention
def
get_self_attn
(
attn_type
):
...
...
@@ -10,6 +11,10 @@ def get_self_attn(attn_type):
return
HaloAttn
elif
attn_type
==
'lambda'
:
return
LambdaLayer
elif
attn_type
==
'swin'
:
return
WindowAttention
else
:
assert
False
,
f
"Unknown attn type (
{
attn_type
}
)"
def
create_self_attn
(
attn_type
,
dim
,
stride
=
1
,
**
kwargs
):
...
...
timm/models/layers/swin_attn.py
0 → 100644
浏览文件 @
9cc7dda6
""" Shifted Window Attn
This is a WIP experiment to apply windowed attention from the Swin Transformer
to a stand-alone module for use as an attn block in conv nets.
Based on original swin window code at https://github.com/microsoft/Swin-Transformer
Swin Transformer paper: https://arxiv.org/pdf/2103.14030.pdf
"""
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
from
.drop
import
DropPath
from
.helpers
import
to_2tuple
from
.weight_init
import
trunc_normal_
def
window_partition
(
x
,
win_size
:
int
):
"""
Args:
x: (B, H, W, C)
win_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B
,
H
,
W
,
C
=
x
.
shape
x
=
x
.
view
(
B
,
H
//
win_size
,
win_size
,
W
//
win_size
,
win_size
,
C
)
windows
=
x
.
permute
(
0
,
1
,
3
,
2
,
4
,
5
).
contiguous
().
view
(
-
1
,
win_size
,
win_size
,
C
)
return
windows
def
window_reverse
(
windows
,
win_size
:
int
,
H
:
int
,
W
:
int
):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
win_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B
=
int
(
windows
.
shape
[
0
]
/
(
H
*
W
/
win_size
/
win_size
))
x
=
windows
.
view
(
B
,
H
//
win_size
,
W
//
win_size
,
win_size
,
win_size
,
-
1
)
x
=
x
.
permute
(
0
,
1
,
3
,
2
,
4
,
5
).
contiguous
().
view
(
B
,
H
,
W
,
-
1
)
return
x
class
WindowAttention
(
nn
.
Module
):
r
""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
win_size (int): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
"""
def
__init__
(
self
,
dim
,
dim_out
=
None
,
feat_size
=
None
,
stride
=
1
,
win_size
=
8
,
shift_size
=
None
,
num_heads
=
8
,
qkv_bias
=
True
,
attn_drop
=
0.
):
super
().
__init__
()
self
.
dim_out
=
dim_out
or
dim
self
.
feat_size
=
to_2tuple
(
feat_size
)
self
.
win_size
=
win_size
self
.
shift_size
=
shift_size
or
win_size
//
2
if
min
(
self
.
feat_size
)
<=
win_size
:
# if window size is larger than input resolution, we don't partition windows
self
.
shift_size
=
0
self
.
win_size
=
min
(
self
.
feat_size
)
assert
0
<=
self
.
shift_size
<
self
.
win_size
,
"shift_size must in 0-window_size"
self
.
num_heads
=
num_heads
head_dim
=
self
.
dim_out
//
num_heads
self
.
scale
=
head_dim
**
-
0.5
if
self
.
shift_size
>
0
:
# calculate attention mask for SW-MSA
H
,
W
=
self
.
feat_size
img_mask
=
torch
.
zeros
((
1
,
H
,
W
,
1
))
# 1 H W 1
h_slices
=
(
slice
(
0
,
-
self
.
win_size
),
slice
(
-
self
.
win_size
,
-
self
.
shift_size
),
slice
(
-
self
.
shift_size
,
None
))
w_slices
=
(
slice
(
0
,
-
self
.
win_size
),
slice
(
-
self
.
win_size
,
-
self
.
shift_size
),
slice
(
-
self
.
shift_size
,
None
))
cnt
=
0
for
h
in
h_slices
:
for
w
in
w_slices
:
img_mask
[:,
h
,
w
,
:]
=
cnt
cnt
+=
1
mask_windows
=
window_partition
(
img_mask
,
self
.
win_size
)
# num_win, window_size, window_size, 1
mask_windows
=
mask_windows
.
view
(
-
1
,
self
.
win_size
*
self
.
win_size
)
attn_mask
=
mask_windows
.
unsqueeze
(
1
)
-
mask_windows
.
unsqueeze
(
2
)
attn_mask
=
attn_mask
.
masked_fill
(
attn_mask
!=
0
,
float
(
-
100.0
)).
masked_fill
(
attn_mask
==
0
,
float
(
0.0
))
else
:
attn_mask
=
None
self
.
register_buffer
(
"attn_mask"
,
attn_mask
)
# define a parameter table of relative position bias
self
.
relative_position_bias_table
=
nn
.
Parameter
(
# 2 * Wh - 1 * 2 * Ww - 1, nH
torch
.
zeros
((
2
*
self
.
win_size
-
1
)
*
(
2
*
self
.
win_size
-
1
),
num_heads
))
# get pair-wise relative position index for each token inside the window
coords_h
=
torch
.
arange
(
self
.
win_size
)
coords_w
=
torch
.
arange
(
self
.
win_size
)
coords
=
torch
.
stack
(
torch
.
meshgrid
([
coords_h
,
coords_w
]))
# 2, Wh, Ww
coords_flatten
=
torch
.
flatten
(
coords
,
1
)
# 2, Wh*Ww
relative_coords
=
coords_flatten
[:,
:,
None
]
-
coords_flatten
[:,
None
,
:]
# 2, Wh*Ww, Wh*Ww
relative_coords
=
relative_coords
.
permute
(
1
,
2
,
0
).
contiguous
()
# Wh*Ww, Wh*Ww, 2
relative_coords
[:,
:,
0
]
+=
self
.
win_size
-
1
# shift to start from 0
relative_coords
[:,
:,
1
]
+=
self
.
win_size
-
1
relative_coords
[:,
:,
0
]
*=
2
*
self
.
win_size
-
1
relative_position_index
=
relative_coords
.
sum
(
-
1
)
# Wh*Ww, Wh*Ww
self
.
register_buffer
(
"relative_position_index"
,
relative_position_index
)
trunc_normal_
(
self
.
relative_position_bias_table
,
std
=
.
02
)
self
.
qkv
=
nn
.
Linear
(
dim
,
self
.
dim_out
*
3
,
bias
=
qkv_bias
)
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
softmax
=
nn
.
Softmax
(
dim
=-
1
)
self
.
pool
=
nn
.
AvgPool2d
(
2
,
2
)
if
stride
==
2
else
nn
.
Identity
()
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
x
=
x
.
permute
(
0
,
2
,
3
,
1
)
# cyclic shift
if
self
.
shift_size
>
0
:
shifted_x
=
torch
.
roll
(
x
,
shifts
=
(
-
self
.
shift_size
,
-
self
.
shift_size
),
dims
=
(
1
,
2
))
else
:
shifted_x
=
x
# partition windows
win_size_sq
=
self
.
win_size
*
self
.
win_size
x_windows
=
window_partition
(
shifted_x
,
self
.
win_size
)
# num_win * B, window_size, window_size, C
x_windows
=
x_windows
.
view
(
-
1
,
win_size_sq
,
C
)
# num_win * B, window_size*window_size, C
BW
,
N
,
_
=
x_windows
.
shape
qkv
=
self
.
qkv
(
x_windows
)
qkv
=
qkv
.
reshape
(
BW
,
N
,
3
,
self
.
num_heads
,
self
.
dim_out
//
self
.
num_heads
).
permute
(
2
,
0
,
3
,
1
,
4
)
q
,
k
,
v
=
qkv
[
0
],
qkv
[
1
],
qkv
[
2
]
q
=
q
*
self
.
scale
attn
=
(
q
@
k
.
transpose
(
-
2
,
-
1
))
relative_position_bias
=
self
.
relative_position_bias_table
[
self
.
relative_position_index
.
view
(
-
1
)].
view
(
win_size_sq
,
win_size_sq
,
-
1
)
relative_position_bias
=
relative_position_bias
.
permute
(
2
,
0
,
1
).
contiguous
()
# nH, Wh * Ww, Wh * Ww
attn
=
attn
+
relative_position_bias
.
unsqueeze
(
0
)
if
self
.
attn_mask
is
not
None
:
num_win
=
self
.
attn_mask
.
shape
[
0
]
attn
=
attn
.
view
(
B
,
num_win
,
self
.
num_heads
,
N
,
N
)
+
self
.
attn_mask
.
unsqueeze
(
1
).
unsqueeze
(
0
)
attn
=
attn
.
view
(
-
1
,
self
.
num_heads
,
N
,
N
)
attn
=
self
.
softmax
(
attn
)
attn
=
self
.
attn_drop
(
attn
)
x
=
(
attn
@
v
).
transpose
(
1
,
2
).
reshape
(
BW
,
N
,
self
.
dim_out
)
# merge windows
x
=
x
.
view
(
-
1
,
self
.
win_size
,
self
.
win_size
,
self
.
dim_out
)
shifted_x
=
window_reverse
(
x
,
self
.
win_size
,
H
,
W
)
# B H' W' C
# reverse cyclic shift
if
self
.
shift_size
>
0
:
x
=
torch
.
roll
(
shifted_x
,
shifts
=
(
self
.
shift_size
,
self
.
shift_size
),
dims
=
(
1
,
2
))
else
:
x
=
shifted_x
x
=
x
.
view
(
B
,
H
,
W
,
self
.
dim_out
).
permute
(
0
,
3
,
1
,
2
)
x
=
self
.
pool
(
x
)
return
x
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录