提交 1f2e88a9 编写于 作者: R Raman Sarokin 提交者: TensorFlower Gardener

ConvolutionTransposed3x3 generation changed to use storage type properties...

ConvolutionTransposed3x3 generation changed to use storage type properties instead of specific storage types.

PiperOrigin-RevId: 339967056
Change-Id: I5f6e192e21cbabf33bd152f91f4ab664b139fd14
上级 64edb2fb
......@@ -1356,6 +1356,7 @@ test_suite(
"conv_buffer_1x1_test",
"conv_constants_test",
"conv_powervr_test",
"convolution_transposed_3x3_test",
"convolution_transposed_3x3_thin_test",
"convolution_transposed_4x4_test",
"convolution_transposed_test",
......
......@@ -86,10 +86,6 @@ std::string ConvolutionTransposed3x3::GenerateConvolutionTransposedCode(
args_.AddInt("padding_x");
args_.AddInt("padding_y");
const auto src_tensor_type = op_def.src_tensors[0].storage_type;
const bool manual_clamp = src_tensor_type == TensorStorageType::BUFFER ||
src_tensor_type == TensorStorageType::IMAGE_BUFFER;
const bool need_local_mem =
weights_upload_type ==
ConvolutionTransposed3x3::WeightsUploadType::LOCAL_MEM_BY_THREADS ||
......@@ -170,26 +166,35 @@ std::string ConvolutionTransposed3x3::GenerateConvolutionTransposedCode(
ConvolutionTransposed3x3::WeightsUploadType::LOCAL_MEM_BY_THREADS) {
c += " int local_id = (int)(get_local_id(1) * 8 + get_local_id(0));\n";
}
if (manual_clamp) {
const std::string next_x = "SRC_X + " + pixel_stride;
const std::string next_x = "SRC_X + " + pixel_stride;
if (!src_desc.SupportsZeroClamp(Axis::WIDTH)) {
c += " bool in_x0 = SRC_X >= 0 && SRC_X < args.src_tensor.Width();\n";
c += " bool in_x1 = " + next_x + " >= 0 && " + next_x +
" < args.src_tensor.Width();\n";
}
if (!src_desc.SupportsZeroClamp(Axis::HEIGHT)) {
c += " bool in_y0 = SRC_Y >= 0 && SRC_Y < args.src_tensor.Height();\n";
c += " bool in_y1 = SRC_Y + 1 >= 0 && SRC_Y + 1 < "
"args.src_tensor.Height();\n";
if (src_tensor_type == TensorStorageType::BUFFER) {
c += " int xc0 = clamp(SRC_X, 0, args.src_tensor.Width() - 1);\n";
c += " int xc1 = clamp(" + next_x +
", 0, args.src_tensor.Width() - 1);\n";
c += " int yc0 = clamp(SRC_Y, 0, args.src_tensor.Height() - 1);\n";
c += " int yc1 = clamp(SRC_Y + 1, 0, args.src_tensor.Height() - 1);\n";
c += " args.src_tensor.GetAddress(addr_0, xc0, yc0, 0);\n";
c += " args.src_tensor.GetAddress(addr_1, xc1, yc0, 0);\n";
c += " args.src_tensor.GetAddress(addr_2, xc0, yc1, 0);\n";
c += " args.src_tensor.GetAddress(addr_3, xc1, yc1, 0);\n";
c += " int dz = args.src_tensor.SliceStride();\n";
} else { // TensorStorageType::IMAGE_BUFFER
}
auto generate_check = [&](int x, int y) {
std::string check;
const std::vector<Axis> axes{Axis::WIDTH, Axis::HEIGHT};
const std::vector<std::string> names{"in_x" + std::to_string(x),
"in_y" + std::to_string(y)};
for (int i = 0; i < axes.size(); ++i) {
const auto& axis = axes[i];
if (src_desc.HasAxis(axis) && !src_desc.SupportsZeroClamp(axis)) {
if (!check.empty()) {
check += " && ";
}
check += names[i];
}
}
return check;
};
if (src_desc.IsLinear()) {
if (src_desc.ReturnsZeroForNegOneRead()) {
c += " args.src_tensor.GetAddress(addr_0, SRC_X, SRC_Y, 0);\n";
c += " args.src_tensor.GetAddress(addr_1," + next_x + ", SRC_Y, 0);\n";
c += " args.src_tensor.GetAddress(addr_2, SRC_X, SRC_Y + 1, 0);\n";
......@@ -206,13 +211,24 @@ std::string ConvolutionTransposed3x3::GenerateConvolutionTransposedCode(
"in_y1));\n";
c += " int dz_3 = select(0, args.src_tensor.SliceStride(), (in_x1 && "
"in_y1));\n";
} else {
c += " int xc0 = clamp(SRC_X, 0, args.src_tensor.Width() - 1);\n";
c += " int xc1 = clamp(" + next_x +
", 0, args.src_tensor.Width() - 1);\n";
c += " int yc0 = clamp(SRC_Y, 0, args.src_tensor.Height() - 1);\n";
c += " int yc1 = clamp(SRC_Y + 1, 0, args.src_tensor.Height() - 1);\n";
c += " args.src_tensor.GetAddress(addr_0, xc0, yc0, 0);\n";
c += " args.src_tensor.GetAddress(addr_1, xc1, yc0, 0);\n";
c += " args.src_tensor.GetAddress(addr_2, xc0, yc1, 0);\n";
c += " args.src_tensor.GetAddress(addr_3, xc1, yc1, 0);\n";
c += " int dz = args.src_tensor.SliceStride();\n";
}
}
auto read_src = [&](int x, int y) {
if (manual_clamp) {
if (src_desc.IsLinear()) {
const std::string id = std::to_string(y * 2 + x);
const std::string addr = "addr_" + std::to_string(y * 2 + x);
if (src_tensor_type == TensorStorageType::IMAGE_BUFFER) {
if (src_desc.ReturnsZeroForNegOneRead()) {
return "args.src_tensor.Read(" + addr + "); " + addr + " += dz_" + id +
";\n";
} else {
......@@ -221,8 +237,13 @@ std::string ConvolutionTransposed3x3::GenerateConvolutionTransposedCode(
addr + " += dz;\n";
}
} else {
std::string check = generate_check(x, y);
if (!check.empty()) {
check = " * (FLT)(" + check + ")";
}
return "args.src_tensor.Read(SRC_X + " + std::to_string(x) + "*" +
pixel_stride + ", SRC_Y + " + std::to_string(y) + ", s);\n";
pixel_stride + ", SRC_Y + " + std::to_string(y) + ", s)" + check +
";\n";
}
};
const int padding_x_rem = abs(padding.x) % 2;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册