Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
Annotated Deep Learning Paper Implementations
提交
f038ab67
A
Annotated Deep Learning Paper Implementations
项目概览
Greenplum
/
Annotated Deep Learning Paper Implementations
9 个月 前同步成功
通知
6
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
A
Annotated Deep Learning Paper Implementations
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
f038ab67
编写于
7月 17, 2021
作者:
V
Varuna Jayasiri
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
vit
上级
f0bf8d39
变更
12
展开全部
隐藏空白更改
内联
并排
Showing
12 changed file
with
979 addition
and
186 deletion
+979
-186
docs/index.html
docs/index.html
+1
-0
docs/transformers/index.html
docs/transformers/index.html
+7
-4
docs/transformers/vit/experiment.html
docs/transformers/vit/experiment.html
+172
-70
docs/transformers/vit/index.html
docs/transformers/vit/index.html
+438
-90
docs/transformers/vit/readme.html
docs/transformers/vit/readme.html
+162
-0
labml_nn/__init__.py
labml_nn/__init__.py
+1
-0
labml_nn/transformers/__init__.py
labml_nn/transformers/__init__.py
+5
-0
labml_nn/transformers/vit/__init__.py
labml_nn/transformers/vit/__init__.py
+135
-9
labml_nn/transformers/vit/experiment.py
labml_nn/transformers/vit/experiment.py
+24
-12
labml_nn/transformers/vit/readme.md
labml_nn/transformers/vit/readme.md
+32
-0
readme.md
readme.md
+1
-0
setup.py
setup.py
+1
-1
未找到文件。
docs/index.html
浏览文件 @
f038ab67
...
@@ -95,6 +95,7 @@ implementations.</p>
...
@@ -95,6 +95,7 @@ implementations.</p>
<li><a
href=
"transformers/mlm/index.html"
>
Masked Language Model
</a></li>
<li><a
href=
"transformers/mlm/index.html"
>
Masked Language Model
</a></li>
<li><a
href=
"transformers/mlp_mixer/index.html"
>
MLP-Mixer: An all-MLP Architecture for Vision
</a></li>
<li><a
href=
"transformers/mlp_mixer/index.html"
>
MLP-Mixer: An all-MLP Architecture for Vision
</a></li>
<li><a
href=
"transformers/gmlp/index.html"
>
Pay Attention to MLPs (gMLP)
</a></li>
<li><a
href=
"transformers/gmlp/index.html"
>
Pay Attention to MLPs (gMLP)
</a></li>
<li><a
href=
"transformers/vit/index.html"
>
Vision Transformer (ViT)
</a></li>
</ul>
</ul>
<h4>
✨
<a
href=
"recurrent_highway_networks/index.html"
>
Recurrent Highway Networks
</a></h4>
<h4>
✨
<a
href=
"recurrent_highway_networks/index.html"
>
Recurrent Highway Networks
</a></h4>
<h4>
✨
<a
href=
"lstm/index.html"
>
LSTM
</a></h4>
<h4>
✨
<a
href=
"lstm/index.html"
>
LSTM
</a></h4>
...
...
docs/transformers/index.html
浏览文件 @
f038ab67
...
@@ -117,12 +117,15 @@ It does single GPU training but we implement the concept of switching as describ
...
@@ -117,12 +117,15 @@ It does single GPU training but we implement the concept of switching as describ
<h2><a
href=
"gmlp/index.html"
>
Pay Attention to MLPs (gMLP)
</a></h2>
<h2><a
href=
"gmlp/index.html"
>
Pay Attention to MLPs (gMLP)
</a></h2>
<p>
This is an implementation of the paper
<p>
This is an implementation of the paper
<a
href=
"https://papers.labml.ai/paper/2105.08050"
>
Pay Attention to MLPs
</a>
.
</p>
<a
href=
"https://papers.labml.ai/paper/2105.08050"
>
Pay Attention to MLPs
</a>
.
</p>
<h2><a
href=
"vit/index.html"
>
Vision Transformer (ViT)
</a></h2>
<p>
This is an implementation of the paper
<a
href=
"https://arxiv.org/abs/2010.11929"
>
An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale
</a>
.
</p>
</div>
</div>
<div
class=
'code'
>
<div
class=
'code'
>
<div
class=
"highlight"
><pre><span
class=
"lineno"
>
87
</span><span></span><span
class=
"kn"
>
from
</span>
<span
class=
"nn"
>
.configs
</span>
<span
class=
"kn"
>
import
</span>
<span
class=
"n"
>
TransformerConfigs
</span>
<div
class=
"highlight"
><pre><span
class=
"lineno"
>
92
</span><span></span><span
class=
"kn"
>
from
</span>
<span
class=
"nn"
>
.configs
</span>
<span
class=
"kn"
>
import
</span>
<span
class=
"n"
>
TransformerConfigs
</span>
<span
class=
"lineno"
>
88
</span><span
class=
"kn"
>
from
</span>
<span
class=
"nn"
>
.models
</span>
<span
class=
"kn"
>
import
</span>
<span
class=
"n"
>
TransformerLayer
</span><span
class=
"p"
>
,
</span>
<span
class=
"n"
>
Encoder
</span><span
class=
"p"
>
,
</span>
<span
class=
"n"
>
Decoder
</span><span
class=
"p"
>
,
</span>
<span
class=
"n"
>
Generator
</span><span
class=
"p"
>
,
</span>
<span
class=
"n"
>
EncoderDecoder
</span>
<span
class=
"lineno"
>
93
</span><span
class=
"kn"
>
from
</span>
<span
class=
"nn"
>
.models
</span>
<span
class=
"kn"
>
import
</span>
<span
class=
"n"
>
TransformerLayer
</span><span
class=
"p"
>
,
</span>
<span
class=
"n"
>
Encoder
</span><span
class=
"p"
>
,
</span>
<span
class=
"n"
>
Decoder
</span><span
class=
"p"
>
,
</span>
<span
class=
"n"
>
Generator
</span><span
class=
"p"
>
,
</span>
<span
class=
"n"
>
EncoderDecoder
</span>
<span
class=
"lineno"
>
89
</span><span
class=
"kn"
>
from
</span>
<span
class=
"nn"
>
.mha
</span>
<span
class=
"kn"
>
import
</span>
<span
class=
"n"
>
MultiHeadAttention
</span>
<span
class=
"lineno"
>
94
</span><span
class=
"kn"
>
from
</span>
<span
class=
"nn"
>
.mha
</span>
<span
class=
"kn"
>
import
</span>
<span
class=
"n"
>
MultiHeadAttention
</span>
<span
class=
"lineno"
>
9
0
</span><span
class=
"kn"
>
from
</span>
<span
class=
"nn"
>
labml_nn.transformers.xl.relative_mha
</span>
<span
class=
"kn"
>
import
</span>
<span
class=
"n"
>
RelativeMultiHeadAttention
</span></pre></div>
<span
class=
"lineno"
>
9
5
</span><span
class=
"kn"
>
from
</span>
<span
class=
"nn"
>
labml_nn.transformers.xl.relative_mha
</span>
<span
class=
"kn"
>
import
</span>
<span
class=
"n"
>
RelativeMultiHeadAttention
</span></pre></div>
</div>
</div>
</div>
</div>
</div>
</div>
...
...
docs/transformers/vit/experiment.html
浏览文件 @
f038ab67
此差异已折叠。
点击以展开。
docs/transformers/vit/index.html
浏览文件 @
f038ab67
此差异已折叠。
点击以展开。
docs/transformers/vit/readme.html
0 → 100644
浏览文件 @
f038ab67
<!DOCTYPE html>
<html>
<head>
<meta
http-equiv=
"content-type"
content=
"text/html;charset=utf-8"
/>
<meta
name=
"viewport"
content=
"width=device-width, initial-scale=1.0"
/>
<meta
name=
"description"
content=
""
/>
<meta
name=
"twitter:card"
content=
"summary"
/>
<meta
name=
"twitter:image:src"
content=
"https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"
/>
<meta
name=
"twitter:title"
content=
" Vision Transformer (ViT)"
/>
<meta
name=
"twitter:description"
content=
""
/>
<meta
name=
"twitter:site"
content=
"@labmlai"
/>
<meta
name=
"twitter:creator"
content=
"@labmlai"
/>
<meta
property=
"og:url"
content=
"https://nn.labml.ai/transformers/vit/readme.html"
/>
<meta
property=
"og:title"
content=
" Vision Transformer (ViT)"
/>
<meta
property=
"og:image"
content=
"https://avatars1.githubusercontent.com/u/64068543?s=400&v=4"
/>
<meta
property=
"og:site_name"
content=
"LabML Neural Networks"
/>
<meta
property=
"og:type"
content=
"object"
/>
<meta
property=
"og:title"
content=
" Vision Transformer (ViT)"
/>
<meta
property=
"og:description"
content=
""
/>
<title>
Vision Transformer (ViT)
</title>
<link
rel=
"shortcut icon"
href=
"/icon.png"
/>
<link
rel=
"stylesheet"
href=
"../../pylit.css"
>
<link
rel=
"canonical"
href=
"https://nn.labml.ai/transformers/vit/readme.html"
/>
<!-- Global site tag (gtag.js) - Google Analytics -->
<script
async
src=
"https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"
></script>
<script>
window
.
dataLayer
=
window
.
dataLayer
||
[];
function
gtag
()
{
dataLayer
.
push
(
arguments
);
}
gtag
(
'
js
'
,
new
Date
());
gtag
(
'
config
'
,
'
G-4V3HC8HBLH
'
);
</script>
</head>
<body>
<div
id=
'container'
>
<div
id=
"background"
></div>
<div
class=
'section'
>
<div
class=
'docs'
>
<p>
<a
class=
"parent"
href=
"/"
>
home
</a>
<a
class=
"parent"
href=
"../index.html"
>
transformers
</a>
<a
class=
"parent"
href=
"index.html"
>
vit
</a>
</p>
<p>
<a
href=
"https://github.com/lab-ml/labml_nn/tree/master/labml_nn/transformers/vit/readme.md"
>
<img
alt=
"Github"
src=
"https://img.shields.io/github/stars/lab-ml/nn?style=social"
style=
"max-width:100%;"
/></a>
<a
href=
"https://twitter.com/labmlai"
rel=
"nofollow"
>
<img
alt=
"Twitter"
src=
"https://img.shields.io/twitter/follow/labmlai?style=social"
style=
"max-width:100%;"
/></a>
</p>
</div>
</div>
<div
class=
'section'
id=
'section-0'
>
<div
class=
'docs'
>
<div
class=
'section-link'
>
<a
href=
'#section-0'
>
#
</a>
</div>
<h1><a
href=
"https://nn.labml.ai/transformer/vit/index.html"
>
Vision Transformer (ViT)
</a></h1>
<p>
This is a
<a
href=
"https://pytorch.org"
>
PyTorch
</a>
implementation of the paper
<a
href=
"https://arxiv.org/abs/2010.11929"
>
An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale
</a>
.
</p>
<p>
Vision transformer applies a pure transformer to images
without any convolution layers.
They split the image into patches and apply a transformer on patch embeddings.
<a
href=
"https://nn.labml.ai/transformer/vit/index.html#PathEmbeddings"
>
Patch embeddings
</a>
are generated by applying a simple linear transformation
to the flattened pixel values of the patch.
Then a standard transformer encoder is fed with the patch embeddings, along with a
classification token
<code>
[CLS]
</code>
.
The encoding on the
<code>
[CLS]
</code>
token is used to classify the image with an MLP.
</p>
<p>
When feeding the transformer with the patches, learned positional embeddings are
added to the patch embeddings, because the patch embeddings do not have any information
about where that patch is from.
The positional embeddings are a set of vectors for each patch location that get trained
with gradient descent along with other parameters.
</p>
<p>
ViTs perform well when they are pre-trained on large datasets.
The paper suggests pre-training them with an MLP classification head and
then using a single linear layer when fine-tuning.
The paper beats SOTA with a ViT pre-trained on a 300 million image dataset.
They also use higher resolution images during inference while keeping the
patch size the same.
The positional embeddings for new patch locations are calculated by interpolating
learning positional embeddings.
</p>
<p>
Here
’
s
<a
href=
"https://nn.labml.ai/transformer/vit/experiment.html"
>
an experiment
</a>
that trains ViT on CIFAR-10.
This doesn
’
t do very well because it
’
s trained on a small dataset.
It
’
s a simple experiment that anyone can run and play with ViTs.
</p>
</div>
<div
class=
'code'
>
</div>
</div>
</div>
</div>
<script
src=
"https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.4/MathJax.js?config=TeX-AMS_HTML"
>
</script>
<!-- MathJax configuration -->
<script
type=
"text/x-mathjax-config"
>
MathJax
.
Hub
.
Config
({
tex2jax
:
{
inlineMath
:
[
[
'
$
'
,
'
$
'
]
],
displayMath
:
[
[
'
$$
'
,
'
$$
'
]
],
processEscapes
:
true
,
processEnvironments
:
true
},
// Center justify equations in code and markdown cells. Elsewhere
// we use CSS to left justify single line equations in code cells.
displayAlign
:
'
center
'
,
"
HTML-CSS
"
:
{
fonts
:
[
"
TeX
"
]
}
});
</script>
<script>
function
handleImages
()
{
var
images
=
document
.
querySelectorAll
(
'
p>img
'
)
console
.
log
(
images
);
for
(
var
i
=
0
;
i
<
images
.
length
;
++
i
)
{
handleImage
(
images
[
i
])
}
}
function
handleImage
(
img
)
{
img
.
parentElement
.
style
.
textAlign
=
'
center
'
var
modal
=
document
.
createElement
(
'
div
'
)
modal
.
id
=
'
modal
'
var
modalContent
=
document
.
createElement
(
'
div
'
)
modal
.
appendChild
(
modalContent
)
var
modalImage
=
document
.
createElement
(
'
img
'
)
modalContent
.
appendChild
(
modalImage
)
var
span
=
document
.
createElement
(
'
span
'
)
span
.
classList
.
add
(
'
close
'
)
span
.
textContent
=
'
x
'
modal
.
appendChild
(
span
)
img
.
onclick
=
function
()
{
console
.
log
(
'
clicked
'
)
document
.
body
.
appendChild
(
modal
)
modalImage
.
src
=
img
.
src
}
span
.
onclick
=
function
()
{
document
.
body
.
removeChild
(
modal
)
}
}
handleImages
()
</script>
</body>
</html>
\ No newline at end of file
labml_nn/__init__.py
浏览文件 @
f038ab67
...
@@ -31,6 +31,7 @@ implementations.
...
@@ -31,6 +31,7 @@ implementations.
* [Masked Language Model](transformers/mlm/index.html)
* [Masked Language Model](transformers/mlm/index.html)
* [MLP-Mixer: An all-MLP Architecture for Vision](transformers/mlp_mixer/index.html)
* [MLP-Mixer: An all-MLP Architecture for Vision](transformers/mlp_mixer/index.html)
* [Pay Attention to MLPs (gMLP)](transformers/gmlp/index.html)
* [Pay Attention to MLPs (gMLP)](transformers/gmlp/index.html)
* [Vision Transformer (ViT)](transformers/vit/index.html)
#### ✨ [Recurrent Highway Networks](recurrent_highway_networks/index.html)
#### ✨ [Recurrent Highway Networks](recurrent_highway_networks/index.html)
...
...
labml_nn/transformers/__init__.py
浏览文件 @
f038ab67
...
@@ -82,6 +82,11 @@ This is an implementation of the paper
...
@@ -82,6 +82,11 @@ This is an implementation of the paper
This is an implementation of the paper
This is an implementation of the paper
[Pay Attention to MLPs](https://papers.labml.ai/paper/2105.08050).
[Pay Attention to MLPs](https://papers.labml.ai/paper/2105.08050).
## [Vision Transformer (ViT)](vit/index.html)
This is an implementation of the paper
[An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale](https://arxiv.org/abs/2010.11929).
"""
"""
from
.configs
import
TransformerConfigs
from
.configs
import
TransformerConfigs
...
...
labml_nn/transformers/vit/__init__.py
浏览文件 @
f038ab67
"""
---
title: Vision Transformer (ViT)
summary: >
A PyTorch implementation/tutorial of the paper
"An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale"
---
# Vision Transformer (ViT)
This is a [PyTorch](https://pytorch.org) implementation of the paper
[An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale](https://arxiv.org/abs/2010.11929).
Vision transformer applies a pure transformer to images
without any convolution layers.
They split the image into patches and apply a transformer on patch embeddings.
[Patch embeddings](#PathEmbeddings) are generated by applying a simple linear transformation
to the flattened pixel values of the patch.
Then a standard transformer encoder is fed with the patch embeddings, along with a
classification token `[CLS]`.
The encoding on the `[CLS]` token is used to classify the image with an MLP.
When feeding the transformer with the patches, learned positional embeddings are
added to the patch embeddings, because the patch embeddings do not have any information
about where that patch is from.
The positional embeddings are a set of vectors for each patch location that get trained
with gradient descent along with other parameters.
ViTs perform well when they are pre-trained on large datasets.
The paper suggests pre-training them with an MLP classification head and
then using a single linear layer when fine-tuning.
The paper beats SOTA with a ViT pre-trained on a 300 million image dataset.
They also use higher resolution images during inference while keeping the
patch size the same.
The positional embeddings for new patch locations are calculated by interpolating
learning positional embeddings.
Here's [an experiment](experiment.html) that trains ViT on CIFAR-10.
This doesn't do very well because it's trained on a small dataset.
It's a simple experiment that anyone can run and play with ViTs.
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/8b531d9ce3dc11eb84fc87df6756eb8f)
"""
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -9,24 +53,41 @@ from labml_nn.utils import clone_module_list
...
@@ -9,24 +53,41 @@ from labml_nn.utils import clone_module_list
class
PatchEmbeddings
(
Module
):
class
PatchEmbeddings
(
Module
):
"""
"""
<a id="PatchEmbeddings">
<a id="PatchEmbeddings">
##
Embed patche
s
##
Get patch embedding
s
</a>
</a>
The paper splits the image into patches of equal size and do a linear transformation
on the flattened pixels for each patch.
We implement the same thing through a convolution layer, because it's simpler to implement.
"""
"""
def
__init__
(
self
,
d_model
:
int
,
patch_size
:
int
,
in_channels
:
int
):
def
__init__
(
self
,
d_model
:
int
,
patch_size
:
int
,
in_channels
:
int
):
"""
* `d_model` is the transformer embeddings size
* `patch_size` is the size of the patch
* `in_channels` is the number of channels in the input image (3 for rgb)
"""
super
().
__init__
()
super
().
__init__
()
self
.
patch_size
=
patch_size
# We create a convolution layer with a kernel size and and stride length equal to patch size.
# This is equivalent to splitting the image into patches and doing a linear
# transformation on each patch.
self
.
conv
=
nn
.
Conv2d
(
in_channels
,
d_model
,
patch_size
,
stride
=
patch_size
)
self
.
conv
=
nn
.
Conv2d
(
in_channels
,
d_model
,
patch_size
,
stride
=
patch_size
)
def
__call__
(
self
,
x
:
torch
.
Tensor
):
def
__call__
(
self
,
x
:
torch
.
Tensor
):
"""
"""
x has
shape `[batch_size, channels, height, width]`
* `x` is the input image of
shape `[batch_size, channels, height, width]`
"""
"""
# Apply convolution layer
x
=
self
.
conv
(
x
)
x
=
self
.
conv
(
x
)
# Get the shape.
bs
,
c
,
h
,
w
=
x
.
shape
bs
,
c
,
h
,
w
=
x
.
shape
# Rearrange to shape `[patches, batch_size, d_model]`
x
=
x
.
permute
(
2
,
3
,
0
,
1
)
x
=
x
.
permute
(
2
,
3
,
0
,
1
)
x
=
x
.
view
(
h
*
w
,
bs
,
c
)
x
=
x
.
view
(
h
*
w
,
bs
,
c
)
# Return the patch embeddings
return
x
return
x
...
@@ -35,56 +96,121 @@ class LearnedPositionalEmbeddings(Module):
...
@@ -35,56 +96,121 @@ class LearnedPositionalEmbeddings(Module):
<a id="LearnedPositionalEmbeddings">
<a id="LearnedPositionalEmbeddings">
## Add parameterized positional encodings
## Add parameterized positional encodings
</a>
</a>
This adds learned positional embeddings to patch embeddings.
"""
"""
def
__init__
(
self
,
d_model
:
int
,
max_len
:
int
=
5_000
):
def
__init__
(
self
,
d_model
:
int
,
max_len
:
int
=
5_000
):
"""
* `d_model` is the transformer embeddings size
* `max_len` is the maximum number of patches
"""
super
().
__init__
()
super
().
__init__
()
# Positional embeddings for each location
self
.
positional_encodings
=
nn
.
Parameter
(
torch
.
zeros
(
max_len
,
1
,
d_model
),
requires_grad
=
True
)
self
.
positional_encodings
=
nn
.
Parameter
(
torch
.
zeros
(
max_len
,
1
,
d_model
),
requires_grad
=
True
)
def
__call__
(
self
,
x
:
torch
.
Tensor
):
def
__call__
(
self
,
x
:
torch
.
Tensor
):
"""
* `x` is the patch embeddings of shape `[patches, batch_size, d_model]`
"""
# Get the positional embeddings for the given patches
pe
=
self
.
positional_encodings
[
x
.
shape
[
0
]]
pe
=
self
.
positional_encodings
[
x
.
shape
[
0
]]
# Add to patch embeddings and return
return
x
+
pe
return
x
+
pe
class
ClassificationHead
(
Module
):
class
ClassificationHead
(
Module
):
"""
<a id="ClassificationHead">
## MLP Classification Head
</a>
This is the two layer MLP head to classify the image based on `[CLS]` token embedding.
"""
def
__init__
(
self
,
d_model
:
int
,
n_hidden
:
int
,
n_classes
:
int
):
def
__init__
(
self
,
d_model
:
int
,
n_hidden
:
int
,
n_classes
:
int
):
"""
* `d_model` is the transformer embedding size
* `n_hidden` is the size of the hidden layer
* `n_classes` is the number of classes in the classification task
"""
super
().
__init__
()
super
().
__init__
()
self
.
ln
=
nn
.
LayerNorm
([
d_model
])
# First layer
self
.
linear1
=
nn
.
Linear
(
d_model
,
n_hidden
)
self
.
linear1
=
nn
.
Linear
(
d_model
,
n_hidden
)
# Activation
self
.
act
=
nn
.
ReLU
()
self
.
act
=
nn
.
ReLU
()
# Second layer
self
.
linear2
=
nn
.
Linear
(
n_hidden
,
n_classes
)
self
.
linear2
=
nn
.
Linear
(
n_hidden
,
n_classes
)
def
__call__
(
self
,
x
:
torch
.
Tensor
):
def
__call__
(
self
,
x
:
torch
.
Tensor
):
x
=
self
.
ln
(
x
)
"""
* `x` is the transformer encoding for `[CLS]` token
"""
# First layer and activation
x
=
self
.
act
(
self
.
linear1
(
x
))
x
=
self
.
act
(
self
.
linear1
(
x
))
# Second layer
x
=
self
.
linear2
(
x
)
x
=
self
.
linear2
(
x
)
#
return
x
return
x
class
VisionTransformer
(
Module
):
class
VisionTransformer
(
Module
):
"""
## Vision Transformer
This combines the [patch embeddings](#PatchEmbeddings),
[positional embeddings](#LearnedPositionalEmbeddings),
transformer and the [classification head](#ClassificationHead).
"""
def
__init__
(
self
,
transformer_layer
:
TransformerLayer
,
n_layers
:
int
,
def
__init__
(
self
,
transformer_layer
:
TransformerLayer
,
n_layers
:
int
,
patch_emb
:
PatchEmbeddings
,
pos_emb
:
LearnedPositionalEmbeddings
,
patch_emb
:
PatchEmbeddings
,
pos_emb
:
LearnedPositionalEmbeddings
,
classification
:
ClassificationHead
):
classification
:
ClassificationHead
):
"""
* `transformer_layer` is a copy of a single [transformer layer](../models.html#TransformerLayer).
We make copies of it to make the transformer with `n_layers`.
* `n_layers` is the number of [transformer layers]((../models.html#TransformerLayer).
* `patch_emb` is the [patch embeddings layer](#PatchEmbeddings).
* `pos_emb` is the [positional embeddings layer](#LearnedPositionalEmbeddings).
* `classification` is the [classification head](#ClassificationHead).
"""
super
().
__init__
()
super
().
__init__
()
# Make copies of the transformer layer
# Patch embeddings
self
.
classification
=
classification
self
.
pos_emb
=
pos_emb
self
.
patch_emb
=
patch_emb
self
.
patch_emb
=
patch_emb
self
.
pos_emb
=
pos_emb
# Classification head
self
.
classification
=
classification
# Make copies of the transformer layer
self
.
transformer_layers
=
clone_module_list
(
transformer_layer
,
n_layers
)
self
.
transformer_layers
=
clone_module_list
(
transformer_layer
,
n_layers
)
# `[CLS]` token embedding
self
.
cls_token_emb
=
nn
.
Parameter
(
torch
.
randn
(
1
,
1
,
transformer_layer
.
size
),
requires_grad
=
True
)
self
.
cls_token_emb
=
nn
.
Parameter
(
torch
.
randn
(
1
,
1
,
transformer_layer
.
size
),
requires_grad
=
True
)
# Final normalization layer
self
.
ln
=
nn
.
LayerNorm
([
transformer_layer
.
size
])
def
__call__
(
self
,
x
):
def
__call__
(
self
,
x
:
torch
.
Tensor
):
"""
* `x` is the input image of shape `[batch_size, channels, height, width]`
"""
# Get patch embeddings. This gives a tensor of shape `[patches, batch_size, d_model]`
x
=
self
.
patch_emb
(
x
)
x
=
self
.
patch_emb
(
x
)
# Add positional embeddings
x
=
self
.
pos_emb
(
x
)
x
=
self
.
pos_emb
(
x
)
# Concatenate the `[CLS]` token embeddings before feeding the transformer
cls_token_emb
=
self
.
cls_token_emb
.
expand
(
-
1
,
x
.
shape
[
1
],
-
1
)
cls_token_emb
=
self
.
cls_token_emb
.
expand
(
-
1
,
x
.
shape
[
1
],
-
1
)
x
=
torch
.
cat
([
cls_token_emb
,
x
])
x
=
torch
.
cat
([
cls_token_emb
,
x
])
# Pass through transformer layers with no attention masking
for
layer
in
self
.
transformer_layers
:
for
layer
in
self
.
transformer_layers
:
x
=
layer
(
x
=
x
,
mask
=
None
)
x
=
layer
(
x
=
x
,
mask
=
None
)
# Get the transformer output of the `[CLS]` token (which is the first in the sequence).
x
=
x
[
0
]
x
=
x
[
0
]
# Layer normalization
x
=
self
.
ln
(
x
)
# Classification head, to get logits
x
=
self
.
classification
(
x
)
x
=
self
.
classification
(
x
)
#
return
x
return
x
labml_nn/transformers/vit/experiment.py
浏览文件 @
f038ab67
"""
"""
---
---
title: Train a Vi
T
on CIFAR 10
title: Train a Vi
sion Transformer (ViT)
on CIFAR 10
summary: >
summary: >
Train a Vi
T
on CIFAR 10
Train a Vi
sion Transformer (ViT)
on CIFAR 10
---
---
# Train a ViT on CIFAR 10
# Train a [Vision Transformer (ViT)](index.html) on CIFAR 10
[![View Run](https://img.shields.io/badge/labml-experiment-brightgreen)](https://app.labml.ai/run/8b531d9ce3dc11eb84fc87df6756eb8f)
"""
"""
from
labml
import
experiment
from
labml
import
experiment
...
@@ -18,19 +20,27 @@ class Configs(CIFAR10Configs):
...
@@ -18,19 +20,27 @@ class Configs(CIFAR10Configs):
"""
"""
## Configurations
## Configurations
We use [`CIFAR10Configs`](../experiments/cifar10.html) which defines all the
We use [`CIFAR10Configs`](../
../
experiments/cifar10.html) which defines all the
dataset related configurations, optimizer, and a training loop.
dataset related configurations, optimizer, and a training loop.
"""
"""
# [Transformer configurations](../configs.html#TransformerConfigs)
# to get [transformer layer](../models.html#TransformerLayer)
transformer
:
TransformerConfigs
transformer
:
TransformerConfigs
# Size of a patch
patch_size
:
int
=
4
patch_size
:
int
=
4
n_hidden
:
int
=
2048
# Size of the hidden layer in classification head
n_hidden_classification
:
int
=
2048
# Number of classes in the task
n_classes
:
int
=
10
n_classes
:
int
=
10
@
option
(
Configs
.
transformer
)
@
option
(
Configs
.
transformer
)
def
_transformer
(
c
:
Configs
):
def
_transformer
():
"""
Create transformer configs
"""
return
TransformerConfigs
()
return
TransformerConfigs
()
...
@@ -42,11 +52,13 @@ def _vit(c: Configs):
...
@@ -42,11 +52,13 @@ def _vit(c: Configs):
from
labml_nn.transformers.vit
import
VisionTransformer
,
LearnedPositionalEmbeddings
,
ClassificationHead
,
\
from
labml_nn.transformers.vit
import
VisionTransformer
,
LearnedPositionalEmbeddings
,
ClassificationHead
,
\
PatchEmbeddings
PatchEmbeddings
# Transformer size from [Transformer configurations](../configs.html#TransformerConfigs)
d_model
=
c
.
transformer
.
d_model
d_model
=
c
.
transformer
.
d_model
# Create a vision transformer
return
VisionTransformer
(
c
.
transformer
.
encoder_layer
,
c
.
transformer
.
n_layers
,
return
VisionTransformer
(
c
.
transformer
.
encoder_layer
,
c
.
transformer
.
n_layers
,
PatchEmbeddings
(
d_model
,
c
.
patch_size
,
3
),
PatchEmbeddings
(
d_model
,
c
.
patch_size
,
3
),
LearnedPositionalEmbeddings
(
d_model
),
LearnedPositionalEmbeddings
(
d_model
),
ClassificationHead
(
d_model
,
c
.
n_hidden
,
c
.
n_classes
)).
to
(
c
.
device
)
ClassificationHead
(
d_model
,
c
.
n_hidden
_classification
,
c
.
n_classes
)).
to
(
c
.
device
)
def
main
():
def
main
():
...
@@ -56,20 +68,20 @@ def main():
...
@@ -56,20 +68,20 @@ def main():
conf
=
Configs
()
conf
=
Configs
()
# Load configurations
# Load configurations
experiment
.
configs
(
conf
,
{
experiment
.
configs
(
conf
,
{
'device.cuda_device'
:
0
,
# Optimizer
# 'optimizer.optimizer': 'Noam',
# 'optimizer.learning_rate': 1.,
'optimizer.optimizer'
:
'Adam'
,
'optimizer.optimizer'
:
'Adam'
,
'optimizer.learning_rate'
:
2.5e-4
,
'optimizer.learning_rate'
:
2.5e-4
,
'optimizer.d_model'
:
512
,
# Transformer embedding size
'transformer.d_model'
:
512
,
'transformer.d_model'
:
512
,
# Training epochs and batch size
'epochs'
:
1000
,
'epochs'
:
1000
,
'train_batch_size'
:
64
,
'train_batch_size'
:
64
,
# Augment CIFAR 10 images for training
'train_dataset'
:
'cifar10_train_augmented'
,
'train_dataset'
:
'cifar10_train_augmented'
,
# Do not augment CIFAR 10 images for validation
'valid_dataset'
:
'cifar10_valid_no_augment'
,
'valid_dataset'
:
'cifar10_valid_no_augment'
,
})
})
# Set model for saving/loading
# Set model for saving/loading
...
...
labml_nn/transformers/vit/readme.md
0 → 100644
浏览文件 @
f038ab67
# [Vision Transformer (ViT)](https://nn.labml.ai/transformer/vit/index.html)
This is a
[
PyTorch
](
https://pytorch.org
)
implementation of the paper
[
An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale
](
https://arxiv.org/abs/2010.11929
)
.
Vision transformer applies a pure transformer to images
without any convolution layers.
They split the image into patches and apply a transformer on patch embeddings.
[
Patch embeddings
](
https://nn.labml.ai/transformer/vit/index.html#PathEmbeddings
)
are generated by applying a simple linear transformation
to the flattened pixel values of the patch.
Then a standard transformer encoder is fed with the patch embeddings, along with a
classification token
`[CLS]`
.
The encoding on the
`[CLS]`
token is used to classify the image with an MLP.
When feeding the transformer with the patches, learned positional embeddings are
added to the patch embeddings, because the patch embeddings do not have any information
about where that patch is from.
The positional embeddings are a set of vectors for each patch location that get trained
with gradient descent along with other parameters.
ViTs perform well when they are pre-trained on large datasets.
The paper suggests pre-training them with an MLP classification head and
then using a single linear layer when fine-tuning.
The paper beats SOTA with a ViT pre-trained on a 300 million image dataset.
They also use higher resolution images during inference while keeping the
patch size the same.
The positional embeddings for new patch locations are calculated by interpolating
learning positional embeddings.
Here's
[
an experiment
](
https://nn.labml.ai/transformer/vit/experiment.html
)
that trains ViT on CIFAR-10.
This doesn't do very well because it's trained on a small dataset.
It's a simple experiment that anyone can run and play with ViTs.
readme.md
浏览文件 @
f038ab67
...
@@ -37,6 +37,7 @@ implementations almost weekly.
...
@@ -37,6 +37,7 @@ implementations almost weekly.
*
[
Masked Language Model
](
https://nn.labml.ai/transformers/mlm/index.html
)
*
[
Masked Language Model
](
https://nn.labml.ai/transformers/mlm/index.html
)
*
[
MLP-Mixer: An all-MLP Architecture for Vision
](
https://nn.labml.ai/transformers/mlp_mixer/index.html
)
*
[
MLP-Mixer: An all-MLP Architecture for Vision
](
https://nn.labml.ai/transformers/mlp_mixer/index.html
)
*
[
Pay Attention to MLPs (gMLP)
](
https://nn.labml.ai/transformers/gmlp/index.html
)
*
[
Pay Attention to MLPs (gMLP)
](
https://nn.labml.ai/transformers/gmlp/index.html
)
*
[
Vision Transformer (ViT)
](
https://nn.labml.ai/transformers/vit/index.html
)
#### ✨ [Recurrent Highway Networks](https://nn.labml.ai/recurrent_highway_networks/index.html)
#### ✨ [Recurrent Highway Networks](https://nn.labml.ai/recurrent_highway_networks/index.html)
...
...
setup.py
浏览文件 @
f038ab67
...
@@ -5,7 +5,7 @@ with open("readme.md", "r") as f:
...
@@ -5,7 +5,7 @@ with open("readme.md", "r") as f:
setuptools
.
setup
(
setuptools
.
setup
(
name
=
'labml-nn'
,
name
=
'labml-nn'
,
version
=
'0.4.10
2
'
,
version
=
'0.4.10
3
'
,
author
=
"Varuna Jayasiri, Nipun Wijerathne"
,
author
=
"Varuna Jayasiri, Nipun Wijerathne"
,
author_email
=
"vpjayasiri@gmail.com, hnipun@gmail.com"
,
author_email
=
"vpjayasiri@gmail.com, hnipun@gmail.com"
,
description
=
"A collection of PyTorch implementations of neural network architectures and layers."
,
description
=
"A collection of PyTorch implementations of neural network architectures and layers."
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录