From c69281b3e4dfefa79a13117a4cee76d5102999b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=97=B7=E6=88=91=E6=87=82=E4=BA=86?= Date: Fri, 25 Dec 2020 17:20:23 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=20README.md?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 README.md diff --git a/README.md b/README.md new file mode 100644 index 0000000..324c328 --- /dev/null +++ b/README.md @@ -0,0 +1,26 @@ +第一次用仓库,咱也不知道咋用,data文件夹不知道怎么传上来,就直接放百度云里了。 + +链接:https://pan.baidu.com/s/1DBstEBXH5C0IsM4fiUV15Q +提取码:oly0 + +简要对代码说明一下,压缩包解压后,就一个main可以直接运行。 + +1、直接用官方的vgg19,并且保留预训练参数,可以加快训练速度和提高精度,由于是二分类,所以把最后一层1000改成2 + model = torchvision.models.vgg19(pretrained=True) + model.classifier[-1] = torch.nn.Linear(4096, 2) + +2、读取文件夹中的内容,文件夹名字作为label + train_data = torchvision.datasets.ImageFolder(train_path, transform) + classes = train_data.classes + train_iterator = DataLoader(train_data, bs, shuffle=True) + +3、模型放在gpu中,model.train()是把模型改成训练模式,后面进行评估时候要model.eval() + model = model.to(device) + model.train() + +4、把训练好的模型整个保存下来 + torch.save(model.state_dict(), 'cat_dog_classification.pth') + + +最后,作者也正在学习,欢迎大家批评指正。 +mail: sunhy6594@outlook.com -- GitLab