From 30f3cff6e42aebbecba4049e56181e95f0d15e53 Mon Sep 17 00:00:00 2001 From: Yizhuang Zhou <62599194+zhouyizhuang-megvii@users.noreply.github.com> Date: Mon, 8 Jun 2020 15:30:45 +0800 Subject: [PATCH] fix(segmentation): fix VOC category and add cityscapes (#18) --- official/vision/segmentation/README.md | 27 ++- .../vision/segmentation/cfg_cityscapes.py | 42 ++++ official/vision/segmentation/cfg_voc.py | 43 ++++ official/vision/segmentation/inference.py | 30 ++- official/vision/segmentation/test.py | 222 +++++++++--------- official/vision/segmentation/train.py | 77 +++--- official/vision/segmentation/utils.py | 10 + 7 files changed, 280 insertions(+), 171 deletions(-) create mode 100644 official/vision/segmentation/cfg_cityscapes.py create mode 100644 official/vision/segmentation/cfg_voc.py create mode 100644 official/vision/segmentation/utils.py diff --git a/official/vision/segmentation/README.md b/official/vision/segmentation/README.md index d9b9f42..cffaf6c 100644 --- a/official/vision/segmentation/README.md +++ b/official/vision/segmentation/README.md @@ -1,6 +1,6 @@ # Semantic Segmentation -本目录包含了采用MegEngine实现的经典[Deeplabv3plus](https://arxiv.org/abs/1802.02611.pdf)网络结构,同时提供了在PASCAL VOC数据集上的完整训练和测试代码。 +本目录包含了采用MegEngine实现的经典[Deeplabv3plus](https://arxiv.org/abs/1802.02611.pdf)网络结构,同时提供了在PASCAL VOC和Cityscapes数据集上的完整训练和测试代码。 网络在PASCAL VOC2012验证集的性能和结果如下: @@ -38,20 +38,25 @@ 3、开始训练: `train.py`的命令行参数如下: +- `--config`,训练时采用的配置文件,VOC和Cityscapes各一份默认配置; - `--dataset_dir`,训练时采用的训练集存放的目录; - `--weight_file`,训练时采用的预训练权重; -- `--batch-size`,训练时采用的batch size, 默认8; - `--ngpus`, 训练时采用的gpu数量,默认8; 当设置为1时,表示单卡训练 -- `--resume`, 是否从已训好的模型继续训练; -- `--train_epochs`, 需要训练的epoch数量; +- `--resume`, 是否从已训好的模型继续训练,默认`None`; ```bash -python3 train.py --dataset_dir /path/to/VOC2012 \ +python3 train.py --config cfg_voc.py \ + --dataset_dir /path/to/VOC2012 \ --weight_file /path/to/weights.pkl \ - --batch_size 8 \ - --ngpus 8 \ - --train_epochs 50 \ - --resume /path/to/model + --ngpus 8 +``` + +或在Cityscapes数据集上进行训练: +```bash +python3 train.py --config cfg_cityscapes.py \ + --dataset_dir /path/to/Cityscapes \ + --weight_file /path/to/weights.pkl \ + --ngpus 8 ``` ## 如何测试 @@ -59,11 +64,13 @@ python3 train.py --dataset_dir /path/to/VOC2012 \ 模型训练好之后,可以通过如下命令测试模型在VOC2012验证集的性能: ```bash -python3 test.py --dataset_dir /path/to/VOC2012 \ +python3 test.py --config cfg_voc.py \ + --dataset_dir /path/to/VOC2012 \ --model_path /path/to/model.pkl ``` `test.py`的命令行参数如下: +- `--config`,训练时采用的配置文件,VOC和Cityscapes各一份默认配置; - `--dataset_dir`,验证时采用的验证集目录; - `--model_path`,载入训练好的模型; diff --git a/official/vision/segmentation/cfg_cityscapes.py b/official/vision/segmentation/cfg_cityscapes.py new file mode 100644 index 0000000..7b54c43 --- /dev/null +++ b/official/vision/segmentation/cfg_cityscapes.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +import os + + +class Config: + DATASET = "Cityscapes" + + BATCH_SIZE = 4 + LEARNING_RATE = 0.0065 + EPOCHS = 200 + + ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname("__file__"))) + MODEL_SAVE_DIR = os.path.join(ROOT_DIR, "log") + LOG_DIR = MODEL_SAVE_DIR + if not os.path.isdir(MODEL_SAVE_DIR): + os.makedirs(MODEL_SAVE_DIR) + + DATA_WORKERS = 4 + + IGNORE_INDEX = 255 + NUM_CLASSES = 19 + IMG_HEIGHT = 800 + IMG_WIDTH = 800 + IMG_MEAN = [103.530, 116.280, 123.675] + IMG_STD = [57.375, 57.120, 58.395] + + VAL_HEIGHT = 800 + VAL_WIDTH = 800 + VAL_BATCHES = 1 + VAL_MULTISCALE = [1.0] # [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] + VAL_FLIP = False + VAL_SLIP = True + VAL_SAVE = None + + +cfg = Config() diff --git a/official/vision/segmentation/cfg_voc.py b/official/vision/segmentation/cfg_voc.py new file mode 100644 index 0000000..0c010da --- /dev/null +++ b/official/vision/segmentation/cfg_voc.py @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +import os + + +class Config: + DATASET = "VOC2012" + + BATCH_SIZE = 8 + LEARNING_RATE = 0.002 + EPOCHS = 100 + + ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname("__file__"))) + MODEL_SAVE_DIR = os.path.join(ROOT_DIR, "log") + LOG_DIR = MODEL_SAVE_DIR + if not os.path.isdir(MODEL_SAVE_DIR): + os.makedirs(MODEL_SAVE_DIR) + + DATA_WORKERS = 4 + DATA_TYPE = "trainaug" + + IGNORE_INDEX = 255 + NUM_CLASSES = 21 + IMG_HEIGHT = 512 + IMG_WIDTH = 512 + IMG_MEAN = [103.530, 116.280, 123.675] + IMG_STD = [57.375, 57.120, 58.395] + + VAL_HEIGHT = 512 + VAL_WIDTH = 512 + VAL_BATCHES = 1 + VAL_MULTISCALE = [1.0] # [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] + VAL_FLIP = False + VAL_SLIP = False + VAL_SAVE = None + + +cfg = Config() diff --git a/official/vision/segmentation/inference.py b/official/vision/segmentation/inference.py index 0405059..1e1c983 100644 --- a/official/vision/segmentation/inference.py +++ b/official/vision/segmentation/inference.py @@ -27,11 +27,35 @@ class Config: cfg = Config() +# pre-defined colors for at most 20 categories +class_colors = [ + [0, 0, 0], # background + [0, 0, 128], + [0, 128, 0], + [0, 128, 128], + [128, 0, 0], + [128, 0, 128], + [128, 128, 0], + [128, 128, 128], + [0, 0, 64], + [0, 0, 192], + [0, 128, 64], + [0, 128, 192], + [128, 0, 64], + [128, 0, 192], + [128, 128, 64], + [128, 128, 192], + [0, 64, 0], + [0, 64, 128], + [0, 192, 0], + [0, 192, 128], + [128, 64, 0], +] def main(): parser = argparse.ArgumentParser() - parser.add_argument("--image_path", type=str, default=None, help="inference image") - parser.add_argument("--model_path", type=str, default=None, help="inference model") + parser.add_argument("-i", "--image_path", type=str, default=None, help="inference image") + parser.add_argument("-m", "--model_path", type=str, default=None, help="inference model") args = parser.parse_args() net = load_model(args.model_path) @@ -43,7 +67,6 @@ def main(): pred = inference(img, net) cv2.imwrite("out.jpg", pred) - def load_model(model_path): model_dict = mge.load(model_path) net = DeepLabV3Plus(class_num=cfg.NUM_CLASSES) @@ -73,7 +96,6 @@ def inference(img, net): pred.astype("uint8"), (oriw, orih), interpolation=cv2.INTER_NEAREST ) - class_colors = dataset.PascalVOC.class_colors out = np.zeros((orih, oriw, 3)) nids = np.unique(pred) for t in nids: diff --git a/official/vision/segmentation/test.py b/official/vision/segmentation/test.py index 88b05b8..0fa8bed 100644 --- a/official/vision/segmentation/test.py +++ b/official/vision/segmentation/test.py @@ -20,28 +20,14 @@ import numpy as np from tqdm import tqdm from official.vision.segmentation.deeplabv3plus import DeepLabV3Plus - - -class Config: - DATA_WORKERS = 4 - - NUM_CLASSES = 21 - IMG_SIZE = 512 - IMG_MEAN = [103.530, 116.280, 123.675] - IMG_STD = [57.375, 57.120, 58.395] - - VAL_BATCHES = 1 - VAL_MULTISCALE = [1.0] # [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] - VAL_FLIP = False - VAL_SLIP = False - VAL_SAVE = None - - -cfg = Config() +from official.vision.segmentation.utils import import_config_from_file def main(): parser = argparse.ArgumentParser() + parser.add_argument( + "-c", "--config", type=str, required=True, help="configuration file" + ) parser.add_argument( "-d", "--dataset_dir", type=str, default="/data/datasets/VOC2012", ) @@ -50,7 +36,9 @@ def main(): ) args = parser.parse_args() - test_loader, test_size = build_dataloader(args.dataset_dir) + cfg = import_config_from_file(args.config) + + test_loader, test_size = build_dataloader(args.dataset_dir, cfg) print("number of test images: %d" % (test_size)) net = DeepLabV3Plus(class_num=cfg.NUM_CLASSES) model_dict = mge.load(args.model_path) @@ -63,13 +51,15 @@ def main(): for sample_batched in tqdm(test_loader): img = sample_batched[0].squeeze() label = sample_batched[1].squeeze() - pred = evaluate(net, img) - result_list.append({"pred": pred, "gt": label}) + im_info = sample_batched[2] + pred = evaluate(net, img, cfg) + result_list.append({"pred": pred, "gt": label, "name":im_info[2]}) if cfg.VAL_SAVE: - save_results(result_list, cfg.VAL_SAVE) - compute_metric(result_list) + save_results(result_list, cfg.VAL_SAVE, cfg) + compute_metric(result_list, cfg) +## inference one image def pad_image_to_shape(img, shape, border_mode, value): margin = np.zeros(4, np.uint32) pad_height = shape[0] - img.shape[0] if shape[0] - img.shape[0] > 0 else 0 @@ -86,40 +76,39 @@ def pad_image_to_shape(img, shape, border_mode, value): def eval_single(net, img, is_flip): @jit.trace(symbolic=True, opt_level=2) - def pred_fun(input_data, net=None): + def pred_fun(data, net=None): net.eval() - pred = net(input_data) + pred = net(data) return pred - input_data = mge.tensor() - input_data.set_value(img.transpose(2, 0, 1)[np.newaxis]) - pred = pred_fun(input_data, net=net) + data = mge.tensor() + data.set_value(img.transpose(2, 0, 1)[np.newaxis]) + pred = pred_fun(data, net=net) if is_flip: img_flip = img[:, ::-1, :] - input_data.set_value(img_flip.transpose(2, 0, 1)[np.newaxis]) - pred_flip = pred_fun(input_data, net=net) + data.set_value(img_flip.transpose(2, 0, 1)[np.newaxis]) + pred_flip = pred_fun(data, net=net) pred = (pred + pred_flip[:, :, :, ::-1]) / 2.0 del pred_flip pred = pred.numpy().squeeze().transpose(1, 2, 0) - del input_data + del data return pred -def evaluate(net, img): +def evaluate(net, img, cfg): ori_h, ori_w, _ = img.shape pred_all = np.zeros((ori_h, ori_w, cfg.NUM_CLASSES)) for rate in cfg.VAL_MULTISCALE: if cfg.VAL_SLIP: - img_scale = cv2.resize( - img, None, fx=rate, fy=rate, interpolation=cv2.INTER_LINEAR - ) - val_size = (cfg.IMG_SIZE, cfg.IMG_SIZE) + new_h, new_w = int(ori_h*rate), int(ori_w*rate) + val_size = (cfg.VAL_HEIGHT, cfg.VAL_WIDTH) else: - out_h, out_w = int(cfg.IMG_SIZE * rate), int(cfg.IMG_SIZE * rate) - img_scale = cv2.resize(img, (out_w, out_h), interpolation=cv2.INTER_LINEAR) - val_size = (out_h, out_w) + new_h, new_w = int(cfg.VAL_HEIGHT*rate), int(cfg.VAL_WIDTH*rate) + val_size = (new_h, new_w) + img_scale = cv2.resize( + img, (new_w, new_h), interpolation=cv2.INTER_LINEAR + ) - new_h = img_scale.shape[0] if (new_h <= val_size[0]) and (new_h <= val_size[1]): img_pad, margin = pad_image_to_shape( img_scale, val_size, cv2.BORDER_CONSTANT, value=0 @@ -133,7 +122,6 @@ def evaluate(net, img): else: stride_rate = 2 / 3 stride = [int(np.ceil(i * stride_rate)) for i in val_size] - print(img_scale.shape, stride, val_size) img_pad, margin = pad_image_to_shape( img_scale, val_size, cv2.BORDER_CONSTANT, value=0 ) @@ -154,19 +142,10 @@ def evaluate(net, img): s_x = e_x - val_size[1] s_y = e_y - val_size[0] img_sub = img_pad[s_y:e_y, s_x:e_x, :] - timg_pad, tmargin = pad_image_to_shape( - img_sub, val_size, cv2.BORDER_CONSTANT, value=0 - ) - print(tmargin, timg_pad.shape) - tpred = eval_single(net, timg_pad, cfg.VAL_FLIP) - tpred = tpred[ - margin[0] : (tpred.shape[0] - margin[1]), - margin[2] : (tpred.shape[1] - margin[3]), - :, - ] + tpred = eval_single(net, img_sub, cfg.VAL_FLIP) count_scale[s_y:e_y, s_x:e_x, :] += 1 pred_scale[s_y:e_y, s_x:e_x, :] += tpred - pred_scale = pred_scale / count_scale + #pred_scale = pred_scale / count_scale pred = pred_scale[ margin[0] : (pred_scale.shape[0] - margin[1]), margin[2] : (pred_scale.shape[1] - margin[3]), @@ -176,77 +155,98 @@ def evaluate(net, img): pred = cv2.resize(pred, (ori_w, ori_h), interpolation=cv2.INTER_LINEAR) pred_all = pred_all + pred - pred_all = pred_all / len(cfg.VAL_MULTISCALE) + #pred_all = pred_all / len(cfg.VAL_MULTISCALE) result = np.argmax(pred_all, axis=2).astype(np.uint8) return result -def save_results(result_list, save_dir): +def save_results(result_list, save_dir, cfg): if not os.path.exists(save_dir): os.makedirs(save_dir) for idx, sample in enumerate(result_list): - file_path = os.path.join(save_dir, "%d.png" % idx) + if cfg.DATASET == "Cityscapes": + name = sample["name"].split('/')[-1][:-4] + else: + name = sample["name"] + file_path = os.path.join(save_dir, "%s.png"%name) cv2.imwrite(file_path, sample["pred"]) - file_path = os.path.join(save_dir, "%d.gt.png" % idx) + file_path = os.path.join(save_dir, "%s.gt.png"%name) cv2.imwrite(file_path, sample["gt"]) - -def compute_metric(result_list): - """ - modified from https://github.com/YudeWang/deeplabv3plus-pytorch - """ - # pylint: disable=redefined-outer-name - TP, P, T = [], [], [] - for i in range(cfg.NUM_CLASSES): - TP.append(mp.Value("i", 0, lock=True)) - P.append(mp.Value("i", 0, lock=True)) - T.append(mp.Value("i", 0, lock=True)) - - def compare(start, step, TP, P, T): - for idx in tqdm(range(start, len(result_list), step)): - pred = result_list[idx]["pred"] - gt = result_list[idx]["gt"] - cal = gt < 255 - mask = (pred == gt) * cal - for i in range(cfg.NUM_CLASSES): - P[i].acquire() - P[i].value += np.sum((pred == i) * cal) - P[i].release() - T[i].acquire() - T[i].value += np.sum((gt == i) * cal) - T[i].release() - TP[i].acquire() - TP[i].value += np.sum((gt == i) * mask) - TP[i].release() - - p_list = [] - for i in range(8): - p = mp.Process(target=compare, args=(i, 8, TP, P, T)) - p.start() - p_list.append(p) - for p in p_list: - p.join() - - class_names = dataset.PascalVOC.class_names - IoU = [] - for i in range(cfg.NUM_CLASSES): - IoU.append(TP[i].value / (T[i].value + P[i].value - TP[i].value + 1e-10)) - for i in range(cfg.NUM_CLASSES): - if i == 0: - print("%11s:%7.3f%%" % ("backbound", IoU[i] * 100), end="\t") +# voc cityscapes metric +def compute_metric(result_list, cfg): + class_num = cfg.NUM_CLASSES + hist = np.zeros((class_num, class_num)) + correct = 0 + labeled = 0 + count = 0 + for idx in range(len(result_list)): + pred = result_list[idx]['pred'] + gt = result_list[idx]['gt'] + assert(pred.shape == gt.shape) + k = (gt>=0) & (gt 0] * freq[freq >0]).sum() + mean_pixel_acc = correct / labeled + + if cfg.DATASET == "VOC2012": + class_names = ("background", ) + dataset.PascalVOC.class_names + elif cfg.DATASET == "Cityscapes": + class_names = dataset.Cityscapes.class_names + else: + raise ValueError("Unsupported dataset {}".format(cfg.DATASET)) + + n = iu.size + lines = [] + for i in range(n): + if class_names is None: + cls = 'Class %d:' % (i+1) else: - if i % 2 != 1: - print("%11s:%7.3f%%" % (class_names[i - 1], IoU[i] * 100), end="\t") - else: - print("%11s:%7.3f%%" % (class_names[i - 1], IoU[i] * 100)) - miou = np.mean(np.array(IoU)) - print("\n======================================================") - print("%11s:%7.3f%%" % ("mIoU", miou * 100)) - return miou - + cls = '%d %s' % (i+1, class_names[i]) + lines.append('%-8s\t%.3f%%' % (cls, iu[i] * 100)) + lines.append('---------------------------- %-8s\t%.3f%%\t%-8s\t%.3f%%' % ('mean_IU', mean_IU * 100,'mean_pixel_ACC',mean_pixel_acc*100)) + line = "\n".join(lines) + print(line) + return mean_IU + + +class EvalPascalVOC(dataset.PascalVOC): + def _trans_mask(self, mask): + label = np.ones(mask.shape[:2]) * 255 + class_colors = self.class_colors.copy() + class_colors.insert(0, [0,0,0]) + for i in range(len(class_colors)): + b, g, r = class_colors[i] + label[ + (mask[:, :, 0] == b) & (mask[:, :, 1] == g) & (mask[:, :, 2] == r) + ] = i + return label.astype(np.uint8) + +def build_dataloader(dataset_dir, cfg): + if cfg.DATASET == "VOC2012": + val_dataset = EvalPascalVOC( + dataset_dir, + "val", + order=["image", "mask", "info"] + ) + elif cfg.DATASET == "Cityscapes": + val_dataset = dataset.Cityscapes( + dataset_dir, + "val", + mode='gtFine', + order=["image", "mask", "info"] + ) + else: + raise ValueError("Unsupported dataset {}".format(cfg.DATASET)) -def build_dataloader(dataset_dir): - val_dataset = dataset.PascalVOC(dataset_dir, "val", order=["image", "mask"]) val_sampler = data.SequentialSampler(val_dataset, cfg.VAL_BATCHES) val_dataloader = data.DataLoader( val_dataset, diff --git a/official/vision/segmentation/train.py b/official/vision/segmentation/train.py index f193e85..d4dff1f 100644 --- a/official/vision/segmentation/train.py +++ b/official/vision/segmentation/train.py @@ -23,32 +23,16 @@ from official.vision.segmentation.deeplabv3plus import ( DeepLabV3Plus, softmax_cross_entropy, ) +from official.vision.segmentation.utils import import_config_from_file logger = mge.get_logger(__name__) -class Config: - ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname("__file__"))) - MODEL_SAVE_DIR = os.path.join(ROOT_DIR, "log") - LOG_DIR = MODEL_SAVE_DIR - if not os.path.isdir(MODEL_SAVE_DIR): - os.makedirs(MODEL_SAVE_DIR) - - DATA_WORKERS = 4 - DATA_TYPE = "trainaug" - - IGNORE_INDEX = 255 - NUM_CLASSES = 21 - IMG_SIZE = 512 - IMG_MEAN = [103.530, 116.280, 123.675] - IMG_STD = [57.375, 57.120, 58.395] - - -cfg = Config() - - def main(): parser = argparse.ArgumentParser() + parser.add_argument( + "-c", "--config", type=str, required=True, help="configuration file" + ) parser.add_argument( "-d", "--dataset_dir", type=str, default="/data/datasets/VOC2012", ) @@ -58,19 +42,6 @@ def main(): parser.add_argument( "-n", "--ngpus", type=int, default=8, help="batchsize for training" ) - parser.add_argument( - "-b", "--batch_size", type=int, default=8, help="batchsize for training" - ) - parser.add_argument( - "-lr", - "--base_lr", - type=float, - default=0.002, - help="base learning rate for training", - ) - parser.add_argument( - "-e", "--train_epochs", type=int, default=100, help="epochs for training" - ) parser.add_argument( "-r", "--resume", type=str, default=None, help="resume model file" ) @@ -92,6 +63,8 @@ def main(): def worker(rank, world_size, args): + cfg = import_config_from_file(args.config) + if world_size > 1: dist.init_process_group( master_ip="localhost", @@ -103,11 +76,11 @@ def worker(rank, world_size, args): logger.info("Init process group done") logger.info("Prepare dataset") - train_loader, epoch_size = build_dataloader(args.batch_size, args.dataset_dir) - batch_iter = epoch_size // (args.batch_size * world_size) + train_loader, epoch_size = build_dataloader(cfg.BATCH_SIZE, args.dataset_dir, cfg) + batch_iter = epoch_size // (cfg.BATCH_SIZE * world_size) net = DeepLabV3Plus(class_num=cfg.NUM_CLASSES, pretrained=args.weight_file) - base_lr = args.base_lr * world_size + base_lr = cfg.LEARNING_RATE * world_size optimizer = optim.SGD( net.parameters(requires_grad=True), lr=base_lr, @@ -116,15 +89,15 @@ def worker(rank, world_size, args): ) @jit.trace(symbolic=True, opt_level=2) - def train_func(input_data, label, net=None, optimizer=None): + def train_func(data, label, net=None, optimizer=None): net.train() - pred = net(input_data) + pred = net(data) loss = softmax_cross_entropy(pred, label, ignore_index=cfg.IGNORE_INDEX) optimizer.backward(loss) return pred, loss begin_epoch = 0 - end_epoch = args.train_epochs + end_epoch = cfg.EPOCHS if args.resume is not None: pretrained = mge.load(args.resume) begin_epoch = pretrained["epoch"] + 1 @@ -135,11 +108,11 @@ def worker(rank, world_size, args): max_itr = end_epoch * batch_iter image = mge.tensor( - np.zeros([args.batch_size, 3, cfg.IMG_SIZE, cfg.IMG_SIZE]).astype(np.float32), + np.zeros([cfg.BATCH_SIZE, 3, cfg.IMG_HEIGHT, cfg.IMG_WIDTH]).astype(np.float32), dtype="float32", ) label = mge.tensor( - np.zeros([args.batch_size, cfg.IMG_SIZE, cfg.IMG_SIZE]).astype(np.int32), + np.zeros([cfg.BATCH_SIZE, cfg.IMG_HEIGHT, cfg.IMG_WIDTH]).astype(np.int32), dtype="int32", ) exp_name = os.path.abspath(os.path.dirname(__file__)).split("/")[-1] @@ -184,10 +157,22 @@ def worker(rank, world_size, args): logger.info("save epoch%d", epoch) -def build_dataloader(batch_size, dataset_dir): - train_dataset = dataset.PascalVOC( - dataset_dir, cfg.DATA_TYPE, order=["image", "mask"] - ) +def build_dataloader(batch_size, dataset_dir, cfg): + if cfg.DATASET == "VOC2012": + train_dataset = dataset.PascalVOC( + dataset_dir, + cfg.DATA_TYPE, + order=["image", "mask"] + ) + elif cfg.DATASET == "Cityscapes": + train_dataset = dataset.Cityscapes( + dataset_dir, + "train", + mode='gtFine', + order=["image", "mask"] + ) + else: + raise ValueError("Unsupported dataset {}".format(cfg.DATASET)) train_sampler = data.RandomSampler(train_dataset, batch_size, drop_last=True) train_dataloader = data.DataLoader( train_dataset, @@ -197,7 +182,7 @@ def build_dataloader(batch_size, dataset_dir): T.RandomHorizontalFlip(0.5), T.RandomResize(scale_range=(0.5, 2)), T.RandomCrop( - output_size=(cfg.IMG_SIZE, cfg.IMG_SIZE), + output_size=(cfg.IMG_HEIGHT, cfg.IMG_WIDTH), padding_value=[0, 0, 0], padding_maskvalue=255, ), diff --git a/official/vision/segmentation/utils.py b/official/vision/segmentation/utils.py new file mode 100644 index 0000000..8da4fc0 --- /dev/null +++ b/official/vision/segmentation/utils.py @@ -0,0 +1,10 @@ +import importlib.util +import os + + +def import_config_from_file(cfg_file): + assert os.path.exists(cfg_file), "config file {} not exists".format(cfg_file) + spec = importlib.util.spec_from_file_location("config", cfg_file) + cfg_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(cfg_module) + return cfg_module.cfg -- GitLab