未验证 提交 2a1d32a2 编写于 作者: J jed 提交者: GitHub

Merge pull request #113 from kaih70/master

fix ks bug & add one hot encoding map
......@@ -21,7 +21,6 @@ namespace psi {
PseudorandomNumberGenerator::PseudorandomNumberGenerator(const block &seed)
: _ctr(0), _now_byte(0) {
set_seed(seed);
refill_buffer();
}
void PseudorandomNumberGenerator::set_seed(const block &b) {
......@@ -59,4 +58,10 @@ void PseudorandomNumberGenerator::get_array(void *res, size_t len) {
}
}
template <>
bool PseudorandomNumberGenerator::get<bool>() {
uint8_t data;
get_array(&data, sizeof(data));
return data & 1;
}
} // namespace smc
......@@ -69,5 +69,6 @@ from .version import version
from .layers import mpc_math_op_patch
from . import input
from . import initializer
from . import metrics
mpc_math_op_patch.monkey_patch_mpc_variable()
......@@ -17,7 +17,10 @@ Import data_utils module.
from . import aby3
from . import alignment
from . import one_hot_encoding
from .alignment import *
from .one_hot_encoding import *
__all__ = []
__all__ += alignment.__all__
__all__ += one_hot_encoding.__all__
......@@ -338,7 +338,7 @@ def _transpile_type_and_shape(block):
for op in block.ops:
if _is_supported_op(op.type):
if op.type == 'fill_constant':
op._set_attr(name='shape', val=(2L, 1L))
op._set_attr(name='shape', val=(2, 1))
# set default MPC value for fill_constant OP
op._set_attr(name='value', val=MPC_ONE_SHARE)
op._set_attr(name='dtype', val=3)
......@@ -482,7 +482,7 @@ def decrypt_model(mpc_model_dir, plain_model_path, mpc_model_filename=None, plai
new_type = str(mpc_op.type)[len(MPC_OP_PREFIX):]
mpc_op.desc.set_type(new_type)
elif mpc_op.type == 'fill_constant':
mpc_op._set_attr(name='shape', val=(1L))
mpc_op._set_attr(name='shape', val=(1))
mpc_op._set_attr(name='value', val=1.0)
mpc_op._set_attr(name='dtype', val=5)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module provide one hot encoding tools, implemented by OT (Oblivious Transfer)-based
PSI (Private Set Intersection) algorithm.
"""
from multiprocessing.connection import Client, Listener
import mpc_data_utils as mdu
__all__ = ['one_hot_encoding_map', ]
def one_hot_encoding_map(input_set, host_addr, is_client=True):
"""
A protocol to get agreement between 2 parties for encoding one
discrete feature to one hot vector via OT-PSI.
Args:
input_set (set:str): The set of possible feature value owned by this
party. Element of set is str, convert before pass in.
host_addr (str): The info of host_addr,e.g., ip:port
is_receiver (bool): True if this party plays as socket client
otherwise, plays as socket server
Return Val: dict, int.
dict key: feature values in input_set,
dict value: corresponding idx in one hot vector.
int: length of one hot vector for this feature.
Examples:
.. code-block:: python
import paddle_fl.mpc.data_utils
import sys
is_client = sys.argv[1] == "1"
a = set([str(x) for x in range(7)])
b = set([str(x) for x in range(5, 10)])
addr = "127.0.0.1:33784"
ins = a if is_client else b
x, y = paddle_fl.mpc.data_utils.one_hot_encoding_map(ins, addr, is_client)
# y = 10
# x['5'] = 0, x['6'] = 1
# for those feature val owned only by one party, dict val shall
not be conflicting.
print(x, y)
"""
ip = host_addr.split(":")[0]
port = int(host_addr.split(":")[1])
if is_client:
intersection = input_set
intersection = mdu.recv_psi(ip, port, intersection)
intersection = sorted(list(intersection))
# Only the receiver can obtain the result.
# Send result to other parties.
else:
ret_code = mdu.send_psi(port, input_set)
if ret_code != 0:
raise RuntimeError("Errors occurred in PSI send lib, "
"error code = {}".format(ret_code))
if not is_client:
server = Listener((ip, port))
conn = Client((ip, port)) if is_client else server.accept()
if is_client:
conn.send(intersection)
diff_size_local = len(input_set) - len(intersection)
conn.send(diff_size_local)
diff_size_remote = conn.recv()
else:
intersection = conn.recv()
diff_size_local = len(input_set) - len(intersection)
diff_size_remote = conn.recv()
conn.send(diff_size_local)
conn.close()
if not is_client:
server.close()
ret = dict()
cnt = 0
for x in intersection:
ret[x] = cnt
cnt += 1
if is_client:
cnt += diff_size_remote
for x in [x for x in input_set if x not in intersection]:
ret[x] = cnt
cnt += 1
return ret, len(intersection) + diff_size_local + diff_size_remote
......@@ -50,8 +50,6 @@ class KSstatistic(MetricBase):
import paddle_fl.mpc
import numpy as np
# init the KSstatistic
ks = paddle_fl.mpc.metrics.KSstatistic('ks')
# suppose that batch_size is 128
batch_num = 100
......@@ -65,6 +63,10 @@ class KSstatistic(MetricBase):
preds = np.concatenate((class0_preds, class1_preds), axis=1)
labels = np.random.randint(2, size = (batch_size, 1))
# init the KSstatistic for each batch
# to get global ks statistic, init ks before for-loop
ks = paddle_fl.mpc.metrics.KSstatistic('ks')
ks.update(preds = preds, labels = labels)
# shall be some score closing to 0.1 as the preds are randomly assigned
......
......@@ -17,14 +17,16 @@ This module test align in aby3 module.
"""
import unittest
from multiprocessing import Process
import multiprocessing as mp
import paddle_fl.mpc.data_utils.alignment as alignment
class TestDataUtilsAlign(unittest.TestCase):
def run_align(self, input_set, party_id, endpoints, is_receiver):
@staticmethod
def run_align(input_set, party_id, endpoints, is_receiver, ret_list):
"""
Call align function in data_utils.
:param input_set:
......@@ -37,7 +39,7 @@ class TestDataUtilsAlign(unittest.TestCase):
party_id=party_id,
endpoints=endpoints,
is_receiver=is_receiver)
self.assertEqual(result, {'0'})
ret_list.append(result)
def test_align(self):
"""
......@@ -49,14 +51,27 @@ class TestDataUtilsAlign(unittest.TestCase):
set_1 = {'0', '10', '11', '111'}
set_2 = {'0', '30', '33', '333'}
party_0 = Process(target=self.run_align, args=(set_0, 0, endpoints, True))
party_1 = Process(target=self.run_align, args=(set_1, 1, endpoints, False))
party_2 = Process(target=self.run_align, args=(set_2, 2, endpoints, False))
mp.set_start_method('spawn')
manager = mp.Manager()
ret_list = manager.list()
party_0 = mp.Process(target=self.run_align, args=(set_0, 0, endpoints, True, ret_list))
party_1 = mp.Process(target=self.run_align, args=(set_1, 1, endpoints, False, ret_list))
party_2 = mp.Process(target=self.run_align, args=(set_2, 2, endpoints, False, ret_list))
party_1.start()
party_2.start()
party_0.start()
party_0.join()
party_1.join()
party_2.join()
self.assertEqual(3, len(ret_list))
self.assertEqual(ret_list[0], ret_list[1])
self.assertEqual(ret_list[0], ret_list[2])
self.assertEqual({'0'}, ret_list[0])
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册