提交 b00660c8 编写于 作者: 嗷我懂了's avatar 嗷我懂了

Add new file

上级
import torchvision
from torch.utils.data import DataLoader
import torch
import time
from PIL import Image
# 搭建模型
model = torchvision.models.vgg19(pretrained=True)
model.classifier[-1] = torch.nn.Linear(4096, 2)
print(model)
# 初始化运行条件
if True:
sp = '\n' + '--------' * 20 + '\n'
root = './data'
train_path = root + '/train'
test_path = root + '/test'
bs = 16
lr = 0.0001
epoch = 20
device = 'cuda'
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(params=model.parameters(), momentum=0.9, lr=lr)
transform = torchvision.transforms.Compose([
torchvision.transforms.RandomResizedCrop(224),
torchvision.transforms.ToTensor()
])
# 读取数据
if True:
train_data = torchvision.datasets.ImageFolder(train_path, transform)
classes = train_data.classes
train_iterator = DataLoader(train_data, bs, shuffle=True)
# 展示数据细节
if False:
print(train_data, end=sp)
print('class_to_idx: ', train_data.class_to_idx)
print('classes: ', train_data.classes)
print('extension: ', train_data.extensions)
print('extra_repr: ', train_data.extra_repr())
print('imgs: ', train_data.imgs)
print('loader: ', train_data.loader)
print('root: ', train_data.root)
print('samples: ', train_data.samples)
print('target_transform: ', train_data.target_transform)
print('targets: ', train_data.targets)
print('transform: ', train_data.transform)
print('transforms: ', train_data.transforms)
print('\n\n', end=sp)
print(train_iterator, end=sp)
print('batch_sampler: ', train_iterator.batch_sampler)
print('batch_size: ', train_iterator.batch_size)
print('collate_fn: ', train_iterator.collate_fn)
print('dataset: ', train_iterator.dataset)
print('drop_last: ', train_iterator.drop_last)
print('generator: ', train_iterator.generator)
print('multiprocessing_context: ', train_iterator.multiprocessing_context)
print('num_workers: ', train_iterator.num_workers)
print('persistent_workers: ', train_iterator.persistent_workers)
print('pin_memory: ', train_iterator.pin_memory)
print('prefetch_factor: ', train_iterator.prefetch_factor)
print('sampler: ', train_iterator.sampler)
print('timeout: ', train_iterator.timeout)
print('worker_init_fn: ', train_iterator.worker_init_fn)
print('\n\n', end=sp)
def train(model, iterator, optimizer, criterion):
def accuracy(outputs, label):
pre = torch.argmax(outputs, dim=1)
acc_num = (pre == label).sum()
return acc_num / len(label)
start_time = time.monotonic()
epoch_loss = 0.0
epoch_acc = 0.0
model = model.to(device)
model.train()
for (images, labels) in iterator:
optimizer.zero_grad()
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
acc = accuracy(outputs, labels)
loss.backward()
optimizer.step()
epoch_loss += loss
epoch_acc += acc
cost_time = time.monotonic() - start_time
return epoch_loss / len(iterator), epoch_acc / len(iterator), cost_time
if __name__ == '__main__':
# 是否训练
if False:
for epoch in range(epoch):
loss, acc, cost_t = train(model, train_iterator, optimizer, criterion)
print(f'epoch: {epoch}\tcost time: {cost_t}\nloss: {loss}\tacc: {acc}')
torch.save(model.state_dict(), 'cat_dog_classification.pth')
model.load_state_dict(torch.load('cat_dog_classification.pth'))
# 是否选择图片进行预测
choice = input('是否选择图片进行预测? ')
if choice in {'y', 'Y'}:
path = input('输入图片路径(仅限于jpg): ')
def classification(path):
image = transform(Image.open(path))
image = torch.unsqueeze(image, 0)
out = model(image)
poss = torch.softmax(out, dim=1)
index = int(torch.argmax(out, dim=1))
print('name: ', classes[index], '\tpossibility: ', float(poss[0, index]))
classification(path)
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册