From 1fa143ce87636264d9c00d674aa70cb64c321acb Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 19 Oct 2020 14:16:03 +0800 Subject: [PATCH] refactor(mge/functional): matmul supports symbolic shape, batched mv multiply GitOrigin-RevId: c4d8cf3306cd833828eca0fc7372397cbf2cc36f --- imperative/python/megengine/functional/nn.py | 74 ++++++++++--------- .../test/unit/functional/test_functional.py | 43 +++++++---- 2 files changed, 68 insertions(+), 49 deletions(-) diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 1eac73f14..b4a8547d6 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -23,7 +23,7 @@ from ..tensor import Tensor from .debug_param import get_conv_execution_strategy from .distributed import all_reduce_sum from .elemwise import exp, floor, log, log1p, maximum, minimum, relu -from .math import argsort, max, sum +from .math import argsort, max, prod, sum from .tensor import ( broadcast_to, concat, @@ -972,38 +972,42 @@ def matmul( [28. 40.]] """ + remove_row, remove_col = False, False inp1, inp2 = utils.convert_inputs(inp1, inp2) + dim1, dim2 = inp1.ndim, inp2.ndim + # handle dim=1 cases, dot and matrix-vector multiplication if dim1 == 1 and dim2 == 1: return dot(inp1, inp2) + # the underlying matmul op requires input dims to be at least 2 + if dim1 == 1: + inp1 = expand_dims(inp1, 0) + dim1 = 2 + remove_row = True + if dim2 == 1: + inp2 = expand_dims(inp2, 1) + dim2 = 2 + remove_col = True + + batch_shape = None + shape1 = astensor1d(inp1.shape, inp1, dtype="int32", device=inp1.device) + shape2 = astensor1d(inp2.shape, inp2, dtype="int32", device=inp2.device) + if dim1 >= 3 or dim2 >= 3: + if dim1 == dim2: + assert ( + shape1[:-2] == shape2[:-2] + ).min(), "operands could not be broadcasted together." + if dim1 > dim2: + shape2 = concat([shape1[:-2], shape2[-2:]]) + inp2 = broadcast_to(inp2, shape2) + if dim1 < dim2: + shape1 = concat([shape2[:-2], shape1[-2:]]) + inp1 = broadcast_to(inp1, shape1) + batch_shape = shape1[:-2] + # compress inputs to 3d + inp1 = inp1.reshape(concat([prod(shape1[:-2]), shape1[-2:]])) + inp2 = inp2.reshape(concat([prod(shape2[:-2]), shape2[-2:]])) - shp = None - if dim1 > 3 or dim2 > 3: - shape1, shape2 = list(inp1.shape), list(inp2.shape) - if dim1 != dim2: - if dim1 < dim2: - shape1 = shape2[: dim2 - dim1] + shape1 - inp1 = broadcast_to(inp1, shape1) - else: - shape2 = shape1[: dim1 - dim2] + shape2 - inp2 = broadcast_to(inp2, shape2) - reshaped_batch_size = 1 - for i in shape1[:-2]: - reshaped_batch_size *= i - inp1 = inp1.reshape(*([reshaped_batch_size] + shape1[-2:])) - inp2 = inp2.reshape(*([reshaped_batch_size] + shape2[-2:])) - op = builtin.BatchedMatrixMul( - transposeA=transpose_a, - transposeB=transpose_b, - compute_mode=compute_mode, - format=format, - ) - shp = shape1[:-1] + shape2[-1:] - elif dim1 == 3 or dim2 == 3: - if dim2 < 3: - inp2 = broadcast_to(inp2, inp1.shape[:1] + inp2.shape) - elif dim1 < 3: - inp1 = broadcast_to(inp1, inp2.shape[:1] + inp1.shape) op = builtin.BatchedMatrixMul( transposeA=transpose_a, transposeB=transpose_b, @@ -1011,12 +1015,6 @@ def matmul( format=format, ) else: - if dim1 == 1: - shp = (inp2.shape[1],) - inp1 = expand_dims(inp1, 0) - if dim2 == 1: - shp = (inp1.shape[0],) - inp2 = expand_dims(inp2, 1) op = builtin.MatrixMul( transposeA=transpose_a, transposeB=transpose_b, @@ -1025,8 +1023,12 @@ def matmul( ) (result,) = apply(op, inp1, inp2) - if shp is not None: - result = result.reshape(shp) + if batch_shape is not None: + result = result.reshape(concat([batch_shape, result.shape[-2:]])) + if remove_row: + result = squeeze(result, axis=-2) + if remove_col: + result = squeeze(result, axis=-1) return result diff --git a/imperative/python/test/unit/functional/test_functional.py b/imperative/python/test/unit/functional/test_functional.py index 4b3c7290a..76ca492d2 100644 --- a/imperative/python/test/unit/functional/test_functional.py +++ b/imperative/python/test/unit/functional/test_functional.py @@ -77,26 +77,43 @@ def test_matmul(): opr_test(cases, F.matmul, ref_fn=np.matmul) batch_size = 10 - shape1 = (batch_size, 2, 3) - shape2 = (batch_size, 3, 4) - shape3 = (batch_size, 10, 4, 5) + shape1 = (2,) + shape2 = (batch_size, 2, 3) + shape3 = (batch_size, 3, 4) + shape4 = (batch_size, 10, 4, 2) + shape5 = (batch_size, 10, 2, 4) data1 = np.random.random(shape1).astype("float32") data2 = np.random.random(shape2).astype("float32") data3 = np.random.random(shape3).astype("float32") + data4 = np.random.random(shape4).astype("float32") + data5 = np.random.random(shape5).astype("float32") - cases = [{"input": [data1, data2]}, {"input": [data2, data3]}] - for i in range(0, batch_size): - - def compare_fn(x, y): - x.numpy()[i, ...] == y - + cases = [ + {"input": [data1, data2]}, + {"input": [data2, data3]}, + {"input": [data3, data4]}, + {"input": [data4, data5]}, + ] + for _ in range(0, batch_size): opr_test( - cases, - F.matmul, - compare_fn=compare_fn, - ref_fn=lambda x, y: np.matmul(x[i, ...], y[i, ...]), + cases, F.matmul, ref_fn=np.matmul, ) + opr_test( + [{"input": [data1, data4]}], + F.matmul, + ref_fn=lambda x, y: np.matmul(x, y.transpose(0, 1, 3, 2)), + transpose_b=True, + ) + + opr_test( + [{"input": [data3, data2]}], + F.matmul, + ref_fn=lambda x, y: np.matmul(x.transpose(0, 2, 1), y.transpose(0, 2, 1)), + transpose_a=True, + transpose_b=True, + ) + def test_interpolate(): def linear_interpolate(): -- GitLab