提交 1fa143ce 编写于 作者: M Megvii Engine Team

refactor(mge/functional): matmul supports symbolic shape, batched mv multiply

GitOrigin-RevId: c4d8cf3306cd833828eca0fc7372397cbf2cc36f
上级 d47cf332
......@@ -23,7 +23,7 @@ from ..tensor import Tensor
from .debug_param import get_conv_execution_strategy
from .distributed import all_reduce_sum
from .elemwise import exp, floor, log, log1p, maximum, minimum, relu
from .math import argsort, max, sum
from .math import argsort, max, prod, sum
from .tensor import (
broadcast_to,
concat,
......@@ -972,38 +972,42 @@ def matmul(
[28. 40.]]
"""
remove_row, remove_col = False, False
inp1, inp2 = utils.convert_inputs(inp1, inp2)
dim1, dim2 = inp1.ndim, inp2.ndim
# handle dim=1 cases, dot and matrix-vector multiplication
if dim1 == 1 and dim2 == 1:
return dot(inp1, inp2)
# the underlying matmul op requires input dims to be at least 2
if dim1 == 1:
inp1 = expand_dims(inp1, 0)
dim1 = 2
remove_row = True
if dim2 == 1:
inp2 = expand_dims(inp2, 1)
dim2 = 2
remove_col = True
batch_shape = None
shape1 = astensor1d(inp1.shape, inp1, dtype="int32", device=inp1.device)
shape2 = astensor1d(inp2.shape, inp2, dtype="int32", device=inp2.device)
if dim1 >= 3 or dim2 >= 3:
if dim1 == dim2:
assert (
shape1[:-2] == shape2[:-2]
).min(), "operands could not be broadcasted together."
if dim1 > dim2:
shape2 = concat([shape1[:-2], shape2[-2:]])
inp2 = broadcast_to(inp2, shape2)
if dim1 < dim2:
shape1 = concat([shape2[:-2], shape1[-2:]])
inp1 = broadcast_to(inp1, shape1)
batch_shape = shape1[:-2]
# compress inputs to 3d
inp1 = inp1.reshape(concat([prod(shape1[:-2]), shape1[-2:]]))
inp2 = inp2.reshape(concat([prod(shape2[:-2]), shape2[-2:]]))
shp = None
if dim1 > 3 or dim2 > 3:
shape1, shape2 = list(inp1.shape), list(inp2.shape)
if dim1 != dim2:
if dim1 < dim2:
shape1 = shape2[: dim2 - dim1] + shape1
inp1 = broadcast_to(inp1, shape1)
else:
shape2 = shape1[: dim1 - dim2] + shape2
inp2 = broadcast_to(inp2, shape2)
reshaped_batch_size = 1
for i in shape1[:-2]:
reshaped_batch_size *= i
inp1 = inp1.reshape(*([reshaped_batch_size] + shape1[-2:]))
inp2 = inp2.reshape(*([reshaped_batch_size] + shape2[-2:]))
op = builtin.BatchedMatrixMul(
transposeA=transpose_a,
transposeB=transpose_b,
compute_mode=compute_mode,
format=format,
)
shp = shape1[:-1] + shape2[-1:]
elif dim1 == 3 or dim2 == 3:
if dim2 < 3:
inp2 = broadcast_to(inp2, inp1.shape[:1] + inp2.shape)
elif dim1 < 3:
inp1 = broadcast_to(inp1, inp2.shape[:1] + inp1.shape)
op = builtin.BatchedMatrixMul(
transposeA=transpose_a,
transposeB=transpose_b,
......@@ -1011,12 +1015,6 @@ def matmul(
format=format,
)
else:
if dim1 == 1:
shp = (inp2.shape[1],)
inp1 = expand_dims(inp1, 0)
if dim2 == 1:
shp = (inp1.shape[0],)
inp2 = expand_dims(inp2, 1)
op = builtin.MatrixMul(
transposeA=transpose_a,
transposeB=transpose_b,
......@@ -1025,8 +1023,12 @@ def matmul(
)
(result,) = apply(op, inp1, inp2)
if shp is not None:
result = result.reshape(shp)
if batch_shape is not None:
result = result.reshape(concat([batch_shape, result.shape[-2:]]))
if remove_row:
result = squeeze(result, axis=-2)
if remove_col:
result = squeeze(result, axis=-1)
return result
......
......@@ -77,26 +77,43 @@ def test_matmul():
opr_test(cases, F.matmul, ref_fn=np.matmul)
batch_size = 10
shape1 = (batch_size, 2, 3)
shape2 = (batch_size, 3, 4)
shape3 = (batch_size, 10, 4, 5)
shape1 = (2,)
shape2 = (batch_size, 2, 3)
shape3 = (batch_size, 3, 4)
shape4 = (batch_size, 10, 4, 2)
shape5 = (batch_size, 10, 2, 4)
data1 = np.random.random(shape1).astype("float32")
data2 = np.random.random(shape2).astype("float32")
data3 = np.random.random(shape3).astype("float32")
data4 = np.random.random(shape4).astype("float32")
data5 = np.random.random(shape5).astype("float32")
cases = [{"input": [data1, data2]}, {"input": [data2, data3]}]
for i in range(0, batch_size):
def compare_fn(x, y):
x.numpy()[i, ...] == y
cases = [
{"input": [data1, data2]},
{"input": [data2, data3]},
{"input": [data3, data4]},
{"input": [data4, data5]},
]
for _ in range(0, batch_size):
opr_test(
cases,
F.matmul,
compare_fn=compare_fn,
ref_fn=lambda x, y: np.matmul(x[i, ...], y[i, ...]),
cases, F.matmul, ref_fn=np.matmul,
)
opr_test(
[{"input": [data1, data4]}],
F.matmul,
ref_fn=lambda x, y: np.matmul(x, y.transpose(0, 1, 3, 2)),
transpose_b=True,
)
opr_test(
[{"input": [data3, data2]}],
F.matmul,
ref_fn=lambda x, y: np.matmul(x.transpose(0, 2, 1), y.transpose(0, 2, 1)),
transpose_a=True,
transpose_b=True,
)
def test_interpolate():
def linear_interpolate():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册