提交 129fa70c 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

fix(mgb/serialization): fix multiple graph load error

GitOrigin-RevId: 89414b014b4ef2a465bca858bc6bcea24f4c951e
上级 4755400e
......@@ -846,7 +846,7 @@ GraphLoader::LoadResult GraphLoaderOSS::load(const LoadConfig& config,
OprLoadContextImpl ctx{this, m_graph->mgb_version()};
auto result = ctx.load_oprs();
auto fbs_end = tensor_begin + offset_to_fbs + size;
auto fbs_end = tensor_begin + offset_to_fbs + sizeof(size) + size;
auto cur = m_file->tell();
mgb_assert(fbs_end > cur);
// Skip to Graph end
......@@ -872,4 +872,4 @@ bool is_fbs_file(InputFile& file) {
} // namespace serialization
} // namespace mgb
#endif
\ No newline at end of file
#endif
......@@ -64,6 +64,34 @@ TEST(TestSerializer2, GraphDumpLoad) {
load();
}
TEST(TestSerializer2, MultiGraphDumpLoad) {
auto fname = GET_OUTPUT_FILE();
auto dump = [&]() {
auto cn = CompNode::load("cpu0");
auto graph = ComputingGraph::make();
auto x = opr::ImmutableTensor::make(*graph, 1926.0817f, {cn});
x.rename("varz");
auto dumper = GraphDumper::make(OutputFile::make_fs(fname.c_str()),
GraphDumpFormat::FLATBUFFERS);
// dump twice
dumper->dump({x});
dumper->dump({x});
};
auto load = [&]() {
GraphLoader::LoadConfig load_config = {};
auto loader = GraphLoader::make(InputFile::make_fs(fname.c_str()),
GraphDumpFormat::FLATBUFFERS);
// load twice
loader->load(load_config, false);
loader = GraphLoader::make(loader->reset_file(), loader->format());
loader->load(load_config, false);
};
dump();
load();
}
TEST(TestSerializer2, APlusB) {
auto fname = GET_OUTPUT_FILE();
TensorShape shape{2, 3};
......@@ -733,4 +761,4 @@ TEST(TestSerializer2, HasOutputDtype) {
load();
}
#endif
\ No newline at end of file
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册