提交 b75bd29c 编写于 作者: M minqiyang

Remove debug info

上级 7a43e517
......@@ -26,46 +26,17 @@ ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope,
scope_(scope),
place_(place) {}
struct RecordTime {
RecordTime(const std::string &name, const std::string &type)
: name_(name), type_(type), start_(std::chrono::system_clock::now()) {}
~RecordTime() {
if (type_ == "elementsize_add") {
end_ = std::chrono::system_clock::now();
std::chrono::duration<double> diff = end_ - start_;
VLOG(1) << name_ << " " << type_ << " time record: " << diff.count();
}
}
std::string name_;
std::string type_;
std::chrono::system_clock::time_point start_;
std::chrono::system_clock::time_point end_;
};
void ComputationOpHandle::RunImpl() {
{
RecordTime rt("ComputationOpHandle::RunImpl", "Wait");
WaitInputVarGenerated(place_);
}
Scope *scope = nullptr;
{
RecordTime rt("ComputationOpHandle::RunImpl", "PrepareScope");
scope = scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
}
{
RecordTime rt("ComputationOpHandle::RunImpl", "ReallyRun " + op_->Type());
WaitInputVarGenerated(place_);
auto run_func = [this, scope]() { op_->Run(*scope, place_); };
auto run_func = [this]() {
op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(), place_);
};
if (is_lock_and_record_event_free_) {
run_func();
} else {
this->RunAndRecordEvent(run_func);
}
if (is_lock_and_record_event_free_) {
run_func();
} else {
this->RunAndRecordEvent(run_func);
}
}
......
......@@ -41,7 +41,7 @@ OpHandleBase::~OpHandleBase() {
void OpHandleBase::Run(bool use_cuda) {
#ifdef PADDLE_WITH_CUDA
if (events_.empty() && use_cuda && !dev_ctxes_.empty()) {
if (events_.empty() && use_cuda) {
for (auto &p : dev_ctxes_) {
int dev_id = boost::get<platform::CUDAPlace>(p.first).device;
PADDLE_ENFORCE(cudaSetDevice(dev_id));
......
......@@ -20,6 +20,10 @@ limitations under the License. */
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_desc.h"
DEFINE_bool(enforce_when_check_program, true,
"Checking whether the program is correct or not. We will log "
"errors rather than throwing exceptions if this flag turned off");
namespace paddle {
namespace framework {
namespace ir {
......@@ -28,55 +32,85 @@ namespace {
void CheckProgram(const ProgramDesc &program) {
#define _INT(role) static_cast<int>(role)
// std::map<int, bool> visit;
// for (OpDesc *op : program.Block(0).AllOps()) {
// // For backward compatibility, some program doesn't have role added.
// if (!op->HasAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) continue;
// int role_id =
// boost::get<int>(op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName()));
// visit[role_id] = true;
// switch (role_id) {
// case _INT(OpRole::kForward):
// if (visit.find(_INT(OpRole::kBackward)) != visit.end()) {
// LOG(ERROR)
// << "Cannot add backward operator before forward operator %s."
// << op->Type();
// }
// break;
// case _INT(OpRole::kBackward):
// case _INT(OpRole::kBackward) | _INT(OpRole::kLoss):
// PADDLE_ENFORCE(
// visit.find(_INT(OpRole::kOptimize)) == visit.end(),
// "Cannot add backward operator %s after optimize operator.",
// op->Type());
// break;
// case _INT(OpRole::kForward) | _INT(OpRole::kLoss):
// PADDLE_ENFORCE(visit.find(_INT(OpRole::kBackward) |
// _INT(OpRole::kLoss)) == visit.end(),
// "Cannot add backward|loss operator before "
// "forward|loss operator %s.",
// op->Type());
// PADDLE_ENFORCE(
// visit.find(_INT(OpRole::kOptimize)) == visit.end(),
// "Cannot add forward|loss operator %s after optimize operator.",
// op->Type());
// break;
// case _INT(OpRole::kOptimize):
// case _INT(OpRole::kOptimize) | _INT(OpRole::kLRSched):
// PADDLE_ENFORCE(visit.find(_INT(OpRole::kBackward)) != visit.end(),
// "Optimize operators %s must follow backward operator.",
// op->Type());
// break;
// case _INT(OpRole::kLRSched):
// case _INT(OpRole::kDist):
// case _INT(OpRole::kRPC):
// case _INT(OpRole::kNotSpecified):
// break;
// default:
// LOG(FATAL) << "Unknown operator role. Don't add new role because "
// "you don't know what you are doing.";
// }
// }
std::map<int, bool> visit;
for (OpDesc *op : program.Block(0).AllOps()) {
// For backward compatibility, some program doesn't have role added.
if (!op->HasAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) continue;
int role_id =
boost::get<int>(op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName()));
visit[role_id] = true;
switch (role_id) {
case _INT(OpRole::kForward):
if (visit.find(_INT(OpRole::kBackward)) != visit.end()) {
LOG(ERROR)
<< "Cannot add backward operator before forward operator %s."
<< op->Type();
}
break;
case _INT(OpRole::kBackward):
case _INT(OpRole::kBackward) | _INT(OpRole::kLoss):
if (!FLAGS_enforce_when_check_program) {
PADDLE_ENFORCE(
visit.find(_INT(OpRole::kOptimize)) == visit.end(),
"Cannot add backward operator %s after optimize operator.",
op->Type());
} else {
if (visit.find(_INT(OpRole::kOptimize)) != visit.end()) {
LOG(ERROR)
<< "Cannot add backward operator %s after optimize operator.",
<< op->Type();
}
}
break;
case _INT(OpRole::kForward) | _INT(OpRole::kLoss):
if (!FLAGS_enforce_when_check_program) {
PADDLE_ENFORCE(visit.find(_INT(OpRole::kBackward) |
_INT(OpRole::kLoss)) == visit.end(),
"Cannot add backward|loss operator before "
"forward|loss operator %s.",
op->Type());
PADDLE_ENFORCE(
visit.find(_INT(OpRole::kOptimize)) == visit.end(),
"Cannot add forward|loss operator %s after optimize operator.",
op->Type());
} else {
if (visit.find(_INT(OpRole::kBackward) | _INT(OpRole::kLoss)) !=
visit.end()) {
LOG(ERROR) << "Cannot add backward|loss operator before "
<< "forward|loss operator %s." << op->Type();
}
if (visit.find(_INT(OpRole::kOptimize)) != visit.end()) {
LOG(ERROR) << "Cannot add forward|loss operator %s after optimize "
"operator.",
<< op->Type();
}
}
break;
case _INT(OpRole::kOptimize):
case _INT(OpRole::kOptimize) | _INT(OpRole::kLRSched):
if (!FLAGS_enforce_when_check_program) {
PADDLE_ENFORCE(visit.find(_INT(OpRole::kBackward)) != visit.end(),
"Optimize operators %s must follow backward operator.",
op->Type());
} else {
if (visit.find(_INT(OpRole::kBackward)) == visit.end()) {
LOG(ERROR)
<< "Optimize operators %s must follow backward operator.",
<< op->Type();
}
}
break;
case _INT(OpRole::kLRSched):
case _INT(OpRole::kDist):
case _INT(OpRole::kRPC):
case _INT(OpRole::kNotSpecified):
break;
default:
LOG(FATAL) << "Unknown operator role. Don't add new role because "
"you don't know what you are doing.";
}
}
#undef _INT
}
......
......@@ -701,125 +701,85 @@ void OperatorWithKernel::RuntimeInferShape(const Scope& scope,
this->InferShape(&infer_shape_ctx);
}
struct RecordTime {
RecordTime(const std::string& name, const std::string& type)
: name_(name), type_(type), start_(std::chrono::system_clock::now()) {}
void inline stop() {
end_ = std::chrono::system_clock::now();
std::chrono::duration<double> diff = end_ - start_;
VLOG(1) << name_ << " " << type_ << " time record: " << diff.count();
}
~RecordTime() {
if (type_ == "elementwise_add") {
stop();
}
// stop();
}
std::string name_;
std::string type_;
std::chrono::system_clock::time_point start_;
std::chrono::system_clock::time_point end_;
};
void OperatorWithKernel::RunImpl(const Scope& scope,
const platform::Place& place) const {
RecordTime rt("OperatorWithKernel::All", type_);
{
RecordTime rt("OperatorWithKernel::InferShape", type_);
RuntimeInferShapeContext infer_shape_ctx(*this, scope);
this->InferShape(&infer_shape_ctx);
}
{
RecordTime* rt_1 = new RecordTime("OperatorWithKernel::Compute1", type_);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);
RuntimeInferShapeContext infer_shape_ctx(*this, scope);
this->InferShape(&infer_shape_ctx);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);
// check if op[type] has kernel registered.
auto& all_op_kernels = AllOpKernels();
auto kernels_iter = all_op_kernels.find(type_);
if (kernels_iter == all_op_kernels.end()) {
PADDLE_THROW(
"There are no kernels which are registered in the %s operator.",
type_);
}
// check if op[type] has kernel registered.
auto& all_op_kernels = AllOpKernels();
auto kernels_iter = all_op_kernels.find(type_);
if (kernels_iter == all_op_kernels.end()) {
PADDLE_THROW(
"There are no kernels which are registered in the %s operator.", type_);
}
OpKernelMap& kernels = kernels_iter->second;
OpKernelMap& kernels = kernels_iter->second;
// TODO(dzhwinter) : kernel fallback mechanism will be added when all the
// transform functions are ready.
// TODO(dzhwinter) : kernel fallback mechanism will be added when all the
// transform functions are ready.
// for (auto& candidate : kKernelPriority) {
// Do selection
// }
// for (auto& candidate : kKernelPriority) {
// Do selection
// }
auto expected_kernel_key =
this->GetExpectedKernelType(ExecutionContext(*this, scope, *dev_ctx));
VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
auto expected_kernel_key =
this->GetExpectedKernelType(ExecutionContext(*this, scope, *dev_ctx));
VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
auto kernel_iter = kernels.find(expected_kernel_key);
auto kernel_iter = kernels.find(expected_kernel_key);
#ifdef PADDLE_WITH_MKLDNN
// workaround for missing MKLDNN kernel when FLAGS_use_mkldnn env var is set
if (kernel_iter == kernels.end() &&
expected_kernel_key.library_type_ == LibraryType::kMKLDNN) {
VLOG(3) << "missing MKLDNN kernel: fallbacking to PLAIN one";
expected_kernel_key.library_type_ = LibraryType::kPlain;
expected_kernel_key.data_layout_ = DataLayout::kAnyLayout;
kernel_iter = kernels.find(expected_kernel_key);
}
// workaround for missing MKLDNN kernel when FLAGS_use_mkldnn env var is set
if (kernel_iter == kernels.end() &&
expected_kernel_key.library_type_ == LibraryType::kMKLDNN) {
VLOG(3) << "missing MKLDNN kernel: fallbacking to PLAIN one";
expected_kernel_key.library_type_ = LibraryType::kPlain;
expected_kernel_key.data_layout_ = DataLayout::kAnyLayout;
kernel_iter = kernels.find(expected_kernel_key);
}
#endif
if (kernel_iter == kernels.end()) {
PADDLE_THROW("op %s does not have kernel for %s", type_,
KernelTypeToString(expected_kernel_key));
}
if (kernel_iter == kernels.end()) {
PADDLE_THROW("op %s does not have kernel for %s", type_,
KernelTypeToString(expected_kernel_key));
}
// do data transformScope &transfer_scope;
std::vector<std::string> transfered_inplace_vars;
Scope* transfer_scope = nullptr;
// auto* transfer_scope =
// TryTransferData(scope, expected_kernel_key, &transfered_inplace_vars);
// do data transformScope &transfer_scope;
std::vector<std::string> transfered_inplace_vars;
auto* transfer_scope =
TryTransferData(scope, expected_kernel_key, &transfered_inplace_vars);
// exec scope is the scope that kernel actually executed on.
const Scope& exec_scope = scope;
// const Scope& exec_scope =
// (transfer_scope == nullptr ? scope : *transfer_scope);
// exec scope is the scope that kernel actually executed on.
const Scope& exec_scope =
(transfer_scope == nullptr ? scope : *transfer_scope);
if (!(expected_kernel_key.place_ == dev_ctx->GetPlace())) {
dev_ctx = pool.Get(expected_kernel_key.place_);
}
delete rt_1;
if (!(expected_kernel_key.place_ == dev_ctx->GetPlace())) {
dev_ctx = pool.Get(expected_kernel_key.place_);
}
RecordTime* rt_2 = new RecordTime("OperatorWithKernel::Compute2", type_);
kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx));
delete rt_2;
kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx));
RecordTime* rt_3 = new RecordTime("OperatorWithKernel::Compute3", type_);
if (!transfered_inplace_vars.empty()) {
// there is inplace variable has been transfered.
TransferInplaceVarsBack(scope, transfered_inplace_vars, *transfer_scope);
}
if (!transfered_inplace_vars.empty()) {
// there is inplace variable has been transfered.
TransferInplaceVarsBack(scope, transfered_inplace_vars, *transfer_scope);
}
/*For profiling/benchmark only*/
if (FLAGS_benchmark) {
dev_ctx->Wait();
}
/*For profiling/benchmark only*/
if (FLAGS_benchmark) {
dev_ctx->Wait();
}
if (FLAGS_check_nan_inf) {
for (auto& vname : OutputVars(true)) {
auto* var = exec_scope.FindVar(vname);
if (var == nullptr) continue;
if (var->IsType<framework::LoDTensor>()) {
CheckTensorNANOrInf(vname, var->Get<framework::LoDTensor>());
} else if (var->IsType<framework::SelectedRows>()) {
CheckTensorNANOrInf(vname,
var->Get<framework::SelectedRows>().value());
}
if (FLAGS_check_nan_inf) {
for (auto& vname : OutputVars(true)) {
auto* var = exec_scope.FindVar(vname);
if (var == nullptr) continue;
if (var->IsType<framework::LoDTensor>()) {
CheckTensorNANOrInf(vname, var->Get<framework::LoDTensor>());
} else if (var->IsType<framework::SelectedRows>()) {
CheckTensorNANOrInf(vname, var->Get<framework::SelectedRows>().value());
}
}
delete rt_3;
}
}
void OperatorWithKernel::TransferInplaceVarsBack(
......
......@@ -33,37 +33,34 @@ class ElementwiseOp : public framework::OperatorWithKernel {
using Tensor = framework::Tensor;
void InferShape(framework::InferShapeContext *ctx) const override {
if (!ctx->IsRuntime()) {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of elementwise op should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Y"),
"Input(Y) of elementwise op should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of elementwise op should not be null.");
PADDLE_ENFORCE(ctx->GetInputsVarType("Y").front() ==
framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the "
"received is %s [%s]",
ctx->GetInputsVarType("Y").front(),
ctx->Inputs("Y").front());
if (ctx->GetInputsVarType("X").front() ==
framework::proto::VarType::LOD_TENSOR) {
auto x_dim = ctx->GetInputDim("X");
auto y_dim = ctx->GetInputDim("Y");
PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(),
"Rank of first input must >= rank of second input.");
} else if (ctx->GetInputsVarType("X").front() ==
framework::proto::VarType::SELECTED_ROWS) {
PADDLE_ENFORCE((ctx->GetInputDim("Y").size() == 1u) &&
(ctx->GetInputDim("Y")[0] == 1),
"For elementwise_op, if X is Sparse, "
"Y must be scalar.");
} else {
PADDLE_THROW("X's type[%s] is not supported by elementwise_op.",
ctx->GetInputsVarType("X").front());
}
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of elementwise op should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Y"),
"Input(Y) of elementwise op should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of elementwise op should not be null.");
PADDLE_ENFORCE(
ctx->GetInputsVarType("Y").front() ==
framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the received is %s [%s]",
ctx->GetInputsVarType("Y").front(), ctx->Inputs("Y").front());
if (ctx->GetInputsVarType("X").front() ==
framework::proto::VarType::LOD_TENSOR) {
auto x_dim = ctx->GetInputDim("X");
auto y_dim = ctx->GetInputDim("Y");
PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(),
"Rank of first input must >= rank of second input.");
} else if (ctx->GetInputsVarType("X").front() ==
framework::proto::VarType::SELECTED_ROWS) {
PADDLE_ENFORCE((ctx->GetInputDim("Y").size() == 1u) &&
(ctx->GetInputDim("Y")[0] == 1),
"For elementwise_op, if X is Sparse, "
"Y must be scalar.");
} else {
PADDLE_THROW("X's type[%s] is not supported by elementwise_op.",
ctx->GetInputsVarType("X").front());
}
ctx->ShareDim("X", /*->*/ "Out");
......@@ -128,7 +125,7 @@ The equation is:
$$%s$$
- $X$: a tensor of any dimension.
- $X$: a tensor of any dimension.
- $Y$: a tensor whose dimensions must be less than or equal to the dimensions of $X$.
There are two cases for this operator:
......@@ -138,10 +135,10 @@ There are two cases for this operator:
For case 2:
1. Broadcast $Y$ to match the shape of $X$, where $axis$ is the start dimension index
for broadcasting $Y$ onto $X$.
1. Broadcast $Y$ to match the shape of $X$, where $axis$ is the start dimension index
for broadcasting $Y$ onto $X$.
2. If $axis$ is -1 (default), $axis = rank(X) - rank(Y)$.
3. The trailing dimensions of size 1 for $Y$ will be ignored for the consideration of
3. The trailing dimensions of size 1 for $Y$ will be ignored for the consideration of
subsequence, such as shape(Y) = (2, 1) => (2).
For example:
......@@ -155,7 +152,7 @@ For example:
shape(X) = (2, 3, 4, 5), shape(Y) = (2), with axis=0
shape(X) = (2, 3, 4, 5), shape(Y) = (2, 1), with axis=0
The inputs $X$ and $Y$ can carry the different LoD information.
The inputs $X$ and $Y$ can carry the different LoD information.
But the output only shares the LoD information with the input $X$.
)DOC",
......
......@@ -23,57 +23,56 @@ class AdamOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
// PADDLE_ENFORCE(ctx->HasInput("Param"),
// "Input(Param) of AdamOp should not be null.");
// PADDLE_ENFORCE(ctx->HasInput("Grad"),
// "Input(Grad) of AdamOp should not be null.");
// PADDLE_ENFORCE(ctx->HasInput("Moment1"),
// "Input(Moment1) of AdamOp should not be null.");
// PADDLE_ENFORCE(ctx->HasInput("Moment2"),
// "Input(Moment2) of AdamOp should not be null.");
// PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
// "Input(LearningRate) of AdamOp should not be null.");
// PADDLE_ENFORCE(ctx->HasInput("Beta1Pow"),
// "Input(Beta1Pow) of AdamOp should not be null.");
// PADDLE_ENFORCE(ctx->HasInput("Beta2Pow"),
// "Input(Beta2Pow) of AdamOp should not be null.");
// PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
// "Output(ParamOut) of AdamOp should not be null.");
// PADDLE_ENFORCE(ctx->HasOutput("Moment1Out"),
// "Output(Moment1Out) of AdamOp should not be null.");
// PADDLE_ENFORCE(ctx->HasOutput("Moment2Out"),
// "Output(Moment2Out) of AdamOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(Param) of AdamOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grad"),
"Input(Grad) of AdamOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Moment1"),
"Input(Moment1) of AdamOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Moment2"),
"Input(Moment2) of AdamOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
"Input(LearningRate) of AdamOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Beta1Pow"),
"Input(Beta1Pow) of AdamOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Beta2Pow"),
"Input(Beta2Pow) of AdamOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(ParamOut) of AdamOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Moment1Out"),
"Output(Moment1Out) of AdamOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Moment2Out"),
"Output(Moment2Out) of AdamOp should not be null.");
auto lr_dims = ctx->GetInputDim("LearningRate");
// PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
// "Learning rate should have 1 dimension");
PADDLE_ENFORCE_EQ(framework::product(lr_dims), 1,
"Learning rate should have 1 dimension");
auto beta1_pow_dims = ctx->GetInputDim("Beta1Pow");
// PADDLE_ENFORCE_EQ(framework::product(beta1_pow_dims), 1,
// "Beta1 power accumulator should have 1 dimension");
PADDLE_ENFORCE_EQ(framework::product(beta1_pow_dims), 1,
"Beta1 power accumulator should have 1 dimension");
auto beta2_pow_dims = ctx->GetInputDim("Beta2Pow");
// PADDLE_ENFORCE_EQ(framework::product(beta2_pow_dims), 1,
// "Beta2 power accumulator should have 1 dimension");
PADDLE_ENFORCE_EQ(framework::product(beta2_pow_dims), 1,
"Beta2 power accumulator should have 1 dimension");
auto param_dims = ctx->GetInputDim("Param");
// if (ctx->GetInputsVarType("Grad")[0] ==
// framework::proto::VarType::LOD_TENSOR) {
// PADDLE_ENFORCE_EQ(
// param_dims, ctx->GetInputDim("Grad"),
// "Param and Grad input of AdamOp should have same dimension");
// }
// PADDLE_ENFORCE_EQ(
// param_dims, ctx->GetInputDim("Moment1"),
// "Param and Moment1 input of AdamOp should have same dimension");
// PADDLE_ENFORCE_EQ(
// param_dims, ctx->GetInputDim("Moment2"),
// "Param and Moment2 input of AdamOp should have same dimension");
if (ctx->GetInputsVarType("Grad")[0] ==
framework::proto::VarType::LOD_TENSOR) {
PADDLE_ENFORCE_EQ(
param_dims, ctx->GetInputDim("Grad"),
"Param and Grad input of AdamOp should have same dimension");
}
PADDLE_ENFORCE_EQ(
param_dims, ctx->GetInputDim("Moment1"),
"Param and Moment1 input of AdamOp should have same dimension");
PADDLE_ENFORCE_EQ(
param_dims, ctx->GetInputDim("Moment2"),
"Param and Moment2 input of AdamOp should have same dimension");
ctx->SetOutputDim("ParamOut", param_dims);
ctx->SetOutputDim("Moment1Out", param_dims);
ctx->SetOutputDim("Moment2Out", param_dims);
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册