From 129fa70cbaacab07dd3a7e5bd9cd9e14d2f09093 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 30 Jun 2020 11:22:50 +0800 Subject: [PATCH] fix(mgb/serialization): fix multiple graph load error GitOrigin-RevId: 89414b014b4ef2a465bca858bc6bcea24f4c951e --- src/serialization/impl/serializer_oss.cpp | 4 +-- src/serialization/test/serializer_oss.cpp | 30 ++++++++++++++++++++++- 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/src/serialization/impl/serializer_oss.cpp b/src/serialization/impl/serializer_oss.cpp index 4b2917b50..3cbfeaeb3 100644 --- a/src/serialization/impl/serializer_oss.cpp +++ b/src/serialization/impl/serializer_oss.cpp @@ -846,7 +846,7 @@ GraphLoader::LoadResult GraphLoaderOSS::load(const LoadConfig& config, OprLoadContextImpl ctx{this, m_graph->mgb_version()}; auto result = ctx.load_oprs(); - auto fbs_end = tensor_begin + offset_to_fbs + size; + auto fbs_end = tensor_begin + offset_to_fbs + sizeof(size) + size; auto cur = m_file->tell(); mgb_assert(fbs_end > cur); // Skip to Graph end @@ -872,4 +872,4 @@ bool is_fbs_file(InputFile& file) { } // namespace serialization } // namespace mgb -#endif \ No newline at end of file +#endif diff --git a/src/serialization/test/serializer_oss.cpp b/src/serialization/test/serializer_oss.cpp index 2685c66b8..90364db52 100644 --- a/src/serialization/test/serializer_oss.cpp +++ b/src/serialization/test/serializer_oss.cpp @@ -64,6 +64,34 @@ TEST(TestSerializer2, GraphDumpLoad) { load(); } +TEST(TestSerializer2, MultiGraphDumpLoad) { + auto fname = GET_OUTPUT_FILE(); + + auto dump = [&]() { + auto cn = CompNode::load("cpu0"); + auto graph = ComputingGraph::make(); + auto x = opr::ImmutableTensor::make(*graph, 1926.0817f, {cn}); + x.rename("varz"); + auto dumper = GraphDumper::make(OutputFile::make_fs(fname.c_str()), + GraphDumpFormat::FLATBUFFERS); + // dump twice + dumper->dump({x}); + dumper->dump({x}); + }; + auto load = [&]() { + GraphLoader::LoadConfig load_config = {}; + auto loader = GraphLoader::make(InputFile::make_fs(fname.c_str()), + GraphDumpFormat::FLATBUFFERS); + // load twice + loader->load(load_config, false); + loader = GraphLoader::make(loader->reset_file(), loader->format()); + loader->load(load_config, false); + }; + + dump(); + load(); +} + TEST(TestSerializer2, APlusB) { auto fname = GET_OUTPUT_FILE(); TensorShape shape{2, 3}; @@ -733,4 +761,4 @@ TEST(TestSerializer2, HasOutputDtype) { load(); } -#endif \ No newline at end of file +#endif -- GitLab