提交 c5a2b289 编写于 作者: C chenyuntc

refactor code

上级 e28a0651
......@@ -4,7 +4,7 @@ from skimage import transform as sktsf
from torchvision import transforms as tvtsf
from . import util
import numpy as np
from config import opt
from utils.config import opt
def inverse_normalize(img):
......
此差异由.gitattributes 抑制。
......@@ -4,16 +4,15 @@ import ipdb
import matplotlib
from tqdm import tqdm
import torch as t
from config import opt
from utils.config import opt
from data.dataset import Dataset, TestDataset
from model import FasterRCNNVGG16
from torch.autograd import Variable
from torch.utils import data as data_
from trainer import FasterRCNNTrainer
from util import array_tool as at
from util.vis_tool import visdom_bbox
from util.eval_tool import eval_detection_voc
from utils import array_tool as at
from utils.vis_tool import visdom_bbox
from utils.eval_tool import eval_detection_voc
matplotlib.use('agg')
......
......@@ -2,14 +2,14 @@ from __future__ import division
import torch as t
import numpy as np
import cupy as cp
from util import array_tool as at
from utils import array_tool as at
from model.utils.bbox_tools import loc2bbox
from model.utils.nms import non_maximum_suppression
from torch import nn
from data.dataset import preprocess
from torch.nn import functional as F
from config import opt
from utils.config import opt
class FasterRCNN(nn.Module):
......
import torch as t
from torch import nn
from torchvision.models import vgg16
from .region_proposal_network import RegionProposalNetwork
from .faster_rcnn import FasterRCNN
from .ROIModule import ROIPooling2D
from util import array_tool as at
from config import opt
from model.region_proposal_network import RegionProposalNetwork
from model.faster_rcnn import FasterRCNN
from model.roi_module import RoIPooling2D
from utils import array_tool as at
from utils.config import opt
def decom_vgg16():
......@@ -154,7 +154,7 @@ class VGG16RoIHead(nn.Module):
self.n_class = n_class
self.roi_size = roi_size
self.spatial_scale = spatial_scale
self.roi = ROIPooling2D(self.roi_size, self.roi_size, self.spatial_scale)
self.roi = RoIPooling2D(self.roi_size, self.roi_size, self.spatial_scale)
def forward(self, x, rois, roi_indices):
"""Forward the chain.
......
......@@ -26,7 +26,7 @@ def GET_BLOCKS(N, K=CUDA_NUM_THREADS):
return (N + K - 1) // K
class ROI(Function):
class RoI(Function):
"""
NOTE:only CUDA-compatible
"""
......@@ -79,14 +79,14 @@ class ROI(Function):
return grad_input, None
class ROIPooling2D(t.nn.Module):
class RoIPooling2D(t.nn.Module):
def __init__(self, outh, outw, spatial_scale):
super(ROIPooling2D, self).__init__()
self.ROI = ROI(outh, outw, spatial_scale)
super(RoIPooling2D, self).__init__()
self.RoI = RoI(outh, outw, spatial_scale)
def forward(self, x, rois):
return self.ROI(x, rois)
return self.RoI(x, rois)
def test_roi_module():
......@@ -103,7 +103,7 @@ def test_roi_module():
outh, outw = PH, PW
# pytorch version
module = ROIPooling2D(outh, outw, spatial_scale)
module = RoIPooling2D(outh, outw, spatial_scale)
x = t.autograd.Variable(bottom_data, requires_grad=True)
rois = t.autograd.Variable(bottom_rois)
output = module(x, rois)
......
......@@ -137,7 +137,6 @@ kernel_backward = '''
int index_ = ph * pooled_width + pw + offset;
if (argmax_data[index_] == (h * width + w)) {
gradient += top_diff[index_];
//printf("%d-%f ",index_, top_diff[index_]);
}
}
}
......
......@@ -4,15 +4,15 @@ import ipdb
import matplotlib
from tqdm import tqdm
from config import opt
from utils.config import opt
from data.dataset import Dataset, TestDataset, inverse_normalize
from model import FasterRCNNVGG16
from torch.autograd import Variable
from torch.utils import data as data_
from trainer import FasterRCNNTrainer
from util import array_tool as at
from util.vis_tool import visdom_bbox
from util.eval_tool import eval_detection_voc
from utils import array_tool as at
from utils.vis_tool import visdom_bbox
from utils.eval_tool import eval_detection_voc
# fix for ulimit
# https://github.com/pytorch/pytorch/issues/973#issuecomment-346405667
......
......@@ -6,10 +6,10 @@ from model.utils.creator_tool import AnchorTargetCreator, ProposalTargetCreator
from torch import nn
import torch as t
from torch.autograd import Variable
from util import array_tool as at
from util.vis_tool import Visualizer
from utils import array_tool as at
from utils.vis_tool import Visualizer
from config import opt
from utils.config import opt
from torchnet.meter import ConfusionMeter, AverageValueMeter
LossTuple = namedtuple('LossTuple',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册