From b790cc16533e66fa415c77c95ff74f85cdb4badc Mon Sep 17 00:00:00 2001 From: zhaozhenlong Date: Tue, 8 Sep 2020 17:10:37 +0800 Subject: [PATCH] fix image gradient height 1 --- mindspore/nn/layer/image.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/mindspore/nn/layer/image.py b/mindspore/nn/layer/image.py index af7e729b9..e807c0cf0 100644 --- a/mindspore/nn/layer/image.py +++ b/mindspore/nn/layer/image.py @@ -66,13 +66,19 @@ class ImageGradients(Cell): check = _check_input_4d(F.shape(images), "images", self.cls_name) images = F.depend(images, check) batch_size, depth, height, width = P.Shape()(images) - dy = images[:, :, 1:, :] - images[:, :, :height - 1, :] - dy_last = P.Fill()(P.DType()(images), (batch_size, depth, 1, width), 0) - dy = P.Concat(2)((dy, dy_last)) - - dx = images[:, :, :, 1:] - images[:, :, :, :width - 1] - dx_last = P.Fill()(P.DType()(images), (batch_size, depth, height, 1), 0) - dx = P.Concat(3)((dx, dx_last)) + if height == 1: + dy = P.Fill()(P.DType()(images), (batch_size, depth, 1, width), 0) + else: + dy = images[:, :, 1:, :] - images[:, :, :height - 1, :] + dy_last = P.Fill()(P.DType()(images), (batch_size, depth, 1, width), 0) + dy = P.Concat(2)((dy, dy_last)) + + if width == 1: + dx = P.Fill()(P.DType()(images), (batch_size, depth, height, 1), 0) + else: + dx = images[:, :, :, 1:] - images[:, :, :, :width - 1] + dx_last = P.Fill()(P.DType()(images), (batch_size, depth, height, 1), 0) + dx = P.Concat(3)((dx, dx_last)) return dy, dx -- GitLab