diff --git a/src/serialization/impl/serializer_oss.cpp b/src/serialization/impl/serializer_oss.cpp index 3cbfeaeb37bcdf37a564e13491811210e0ec252c..d303804a979f65185267164a683f2b5d55f34f14 100644 --- a/src/serialization/impl/serializer_oss.cpp +++ b/src/serialization/impl/serializer_oss.cpp @@ -33,6 +33,8 @@ #include "megbrain/serialization/serializer.h" #include "megbrain/version.h" +#include + #include #include #include @@ -121,7 +123,12 @@ public: void dump_tensor(const std::string& name, const HostTensorND& tensor, TensorWriteMethod method) override; flatbuffers::FlatBufferBuilder& builder() override { return m_builder; } - void append_param(uint32_t type, flatbuffers::Offset value) override { + void append_param(uint32_t type, uint32_t value) override { + static_assert(std::is_same::value, + "append_param depends on uoffset_t being uint32_t"); + static_assert(std::is_standard_layout>::value, + "append_param depends on flatbuffers::Offset having " + "standard memory layout"); mgb_assert(type != fbs::OperatorParam_NONE); m_cur_opr_param_type.emplace_back( static_cast(type)); diff --git a/src/serialization/include/megbrain/serialization/opr_load_dump.h b/src/serialization/include/megbrain/serialization/opr_load_dump.h index af43238e3b31d6dff5deb104be7d6ad6e696627f..7d8f4f5f66461d284c61da2afc400ed145a69842 100644 --- a/src/serialization/include/megbrain/serialization/opr_load_dump.h +++ b/src/serialization/include/megbrain/serialization/opr_load_dump.h @@ -12,9 +12,12 @@ #include "megbrain/graph.h" #include "megbrain/serialization/load_dump_config.h" #include "megbrain/serialization/opr_registry.h" -#if MGB_ENABLE_FBS_SERIALIZATION -#include -#endif + +// Forward declaration for breaking header dependency: we do not want to hard +// depend on flatbuffers/flatbuffers.h in our public headers. +namespace flatbuffers { +class FlatBufferBuilder; +} // namespace flatbuffers namespace mgb { namespace serialization { @@ -122,8 +125,12 @@ class OprDumpContextFlatBuffers : public OprDumpContext { protected: OprDumpContextFlatBuffers() : OprDumpContext(SerializationFormat::FLATBUFFERS) {} - virtual void append_param(uint32_t type, - flatbuffers::Offset value) = 0; + // value_offset should be a flatbuffers::Offset (or ). + // Assuming flatbuffers::Offset is a wrapper around uoffset_t = uint32_t, + // we pass around a uint32_t to avoid dependency to flatbuffers in public + // headers. There are a few static_asserts in serializer_oss.cpp about the + // assumption. + virtual void append_param(uint32_t type, uint32_t value_offset) = 0; public: virtual flatbuffers::FlatBufferBuilder& builder() = 0; @@ -136,7 +143,7 @@ public: auto param_offset = fbs::ParamConverter::to_flatbuffer(builder(), param); append_param(fbs::OperatorParamTraits::enum_value, - param_offset.Union()); + param_offset.Union().o); } template