未验证 提交 ea0b5f52 编写于 作者: I i-robot 提交者: Gitee

!52829 fix conv3d group

Merge pull request !52829 from chenkang/r2.0_fix_conv3d_group
......@@ -29,6 +29,10 @@ mindspore.ops.Conv3D
:math:`dilation` 为三维卷积核膨胀尺寸, :math:`stride` 为移动步长,
:math:`padding` 为在输入两侧的填充长度。
.. note::
在Ascend平台上,目前只支持 :math:`group=1` 。
参数:
- **out_channel** (int) - 输出的通道数 :math:`C_{out}` 。
- **kernel_size** (Union[int, tuple[int]]) - 指定三维卷积核的深度、高度和宽度。可以为单个int或包含三个整数的Tuple。一个整数表示卷积核的深度、高度和宽度均为该值。包含三个整数的Tuple分别表示卷积核的深度、高度和宽度。
......
......@@ -21,7 +21,7 @@ mindspore.ops.conv3d
.. note::
1. 在Ascend平台上,目前只支持深度卷积场景下的分组卷积运算。也就是说,当 `group>1` 的场景下,必须要满足 :math:`C_{in} = C_{out} = group` 的约束条件
1. 在Ascend平台上,目前只支持 :math:`groups=1`
2. 在Ascend平台上,目前只支持 :math:`dialtion=1` 。
......
......@@ -5143,8 +5143,7 @@ def conv3d(input, weight, bias=None, stride=1, pad_mode="valid", padding=0, dila
Recognition <http://vision.stanford.edu/cs598_spring07/papers/Lecun98.pdf>`_ .
Note:
1. On Ascend platform, only group convolution in depthwise convolution scenarios is supported.
That is, when `groups>1`, condition :math:`C_{in} = C_{out} = groups` must be satisfied.
1. On Ascend platform, :math:`groups = 1` must be satisfied.
2. On Ascend dilation on depth only supports the case of 1.
Args:
......
......@@ -1274,8 +1274,7 @@ class Conv2D(Primitive):
<http://vision.stanford.edu/cs598_spring07/papers/Lecun98.pdf>`_.
Note:
On Ascend platform, only group convolution in depthwise convolution scenarios is supported.
That is, when `group>1`, condition `in\_channels` = `out\_channels` = `group` must be satisfied.
On Ascend platform, :math:`group = 1` must be satisfied.
Args:
out_channel (int): The number of output channel :math:`C_{out}`.
......@@ -7775,8 +7774,8 @@ class Conv3D(Primitive):
validator.check_value_type("group", group, (int,), self.name)
validator.check_int_range(group, 1, out_channel, validator.INC_BOTH, "group", self.name)
device_target = context.get_context("device_target")
if device_target == "Ascend" and group > 1 and out_channel != group:
raise ValueError("On Ascend platform, when group > 1, condition C_in = C_out = group must be satisfied.")
if device_target == "Ascend" and group != 1:
raise ValueError("On Ascend platform, group = 1 must be satisfied.")
self.group = group
self.add_prim_attr('groups', self.group)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册