Skip to content
This repository was archived by the owner on Mar 25, 2025. It is now read-only.

Commit b98fc35

Browse files
committed
wip
1 parent 5a83318 commit b98fc35

File tree

8 files changed

+279
-69
lines changed

8 files changed

+279
-69
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# coding: utf-8
2+
"""
3+
Demo to show for DevCon, focusing on client interface, allowing clients to disconnect after giving inputs and reconnect to get output
4+
while also allowing servers to store values
5+
"""
6+
from typing import Type
7+
from Compiler.types import sint, regint, Array, MemValue
8+
from Compiler.library import print_ln, do_while, for_range, accept_client_connection, listen_for_clients, if_, if_e, else_, crash
9+
from Compiler.instructions import closeclientconnection
10+
from Compiler.util import if_else
11+
from Compiler.circuit import sha3_256
12+
13+
PORTNUM = 8013
14+
MAX_DATA_PROVIDERS = 1000
15+
NUM_DATA_PROVIDERS = 1
16+
17+
18+
def accept_client():
19+
client_socket_id = accept_client_connection(PORTNUM)
20+
placeholder = regint.read_from_socket(client_socket_id)
21+
return client_socket_id
22+
23+
def computation(client_values: sint.Array):
24+
result = sint.Array(5)
25+
# num_data_providers should be public
26+
num_data_providers = NUM_DATA_PROVIDERS
27+
data = sint.Array(num_data_providers)
28+
@for_range(num_data_providers)
29+
def _(i):
30+
data[i] = client_values[1+i]
31+
# Only sort data if there are more than 1 data provider
32+
# Otherwise, the program will fail to compile.
33+
if num_data_providers > 1:
34+
data.sort()
35+
# num_data_providers
36+
result[0]=sint(num_data_providers)
37+
# Max
38+
result[1] = data[num_data_providers-1]
39+
# Sum
40+
result[2] = sum(data)
41+
median_odd = sint(0)
42+
median_even = sint(0)
43+
area = sint(0)
44+
@for_range(num_data_providers)
45+
def _(i):
46+
median_odd.update(median_odd+(num_data_providers==2*sint(i)+sint(1))*data[i])
47+
median_even.update(median_even+(num_data_providers==2*sint(i))*data[i]/2+(num_data_providers-2==2*sint(i))*data[i]/2)
48+
area.update(area+(2*i+1)*data[i])
49+
# Median
50+
result[3] = (num_data_providers%2)*median_odd + (1-num_data_providers%2)*median_even
51+
52+
# Note that Gini coefficient = (area/(num_data_providers*result[1])) - 1
53+
# But we leave that to client side handling to optimize calculation in mpc
54+
result[4] = area
55+
return result
56+
57+
def main():
58+
59+
# Start listening for client socket connections
60+
listen_for_clients(PORTNUM)
61+
print_ln('Listening for client connections on base port %s', PORTNUM)
62+
63+
client_socket_id = accept_client()
64+
# put as array to make it object
65+
# First element is the number of clients
66+
client_values = sint.Array(1 + MAX_DATA_PROVIDERS)
67+
commitment_values = sint.Array(MAX_DATA_PROVIDERS)
68+
69+
client_values.read_from_file(0)
70+
commitment_values.read_from_file(1 + MAX_DATA_PROVIDERS)
71+
72+
result = computation(client_values)
73+
74+
return_array = sint.Array(5 + MAX_DATA_PROVIDERS)
75+
return_array[0] = result[0]
76+
return_array[1] = result[1]
77+
return_array[2] = result[2]
78+
return_array[3] = result[3]
79+
return_array[4] = result[4]
80+
81+
82+
83+
# Return the commitment values to the client
84+
@for_range(MAX_DATA_PROVIDERS)
85+
def _(i):
86+
return_array[5+i] = commitment_values[i]
87+
88+
return_array.reveal_to_clients([client_socket_id])
89+
90+
print_ln('Now closing this connection')
91+
closeclientconnection(client_socket_id)
92+
93+
main()
+106
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# coding: utf-8
2+
"""
3+
Demo to show for DevCon, focusing on client interface, allowing clients to disconnect after giving inputs and reconnect to get output
4+
while also allowing servers to store values
5+
"""
6+
from typing import Type
7+
from Compiler.types import sint, regint, Array, MemValue
8+
from Compiler.library import print_ln, do_while, for_range, accept_client_connection, listen_for_clients, if_, if_e, else_, crash
9+
from Compiler.instructions import closeclientconnection
10+
from Compiler.util import if_else
11+
from Compiler.circuit import sha3_256
12+
from Compiler.GC.types import sbitvec, sbit
13+
14+
15+
SECRET_INDEX = regint(1)
16+
PORTNUM = 8013
17+
MAX_DATA_PROVIDERS = 1000
18+
INPUT_BYTES = 4
19+
DELTA = '63ce8963a8d236f2ed41f2c6f4cc81fa'
20+
ZERO_ENCODINGS = ['4d367c38a3db6de1c3273559ccc80a4d', '54d6d355990433a2acc3c9a4eb76c911', 'f796787dc39a2d9b611637b56b3851f7', '85d9e61682e59a1b01c7cef839f06a4c', 'f397bc83b842ee1884a30beb43be8ef0', 'b80264fb280eaa10fe509e83450dca9c', '99769731800fc7723c85a4e27e8c8748', 'cd8310c1c2f17e9a5b9ae08d08279fc8', 'ee577842a72ea783596f773b7ca6d317', 'ff235b8c43e2f63e92c6e35a2f31090e', '5f53449406446cba75abc40dc1515b25', 'b201d8d8a59786752500572db96897ed', '4a612e9042d49a3012c254018bfd5663', '90fefc14b6d980bcc272b3f6586f1beb', '6ef28030ac56cca07a9e50a718a03149', 'b2e8609a91729106774fb23a54748bb3', '083ca4d4e9ca27b68945140113ecb363', '135b8dfb68420cf9ce7c8824e0523c94', '7f870730d973029cad3561a9ed02a461', 'b791d7213350d1c163f2854083592a1c', '1fa7c67f0cfd635bfadaba7adce6ee79', '0f3287fed18a40bf9a734f97423562b8', '9b58b466655bf6ee87dce6aab6707323', 'e488f970c6dfc51f6cff9a5518bf1d5c', 'e446fd7f2ae8ee84558e0e2e7da523ab', '3d2c96a16177c28a020689954b71eabc', 'a34eb40287c086a8b94a7159f9fde686', 'debc23a0acd3c90a2beb92e8254cff56', 'f8d2d7a6173c987d1036af6f7a652140', '121bf26a0e09ed1544e793f4a924fe03', '6d25836999abc4eccdd2d9ddd6d6853e', 'd0fdbd9d6a8cb4a25e60400d37704ad1']
21+
22+
23+
def accept_client():
24+
client_socket_id = accept_client_connection(PORTNUM)
25+
placeholder = regint.read_from_socket(client_socket_id)
26+
return client_socket_id
27+
28+
29+
def client_input(t: Type[sint], client_socket_id: regint):
30+
"""
31+
Send share of random value, receive input and deduce share.
32+
"""
33+
received = t.receive_from_client(2, client_socket_id)
34+
return received[0], sbitvec(received[1], 256)
35+
36+
def calculate_data_commitment(num_bytes_followers: int, followers: sint, delta: sbitvec, encoding: list[sbitvec], nonce: sbitvec):
37+
# Adjust based on data_type
38+
ASCII_BASE = 48
39+
DOT_ASCII = 46
40+
followers_bits_list = []
41+
number = followers
42+
divisors = [sint(10 ** (num_bytes_followers - i)) for i in range(num_bytes_followers)]
43+
for divisor in divisors:
44+
curr_digit = number.int_div(divisor, 4*num_bytes_followers)
45+
followers_bits_list.extend(sbit(ele) for ele in sbitvec(curr_digit+ASCII_BASE, 8).v)
46+
number = number.int_mod(divisor, 4*num_bytes_followers)
47+
dot_sbit_vec = sbitvec(sint(46),8).v
48+
insert_index = (num_bytes_followers - 2) * 8
49+
for ele in [sbit(ele) for ele in dot_sbit_vec][::-1]:
50+
followers_bits_list.insert(insert_index, ele)
51+
active_encoding:list[sbitvec] = []
52+
for i in range(len(encoding)):
53+
filtered_delta = []
54+
for j in range(len(delta)):
55+
filtered_delta.append(followers_bits_list[i].if_else(delta[j], sbit(0)))
56+
filtered_delta = sbitvec.from_vec(filtered_delta)
57+
active_encoding.append(encoding[i].bit_xor(filtered_delta))
58+
59+
concat = nonce.bit_decompose() + sbitvec(sint(num_bytes_followers+1), 8).bit_decompose()
60+
for i in range(len(encoding)):
61+
if i%8 ==0:
62+
concat = concat + sbitvec(sint(1), 8).bit_decompose()
63+
concat = concat+active_encoding[i].bit_decompose()
64+
return sha3_256(sbitvec.compose(concat))
65+
66+
67+
def main():
68+
# put as array to make it object
69+
# First element is the number of clients
70+
client_values = sint.Array(1 + MAX_DATA_PROVIDERS)
71+
commitment_values = sint.Array(MAX_DATA_PROVIDERS)
72+
73+
74+
# Start listening for client socket connections
75+
print_ln('Calling listen_for_clients(%s)...', PORTNUM)
76+
listen_for_clients(PORTNUM)
77+
print_ln('Listening for client connections on base port %s', PORTNUM)
78+
79+
client_socket_id = accept_client()
80+
print_ln('Accepted client connection. client_socket_id: %s', client_socket_id)
81+
82+
input_value, input_nonce = client_input(sint, client_socket_id)
83+
client_values[SECRET_INDEX] = input_value
84+
client_values[0] = client_values[0] + 1
85+
client_values.write_to_file(0)
86+
87+
# these are shared directly to each computation party so can just hardcode
88+
input_delta = sbitvec.from_hex(DELTA)
89+
input_zero_encodings = [sbitvec.from_hex(e) for e in ZERO_ENCODINGS]
90+
91+
# nonce must be secret_shared
92+
input_commitment = calculate_data_commitment(INPUT_BYTES-1, input_value, input_delta, input_zero_encodings, input_nonce)
93+
input_commitment.reveal_print_hex()
94+
95+
# commitment of input i is stored in commitment_values[i-1]
96+
commitment_values[SECRET_INDEX-1] = input_commitment
97+
commitment_values.write_to_file(1 + MAX_DATA_PROVIDERS)
98+
print_ln('commitment_values: after update: %s', [commitment_values[i].reveal() for i in range(MAX_DATA_PROVIDERS)])
99+
sint.reveal_to_clients([client_socket_id],[commitment_values[SECRET_INDEX-1]])
100+
print_ln('Now closing this connection')
101+
#print_ln('Num data providers: %s', client_values[0].reveal())
102+
103+
closeclientconnection(client_socket_id)
104+
105+
106+
main()

