提交 fa1ca0ea 编写于 作者: M Megvii Engine Team

fix(imperative/opr): fix apply_on_var_node for broadcast

GitOrigin-RevId: 686fff4f739e332b0c2bdfd7678c67a0b9ec0a5f
上级 e1c83d8d
......@@ -309,12 +309,17 @@ def test_broadcast():
output2_shape = (20, 10, 20)
data2 = np.random.random(input2_shape).astype(np.float32)
input3_shape = (10, 10)
output3_shape = (10, 10)
data3 = np.random.random(input3_shape).astype(np.float32)
def compare_fn(x, y):
assert x.shape[0] == y
cases = [
{"input": [data1, output1_shape], "output": output1_shape},
{"input": [data2, output2_shape], "output": output2_shape},
{"input": [data3, output3_shape], "output": output3_shape},
]
opr_test(cases, F.broadcast_to, compare_fn=compare_fn)
......
......@@ -24,14 +24,14 @@ std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) {
return Broadcast::make();
}
cg::OperatorNodeBase* apply_on_var_node(
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = def.cast_final_safe<Broadcast>();
size_t nr_inp = inputs.size();
mgb_assert(nr_inp == 2, "Broadcast expects 2 inputs; got %lu actually", nr_inp);
OperatorNodeConfig config{op.make_name()};
return opr::Broadcast::make(inputs[0], inputs[1], config).node()->owner_opr();
return opr::Broadcast::make(inputs[0], inputs[1], config);
}
bool valid_broadcast(const TensorShape& src_shape,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册