video_model_unet.py 3.0 KB
Newer Older
H
hypox64 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
import torch
import torch.nn as nn
import torch.nn.functional as F
from .unet_parts import *


class conv_3d(nn.Module):
    def __init__(self,inchannel,outchannel,kernel_size=3,stride=2,padding=1):
        super(conv_3d, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv3d(inchannel, outchannel, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
            nn.BatchNorm3d(outchannel),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        x = self.conv(x)
        return x


class encoder_3d(nn.Module):
    def __init__(self,in_channel):
        super(encoder_3d, self).__init__()
        self.down1 = conv_3d(1, 64, 3, 2, 1)
        self.down2 = conv_3d(64, 128, 3, 2, 1)
        self.down3 = conv_3d(128, 256, 3, 2, 1)
        self.down4 = conv_3d(256, 512, 3, 2, 1)
        self.conver2d = nn.Sequential(
            nn.Conv2d(int(in_channel/16)+1, 1, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(1),
            nn.ReLU(inplace=True),
        )


    def forward(self, x):
        x = x.view(x.size(0),1,x.size(1),x.size(2),x.size(3))
        x = self.down1(x)
        x = self.down2(x)
        x = self.down3(x)
        x = self.down4(x)
        x = x.view(x.size(1),x.size(2),x.size(3),x.size(4))
        x = self.conver2d(x)
        x = x.view(x.size(1),x.size(0),x.size(2),x.size(3))
        # print(x.size())
        # x = self.avgpool(x)
        return x




class encoder_2d(nn.Module):
    def __init__(self, in_channel):
        super(encoder_2d, self).__init__()
        self.inc = inconv(in_channel, 64)
        self.down1 = down(64, 128)
        self.down2 = down(128, 256)
        self.down3 = down(256, 512)
        self.down4 = down(512, 512)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)

        return x1,x2,x3,x4,x5

class decoder_2d(nn.Module):
    def __init__(self, out_channel):
        super(decoder_2d, self).__init__()
        self.up1 = up(1024, 256,bilinear=False)
        self.up2 = up(512, 128,bilinear=False)
        self.up3 = up(256, 64,bilinear=False)
        self.up4 = up(128, 64,bilinear=False)
        self.outc = outconv(64, out_channel)

    def forward(self,x5,x4,x3,x2,x1):
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.outc(x)
        
        return x


class HypoNet(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(HypoNet, self).__init__()

        self.encoder_2d = encoder_2d(4)
        self.encoder_3d = encoder_3d(in_channel)
        self.decoder_2d = decoder_2d(out_channel)

    def forward(self, x):

        N = int((x.size()[1])/3)
        x_2d = torch.cat((x[:,int((N-1)/2)*3:(int((N-1)/2)+1)*3,:,:], x[:,N-1:N,:,:]), 1)
        # print(x_2d.size())
        x_3d = self.encoder_3d(x)

        x1,x2,x3,x4,x5 = self.encoder_2d(x_2d)
        x5 = x5 + x_3d
        x_2d = self.decoder_2d(x5,x4,x3,x2,x1)

        return x_2d