提交 f12355f7 编写于 作者: M Megvii Engine Team 提交者: huangxinda

fix(imperative/grad): fix hardcode dtype in subtensor_grad_rule

GitOrigin-RevId: 50da4af26dd4f0f0efe38f07573d704ea2fbe841
上级 4e4497b9
......@@ -35,9 +35,9 @@ std::shared_ptr<Tensor> broadcast_to(Tensor* x, Tensor* s) {
return python::apply(op, x, s)[0];
}
std::shared_ptr<Tensor> make_tensor(CompNode cn, Tensor* shape, float v = 0) {
HostTensorND scalar{cn, {{1}, dtype::Float32()}};
scalar.ptr<float>()[0] = v;
std::shared_ptr<Tensor> make_empty_tensor(CompNode cn, Tensor* shape, DType dtype) {
HostTensorND scalar{cn, {{1}, dtype}};
std:memset(scalar.raw_ptr(), 0, dtype.size());
interpreter::Interpreter::Handle handle = interpreter_for_py->put(scalar, false);
auto&& t = std::make_shared<Tensor>(handle);
auto res = broadcast_to(t.get(), shape);
......@@ -117,7 +117,7 @@ apply_result_t subtensor_grad_rule(ApplyContext& ctx, CustomBackward::Maker& mak
apply_result_t ret(1);
if (grad && inputs[0]) {
SmallVector<Tensor*> args_(inputs.size()+1);
auto&& zeros = make_tensor(grad->comp_node(), inputs[0].get());
auto&& zeros = make_empty_tensor(grad->comp_node(), inputs[0].get(), grad->dtype());
args_[0] = zeros.get();
args_[1] = grad;
for (size_t i = 1; i < inputs.size(); ++i) {
......@@ -147,7 +147,7 @@ apply_result_t indexingMultiAxisVec_grad_rule(ApplyContext& ctx, CustomBackward:
apply_result_t ret(1);
if (grad && inputs[0]) {
SmallVector<Tensor*> args_(inputs.size()+1);
auto&& zeros = make_tensor(grad->comp_node(), inputs[0].get());
auto&& zeros = make_empty_tensor(grad->comp_node(), inputs[0].get(), grad->dtype());
args_[0] = zeros.get();
args_[1] = grad;
for (size_t i = 1; i < inputs.size(); ++i) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册