in protocol/src/private_id_multi_key/company.rs [203:357]
fn calculate_set_diff(&self) -> Result<(), ProtocolError> {
match (
self.e_partner.clone().read(),
self.e_company.clone().read(),
self.v_partner.clone().write(),
self.v_company.clone().write(),
self.s_partner.clone().write(),
self.s_prime_company.clone().write(),
) {
(
Ok(e_partner),
Ok(e_company),
Ok(mut v_partner),
Ok(mut v_company),
Ok(mut s_partner),
Ok(mut s_prime_company),
) => {
let s_c = e_company.iter().map(|e| e[0]).collect::<Vec<_>>();
let s_p = e_partner.iter().map(|e| e[0]).collect::<Vec<_>>();
let mut v_c = self.ec_cipher.encrypt(
&self.ec_cipher.encrypt(s_c.as_slice(), &self.private_keys.1),
&self.private_keys.2,
);
let mut v_p = self.ec_cipher.encrypt(s_p.as_slice(), &self.private_keys.1);
let max_len = e_company.iter().map(|e| e.len()).max().unwrap();
// Start with both vectors as all valid
let mut e_c_valid = vec![true; e_company.len()];
let mut e_p_valid = vec![true; e_partner.len()];
for idx in 0..max_len {
// TODO: This should be a ByteBuffer instead of a vec<u8>
let mut e_c_map = HashMap::<Vec<u8>, usize>::new();
// Strip the idx-th key (viewed as a column)
for (e, i) in e_company
.iter()
.enumerate()
.filter(|(_, e)| e.len() > idx)
.map(|(i, e)| (e[idx], i))
{
// Ristretto points are not hashable by themselves
e_c_map.insert(e.compress().to_bytes().to_vec(), i);
}
// Vector of indices of e_p that match. These will be set to false
let mut e_p_match_idx = Vec::<usize>::new();
for ((i, e), _) in e_partner
.iter()
.enumerate()
.zip_eq(e_p_valid.iter())
.filter(|((_, _), &f)| f)
{
// Find the minimum index where match happens
let match_idx = e
.iter()
.map(|key|
// TODO: Replace with match
if e_c_map.contains_key(&key.compress().to_bytes().to_vec()) {
let &m_idx = e_c_map.get(&key.compress().to_bytes().to_vec()).unwrap();
(m_idx, e_c_valid[m_idx])
} else {
// Using length of vector as a sentinel value. Will get
// filtered out because of false
(e_c_valid.len(), false)
})
.filter(|(_, f)| *f)
.map(|(e, _)| e)
.min();
// For those indices that have matched - set them to false
// Also assign the correct keys
if let Some(m_idx) = match_idx {
// Create a single element vector since that is what encrypt
// expects
let matched = vec![e_company[m_idx][idx]];
let c = self.ec_cipher.encrypt(
&self
.ec_cipher
.encrypt(matched.as_slice(), &self.private_keys.1),
&self.private_keys.2,
);
let p = self
.ec_cipher
.encrypt(matched.as_slice(), &self.private_keys.1);
e_c_valid[m_idx] = false;
v_c[m_idx] = c[0];
e_p_match_idx.push(i);
v_p[i] = p[0];
}
}
// Set all e_p that matched to false - so they aren't matched in the next
// iteration
e_p_match_idx.iter().for_each(|&idx| e_p_valid[idx] = false);
}
// Create V_c and V_p
v_company.clear();
v_company.extend(self.ec_cipher.to_bytes(v_c.as_slice()));
v_partner.clear();
v_partner.extend(self.ec_cipher.to_bytes(v_p.as_slice()));
// Create S_p by filtering out values that matched
s_partner.clear();
{
// Only keep s_p that have not been matched
let mut inp = s_p
.iter()
.zip_eq(e_p_valid.iter())
.filter(|(_, &f)| f)
.map(|(&e, _)| e)
.collect::<Vec<_>>();
if !inp.is_empty() {
// Permute s_p
permute(gen_permute_pattern(inp.len()).as_slice(), &mut inp);
// save output
s_partner.extend(self.ec_cipher.to_bytes(inp.as_slice()));
}
}
// Create S_c by filtering out values that matched
let t = s_c
.iter()
.zip_eq(e_c_valid.iter())
.filter(|(_, &f)| f)
.map(|(&e, _)| e)
.collect::<Vec<_>>();
s_prime_company.clear();
if !t.is_empty() {
s_prime_company.extend(
self.ec_cipher
.encrypt_to_bytes(t.as_slice(), &self.private_keys.1),
);
}
Ok(())
}
_ => {
error!("Unable to obtain locks to buffers for set diff operation");
Err(ProtocolError::ErrorCalcSetDiff(
"cannot calculate set difference".to_string(),
))
}
}
}