src/key_exchange/group/ristretto255.rs (49 lines of code) (raw):
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under both the MIT license found in the
// LICENSE-MIT file in the root directory of this source tree and the Apache
// License, Version 2.0 found in the LICENSE-APACHE file in the root directory
// of this source tree.
//! Key Exchange group implementation for ristretto255
use curve25519_dalek::constants::RISTRETTO_BASEPOINT_POINT;
use curve25519_dalek::ristretto::{CompressedRistretto, RistrettoPoint};
use curve25519_dalek::scalar::Scalar;
use digest::core_api::BlockSizeUser;
use digest::{Digest, OutputSizeUser};
use elliptic_curve::hash2curve::{ExpandMsg, ExpandMsgXmd, Expander};
use generic_array::typenum::{IsLess, IsLessOrEqual, U256, U32, U64};
use generic_array::GenericArray;
use rand::{CryptoRng, RngCore};
use voprf::Group;
use zeroize::Zeroize;
use super::KeGroup;
use crate::errors::InternalError;
/// Implementation for Ristretto255.
// This is necessary because Rust lacks specialization, otherwise we could
// implement `KeGroup` for `voprf::Ristretto255`.
pub struct Ristretto255;
impl KeGroup for Ristretto255 {
type Pk = RistrettoPoint;
type PkLen = U32;
type Sk = Scalar;
type SkLen = U32;
fn serialize_pk(pk: &Self::Pk) -> GenericArray<u8, Self::PkLen> {
pk.compress().to_bytes().into()
}
fn deserialize_pk(bytes: &GenericArray<u8, Self::PkLen>) -> Result<Self::Pk, InternalError> {
CompressedRistretto::from_slice(bytes)
.decompress()
.ok_or(InternalError::PointError)
}
fn random_sk<R: RngCore + CryptoRng>(rng: &mut R) -> Self::Sk {
loop {
let scalar = {
#[cfg(not(test))]
{
let mut scalar_bytes = [0u8; 64];
rng.fill_bytes(&mut scalar_bytes);
Scalar::from_bytes_mod_order_wide(&scalar_bytes)
}
// Tests need an exact conversion from bytes to scalar, sampling only 32 bytes
// from rng
#[cfg(test)]
{
let mut scalar_bytes = [0u8; 32];
rng.fill_bytes(&mut scalar_bytes);
Scalar::from_bytes_mod_order(scalar_bytes)
}
};
if scalar != Scalar::zero() && scalar.is_canonical() {
break scalar;
}
}
}
// Implements the `HashToScalar()` function from
// https://www.ietf.org/archive/id/draft-irtf-cfrg-voprf-08.html#section-4.1
fn hash_to_scalar<'a, H>(input: &[&[u8]], dst: &[u8]) -> Result<Self::Sk, InternalError>
where
H: Digest + BlockSizeUser,
H::OutputSize: IsLess<U256> + IsLessOrEqual<H::BlockSize>,
{
let mut uniform_bytes = GenericArray::<_, U64>::default();
ExpandMsgXmd::<H>::expand_message(input, dst, 64)
.map_err(|_| InternalError::HashToScalar)?
.fill_bytes(&mut uniform_bytes);
Ok(Scalar::from_bytes_mod_order_wide(&uniform_bytes.into()))
}
fn public_key(sk: &Self::Sk) -> Self::Pk {
RISTRETTO_BASEPOINT_POINT * sk
}
fn diffie_hellman(pk: &Self::Pk, sk: &Self::Sk) -> GenericArray<u8, Self::PkLen> {
Self::serialize_pk(&(pk * sk))
}
fn zeroize_sk_on_drop(sk: &mut Self::Sk) {
sk.zeroize()
}
fn serialize_sk(sk: &Self::Sk) -> GenericArray<u8, Self::SkLen> {
sk.to_bytes().into()
}
fn deserialize_sk(bytes: &GenericArray<u8, Self::PkLen>) -> Result<Self::Sk, InternalError> {
Scalar::from_canonical_bytes((*bytes).into()).ok_or(InternalError::PointError)
}
}
#[cfg(feature = "ristretto255_voprf")]
impl voprf::CipherSuite for Ristretto255 {
const ID: u16 = voprf::Ristretto255::ID;
type Group = <voprf::Ristretto255 as voprf::CipherSuite>::Group;
type Hash = <voprf::Ristretto255 as voprf::CipherSuite>::Hash;
}
impl Group for Ristretto255 {
type Elem = <voprf::Ristretto255 as Group>::Elem;
type ElemLen = <voprf::Ristretto255 as Group>::ElemLen;
type Scalar = <voprf::Ristretto255 as Group>::Scalar;
type ScalarLen = <voprf::Ristretto255 as Group>::ScalarLen;
fn hash_to_curve<CS: voprf::CipherSuite>(
input: &[&[u8]],
dst: &[u8],
) -> voprf::Result<Self::Elem, voprf::InternalError>
where
<CS::Hash as OutputSizeUser>::OutputSize:
IsLess<U256> + IsLessOrEqual<<CS::Hash as BlockSizeUser>::BlockSize>,
{
<voprf::Ristretto255 as Group>::hash_to_curve::<CS>(input, dst)
}
fn hash_to_scalar<CS: voprf::CipherSuite>(
input: &[&[u8]],
dst: &[u8],
) -> voprf::Result<Self::Scalar, voprf::InternalError>
where
<CS::Hash as OutputSizeUser>::OutputSize:
IsLess<U256> + IsLessOrEqual<<CS::Hash as BlockSizeUser>::BlockSize>,
{
<voprf::Ristretto255 as Group>::hash_to_scalar::<CS>(input, dst)
}
fn base_elem() -> Self::Elem {
<voprf::Ristretto255 as Group>::base_elem()
}
fn identity_elem() -> Self::Elem {
<voprf::Ristretto255 as Group>::identity_elem()
}
fn serialize_elem(elem: Self::Elem) -> GenericArray<u8, Self::ElemLen> {
<voprf::Ristretto255 as Group>::serialize_elem(elem)
}
fn deserialize_elem(element_bits: &[u8]) -> voprf::Result<Self::Elem> {
<voprf::Ristretto255 as Group>::deserialize_elem(element_bits)
}
fn random_scalar<R: RngCore + CryptoRng>(rng: &mut R) -> Self::Scalar {
<voprf::Ristretto255 as Group>::random_scalar(rng)
}
fn invert_scalar(scalar: Self::Scalar) -> Self::Scalar {
<voprf::Ristretto255 as Group>::invert_scalar(scalar)
}
fn is_zero_scalar(scalar: Self::Scalar) -> subtle::Choice {
<voprf::Ristretto255 as Group>::is_zero_scalar(scalar)
}
fn serialize_scalar(scalar: Self::Scalar) -> GenericArray<u8, Self::ScalarLen> {
<voprf::Ristretto255 as Group>::serialize_scalar(scalar)
}
fn deserialize_scalar(scalar_bits: &[u8]) -> voprf::Result<Self::Scalar> {
<voprf::Ristretto255 as Group>::deserialize_scalar(scalar_bits)
}
}