提交 be51173c 编写于 作者: 别团等shy哥发育's avatar 别团等shy哥发育

ConvMixer架构复现

上级 f34be2e9
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.models import Model
'''
patchembedding 的主要功能是对原始输入图像(h, w)划分图像块。首先指定每个图像块的size为(patch_size, patch_size),
将每张图像划分出(h//patch_size, w//patch_size)个图像块。
它的实现方法就是通过一个 kernel_size 和 stride 都等于 patch_size 的卷积层来划分图像块。
'''
# patch embedding层
def patch_embedding(inputs, # 输入张量
out_channel, # 输出通道数
patch_size # 每个图像的宽高
):
# 卷积核大小为patch_size*patch_size,步长为patch_size的标准卷积划分图像块
x = layers.Conv2D(filters=out_channel,
kernel_size=patch_size,
strides=patch_size,
padding='same',
use_bias=False)(inputs)
# GELU激活+BN
x = layers.Activation('gelu')(x)
x = layers.BatchNormalization()(x)
return x
# ConvMixer Layer:Depthwise conv+Pointwise conv
# out_channel代表poingwise conv的输出通道数
# kernel_size代表Depthwise conv的卷积核大小
def layer(inputs, out_channel, kernel_size):
# depthwise卷积
x = layers.DepthwiseConv2D(kernel_size=kernel_size, # 卷积核大小
strides=1, # 不经过下采样
padding='same', # 卷积前后size不变
use_bias=False)(inputs)
# GELU+BN
x-layers.Activation('gelu')(x)
x=layers.BatchNormalization()(x)
# shortcut connection
x=layers.Add()([x,inputs])
# pointwise卷积:1*1
x=layers.Conv2D(filters=out_channel,
kernel_size=1,
strides=1)(x)
# GELU+BN
x - layers.Activation('gelu')(x)
x = layers.BatchNormalization()(x)
return x
# ConvMixer block的堆叠
def blocks(x,depth,out_channel,kernel_size):
for _ in range(depth):
x=layer(x,out_channel,kernel_size)
return x
# 主干网络
def convmixer(input_shape,num_classes):
# 输入层:[b,224,224,3]
inputs=layers.Input(shape=input_shape)
# patch embedding层:[b,224//7,224//7,1536]
x=patch_embedding(inputs,out_channel=1536,patch_size=7)
# 经过20个ConvMixer block的堆叠:[b,224//7,224//7,1536]
x=blocks(x,depth=20,out_channel=1536,kernel_size=9)
# GAP
x=layers.GlobalAveragePooling2D()(x)
outputs=layers.Dense(num_classes,activation='softmax')(x)
# 构造网络
model=Model(inputs,outputs)
return model
if __name__ == '__main__':
model=convmixer(input_shape=(224,224,3),num_classes=1000)
model.summary()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册