Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
stoneliu1981
pytorch-image-models
提交
b7de82e8
P
pytorch-image-models
项目概览
stoneliu1981
/
pytorch-image-models
与 Fork 源项目一致
从无法访问的项目Fork
通知
7
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
pytorch-image-models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
b7de82e8
编写于
5月 21, 2021
作者:
R
Ross Wightman
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
ConViT cleanup, fix torchscript, bit of reformatting, reuse existing layers.
上级
306c86b6
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
98 addition
and
192 deletion
+98
-192
timm/models/convit.py
timm/models/convit.py
+98
-192
未找到文件。
timm/models/convit.py
浏览文件 @
b7de82e8
"""These modules are adapted from those of timm, see
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
""" ConViT Model
@article{d2021convit,
title={ConViT: Improving Vision Transformers with Soft Convolutional Inductive Biases},
author={d'Ascoli, St{
\'
e}phane and Touvron, Hugo and Leavitt, Matthew and Morcos, Ari and Biroli, Giulio and Sagun, Levent},
journal={arXiv preprint arXiv:2103.10697},
year={2021}
}
Paper link: https://arxiv.org/abs/2103.10697
Original code: https://github.com/facebookresearch/convit, original copyright below
"""
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the CC-by-NC license found in the
# LICENSE file in the root directory of this source tree.
#
'''These modules are adapted from those of timm, see
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
'''
import
torch
import
torch.nn
as
nn
...
...
@@ -9,8 +27,9 @@ import torch.nn.functional as F
from
timm.data
import
IMAGENET_DEFAULT_MEAN
,
IMAGENET_DEFAULT_STD
from
.helpers
import
build_model_with_cfg
from
timm.models.layers
import
DropPath
,
to_2tuple
,
trunc_normal_
from
timm.models.registry
import
register_model
from
.layers
import
DropPath
,
to_2tuple
,
trunc_normal_
,
PatchEmbed
,
Mlp
from
.registry
import
register_model
from
.vision_transformer_hybrid
import
HybridEmbed
import
torch
import
torch.nn
as
nn
...
...
@@ -29,7 +48,7 @@ def _cfg(url='', **kwargs):
default_cfgs
=
{
# ConViT
'convit_tiny'
:
_cfg
(
url
=
"https://dl.fbaipublicfiles.com/convit/convit_tiny.pth"
),
url
=
"https://dl.fbaipublicfiles.com/convit/convit_tiny.pth"
),
'convit_small'
:
_cfg
(
url
=
"https://dl.fbaipublicfiles.com/convit/convit_small.pth"
),
'convit_base'
:
_cfg
(
...
...
@@ -37,71 +56,31 @@ default_cfgs = {
}
class
Mlp
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
act_layer
=
nn
.
GELU
,
drop
=
0.
):
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
)
self
.
act
=
act_layer
()
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
)
self
.
drop
=
nn
.
Dropout
(
drop
)
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
def
forward
(
self
,
x
):
x
=
self
.
fc1
(
x
)
x
=
self
.
act
(
x
)
x
=
self
.
drop
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
drop
(
x
)
return
x
class
GPSA
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
num_heads
=
8
,
qkv_bias
=
False
,
qk_scale
=
None
,
attn_drop
=
0.
,
proj_drop
=
0.
,
locality_strength
=
1.
,
use_local_init
=
True
):
locality_strength
=
1.
):
super
().
__init__
()
self
.
num_heads
=
num_heads
self
.
dim
=
dim
head_dim
=
dim
//
num_heads
self
.
scale
=
qk_scale
or
head_dim
**
-
0.5
self
.
locality_strength
=
locality_strength
self
.
qk
=
nn
.
Linear
(
dim
,
dim
*
2
,
bias
=
qkv_bias
)
self
.
v
=
nn
.
Linear
(
dim
,
dim
,
bias
=
qkv_bias
)
self
.
qk
=
nn
.
Linear
(
dim
,
dim
*
2
,
bias
=
qkv_bias
)
self
.
v
=
nn
.
Linear
(
dim
,
dim
,
bias
=
qkv_bias
)
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
proj
=
nn
.
Linear
(
dim
,
dim
)
self
.
pos_proj
=
nn
.
Linear
(
3
,
num_heads
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
self
.
locality_strength
=
locality_strength
self
.
gating_param
=
nn
.
Parameter
(
torch
.
ones
(
self
.
num_heads
))
self
.
apply
(
self
.
_init_weights
)
if
use_local_init
:
self
.
local_init
(
locality_strength
=
locality_strength
)
self
.
rel_indices
:
torch
.
Tensor
=
torch
.
zeros
(
1
,
1
,
1
,
3
)
# silly torchscript hack, won't work with None
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
def
forward
(
self
,
x
):
B
,
N
,
C
=
x
.
shape
if
not
hasattr
(
self
,
'rel_indices'
)
or
self
.
rel_indices
.
size
(
1
)
!=
N
:
self
.
get_rel_indices
(
N
)
if
self
.
rel_indices
is
None
or
self
.
rel_indices
.
shape
[
1
]
!=
N
:
self
.
rel_indices
=
self
.
get_rel_indices
(
N
)
attn
=
self
.
get_attention
(
x
)
v
=
self
.
v
(
x
).
reshape
(
B
,
N
,
self
.
num_heads
,
C
//
self
.
num_heads
).
permute
(
0
,
2
,
1
,
3
)
x
=
(
attn
@
v
).
transpose
(
1
,
2
).
reshape
(
B
,
N
,
C
)
...
...
@@ -110,61 +89,58 @@ class GPSA(nn.Module):
return
x
def
get_attention
(
self
,
x
):
B
,
N
,
C
=
x
.
shape
B
,
N
,
C
=
x
.
shape
qk
=
self
.
qk
(
x
).
reshape
(
B
,
N
,
2
,
self
.
num_heads
,
C
//
self
.
num_heads
).
permute
(
2
,
0
,
3
,
1
,
4
)
q
,
k
=
qk
[
0
],
qk
[
1
]
pos_score
=
self
.
rel_indices
.
expand
(
B
,
-
1
,
-
1
,
-
1
)
pos_score
=
self
.
pos_proj
(
pos_score
).
permute
(
0
,
3
,
1
,
2
)
pos_score
=
self
.
rel_indices
.
expand
(
B
,
-
1
,
-
1
,
-
1
)
pos_score
=
self
.
pos_proj
(
pos_score
).
permute
(
0
,
3
,
1
,
2
)
patch_score
=
(
q
@
k
.
transpose
(
-
2
,
-
1
))
*
self
.
scale
patch_score
=
patch_score
.
softmax
(
dim
=-
1
)
pos_score
=
pos_score
.
softmax
(
dim
=-
1
)
gating
=
self
.
gating_param
.
view
(
1
,
-
1
,
1
,
1
)
attn
=
(
1.
-
torch
.
sigmoid
(
gating
))
*
patch_score
+
torch
.
sigmoid
(
gating
)
*
pos_score
gating
=
self
.
gating_param
.
view
(
1
,
-
1
,
1
,
1
)
attn
=
(
1.
-
torch
.
sigmoid
(
gating
))
*
patch_score
+
torch
.
sigmoid
(
gating
)
*
pos_score
attn
/=
attn
.
sum
(
dim
=-
1
).
unsqueeze
(
-
1
)
attn
=
self
.
attn_drop
(
attn
)
return
attn
def
get_attention_map
(
self
,
x
,
return_map
=
False
):
attn_map
=
self
.
get_attention
(
x
).
mean
(
0
)
# average over batch
distances
=
self
.
rel_indices
.
squeeze
()[:,:,
-
1
]
**
.
5
dist
=
torch
.
einsum
(
'nm,hnm->h'
,
(
distances
,
attn_map
))
dist
/=
distances
.
size
(
0
)
def
get_attention_map
(
self
,
x
,
return_map
=
False
):
attn_map
=
self
.
get_attention
(
x
).
mean
(
0
)
# average over batch
distances
=
self
.
rel_indices
.
squeeze
()[:,
:,
-
1
]
**
.
5
dist
=
torch
.
einsum
(
'nm,hnm->h'
,
(
distances
,
attn_map
))
/
distances
.
size
(
0
)
if
return_map
:
return
dist
,
attn_map
else
:
return
dist
def
local_init
(
self
,
locality_strength
=
1.
):
def
local_init
(
self
):
self
.
v
.
weight
.
data
.
copy_
(
torch
.
eye
(
self
.
dim
))
locality_distance
=
1
#
max(1,1/locality_strength**.5)
kernel_size
=
int
(
self
.
num_heads
**
.
5
)
center
=
(
kernel_size
-
1
)
/
2
if
kernel_size
%
2
==
0
else
kernel_size
//
2
locality_distance
=
1
#
max(1,1/locality_strength**.5)
kernel_size
=
int
(
self
.
num_heads
**
.
5
)
center
=
(
kernel_size
-
1
)
/
2
if
kernel_size
%
2
==
0
else
kernel_size
//
2
for
h1
in
range
(
kernel_size
):
for
h2
in
range
(
kernel_size
):
position
=
h1
+
kernel_size
*
h2
self
.
pos_proj
.
weight
.
data
[
position
,
2
]
=
-
1
self
.
pos_proj
.
weight
.
data
[
position
,
1
]
=
2
*
(
h1
-
center
)
*
locality_distance
self
.
pos_proj
.
weight
.
data
[
position
,
0
]
=
2
*
(
h2
-
center
)
*
locality_distance
self
.
pos_proj
.
weight
.
data
*=
locality_strength
def
get_rel_indices
(
self
,
num_patches
)
:
img_size
=
int
(
num_patches
**
.
5
)
rel_indices
=
torch
.
zeros
(
1
,
num_patches
,
num_patches
,
3
)
ind
=
torch
.
arange
(
img_size
).
view
(
1
,
-
1
)
-
torch
.
arange
(
img_size
).
view
(
-
1
,
1
)
indx
=
ind
.
repeat
(
img_size
,
img_size
)
indy
=
ind
.
repeat_interleave
(
img_size
,
dim
=
0
).
repeat_interleave
(
img_size
,
dim
=
1
)
indd
=
indx
**
2
+
indy
**
2
rel_indices
[:,
:,:,
2
]
=
indd
.
unsqueeze
(
0
)
rel_indices
[:,
:,:,
1
]
=
indy
.
unsqueeze
(
0
)
rel_indices
[:,
:,:,
0
]
=
indx
.
unsqueeze
(
0
)
position
=
h1
+
kernel_size
*
h2
self
.
pos_proj
.
weight
.
data
[
position
,
2
]
=
-
1
self
.
pos_proj
.
weight
.
data
[
position
,
1
]
=
2
*
(
h1
-
center
)
*
locality_distance
self
.
pos_proj
.
weight
.
data
[
position
,
0
]
=
2
*
(
h2
-
center
)
*
locality_distance
self
.
pos_proj
.
weight
.
data
*=
self
.
locality_strength
def
get_rel_indices
(
self
,
num_patches
:
int
)
->
torch
.
Tensor
:
img_size
=
int
(
num_patches
**
.
5
)
rel_indices
=
torch
.
zeros
(
1
,
num_patches
,
num_patches
,
3
)
ind
=
torch
.
arange
(
img_size
).
view
(
1
,
-
1
)
-
torch
.
arange
(
img_size
).
view
(
-
1
,
1
)
indx
=
ind
.
repeat
(
img_size
,
img_size
)
indy
=
ind
.
repeat_interleave
(
img_size
,
dim
=
0
).
repeat_interleave
(
img_size
,
dim
=
1
)
indd
=
indx
**
2
+
indy
**
2
rel_indices
[:,
:,
:,
2
]
=
indd
.
unsqueeze
(
0
)
rel_indices
[:,
:,
:,
1
]
=
indy
.
unsqueeze
(
0
)
rel_indices
[:,
:,
:,
0
]
=
indx
.
unsqueeze
(
0
)
device
=
self
.
qk
.
weight
.
device
self
.
rel_indices
=
rel_indices
.
to
(
device
)
return
rel_indices
.
to
(
device
)
class
MHSA
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
num_heads
=
8
,
qkv_bias
=
False
,
qk_scale
=
None
,
attn_drop
=
0.
,
proj_drop
=
0.
):
super
().
__init__
()
...
...
@@ -176,41 +152,28 @@ class MHSA(nn.Module):
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
proj
=
nn
.
Linear
(
dim
,
dim
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
def
get_attention_map
(
self
,
x
,
return_map
=
False
):
def
get_attention_map
(
self
,
x
,
return_map
=
False
):
B
,
N
,
C
=
x
.
shape
qkv
=
self
.
qkv
(
x
).
reshape
(
B
,
N
,
3
,
self
.
num_heads
,
C
//
self
.
num_heads
).
permute
(
2
,
0
,
3
,
1
,
4
)
q
,
k
,
v
=
qkv
[
0
],
qkv
[
1
],
qkv
[
2
]
attn_map
=
(
q
@
k
.
transpose
(
-
2
,
-
1
))
*
self
.
scale
attn_map
=
attn_map
.
softmax
(
dim
=-
1
).
mean
(
0
)
img_size
=
int
(
N
**
.
5
)
ind
=
torch
.
arange
(
img_size
).
view
(
1
,
-
1
)
-
torch
.
arange
(
img_size
).
view
(
-
1
,
1
)
indx
=
ind
.
repeat
(
img_size
,
img_size
)
indy
=
ind
.
repeat_interleave
(
img_size
,
dim
=
0
).
repeat_interleave
(
img_size
,
dim
=
1
)
indd
=
indx
**
2
+
indy
**
2
distances
=
indd
**
.
5
img_size
=
int
(
N
**
.
5
)
ind
=
torch
.
arange
(
img_size
).
view
(
1
,
-
1
)
-
torch
.
arange
(
img_size
).
view
(
-
1
,
1
)
indx
=
ind
.
repeat
(
img_size
,
img_size
)
indy
=
ind
.
repeat_interleave
(
img_size
,
dim
=
0
).
repeat_interleave
(
img_size
,
dim
=
1
)
indd
=
indx
**
2
+
indy
**
2
distances
=
indd
**
.
5
distances
=
distances
.
to
(
'cuda'
)
dist
=
torch
.
einsum
(
'nm,hnm->h'
,
(
distances
,
attn_map
))
dist
/=
N
dist
=
torch
.
einsum
(
'nm,hnm->h'
,
(
distances
,
attn_map
))
/
N
if
return_map
:
return
dist
,
attn_map
else
:
return
dist
def
forward
(
self
,
x
):
B
,
N
,
C
=
x
.
shape
qkv
=
self
.
qkv
(
x
).
reshape
(
B
,
N
,
3
,
self
.
num_heads
,
C
//
self
.
num_heads
).
permute
(
2
,
0
,
3
,
1
,
4
)
...
...
@@ -228,15 +191,19 @@ class MHSA(nn.Module):
class
Block
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
num_heads
,
mlp_ratio
=
4.
,
qkv_bias
=
False
,
qk_scale
=
None
,
drop
=
0.
,
attn_drop
=
0.
,
def
__init__
(
self
,
dim
,
num_heads
,
mlp_ratio
=
4.
,
qkv_bias
=
False
,
qk_scale
=
None
,
drop
=
0.
,
attn_drop
=
0.
,
drop_path
=
0.
,
act_layer
=
nn
.
GELU
,
norm_layer
=
nn
.
LayerNorm
,
use_gpsa
=
True
,
**
kwargs
):
super
().
__init__
()
self
.
norm1
=
norm_layer
(
dim
)
self
.
use_gpsa
=
use_gpsa
if
self
.
use_gpsa
:
self
.
attn
=
GPSA
(
dim
,
num_heads
=
num_heads
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
attn_drop
=
attn_drop
,
proj_drop
=
drop
,
**
kwargs
)
self
.
attn
=
GPSA
(
dim
,
num_heads
=
num_heads
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
attn_drop
=
attn_drop
,
proj_drop
=
drop
,
**
kwargs
)
else
:
self
.
attn
=
MHSA
(
dim
,
num_heads
=
num_heads
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
attn_drop
=
attn_drop
,
proj_drop
=
drop
,
**
kwargs
)
self
.
attn
=
MHSA
(
dim
,
num_heads
=
num_heads
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
attn_drop
=
attn_drop
,
proj_drop
=
drop
,
**
kwargs
)
self
.
drop_path
=
DropPath
(
drop_path
)
if
drop_path
>
0.
else
nn
.
Identity
()
self
.
norm2
=
norm_layer
(
dim
)
mlp_hidden_dim
=
int
(
dim
*
mlp_ratio
)
...
...
@@ -246,75 +213,12 @@ class Block(nn.Module):
x
=
x
+
self
.
drop_path
(
self
.
attn
(
self
.
norm1
(
x
)))
x
=
x
+
self
.
drop_path
(
self
.
mlp
(
self
.
norm2
(
x
)))
return
x
class
PatchEmbed
(
nn
.
Module
):
""" Image to Patch Embedding, from timm
"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
16
,
in_chans
=
3
,
embed_dim
=
768
):
super
().
__init__
()
img_size
=
to_2tuple
(
img_size
)
patch_size
=
to_2tuple
(
patch_size
)
num_patches
=
(
img_size
[
1
]
//
patch_size
[
1
])
*
(
img_size
[
0
]
//
patch_size
[
0
])
self
.
img_size
=
img_size
self
.
patch_size
=
patch_size
self
.
num_patches
=
num_patches
self
.
proj
=
nn
.
Conv2d
(
in_chans
,
embed_dim
,
kernel_size
=
patch_size
,
stride
=
patch_size
)
self
.
apply
(
self
.
_init_weights
)
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
assert
H
==
self
.
img_size
[
0
]
and
W
==
self
.
img_size
[
1
],
\
f
"Input image size (
{
H
}
*
{
W
}
) doesn't match model (
{
self
.
img_size
[
0
]
}
*
{
self
.
img_size
[
1
]
}
)."
x
=
self
.
proj
(
x
).
flatten
(
2
).
transpose
(
1
,
2
)
return
x
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
class
HybridEmbed
(
nn
.
Module
):
""" CNN Feature Map Embedding, from timm
"""
def
__init__
(
self
,
backbone
,
img_size
=
224
,
feature_size
=
None
,
in_chans
=
3
,
embed_dim
=
768
):
super
().
__init__
()
assert
isinstance
(
backbone
,
nn
.
Module
)
img_size
=
to_2tuple
(
img_size
)
self
.
img_size
=
img_size
self
.
backbone
=
backbone
if
feature_size
is
None
:
with
torch
.
no_grad
():
training
=
backbone
.
training
if
training
:
backbone
.
eval
()
o
=
self
.
backbone
(
torch
.
zeros
(
1
,
in_chans
,
img_size
[
0
],
img_size
[
1
]))[
-
1
]
feature_size
=
o
.
shape
[
-
2
:]
feature_dim
=
o
.
shape
[
1
]
backbone
.
train
(
training
)
else
:
feature_size
=
to_2tuple
(
feature_size
)
feature_dim
=
self
.
backbone
.
feature_info
.
channels
()[
-
1
]
self
.
num_patches
=
feature_size
[
0
]
*
feature_size
[
1
]
self
.
proj
=
nn
.
Linear
(
feature_dim
,
embed_dim
)
self
.
apply
(
self
.
_init_weights
)
def
forward
(
self
,
x
):
x
=
self
.
backbone
(
x
)[
-
1
]
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
)
x
=
self
.
proj
(
x
)
return
x
class
ConViT
(
nn
.
Module
):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
16
,
in_chans
=
3
,
num_classes
=
1000
,
embed_dim
=
768
,
depth
=
12
,
num_heads
=
12
,
mlp_ratio
=
4.
,
qkv_bias
=
False
,
qk_scale
=
None
,
drop_rate
=
0.
,
attn_drop_rate
=
0.
,
drop_path_rate
=
0.
,
hybrid_backbone
=
None
,
norm_layer
=
nn
.
LayerNorm
,
global_pool
=
None
,
...
...
@@ -335,7 +239,7 @@ class ConViT(nn.Module):
img_size
=
img_size
,
patch_size
=
patch_size
,
in_chans
=
in_chans
,
embed_dim
=
embed_dim
)
num_patches
=
self
.
patch_embed
.
num_patches
self
.
num_patches
=
num_patches
self
.
cls_token
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
embed_dim
))
self
.
pos_drop
=
nn
.
Dropout
(
p
=
drop_rate
)
...
...
@@ -350,7 +254,7 @@ class ConViT(nn.Module):
drop
=
drop_rate
,
attn_drop
=
attn_drop_rate
,
drop_path
=
dpr
[
i
],
norm_layer
=
norm_layer
,
use_gpsa
=
True
,
locality_strength
=
locality_strength
)
if
i
<
local_up_to_layer
else
if
i
<
local_up_to_layer
else
Block
(
dim
=
embed_dim
,
num_heads
=
num_heads
,
mlp_ratio
=
mlp_ratio
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop_rate
,
attn_drop
=
attn_drop_rate
,
drop_path
=
dpr
[
i
],
norm_layer
=
norm_layer
,
...
...
@@ -363,7 +267,10 @@ class ConViT(nn.Module):
self
.
head
=
nn
.
Linear
(
embed_dim
,
num_classes
)
if
num_classes
>
0
else
nn
.
Identity
()
trunc_normal_
(
self
.
cls_token
,
std
=
.
02
)
self
.
head
.
apply
(
self
.
_init_weights
)
self
.
apply
(
self
.
_init_weights
)
for
n
,
m
in
self
.
named_modules
():
if
hasattr
(
m
,
'local_init'
):
m
.
local_init
()
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
...
...
@@ -395,8 +302,8 @@ class ConViT(nn.Module):
x
=
x
+
self
.
pos_embed
x
=
self
.
pos_drop
(
x
)
for
u
,
blk
in
enumerate
(
self
.
blocks
):
if
u
==
self
.
local_up_to_layer
:
for
u
,
blk
in
enumerate
(
self
.
blocks
):
if
u
==
self
.
local_up_to_layer
:
x
=
torch
.
cat
((
cls_tokens
,
x
),
dim
=
1
)
x
=
blk
(
x
)
...
...
@@ -415,30 +322,29 @@ def _create_convit(variant, pretrained=False, **kwargs):
default_cfg
=
default_cfgs
[
variant
],
**
kwargs
)
@
register_model
def
convit_tiny
(
pretrained
=
False
,
**
kwargs
):
model_args
=
dict
(
local_up_to_layer
=
10
,
locality_strength
=
1.0
,
embed_dim
=
48
,
local_up_to_layer
=
10
,
locality_strength
=
1.0
,
embed_dim
=
48
,
num_heads
=
4
,
norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
),
**
kwargs
)
model
=
_create_convit
(
variant
=
'convit_tiny'
,
pretrained
=
pretrained
,
**
model_args
)
model
=
_create_convit
(
variant
=
'convit_tiny'
,
pretrained
=
pretrained
,
**
model_args
)
return
model
@
register_model
def
convit_small
(
pretrained
=
False
,
**
kwargs
):
model_args
=
dict
(
local_up_to_layer
=
10
,
locality_strength
=
1.0
,
embed_dim
=
48
,
local_up_to_layer
=
10
,
locality_strength
=
1.0
,
embed_dim
=
48
,
num_heads
=
9
,
norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
),
**
kwargs
)
model
=
_create_convit
(
variant
=
'convit_small'
,
pretrained
=
pretrained
,
**
model_args
)
model
=
_create_convit
(
variant
=
'convit_small'
,
pretrained
=
pretrained
,
**
model_args
)
return
model
@
register_model
def
convit_base
(
pretrained
=
False
,
**
kwargs
):
model_args
=
dict
(
local_up_to_layer
=
10
,
locality_strength
=
1.0
,
embed_dim
=
48
,
local_up_to_layer
=
10
,
locality_strength
=
1.0
,
embed_dim
=
48
,
num_heads
=
16
,
norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
),
**
kwargs
)
model
=
_create_convit
(
variant
=
'convit_base'
,
pretrained
=
pretrained
,
**
model_args
)
model
=
_create_convit
(
variant
=
'convit_base'
,
pretrained
=
pretrained
,
**
model_args
)
return
model
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录