提交 2ab84e1f 编写于 作者: M Matthew Moore 提交者: Waleed

add IMAGE_CHANNEL_COUNT class variable to config to make it easier to use...

add IMAGE_CHANNEL_COUNT class variable to config to make it easier to use Mask_RCNN for non 3-channel images
上级 80cac4c0
......@@ -131,6 +131,9 @@ class Config(object):
# the width and height, or more, even if MIN_IMAGE_DIM doesn't require it.
# Howver, in 'square' mode, it can be overruled by IMAGE_MAX_DIM.
IMAGE_MIN_SCALE = 0
# IMAGE_CHANNEL_COUNT refers to number of channels per image, which can
# be overridden by a subclass. RGB = 3, grayscale = 1, RGB-D = 4
IMAGE_CHANNEL_COUNT = 3
# Image mean (RGB)
MEAN_PIXEL = np.array([123.7, 116.8, 103.9])
......@@ -213,9 +216,11 @@ class Config(object):
# Input image size
if self.IMAGE_RESIZE_MODE == "crop":
self.IMAGE_SHAPE = np.array([self.IMAGE_MIN_DIM, self.IMAGE_MIN_DIM, 3])
self.IMAGE_SHAPE = np.array([self.IMAGE_MIN_DIM, self.IMAGE_MIN_DIM,
self.__class__.IMAGE_CHANNEL_COUNT])
else:
self.IMAGE_SHAPE = np.array([self.IMAGE_MAX_DIM, self.IMAGE_MAX_DIM, 3])
self.IMAGE_SHAPE = np.array([self.IMAGE_MAX_DIM, self.IMAGE_MAX_DIM,
self.__class__.IMAGE_CHANNEL_COUNT])
# Image meta data length
# See compose_image_meta() for details
......
......@@ -70,7 +70,7 @@ class BatchNorm(KL.BatchNormalization):
def compute_backbone_shapes(config, image_shape):
"""Computes the width and height of each stage of the backbone network.
Returns:
[N, (height, width)]. Where N is the number of stages
"""
......@@ -808,7 +808,7 @@ class DetectionLayer(KE.Layer):
m = parse_image_meta_graph(image_meta)
image_shape = m['image_shape'][0]
window = norm_boxes_graph(m['window'], image_shape[:2])
# Run detection refinement graph on each item in the batch
detections_batch = utils.batch_slice(
[rois, mrcnn_class, mrcnn_bbox, window],
......@@ -1861,7 +1861,7 @@ class MaskRCNN():
# Inputs
input_image = KL.Input(
shape=[None, None, 3], name="input_image")
shape=[None, None, config.IMAGE_SHAPE[2]], name="input_image")
input_image_meta = KL.Input(shape=[config.IMAGE_META_SIZE],
name="input_image_meta")
if mode == "training":
......@@ -2046,7 +2046,7 @@ class MaskRCNN():
fc_layers_size=config.FPN_CLASSIF_FC_LAYERS_SIZE)
# Detections
# output is [batch, num_detections, (y1, x1, y2, x2, class_id, score)] in
# output is [batch, num_detections, (y1, x1, y2, x2, class_id, score)] in
# normalized coordinates
detections = DetectionLayer(config, name="mrcnn_detection")(
[rpn_rois, mrcnn_class, mrcnn_bbox, input_image_meta])
......@@ -2310,7 +2310,7 @@ class MaskRCNN():
imgaug.augmenters.GaussianBlur(sigma=(0.0, 5.0))
])
custom_callbacks: Optional. Add custom callbacks to be called
with the keras fit_generator method. Must be list of type keras.callbacks.
with the keras fit_generator method. Must be list of type keras.callbacks.
no_augmentation_sources: Optional. List of sources to exclude for
augmentation. A source is string that identifies a dataset and is
defined in the Dataset class.
......@@ -2350,7 +2350,7 @@ class MaskRCNN():
keras.callbacks.ModelCheckpoint(self.checkpoint_path,
verbose=0, save_weights_only=True),
]
# Add custom callbacks to the list
if custom_callbacks:
callbacks += custom_callbacks
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册