提交 b37b0d00 编写于 作者: R Raman Sarokin 提交者: TensorFlower Gardener

DepthwiseConvolution generation changed to use storage type properties instead...

DepthwiseConvolution generation changed to use storage type properties instead of specific storage types.

PiperOrigin-RevId: 339964553
Change-Id: I2cded9c306a40b136002c08e610daca2d75e1758
上级 50ec85cd
......@@ -86,13 +86,8 @@ std::string GenerateDepthwiseConvolutionCode(
}
op->AddDstTensor("dst_tensor", dst_desc);
const auto src_tensor_type = op_def.src_tensors[0].storage_type;
std::string c = GetCommonDefines(op_def.precision);
const bool manual_clamp = src_tensor_type == TensorStorageType::BUFFER ||
src_tensor_type == TensorStorageType::IMAGE_BUFFER;
c += "__kernel void main_function(\n";
c += "$0) {\n";
c += " int X = get_global_id(0);\n";
......@@ -142,84 +137,91 @@ std::string GenerateDepthwiseConvolutionCode(
std::string kernel_size_z =
dynamic_weights ? "args.weights.Depth()" : "args.kernel_size_z";
std::string flat_coords = "x_c, y_c";
if (manual_clamp) {
std::string check = "!outside_x && !outside_y";
if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) {
check += " && !outside_z";
flat_coords += ", z_c";
c += " for (int kz = 0; kz < " + kernel_size_z + "; ++kz) {\n";
c += " int z_c = z_offseted + kz * args.dilation_z;\n";
c += " bool outside_z = z_c < 0 || z_c >= args.src_tensor.Depth();\n";
}
c += " for (int ky = 0; ky < " + kernel_size_y + "; ++ky) {\n";
c += " int y_c = y_offseted + ky * args.dilation_y;\n";
c += " bool outside_y = y_c < 0 || y_c >= args.src_tensor.Height();\n";
c += " for (int kx = 0; kx < " + kernel_size_x + "; ++kx) {\n";
const std::string dilation_x =
op_def.IsBatchSupported() ? "args.dilation_x * args.src_tensor.Batch()"
: "args.dilation_x";
c += " int x_c = x_offseted + kx * " + dilation_x + ";\n";
c += " bool outside_x = x_c < 0 || x_c >= args.src_tensor.Width();\n";
c += " if (" + check + ") {\n";
if (dynamic_weights) {
c += " FLT4 f = args.weights.Read(kx, ky, S);\n";
} else {
if (weights_are_buffer) {
c += " FLT4 f = args.weights.Read(fx_c);\n";
} else {
c += " FLT4 f = args.weights.Read(fx_c, S);\n";
auto generate_check = [&]() {
std::string check;
const std::vector<Axis> axes{Axis::WIDTH, Axis::HEIGHT, Axis::DEPTH};
const std::vector<std::string> names{"outside_x", "outside_y", "outside_z"};
for (int i = 0; i < axes.size(); ++i) {
const auto& axis = axes[i];
if (src_desc.HasAxis(axis) && !src_desc.SupportsZeroClamp(axis)) {
if (!check.empty()) {
check += " && ";
}
check += "!" + names[i];
}
}
c += GetSrcValue(channel_multiplier, flat_coords);
c += " r += TO_ACCUM_TYPE(src_final * f);\n";
c += " };\n";
if (!dynamic_weights) {
c += " fx_c++;\n";
}
c += " }\n";
c += " }\n";
if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) {
c += " }\n";
}
} else { // Texture types with ZERO clamping
if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) {
flat_coords += ", z_c";
c += " for (int kz = 0; kz < " + kernel_size_z + "; ++kz) {\n";
c += " int z_c = z_offseted + kz * args.dilation_z;\n";
if (src_tensor_type !=
TensorStorageType::TEXTURE_3D) { // Only TEXTURE_3D supports clamping
// in DEPTH dimension
c += " if (z_c < 0 || z_c >= args.src_tensor.Depth()) {\n";
c += " fx_c += args.kernel_size_y * args.kernel_size_x;\n";
c += " continue;\n";
c += " }\n";
return check;
};
auto generate_coords = [&]() {
std::string check;
const std::vector<Axis> axes{Axis::WIDTH, Axis::HEIGHT, Axis::DEPTH};
const std::vector<std::string> names{"x_c", "y_c", "z_c"};
for (int i = 0; i < axes.size(); ++i) {
const auto& axis = axes[i];
if (src_desc.HasAxis(axis)) {
if (!check.empty()) {
check += ", ";
}
check += names[i];
}
}
return check;
};
const std::string check = generate_check();
const std::string coords = generate_coords();
if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) {
c += " for (int kz = 0; kz < " + kernel_size_z + "; ++kz) {\n";
c += " int z_c = z_offseted + kz * args.dilation_z;\n";
if (!src_desc.SupportsZeroClamp(Axis::DEPTH)) {
c += " bool outside_z = z_c < 0 || z_c >= args.src_tensor.Depth();\n";
}
}
if (op_def.dst_tensors[0].HasAxis(Axis::HEIGHT)) {
c += " for (int ky = 0; ky < " + kernel_size_y + "; ++ky) {\n";
c += " int y_c = y_offseted + ky * args.dilation_y;\n";
c += " for (int kx = 0; kx < " + kernel_size_x + "; ++kx) {\n";
if (!src_desc.SupportsZeroClamp(Axis::HEIGHT)) {
c += " bool outside_y = y_c < 0 || y_c >= args.src_tensor.Height();\n";
}
}
if (op_def.dst_tensors[0].HasAxis(Axis::WIDTH)) {
c += " for (int kx = 0; kx < " + kernel_size_x + "; ++kx) {\n";
const std::string dilation_x =
op_def.IsBatchSupported() ? "args.dilation_x * args.src_tensor.Batch()"
: "args.dilation_x";
c += " int x_c = x_offseted + kx * " + dilation_x + ";\n";
c += GetSrcValue(channel_multiplier, flat_coords);
if (dynamic_weights) {
c += " FLT4 f = args.weights.Read(kx, ky, S);\n";
c += " int x_c = x_offseted + kx * " + dilation_x + ";\n";
if (!src_desc.SupportsZeroClamp(Axis::WIDTH)) {
c += " bool outside_x = x_c < 0 || x_c >= args.src_tensor.Width();\n";
}
}
if (!check.empty()) {
c += " if (" + check + ") {\n";
}
if (dynamic_weights) {
c += " FLT4 f = args.weights.Read(kx, ky, S);\n";
} else {
if (weights_are_buffer) {
c += " FLT4 f = args.weights.Read(fx_c);\n";
} else {
if (weights_are_buffer) {
c += " FLT4 f = args.weights.Read(fx_c);\n";
} else {
c += " FLT4 f = args.weights.Read(fx_c, S);\n";
}
c += " fx_c++;\n";
c += " FLT4 f = args.weights.Read(fx_c, S);\n";
}
c += " r += TO_ACCUM_TYPE(src_final * f);\n";
}
c += GetSrcValue(channel_multiplier, coords);
c += " r += TO_ACCUM_TYPE(src_final * f);\n";
if (!check.empty()) {
c += " }\n";
}
if (!dynamic_weights) {
c += " fx_c++;\n";
}
if (op_def.dst_tensors[0].HasAxis(Axis::WIDTH)) {
c += " }\n";
}
if (op_def.dst_tensors[0].HasAxis(Axis::HEIGHT)) {
c += " }\n";
}
if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) {
c += " }\n";
if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) {
c += " }\n";
}
}
c += " FLT4 res0 = TO_FLT4(r) + args.biases.Read(S);\n";
if (op_def.dst_tensors[0].HasAxis(Axis::DEPTH)) {
......@@ -228,7 +230,6 @@ std::string GenerateDepthwiseConvolutionCode(
c += " args.dst_tensor.Write(res0, X, Y, S);\n";
}
c += "}\n";
return c;
}
} // namespace
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册