提交 a77e4aec 编写于 作者: E Eugene Brevdo 提交者: TensorFlower Gardener

AddN for variants adds in a tree structure (pairwise summation)

Improves numerical precision (if applicable) using pairwise summation:
https://en.wikipedia.org/wiki/Pairwise_summation

Thanks to Rasmus Larsen for the succinct binary tree aggregation pseudocode.

Also adds AsString for Variant types: this emits the Variant as a string via
its DebugString().

PiperOrigin-RevId: 340246073
Change-Id: I009281f46cbea30d6e33ecf79a1723d62e96cc6d
上级 faac5b2f
......@@ -509,7 +509,7 @@ array([b'3.14', b'2.72'], dtype=object)
}];
let arguments = (ins
TensorOf<[TF_Bool, TF_Complex128, TF_Complex64, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8]>:$input,
TensorOf<[TF_Bool, TF_Complex128, TF_Complex64, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Variant]>:$input,
DefaultValuedAttr<I64Attr, "-1">:$precision,
DefaultValuedAttr<BoolAttr, "false">:$scientific,
......@@ -15226,4 +15226,4 @@ execution the transfer corresponds to.}]>:$dynamic_key,
let results = (outs);
TF_DerivedOperandTypeListAttr Tinputs = TF_DerivedOperandTypeListAttr<0>;
}
\ No newline at end of file
}
......@@ -370,24 +370,77 @@ class AddNOp<Device, Variant, OpKernelT, OpKernelConstructionT,
i, " has shape: ", ctx->input(i).shape().DebugString(), "."));
}
// Step 2: attempt to add using
// Step 2: Sum input variants in a tree-like structure using
// BinaryOpVariants(ADD_VARIANT_BINARY_OP, ...)
// For the output create a default-constructed variant object.
// TODO(ebrevdo): Perform summation in a tree-structure.
Tensor out(cpu_allocator(), DT_VARIANT, TensorShape({}));
Variant* v_out = &(out.scalar<Variant>()());
OP_REQUIRES_OK(ctx, BinaryOpVariants<Device>(
ctx, ADD_VARIANT_BINARY_OP,
ctx->input(0).template scalar<Variant>()(),
ctx->input(1).template scalar<Variant>()(), v_out));
for (int i = 2; i < num; ++i) {
const Variant tmp = std::move(*v_out);
const Variant& inp = ctx->input(i).template scalar<Variant>()();
OP_REQUIRES_OK(ctx, BinaryOpVariants<Device>(ctx, ADD_VARIANT_BINARY_OP,
inp, tmp, v_out));
//
// Pairwise summation provides better numerical precision by
// reducing round-off error:
//
// https://en.wikipedia.org/wiki/Pairwise_summation
//
// These two vectors are used to store and mark intermediate sums.
gtl::InlinedVector<bool, 4> temp_filled(num, false);
gtl::InlinedVector<Variant, 4> temp(num);
// Tree-based summation.
int skip = 1;
int n = num;
while (skip < n) {
int i = skip;
while (i < n) {
// TODO(ebrevdo, rmlarsen): Parallelize the pairwise summations in the
// inner loop if the variants are "large".
// x[i - skip] += x[i]
OP_REQUIRES_OK(ctx,
AddVariantTo(ctx, i - skip, i, &temp, &temp_filled));
// We won't use this index again, recover its memory.
temp[i].clear();
i += 2 * skip;
}
if (i == n) {
// x[0] += x[i - skip]
OP_REQUIRES_OK(ctx,
AddVariantTo(ctx, 0, i - skip, &temp, &temp_filled));
// We won't use this index again, recover its memory.
temp[i - skip].clear();
n -= skip;
}
skip *= 2;
}
Tensor out(cpu_allocator(), DT_VARIANT, TensorShape({}));
out.scalar<Variant>()() = std::move(temp[0]);
ctx->set_output(0, out);
}
private:
// AddVariantTo efficiently performs:
// temp[lhs_ix] <- array(lhs_ix) + array(rhs_ix)
// where array(ix) := (temp_filled[ix]
// ? temp[ix]
// : ctx->input(ix).scalar<Variant>()())
// This reduces (possibly expensive) copying of Variants from
// the inputs into temp at the lowest levels of the summation tree.
static inline Status AddVariantTo(OpKernelContextT* ctx, const int lhs_ix,
const int rhs_ix,
gtl::InlinedVector<Variant, 4>* temp,
gtl::InlinedVector<bool, 4>* temp_filled) {
Variant tmp;
if (temp_filled->at(lhs_ix)) tmp = std::move(temp->at(lhs_ix));
const Variant& a = temp_filled->at(lhs_ix)
? tmp
: ctx->input(lhs_ix).template scalar<Variant>()();
const Variant& b = temp_filled->at(rhs_ix)
? temp->at(rhs_ix)
: ctx->input(rhs_ix).template scalar<Variant>()();
Variant* c = &temp->at(lhs_ix);
TF_RETURN_IF_ERROR(
BinaryOpVariants<Device>(ctx, ADD_VARIANT_BINARY_OP, a, b, c));
temp_filled->at(lhs_ix) = true;
return Status::OK();
}
};
} // namespace tensorflow
......
......@@ -20,6 +20,9 @@ limitations under the License.
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/framework/variant_encode_decode.h"
#include "tensorflow/core/framework/variant_tensor_data.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
......@@ -112,6 +115,8 @@ class AsStringOp : public OpKernel {
break;
case DT_BOOL:
break;
case DT_VARIANT:
break;
default:
bool type_not_supported = true;
OP_REQUIRES(ctx, !type_not_supported,
......@@ -156,6 +161,12 @@ class AsStringOp : public OpKernel {
output_flat(i) = (input_flat(i)) ? "true" : "false";
}
} break;
case (DT_VARIANT): {
const auto& input_flat = input_tensor->flat<Variant>();
for (int i = 0; i < input_flat.size(); ++i) {
output_flat(i) = input_flat(i).DebugString();
}
} break;
case (DT_COMPLEX64): {
const auto& input_flat = input_tensor->flat<complex64>();
for (int i = 0; i < input_flat.size(); ++i) {
......
......@@ -18,6 +18,9 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/framework/variant_encode_decode.h"
#include "tensorflow/core/framework/variant_tensor_data.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
......@@ -148,6 +151,25 @@ TEST_F(AsStringGraphTest, Bool) {
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
}
TEST_F(AsStringGraphTest, Variant) {
TF_ASSERT_OK(Init(DT_VARIANT));
AddInput(DT_VARIANT, TensorShape({4}));
auto inputs = mutable_input(0)->flat<Variant>();
inputs(0) = 2;
inputs(1) = 3;
inputs(2) = true;
inputs(3) = Tensor("hi");
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_STRING, TensorShape({4}));
test::FillValues<tstring>(
&expected, {"Variant<type: int value: 2>", "Variant<type: int value: 3>",
"Variant<type: bool value: 1>",
("Variant<type: tensorflow::Tensor value: Tensor<type: string"
" shape: [] values: hi>>")});
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
}
TEST_F(AsStringGraphTest, String) {
Status s = Init(DT_STRING);
ASSERT_EQ(error::INVALID_ARGUMENT, s.code());
......
......@@ -116,7 +116,7 @@ REGISTER_OP("AsString")
.Output("output: string")
.Attr(
"T: {int8, int16, int32, int64, complex64, complex128, float, double, "
"bool}")
"bool, variant}")
.Attr("precision: int = -1")
.Attr("scientific: bool = false")
.Attr("shortest: bool = false")
......
......@@ -1606,6 +1606,7 @@ cuda_py_test(
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:string_ops",
"//third_party/py/numpy",
],
)
......
......@@ -26,8 +26,8 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.platform import test
......@@ -100,24 +100,28 @@ class AddNTest(test.TestCase):
# TODO(ebrevdo): Re-enable use_gpu=True once non-DMA Variant
# copying between CPU and GPU is supported.
with self.session(use_gpu=False):
variant_const_3 = create_constant_variant(3)
variant_const_4 = create_constant_variant(4)
variant_const_5 = create_constant_variant(5)
# 3 + 3 + 5 + 4 = 15.
result = math_ops.add_n((variant_const_3, variant_const_3,
variant_const_5, variant_const_4))
num_tests = 127
values = list(range(100))
variant_consts = [create_constant_variant(x) for x in values]
sum_count_indices = np.random.randint(1, 29, size=num_tests)
sum_indices = [
np.random.randint(100, size=count) for count in sum_count_indices]
expected_sums = [np.sum(x) for x in sum_indices]
variant_sums = [math_ops.add_n([variant_consts[i] for i in x])
for x in sum_indices]
# Smoke test -- ensure this executes without trouble.
# We use as_string() to get the Variant DebugString for the
# variant_sums; we know its value so we can check via string equality
# here.
#
# Right now, non-numpy-compatible objects cannot be returned from a
# session.run call; similarly, objects that can't be converted to
# native numpy types cannot be passed to ops.convert_to_tensor.
# For now, run the test and examine the output to see that the result is
# equal to 15.
result_op = logging_ops.Print(
result, [variant_const_3, variant_const_4, variant_const_5, result],
message=("Variants stored an int: c(3), c(4), c(5), "
"add_n(c(3), c(3), c(5), c(4)): ")).op
result_op.run()
variant_sums_string = string_ops.as_string(variant_sums)
self.assertAllEqual(
variant_sums_string,
["Variant<type: int value: {}>".format(s).encode("utf-8")
for s in expected_sums])
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册