提交 c69281b3 编写于 作者: 嗷我懂了's avatar 嗷我懂了

添加 README.md

上级 b00660c8
第一次用仓库,咱也不知道咋用,data文件夹不知道怎么传上来,就直接放百度云里了。
链接:https://pan.baidu.com/s/1DBstEBXH5C0IsM4fiUV15Q
提取码:oly0
简要对代码说明一下,压缩包解压后,就一个main可以直接运行。
1、直接用官方的vgg19,并且保留预训练参数,可以加快训练速度和提高精度,由于是二分类,所以把最后一层1000改成2
model = torchvision.models.vgg19(pretrained=True)
model.classifier[-1] = torch.nn.Linear(4096, 2)
2、读取文件夹中的内容,文件夹名字作为label
train_data = torchvision.datasets.ImageFolder(train_path, transform)
classes = train_data.classes
train_iterator = DataLoader(train_data, bs, shuffle=True)
3、模型放在gpu中,model.train()是把模型改成训练模式,后面进行评估时候要model.eval()
model = model.to(device)
model.train()
4、把训练好的模型整个保存下来
torch.save(model.state_dict(), 'cat_dog_classification.pth')
最后,作者也正在学习,欢迎大家批评指正。
mail: sunhy6594@outlook.com
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册