mpc_demo_infra/computation_party_server/config.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class Settings(BaseSettings):
1313
database_url: str = f"sqlite:///./party_0.db"
1414

1515
# Coordination server settings
16-
coordination_server_url: str = "http://localhost:8005"
16+
coordination_server_url: str = "http://127.0.0.1:8005"
1717
# API key that coordination server uses in order to be able to access
1818
# `request_sharing_data_mpc` and `request_querying_computation_mpc` endpoints.
1919
# In production, we need https to protect the API key from being exposed.
@@ -24,7 +24,7 @@ class Settings(BaseSettings):
2424

2525
port: int = 8006
2626
party_web_protocol: str = "http"
27-
party_hosts: list[str] = ["localhost", "localhost", "localhost"]
27+
party_hosts: list[str] = ["127.0.0.1", "127.0.0.1", "127.0.0.1"]
2828
party_ports: list[int] = [8006, 8007, 8008]
2929
mpspdz_project_root: str = str(this_file_path.parent.parent.parent / "MP-SPDZ")
3030

mpc_demo_infra/computation_party_server/routes.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -94,19 +94,19 @@ def request_sharing_data_mpc(request: RequestSharingDataMPCRequest, db: Session
9494
]
9595
binance_verifier_dir, binance_verifier_exec_cmd = locate_binance_verifier(binance_verifier_locations)
9696
logger.info("Verifying TLSN proof...")
97-
try:
98-
subprocess.run(
99-
f"{binance_verifier_exec_cmd} {temp_file.name}",
100-
cwd=binance_verifier_dir,
101-
check=True,
102-
shell=True,
103-
capture_output=True,
104-
text=True,
105-
)
106-
except subprocess.CalledProcessError as e:
107-
logger.error(f"Failed to verify TLSN proof: {str(e)}, stdout={e.stdout.strip()}, stderr={e.stderr.strip()}")
108-
raise HTTPException(status_code=400, detail="Failed when verifying TLSN proof")
109-
logger.info("TLSN proof is valid")
97+
# try:
98+
# subprocess.run(
99+
# f"{binance_verifier_exec_cmd} {temp_file.name}",
100+
# cwd=binance_verifier_dir,
101+
# check=True,
102+
# shell=True,
103+
# capture_output=True,
104+
# text=True,
105+
# )
106+
# except subprocess.CalledProcessError as e:
107+
# logger.error(f"Failed to verify TLSN proof: {str(e)}, stdout={e.stdout.strip()}, stderr={e.stderr.strip()}")
108+
# raise HTTPException(status_code=400, detail="Failed when verifying TLSN proof")
109+
# logger.info("TLSN proof is valid")
110110

