diff --git a/imperative/python/test/unit/functional/test_tensor.py b/imperative/python/test/unit/functional/test_tensor.py index 0ffa5ebad8f3ba53e198155163ee6c59cfdb1273..e43386b8c94227ca2cb6a7275e43b2422d943e34 100644 --- a/imperative/python/test/unit/functional/test_tensor.py +++ b/imperative/python/test/unit/functional/test_tensor.py @@ -309,12 +309,17 @@ def test_broadcast(): output2_shape = (20, 10, 20) data2 = np.random.random(input2_shape).astype(np.float32) + input3_shape = (10, 10) + output3_shape = (10, 10) + data3 = np.random.random(input3_shape).astype(np.float32) + def compare_fn(x, y): assert x.shape[0] == y cases = [ {"input": [data1, output1_shape], "output": output1_shape}, {"input": [data2, output2_shape], "output": output2_shape}, + {"input": [data3, output3_shape], "output": output3_shape}, ] opr_test(cases, F.broadcast_to, compare_fn=compare_fn) diff --git a/imperative/src/impl/ops/broadcast.cpp b/imperative/src/impl/ops/broadcast.cpp index dac2282ad2c215e6547fc031d78cf0b0ea24cdc1..21671e06e8f5b84ce9fb92ced6e5f7b783f15285 100644 --- a/imperative/src/impl/ops/broadcast.cpp +++ b/imperative/src/impl/ops/broadcast.cpp @@ -24,14 +24,14 @@ std::shared_ptr make_from_op_node(cg::OperatorNodeBase* node_) { return Broadcast::make(); } -cg::OperatorNodeBase* apply_on_var_node( +auto apply_on_var_node( const OpDef& def, const VarNodeArray& inputs) { auto&& op = def.cast_final_safe(); size_t nr_inp = inputs.size(); mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp); OperatorNodeConfig config{op.make_name()}; - return opr::Broadcast::make(inputs[0], inputs[1], config).node()->owner_opr(); + return opr::Broadcast::make(inputs[0], inputs[1], config); } bool valid_broadcast(const TensorShape& src_shape,