Skip to content

Commit 6e1d0b1

Browse files
naureAurel
and
Aurel
authored
Get Vector IDs from server (#1151)
* aurel/expose-vector-id: Flip the dependency VectorId / PointId * aurel/expose-vector-id: Move VectorId to common * aurel/expose-vector-id: API for vector_id of loaded irises * aurel/expose-vector-id: _from_s3 * aurel/expose-vector-id: Use 0-based index in results * aurel/expose-vector-id: conversion in tests * aurel/expose-vector-id: rename variable * aurel/expose-vector-id: Less pub * aurel/expose-vector-id: Convert plaintext test IDs --------- Co-authored-by: Aurel <[email protected]>
1 parent c20d10d commit 6e1d0b1

File tree

15 files changed

+157
-124
lines changed

15 files changed

+157
-124
lines changed

Diff for: iris-mpc-common/src/helpers/inmemory_store.rs

+12-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use crate::vector_id::VectorId;
2+
13
/// A helper trait encapsulating the functionality to add iris codes to some
24
/// form of in-memory store.
35
pub trait InMemoryStore {
@@ -15,6 +17,7 @@ pub trait InMemoryStore {
1517
fn load_single_record_from_db(
1618
&mut self,
1719
index: usize,
20+
vector_id: VectorId,
1821
left_code: &[u16],
1922
left_mask: &[u16],
2023
right_code: &[u16],
@@ -52,6 +55,7 @@ pub trait InMemoryStore {
5255
fn load_single_record_from_s3(
5356
&mut self,
5457
index: usize,
58+
vector_id: VectorId,
5559
left_code_odd: &[u8],
5660
left_code_even: &[u8],
5761
right_code_odd: &[u8],
@@ -90,7 +94,14 @@ pub trait InMemoryStore {
9094
.zip(right_mask_even.iter())
9195
.map(|(odd, even)| map_back_to_u16(*odd, *even))
9296
.collect::<Vec<_>>();
93-
self.load_single_record_from_db(index, &left_code, &left_mask, &right_code, &right_mask);
97+
self.load_single_record_from_db(
98+
index,
99+
vector_id,
100+
&left_code,
101+
&left_mask,
102+
&right_code,
103+
&right_mask,
104+
);
94105
}
95106

96107
/// Executes any necessary preprocessing steps on the in-memory store.

Diff for: iris-mpc-common/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ pub mod job;
1010
pub mod shamir;
1111
#[cfg(feature = "helpers")]
1212
pub mod test;
13+
pub mod vector_id;
1314

1415
pub const IRIS_CODE_LENGTH: usize = 12_800;
1516
pub const MASK_CODE_LENGTH: usize = 6_400;

Diff for: iris-mpc-common/src/test.rs

+9-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use crate::{
1010
iris::{IrisCode, IrisCodeArray},
1111
},
1212
job::{BatchQuery, JobSubmissionHandle, ServerJobResult},
13+
vector_id::VectorId,
1314
IRIS_CODE_LENGTH,
1415
};
1516
use eyre::Result;
@@ -1172,7 +1173,14 @@ pub fn load_test_db(
11721173
) -> Result<()> {
11731174
let iris_shares = generate_test_db(party_id, db_size, db_rng_seed);
11741175
for (idx, (code, mask)) in iris_shares.into_iter().enumerate() {
1175-
loader.load_single_record_from_db(idx, &code.coefs, &mask.coefs, &code.coefs, &mask.coefs);
1176+
loader.load_single_record_from_db(
1177+
idx,
1178+
VectorId::from_0_index(idx as u32),
1179+
&code.coefs,
1180+
&mask.coefs,
1181+
&code.coefs,
1182+
&mask.coefs,
1183+
);
11761184
loader.increment_db_size(idx);
11771185
}
11781186

Diff for: iris-mpc-common/src/vector_id.rs

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
use std::{fmt::Display, num::ParseIntError, str::FromStr};
2+
3+
use serde::{Deserialize, Serialize};
4+
5+
/// Unique identifier for an immutable pair of iris codes.
6+
#[derive(Copy, Default, Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
7+
pub struct VectorId {
8+
id: u32,
9+
}
10+
11+
impl Display for VectorId {
12+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
13+
Display::fmt(&self.id, f)
14+
}
15+
}
16+
17+
impl FromStr for VectorId {
18+
type Err = ParseIntError;
19+
20+
fn from_str(s: &str) -> Result<Self, Self::Err> {
21+
Ok(VectorId {
22+
id: FromStr::from_str(s)?,
23+
})
24+
}
25+
}
26+
27+
impl VectorId {
28+
/// From Serial ID (1-indexed).
29+
pub fn from_serial_id(id: u32) -> Self {
30+
VectorId { id }
31+
}
32+
33+
/// To Serial ID (1-indexed).
34+
pub fn serial_id(&self) -> u32 {
35+
self.id
36+
}
37+
38+
/// From index (0-indexed).
39+
pub fn from_0_index(index: u32) -> Self {
40+
VectorId { id: index + 1 }
41+
}
42+
43+
/// To index (0-indexed).
44+
pub fn index(&self) -> u32 {
45+
self.id - 1
46+
}
47+
}

Diff for: iris-mpc-cpu/src/execution/hawk_main.rs

+8-8
Original file line numberDiff line numberDiff line change
@@ -482,13 +482,13 @@ pub struct IrisLoader<'a> {
482482
impl<'a> InMemoryStore for IrisLoader<'a> {
483483
fn load_single_record_from_db(
484484
&mut self,
485-
index: usize,
485+
_index: usize, // TODO: Map.
486+
vector_id: VectorId,
486487
left_code: &[u16],
487488
left_mask: &[u16],
488489
right_code: &[u16],
489490
right_mask: &[u16],
490491
) {
491-
let vector_id = VectorId::from_serial_id(index as u32);
492492
for (side, code, mask) in izip!(
493493
&mut self.irises,
494494
[left_code, right_code],
@@ -674,7 +674,7 @@ impl HawkResult {
674674
.iter()
675675
.enumerate()
676676
.map(|(idx, plan)| match plan {
677-
Some(plan) => plan.inserted_vector.to_serial_id(),
677+
Some(plan) => plan.inserted_vector.index(),
678678
None => match_ids[idx][0],
679679
})
680680
.collect()
@@ -684,13 +684,13 @@ impl HawkResult {
684684
// Graph matches.
685685
let mut match_ids = self
686686
.match_results
687-
.filter_map(|(id, [l, r])| (*l && *r).then_some(id.to_serial_id()));
687+
.filter_map(|(id, [l, r])| (*l && *r).then_some(id.index()));
688688

689689
// Intra-batch matches. Find the serial IDs that were just inserted.
690690
for (graph_matches, intra_matches) in izip!(match_ids.iter_mut(), &self.intra_results) {
691691
for i_request in intra_matches {
692692
if let Some(plan) = &self.connect_plans.0[LEFT][*i_request] {
693-
graph_matches.push(plan.inserted_vector.to_serial_id());
693+
graph_matches.push(plan.inserted_vector.index());
694694
}
695695
}
696696
}
@@ -717,10 +717,10 @@ impl HawkResult {
717717

718718
let partial_match_ids_left = self
719719
.match_results
720-
.filter_map(|(id, [l, _r])| l.then_some(id.to_serial_id()));
720+
.filter_map(|(id, [l, _r])| l.then_some(id.index()));
721721
let partial_match_ids_right = self
722722
.match_results
723-
.filter_map(|(id, [_l, r])| r.then_some(id.to_serial_id()));
723+
.filter_map(|(id, [_l, r])| r.then_some(id.index()));
724724
let partial_match_counters_left = partial_match_ids_left.iter().map(Vec::len).collect();
725725
let partial_match_counters_right = partial_match_ids_right.iter().map(Vec::len).collect();
726726

@@ -1192,7 +1192,7 @@ mod tests_db {
11921192
#[tokio::test]
11931193
async fn test_graph_load() -> Result<()> {
11941194
// The test data is a sequence of mutations on the graph.
1195-
let vectors = (0..5).map(VectorId::from).collect_vec();
1195+
let vectors = (0..5).map(VectorId::from_0_index).collect_vec();
11961196
let distance = DistanceShare::new(Default::default(), Default::default());
11971197

11981198
let make_plans = |side| {

Diff for: iris-mpc-cpu/src/hawkers/aby3/aby3_store.rs

+11-72
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
use crate::{
22
execution::{player::Identity, session::Session},
3-
hawkers::plaintext_store::PointId,
43
hnsw::{vector_store::VectorStoreMut, VectorStore},
54
protocol::{
65
ops::{
@@ -13,71 +12,15 @@ use crate::{
1312
};
1413
use itertools::Itertools;
1514
use serde::{Deserialize, Serialize};
16-
use std::{
17-
collections::HashMap,
18-
fmt::{Debug, Display},
19-
num::ParseIntError,
20-
str::FromStr,
21-
sync::Arc,
22-
vec,
23-
};
15+
use std::{collections::HashMap, fmt::Debug, sync::Arc, vec};
2416
use tokio::sync::{RwLock, RwLockWriteGuard};
2517
use tracing::instrument;
2618

19+
pub use iris_mpc_common::vector_id::VectorId;
20+
2721
/// Reference to an iris in the Shamir secret shared form over a Galois ring.
2822
pub type IrisRef = Arc<GaloisRingSharedIris>;
2923

30-
/// Unique identifier for an iris inserted into the store.
31-
#[derive(Copy, Default, Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
32-
pub struct VectorId {
33-
pub(crate) id: PointId,
34-
}
35-
36-
impl Display for VectorId {
37-
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38-
Display::fmt(&self.id, f)
39-
}
40-
}
41-
42-
impl FromStr for VectorId {
43-
type Err = ParseIntError;
44-
45-
fn from_str(s: &str) -> Result<Self, Self::Err> {
46-
Ok(VectorId {
47-
id: FromStr::from_str(s)?,
48-
})
49-
}
50-
}
51-
52-
impl From<PointId> for VectorId {
53-
fn from(id: PointId) -> Self {
54-
VectorId { id }
55-
}
56-
}
57-
58-
impl From<&PointId> for VectorId {
59-
fn from(id: &PointId) -> Self {
60-
VectorId { id: *id }
61-
}
62-
}
63-
64-
impl From<usize> for VectorId {
65-
fn from(id: usize) -> Self {
66-
VectorId { id: id.into() }
67-
}
68-
}
69-
70-
impl VectorId {
71-
pub fn from_serial_id(id: u32) -> Self {
72-
VectorId { id: id.into() }
73-
}
74-
75-
/// Returns the ID of a vector as a number.
76-
pub fn to_serial_id(&self) -> u32 {
77-
self.id.0
78-
}
79-
}
80-
8124
/// Iris to be searcher or inserted into the store.
8225
#[derive(Clone, Serialize, Deserialize, Hash, Eq, PartialEq, Debug)]
8326
pub struct Query {
@@ -126,13 +69,11 @@ pub struct SharedIrises {
12669
impl SharedIrises {
12770
pub fn insert(&mut self, vector_id: VectorId, iris: IrisRef) {
12871
self.points.insert(vector_id, iris);
129-
self.next_id = self.next_id.max(vector_id.to_serial_id() + 1);
72+
self.next_id = self.next_id.max(vector_id.serial_id() + 1);
13073
}
13174

13275
fn next_id(&mut self) -> VectorId {
133-
let new_id = VectorId {
134-
id: PointId(self.next_id),
135-
};
76+
let new_id = VectorId::from_serial_id(self.next_id);
13677
self.next_id += 1;
13778
new_id
13879
}
@@ -171,11 +112,7 @@ impl Default for SharedIrisesRef {
171112
// Constructor.
172113
impl SharedIrisesRef {
173114
pub fn new(points: HashMap<VectorId, IrisRef>) -> Self {
174-
let next_id = points
175-
.keys()
176-
.map(|v| v.to_serial_id() + 1)
177-
.max()
178-
.unwrap_or(0);
115+
let next_id = points.keys().map(|v| v.serial_id()).max().unwrap_or(0) + 1;
179116
let body = SharedIrises {
180117
points,
181118
next_id,
@@ -463,6 +400,7 @@ mod tests {
463400
let hawk_searcher = HnswSearcher::default();
464401

465402
for i in 0..database_size {
403+
let vector_id = VectorId::from_0_index(i as u32);
466404
let cleartext_neighbors = hawk_searcher
467405
.search(&mut cleartext_data.0, &mut cleartext_data.1, &i.into(), 1)
468406
.await;
@@ -477,7 +415,7 @@ mod tests {
477415
let hawk_searcher = hawk_searcher.clone();
478416
let v_lock = v.lock().await;
479417
let mut g = g.clone();
480-
let q = v_lock.storage.get_vector(&i.into()).await;
418+
let q = v_lock.storage.get_vector(&vector_id).await;
481419
let q = prepare_query((*q).clone());
482420
let v = v.clone();
483421
jobs.spawn(async move {
@@ -498,7 +436,7 @@ mod tests {
498436
let mut g = g.clone();
499437
jobs.spawn(async move {
500438
let mut v_lock = v.lock().await;
501-
let query = v_lock.storage.get_vector(&i.into()).await;
439+
let query = v_lock.storage.get_vector(&vector_id).await;
502440
let query = prepare_query((*query).clone());
503441
let secret_neighbors =
504442
hawk_searcher.search(&mut *v_lock, &mut g, &query, 1).await;
@@ -707,11 +645,12 @@ mod tests {
707645
.unwrap();
708646

709647
for i in 0..database_size {
648+
let vector_id = VectorId::from_0_index(i as u32);
710649
let mut jobs = JoinSet::new();
711650
for (store, graph) in vectors_and_graphs.iter_mut() {
712651
let mut graph = graph.clone();
713652
let searcher = searcher.clone();
714-
let q = store.lock().await.storage.get_vector(&i.into()).await;
653+
let q = store.lock().await.storage.get_vector(&vector_id).await;
715654
let q = prepare_query((*q).clone());
716655
let store = store.clone();
717656
jobs.spawn(async move {

Diff for: iris-mpc-cpu/src/hawkers/aby3/test_utils.rs

+6-6
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ use crate::{
1111
local::{generate_local_identities, LocalRuntime},
1212
session::SessionHandles,
1313
},
14-
hawkers::plaintext_store::PlaintextStore,
14+
hawkers::plaintext_store::{PlaintextStore, PointId},
1515
hnsw::{graph::layered_graph::Layer, GraphMem, HnswSearcher, SortedNeighborhood, VectorStore},
1616
network::NetworkType,
1717
protocol::shared_iris::GaloisRingSharedIris,
@@ -39,8 +39,8 @@ pub async fn setup_local_aby3_players_with_preloaded_db<R: RngCore + CryptoRng>(
3939

4040
let mut shared_irises = vec![HashMap::new(); identities.len()];
4141

42-
for (vector_id, iris) in plain_store.points.iter().enumerate() {
43-
let vector_id = VectorId::from_serial_id(vector_id as u32);
42+
for (i, iris) in plain_store.points.iter().enumerate() {
43+
let vector_id = VectorId::from(PointId::from(i));
4444
let all_shares = GaloisRingSharedIris::generate_shares_locally(rng, iris.data.0.clone());
4545
for (party_id, share) in all_shares.into_iter().enumerate() {
4646
shared_irises[party_id].insert(vector_id, Arc::new(share));
@@ -138,7 +138,7 @@ async fn graph_from_plain(
138138
recompute_distances: bool,
139139
) -> GraphMem<Aby3Store> {
140140
let ep = graph_store.get_entry_point().await;
141-
let new_ep = ep.map(|(vector_ref, layer_count)| (VectorId { id: vector_ref }, layer_count));
141+
let new_ep = ep.map(|(vector_ref, layer_count)| (VectorId::from(vector_ref), layer_count));
142142

143143
let layers = graph_store.get_layers();
144144

@@ -150,10 +150,10 @@ async fn graph_from_plain(
150150
let links = layer.get_links_map();
151151
let mut shared_links = HashMap::new();
152152
for (source_v, queue) in links {
153-
let source_v = source_v.into();
153+
let source_v = VectorId::from(*source_v);
154154
let mut shared_queue = vec![];
155155
for (target_v, dist) in queue.as_vec_ref() {
156-
let target_v = target_v.into();
156+
let target_v = VectorId::from(*target_v);
157157
let distance = if recompute_distances {
158158
// recompute distances of graph edges from scratch
159159
eval_vector_distance(&mut vectore_store_lock, &source_v, &target_v).await

Diff for: iris-mpc-cpu/src/hawkers/plaintext_store.rs

+8
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ use std::{
1818
};
1919
use tracing::debug;
2020

21+
use super::aby3::aby3_store::VectorId;
22+
2123
#[derive(Default, Debug, Clone, Serialize, Deserialize, PartialEq)]
2224
pub struct PlaintextIris(pub IrisCode);
2325

@@ -107,6 +109,12 @@ impl From<u32> for PointId {
107109
}
108110
}
109111

112+
impl From<PointId> for VectorId {
113+
fn from(id: PointId) -> Self {
114+
VectorId::from_0_index(id.0)
115+
}
116+
}
117+
110118
#[derive(Default, Debug, Clone, Serialize, Deserialize, PartialEq)]
111119
pub struct PlaintextStore {
112120
pub points: Vec<PlaintextPoint>,

0 commit comments

Comments
 (0)