Skip to content

Commit

Permalink
Allow disabling rotations
Browse files Browse the repository at this point in the history
  • Loading branch information
Dzejkop committed Feb 29, 2024
1 parent 36b3edc commit e8a753c
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 8 deletions.
5 changes: 5 additions & 0 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ pub struct CoordinatorConfig {
pub participants: JsonStrWrapper<Vec<String>>,
pub hamming_distance_threshold: f64,
pub n_closest_distances: usize,
#[serde(default)]
pub no_rotations: bool,
pub db: DbConfig,
pub queues: CoordinatorQueuesConfig,
#[serde(default)]
Expand All @@ -61,6 +63,7 @@ pub struct CoordinatorConfig {
pub struct ParticipantConfig {
pub socket_addr: SocketAddr,
pub batch_size: usize,
pub no_rotations: bool,
pub db: DbConfig,
pub queues: ParticipantQueuesConfig,
#[serde(default)]
Expand Down Expand Up @@ -153,6 +156,7 @@ mod tests {
]),
hamming_distance_threshold: 0.375,
n_closest_distances: 20,
no_rotations: false,
db: DbConfig {
url: "postgres://localhost:5432/mpc".to_string(),
migrate: true,
Expand Down Expand Up @@ -196,6 +200,7 @@ mod tests {
participants = '["127.0.0.1:8000", "127.0.0.1:8001", "127.0.0.1:8002"]'
hamming_distance_threshold = 0.375
n_closest_distances = 20
no_rotations = false
[coordinator.db]
url = "postgres://localhost:5432/mpc"
Expand Down
3 changes: 2 additions & 1 deletion src/coordinator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,11 +245,12 @@ impl Coordinator {
) -> (Receiver<Vec<[u16; 31]>>, JoinHandle<eyre::Result<()>>) {
let (sender, denom_receiver) = tokio::sync::mpsc::channel(4);
let masks = self.masks.clone();
let no_rotations = self.config.no_rotations;

let denominator_handle = tokio::task::spawn_blocking(move || {
let masks = masks.blocking_lock();
let masks: &[Bits] = bytemuck::cast_slice(&masks);
let engine = MasksEngine::new(&mask);
let engine = MasksEngine::new(&mask, no_rotations);
let total_masks: usize = masks.len();

tracing::info!("Processing denominators");
Expand Down
23 changes: 18 additions & 5 deletions src/distance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,13 @@ pub struct DistanceEngine {
}

impl DistanceEngine {
pub fn new(rotations: impl Iterator<Item = EncodedBits>) -> Self {
pub fn new(rotations: impl IntoIterator<Item = EncodedBits>) -> Self {
Self {
rotations: rotations.collect::<Box<[_]>>().try_into().unwrap(),
rotations: rotations
.into_iter()
.collect::<Box<[_]>>()
.try_into()
.unwrap(),
}
}

Expand Down Expand Up @@ -98,9 +102,18 @@ pub struct MasksEngine {
}

impl MasksEngine {
pub fn new(query: &Bits) -> Self {
let rotations =
query.rotations().collect::<Box<[_]>>().try_into().unwrap();
pub fn new(query: &Bits, no_rotations: bool) -> Self {
let rotations = if no_rotations {
query
.rotations()
.map(|_| *query) // If rotations are disabled we just copy the original value for each rotation
.collect::<Box<[_]>>()
.try_into()
.unwrap()
} else {
query.rotations().collect::<Box<[_]>>().try_into().unwrap()
};

Self { rotations }
}

Expand Down
21 changes: 19 additions & 2 deletions src/participant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,15 @@ impl Participant {

// Process in worker thread
let (sender, mut receiver) = mpsc::channel(4);
let no_rotations = self.config.no_rotations;
let worker = tokio::task::spawn_blocking(move || {
calculate_share_distances(shares_ref, template, batch_size, sender)
calculate_share_distances(
shares_ref,
template,
batch_size,
sender,
no_rotations,
)
});

while let Some(buffer) = receiver.recv().await {
Expand Down Expand Up @@ -261,11 +268,21 @@ fn calculate_share_distances(
template: Template,
batch_size: usize,
sender: mpsc::Sender<Vec<u8>>,
no_rotations: bool,
) -> eyre::Result<()> {
let shares = shares.blocking_lock();
let patterns: &[EncodedBits] = bytemuck::cast_slice(&shares);

let template_rotations = template.rotations().map(|r| encode(&r));
let template_rotations: Vec<EncodedBits> = if no_rotations {
template
.rotations()
.map(|_| template) // ignore rotations
.map(|r| encode(&r))
.collect()
} else {
template.rotations().map(|r| encode(&r)).collect()
};

let engine = DistanceEngine::new(template_rotations);

for chunk in patterns.chunks(batch_size) {
Expand Down

0 comments on commit e8a753c

Please sign in to comment.