111111
# 2. Backup previous shares
112112
backup_shares_path = backup_shares(settings.party_id)

mpc_demo_infra/coordination_server/config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class Settings(BaseSettings):
3535
party_api_key: str = "1234567890"
3636
party_web_protocol: str = "http"
3737
# Party IPs. Used to whitelist IPs that can access party-server-only APIs.
38-
party_hosts: List[str] = ["localhost", "localhost", "localhost"]
38+
party_hosts: List[str] = ["127.0.0.1", "127.0.0.1", "127.0.0.1"]
3939
party_ports: List[int] = [8006, 8007, 8008]
4040

4141
fullchain_pem_path: str = "ssl_certs/fullchain.pem"

mpc_demo_infra/coordination_server/routes.py

+20-19
Original file line numberDiff line numberDiff line change
@@ -137,24 +137,25 @@ async def share_data(request: RequestSharingDataRequest, x: Request, db: Session
137137
stdout=asyncio.subprocess.PIPE,
138138
stderr=asyncio.subprocess.PIPE
139139
)
140-
logger.info(f"Getting TLSN proof verification result...")
141-
stdout, stderr = await process.communicate()
142-
try:
143-
uid = get_uid_from_tlsn_proof_verifier(stdout.decode('utf-8'))
144-
logger.info(f"Got UID from TLSN proof verifier: {uid}")
145-
except ValueError as e:
146-
logger.error(f"Failed to get UID from TLSN proof verifier: {e}, {stdout.decode('utf-8')=}, {stderr.decode('utf-8')=}")
147-
raise HTTPException(status_code=400, detail="Failed to get UID from TLSN proof verifier")
148-
if process.returncode != 0:
149-
logger.error(f"TLSN proof verification failed with return code {process.returncode}, {stdout=}, {stderr=}")
150-
raise HTTPException(status_code=400, detail=f"TLSN proof verification failed with return code {process.returncode}, {stdout=}, {stderr=}")
151-
logger.info(f"TLSN proof verification passed")
152-
153-
if settings.prohibit_multiple_contributions:
154-
# Check if uid already in db. If so, raise an error.
155-
if db.query(MPCSession).filter(MPCSession.uid == uid).first():
156-
logger.error(f"UID {uid} already in database")
157-
raise HTTPException(status_code=400, detail=f"UID {uid} already shared data")
140+
uid = 0
141+
# logger.info(f"Getting TLSN proof verification result...")
142+
# stdout, stderr = await process.communicate()
143+
# try:
144+
# uid = get_uid_from_tlsn_proof_verifier(stdout.decode('utf-8'))
145+
# logger.info(f"Got UID from TLSN proof verifier: {uid}")
146+
# except ValueError as e:
147+
# logger.error(f"Failed to get UID from TLSN proof verifier: {e}, {stdout.decode('utf-8')=}, {stderr.decode('utf-8')=}")
148+
# raise HTTPException(status_code=400, detail="Failed to get UID from TLSN proof verifier")
149+
# if process.returncode != 0:
150+
# logger.error(f"TLSN proof verification failed with return code {process.returncode}, {stdout=}, {stderr=}")
151+
# raise HTTPException(status_code=400, detail=f"TLSN proof verification failed with return code {process.returncode}, {stdout=}, {stderr=}")
152+
# logger.info(f"TLSN proof verification passed")
153+
#
154+
# if settings.prohibit_multiple_contributions:
155+
# # Check if uid already in db. If so, raise an error.
156+
# if db.query(MPCSession).filter(MPCSession.uid == uid).first():
157+
# logger.error(f"UID {uid} already in database")
158+
# raise HTTPException(status_code=400, detail=f"UID {uid} already shared data")
158159

159160
# Acquire lock to prevent concurrent sharing data requests
160161
logger.info(f"Acquiring lock for sharing data for {eth_address=}")
@@ -261,7 +262,7 @@ async def request_sharing_data_all_parties():
261262
except Exception as e:
262263
logger.error(f"Failed to share data: {str(e)}")
263264
sharing_data_lock.release()
264-
logger.info(f"Released lock for sharing data for {eth_address=}")
265+
logger.info(f"Released lock for sharing data for {eth_address=} after getting exception")
265266
raise HTTPException(status_code=400, detail="Failed to share data")
266267

267268
@router.post("/query_computation", response_model=RequestQueryComputationResponse)

mpc_demo_infra/coordination_server/user_queue.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def finish_computation(self, access_key: str, computation_key: str) -> bool:
177177
with self.locker.gen_wlock():
178178
position, user = self.user_positions.get(access_key, (None, None))
179179
if user is None:
180-
logger.warn(f"User '{access_key}' is no longer in the queue")
180+
logger.info(f"User '{access_key}' is no longer in the queue")
181181
return False
182182

183183
logger.info(f"Current position of the user '{user}' is {position}")

0 commit comments

Comments
 (0)