Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
嗷我懂了
猫狗分类-pytorch
提交
b00660c8
猫
猫狗分类-pytorch
项目概览
嗷我懂了
/
猫狗分类-pytorch
通知
9
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
猫
猫狗分类-pytorch
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
b00660c8
编写于
12月 25, 2020
作者:
嗷我懂了
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add new file
上级
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
125 addition
and
0 deletion
+125
-0
main
main
+125
-0
未找到文件。
main
0 → 100644
浏览文件 @
b00660c8
import
torchvision
from
torch
.
utils
.
data
import
DataLoader
import
torch
import
time
from
PIL
import
Image
#
搭建模型
model
=
torchvision
.
models
.
vgg19
(
pretrained
=
True
)
model
.
classifier
[-
1
]
=
torch
.
nn
.
Linear
(
4096
,
2
)
print
(
model
)
#
初始化运行条件
if
True
:
sp
=
'\n'
+
'--------'
*
20
+
'\n'
root
=
'./data'
train_path
=
root
+
'/train'
test_path
=
root
+
'/test'
bs
=
16
lr
=
0.0001
epoch
=
20
device
=
'cuda'
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
optimizer
=
torch
.
optim
.
SGD
(
params
=
model
.
parameters
(),
momentum
=
0.9
,
lr
=
lr
)
transform
=
torchvision
.
transforms
.
Compose
([
torchvision
.
transforms
.
RandomResizedCrop
(
224
),
torchvision
.
transforms
.
ToTensor
()
])
#
读取数据
if
True
:
train_data
=
torchvision
.
datasets
.
ImageFolder
(
train_path
,
transform
)
classes
=
train_data
.
classes
train_iterator
=
DataLoader
(
train_data
,
bs
,
shuffle
=
True
)
#
展示数据细节
if
False
:
print
(
train_data
,
end
=
sp
)
print
(
'class_to_idx: '
,
train_data
.
class_to_idx
)
print
(
'classes: '
,
train_data
.
classes
)
print
(
'extension: '
,
train_data
.
extensions
)
print
(
'extra_repr: '
,
train_data
.
extra_repr
())
print
(
'imgs: '
,
train_data
.
imgs
)
print
(
'loader: '
,
train_data
.
loader
)
print
(
'root: '
,
train_data
.
root
)
print
(
'samples: '
,
train_data
.
samples
)
print
(
'target_transform: '
,
train_data
.
target_transform
)
print
(
'targets: '
,
train_data
.
targets
)
print
(
'transform: '
,
train_data
.
transform
)
print
(
'transforms: '
,
train_data
.
transforms
)
print
(
'\n\n'
,
end
=
sp
)
print
(
train_iterator
,
end
=
sp
)
print
(
'batch_sampler: '
,
train_iterator
.
batch_sampler
)
print
(
'batch_size: '
,
train_iterator
.
batch_size
)
print
(
'collate_fn: '
,
train_iterator
.
collate_fn
)
print
(
'dataset: '
,
train_iterator
.
dataset
)
print
(
'drop_last: '
,
train_iterator
.
drop_last
)
print
(
'generator: '
,
train_iterator
.
generator
)
print
(
'multiprocessing_context: '
,
train_iterator
.
multiprocessing_context
)
print
(
'num_workers: '
,
train_iterator
.
num_workers
)
print
(
'persistent_workers: '
,
train_iterator
.
persistent_workers
)
print
(
'pin_memory: '
,
train_iterator
.
pin_memory
)
print
(
'prefetch_factor: '
,
train_iterator
.
prefetch_factor
)
print
(
'sampler: '
,
train_iterator
.
sampler
)
print
(
'timeout: '
,
train_iterator
.
timeout
)
print
(
'worker_init_fn: '
,
train_iterator
.
worker_init_fn
)
print
(
'\n\n'
,
end
=
sp
)
def
train
(
model
,
iterator
,
optimizer
,
criterion
):
def
accuracy
(
outputs
,
label
):
pre
=
torch
.
argmax
(
outputs
,
dim
=
1
)
acc_num
=
(
pre
==
label
).
sum
()
return
acc_num
/
len
(
label
)
start_time
=
time
.
monotonic
()
epoch_loss
=
0.0
epoch_acc
=
0.0
model
=
model
.
to
(
device
)
model
.
train
()
for
(
images
,
labels
)
in
iterator
:
optimizer
.
zero_grad
()
images
=
images
.
to
(
device
)
labels
=
labels
.
to
(
device
)
outputs
=
model
(
images
)
loss
=
criterion
(
outputs
,
labels
)
acc
=
accuracy
(
outputs
,
labels
)
loss
.
backward
()
optimizer
.
step
()
epoch_loss
+=
loss
epoch_acc
+=
acc
cost_time
=
time
.
monotonic
()
-
start_time
return
epoch_loss
/
len
(
iterator
),
epoch_acc
/
len
(
iterator
),
cost_time
if
__name__
==
'__main__'
:
#
是否训练
if
False
:
for
epoch
in
range
(
epoch
):
loss
,
acc
,
cost_t
=
train
(
model
,
train_iterator
,
optimizer
,
criterion
)
print
(
f
'epoch: {epoch}\tcost time: {cost_t}\nloss: {loss}\tacc: {acc}'
)
torch
.
save
(
model
.
state_dict
(),
'cat_dog_classification.pth'
)
model
.
load_state_dict
(
torch
.
load
(
'cat_dog_classification.pth'
))
#
是否选择图片进行预测
choice
=
input
(
'是否选择图片进行预测? '
)
if
choice
in
{
'y'
,
'Y'
}:
path
=
input
(
'输入图片路径(仅限于jpg): '
)
def
classification
(
path
):
image
=
transform
(
Image
.
open
(
path
))
image
=
torch
.
unsqueeze
(
image
,
0
)
out
=
model
(
image
)
poss
=
torch
.
softmax
(
out
,
dim
=
1
)
index
=
int
(
torch
.
argmax
(
out
,
dim
=
1
))
print
(
'name: '
,
classes
[
index
],
'\tpossibility: '
,
float
(
poss
[
0
,
index
]))
classification
(
path
)
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录