-
Notifications
You must be signed in to change notification settings - Fork 26
/
Copy pathot.py
137 lines (109 loc) · 4.49 KB
/
ot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import hashlib
import logging
import pickle
import util
import yao
class ObliviousTransfer:
def __init__(self, socket, enabled=True, group=None):
self.socket = socket
self.enabled = enabled
self.group = group
def get_result(self, a_inputs, b_keys):
"""Send Alice's inputs and retrieve Bob's result of evaluation.
Args:
a_inputs: A dict mapping Alice's wires to (key, encr_bit) inputs.
b_keys: A dict mapping each Bob's wire to a pair (key, encr_bit).
Returns:
The result of the yao circuit evaluation.
"""
logging.debug("Sending inputs to Bob")
self.socket.send_wait(a_inputs)
logging.debug("Generating prime group to use for OT")
self.group = self.enabled and (self.group or util.PrimeGroup())
logging.debug("Sending prime group")
self.socket.send(self.group)
for _ in range(len(b_keys)):
w = self.socket.receive() # receive gate ID where to perform OT
logging.debug(f"Received gate ID {w}")
if self.enabled: # perform oblivious transfer
pair = (pickle.dumps(b_keys[w][0]), pickle.dumps(b_keys[w][1]))
self.ot_garbler(pair)
else:
to_send = (b_keys[w][0], b_keys[w][1])
self.socket.send(to_send)
return self.socket.receive()
def send_result(self, circuit, g_tables, pbits_out, b_inputs):
"""Evaluate circuit and send the result to Alice.
Args:
circuit: A dict containing circuit spec.
g_tables: Garbled tables of yao circuit.
pbits_out: p-bits of outputs.
b_inputs: A dict mapping Bob's wires to (clear) input bits.
Returns:
The result of the yao circuit evaluation.
"""
# map from Alice's wires to (key, encr_bit) inputs
a_inputs = self.socket.receive()
self.socket.send(True)
# map from Bob's wires to (key, encr_bit) inputs
b_inputs_encr = {}
logging.debug("Received Alice's inputs")
self.group = self.socket.receive()
logging.debug("Received group to use for OT")
for w, b_input in b_inputs.items():
logging.debug(f"Sending gate ID {w}")
self.socket.send(w)
if self.enabled:
b_inputs_encr[w] = pickle.loads(self.ot_evaluator(b_input))
else:
pair = self.socket.receive()
logging.debug(f"Received key pair, key {b_input} selected")
b_inputs_encr[w] = pair[b_input]
result = yao.evaluate(circuit, g_tables, pbits_out, a_inputs,
b_inputs_encr)
logging.debug("Sending circuit evaluation")
self.socket.send(result)
return result
def ot_garbler(self, msgs):
"""Oblivious transfer, Alice's side.
Args:
msgs: A pair (msg1, msg2) to suggest to Bob.
"""
logging.debug("OT protocol started")
G = self.group
# OT protocol based on Nigel Smart’s "Cryptography Made Simple"
c = G.gen_pow(G.rand_int())
h0 = self.socket.send_wait(c)
h1 = G.mul(c, G.inv(h0))
k = G.rand_int()
c1 = G.gen_pow(k)
e0 = util.xor_bytes(msgs[0], self.ot_hash(G.pow(h0, k), len(msgs[0])))
e1 = util.xor_bytes(msgs[1], self.ot_hash(G.pow(h1, k), len(msgs[1])))
self.socket.send((c1, e0, e1))
logging.debug("OT protocol ended")
def ot_evaluator(self, b):
"""Oblivious transfer, Bob's side.
Args:
b: Bob's input bit used to select one of Alice's messages.
Returns:
The message selected by Bob.
"""
logging.debug("OT protocol started")
G = self.group
# OT protocol based on Nigel Smart’s "Cryptography Made Simple"
c = self.socket.receive()
x = G.rand_int()
x_pow = G.gen_pow(x)
h = (x_pow, G.mul(c, G.inv(x_pow)))
c1, e0, e1 = self.socket.send_wait(h[b])
e = (e0, e1)
ot_hash = self.ot_hash(G.pow(c1, x), len(e[b]))
mb = util.xor_bytes(e[b], ot_hash)
logging.debug("OT protocol ended")
return mb
@staticmethod
def ot_hash(pub_key, msg_length):
"""Hash function for OT keys."""
key_length = (pub_key.bit_length() + 7) // 8 # key length in bytes
bytes = pub_key.to_bytes(key_length, byteorder="big")
return hashlib.shake_256(bytes).digest(msg_length)