diff --git a/oneflow/core/job/job.proto b/oneflow/core/job/job.proto index f14b4e343973a5c84ce1538deae6ac016f87def4..7fdf52c08be27971db233499eb32cc7cc0e5462a 100644 --- a/oneflow/core/job/job.proto +++ b/oneflow/core/job/job.proto @@ -45,11 +45,12 @@ message MemoryAllocationAlgorithmConf { message XrtConfig { message XlaConfig { - // TODO + // TODO(hjchen2) } message TensorRTConfig { optional bool use_fp16 = 1 [default = false]; optional bool use_int8 = 2 [default = false]; + optional string int8_calibration = 3; } optional bool use_xla_jit = 1 [default = false]; optional bool use_tensorrt = 2 [default = false]; diff --git a/oneflow/python/contrib/__init__.py b/oneflow/python/contrib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..af986df4b795786fa1591371fc0e1e18b0d4c9b0 --- /dev/null +++ b/oneflow/python/contrib/__init__.py @@ -0,0 +1 @@ +from .tensorrt import * diff --git a/oneflow/python/contrib/tensorrt/__init__.py b/oneflow/python/contrib/tensorrt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/oneflow/python/contrib/tensorrt/tensorrt_api.py b/oneflow/python/contrib/tensorrt/tensorrt_api.py new file mode 100644 index 0000000000000000000000000000000000000000..8e782363817740258201f1c89a6ebdf622ac7dc4 --- /dev/null +++ b/oneflow/python/contrib/tensorrt/tensorrt_api.py @@ -0,0 +1,14 @@ +from __future__ import absolute_import + +import oneflow.oneflow_internal as oneflow_internal +from oneflow.python.oneflow_export import oneflow_export + + +@oneflow_export("tensorrt.write_int8_calibration") +def write_int8_calibration(path): + oneflow_internal.WriteInt8Calibration(path) + + +@oneflow_export("tensorrt.cache_int8_calibration") +def cache_int8_calibration(): + oneflow_internal.CacheInt8Calibration() diff --git a/oneflow/python/framework/function_util.py b/oneflow/python/framework/function_util.py index 6afcf3323f79c19f69e3e082a4c6bbcbf9be0f14..64cec0866e570d9546aa9b0030b635b37b3e2ace 100644 --- a/oneflow/python/framework/function_util.py +++ b/oneflow/python/framework/function_util.py @@ -414,6 +414,12 @@ def set_tensorrt_use_int8(func_desc, value=True): func_desc.job_config_proto.xrt_config.tensorrt_config.use_int8 = value +@oneflow_function_config("tensorrt.int8_calibration") +def set_tensorrt_int8_calibration(func_desc, value): + assert func_desc.job_config_proto.xrt_config.tensorrt_config.use_int8 + func_desc.job_config_proto.xrt_config.tensorrt_config.int8_calibration = value + + @oneflow_function_config("default_distribute_strategy") def set_default_distribute_strategy(func_desc, value): assert isinstance(value, distribute_ctx.DistributeStrategy) diff --git a/oneflow/python/oneflow_internal.h b/oneflow/python/oneflow_internal.h index 3ad449764ebfacf803be3300b3deefa6f31c1eb8..d3a87b52c88806f2c1a1e4a771e65e4053c976ea 100644 --- a/oneflow/python/oneflow_internal.h +++ b/oneflow/python/oneflow_internal.h @@ -211,3 +211,11 @@ void OfBlob_CurMutTensorCopyShapeFrom(uint64_t of_blob_ptr, long* array, int siz auto* of_blob = reinterpret_cast(of_blob_ptr); return of_blob->CurMutTensorCopyShapeFrom(array, size); } + +void CacheInt8Calibration(std::string* error_str) { + oneflow::CacheInt8Calibration().GetDataAndSerializedErrorProto(error_str); +} + +void WriteInt8Calibration(const std::string& path, std::string* error_str) { + oneflow::WriteInt8Calibration(path).GetDataAndSerializedErrorProto(error_str); +} diff --git a/oneflow/python/oneflow_internal_helper.h b/oneflow/python/oneflow_internal_helper.h index f9f82327d13dcd2fe47985d5ee4b45c5204798d6..853d5ce2a606e09cdd9f5fc005a9bea76ccb4c81 100644 --- a/oneflow/python/oneflow_internal_helper.h +++ b/oneflow/python/oneflow_internal_helper.h @@ -24,6 +24,10 @@ #include "oneflow/core/framework/op_registration.h" #include "oneflow/core/persistence/tee_persistent_log_stream.h" +#ifdef WITH_TENSORRT +#include "oneflow/xrt/api.h" +#endif // WITH_TENSORRT + namespace oneflow { Maybe RegisterWatcherOnlyOnce(ForeignWatcher* watcher) { @@ -177,6 +181,25 @@ Maybe GetSerializedMachineId2DeviceIdListOFRecord( return PbMessage2TxtString(*JUST(ParseMachineAndDeviceIdList(parallel_conf))); } +Maybe CacheInt8Calibration() { +#ifdef WITH_TENSORRT + xrt::tensorrt::CacheInt8Calibration(); +#else + CHECK_OR_RETURN(0) << "Please recompile with TensorRT."; +#endif // WITH_TENSORRT + return Maybe::Ok(); +} + +Maybe WriteInt8Calibration(const std::string& path) { +#ifdef WITH_TENSORRT + xrt::tensorrt::CacheInt8Calibration(); + xrt::tensorrt::WriteInt8Calibration(path); +#else + CHECK_OR_RETURN(0) << "Please recompile with TensorRT."; +#endif // WITH_TENSORRT + return Maybe::Ok(); +} + Maybe CheckAndCompleteUserOpConf(const std::string& op_conf_str) { OperatorConf op_conf; CHECK_OR_RETURN(TxtString2PbMessage(op_conf_str, &op_conf)) << "operator conf parse failed"; diff --git a/oneflow/python/test/xrt/test_leaky_relu.py b/oneflow/python/test/xrt/test_leaky_relu.py new file mode 100644 index 0000000000000000000000000000000000000000..e01b57d24e14f7dae624f335ff6e80f47e1aa929 --- /dev/null +++ b/oneflow/python/test/xrt/test_leaky_relu.py @@ -0,0 +1,75 @@ +import unittest +import numpy as np + +import oneflow as flow + +config = flow.function_config() + + +def make_job(input_shape, alpha, dtype=flow.float32): + config.use_xla_jit(False) + config.use_tensorrt(False) + + @flow.global_function(config) + def leaky_relu_job(x=flow.FixedTensorDef(input_shape, dtype=dtype)): + return flow.nn.leaky_relu(x, alpha=alpha) + + return leaky_relu_job + + +def make_trt_job(input_shape, alpha, dtype=flow.float32): + config.use_xla_jit(False) + config.use_tensorrt(True) + + @flow.global_function(config) + def trt_leaky_relu_job(x=flow.FixedTensorDef(input_shape, dtype=dtype)): + return flow.nn.leaky_relu(x, alpha=alpha) + + return trt_leaky_relu_job + + +class TestLeakyRelu(unittest.TestCase): + def _test_body(self, x, alpha, dtype=np.float32): + f1 = make_job(x.shape, alpha, dtype=flow.float32) + f2 = make_trt_job(x.shape, alpha, dtype=flow.float32) + a = f1(x).get() + b = f2(x).get() + print("oneflow: ", a) + print("oneflow with tensorrt: ", b) + self.assertTrue(np.allclose(a.ndarray(), b.ndarray(), rtol=1e-03, atol=1e-05)) + flow.clear_default_session() + + def _test_ones_body(self, shape, alpha=0.1, dtype=np.float32): + x = np.ones(shape, dtype=dtype) + self._test_body(x, alpha, dtype=dtype) + + def _test_random_body(self, shape, alpha=0.1, dtype=np.float32): + # np.random.random generates float range from 0 to 1. + x = 100 * (np.random.random(shape).astype(dtype) - 0.5) + self._test_body(x, alpha, dtype=dtype) + + def test_ones_input(self): + self._test_ones_body((1), alpha=0.1) + self._test_ones_body((1, 10), alpha=0.1) + self._test_ones_body((2, 10, 2), alpha=0.1) + self._test_ones_body((2, 5, 2, 2), alpha=0.1) + + self._test_ones_body((1), alpha=0.33) + self._test_ones_body((1, 10), alpha=0.33) + self._test_ones_body((2, 10, 2), alpha=0.33) + self._test_ones_body((2, 5, 2, 2), alpha=0.33) + + def test_random_input(self): + self._test_random_body((1), alpha=0.1) + self._test_random_body((1, 10), alpha=0.1) + self._test_random_body((2, 10, 2), alpha=0.1) + self._test_random_body((2, 5, 2, 2), alpha=0.1) + + self._test_random_body((1), alpha=0.33) + self._test_random_body((1, 10), alpha=0.33) + self._test_random_body((2, 10, 2), alpha=0.33) + self._test_random_body((2, 5, 2, 2), alpha=0.33) + + +if __name__ == "__main__": + unittest.main() diff --git a/oneflow/python/test/xrt/test_online_int8.py b/oneflow/python/test/xrt/test_online_int8.py new file mode 100644 index 0000000000000000000000000000000000000000..167fa4a48a8b8a268bf8b627494d829eb71d04c4 --- /dev/null +++ b/oneflow/python/test/xrt/test_online_int8.py @@ -0,0 +1,139 @@ +import unittest +import numpy as np + +import oneflow as flow + +config = flow.function_config() + + +def make_trt_job( + x_shape, + w_shape, + kernel_size=None, + strides=None, + padding="valid", + data_format="NCHW", + dilation_rate=None, + dtype=flow.float32, +): + config.use_xla_jit(False) + config.use_tensorrt(True) + config.tensorrt.use_int8() + + @flow.global_function(config) + def trt_conv2d_job( + x=flow.FixedTensorDef(x_shape, dtype=dtype), + weight=flow.FixedTensorDef(w_shape, dtype=dtype), + ): + return flow.nn.conv2d(x, weight, strides, padding, data_format, dilation_rate) + + return trt_conv2d_job + + +class TestConv2d(unittest.TestCase): + def make_filter_shape(self, shape, filters, kernel_size, data_format): + if data_format == "NCHW": + return [filters, shape[1], kernel_size, kernel_size] + else: + return [filters, kernel_size, kernel_size, shape[3]] + + def _test_body( + self, + x, + filters, + kernel_size, + strides, + padding, + data_format, + dilation_rate, + dtype=np.float32, + ): + f2 = make_trt_job( + x.shape, + filters.shape, + kernel_size, + strides, + padding, + data_format, + dilation_rate, + dtype=flow.float32, + ) + + for i in range(1): + b = f2(x, filters).get() + print("with tensorrt float32: ", b) + + flow.tensorrt.cache_int8_calibration() + + for i in range(1): + b = f2(x, filters).get() + print("with tensorrt int8: ", b) + + flow.clear_default_session() + + def _test_ones_body( + self, + shape, + filters, + kernel_size, + strides, + padding, + data_format, + dilation_rate, + dtype=np.float32, + ): + assert len(shape) == 4 + x = np.ones(shape, dtype=dtype) + w_shape = self.make_filter_shape(shape, filters, kernel_size, data_format) + weight = np.random.random(w_shape).astype(dtype) + + self._test_body( + x, + weight, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + ) + + def _test_random_body( + self, + shape, + filters, + kernel_size, + strides, + padding, + data_format, + dilation_rate, + dtype=np.float32, + ): + assert len(shape) == 4 + x = np.random.random(shape).astype(dtype) + w_shape = self.make_filter_shape(shape, filters, kernel_size, data_format) + weight = np.random.random(w_shape).astype(dtype) + + self._test_body( + x, + weight, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + ) + + def test_random_kernel_1x1(self): + self._test_random_body( + shape=[3, 3, 5, 5], + filters=1, + kernel_size=1, + strides=1, + padding="VALID", + data_format="NCHW", + dilation_rate=1, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/oneflow/xrt/README.md b/oneflow/xrt/README.md index 9ed3cc2ba2c19507b6ed2df97c101d92ab4e2295..c71bc26fcc61b0cf8a87e85d6dc1094d6f6beb39 100644 --- a/oneflow/xrt/README.md +++ b/oneflow/xrt/README.md @@ -119,8 +119,64 @@ OneFlow中XRT的使用默认是关闭的,可以通过前端的Python接口和 # TensorRT float16 config.tensorrt.use_fp16() - # TensorRT int8 (目前尚未支持) + # TensorRT int8 (离线加载Calibration的方式) config.tensorrt.use_int8() + # Set int8 calibration table path + int8_calibration_path = "./int8_calibration" + config.tensorrt.int8_calibration(int8_calibration_path) + ``` + +#### 使用Int8量化计算 + +XRT支持离线加载和在线生成量化校准表两种方式来启动Int8的量化计算。离线加载的方式需要提前生成一个TensorRT格式的量化校准表,而且该量化校准表通常可以被重复使用,而在线生成的方式则在同一份脚本中,同时进行正常精度的计算和量化校准表的生成,一旦校准表生成后,则会在下一个迭代中自动切换到Int8精度的计算。 + +- 生成Int8量化校准表(Int8 Calibration Table) + + 首先你需要为生成量化校准表准备一个校准数据集,通常可以是训练集或验证集的一个子集。然后按照正常的网络配置,开启TensorRT Int8。比如: + + ```python + import oneflow as flow + + config = flow.function_config() + + config.use_tensorrt() + config.tensorrt.use_int8() + + @flow.function(config) + def Job(input): + # define your network + pass + ``` + 当开启Int8,但又没有指定对应的量化校准表时,XRT会自动进入量化表生成模式,之后feed的数据都会按照正常的精度(fp32或fp16)进行计算,计算的结果会被用于生成对应的Int8量化校准表。最后将生成的量化校准表保存到指定的目录,在该目录下,每一个子图都会生成一个对应的量化校准表文件。 + + ```python + # 使用10个batch的数据生成Int8量化校准表 + for _ in range(10): + input = next_calibration_batch() # 加载校准数据集 + Job(input).get() + + # 保存量化校准表 + flow.tensorrt.write_int8_calibration("./int8_calibration") # int8_calibration目录需要手动创建 + ``` + 当Int8量化校准表生成完成后,你就可以按照上面介绍的离线加载Calibration的方式启动TensorRT Int8的量化计算。 + +- 在线生成量化校准表并进行int8计算 + + 在线方式分成两个步骤,首先利用校准数据集生成量化校准表,然后直接利用生成的量化校准表进行Int8的构图和计算。同样以上面的Job为例, + + ```python + # 使用10个batch的数据生成Int8量化校准表 + for _ in range(10): + input = next_calibration_batch() # 加载校准数据集 + Job(input).get() + + # 缓存量化校准表 + flow.tensorrt.cache_int8_calibration() + + # 当量化校准表cache完成后,XRT会自动切换到int8的计算 + for _ in range(100): + input = next_batch() # 加载数据 + Job(input).get() ``` ### BenchMark diff --git a/oneflow/xrt/api.cpp b/oneflow/xrt/api.cpp index 90d62469218a18c78d1f8a94cf1f413e9522d26d..14680acdea2ff63efe641c353d71945f7023e2ea 100644 --- a/oneflow/xrt/api.cpp +++ b/oneflow/xrt/api.cpp @@ -1,11 +1,19 @@ #include "oneflow/xrt/api.h" +#include "absl/strings/str_cat.h" #include "glog/logging.h" #include "oneflow/core/operator/operator.h" // GenLogicalBlobName, GenLogicalBlobId #include "oneflow/xrt/build_graph.h" #include "oneflow/xrt/utility/env.h" +#include +#include + +#ifdef WITH_TENSORRT +#include "oneflow/xrt/tensorrt/trt_int8_calibrator.h" +#endif // WITH_TENSORRT + DEFINE_int32(clustering_minimum_nodes, EnvToInt(FLAGS_clustering_minimum_nodes, 1), "Minium nodes of a cluster after clustering."); DEFINE_int32(clustering_maximum_nodes, EnvToInt(FLAGS_clustering_maximum_nodes, 1000), @@ -24,6 +32,11 @@ DEFINE_bool(tensorrt_fp16, EnvToBool(FLAGS_tensorrt_fp16, false), DEFINE_bool(tensorrt_int8, EnvToBool(FLAGS_tensorrt_int8, false), "Enable int8 precision for TENSORRT engine."); +DEFINE_string(int8_calibration, EnvToString(FLAGS_int8_calibration, ""), + "TensorRT int8 calibration table directory. " + "Default is empty, and this means the calibration table will be " + "implictly generated if tensorrt_int8 flag is true."); + namespace oneflow { namespace xrt { @@ -73,6 +86,7 @@ static std::unordered_map user_op_type_name2string_map {"layer_norm_grad", "LayerNormGrad"}, {"scalar_add", "ScalarAdd"}, {"scalar_mul", "ScalarMul"}, + {"leaky_relu", "LeakyRelu"}, }; std::string ExtractOpTypeAsString(const OperatorConf &conf) { @@ -159,7 +173,12 @@ void InitXrtConfigurations(const XrtConfig &config) { if (config.has_tensorrt_config()) { const XrtConfig::TensorRTConfig &trt_config = config.tensorrt_config(); if (trt_config.has_use_fp16()) { FLAGS_tensorrt_fp16 = trt_config.use_fp16(); } - if (trt_config.has_use_int8()) { FLAGS_tensorrt_int8 = trt_config.use_int8(); } + if (trt_config.has_use_int8()) { + FLAGS_tensorrt_int8 = trt_config.use_int8(); + if (trt_config.has_int8_calibration()) { + FLAGS_int8_calibration = trt_config.int8_calibration(); + } + } } } @@ -193,5 +212,41 @@ void RunCompilationTimeXrtPasses(const OpGraph &op_graph, Job *job, bool train_p RunXrtPass("RebuildCompiledJob", graph.get(), options, job); } +#ifdef WITH_TENSORRT +namespace tensorrt { +void CacheInt8Calibration() { + const auto &calib_resources = TRTInt8CalibratorResource::All(); + for (const auto &res : calib_resources) { + std::lock_guard lock(res.second->mutex_); + if (!res.second->calibrator_->isDone()) { + res.second->calibrator_->waitAndSetDone(); + res.second->thread_->join(); + } + res.second->calibrator_->ReleaseDevBuffers(); + } +} + +void WriteInt8Calibration(const std::string &path) { + const auto &calib_resources = TRTInt8CalibratorResource::All(); + for (const auto &res : calib_resources) { + CHECK(res.second->calibrator_->isDone()) // NOLINT + << "Calibration table maybe has not been generated " + << "since the calibrator has not been done."; + + const std::string &calibration_table_data = + res.second->calibrator_->getCalibrationTableAsString(); + CHECK(calibration_table_data.size()) << "Calibration table data is empty."; + + std::string calib_store_path = // NOLINT + absl::StrCat(path, "/", res.first /*calibrator name*/); + std::ofstream ofile(calib_store_path, std::ios::out); + CHECK(ofile.good()) << "Could not open calibration file: " << calib_store_path; + ofile << calibration_table_data; + ofile.close(); + } +} +} // namespace tensorrt +#endif // WITH_TENSORRT + } // namespace xrt } // namespace oneflow diff --git a/oneflow/xrt/api.h b/oneflow/xrt/api.h index 5c80475c6151d83017ed1bb4d791edfd755156cf..844048242041f8384a19dc84ac963d12198a8e6d 100644 --- a/oneflow/xrt/api.h +++ b/oneflow/xrt/api.h @@ -66,6 +66,13 @@ inline void RunXrtPass(const std::string &pass, XrtGraph *graph, const XrtPassOp void RunCompilationTimeXrtPasses(const OpGraph &op_graph, Job *job, bool train_phase); +#ifdef WITH_TENSORRT +namespace tensorrt { +void CacheInt8Calibration(); +void WriteInt8Calibration(const std::string &path); +} // namespace tensorrt +#endif // WITH_TENSORRT + } // namespace xrt } // namespace oneflow diff --git a/oneflow/xrt/executable.h b/oneflow/xrt/executable.h index 1c382242c6df0430d2a29caaa4358ff86eb0d8ff..cefea3d721f6374f190908863db4b04d9b73f78a 100644 --- a/oneflow/xrt/executable.h +++ b/oneflow/xrt/executable.h @@ -35,6 +35,8 @@ struct ExecutableRunOptions { // Enable TensorRT int8 bool tensorrt_int8 = false; + std::string tensorrt_int8_calibration = ""; + // Feed the return parameters to reuse it's storage while running // the executable. std::vector return_params; @@ -42,21 +44,29 @@ struct ExecutableRunOptions { class Executable { public: - Executable(const XrtEngine &engine) : engine_(engine) {} + Executable(const std::string &name, const XrtEngine &engine) // NOLINT + : name_(name), engine_(engine) {} virtual ~Executable() = default; const XrtEngine &engine() const { return engine_; } - virtual bool Run(const std::vector &inputs, const ExecutableRunOptions &run_options, + const std::string &name() const { return name_; } + + virtual bool Run(const std::vector &inputs, // NOLINT + const ExecutableRunOptions &run_options, // NOLINT bool block_until_done = true) = 0; - bool RunAsync(const std::vector inputs, const ExecutableRunOptions &run_options) { + bool RunAsync(const std::vector inputs, // NOLINT + const ExecutableRunOptions &run_options) { return Run(inputs, run_options, false); } const std::vector &Results() const { return results_; } protected: + // Executable name. + std::string name_; + // Executable engine, XLA or TensorRT. XrtEngine engine_; std::vector results_; }; diff --git a/oneflow/xrt/launch_kernel.cpp b/oneflow/xrt/launch_kernel.cpp index b319a05e990fc64f25254dc73ba6023abca408ec..461d398d5feccb95d8e10f6eafa59793c48c16ce 100644 --- a/oneflow/xrt/launch_kernel.cpp +++ b/oneflow/xrt/launch_kernel.cpp @@ -15,6 +15,7 @@ DEFINE_int32(max_batch_size, EnvToInt(FLAGS_max_batch_size, 1), DECLARE_bool(tensorrt_fp16); DECLARE_bool(tensorrt_int8); +DECLARE_string(int8_calibration); namespace oneflow { namespace xrt { @@ -165,6 +166,8 @@ void XrtLaunchKernel::ForwardDataContent( CHECK_EQ(device_type, DeviceType::kGPU); run_options.max_batch_size = FLAGS_max_batch_size; run_options.tensorrt_fp16 = FLAGS_tensorrt_fp16; + run_options.tensorrt_int8 = FLAGS_tensorrt_int8; + run_options.tensorrt_int8_calibration = FLAGS_int8_calibration; } bool status = executable->Run(entry_params, run_options, block_until_done); CHECK(status) << "Executable is running failed."; diff --git a/oneflow/xrt/platform.cpp b/oneflow/xrt/platform.cpp index d5d73ac2996328b6ad3575f1c2efef44693c79f2..6e2ee1fd9cff704746229854a29550cacb4846d9 100644 --- a/oneflow/xrt/platform.cpp +++ b/oneflow/xrt/platform.cpp @@ -32,6 +32,25 @@ int GetDeviceId(const XrtDevice &device) { return 0; // Compiler warning free } +void SetDeviceId(const XrtDevice &device, const int device_id) { + switch (device) { + case XrtDevice::CPU_X86: return; + case XrtDevice::GPU_CUDA: { +#ifdef WITH_CUDA + CHECK_EQ(cudaSuccess, cudaSetDevice(device_id)); + return; +#endif + } + case XrtDevice::GPU_CL: + // TODO(hjchen2) + case XrtDevice::CPU_ARM: + // TODO(hjchen2) + case XrtDevice::GPU_ARM: + // TODO(hjchen2) + return; + } +} + } // namespace platform } // namespace xrt diff --git a/oneflow/xrt/platform.h b/oneflow/xrt/platform.h index c07935991d8f5a726b48f2e87d5129a9df52820d..907ec260057eef57d7e893eb58b9339c3663045f 100644 --- a/oneflow/xrt/platform.h +++ b/oneflow/xrt/platform.h @@ -10,6 +10,8 @@ namespace platform { int GetDeviceId(const XrtDevice &device); +void SetDeviceId(const XrtDevice &device, const int device_id); + } // namespace platform } // namespace xrt diff --git a/oneflow/xrt/tensorrt/ops/activation_op.cpp b/oneflow/xrt/tensorrt/ops/activation_op.cpp index 45476d3b16edd060ea1140f178c1906a02b2835e..67e19c9d0f30803821c7dbb7312a13b71be53db5 100644 --- a/oneflow/xrt/tensorrt/ops/activation_op.cpp +++ b/oneflow/xrt/tensorrt/ops/activation_op.cpp @@ -28,6 +28,22 @@ REGISTER_TRT_OP_KERNEL(Sigmoid, ActivationOp .EnableTrainPhase() .Finalize(); +template<> +class ActivationOp : public TrtOpKernel { + public: + void Compile(TrtOpContext *ctx) override { + nvinfer1::ITensor *in = ctx->SoleInput(); + auto *layer = ctx->builder()->addActivation(*in, nvinfer1::ActivationType::kLEAKY_RELU); + layer->setAlpha(ctx->Attr("alpha")); + layer->setName(ctx->op_name().c_str()); + ctx->SetSoleOutput(layer->getOutput(0)); + } +}; + +REGISTER_TRT_OP_KERNEL(LeakyRelu, ActivationOp) + .EnableTrainPhase() + .Finalize(); + } // namespace tensorrt } // namespace xrt } // namespace oneflow diff --git a/oneflow/xrt/tensorrt/trt_builder.h b/oneflow/xrt/tensorrt/trt_builder.h index 49101f2206cd6a0c4c09ae68117603600cca18b5..a8eee3262db9ec0b959096b2fcd86b2fdb04d351 100644 --- a/oneflow/xrt/tensorrt/trt_builder.h +++ b/oneflow/xrt/tensorrt/trt_builder.h @@ -65,7 +65,7 @@ class TrtBuilder { std::string builder_name_; // The next new handle number. - int64_t next_handle_ = 0; + int64_t next_handle_ = -1; nv::unique_ptr builder_; nv::unique_ptr network_; @@ -78,7 +78,7 @@ class TrtBuilder { util::Map>> host_weights_; public: - explicit TrtBuilder(const std::string &name) : builder_name_(name), next_handle_(0) { + explicit TrtBuilder(const std::string &name) : builder_name_(name), next_handle_(-1) { static nv::Logger logger; builder_.reset(nvinfer1::createInferBuilder(logger)); nvinfer1::NetworkDefinitionCreationFlags flags = @@ -87,6 +87,8 @@ class TrtBuilder { network_.reset(builder_->createNetworkV2(flags)); } + const std::string &name() const { return builder_name_; } + nvinfer1::ITensor *GetTensor(int64_t handle); nvinfer1::Weights &GetWeight(int64_t handle); @@ -138,7 +140,7 @@ class TrtBuilder { CHECK_GT(params_.count(handle), 0) << "Parameter is not found for handle " << handle; } - int64_t IncreaseHandle() { return next_handle_++; } + int64_t IncreaseHandle() { return ++next_handle_; } }; } // namespace tensorrt diff --git a/oneflow/xrt/tensorrt/trt_executable.cpp b/oneflow/xrt/tensorrt/trt_executable.cpp index fb18551a19d6d083674d9f4ecd03454bedc707d8..0311dfcb2c5ac51b81d0dc3b8be2a16f9fac4973 100644 --- a/oneflow/xrt/tensorrt/trt_executable.cpp +++ b/oneflow/xrt/tensorrt/trt_executable.cpp @@ -1,67 +1,120 @@ -#include "cuda_runtime.h" - #include "oneflow/xrt/tensorrt/trt_executable.h" +#include "oneflow/xrt/tensorrt/trt_int8_calibrator.h" +#include "oneflow/xrt/platform.h" + +#include +#include +#include "cuda_runtime.h" +#include "absl/strings/str_cat.h" namespace oneflow { namespace xrt { namespace tensorrt { -bool TrtExecutable::CreateExecutableEngine(const ExecutableRunOptions &run_options, - const int batch_size) { - if (!builder_ || !network_) { return false; } - auto build_config = nv::unique_ptr(builder_->createBuilderConfig()); - int64_t max_workspace_size = 1U << 24; // 16MiB - if (run_options.device_memory_limit > 0) { max_workspace_size = run_options.device_memory_limit; } +nvinfer1::ICudaEngine *TrtExecutable::CreateExecutableEngine( + const ExecutableRunOptions &run_options, const int batch_size /*= 1*/, + TRTInt8Calibrator *calibrator /*= nullptr*/) { + CHECK(builder_ && network_) << "Builder and network should be setup before."; + + auto build_config = // NOLINT + nv::unique_ptr(builder_->createBuilderConfig()); + int64_t max_workspace_size = 1U << 24; // 16MiB + if (run_options.device_memory_limit > 0) { // NOLINT + max_workspace_size = run_options.device_memory_limit; + } build_config->setMaxWorkspaceSize(max_workspace_size); nvinfer1::BuilderFlags flags = 0U; if (run_options.tensorrt_fp16) { if (builder_->platformHasFastFp16()) { flags |= (1U << int(nvinfer1::BuilderFlag::kFP16)); - // It does not guarantee using half precision if only set kFP16 flag, - // but you can set kSTRICT_TYPES to force using half precision. - // flags |= (1U << int(nvinfer1::BuilderFlag::kSTRICT_TYPES)); } else { LOG(INFO) << "TensorRT couldn't use fp16 precision since the GPU " "hardware does not support."; } } - // flags |= (1U << int(nvinfer1::BuilderFlag::kINT8)); + if (run_options.tensorrt_int8) { + if (builder_->platformHasFastInt8()) { + if (calibrator) { + flags |= (1U << int(nvinfer1::BuilderFlag::kINT8)); + if (builder_->platformHasFastFp16()) { // NOLINT + flags |= (1U << int(nvinfer1::BuilderFlag::kFP16)); + } + build_config->setInt8Calibrator(calibrator); + } + } else { + LOG(INFO) << "TensorRT couldn't use int8 precision since the GPU " + "hardware does not support."; + } + } + // It does not guarantee to use low precision if just set kFP16 or kint8 flag, + // but you can set kSTRICT_TYPES to enforce using half or int8 precision. + // flags |= (1U << int(nvinfer1::BuilderFlag::kSTRICT_TYPES)); + // flags |= (1U << int(nvinfer1::BuilderFlag::kREFIT)); build_config->setFlags(flags); - // build_config->setInt8Calibrator(); int32_t max_batch_size = std::max(run_options.max_batch_size, batch_size); builder_->setMaxBatchSize(max_batch_size); // builder_->setGpuAllocator(); - engine_.reset(builder_->buildEngineWithConfig(*network_, *build_config)); - return true; + return builder_->buildEngineWithConfig(*network_, *build_config); } bool TrtExecutable::ExecuteEngine(int batch_size, void **buffers, void *stream, bool block_until_done) { - if (!execution_context_) { execution_context_.reset(engine_->createExecutionContext()); } + if (!execution_context_) { // NOLINT + execution_context_.reset(engine_->createExecutionContext()); + } cudaStream_t cu_stream = reinterpret_cast(stream); bool status = // execution_context_->enqueue(batch_size, buffers, cu_stream, nullptr); execution_context_->enqueueV2(buffers, cu_stream, nullptr); - if (block_until_done) { CHECK_EQ(cudaSuccess, cudaStreamSynchronize(cu_stream)); } + if (block_until_done) { // NOLINT + CHECK_EQ(cudaSuccess, cudaStreamSynchronize(cu_stream)); + } return status; } +std::string TrtExecutable::LoadCalibrationTable( // NOLINT + const std::string &calibration_path) { + std::string calib_restore_path(absl::StrCat(calibration_path, "/", this->name())); + std::ifstream infile(calib_restore_path, std::ios::in); + CHECK(infile.good()) << "Could not open calibration file: " // NOLINT + << calib_restore_path; + std::stringstream buffer; + buffer << infile.rdbuf(); + return std::move(buffer.str()); +} + bool TrtExecutable::Run(const std::vector &inputs, - const ExecutableRunOptions &run_options, bool block_until_done) { - // TODO(hjchen2) + const ExecutableRunOptions &run_options, // NOLINT + bool block_until_done) { + // TODO(hjchen2): Refactor + if (run_options.tensorrt_int8 && !calibrator_ && // NOLINT + run_options.tensorrt_int8_calibration.size()) { + std::string calibration_data = // NOLINT + LoadCalibrationTable(run_options.tensorrt_int8_calibration); + CHECK(calibration_data.size()) << "Calibration data is empty."; + calibrator_.reset(new TRTInt8Calibrator(calibration_data)); + } if (!execution_context_ && !engine_) { - CHECK(CreateExecutableEngine(run_options)) << "Cannot create TensorRT executanble engine."; + engine_.reset(CreateExecutableEngine(run_options, 1 /*batch size*/, // NOLINT + calibrator_.get())); + CHECK(engine_) << "Cannot create TensorRT executable engine."; } - // All return params are results of the executable. + + // All return params are the results of the executable. this->results_ = run_options.return_params; + // TODO(hjchen2): Cache the parameters raw address. util::Map all_params; - for (const Parameter &input : inputs) { all_params.emplace(input.name(), &input); } - for (const Parameter &output : this->results_) { all_params.emplace(output.name(), &output); } + for (const Parameter &input : inputs) { // NOLINT + all_params.emplace(input.name(), &input); // NOLINT + } + for (const Parameter &output : this->results_) { // NOLINT + all_params.emplace(output.name(), &output); // NOLINT + } const int num_bindings = engine_->getNbBindings(); std::vector binding_params(num_bindings); @@ -75,13 +128,47 @@ bool TrtExecutable::Run(const std::vector &inputs, // TODO(hjchen2): Check batch size is same for all binding parameters. const int batch_size = binding_params[0]->shape().At(0); if (batch_size > engine_->getMaxBatchSize()) { - LOG(WARNING) << "Rebuild engine since the maximum batch size " << engine_->getMaxBatchSize() + LOG(WARNING) << "Rebuild engine since the maximum batch size " // NOLINT + << engine_->getMaxBatchSize() // NOLINT << " is less than the input batch size " << batch_size; - CHECK(CreateExecutableEngine(run_options, batch_size)) - << "Failed to create engine with batch size " << batch_size; + engine_.reset(CreateExecutableEngine(run_options, batch_size, // NOLINT + calibrator_.get())); + CHECK(engine_) << "Failed to create engine with batch size " << batch_size; execution_context_.reset(engine_->createExecutionContext()); } - return ExecuteEngine(batch_size, buffers.data(), run_options.stream, block_until_done); + + if (run_options.tensorrt_int8 && !calibrator_) { + auto *res = TRTInt8CalibratorResource::LookupOrCreate(this->name()); + { + std::lock_guard lock(res->mutex_); + if (!res->calibrator_) { + res->calibrator_.reset(new TRTInt8Calibrator()); + int ordinal = platform::GetDeviceId(XrtDevice::GPU_CUDA); + res->thread_.reset(new std::thread([this, ordinal, batch_size, res, // NOLINT + run_options]() { + platform::SetDeviceId(XrtDevice::GPU_CUDA, ordinal); + // TODO(hjchen2): TensorRT maybe crash if calibrator batch size > 1 + res->calibrator_->setBatchSize(1 /*batch_size*/); + res->engine_.reset( // NOLINT + this->CreateExecutableEngine(run_options, batch_size, // NOLINT + res->calibrator_.get())); + })); + } + } + + if (res->calibrator_->isDone()) { + CHECK_EQ(cudaSuccess, cudaStreamSynchronize( // NOLINT + reinterpret_cast(run_options.stream))); // NOLINT + calibrator_ = res->calibrator_; + // engine_ = std::move(res->engine_); + execution_context_.reset(res->engine_->createExecutionContext()); + } else { + res->calibrator_->setBatch(binding_params); + } + } + + return ExecuteEngine(batch_size, buffers.data(), run_options.stream, // NOLINT + block_until_done); } } // namespace tensorrt diff --git a/oneflow/xrt/tensorrt/trt_executable.h b/oneflow/xrt/tensorrt/trt_executable.h index c356aeda6b5e2de422014a8ba8e48f61a813e2c0..a70a68768994bcc526b9d67952244a4eb73f1283 100644 --- a/oneflow/xrt/tensorrt/trt_executable.h +++ b/oneflow/xrt/tensorrt/trt_executable.h @@ -7,6 +7,7 @@ #include "oneflow/xrt/executable.h" #include "oneflow/xrt/parameter.h" #include "oneflow/xrt/tensorrt/trt_unique_ptr.h" +#include "oneflow/xrt/tensorrt/trt_int8_calibrator.h" #include "oneflow/xrt/utility/stl.h" namespace oneflow { @@ -17,15 +18,17 @@ namespace tensorrt { class TrtExecutable : public Executable { public: explicit TrtExecutable( - nv::unique_ptr &&engine, + const std::string &name, nv::unique_ptr &&engine, const util::Map>> &host_weights) - : Executable(XrtEngine::TENSORRT), engine_(std::move(engine)), host_weights_(host_weights) {} + : Executable(name, XrtEngine::TENSORRT), + engine_(std::move(engine)), // NOLINT + host_weights_(host_weights) {} explicit TrtExecutable( - nv::unique_ptr &&builder, + const std::string &name, nv::unique_ptr &&builder, nv::unique_ptr &&network, const util::Map>> &host_weights) - : Executable(XrtEngine::TENSORRT), + : Executable(name, XrtEngine::TENSORRT), builder_(std::move(builder)), network_(std::move(network)), host_weights_(host_weights) {} @@ -36,16 +39,22 @@ class TrtExecutable : public Executable { bool block_until_done = true) override; private: - bool CreateExecutableEngine(const ExecutableRunOptions &run_options, const int batch_size = -1); + nvinfer1::ICudaEngine *CreateExecutableEngine(const ExecutableRunOptions &run_options, + const int batch_size = 1, + TRTInt8Calibrator *calibrator = nullptr); bool ExecuteEngine(const int batch_size, void **buffers, void *stream, bool block_until_done); + std::string LoadCalibrationTable(const std::string &calibration_path); + private: nv::unique_ptr engine_; nv::unique_ptr builder_; nv::unique_ptr network_; nv::unique_ptr execution_context_; + std::shared_ptr calibrator_; + util::Map>> host_weights_; }; diff --git a/oneflow/xrt/tensorrt/trt_graph_compiler.cpp b/oneflow/xrt/tensorrt/trt_graph_compiler.cpp index 17f6b377fcbeccbb3e85a6150db998ff34ac1e07..df20101cdbe70bc5f7462feb3c3dd9a0c201f4c2 100644 --- a/oneflow/xrt/tensorrt/trt_graph_compiler.cpp +++ b/oneflow/xrt/tensorrt/trt_graph_compiler.cpp @@ -81,9 +81,9 @@ std::shared_ptr TrtGraphCompiler::Compile( builder_->MarkOutput(value.handle()); } - // return std::make_shared(builder_->BuildCudaEngine()); - return std::make_shared(builder_->ReleaseBuilder(), builder_->ReleaseNetwork(), - builder_->host_weights()); + // return std::make_shared(builder_->name(), builder_->BuildCudaEngine()); + return std::make_shared(builder_->name(), builder_->ReleaseBuilder(), + builder_->ReleaseNetwork(), builder_->host_weights()); } REGISTER_GRAPH_COMPILER(XrtEngine::TENSORRT, TrtGraphCompiler); diff --git a/oneflow/xrt/tensorrt/trt_int8_calibrator.cpp b/oneflow/xrt/tensorrt/trt_int8_calibrator.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1c2322dc26d78030ed7fef67fd89de7db294d143 --- /dev/null +++ b/oneflow/xrt/tensorrt/trt_int8_calibrator.cpp @@ -0,0 +1,174 @@ +#include +#include +#include +#include +#include "cuda_runtime.h" + +#include "oneflow/xrt/tensorrt/trt_int8_calibrator.h" + +namespace oneflow { +namespace xrt { + +namespace tensorrt { + +void TRTInt8Calibrator::setBatchSize(const int batch_size) { // NOLINT + batch_size_ = batch_size; +} + +// set the batch size before constructing the thread to execute engine +int TRTInt8Calibrator::getBatchSize() const { // NOLINT + return batch_size_; +} + +TRTInt8Calibrator::TRTInt8Calibrator() // NOLINT + : done_(false), calib_running_(true), batch_is_set_(false) {} + +TRTInt8Calibrator::TRTInt8Calibrator(const std::string& calib_data) + : batch_size_(0), + done_(true), + calib_running_(false), + batch_is_set_(false), + calibration_table_(calib_data) {} + +TRTInt8Calibrator::~TRTInt8Calibrator() { ReleaseDevBuffers(); } + +void TRTInt8Calibrator::waitAndSetDone() { + std::unique_lock lk(cond_mtx_); + while ((calib_running_ || batch_is_set_) && !done_) cond_.wait(lk); + if (!done_) { + done_ = true; + cond_.notify_all(); + } +} + +void* TRTInt8Calibrator::createDevBuffer(const size_t buffer_size) { + LOG(INFO) << "Alloc memory buffer which size is " << buffer_size; + void* dev_buffer = nullptr; + CHECK_EQ(cudaSuccess, cudaMalloc(&dev_buffer, buffer_size)) // NOLINT + << "Failed to alloc " << buffer_size << " bytes for calibrator."; + CHECK(dev_buffer) << "Failed to alloc " << buffer_size // NOLINT + << " bytes for calibrator."; + return dev_buffer; +} + +void TRTInt8Calibrator::ReleaseDevBuffers() { + std::unique_lock lk(cond_mtx_); + CHECK(done_) << "Calibrator could not release the device buffers " + << "since it had not been done."; + for (auto it : dev_buffers_) { CHECK_EQ(cudaSuccess, cudaFree(it.second.first)); } + dev_buffers_.clear(); +} + +// There might be more than one input for trt subgraph, +// So, we use a map to store input information. +bool TRTInt8Calibrator::setBatch(const std::vector& params) { + std::unique_lock lk(cond_mtx_); + // There is a producer and a consumer. The producer set the batch data and + // the consumer get the batch data. The size of the data pool is one. + // So, the producer has to wait for the consumer to finish processing before + // they can set the data. + while ((calib_running_ || batch_is_set_) && (!done_)) cond_.wait(lk); + // The done_ is set to true using waitAndSetDone, When all calibration data + // are processed. + if (done_) return false; + + // Sets the batch. + for (const auto& it : params) { + auto dataptr = dev_buffers_.find(it->name()); + if (dataptr == dev_buffers_.end()) { + void* buffer = createDevBuffer(it->byte_size()); + dataptr = dev_buffers_ + .emplace(it->name(), // NOLINT + std::make_pair(buffer, it->byte_size())) + .first; + // dataptr = dev_buffers_.emplace(it->name(), std::make_pair(it->data(), + // it->byte_size())).first; + } + CHECK(dataptr != dev_buffers_.end()) // NOLINT + << "Buffer '" << it->name() << "' does not exist."; + + const auto& d = dataptr->second; + CHECK_EQ(cudaSuccess, // NOLINT + cudaMemcpy(d.first, it->data(), d.second, // NOLINT + cudaMemcpyDeviceToDevice)) // NOLINT + << "Fail to cudaMemcpy for " << it->name(); + } + + batch_is_set_ = true; + cond_.notify_all(); + return true; +} + +bool TRTInt8Calibrator::getBatch(void** bindings, const char** names, // NOLINT + int num_bindings) { + std::unique_lock lk(cond_mtx_); + // Notify finish of last round of calibration. + calib_running_ = false; + cond_.notify_all(); + + // As long as there is data in the pool, the consumer can get it. + while (!batch_is_set_ && !done_) cond_.wait(lk); + if (done_) return false; + + // Gets the batch + for (int i = 0; i < num_bindings; i++) { + auto it = dev_buffers_.find(names[i]); + if (it == dev_buffers_.end()) { + LOG(FATAL) << "Calibration engine asked for unknown tensor name '" // NOLINT + << names[i] << "' at position " << i; + } + bindings[i] = it->second.first; + } + + batch_is_set_ = false; + calib_running_ = true; + return true; +} + +void TRTInt8Calibrator::setDone() { + std::unique_lock lk(cond_mtx_); + done_ = true; + cond_.notify_all(); +} + +bool TRTInt8Calibrator::isDone() const { + std::unique_lock lk(cond_mtx_); + return done_; +} + +const void* TRTInt8Calibrator::readCalibrationCache(size_t& length) { + if (calibration_table_.empty()) return nullptr; + length = calibration_table_.size(); + return calibration_table_.data(); +} + +void TRTInt8Calibrator::writeCalibrationCache(const void* ptr, // NOLINT + std::size_t length) { + calibration_table_ = std::string((const char*)ptr, length); +} + +static std::unordered_map + resources; + +/*static*/ TRTInt8CalibratorResource* // NOLINT +TRTInt8CalibratorResource::LookupOrCreate(const std::string& name) { + static std::mutex mutex; + std::lock_guard lock(mutex); + auto it = resources.find(name); + if (it == resources.end()) { + it = resources.emplace(name, new TRTInt8CalibratorResource).first; + } + return it->second; +} + +/*static*/ const std::unordered_map& +TRTInt8CalibratorResource::All() { + return resources; +} + +} // namespace tensorrt + +} // namespace xrt +} // namespace oneflow diff --git a/oneflow/xrt/tensorrt/trt_int8_calibrator.h b/oneflow/xrt/tensorrt/trt_int8_calibrator.h new file mode 100644 index 0000000000000000000000000000000000000000..d4b363a053a287f6ff493e620e5dbb33a54c940d --- /dev/null +++ b/oneflow/xrt/tensorrt/trt_int8_calibrator.h @@ -0,0 +1,107 @@ +#ifndef ONEFLOW_XRT_TENSORRT_TRT_INT8_CALIBRATOR_H_ +#define ONEFLOW_XRT_TENSORRT_TRT_INT8_CALIBRATOR_H_ + +#include +#include +#include +#include + +#include "NvInfer.h" +#include "oneflow/xrt/parameter.h" +#include "oneflow/xrt/tensorrt/trt_unique_ptr.h" + +namespace oneflow { +namespace xrt { + +namespace tensorrt { + +// Refered from tensorflow +class TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator2 { + // class TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator { + public: + // Construct a calibrator for future calibration. + TRTInt8Calibrator(); + + // Construct a finalized calibrator where we don't need to run calibration any + // more, as the calibration data is provided. + TRTInt8Calibrator(const std::string& calibration_data); + + ~TRTInt8Calibrator(); + + int getBatchSize() const override; + + bool getBatch(void* bindings[], const char* names[], // NOLINT + int num_bindings) override; + + void setBatchSize(const int batch_size); + + // Feed calibration data to the calibrator, and return true if the data is + // accepted. Return false if the calibrator has been terminated. + bool setBatch(const std::vector& params); + + // Wait until the last batch is consumed by the calibrator and set done. + void waitAndSetDone(); + + // Notify that calibration is done and future batches provided by setBatch() + // will be ignored. + void setDone(); + + bool isDone() const; + + void* createDevBuffer(const size_t buffer_size); + + void ReleaseDevBuffers(); + + // If not null, calibration is skipped. + const void* readCalibrationCache(std::size_t& length) override; + + void writeCalibrationCache(const void* ptr, std::size_t length) override; + + const std::string& getCalibrationTableAsString() { // NOLINT + return calibration_table_; + } + + private: + int batch_size_; + + // mutex for condition_variable + mutable std::mutex cond_mtx_; + + // condition variable to implement producer-consumer queue for calibration + std::condition_variable cond_; + + // Is calibration finished? + bool done_; + + // Map to keep tensorrt input buffers and sizes keyed with buffer names + std::unordered_map> dev_buffers_; + + bool calib_running_; + bool batch_is_set_; + + std::string calibration_table_; +}; + +struct TRTInt8CalibratorResource { + public: + static TRTInt8CalibratorResource* LookupOrCreate(const std::string& name); + + static const std::unordered_map& + All(); + + // Individual mutex + mutable std::mutex mutex_; + + std::shared_ptr calibrator_; + std::shared_ptr thread_; + + nv::unique_ptr engine_; +}; + +} // namespace tensorrt + +} // namespace xrt +} // namespace oneflow + +#endif // ONEFLOW_XRT_TENSORRT_TRT_INT8_CALIBRATOR_H_ diff --git a/oneflow/xrt/tensorrt/trt_logger.cpp b/oneflow/xrt/tensorrt/trt_logger.cpp index 585062fcc3632159a7069289467eebdbccac2dca..daf2f1540e07b3312308e2cbea6b739fbb12f642 100644 --- a/oneflow/xrt/tensorrt/trt_logger.cpp +++ b/oneflow/xrt/tensorrt/trt_logger.cpp @@ -21,7 +21,7 @@ void Logger::log(ILogger::Severity severity, const char* msg) { break; } case ILogger::Severity::kERROR: { - LOG(ERROR) << name_ << ": " << msg; + LOG(FATAL) << name_ << ": " << msg; break; } case ILogger::Severity::kINTERNAL_ERROR: { diff --git a/oneflow/xrt/xla/xla_executable.h b/oneflow/xrt/xla/xla_executable.h index 26f47f2c57b831110e45f35d8cb1bd16d71ce724..362074ce7d96ba0235ba3e54ef368344e1716896 100644 --- a/oneflow/xrt/xla/xla_executable.h +++ b/oneflow/xrt/xla/xla_executable.h @@ -10,9 +10,10 @@ namespace mola { class XlaExecutable : public Executable { public: - XlaExecutable(const XrtDevice &device, const std::vector &input_shapes, + XlaExecutable(const std::string &name, const XrtDevice &device, + const std::vector &input_shapes, const xla::Shape &output_shape, std::unique_ptr &&executable) - : Executable(XrtEngine::XLA), + : Executable(name, XrtEngine::XLA), device_(device), input_shapes_(input_shapes), output_shape_(output_shape), diff --git a/oneflow/xrt/xla/xla_graph_compiler.cpp b/oneflow/xrt/xla/xla_graph_compiler.cpp index 60629dce303743f22b55a83eeace9d866f06b679..d2f17907460426af4b94305fcaa8d68cf8a1b19f 100644 --- a/oneflow/xrt/xla/xla_graph_compiler.cpp +++ b/oneflow/xrt/xla/xla_graph_compiler.cpp @@ -124,7 +124,7 @@ std::shared_ptr XlaGraphCompiler::BuildExecutable( build_options.set_result_layout(xla_output_shape); MOLA_CHECK_AND_ASSIGN(auto executable, client->Compile(computation, argument_layouts, build_options)); - return std::make_shared(this->device_, xla_input_shapes, xla_output_shape, + return std::make_shared(builder_->name(), this->device_, xla_input_shapes, xla_output_shape, std::move(executable)); }