src/util.rs (353 lines of code) (raw):
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under both the MIT license found in the
// LICENSE-MIT file in the root directory of this source tree and the Apache
// License, Version 2.0 found in the LICENSE-APACHE file in the root directory
// of this source tree.
//! Helper functions
use core::convert::TryFrom;
use derive_where::derive_where;
use digest::core_api::BlockSizeUser;
use digest::{Digest, OutputSizeUser};
use generic_array::sequence::Concat;
use generic_array::typenum::{IsLess, IsLessOrEqual, Unsigned, U11, U2, U256};
use generic_array::{ArrayLength, GenericArray};
use rand_core::{CryptoRng, RngCore};
use subtle::ConstantTimeEq;
use crate::group::{STR_HASH_TO_GROUP, STR_HASH_TO_SCALAR};
#[cfg(feature = "serde")]
use crate::serialization::serde::{Element, Scalar};
use crate::{CipherSuite, Error, Group, InternalError, Result};
///////////////
// Constants //
// ========= //
///////////////
pub(crate) const STR_FINALIZE: [u8; 8] = *b"Finalize";
pub(crate) const STR_SEED: [u8; 5] = *b"Seed-";
pub(crate) const STR_DERIVE_KEYPAIR: [u8; 13] = *b"DeriveKeyPair";
pub(crate) const STR_COMPOSITE: [u8; 9] = *b"Composite";
pub(crate) const STR_CHALLENGE: [u8; 9] = *b"Challenge";
pub(crate) const STR_INFO: [u8; 4] = *b"Info";
pub(crate) const STR_VOPRF: [u8; 8] = *b"VOPRF09-";
/// Determines the mode of operation (either base mode or verifiable mode). This
/// is only used for custom implementations for [`Group`].
#[derive(Clone, Copy, Debug)]
pub enum Mode {
/// Non-verifiable mode.
Oprf,
/// Verifiable mode.
Voprf,
/// Partially-oblivious mode.
Poprf,
}
impl Mode {
/// Mode as it is represented in a context string.
pub fn to_u8(self) -> u8 {
match self {
Mode::Oprf => 0,
Mode::Voprf => 1,
Mode::Poprf => 2,
}
}
}
////////////////////////////
// High-level API Structs //
// ====================== //
////////////////////////////
/// The first client message sent from a client (either verifiable or not) to a
/// server (either verifiable or not).
#[derive_where(Clone, ZeroizeOnDrop)]
#[derive_where(Debug, Eq, Hash, Ord, PartialEq, PartialOrd; <CS::Group as Group>::Elem)]
#[cfg_attr(
feature = "serde",
derive(serde::Deserialize, serde::Serialize),
serde(crate = "serde", bound = "")
)]
pub struct BlindedElement<CS: CipherSuite>(
#[cfg_attr(feature = "serde", serde(with = "Element::<CS::Group>"))]
pub(crate) <CS::Group as Group>::Elem,
)
where
<CS::Hash as OutputSizeUser>::OutputSize:
IsLess<U256> + IsLessOrEqual<<CS::Hash as BlockSizeUser>::BlockSize>;
/// The server's response to the [BlindedElement] message from a client (either
/// verifiable or not) to a server (either verifiable or not).
#[derive_where(Clone, ZeroizeOnDrop)]
#[derive_where(Debug, Eq, Hash, Ord, PartialEq, PartialOrd; <CS::Group as Group>::Elem)]
#[cfg_attr(
feature = "serde",
derive(serde::Deserialize, serde::Serialize),
serde(crate = "serde", bound = "")
)]
pub struct EvaluationElement<CS: CipherSuite>(
#[cfg_attr(feature = "serde", serde(with = "Element::<CS::Group>"))]
pub(crate) <CS::Group as Group>::Elem,
)
where
<CS::Hash as OutputSizeUser>::OutputSize:
IsLess<U256> + IsLessOrEqual<<CS::Hash as BlockSizeUser>::BlockSize>;
/// Contains prepared [`EvaluationElement`]s by a server batch evaluate
/// preparation.
#[derive_where(Clone, ZeroizeOnDrop)]
#[derive_where(Debug, Eq, Hash, Ord, PartialEq, PartialOrd; <CS::Group as Group>::Elem)]
#[cfg_attr(
feature = "serde",
derive(serde::Deserialize, serde::Serialize),
serde(crate = "serde", bound = "")
)]
pub struct PreparedEvaluationElement<CS: CipherSuite>(pub(crate) EvaluationElement<CS>)
where
<CS::Hash as OutputSizeUser>::OutputSize:
IsLess<U256> + IsLessOrEqual<<CS::Hash as BlockSizeUser>::BlockSize>;
/// A proof produced by a server that the OPRF output matches against a server
/// public key.
#[derive_where(Clone, ZeroizeOnDrop)]
#[derive_where(Debug, Eq, Hash, Ord, PartialEq, PartialOrd; <CS::Group as Group>::Scalar)]
#[cfg_attr(
feature = "serde",
derive(serde::Deserialize, serde::Serialize),
serde(crate = "serde", bound = "")
)]
pub struct Proof<CS: CipherSuite>
where
<CS::Hash as OutputSizeUser>::OutputSize:
IsLess<U256> + IsLessOrEqual<<CS::Hash as BlockSizeUser>::BlockSize>,
{
#[cfg_attr(feature = "serde", serde(with = "Scalar::<CS::Group>"))]
pub(crate) c_scalar: <CS::Group as Group>::Scalar,
#[cfg_attr(feature = "serde", serde(with = "Scalar::<CS::Group>"))]
pub(crate) s_scalar: <CS::Group as Group>::Scalar,
}
/////////////////////
// Proof Functions //
// =============== //
/////////////////////
/// Can only fail with [`Error::Batch`].
#[allow(clippy::many_single_char_names)]
pub(crate) fn generate_proof<CS: CipherSuite, R: RngCore + CryptoRng>(
rng: &mut R,
k: <CS::Group as Group>::Scalar,
a: <CS::Group as Group>::Elem,
b: <CS::Group as Group>::Elem,
cs: impl Iterator<Item = <CS::Group as Group>::Elem> + ExactSizeIterator,
ds: impl Iterator<Item = <CS::Group as Group>::Elem> + ExactSizeIterator,
mode: Mode,
) -> Result<Proof<CS>>
where
<CS::Hash as OutputSizeUser>::OutputSize:
IsLess<U256> + IsLessOrEqual<<CS::Hash as BlockSizeUser>::BlockSize>,
{
// https://www.ietf.org/archive/id/draft-irtf-cfrg-voprf-08.html#section-3.3.2.2-1
let (m, z) = compute_composites::<CS, _, _>(Some(k), b, cs, ds, mode)?;
let r = CS::Group::random_scalar(rng);
let t2 = a * &r;
let t3 = m * &r;
// Bm = GG.SerializeElement(B)
let bm = CS::Group::serialize_elem(b);
// a0 = GG.SerializeElement(M)
let a0 = CS::Group::serialize_elem(m);
// a1 = GG.SerializeElement(Z)
let a1 = CS::Group::serialize_elem(z);
// a2 = GG.SerializeElement(t2)
let a2 = CS::Group::serialize_elem(t2);
// a3 = GG.SerializeElement(t3)
let a3 = CS::Group::serialize_elem(t3);
let elem_len = <CS::Group as Group>::ElemLen::U16.to_be_bytes();
// h2Input = I2OSP(len(Bm), 2) || Bm ||
// I2OSP(len(a0), 2) || a0 ||
// I2OSP(len(a1), 2) || a1 ||
// I2OSP(len(a2), 2) || a2 ||
// I2OSP(len(a3), 2) || a3 ||
// "Challenge"
let h2_input = [
&elem_len,
bm.as_slice(),
&elem_len,
&a0,
&elem_len,
&a1,
&elem_len,
&a2,
&elem_len,
&a3,
&STR_CHALLENGE,
];
let dst = GenericArray::from(STR_HASH_TO_SCALAR).concat(create_context_string::<CS>(mode));
// This can't fail, the size of the `input` is known.
let c_scalar = CS::Group::hash_to_scalar::<CS>(&h2_input, &dst).unwrap();
let s_scalar = r - &(c_scalar * &k);
Ok(Proof { c_scalar, s_scalar })
}
/// Can only fail with [`Error::ProofVerification`] or [`Error::Batch`].
#[allow(clippy::many_single_char_names)]
pub(crate) fn verify_proof<CS: CipherSuite>(
a: <CS::Group as Group>::Elem,
b: <CS::Group as Group>::Elem,
cs: impl Iterator<Item = <CS::Group as Group>::Elem> + ExactSizeIterator,
ds: impl Iterator<Item = <CS::Group as Group>::Elem> + ExactSizeIterator,
proof: &Proof<CS>,
mode: Mode,
) -> Result<()>
where
<CS::Hash as OutputSizeUser>::OutputSize:
IsLess<U256> + IsLessOrEqual<<CS::Hash as BlockSizeUser>::BlockSize>,
{
// https://www.ietf.org/archive/id/draft-irtf-cfrg-voprf-08.html#section-3.3.4.1-2
let (m, z) = compute_composites::<CS, _, _>(None, b, cs, ds, mode)?;
let t2 = (a * &proof.s_scalar) + &(b * &proof.c_scalar);
let t3 = (m * &proof.s_scalar) + &(z * &proof.c_scalar);
// Bm = GG.SerializeElement(B)
let bm = CS::Group::serialize_elem(b);
// a0 = GG.SerializeElement(M)
let a0 = CS::Group::serialize_elem(m);
// a1 = GG.SerializeElement(Z)
let a1 = CS::Group::serialize_elem(z);
// a2 = GG.SerializeElement(t2)
let a2 = CS::Group::serialize_elem(t2);
// a3 = GG.SerializeElement(t3)
let a3 = CS::Group::serialize_elem(t3);
let elem_len = <CS::Group as Group>::ElemLen::U16.to_be_bytes();
// h2Input = I2OSP(len(Bm), 2) || Bm ||
// I2OSP(len(a0), 2) || a0 ||
// I2OSP(len(a1), 2) || a1 ||
// I2OSP(len(a2), 2) || a2 ||
// I2OSP(len(a3), 2) || a3 ||
// "Challenge"
let h2_input = [
&elem_len,
bm.as_slice(),
&elem_len,
&a0,
&elem_len,
&a1,
&elem_len,
&a2,
&elem_len,
&a3,
&STR_CHALLENGE,
];
let dst = GenericArray::from(STR_HASH_TO_SCALAR).concat(create_context_string::<CS>(mode));
// This can't fail, the size of the `input` is known.
let c = CS::Group::hash_to_scalar::<CS>(&h2_input, &dst).unwrap();
match c.ct_eq(&proof.c_scalar).into() {
true => Ok(()),
false => Err(Error::ProofVerification),
}
}
type ComputeCompositesResult<CS> = (
<<CS as CipherSuite>::Group as Group>::Elem,
<<CS as CipherSuite>::Group as Group>::Elem,
);
/// Can only fail with [`Error::Batch`].
fn compute_composites<
CS: CipherSuite,
IC: Iterator<Item = <CS::Group as Group>::Elem> + ExactSizeIterator,
ID: Iterator<Item = <CS::Group as Group>::Elem> + ExactSizeIterator,
>(
k_option: Option<<CS::Group as Group>::Scalar>,
b: <CS::Group as Group>::Elem,
c_slice: IC,
d_slice: ID,
mode: Mode,
) -> Result<ComputeCompositesResult<CS>>
where
<CS::Hash as OutputSizeUser>::OutputSize:
IsLess<U256> + IsLessOrEqual<<CS::Hash as BlockSizeUser>::BlockSize>,
{
// https://www.ietf.org/archive/id/draft-irtf-cfrg-voprf-08.html#section-3.3.2.3-2
let elem_len = <CS::Group as Group>::ElemLen::U16.to_be_bytes();
if c_slice.len() != d_slice.len() {
return Err(Error::Batch);
}
let len = u16::try_from(c_slice.len()).map_err(|_| Error::Batch)?;
// seedDST = "Seed-" || contextString
let seed_dst = GenericArray::from(STR_SEED).concat(create_context_string::<CS>(mode));
// h1Input = I2OSP(len(Bm), 2) || Bm ||
// I2OSP(len(seedDST), 2) || seedDST
// seed = Hash(h1Input)
let seed = CS::Hash::new()
.chain_update(&elem_len)
.chain_update(CS::Group::serialize_elem(b))
.chain_update(i2osp_2_array(&seed_dst))
.chain_update(seed_dst)
.finalize();
let seed_len = i2osp_2_array(&seed);
let mut m = CS::Group::identity_elem();
let mut z = CS::Group::identity_elem();
for (i, (c, d)) in (0..len).zip(c_slice.zip(d_slice)) {
// Ci = GG.SerializeElement(Cs[i])
let ci = CS::Group::serialize_elem(c);
// Di = GG.SerializeElement(Ds[i])
let di = CS::Group::serialize_elem(d);
// h2Input = I2OSP(len(seed), 2) || seed || I2OSP(i, 2) ||
// I2OSP(len(Ci), 2) || Ci ||
// I2OSP(len(Di), 2) || Di ||
// "Composite"
let h2_input = [
seed_len.as_slice(),
&seed,
&i.to_be_bytes(),
&elem_len,
&ci,
&elem_len,
&di,
&STR_COMPOSITE,
];
let dst = GenericArray::from(STR_HASH_TO_SCALAR).concat(create_context_string::<CS>(mode));
// This can't fail, the size of the `input` is known.
let di = CS::Group::hash_to_scalar::<CS>(&h2_input, &dst).unwrap();
m = c * &di + &m;
z = match k_option {
Some(_) => z,
None => d * &di + &z,
};
}
z = match k_option {
Some(k) => m * &k,
None => z,
};
Ok((m, z))
}
/////////////////////
// Inner Functions //
// =============== //
/////////////////////
type DeriveKeypairResult<CS> = (
<<CS as CipherSuite>::Group as Group>::Scalar,
<<CS as CipherSuite>::Group as Group>::Elem,
);
/// Can only fail with [`Error::DeriveKeyPair`] and [`Error::Protocol`].
pub(crate) fn derive_keypair<CS: CipherSuite>(
seed: &[u8],
info: &[u8],
mode: Mode,
) -> Result<DeriveKeypairResult<CS>, Error>
where
<CS::Hash as OutputSizeUser>::OutputSize:
IsLess<U256> + IsLessOrEqual<<CS::Hash as BlockSizeUser>::BlockSize>,
{
let context_string = create_context_string::<CS>(mode);
let dst = GenericArray::from(STR_DERIVE_KEYPAIR).concat(context_string);
let info_len = i2osp_2(info.len()).map_err(|_| Error::DeriveKeyPair)?;
for counter in 0_u8..=u8::MAX {
// deriveInput = seed || I2OSP(len(info), 2) || info
// skS = G.HashToScalar(deriveInput || I2OSP(counter, 1), DST = "DeriveKeyPair"
// || contextString)
let sk_s = <CS::Group as Group>::hash_to_scalar::<CS>(
&[seed, &info_len, info, &counter.to_be_bytes()],
&dst,
)
.map_err(|_| Error::DeriveKeyPair)?;
if !bool::from(CS::Group::is_zero_scalar(sk_s)) {
let pk_s = CS::Group::base_elem() * &sk_s;
return Ok((sk_s, pk_s));
}
}
Err(Error::Protocol)
}
/// Inner function for blind that assumes that the blinding factor has already
/// been chosen, and therefore takes it as input. Does not check if the blinding
/// factor is non-zero.
///
/// Can only fail with [`Error::Input`].
pub(crate) fn deterministic_blind_unchecked<CS: CipherSuite>(
input: &[u8],
blind: &<CS::Group as Group>::Scalar,
mode: Mode,
) -> Result<<CS::Group as Group>::Elem>
where
<CS::Hash as OutputSizeUser>::OutputSize:
IsLess<U256> + IsLessOrEqual<<CS::Hash as BlockSizeUser>::BlockSize>,
{
let dst = GenericArray::from(STR_HASH_TO_GROUP).concat(create_context_string::<CS>(mode));
let hashed_point = CS::Group::hash_to_curve::<CS>(&[input], &dst).map_err(|_| Error::Input)?;
Ok(hashed_point * blind)
}
/// Generates the contextString parameter as defined in
/// <https://www.ietf.org/archive/id/draft-irtf-cfrg-voprf-08.html>
pub(crate) fn create_context_string<CS: CipherSuite>(mode: Mode) -> GenericArray<u8, U11>
where
<CS::Hash as OutputSizeUser>::OutputSize:
IsLess<U256> + IsLessOrEqual<<CS::Hash as BlockSizeUser>::BlockSize>,
{
GenericArray::from(STR_VOPRF)
.concat([mode.to_u8()].into())
.concat(CS::ID.to_be_bytes().into())
}
///////////////////////
// Utility Functions //
// ================= //
///////////////////////
pub(crate) fn i2osp_2(input: usize) -> Result<[u8; 2], InternalError> {
u16::try_from(input)
.map(|input| input.to_be_bytes())
.map_err(|_| InternalError::I2osp)
}
pub(crate) fn i2osp_2_array<L: ArrayLength<u8> + IsLess<U256>>(
_: &GenericArray<u8, L>,
) -> GenericArray<u8, U2> {
L::U16.to_be_bytes().into()
}
#[cfg(test)]
mod unit_tests {
use proptest::collection::vec;
use proptest::prelude::*;
use crate::{
BlindedElement, EvaluationElement, OprfClient, OprfServer, PoprfClient, PoprfServer, Proof,
VoprfClient, VoprfServer,
};
macro_rules! test_deserialize {
($item:ident, $bytes:ident) => {
#[cfg(feature = "ristretto255")]
{
let _ = $item::<crate::Ristretto255>::deserialize(&$bytes[..]);
}
let _ = $item::<p256::NistP256>::deserialize(&$bytes[..]);
};
}
proptest! {
#[test]
fn test_nocrash_oprf_client(bytes in vec(any::<u8>(), 0..200)) {
test_deserialize!(OprfClient, bytes);
}
#[test]
fn test_nocrash_voprf_client(bytes in vec(any::<u8>(), 0..200)) {
test_deserialize!(VoprfClient, bytes);
}
#[test]
fn test_nocrash_poprf_client(bytes in vec(any::<u8>(), 0..200)) {
test_deserialize!(PoprfClient, bytes);
}
#[test]
fn test_nocrash_oprf_server(bytes in vec(any::<u8>(), 0..200)) {
test_deserialize!(OprfServer, bytes);
}
#[test]
fn test_nocrash_voprf_server(bytes in vec(any::<u8>(), 0..200)) {
test_deserialize!(VoprfServer, bytes);
}
#[test]
fn test_nocrash_poprf_server(bytes in vec(any::<u8>(), 0..200)) {
test_deserialize!(PoprfServer, bytes);
}
#[test]
fn test_nocrash_blinded_element(bytes in vec(any::<u8>(), 0..200)) {
test_deserialize!(BlindedElement, bytes);
}
#[test]
fn test_nocrash_evaluation_element(bytes in vec(any::<u8>(), 0..200)) {
test_deserialize!(EvaluationElement, bytes);
}
#[test]
fn test_nocrash_proof(bytes in vec(any::<u8>(), 0..200)) {
test_deserialize!(Proof, bytes);
}
}
}