提交 e8c62ab3 编写于 作者: A Allen Lavoie 提交者: TensorFlower Gardener

A very basic start on some op handler infrastructure

Does not include a handler's tensor representation (and so no copy-on etc.), and almost all of the hooks are missing.

My medium-term goal is to get the parallel device working with function replay so TPU collectives work inside functions. That will also get us a replication primitive for use with the eager/graph agnostic C API, and I'll plan to call it from the existing custom device to start.

PiperOrigin-RevId: 340253840
Change-Id: Ic9a5acca7bf42ceb9cb54aca635a9861daca3b38
上级 aeafc6a5
......@@ -32,7 +32,7 @@ namespace tensorflow {
// environment, a traced representation etc.
class AbstractContext {
protected:
enum AbstractContextKind { kGraph, kMlir, kEager, kTfrt, kTape };
enum AbstractContextKind { kGraph, kMlir, kEager, kTfrt, kTape, kOpHandler };
explicit AbstractContext(AbstractContextKind kind) : kind_(kind) {}
virtual ~AbstractContext() {}
......
......@@ -30,7 +30,14 @@ namespace tensorflow {
// tracing or immediate execution mode.
class AbstractOperation {
protected:
enum AbstractOperationKind { kGraph, kMlir, kEager, kTfrt, kTape };
enum AbstractOperationKind {
kGraph,
kMlir,
kEager,
kTfrt,
kTape,
kOpHandler
};
explicit AbstractOperation(AbstractOperationKind kind) : kind_(kind) {}
virtual ~AbstractOperation() {}
......
......@@ -25,7 +25,7 @@ TapeOperation::TapeOperation(AbstractOperation* parent_op, Tape* tape,
parent_op_(parent_op),
tape_(tape),
registry_(registry) {
// TODO(srbs): Make AbstractOperation RefCounted.
// TODO(b/172003047): Consider making AbstractOperation RefCounted.
// parent_op_->Ref();
}
void TapeOperation::Release() {
......@@ -33,7 +33,7 @@ void TapeOperation::Release() {
delete this;
}
TapeOperation::~TapeOperation() {
// TODO(srbs): Make AbstractOperation RefCounted.
// TODO(b/172003047): Consider making AbstractOperation RefCounted.
// parent_op->Unref();
}
Status TapeOperation::Reset(const char* op, const char* raw_device_name) {
......
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
package(
licenses = ["notice"], # Apache 2.0
)
tf_cc_test(
name = "internal_test",
srcs = ["internal_test.cc"],
deps = [
":internal",
"//tensorflow/c/eager:c_api_experimental",
"//tensorflow/c/eager:c_api_unified_internal",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/platform:errors",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "internal",
srcs = ["internal.cc"],
hdrs = ["internal.h"],
deps = [
":wrapper_operation",
"//tensorflow/c:conversion_macros",
"//tensorflow/c/eager:abstract_context",
"//tensorflow/c/eager:abstract_operation",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:c_api_experimental",
"//tensorflow/core/platform:refcount",
"//tensorflow/core/platform:types",
],
)
cc_library(
name = "wrapper_operation",
srcs = ["wrapper_operation.cc"],
hdrs = ["wrapper_operation.h"],
deps = ["//tensorflow/c/eager:abstract_operation"],
)
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_INTERNAL_CC_
#define TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_INTERNAL_CC_
#include "tensorflow/c/experimental/op_handler/internal.h"
#include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/eager/abstract_context.h"
#include "tensorflow/c/eager/abstract_operation.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/experimental/op_handler/wrapper_operation.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
OpHandlerContext::OpHandlerContext(AbstractContext* parent_ctx)
: AbstractContext(kOpHandler), parent_ctx_(parent_ctx) {}
OpHandlerContext::~OpHandlerContext() {}
void OpHandlerContext::Release() { delete this; }
Status OpHandlerContext::RegisterFunction(AbstractFunction* function) {
return parent_ctx_->RegisterFunction(function);
}
Status OpHandlerContext::RemoveFunction(const string& function) {
return parent_ctx_->RemoveFunction(function);
}
void OpHandlerContext::set_default_handler(OpHandler* handler) {
handler->Ref();
default_handler_.reset(handler);
}
OpHandlerOperation* OpHandlerContext::CreateOperation() {
OpHandlerOperation* result =
new OpHandlerOperation(parent_ctx_->CreateOperation());
if (default_handler_ != nullptr) {
result->set_handler(default_handler_.get());
}
return result;
}
OpHandlerOperation::OpHandlerOperation(AbstractOperation* parent_op)
: WrapperOperation(parent_op, kOpHandler) {}
OpHandler* OpHandlerOperation::get_handler() { return handler_.get(); }
void OpHandlerOperation::set_handler(OpHandler* handler) {
if (handler != nullptr) {
handler->Ref();
}
handler_.reset(handler);
}
Status OpHandlerOperation::Execute(absl::Span<AbstractTensorHandle*> retvals,
int* num_retvals) {
if (handler_ == nullptr) {
return WrapperOperation::Execute(retvals, num_retvals);
} else {
return handler_->Execute(this, retvals, num_retvals);
}
}
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_INTERNAL_H_
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_INTERNAL_H_
#define TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_INTERNAL_H_
#include "tensorflow/c/conversion_macros.h"
#include "tensorflow/c/eager/abstract_context.h"
#include "tensorflow/c/eager/abstract_operation.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/experimental/op_handler/wrapper_operation.h"
#include "tensorflow/core/platform/refcount.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
class OpHandlerOperation;
// Op handlers are a convenient way to intercept and transform computation.
//
// The implementation is currently experimental and incomplete, but aims
// eventually to support tracing and replay of function bodies, gradients
// through copy operations, and a variety of hooks for things like debug
// strings. A public C API for op handlers is planned.
class OpHandler : public core::RefCounted {
public:
// Called on operation->Execute when operation->get_handler() == this.
//
// Allows the handler to customize or inspect `operation`'s execution.
virtual Status Execute(OpHandlerOperation* operation,
absl::Span<AbstractTensorHandle*> retvals,
int* num_retvals) = 0;
// Creates a new handler by merging this handler with `next_handler`.
//
// The new handler is expected to transform operations first with this handler
// and then execute the resulting operations on `next_handler` (by calling
// `OpHandlerOperation::set_handler` and passing `next_handler`). If this is
// not possible then the merge operation should fail.
virtual Status Merge(OpHandler* next_handler,
core::RefCountPtr<OpHandler>& merged_handler) = 0;
};
// Keeps some handler-specific metadata, but otherwise wraps a single
// AbstractOperation in the underlying context. The operation is created, its
// attributes set, etc., and at execution time it is presented to its handler,
// which may choose to execute it or simply inspect it and do something else.
//
// This is somewhat different than the Context approach, where the operation's
// construction is streamed through each layered Context. The streaming approach
// would require a much larger op handler public API, one function pointer per
// attribute type, and there is some ambiguity before an op is finalized about
// whether it should be presented as-is to handlers (regular operations) or
// replayed (function calls and control flow operations).
class OpHandlerOperation : public WrapperOperation {
public:
explicit OpHandlerOperation(AbstractOperation*);
OpHandler* get_handler();
void set_handler(OpHandler* handler);
Status Execute(absl::Span<AbstractTensorHandle*> retvals,
int* num_retvals) override;
protected:
core::RefCountPtr<OpHandler> handler_;
};
// A context which allows a default handler to be set for new operations. It
// otherwise defers to the context it wraps.
//
// TODO(allenl): A stack of contexts and a stack of handlers look pretty similar
// in some ways. Having each handler be its own context seems almost doable,
// with things like copy operations and function/control flow replay being
// somewhat tricky (since they should be generated at the top of the handler
// stack and "caught" at the bottom). After handlers have evolved for a bit we
// should re-evaluate whether the handler+context concepts can be merged.
class OpHandlerContext : public AbstractContext {
public:
explicit OpHandlerContext(AbstractContext*);
void Release() override;
OpHandlerOperation* CreateOperation() override;
Status RegisterFunction(AbstractFunction*) override;
Status RemoveFunction(const string&) override;
// For LLVM style RTTI.
static bool classof(const AbstractContext* ptr) {
return ptr->getKind() == kOpHandler;
}
~OpHandlerContext() override;
void set_default_handler(OpHandler* handler);
private:
AbstractContext* parent_ctx_; // Not owned.
core::RefCountPtr<OpHandler> default_handler_;
};
class ReleaseOpHandlerOperation {
public:
void operator()(OpHandlerOperation* operation) { operation->Release(); }
};
typedef std::unique_ptr<OpHandlerOperation, ReleaseOpHandlerOperation>
OpHandlerOperationPtr;
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_INTERNAL_H_
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/op_handler/internal.h"
#include "absl/types/span.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
class TestOpHandler : public OpHandler {
public:
TestOpHandler() : last_operation_(new std::string("")) {}
Status Execute(OpHandlerOperation* operation,
absl::Span<AbstractTensorHandle*> retvals,
int* num_retvals) override {
CHECK(operation->get_handler() == this);
*last_operation_ = operation->Name();
operation->set_handler(next_handler_.get());
return operation->Execute(retvals, num_retvals);
}
Status Merge(OpHandler* next_handler,
core::RefCountPtr<OpHandler>& merged_handler) override {
merged_handler.reset(new TestOpHandler(next_handler, last_operation_));
return Status::OK();
}
core::RefCountPtr<OpHandler> next_handler_ = nullptr;
// Shared between merged handlers of this type.
std::shared_ptr<std::string> last_operation_;
private:
TestOpHandler(OpHandler* next_handler,
std::shared_ptr<std::string> last_operation)
: next_handler_(next_handler), last_operation_(last_operation) {
next_handler->Ref();
}
};
TEST(INTERNAL_TEST, UseOpHandler) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TF_ExecutionContext, decltype(&TF_DeleteExecutionContext)>
c_ctx(TF_NewEagerExecutionContext(opts.get(), status.get()),
TF_DeleteExecutionContext);
OpHandlerContext ctx(unwrap(c_ctx.get()));
core::RefCountPtr<TestOpHandler> outer_handler(new TestOpHandler());
core::RefCountPtr<TestOpHandler> inner_handler(new TestOpHandler());
ctx.set_default_handler(outer_handler.get());
OpHandlerOperationPtr op(ctx.CreateOperation());
Status s = op->Reset("NoOp", "");
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
std::vector<AbstractTensorHandle*> retvals;
int num_retvals = 0;
EXPECT_EQ("", *outer_handler->last_operation_);
s = op->Execute(absl::Span<AbstractTensorHandle*>(retvals), &num_retvals);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
EXPECT_EQ("NoOp", *outer_handler->last_operation_);
*outer_handler->last_operation_ = "";
EXPECT_EQ("", *inner_handler->last_operation_);
// This op executes on both handlers, changing the state of `inner_handler`
// since the handler has decided to preserve that state across merges.
core::RefCountPtr<OpHandler> merged;
s = inner_handler->Merge(outer_handler.get(), merged);
ctx.set_default_handler(merged.get());
op.reset(ctx.CreateOperation());
s = op->Reset("NoOp", "");
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
s = op->Execute(absl::Span<AbstractTensorHandle*>(retvals), &num_retvals);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
EXPECT_EQ("NoOp", *inner_handler->last_operation_);
EXPECT_EQ("NoOp", *outer_handler->last_operation_);
inner_handler.reset();
outer_handler.reset();
op.reset(ctx.CreateOperation());
s = op->Reset("NoOp", "");
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
s = op->Execute(absl::Span<AbstractTensorHandle*>(retvals), &num_retvals);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
}
} // namespace tensorflow
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/op_handler/wrapper_operation.h"
namespace tensorflow {
WrapperOperation::WrapperOperation(AbstractOperation* parent_op,
AbstractOperationKind kind)
: AbstractOperation(kind), parent_op_(parent_op) {
// TODO(b/172003047): Consider making AbstractOperation RefCounted.
// parent_op_->Ref();
}
void WrapperOperation::Release() {
parent_op_->Release();
// TODO(b/172003047): Consider making AbstractOperation RefCounted.
delete this;
}
Status WrapperOperation::Reset(const char* op, const char* raw_device_name) {
return parent_op_->Reset(op, raw_device_name);
}
const string& WrapperOperation::Name() const { return parent_op_->Name(); }
const string& WrapperOperation::DeviceName() const {
return parent_op_->DeviceName();
}
Status WrapperOperation::SetDeviceName(const char* name) {
return parent_op_->SetDeviceName(name);
}
Status WrapperOperation::AddInput(AbstractTensorHandle* input) {
return parent_op_->AddInput(input);
}
Status WrapperOperation::AddInputList(
absl::Span<AbstractTensorHandle* const> inputs) {
return parent_op_->AddInputList(inputs);
}
Status WrapperOperation::SetAttrString(const char* attr_name, const char* data,
size_t length) {
return parent_op_->SetAttrString(attr_name, data, length);
}
Status WrapperOperation::SetAttrInt(const char* attr_name, int64_t value) {
return parent_op_->SetAttrInt(attr_name, value);
}
Status WrapperOperation::SetAttrFloat(const char* attr_name, float value) {
return parent_op_->SetAttrFloat(attr_name, value);
}
Status WrapperOperation::SetAttrBool(const char* attr_name, bool value) {
return parent_op_->SetAttrBool(attr_name, value);
}
Status WrapperOperation::SetAttrType(const char* attr_name, DataType value) {
return parent_op_->SetAttrType(attr_name, value);
}
Status WrapperOperation::SetAttrShape(const char* attr_name,
const int64_t* dims, const int num_dims) {
return parent_op_->SetAttrShape(attr_name, dims, num_dims);
}
Status WrapperOperation::SetAttrFunction(const char* attr_name,
const AbstractOperation* value) {
return parent_op_->SetAttrFunction(attr_name, value);
}
Status WrapperOperation::SetAttrFunctionName(const char* attr_name,
const char* value, size_t length) {
return parent_op_->SetAttrFunctionName(attr_name, value, length);
}
Status WrapperOperation::SetAttrTensor(const char* attr_name,
AbstractTensorInterface* tensor) {
return parent_op_->SetAttrTensor(attr_name, tensor);
}
Status WrapperOperation::SetAttrStringList(const char* attr_name,
const void* const* values,
const size_t* lengths,
int num_values) {
return parent_op_->SetAttrStringList(attr_name, values, lengths, num_values);
}
Status WrapperOperation::SetAttrFloatList(const char* attr_name,
const float* values, int num_values) {
return parent_op_->SetAttrFloatList(attr_name, values, num_values);
}
Status WrapperOperation::SetAttrIntList(const char* attr_name,
const int64_t* values, int num_values) {
return parent_op_->SetAttrIntList(attr_name, values, num_values);
}
Status WrapperOperation::SetAttrTypeList(const char* attr_name,
const DataType* values,
int num_values) {
return parent_op_->SetAttrTypeList(attr_name, values, num_values);
}
Status WrapperOperation::SetAttrBoolList(const char* attr_name,
const unsigned char* values,
int num_values) {
return parent_op_->SetAttrBoolList(attr_name, values, num_values);
}
Status WrapperOperation::SetAttrShapeList(const char* attr_name,
const int64_t** dims,
const int* num_dims, int num_values) {
return parent_op_->SetAttrShapeList(attr_name, dims, num_dims, num_values);
}
Status WrapperOperation::SetAttrFunctionList(
const char* attr_name, absl::Span<const AbstractOperation*> values) {
return parent_op_->SetAttrFunctionList(attr_name, values);
}
AbstractOperation* WrapperOperation::GetBackingOperation() {
return parent_op_;
}
Status WrapperOperation::Execute(absl::Span<AbstractTensorHandle*> retvals,
int* num_retvals) {
return parent_op_->Execute(retvals, num_retvals);
}
} // namespace tensorflow
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_WRAPPER_OPERATION_H_
#define TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_WRAPPER_OPERATION_H_
#include "tensorflow/c/eager/abstract_operation.h"
namespace tensorflow {
// Forwards all of the AbstractOperation's methods to its wrapped operation.
//
// Useful as a base class to default to forwarding while adding some
// customization.
class WrapperOperation : public AbstractOperation {
public:
explicit WrapperOperation(AbstractOperation*, AbstractOperationKind kind);
void Release() override;
Status Reset(const char* op, const char* raw_device_name) override;
const string& Name() const override;
const string& DeviceName() const override;
Status SetDeviceName(const char* name) override;
Status AddInput(AbstractTensorHandle* input) override;
Status AddInputList(absl::Span<AbstractTensorHandle* const> inputs) override;
Status Execute(absl::Span<AbstractTensorHandle*> retvals,
int* num_retvals) override;
Status SetAttrString(const char* attr_name, const char* data,
size_t length) override;
Status SetAttrInt(const char* attr_name, int64_t value) override;
Status SetAttrFloat(const char* attr_name, float value) override;
Status SetAttrBool(const char* attr_name, bool value) override;
Status SetAttrType(const char* attr_name, DataType value) override;
Status SetAttrShape(const char* attr_name, const int64_t* dims,
const int num_dims) override;
Status SetAttrFunction(const char* attr_name,
const AbstractOperation* value) override;
Status SetAttrFunctionName(const char* attr_name, const char* value,
size_t length) override;
Status SetAttrTensor(const char* attr_name,
AbstractTensorInterface* tensor) override;
Status SetAttrStringList(const char* attr_name, const void* const* values,
const size_t* lengths, int num_values) override;
Status SetAttrFloatList(const char* attr_name, const float* values,
int num_values) override;
Status SetAttrIntList(const char* attr_name, const int64_t* values,
int num_values) override;
Status SetAttrTypeList(const char* attr_name, const DataType* values,
int num_values) override;
Status SetAttrBoolList(const char* attr_name, const unsigned char* values,
int num_values) override;
Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
const int* num_dims, int num_values) override;
Status SetAttrFunctionList(
const char* attr_name,
absl::Span<const AbstractOperation*> values) override;
AbstractOperation* GetBackingOperation();
private:
AbstractOperation* parent_op_;
};
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_OP_HANDLER_WRAPPER_OPERATION_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册