提交 b9e0319f 编写于 作者: M Megvii Engine Team

fix(imperative): fix format transformation handle nchw format tensor

GitOrigin-RevId: f5838c1f7fbc1a1f4ffd9fc8951ed0cbdd422dc2
上级 fca4ae57
......@@ -20,6 +20,10 @@ def test_basic():
b = tensor(a)
assert b.format == "nhwc"
b = tensor(data, format="nchw")
result = F.pad(b, ((0, 0), (0, 0), (1, 1), (1, 1)), mode="reflect")
assert result.format == "default"
# TODO: init from tensor with new format
# c = tensor(a, format="nchw")
# assert c.format == "nchw"
......
......@@ -435,13 +435,22 @@ inline FT get_inputs_format(Span<ValueRef>& inputs, const FormatTransformation&
return format;
}
inline bool if_convert_format(const Format src_fmt, const FT& dst_fmt) {
if ((src_fmt == FT::NCHW && dst_fmt == FT::DEFAULT) ||
(src_fmt == FT::DEFAULT && dst_fmt == FT::NCHW)) {
return false;
} else {
return true;
}
}
inline ValueRefList unify_inputs_format(
const Span<ValueRef>& inputs, const FT& dst_fmt, const std::string& scope,
const FormatTransformation& t) {
ValueRefList unified_inputs(inputs.size());
for (size_t i = 0; i < inputs.size(); ++i) {
auto&& inp = inputs[i].cast(t.value_type());
if (inp.format() != dst_fmt) {
if (inp.format() != dst_fmt && if_convert_format(inp.format(), dst_fmt)) {
unified_inputs[i] = t.to(inp, dst_fmt, scope);
} else {
unified_inputs[i] = inputs[i];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册