From fa1ca0ea6c108219a5136542811f42600322edf2 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 25 Mar 2021 15:11:52 +0800 Subject: [PATCH] fix(imperative/opr): fix apply_on_var_node for broadcast GitOrigin-RevId: 686fff4f739e332b0c2bdfd7678c67a0b9ec0a5f --- imperative/python/test/unit/functional/test_tensor.py | 5 +++++ imperative/src/impl/ops/broadcast.cpp | 4 ++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/imperative/python/test/unit/functional/test_tensor.py b/imperative/python/test/unit/functional/test_tensor.py index 0ffa5ebad..e43386b8c 100644 --- a/imperative/python/test/unit/functional/test_tensor.py +++ b/imperative/python/test/unit/functional/test_tensor.py @@ -309,12 +309,17 @@ def test_broadcast(): output2_shape = (20, 10, 20) data2 = np.random.random(input2_shape).astype(np.float32) + input3_shape = (10, 10) + output3_shape = (10, 10) + data3 = np.random.random(input3_shape).astype(np.float32) + def compare_fn(x, y): assert x.shape[0] == y cases = [ {"input": [data1, output1_shape], "output": output1_shape}, {"input": [data2, output2_shape], "output": output2_shape}, + {"input": [data3, output3_shape], "output": output3_shape}, ] opr_test(cases, F.broadcast_to, compare_fn=compare_fn) diff --git a/imperative/src/impl/ops/broadcast.cpp b/imperative/src/impl/ops/broadcast.cpp index dac2282ad..21671e06e 100644 --- a/imperative/src/impl/ops/broadcast.cpp +++ b/imperative/src/impl/ops/broadcast.cpp @@ -24,14 +24,14 @@ std::shared_ptr make_from_op_node(cg::OperatorNodeBase* node_) { return Broadcast::make(); } -cg::OperatorNodeBase* apply_on_var_node( +auto apply_on_var_node( const OpDef& def, const VarNodeArray& inputs) { auto&& op = def.cast_final_safe(); size_t nr_inp = inputs.size(); mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp); OperatorNodeConfig config{op.make_name()}; - return opr::Broadcast::make(inputs[0], inputs[1], config).node()->owner_opr(); + return opr::Broadcast::make(inputs[0], inputs[1], config); } bool valid_broadcast(const TensorShape& src_shape, -- GitLab