提交 c5a2b289 编写于 作者: C chenyuntc

refactor code

上级 e28a0651
...@@ -4,7 +4,7 @@ from skimage import transform as sktsf ...@@ -4,7 +4,7 @@ from skimage import transform as sktsf
from torchvision import transforms as tvtsf from torchvision import transforms as tvtsf
from . import util from . import util
import numpy as np import numpy as np
from config import opt from utils.config import opt
def inverse_normalize(img): def inverse_normalize(img):
......
此差异由.gitattributes 抑制。
...@@ -4,16 +4,15 @@ import ipdb ...@@ -4,16 +4,15 @@ import ipdb
import matplotlib import matplotlib
from tqdm import tqdm from tqdm import tqdm
import torch as t from utils.config import opt
from config import opt
from data.dataset import Dataset, TestDataset from data.dataset import Dataset, TestDataset
from model import FasterRCNNVGG16 from model import FasterRCNNVGG16
from torch.autograd import Variable from torch.autograd import Variable
from torch.utils import data as data_ from torch.utils import data as data_
from trainer import FasterRCNNTrainer from trainer import FasterRCNNTrainer
from util import array_tool as at from utils import array_tool as at
from util.vis_tool import visdom_bbox from utils.vis_tool import visdom_bbox
from util.eval_tool import eval_detection_voc from utils.eval_tool import eval_detection_voc
matplotlib.use('agg') matplotlib.use('agg')
......
...@@ -2,14 +2,14 @@ from __future__ import division ...@@ -2,14 +2,14 @@ from __future__ import division
import torch as t import torch as t
import numpy as np import numpy as np
import cupy as cp import cupy as cp
from util import array_tool as at from utils import array_tool as at
from model.utils.bbox_tools import loc2bbox from model.utils.bbox_tools import loc2bbox
from model.utils.nms import non_maximum_suppression from model.utils.nms import non_maximum_suppression
from torch import nn from torch import nn
from data.dataset import preprocess from data.dataset import preprocess
from torch.nn import functional as F from torch.nn import functional as F
from config import opt from utils.config import opt
class FasterRCNN(nn.Module): class FasterRCNN(nn.Module):
......
import torch as t import torch as t
from torch import nn from torch import nn
from torchvision.models import vgg16 from torchvision.models import vgg16
from .region_proposal_network import RegionProposalNetwork from model.region_proposal_network import RegionProposalNetwork
from .faster_rcnn import FasterRCNN from model.faster_rcnn import FasterRCNN
from .ROIModule import ROIPooling2D from model.roi_module import RoIPooling2D
from util import array_tool as at from utils import array_tool as at
from config import opt from utils.config import opt
def decom_vgg16(): def decom_vgg16():
...@@ -154,7 +154,7 @@ class VGG16RoIHead(nn.Module): ...@@ -154,7 +154,7 @@ class VGG16RoIHead(nn.Module):
self.n_class = n_class self.n_class = n_class
self.roi_size = roi_size self.roi_size = roi_size
self.spatial_scale = spatial_scale self.spatial_scale = spatial_scale
self.roi = ROIPooling2D(self.roi_size, self.roi_size, self.spatial_scale) self.roi = RoIPooling2D(self.roi_size, self.roi_size, self.spatial_scale)
def forward(self, x, rois, roi_indices): def forward(self, x, rois, roi_indices):
"""Forward the chain. """Forward the chain.
......
...@@ -26,7 +26,7 @@ def GET_BLOCKS(N, K=CUDA_NUM_THREADS): ...@@ -26,7 +26,7 @@ def GET_BLOCKS(N, K=CUDA_NUM_THREADS):
return (N + K - 1) // K return (N + K - 1) // K
class ROI(Function): class RoI(Function):
""" """
NOTE:only CUDA-compatible NOTE:only CUDA-compatible
""" """
...@@ -79,14 +79,14 @@ class ROI(Function): ...@@ -79,14 +79,14 @@ class ROI(Function):
return grad_input, None return grad_input, None
class ROIPooling2D(t.nn.Module): class RoIPooling2D(t.nn.Module):
def __init__(self, outh, outw, spatial_scale): def __init__(self, outh, outw, spatial_scale):
super(ROIPooling2D, self).__init__() super(RoIPooling2D, self).__init__()
self.ROI = ROI(outh, outw, spatial_scale) self.RoI = RoI(outh, outw, spatial_scale)
def forward(self, x, rois): def forward(self, x, rois):
return self.ROI(x, rois) return self.RoI(x, rois)
def test_roi_module(): def test_roi_module():
...@@ -103,7 +103,7 @@ def test_roi_module(): ...@@ -103,7 +103,7 @@ def test_roi_module():
outh, outw = PH, PW outh, outw = PH, PW
# pytorch version # pytorch version
module = ROIPooling2D(outh, outw, spatial_scale) module = RoIPooling2D(outh, outw, spatial_scale)
x = t.autograd.Variable(bottom_data, requires_grad=True) x = t.autograd.Variable(bottom_data, requires_grad=True)
rois = t.autograd.Variable(bottom_rois) rois = t.autograd.Variable(bottom_rois)
output = module(x, rois) output = module(x, rois)
......
...@@ -137,7 +137,6 @@ kernel_backward = ''' ...@@ -137,7 +137,6 @@ kernel_backward = '''
int index_ = ph * pooled_width + pw + offset; int index_ = ph * pooled_width + pw + offset;
if (argmax_data[index_] == (h * width + w)) { if (argmax_data[index_] == (h * width + w)) {
gradient += top_diff[index_]; gradient += top_diff[index_];
//printf("%d-%f ",index_, top_diff[index_]);
} }
} }
} }
......
...@@ -4,15 +4,15 @@ import ipdb ...@@ -4,15 +4,15 @@ import ipdb
import matplotlib import matplotlib
from tqdm import tqdm from tqdm import tqdm
from config import opt from utils.config import opt
from data.dataset import Dataset, TestDataset, inverse_normalize from data.dataset import Dataset, TestDataset, inverse_normalize
from model import FasterRCNNVGG16 from model import FasterRCNNVGG16
from torch.autograd import Variable from torch.autograd import Variable
from torch.utils import data as data_ from torch.utils import data as data_
from trainer import FasterRCNNTrainer from trainer import FasterRCNNTrainer
from util import array_tool as at from utils import array_tool as at
from util.vis_tool import visdom_bbox from utils.vis_tool import visdom_bbox
from util.eval_tool import eval_detection_voc from utils.eval_tool import eval_detection_voc
# fix for ulimit # fix for ulimit
# https://github.com/pytorch/pytorch/issues/973#issuecomment-346405667 # https://github.com/pytorch/pytorch/issues/973#issuecomment-346405667
......
...@@ -6,10 +6,10 @@ from model.utils.creator_tool import AnchorTargetCreator, ProposalTargetCreator ...@@ -6,10 +6,10 @@ from model.utils.creator_tool import AnchorTargetCreator, ProposalTargetCreator
from torch import nn from torch import nn
import torch as t import torch as t
from torch.autograd import Variable from torch.autograd import Variable
from util import array_tool as at from utils import array_tool as at
from util.vis_tool import Visualizer from utils.vis_tool import Visualizer
from config import opt from utils.config import opt
from torchnet.meter import ConfusionMeter, AverageValueMeter from torchnet.meter import ConfusionMeter, AverageValueMeter
LossTuple = namedtuple('LossTuple', LossTuple = namedtuple('LossTuple',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册