/** * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef GE_GE_RUNTIME_RUNTIME_MODEL_H_ #define GE_GE_RUNTIME_RUNTIME_MODEL_H_ #include #include #include #include #include "ge_runtime/davinci_model.h" #include "common/ge_types.h" #include "runtime/base.h" #include "runtime/rt_model.h" namespace ge { namespace model_runner { using RuntimeInfo = std::tuple; class Task; class RuntimeModel { public: RuntimeModel() = default; ~RuntimeModel(); bool Load(uint32_t device_id, uint64_t session_id, std::shared_ptr &davinci_model); bool LoadComplete(); const std::vector &GetTaskIdList() const; const std::vector &GetStreamIdList() const; const std::map> &GetRuntimeInfoMap() const { return runtime_info_map_; } bool Run(); bool CopyInputData(const InputData &input_data); bool GetInputOutputDescInfo(bool zero_copy, std::vector *input_desc, std::vector *output_desc, std::vector *input_format, std::vector *output_format); private: bool InitResource(std::shared_ptr &davinci_model); void GenerateTask(uint32_t device_id, uint64_t session_id, std::shared_ptr &davinci_model); bool LoadTask(); bool InitStream(std::shared_ptr &davinci_model); bool InitEvent(uint32_t event_num); bool InitLabel(std::shared_ptr &davinci_model); bool InitDataInfo(std::shared_ptr &davinci_model); bool InitOutputInfo(std::shared_ptr &davinci_model); bool InitConstantInfo(std::shared_ptr &davinci_model); void RtModelUnbindStream() noexcept; void RtStreamDestory() noexcept; void RtModelDestory() noexcept; void RtLabelDestory() noexcept; void RtEventDestory() noexcept; bool CopyInputDataToModel(const std::vector &data, const std::shared_ptr &data_info); bool CopyHostData(const std::vector &data, const std::shared_ptr &data_info) const; bool CopyTransData(const std::vector &data, const std::shared_ptr &data_info); bool GetInputDescInfo(std::vector *input_desc, std::vector *formats); bool GetOutputDescInfo(std::vector *output_desc, std::vector *formats); void CreateOutput(uint32_t index, const OpInfo &op_info, InputOutputDescInfo *output, uint32_t *format); rtModel_t rt_model_handle_{}; rtStream_t rt_model_stream_{}; std::vector stream_list_{}; std::vector label_list_{}; std::vector event_list_{}; std::vector> task_list_{}; std::vector> data_info_list_{}; std::vector> output_info_list_{}; std::vector> constant_info_list_{}; std::vector task_id_list_{}; std::vector stream_id_list_{}; std::map> runtime_info_map_; }; } // namespace model_runner } // namespace ge #endif // GE_GE_RUNTIME_RUNTIME_MODEL_H_