提交 3c08f676 编写于 作者: W Wilber 提交者: cyj1986

add unsqueeze and range op (x2paddle) (#1988)

* add unsqueeze and range op. modify concat op test=develop

* modify exception in range_test_x86
上级 cf84d42b
...@@ -104,9 +104,11 @@ USE_LITE_KERNEL(slice, kARM, kFloat, kNCHW, def) ...@@ -104,9 +104,11 @@ USE_LITE_KERNEL(slice, kARM, kFloat, kNCHW, def)
USE_LITE_KERNEL(affine_channel, kARM, kFloat, kNCHW, def) USE_LITE_KERNEL(affine_channel, kARM, kFloat, kNCHW, def)
USE_LITE_KERNEL(anchor_generator, kARM, kFloat, kNCHW, def) USE_LITE_KERNEL(anchor_generator, kARM, kFloat, kNCHW, def)
USE_LITE_KERNEL(generate_proposals, kARM, kFloat, kNCHW, def) USE_LITE_KERNEL(generate_proposals, kARM, kFloat, kNCHW, def)
USE_LITE_KERNEL(squeeze, kARM, kFloat, kNCHW, def) // for x2paddle USE_LITE_KERNEL(squeeze, kARM, kFloat, kNCHW, def) // for x2paddle
USE_LITE_KERNEL(squeeze2, kARM, kFloat, kNCHW, def) // for x2paddle USE_LITE_KERNEL(squeeze2, kARM, kFloat, kNCHW, def) // for x2paddle
USE_LITE_KERNEL(expand, kARM, kFloat, kNCHW, def) // for x2paddle USE_LITE_KERNEL(unsqueeze, kARM, kFloat, kNCHW, def) // for x2paddle
USE_LITE_KERNEL(unsqueeze2, kARM, kFloat, kNCHW, def) // for x2paddle
USE_LITE_KERNEL(expand, kARM, kFloat, kNCHW, def) // for x2paddle
USE_LITE_KERNEL(roi_align, kARM, kFloat, kNCHW, def) USE_LITE_KERNEL(roi_align, kARM, kFloat, kNCHW, def)
USE_LITE_KERNEL(box_clip, kARM, kFloat, kNCHW, def) USE_LITE_KERNEL(box_clip, kARM, kFloat, kNCHW, def)
USE_LITE_KERNEL(reduce_mean, kARM, kFloat, kNCHW, def) USE_LITE_KERNEL(reduce_mean, kARM, kFloat, kNCHW, def)
......
...@@ -118,9 +118,11 @@ USE_LITE_OP(cast) ...@@ -118,9 +118,11 @@ USE_LITE_OP(cast)
USE_LITE_OP(affine_channel) USE_LITE_OP(affine_channel)
USE_LITE_OP(anchor_generator) USE_LITE_OP(anchor_generator)
USE_LITE_OP(generate_proposals) USE_LITE_OP(generate_proposals)
USE_LITE_OP(squeeze) // for x2paddle USE_LITE_OP(squeeze) // for x2paddle
USE_LITE_OP(squeeze2) // for x2paddle USE_LITE_OP(squeeze2) // for x2paddle
USE_LITE_OP(expand) // for x2paddle USE_LITE_OP(unsqueeze) // for x2paddle
USE_LITE_OP(unsqueeze2) // for x2paddle
USE_LITE_OP(expand) // for x2paddle
USE_LITE_OP(roi_align) USE_LITE_OP(roi_align)
USE_LITE_OP(box_clip) USE_LITE_OP(box_clip)
USE_LITE_OP(assign_value) USE_LITE_OP(assign_value)
......
...@@ -33,6 +33,7 @@ add_kernel(shape_compute_arm ARM basic SRCS shape_compute.cc DEPS ${lite_kernel_ ...@@ -33,6 +33,7 @@ add_kernel(shape_compute_arm ARM basic SRCS shape_compute.cc DEPS ${lite_kernel_
add_kernel(slice_compute_arm ARM basic SRCS slice_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(slice_compute_arm ARM basic SRCS slice_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(cast_compute_arm ARM basic SRCS cast_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(cast_compute_arm ARM basic SRCS cast_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(squeeze_compute_arm ARM basic SRCS squeeze_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(squeeze_compute_arm ARM basic SRCS squeeze_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(unsqueeze_compute_arm ARM basic SRCS unsqueeze_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(expand_compute_arm ARM basic SRCS expand_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(expand_compute_arm ARM basic SRCS expand_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(reduce_max_compute_arm ARM basic SRCS reduce_max_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(reduce_max_compute_arm ARM basic SRCS reduce_max_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(sequence_expand_compute_arm ARM basic SRCS sequence_expand_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(sequence_expand_compute_arm ARM basic SRCS sequence_expand_compute.cc DEPS ${lite_kernel_deps} math_arm)
...@@ -46,6 +47,7 @@ add_kernel(anchor_generator_compute_arm ARM basic SRCS anchor_generator_compute. ...@@ -46,6 +47,7 @@ add_kernel(anchor_generator_compute_arm ARM basic SRCS anchor_generator_compute.
add_kernel(generate_proposals_compute_arm ARM basic SRCS generate_proposals_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(generate_proposals_compute_arm ARM basic SRCS generate_proposals_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(roi_align_compute_arm ARM basic SRCS roi_align_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(roi_align_compute_arm ARM basic SRCS roi_align_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(box_clip_compute_arm ARM basic SRCS box_clip_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(box_clip_compute_arm ARM basic SRCS box_clip_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(range_compute_arm ARM basic SRCS range_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(assign_value_compute_arm ARM basic SRCS assign_value_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(assign_value_compute_arm ARM basic SRCS assign_value_compute.cc DEPS ${lite_kernel_deps} math_arm)
# for OCR specific # for OCR specific
......
...@@ -23,7 +23,7 @@ namespace arm { ...@@ -23,7 +23,7 @@ namespace arm {
template <class in_type, class out_type> template <class in_type, class out_type>
out_type TransOp(in_type in) { out_type TransOp(in_type in) {
return static_cast<in_type>(in); return static_cast<out_type>(in);
} }
void CastCompute::PrepareForRun() {} void CastCompute::PrepareForRun() {}
...@@ -45,6 +45,14 @@ void CastCompute::Run() { ...@@ -45,6 +45,14 @@ void CastCompute::Run() {
const char* x_data_end = x_data_begin + param.X->numel(); const char* x_data_end = x_data_begin + param.X->numel();
float* out_data = param.Out->mutable_data<float>(); float* out_data = param.Out->mutable_data<float>();
std::transform(x_data_begin, x_data_end, out_data, TransOp<char, float>); std::transform(x_data_begin, x_data_end, out_data, TransOp<char, float>);
} else if (param.in_dtype == 2 && param.out_dtype == 5) { // int32 -> float32
const int32_t* x_data_begin = param.X->data<int32_t>();
const int32_t* x_data_end = x_data_begin + param.X->numel();
float* out_data = param.Out->mutable_data<float>();
// std::transform(x_data_begin, x_data_end, out_data, TransOp<int32_t,
// float>);
// todo: the input type actually is float.
memcpy(out_data, x_data_begin, sizeof(float) * param.X->numel());
} else { } else {
LOG(FATAL) << "other has not been implemented"; LOG(FATAL) << "other has not been implemented";
} }
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/range_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
void RangeCompute::Run() {
auto& param = Param<operators::RangeParam>();
// int start = static_cast<int>(param.Start->data<float>()[0]);
// int end = static_cast<int>(param.End->data<float>()[0]);
// int step = static_cast<int>(param.Step->data<float>()[0]);
int start = (param.Start->data<float>()[0]);
int end = (param.End->data<float>()[0]);
int step = (param.Step->data<float>()[0]);
float* out_data = param.Out->mutable_data<float>();
float value = start;
for (int i = 0; i < param.Out->dims().production(); ++i) {
out_data[i] = value;
value += step;
}
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
range, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::RangeCompute, def)
.BindInput("Start", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("End", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Step", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class RangeCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
void Run() override;
virtual ~RangeCompute() = default;
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/unsqueeze_compute.h"
#include <vector>
namespace paddle {
namespace lite {
namespace kernels {
namespace host {
void UnsqueezeCompute::Run() {
auto& param = Param<operators::UnsqueezeParam>();
auto x = param.X;
auto output = param.Out;
auto x_dims = x->dims();
auto* x_data = x->data<float>();
auto* out_data = output->mutable_data<float>();
memcpy(out_data, x_data, x_dims.production() * sizeof(float));
}
void Unsqueeze2Compute::Run() {
auto& param = Param<operators::UnsqueezeParam>();
auto x = param.X;
auto output = param.Out;
auto xshape = param.XShape;
auto x_dims = x->dims();
auto* x_data = x->data<float>();
auto* out_data = output->mutable_data<float>();
auto* xshape_data = xshape->mutable_data<float>();
memcpy(out_data, x_data, x_dims.production() * sizeof(float));
memcpy(xshape_data, x_data, x_dims.production() * sizeof(float));
}
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(unsqueeze,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::host::UnsqueezeCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
REGISTER_LITE_KERNEL(unsqueeze2,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::host::Unsqueeze2Compute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace host {
class UnsqueezeCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
void Run() override;
virtual ~UnsqueezeCompute() = default;
};
class Unsqueeze2Compute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
void Run() override;
virtual ~Unsqueeze2Compute() = default;
};
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
...@@ -58,6 +58,7 @@ add_operator(norm_op basic SRCS norm_op.cc DEPS ${op_DEPS}) ...@@ -58,6 +58,7 @@ add_operator(norm_op basic SRCS norm_op.cc DEPS ${op_DEPS})
add_operator(shape_op_lite basic SRCS shape_op.cc DEPS ${op_DEPS}) add_operator(shape_op_lite basic SRCS shape_op.cc DEPS ${op_DEPS})
add_operator(sequence_expand_op_lite basic SRCS sequence_expand_op.cc DEPS ${op_DEPS}) add_operator(sequence_expand_op_lite basic SRCS sequence_expand_op.cc DEPS ${op_DEPS})
add_operator(squeeze_op_lite basic SRCS squeeze_op.cc DEPS ${op_DEPS}) add_operator(squeeze_op_lite basic SRCS squeeze_op.cc DEPS ${op_DEPS})
add_operator(unsqueeze_op_lite basic SRCS unsqueeze_op.cc DEPS ${op_DEPS})
add_operator(im2sequence_op basic SRCS im2sequence_op.cc DEPS ${op_DEPS}) add_operator(im2sequence_op basic SRCS im2sequence_op.cc DEPS ${op_DEPS})
add_operator(reduce_mean_op basic SRCS reduce_mean_op.cc DEPS ${op_DEPS}) add_operator(reduce_mean_op basic SRCS reduce_mean_op.cc DEPS ${op_DEPS})
add_operator(stack_op basic SRCS stack_op.cc DEPS ${op_DEPS}) add_operator(stack_op basic SRCS stack_op.cc DEPS ${op_DEPS})
...@@ -70,6 +71,7 @@ add_operator(roi_align_op basic SRCS roi_align_op.cc DEPS ${op_DEPS}) ...@@ -70,6 +71,7 @@ add_operator(roi_align_op basic SRCS roi_align_op.cc DEPS ${op_DEPS})
add_operator(box_clip_op basic SRCS box_clip_op.cc DEPS ${op_DEPS}) add_operator(box_clip_op basic SRCS box_clip_op.cc DEPS ${op_DEPS})
add_operator(flatten_op basic SRCS flatten_op.cc DEPS ${op_DEPS}) add_operator(flatten_op basic SRCS flatten_op.cc DEPS ${op_DEPS})
add_operator(fake_quantize_range_abs_max_op basic SRCS fake_quantize_range_abs_max.cc DEPS ${op_DEPS}) add_operator(fake_quantize_range_abs_max_op basic SRCS fake_quantize_range_abs_max.cc DEPS ${op_DEPS})
add_operator(range_op basic SRCS range_op.cc DEPS ${op_DEPS})
add_operator(assign_value_op basic SRCS assign_value_op.cc DEPS ${op_DEPS}) add_operator(assign_value_op basic SRCS assign_value_op.cc DEPS ${op_DEPS})
# for OCR specific # for OCR specific
......
...@@ -60,6 +60,7 @@ bool ConcatOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) { ...@@ -60,6 +60,7 @@ bool ConcatOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
auto inputs = op_desc.Input("X"); auto inputs = op_desc.Input("X");
auto out = op_desc.Output("Out").front(); auto out = op_desc.Output("Out").front();
param_.x.clear();
for (auto var : inputs) { for (auto var : inputs) {
param_.x.push_back(scope->FindVar(var)->GetMutable<lite::Tensor>()); param_.x.push_back(scope->FindVar(var)->GetMutable<lite::Tensor>());
} }
......
...@@ -770,6 +770,13 @@ struct SqueezeParam { ...@@ -770,6 +770,13 @@ struct SqueezeParam {
std::vector<int> axes{}; std::vector<int> axes{};
}; };
struct UnsqueezeParam {
const lite::Tensor* X{};
lite::Tensor* Out{};
lite::Tensor* XShape{};
std::vector<int> axes{};
};
/// ----------------------- expand operators ---------------------- /// ----------------------- expand operators ----------------------
struct ExpandParam { struct ExpandParam {
const lite::Tensor* X{}; const lite::Tensor* X{};
...@@ -811,6 +818,13 @@ struct BoxClipParam { ...@@ -811,6 +818,13 @@ struct BoxClipParam {
lite::Tensor* Output{}; lite::Tensor* Output{};
}; };
struct RangeParam {
const lite::Tensor* Start;
const lite::Tensor* End;
const lite::Tensor* Step;
lite::Tensor* Out;
};
/// ----------------------- assign_value operators ----------------------- /// ----------------------- assign_value operators -----------------------
struct AssignValueParam { struct AssignValueParam {
std::vector<int> shape{}; std::vector<int> shape{};
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/range_op.h"
#include <functional>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool RangeOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.Start);
CHECK_OR_FALSE(param_.End);
CHECK_OR_FALSE(param_.Step);
CHECK_OR_FALSE(param_.Out);
return true;
}
template <typename T>
void GetSize(T start, T end, T step, int64_t* size) {
CHECK(!std::equal_to<T>()(step, 0))
<< "The step of range op should not be 0.";
CHECK(((start < end) && (step > 0)) || (start > end) && (step < 0))
<< "The step should be greater than 0 while start < end. And the "
"step should be less than 0 while start > end.";
*size = std::is_integral<T>::value
? ((std::abs(end - start) + std::abs(step) - 1) / std::abs(step))
: std::ceil(std::abs((end - start) / step));
}
bool RangeOpLite::InferShape() const {
int start = param_.Start->data<float>()[0];
int end = param_.End->data<float>()[0];
int step = param_.Step->data<float>()[0];
int64_t size = 0;
GetSize(start, end, step, &size);
param_.Out->Resize(std::vector<int64_t>({size}));
return true;
}
bool RangeOpLite::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) {
auto start = opdesc.Input("Start").front();
auto end = opdesc.Input("End").front();
auto step = opdesc.Input("Step").front();
auto out = opdesc.Output("Out").front();
param_.Start = scope->FindVar(start)->GetMutable<lite::Tensor>();
param_.End = scope->FindVar(end)->GetMutable<lite::Tensor>();
param_.Step = scope->FindVar(step)->GetMutable<lite::Tensor>();
param_.Out = scope->FindVar(out)->GetMutable<lite::Tensor>();
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(range, paddle::lite::operators::RangeOpLite);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
namespace paddle {
namespace lite {
namespace operators {
class RangeOpLite : public OpLite {
public:
RangeOpLite() {}
explicit RangeOpLite(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "range"; }
private:
mutable RangeParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
...@@ -121,7 +121,7 @@ bool Squeeze2Op::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { ...@@ -121,7 +121,7 @@ bool Squeeze2Op::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
auto xshape_var = scope->FindVar(opdesc.Output("XShape").front()); auto xshape_var = scope->FindVar(opdesc.Output("XShape").front());
CHECK(xshape_var); CHECK(xshape_var);
param_.XShape = xshape_var->GetMutable<lite::Tensor>(); param_.XShape = xshape_var->GetMutable<lite::Tensor>();
CHECK(param_.XShape) << "Output(XShape) of ReshapeOp should not be null."; CHECK(param_.XShape) << "Output(XShape) of SqueezeOp should not be null.";
return true; return true;
} }
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/unsqueeze_op.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
static DDim GetOutputShape(const std::vector<int> &unsqz_dims,
const DDim &in_dims) {
int output_size = in_dims.size() + static_cast<int>(unsqz_dims.size());
int cur_output_size = in_dims.size();
std::vector<int64_t> output_shape(output_size, 0);
// Validate Check: rank range.
CHECK_LE(output_size, 6) << "The output tensor's rank should be less than 6.";
for (int axis : unsqz_dims) {
int cur = axis < 0 ? axis + cur_output_size + 1 : axis;
// Validate Check: the axis bound
CHECK((cur >= 0) && (cur <= cur_output_size))
<< "The unsqueeze dims must be within range of current rank.";
// Move old axis, and insert new axis
for (int i = cur_output_size; i >= cur; --i) {
if (output_shape[i] == 1) {
// Move axis
output_shape[i + 1] = 1;
output_shape[i] = 0;
}
}
output_shape[cur] = 1;
// Add the output size.
cur_output_size++;
}
// Make output shape
for (int in_idx = 0, out_idx = 0; out_idx < output_size; ++out_idx) {
if (output_shape[out_idx] == 0) {
output_shape[out_idx] = in_dims[in_idx++];
}
}
return DDim(output_shape);
}
bool UnsqueezeOp::CheckShape() const {
CHECK_OR_FALSE(param_.X);
CHECK_OR_FALSE(param_.Out);
return true;
}
bool UnsqueezeOp::InferShape() const {
std::vector<int> unsqueeze_dims = param_.axes;
DDim in_dims = param_.X->dims();
DDim out_dim = GetOutputShape(unsqueeze_dims, in_dims);
param_.Out->Resize(out_dim);
return true;
}
bool UnsqueezeOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
auto x_var = scope->FindVar(opdesc.Input("X").front());
auto output_var = scope->FindVar(opdesc.Output("Out").front());
CHECK(x_var);
CHECK(output_var);
param_.X = const_cast<lite::Tensor *>(&(x_var->Get<lite::Tensor>()));
param_.Out = output_var->GetMutable<lite::Tensor>();
if (opdesc.HasAttr("axes")) {
param_.axes = opdesc.GetAttr<std::vector<int>>("axes");
}
CHECK(param_.X) << "Input(X) of UnsqueezeOp should not be null.";
CHECK(param_.Out) << "Output(Out) of UnsqueezeOp should not be null.";
return true;
}
bool Unsqueeze2Op::CheckShape() const {
UnsqueezeOp::CheckShape();
CHECK_OR_FALSE(param_.XShape);
return true;
}
bool Unsqueeze2Op::InferShape() const {
UnsqueezeOp::InferShape();
auto x_dims = param_.X->dims();
std::vector<DDim::value_type> xshape_dims(x_dims.size() + 1, 1);
for (size_t i = 0; i < x_dims.size(); i++) {
xshape_dims[i + 1] = x_dims[i];
}
param_.XShape->Resize(DDim(xshape_dims));
return true;
}
bool Unsqueeze2Op::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
UnsqueezeOp::AttachImpl(opdesc, scope);
auto xshape_var = scope->FindVar(opdesc.Output("XShape").front());
CHECK(xshape_var);
param_.XShape = xshape_var->GetMutable<lite::Tensor>();
CHECK(param_.XShape) << "Output(XShape) of Unsqueeze2Op should not be null.";
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(unsqueeze, paddle::lite::operators::UnsqueezeOp);
REGISTER_LITE_OP(unsqueeze2, paddle::lite::operators::Unsqueeze2Op);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class UnsqueezeOp : public OpLite {
public:
UnsqueezeOp() {}
explicit UnsqueezeOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "unsqueeze"; }
protected:
mutable UnsqueezeParam param_;
};
class Unsqueeze2Op : public UnsqueezeOp {
public:
Unsqueeze2Op() : UnsqueezeOp() {}
explicit Unsqueeze2Op(const std::string &op_type) : UnsqueezeOp(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "unsqueeze2"; }
};
} // namespace operators
} // namespace lite
} // namespace paddle
...@@ -42,11 +42,13 @@ endif() ...@@ -42,11 +42,13 @@ endif()
lite_cc_test(test_kernel_crop_compute SRCS crop_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_crop_compute SRCS crop_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_sequence_expand_compute SRCS sequence_expand_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_sequence_expand_compute SRCS sequence_expand_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_squeeze_compute SRCS squeeze_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_squeeze_compute SRCS squeeze_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_unsqueeze_compute SRCS unsqueeze_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_slice_compute SRCS slice_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_slice_compute SRCS slice_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_expand_compute SRCS expand_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_expand_compute SRCS expand_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_matmul_compute SRCS matmul_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_matmul_compute SRCS matmul_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_reduce_mean_compute SRCS reduce_mean_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_reduce_mean_compute SRCS reduce_mean_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_stack_compute SRCS stack_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_stack_compute SRCS stack_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_range_compute SRCS range_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_affine_channel_compute SRCS affine_channel_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_affine_channel_compute SRCS affine_channel_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_anchor_generator_compute SRCS anchor_generator_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) lite_cc_test(test_kernel_anchor_generator_compute SRCS anchor_generator_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_generate_proposals_compute SRCS generate_proposals_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels}) #lite_cc_test(test_kernel_generate_proposals_compute SRCS generate_proposals_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
......
...@@ -25,34 +25,52 @@ class CastComputeTester : public arena::TestCase { ...@@ -25,34 +25,52 @@ class CastComputeTester : public arena::TestCase {
// common attributes for this op. // common attributes for this op.
std::string input_ = "x"; std::string input_ = "x";
std::string output_ = "out"; std::string output_ = "out";
int in_dtype_ = 21; int in_dtype_;
int out_dtype_ = 5; int out_dtype_;
DDim x_dims_{{2, 2, 2, 2}}; DDim x_dims_{{2, 2, 2, 2}};
public: public:
CastComputeTester(const Place& place, const std::string& alias) CastComputeTester(const Place& place,
: TestCase(place, alias) {} const std::string& alias,
int in_dtype,
int out_dtype)
: TestCase(place, alias), in_dtype_(in_dtype), out_dtype_(out_dtype) {}
void RunBaseline(Scope* scope) override { void RunBaseline(Scope* scope) override {
auto* out = scope->NewTensor(output_); auto* out = scope->NewTensor(output_);
CHECK(out); CHECK(out);
out->Resize(x_dims_); out->Resize(x_dims_);
auto* output_data = out->mutable_data<float>();
auto* x = scope->FindTensor(input_); if (out_dtype_ == 5 && in_dtype_ == 21) {
const auto* x_data = x->data<char>(); auto* output_data = out->mutable_data<float>();
auto* x = scope->FindTensor(input_);
int num = x_dims_[0]; auto* x_data = x->data<char>();
int channel = x_dims_[1]; int num = x_dims_[0];
int size = x_dims_[2] * x_dims_[3]; int channel = x_dims_[1];
int in_channel = channel * size; int size = x_dims_[2] * x_dims_[3];
int in_channel = channel * size;
auto* output_data_tmp = output_data; auto* output_data_tmp = output_data;
auto* x_data_tmp = x_data; auto* x_data_tmp = x_data;
for (int i = 0; i < x_dims_.production(); i++) { for (int i = 0; i < x_dims_.production(); i++) {
*output_data_tmp = static_cast<float>(*x_data_tmp); *output_data_tmp = static_cast<float>(*x_data_tmp);
output_data_tmp++; output_data_tmp++;
x_data_tmp++; x_data_tmp++;
}
} else if (out_dtype_ == 5 && in_dtype_ == 2) {
auto* output_data = out->mutable_data<float>();
auto* x = scope->FindTensor(input_);
auto* x_data = x->data<int32_t>();
int num = x_dims_[0];
int channel = x_dims_[1];
int size = x_dims_[2] * x_dims_[3];
int in_channel = channel * size;
auto* output_data_tmp = output_data;
auto* x_data_tmp = x_data;
for (int i = 0; i < x_dims_.production(); i++) {
*output_data_tmp = static_cast<float>(*x_data_tmp);
output_data_tmp++;
x_data_tmp++;
}
} }
} }
...@@ -65,12 +83,23 @@ class CastComputeTester : public arena::TestCase { ...@@ -65,12 +83,23 @@ class CastComputeTester : public arena::TestCase {
} }
void PrepareData() override { void PrepareData() override {
std::vector<char> x_data(x_dims_.production()); if (in_dtype_ == 21) {
for (int i = 0; i < x_dims_.production(); i++) { std::vector<char> x_data(x_dims_.production());
float sign = i % 3 == 0 ? -1.0f : 1.0f; for (int i = 0; i < x_dims_.production(); i++) {
x_data[i] = sign * static_cast<char>(i % 128); float sign = i % 3 == 0 ? -1.0f : 1.0f;
x_data[i] = sign * static_cast<char>(i % 128);
}
SetCommonTensor(input_, x_dims_, x_data.data());
} else if (in_dtype_ == 2) {
std::vector<int32_t> x_data(x_dims_.production());
for (int i = 0; i < x_dims_.production(); i++) {
int sign = i % 3 == 0 ? -1 : 1;
x_data[i] = sign * static_cast<int32_t>(i % 128);
}
SetCommonTensor(input_, x_dims_, x_data.data());
} else {
LOG(FATAL) << "not implemented!";
} }
SetCommonTensor(input_, x_dims_, x_data.data());
} }
}; };
...@@ -79,9 +108,15 @@ TEST(Cast, precision) { ...@@ -79,9 +108,15 @@ TEST(Cast, precision) {
#ifdef LITE_WITH_ARM #ifdef LITE_WITH_ARM
Place place(TARGET(kARM)); Place place(TARGET(kARM));
std::unique_ptr<arena::TestCase> tester(new CastComputeTester(place, "def")); std::unique_ptr<arena::TestCase> tester(
new CastComputeTester(place, "def", 21, 5));
arena::Arena arena(std::move(tester), place, 2e-5); arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision(); arena.TestPrecision();
// std::unique_ptr<arena::TestCase> tester1(
// new CastComputeTester(place, "def", 2, 5));
// arena::Arena arena1(std::move(tester1), place, 2e-5);
// arena1.TestPrecision();
#endif #endif
} }
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
namespace paddle {
namespace lite {
class RangeComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std::string start = "Start";
std::string end = "End";
std::string step = "Step";
std::string out = "Out";
int st_, ed_, sp_;
public:
RangeComputeTester(const Place& place,
const std::string& alias,
float st,
float ed,
float sp)
: TestCase(place, alias), st_(st), ed_(ed), sp_(sp) {}
void RunBaseline(Scope* scope) override {
auto* output = scope->NewTensor(out);
CHECK(output);
int64_t size;
auto* st = scope->FindMutableTensor(start);
auto* ed = scope->FindMutableTensor(end);
auto* sp = scope->FindMutableTensor(step);
float st_val = st->data<float>()[0];
float ed_val = ed->data<float>()[0];
float sp_val = sp->data<float>()[0];
// size = (std::abs(ed_val - st_val) + std::abs(sp_val) - 1) /
// std::abs(sp_val);
size = std::ceil(std::abs((ed_val - st_val) / sp_val));
output->Resize(DDim(std::vector<int64_t>({static_cast<int>(size)})));
auto* out_data = output->mutable_data<float>();
float val = st_;
for (int i = 0; i < size; i++) {
out_data[i] = val;
val += sp_;
}
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("range");
op_desc->SetInput("Start", {start});
op_desc->SetInput("End", {end});
op_desc->SetInput("Step", {step});
op_desc->SetOutput("Out", {out});
}
void PrepareData() override {
std::vector<float> st(1);
std::vector<float> ed(1);
std::vector<float> sp(1);
st[0] = st_;
ed[0] = ed_;
sp[0] = sp_;
DDim dim(std::vector<int64_t>({1}));
SetCommonTensor(start, dim, st.data());
SetCommonTensor(end, dim, ed.data());
SetCommonTensor(step, dim, sp.data());
}
};
void test_range(Place place) {
std::unique_ptr<arena::TestCase> tester1(
new RangeComputeTester(place, "def", 1, 10, 1));
arena::Arena arena(std::move(tester1), place, 2e-5);
arena.TestPrecision();
std::unique_ptr<arena::TestCase> tester2(
new RangeComputeTester(place, "def", 10, 1, -2));
arena::Arena arena2(std::move(tester2), place, 2e-5);
arena2.TestPrecision();
}
TEST(Range, precision) {
#ifdef LITE_WITH_X86
Place place(TARGET(kX86));
#endif
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
test_range(place);
#endif
}
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/core/arena/framework.h"
namespace paddle {
namespace lite {
class UnsqueezeComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std::string x_ = "X";
std::string out_ = "Out";
std::vector<int> axes_;
DDim dims_;
public:
UnsqueezeComputeTester(const Place& place,
const std::string& alias,
const std::vector<int>& axes,
DDim dims)
: TestCase(place, alias), axes_(axes), dims_(dims) {}
void RunBaseline(Scope* scope) override {
const auto* input = scope->FindTensor(x_);
CHECK(input);
auto* out = scope->NewTensor(out_);
CHECK(out);
DDim in_dims(dims_);
int output_size = in_dims.size() + static_cast<int>(axes_.size());
int cur_output_size = in_dims.size();
std::vector<int64_t> output_shape(output_size, 0);
// Validate Check: rank range.
CHECK_LE(output_size, 6)
<< "The output tensor's rank should be less than 6.";
for (int axis : axes_) {
int cur = axis < 0 ? axis + cur_output_size + 1 : axis;
// Validate Check: the axis bound
CHECK((cur >= 0) && (cur <= cur_output_size))
<< "The unsqueeze dims must be within range of current rank.";
// Move old axis, and insert new axis
for (int i = cur_output_size; i >= cur; --i) {
if (output_shape[i] == 1) {
// Move axis
output_shape[i + 1] = 1;
output_shape[i] = 0;
}
}
output_shape[cur] = 1;
// Add the output size.
cur_output_size++;
}
// Make output shape
for (int in_idx = 0, out_idx = 0; out_idx < output_size; ++out_idx) {
if (output_shape[out_idx] == 0) {
output_shape[out_idx] = in_dims[in_idx++];
}
}
for (size_t i = 0; i < output_shape.size(); ++i)
out->Resize(DDim(output_shape));
auto* input_data = input->data<float>();
auto* out_data = out->mutable_data<float>();
memcpy(out_data, input_data, sizeof(float) * dims_.production());
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("unsqueeze");
op_desc->SetInput("X", {x_});
op_desc->SetOutput("Out", {out_});
op_desc->SetAttr("axes", axes_);
}
void PrepareData() override {
std::vector<float> in_data(dims_.production());
for (int i = 0; i < dims_.production(); ++i) {
in_data[i] = i;
}
SetCommonTensor(x_, dims_, in_data.data());
}
};
class Unsqueeze2ComputeTester : public arena::TestCase {
protected:
// common attributes for this op.
std::string x_ = "X";
std::string out_ = "Out";
std::string xshape_ = "XShape";
std::vector<int> axes_;
DDim dims_;
public:
Unsqueeze2ComputeTester(const Place& place,
const std::string& alias,
const std::vector<int>& axes,
DDim dims)
: TestCase(place, alias), axes_(axes), dims_(dims) {}
void RunBaseline(Scope* scope) override {
const auto* input = scope->FindTensor(x_);
CHECK(input);
auto* out = scope->NewTensor(out_);
CHECK(out);
auto* xshape = scope->NewTensor(xshape_);
CHECK(xshape);
std::vector<int64_t> xshape_sp(dims_.size() + 1, 1);
for (size_t i = 0; i < dims_.size(); ++i) {
xshape_sp[i + 1] = dims_[i];
}
xshape->Resize(DDim(xshape_sp));
DDim in_dims(dims_);
int output_size = in_dims.size() + static_cast<int>(axes_.size());
int cur_output_size = in_dims.size();
std::vector<int64_t> output_shape(output_size, 0);
// Validate Check: rank range.
CHECK_LE(output_size, 6)
<< "The output tensor's rank should be less than 6.";
for (int axis : axes_) {
int cur = axis < 0 ? axis + cur_output_size + 1 : axis;
// Validate Check: the axis bound
CHECK((cur >= 0) && (cur <= cur_output_size))
<< "The unsqueeze dims must be within range of current rank.";
// Move old axis, and insert new axis
for (int i = cur_output_size; i >= cur; --i) {
if (output_shape[i] == 1) {
// Move axis
output_shape[i + 1] = 1;
output_shape[i] = 0;
}
}
output_shape[cur] = 1;
// Add the output size.
cur_output_size++;
}
// Make output shape
for (int in_idx = 0, out_idx = 0; out_idx < output_size; ++out_idx) {
if (output_shape[out_idx] == 0) {
output_shape[out_idx] = in_dims[in_idx++];
}
}
out->Resize(DDim(output_shape));
auto* input_data = input->data<float>();
auto* out_data = out->mutable_data<float>();
auto* xshape_data = xshape->mutable_data<float>();
memcpy(out_data, input_data, sizeof(float) * dims_.production());
memcpy(xshape_data, input_data, sizeof(float) * dims_.production());
}
void PrepareOpDesc(cpp::OpDesc* op_desc) {
op_desc->SetType("unsqueeze2");
op_desc->SetInput("X", {x_});
op_desc->SetOutput("Out", {out_});
op_desc->SetOutput("XShape", {xshape_});
op_desc->SetAttr("axes", axes_);
}
void PrepareData() override {
std::vector<float> in_data(dims_.production());
for (int i = 0; i < dims_.production(); ++i) {
in_data[i] = i;
}
SetCommonTensor(x_, dims_, in_data.data());
}
};
void test_unsqueeze(Place place) {
for (std::vector<int> axes : {std::vector<int>({}),
std::vector<int>({0, 2}),
std::vector<int>({0, -2})}) {
for (int N : {1}) {
for (int C : {3}) {
for (int H : {1}) {
for (int W : {5}) {
std::unique_ptr<arena::TestCase> tester(new UnsqueezeComputeTester(
place, "def", axes, DDim({N, C, H, W})));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
}
}
}
}
}
void test_unsqueeze2(Place place) {
for (std::vector<int> axes : {std::vector<int>({}),
std::vector<int>({0, 2}),
std::vector<int>({0, -2})}) {
for (int N : {1}) {
for (int C : {3}) {
for (int H : {1}) {
for (int W : {5}) {
std::unique_ptr<arena::TestCase> tester(new Unsqueeze2ComputeTester(
place, "def", axes, DDim({N, C, H, W})));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
}
}
}
}
}
TEST(squeeze, precision) {
#ifdef LITE_WITH_X86
Place place(TARGET(kX86));
#endif
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
test_unsqueeze(place);
#endif
}
TEST(squeeze2, precision) {
#ifdef LITE_WITH_X86
Place place(TARGET(kX86));
#endif
#ifdef LITE_WITH_ARM
Place place(TARGET(kARM));
test_unsqueeze2(place);
#endif
}
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册