提交 ba29f12b 编写于 作者: Y yangqingyou

using rescale method to limit r, and add truncate ut

上级 80e1ec51
......@@ -191,6 +191,9 @@ public:
void max_pooling(FixedPointTensor* ret,
BooleanTensor<T>* pos = nullptr) const;
static void truncate(const FixedPointTensor* op, FixedPointTensor* ret,
size_t scaling_factor);
private:
static inline std::shared_ptr<CircuitContext> aby3_ctx() {
......@@ -201,9 +204,6 @@ private:
return paddle::mpc::ContextHolder::tensor_factory();
}
static void truncate(const FixedPointTensor* op, FixedPointTensor* ret,
size_t scaling_factor);
template<typename MulFunc>
static void mul_trunc(const FixedPointTensor<T, N>* lhs,
const FixedPointTensor<T, N>* rhs,
......
......@@ -14,6 +14,7 @@
#pragma once
#include <limits>
#include <memory>
#include <algorithm>
......@@ -21,7 +22,6 @@
#include "prng.h"
namespace aby3 {
template<typename T, size_t N>
FixedPointTensor<T, N>::FixedPointTensor(TensorAdapter<T>* share_tensor[2]) {
// TODO: check tensors' shapes
......@@ -245,15 +245,12 @@ void FixedPointTensor<T, N>::truncate(const FixedPointTensor<T, N>* op,
temp.emplace_back(
tensor_factory()->template create<T>(op->shape()));
}
// r', contraint in (constraint_low, contraint_upper)
// r'
aby3_ctx()->template gen_random_private(*temp[0]);
T contraint_upper = (T) 1 << (sizeof(T) * 8 - 2);
T contraint_low = - contraint_upper;
std::for_each(temp[0]->data(), temp[0]->data() + temp[0]->numel(),
[&contraint_upper, &contraint_low] (T& a) {
while ((a > contraint_upper || a < contraint_low)) {
a = aby3_ctx()->template gen_random_private<T>();
}
[] (T& a) {
a = (T) (a * std::pow(2, sizeof(T) * 8 - 2)
/ std::numeric_limits<T>::max());
});
//r'_0, r'_1
......
......@@ -3437,4 +3437,124 @@ TEST_F(FixedTensorTest, inv_sqrt_test) {
}
#ifdef USE_ABY3_TRUNC1 //use aby3 trunc1
TEST_F(FixedTensorTest, truncate1_msb_failed) {
std::vector<size_t> shape = { 1 };
std::shared_ptr<TensorAdapter<int64_t>> sl[3] = { gen(shape), gen(shape), gen(shape) };
std::shared_ptr<TensorAdapter<int64_t>> sout[6] = { gen(shape), gen(shape), gen(shape),
gen(shape), gen(shape), gen(shape)};
// lhs = 6 = 1 + 2 + 3, share before truncate
// zero share 0 = (1 << 62) + (1 << 62) - (1 << 63)
sl[0]->data()[0] = ((int64_t) 3 << 32) - ((uint64_t) 1 << 63);
sl[1]->data()[0] = ((int64_t) 2 << 32) + ((int64_t) 1 << 62);
sl[2]->data()[0] = ((int64_t) 1 << 32) + ((int64_t) 1 << 62);
auto pr = gen(shape);
// rhs = 15
pr->data()[0] = 6 << 16;
pr->scaling_factor() = 16;
Fix64N16 fl0(sl[0].get(), sl[1].get());
Fix64N16 fl1(sl[1].get(), sl[2].get());
Fix64N16 fl2(sl[2].get(), sl[0].get());
Fix64N16 fout0(sout[0].get(), sout[1].get());
Fix64N16 fout1(sout[2].get(), sout[3].get());
Fix64N16 fout2(sout[4].get(), sout[5].get());
auto p = gen(shape);
_t[0] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[0], [&](){
Fix64N16::truncate(&fl0, &fout0, 16);
fout0.reveal_to_one(0, p.get());
});
}
);
_t[1] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[1], [&](){
Fix64N16::truncate(&fl1, &fout1, 16);
fout1.reveal_to_one(0, nullptr);
});
}
);
_t[2] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[2], [&](){
Fix64N16::truncate(&fl2, &fout2, 16);
fout2.reveal_to_one(0, nullptr);
});
}
);
for (auto &t: _t) {
t.join();
}
// failed: result is not close to 6
EXPECT_GT(std::abs((p->data()[0] >> 16) - 6), 1000);
}
#else
TEST_F(FixedTensorTest, truncate3_msb_not_failed) {
std::vector<size_t> shape = { 1 };
std::shared_ptr<TensorAdapter<int64_t>> sl[3] = { gen(shape), gen(shape), gen(shape) };
std::shared_ptr<TensorAdapter<int64_t>> sout[6] = { gen(shape), gen(shape), gen(shape),
gen(shape), gen(shape), gen(shape)};
// lhs = 6 = 1 + 2 + 3, share before truncate
// zero share 0 = (1 << 62) + (1 << 62) - (1 << 63)
sl[0]->data()[0] = ((int64_t) 3 << 32) - ((uint64_t) 1 << 63);
sl[1]->data()[0] = ((int64_t) 2 << 32) + ((int64_t) 1 << 62);
sl[2]->data()[0] = ((int64_t) 1 << 32) + ((int64_t) 1 << 62);
auto pr = gen(shape);
// rhs = 15
pr->data()[0] = 6 << 16;
pr->scaling_factor() = 16;
Fix64N16 fl0(sl[0].get(), sl[1].get());
Fix64N16 fl1(sl[1].get(), sl[2].get());
Fix64N16 fl2(sl[2].get(), sl[0].get());
Fix64N16 fout0(sout[0].get(), sout[1].get());
Fix64N16 fout1(sout[2].get(), sout[3].get());
Fix64N16 fout2(sout[4].get(), sout[5].get());
auto p = gen(shape);
_t[0] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[0], [&](){
Fix64N16::truncate(&fl0, &fout0, 16);
fout0.reveal_to_one(0, p.get());
});
}
);
_t[1] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[1], [&](){
Fix64N16::truncate(&fl1, &fout1, 16);
fout1.reveal_to_one(0, nullptr);
});
}
);
_t[2] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[2], [&](){
Fix64N16::truncate(&fl2, &fout2, 16);
fout2.reveal_to_one(0, nullptr);
});
}
);
for (auto &t: _t) {
t.join();
}
EXPECT_EQ((p->data()[0] >> 16), 6);
}
#endif
} // namespace aby3
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册