提交 086bd64e 编写于 作者: R Raman Sarokin 提交者: TensorFlower Gardener

ConvConstants generation changed to use storage type properties instead of specific storage types.

ConvConstants extended for better mapping on thin dst tensors.

PiperOrigin-RevId: 339971199
Change-Id: Ia14be12361ec81c70c0522baa9516c8cb28ba7ff
上级 1f2e88a9
......@@ -46,9 +46,57 @@ int GetOptimalMaxConstantSize(const DeviceInfo& info) {
// src_size and dst_size must be <= 4;
std::string GenerateConv(int src_size, int dst_size, bool use_dot_conv,
int const_mem_offset, CalculationsPrecision precision,
const std::string& dst, const std::string& src) {
std::string result;
const std::string postfixes[] = {".x", ".y", ".z", ".w"};
if (use_dot_conv) {
const std::string src_postfixes[] = {".x", ".xy", ".xyz", ""};
const std::string src_postfix = src_postfixes[src_size - 1];
for (int i = 0; i < dst_size; ++i) {
result += " " + dst + postfixes[i] + " += dot(" + src +
", constants[" + std::to_string(const_mem_offset + i) + "]" +
src_postfix + ");\n";
} else {
const std::string dst_postfixes[] = {".x", ".xy", ".xyz", ""};
const std::string dst_postfix = dst_postfixes[dst_size - 1];
if (precision == CalculationsPrecision::F32_F16) {
for (int i = 0; i < src_size; ++i) {
if (i != 0) {
result += " + ";
std::string src_name = src;
if (src_size != 1) {
src_name += postfixes[i];
result += src_name + " * constants[" +
std::to_string(const_mem_offset + i) + "]" + dst_postfix;
std::string size = dst_size == 1 ? "" : std::to_string(dst_size);
result = " " + dst + dst_postfix + " += convert_float" + size + "(" +
result + ");\n";
} else {
for (int i = 0; i < src_size; ++i) {
std::string src_name = src;
if (src_size != 1) {
src_name += postfixes[i];
result += " " + dst + dst_postfix + " += " + src_name +
" * constants[" + std::to_string(const_mem_offset + i) + "]" +
dst_postfix + ";\n";
return result;
std::string GenerateConvolutionConstantCode(const OperationDef& op_def,
const OHWI& weights_shape,
bool stride_correction,
bool use_dot_conv,
GPUOperation* op) {
auto src_desc = op_def.src_tensors[0];
......@@ -69,48 +117,6 @@ std::string GenerateConvolutionConstantCode(const OperationDef& op_def,
const std::string kOutZ = std::to_string(out_z);
const int src_depth = DivideRoundUp(weights_shape.i, 4);
const auto src_tensor_type = op_def.src_tensors[0].storage_type;
const bool manual_clamp = src_tensor_type == TensorStorageType::BUFFER ||
src_tensor_type == TensorStorageType::IMAGE_BUFFER;
switch (op_def.precision) {
case CalculationsPrecision::F32:
case CalculationsPrecision::F16:
c += "#define CONV4(R, SRC, F, i) \\\n";
c += " R += SRC.x * F[i + 0]; \\\n";
c += " R += SRC.y * F[i + 1]; \\\n";
c += " R += SRC.z * F[i + 2]; \\\n";
c += " R += SRC.w * F[i + 3]; \n";
c += "#define CONV3(R, SRC, F, i) \\\n";
c += " R += SRC.x * F[i + 0]; \\\n";
c += " R += SRC.y * F[i + 1]; \\\n";
c += " R += SRC.z * F[i + 2]; \n";
c += "#define CONV2(R, SRC, F, i) \\\n";
c += " R += SRC.x * F[i + 0]; \\\n";
c += " R += SRC.y * F[i + 1]; \n";
c += "#define CONV1(R, SRC, F, i) \\\n";
c += " R += SRC * F[i + 0]; \n";
case CalculationsPrecision::F32_F16:
c += "#define CONV4(R, SRC, F, i) \\\n";
c += " R += convert_float4(SRC.x * F[i + 0] + SRC.y * F[i + 1]";
c += " + SRC.z * F[i + 2] + SRC.w * F[i + 3]);\n";
c += "#define CONV3(R, SRC, F, i) \\\n";
c += " R += convert_float4(SRC.x * F[i + 0] + SRC.y * F[i + 1]";
c += " + SRC.z * F[i + 2]);\n";
c += "#define CONV2(R, SRC, F, i) \\\n";
c += " R += convert_float4(SRC.x * F[i + 0] + SRC.y * F[i + 1]);\n";
c += "#define CONV1(R, SRC, F, i) \\\n";
c += " R += convert_float4(SRC * F[i + 0]);\n";
const std::string postfixes[] = {".x", ".xy", ".xyz", ""};
c += "__kernel void main_function(\n";
......@@ -133,23 +139,40 @@ std::string GenerateConvolutionConstantCode(const OperationDef& op_def,
c += " int start_y = Y * args.stride_y + args.padding_y;\n";
c += " ACCUM_FLT4 r[" + kOutZ + "];\n";
c += " for (int i = 0; i < " + kOutZ + "; ++i) {\n";
c += " r[i] = (ACCUM_FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n";
c += " }\n";
c += " __constant FLT4* constants = args.weights.GetPtr();\n";
for (int i = 0; i < out_z; ++i) {
c += " ACCUM_FLT4 r" + std::to_string(i) +
" = (ACCUM_FLT4)(0.0f, 0.0f, 0.0f, 0.0f);\n";
auto generate_check = [&]() {
std::string check;
const std::vector<Axis> axes{Axis::WIDTH, Axis::HEIGHT, Axis::DEPTH};
const std::vector<std::string> names{"x_out", "y_out", "z_out"};
for (int i = 0; i < axes.size(); ++i) {
const auto& axis = axes[i];
if (src_desc.HasAxis(axis) && !src_desc.SupportsZeroClamp(axis)) {
if (!check.empty()) {
check += " || ";
check += names[i];
return check;
const std::string check = generate_check();
int filters_counter = 0;
for (int s = 0; s < src_depth; ++s) {
const int ch_count = std::min(4, weights_shape.i - s * 4);
const std::string s_conv = "CONV" + std::to_string(ch_count);
const std::string s_count = ch_count == 1 ? "" : std::to_string(ch_count);
const int src_ch_count = std::min(4, weights_shape.i - s * 4);
const std::string s_count =
src_ch_count == 1 ? "" : std::to_string(src_ch_count);
const std::string s_type = absl::StrCat("FLT", s_count);
const std::string s_postfix = postfixes[ch_count - 1];
const std::string s_postfix = postfixes[src_ch_count - 1];
const std::string dilation_x =
op_def.IsBatchSupported() ? "args.dilation_x * args.src_tensor.Batch()"
: "args.dilation_x";
for (int ky = 0; ky < weights_shape.h; ++ky) {
std::string s_y = absl::StrCat("(start_y + ", ky, " * args.dilation_y)");
if (manual_clamp) {
if (!src_desc.SupportsZeroClamp(Axis::HEIGHT)) {
c += " {\n";
c += " bool y_out = " + s_y + " < 0 || " + s_y +
" >= args.src_tensor.Height();\n";
......@@ -158,25 +181,28 @@ std::string GenerateConvolutionConstantCode(const OperationDef& op_def,
c += " {\n";
std::string s_x =
absl::StrCat("(start_x + ", kx, " * " + dilation_x + ")");
if (manual_clamp) {
c += " bool x_out = " + s_x + "< 0 || " + s_x +
if (!src_desc.SupportsZeroClamp(Axis::WIDTH)) {
c += " bool x_out = " + s_x + " < 0 || " + s_x +
">= args.src_tensor.Width();\n";
c += " " + s_type + " src = x_out || y_out ?";
c += "(" + s_type + ")(0.0) : args.src_tensor.Read(" + s_x + ", " +
if (check.empty()) {
c += " " + s_type + " src = args.src_tensor.Read(" + s_x + ", " +
s_y + ", " + std::to_string(s) + ")" + s_postfix + ";\n";
} else {
c += " " + s_type + " src = args.src_tensor.Read(" + s_x + ", " +
c += " " + s_type + " src = x_out || y_out ? ";
c += "(" + s_type + ")(0.0) : args.src_tensor.Read(" + s_x + ", " +
s_y + ", " + std::to_string(s) + ")" + s_postfix + ";\n";
for (int d = 0; d < out_z; ++d) {
c += " " + s_conv + "(r[" + std::to_string(d) +
"], src, args.weigths.GetPtr(),";
c += " " + std::to_string(filters_counter) + ");\n";
filters_counter += ch_count;
const int dst_ch_count = std::min(4, weights_shape.o - d * 4);
c += GenerateConv(src_ch_count, dst_ch_count, use_dot_conv,
filters_counter, op_def.precision,
"r" + std::to_string(d), "src");
filters_counter += use_dot_conv ? dst_ch_count : src_ch_count;
c += " }\n";
if (manual_clamp) {
if (!src_desc.SupportsZeroClamp(Axis::HEIGHT)) {
c += " }\n";
......@@ -184,15 +210,31 @@ std::string GenerateConvolutionConstantCode(const OperationDef& op_def,
for (int i = 0; i < out_z; ++i) {
std::string s_i = std::to_string(i);
c += " {\n";
c += " FLT4 res = TO_FLT4(r[" + s_i + "]) + args.biases.Read(" + s_i +
c += " FLT4 res = TO_FLT4(r" + s_i + ") + args.biases.Read(" + s_i +
c += " args.dst_tensor.Write(res, X, Y, " + s_i + ");\n";
c += " args.dst_tensor.Write(res, X, Y, " + s_i + ");\n";
c += " }\n";
c += "}\n";
return c;
bool IsDotConvBetter(int src_channels, int dst_channels) {
if (dst_channels % 4 == 0) {
return false;
// dst_channels % 4 != 0
if (src_channels % 4 == 0) {
return true;
// dst_channels % 4 != 0 && src_channels % 4 != 0
const int src_depth = DivideRoundUp(src_channels, 4);
const int dst_depth = DivideRoundUp(dst_channels, 4);
return dst_channels * src_depth < src_channels * dst_depth;
} // namespace
bool IsConvConstantsSupported(const DeviceInfo& device_info,
......@@ -205,9 +247,14 @@ bool IsConvConstantsSupported(const DeviceInfo& device_info,
return false;
const bool use_dot_conv =
IsDotConvBetter(attr.weights.shape.i, attr.weights.shape.o);
const auto& w_shape = attr.weights.shape;
const int dst_channels = AlignByN(w_shape.o, 4);
const int filters_count = w_shape.i * dst_channels * w_shape.h * w_shape.w;
const int src_depth = DivideRoundUp(w_shape.i, 4);
const int dst_depth = DivideRoundUp(w_shape.o, 4);
const int aligned_ch_count =
use_dot_conv ? w_shape.o * src_depth * 4 : w_shape.i * dst_depth * 4;
const int filters_count = aligned_ch_count * w_shape.h * w_shape.w;
const int float_size = definition.precision == CalculationsPrecision::F32
? sizeof(float)
: sizeof(half);
......@@ -220,8 +267,11 @@ bool IsConvConstantsSupported(const DeviceInfo& device_info,
GPUOperation CreateConvConstants(const DeviceInfo& device_info,
const OperationDef& definition,
const Convolution2DAttributes& attr) {
const bool use_dot_conv =
IsDotConvBetter(attr.weights.shape.i, attr.weights.shape.o);
GPUOperation op(definition);
UploadWeightsForConvConstants(attr.weights, definition.precision, &op);
UploadWeightsForConvConstants(attr.weights, definition.precision,
use_dot_conv, &op);
op.args_.AddInt("stride_x", attr.strides.w);
op.args_.AddInt("stride_y", attr.strides.h);
op.args_.AddInt("padding_x", -attr.padding.prepended.w);
......@@ -232,8 +282,9 @@ GPUOperation CreateConvConstants(const DeviceInfo& device_info,
const bool stride_correction =
definition.IsBatchSupported() && attr.strides.w != 1;
op.code_ = GenerateConvolutionConstantCode(definition, attr.weights.shape,
stride_correction, &op);
op.code_ = GenerateConvolutionConstantCode(
definition, attr.weights.shape, stride_correction, use_dot_conv, &op);
if (definition.precision == CalculationsPrecision::F16 &&
device_info.IsAdreno3xx()) {
......@@ -54,20 +54,51 @@ void RearrangeWeightsForConvConstants(
if (s_ch < weights.shape.i && d_ch < weights.shape.o) {
const int f_index =
weights.shape.LinearIndex({d_ch, y, x, s_ch});
filters[i][j] = weights.data[f_index];
filters[j][i] = weights.data[f_index];
} else {
filters[i][j] = 0.0f;
filters[j][i] = 0.0f;
T filters_new[4];
for (int i = 0; i < 4; ++i) {
for (int j = 0; j < 4; ++j) {
filters_new[i][j] = filters[j][i];
for (int i = 0; i < channels_count; ++i) {
dst[counter++] = filters[i];
template <DataType S, typename T>
void RearrangeWeightsForConvConstantsDot(
const tflite::gpu::Tensor<OHWI, S>& weights, absl::Span<T> dst) {
const int dst_depth = DivideRoundUp(weights.shape.o, 4);
const int src_depth = DivideRoundUp(weights.shape.i, 4);
const int kernel_x = weights.shape.w;
const int kernel_y = weights.shape.h;
int counter = 0;
for (int s = 0; s < src_depth; ++s) {
for (int y = 0; y < kernel_y; ++y) {
for (int x = 0; x < kernel_x; ++x) {
for (int d = 0; d < dst_depth; ++d) {
const int channels_count = std::min(4, weights.shape.o - d * 4);
T filters[4];
for (int j = 0; j < channels_count; ++j) {
for (int i = 0; i < 4; ++i) {
const int s_ch = s * 4 + i;
const int d_ch = d * 4 + j;
if (s_ch < weights.shape.i && d_ch < weights.shape.o) {
const int f_index =
weights.shape.LinearIndex({d_ch, y, x, s_ch});
filters[j][i] = weights.data[f_index];
} else {
filters[j][i] = 0.0f;
for (int i = 0; i < channels_count; ++i) {
dst[counter++] = filters_new[i];
dst[counter++] = filters[i];
......@@ -78,14 +109,17 @@ void RearrangeWeightsForConvConstants(
template <DataType T>
void UploadWeightsForConvConstants(const tflite::gpu::Tensor<OHWI, T>& weights,
CalculationsPrecision precision,
GPUOperation* op) {
bool use_dot_conv, GPUOperation* op) {
const int src_depth = DivideRoundUp(weights.shape.i, 4);
const int dst_depth = DivideRoundUp(weights.shape.o, 4);
const int kernel_x = weights.shape.w;
const int kernel_y = weights.shape.h;
const bool f32_weights = precision == CalculationsPrecision::F32;
const int float_size = f32_weights ? 4 : 2;
const int float_count = weights.shape.i * dst_depth * 4 * kernel_x * kernel_y;
const int aligned_ch_count = use_dot_conv ? weights.shape.o * src_depth * 4
: weights.shape.i * dst_depth * 4;
const int float_count = aligned_ch_count * kernel_x * kernel_y;
BufferDescriptor desc;
desc.element_type = f32_weights ? DataType::FLOAT32 : DataType::FLOAT16;
......@@ -96,15 +130,25 @@ void UploadWeightsForConvConstants(const tflite::gpu::Tensor<OHWI, T>& weights,
if (f32_weights) {
float4* ptr = reinterpret_cast<float4*>(desc.data.data());
absl::MakeSpan(ptr, float_count / 4));
if (use_dot_conv) {
absl::MakeSpan(ptr, float_count / 4));
} else {
absl::MakeSpan(ptr, float_count / 4));
} else {
half4* ptr = reinterpret_cast<half4*>(desc.data.data());
absl::MakeSpan(ptr, float_count / 4));
if (use_dot_conv) {
absl::MakeSpan(ptr, float_count / 4));
} else {
absl::MakeSpan(ptr, float_count / 4));
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
想要评论请 注册