loadmodel.py 603 字节
Newer Older
HypoX64's avatar
preview  
HypoX64 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
import torch
from .pix2pix_model import *
from .unet_model import UNet

def pix2pix(model_path,G_model_type,use_gpu = True):
    gpu_ids=[]
    if use_gpu:
        gpu_ids=[0]
    netG = define_G(3, 3, 64, G_model_type, norm='instance', init_type='normal', gpu_ids=gpu_ids)
    netG.load_state_dict(torch.load(model_path))
    netG.eval()
    if use_gpu:
        netG.cuda()
    return netG

def unet(model_path,use_gpu = True):
    net = UNet(n_channels = 3, n_classes = 1)
    net.load_state_dict(torch.load(model_path))
    net.eval()
    if use_gpu:
        net.cuda()
    return net


# def unet():