math/src/field/extensions/cubic.rs (352 lines of code) (raw):
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
use super::{ExtensibleField, FieldElement};
use core::{
convert::TryFrom,
fmt,
ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign},
slice,
};
use utils::{
collections::Vec, string::ToString, AsBytes, ByteReader, ByteWriter, Deserializable,
DeserializationError, Randomizable, Serializable, SliceReader,
};
// QUADRATIC EXTENSION FIELD
// ================================================================================================
/// Represents an element in a cubic extension of a [StarkField](crate::StarkField).
///
/// The extension element is defined as α + β * φ + γ * φ^2, where φ is a root of in irreducible
/// polynomial defined by the implementation of the [ExtensibleField] trait, and α, β, γ are base
/// field elements.
#[repr(C)]
#[derive(Copy, Clone, Debug, PartialEq, Eq, Default)]
pub struct CubeExtension<B: ExtensibleField<3>>(B, B, B);
impl<B: ExtensibleField<3>> CubeExtension<B> {
/// Returns a new extension element instantiated from the provided base elements.
pub fn new(a: B, b: B, c: B) -> Self {
Self(a, b, c)
}
/// Returns true if the base field specified by B type parameter supports cubic extensions.
pub fn is_supported() -> bool {
<B as ExtensibleField<3>>::is_supported()
}
/// Converts a vector of base elements into a vector of elements in a cubic extension field
/// by fusing three adjacent base elements together. The output vector is half the length of
/// the source vector.
fn base_to_cubic_vector(source: Vec<B>) -> Vec<Self> {
debug_assert!(
source.len() % 3 == 0,
"source vector length must be divisible by three, but was {}",
source.len()
);
let mut v = core::mem::ManuallyDrop::new(source);
let p = v.as_mut_ptr();
let len = v.len() / 3;
let cap = v.capacity() / 3;
unsafe { Vec::from_raw_parts(p as *mut Self, len, cap) }
}
}
impl<B: ExtensibleField<3>> FieldElement for CubeExtension<B> {
type PositiveInteger = B::PositiveInteger;
type BaseField = B;
const ELEMENT_BYTES: usize = B::ELEMENT_BYTES * 3;
const IS_CANONICAL: bool = B::IS_CANONICAL;
const ZERO: Self = Self(B::ZERO, B::ZERO, B::ZERO);
const ONE: Self = Self(B::ONE, B::ZERO, B::ZERO);
#[inline]
fn inv(self) -> Self {
if self == Self::ZERO {
return self;
}
let x = [self.0, self.1, self.2];
let c1 = <B as ExtensibleField<3>>::frobenius(x);
let c2 = <B as ExtensibleField<3>>::frobenius(c1);
let numerator = <B as ExtensibleField<3>>::mul(c1, c2);
let norm = <B as ExtensibleField<3>>::mul(x, numerator);
debug_assert_eq!(norm[1], B::ZERO, "norm must be in the base field");
debug_assert_eq!(norm[2], B::ZERO, "norm must be in the base field");
let denom_inv = norm[0].inv();
Self(
numerator[0] * denom_inv,
numerator[1] * denom_inv,
numerator[2] * denom_inv,
)
}
#[inline]
fn conjugate(&self) -> Self {
let result = <B as ExtensibleField<3>>::frobenius([self.0, self.1, self.2]);
Self(result[0], result[1], result[2])
}
fn elements_as_bytes(elements: &[Self]) -> &[u8] {
unsafe {
slice::from_raw_parts(
elements.as_ptr() as *const u8,
elements.len() * Self::ELEMENT_BYTES,
)
}
}
unsafe fn bytes_as_elements(bytes: &[u8]) -> Result<&[Self], DeserializationError> {
if bytes.len() % Self::ELEMENT_BYTES != 0 {
return Err(DeserializationError::InvalidValue(format!(
"number of bytes ({}) does not divide into whole number of field elements",
bytes.len(),
)));
}
let p = bytes.as_ptr();
let len = bytes.len() / Self::ELEMENT_BYTES;
// make sure the bytes are aligned on the boundary consistent with base element alignment
if (p as usize) % Self::BaseField::ELEMENT_BYTES != 0 {
return Err(DeserializationError::InvalidValue(
"slice memory alignment is not valid for this field element type".to_string(),
));
}
Ok(slice::from_raw_parts(p as *const Self, len))
}
fn zeroed_vector(n: usize) -> Vec<Self> {
// get three times the number of base elements and re-interpret them as cubic field
// elements
let result = B::zeroed_vector(n * 3);
Self::base_to_cubic_vector(result)
}
fn as_base_elements(elements: &[Self]) -> &[Self::BaseField] {
let ptr = elements.as_ptr();
let len = elements.len() * 3;
unsafe { slice::from_raw_parts(ptr as *const Self::BaseField, len) }
}
}
impl<B: ExtensibleField<3>> Randomizable for CubeExtension<B> {
const VALUE_SIZE: usize = Self::ELEMENT_BYTES;
fn from_random_bytes(bytes: &[u8]) -> Option<Self> {
Self::try_from(bytes).ok()
}
}
impl<B: ExtensibleField<3>> fmt::Display for CubeExtension<B> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "({}, {}, {})", self.0, self.1, self.2)
}
}
// OVERLOADED OPERATORS
// ------------------------------------------------------------------------------------------------
impl<B: ExtensibleField<3>> Add for CubeExtension<B> {
type Output = Self;
#[inline]
fn add(self, rhs: Self) -> Self {
Self(self.0 + rhs.0, self.1 + rhs.1, self.2 + rhs.2)
}
}
impl<B: ExtensibleField<3>> AddAssign for CubeExtension<B> {
#[inline]
fn add_assign(&mut self, rhs: Self) {
*self = *self + rhs
}
}
impl<B: ExtensibleField<3>> Sub for CubeExtension<B> {
type Output = Self;
#[inline]
fn sub(self, rhs: Self) -> Self {
Self(self.0 - rhs.0, self.1 - rhs.1, self.2 - rhs.2)
}
}
impl<B: ExtensibleField<3>> SubAssign for CubeExtension<B> {
#[inline]
fn sub_assign(&mut self, rhs: Self) {
*self = *self - rhs;
}
}
impl<B: ExtensibleField<3>> Mul for CubeExtension<B> {
type Output = Self;
#[inline]
fn mul(self, rhs: Self) -> Self {
let result =
<B as ExtensibleField<3>>::mul([self.0, self.1, self.2], [rhs.0, rhs.1, rhs.2]);
Self(result[0], result[1], result[2])
}
}
impl<B: ExtensibleField<3>> MulAssign for CubeExtension<B> {
#[inline]
fn mul_assign(&mut self, rhs: Self) {
*self = *self * rhs
}
}
impl<B: ExtensibleField<3>> Div for CubeExtension<B> {
type Output = Self;
#[inline]
#[allow(clippy::suspicious_arithmetic_impl)]
fn div(self, rhs: Self) -> Self {
self * rhs.inv()
}
}
impl<B: ExtensibleField<3>> DivAssign for CubeExtension<B> {
#[inline]
fn div_assign(&mut self, rhs: Self) {
*self = *self / rhs
}
}
impl<B: ExtensibleField<3>> Neg for CubeExtension<B> {
type Output = Self;
#[inline]
fn neg(self) -> Self {
Self(-self.0, -self.1, -self.2)
}
}
// TYPE CONVERSIONS
// ------------------------------------------------------------------------------------------------
impl<B: ExtensibleField<3>> From<B> for CubeExtension<B> {
fn from(value: B) -> Self {
Self(value, B::ZERO, B::ZERO)
}
}
impl<B: ExtensibleField<3>> From<u128> for CubeExtension<B> {
fn from(value: u128) -> Self {
Self(B::from(value), B::ZERO, B::ZERO)
}
}
impl<B: ExtensibleField<3>> From<u64> for CubeExtension<B> {
fn from(value: u64) -> Self {
Self(B::from(value), B::ZERO, B::ZERO)
}
}
impl<B: ExtensibleField<3>> From<u32> for CubeExtension<B> {
fn from(value: u32) -> Self {
Self(B::from(value), B::ZERO, B::ZERO)
}
}
impl<B: ExtensibleField<3>> From<u16> for CubeExtension<B> {
fn from(value: u16) -> Self {
Self(B::from(value), B::ZERO, B::ZERO)
}
}
impl<B: ExtensibleField<3>> From<u8> for CubeExtension<B> {
fn from(value: u8) -> Self {
Self(B::from(value), B::ZERO, B::ZERO)
}
}
impl<'a, B: ExtensibleField<3>> TryFrom<&'a [u8]> for CubeExtension<B> {
type Error = DeserializationError;
/// Converts a slice of bytes into a field element; returns error if the value encoded in bytes
/// is not a valid field element. The bytes are assumed to be in little-endian byte order.
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
if bytes.len() < Self::ELEMENT_BYTES {
return Err(DeserializationError::InvalidValue(format!(
"not enough bytes for a full field element; expected {} bytes, but was {} bytes",
Self::ELEMENT_BYTES,
bytes.len(),
)));
}
if bytes.len() > Self::ELEMENT_BYTES {
return Err(DeserializationError::InvalidValue(format!(
"too many bytes for a field element; expected {} bytes, but was {} bytes",
Self::ELEMENT_BYTES,
bytes.len(),
)));
}
let mut reader = SliceReader::new(bytes);
Self::read_from(&mut reader)
}
}
impl<B: ExtensibleField<3>> AsBytes for CubeExtension<B> {
fn as_bytes(&self) -> &[u8] {
// TODO: take endianness into account
let self_ptr: *const Self = self;
unsafe { slice::from_raw_parts(self_ptr as *const u8, Self::ELEMENT_BYTES) }
}
}
// SERIALIZATION / DESERIALIZATION
// ------------------------------------------------------------------------------------------------
impl<B: ExtensibleField<3>> Serializable for CubeExtension<B> {
fn write_into<W: ByteWriter>(&self, target: &mut W) {
self.0.write_into(target);
self.1.write_into(target);
self.2.write_into(target);
}
}
impl<B: ExtensibleField<3>> Deserializable for CubeExtension<B> {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let value0 = B::read_from(source)?;
let value1 = B::read_from(source)?;
let value2 = B::read_from(source)?;
Ok(Self(value0, value1, value2))
}
}
// TESTS
// ================================================================================================
#[cfg(test)]
mod tests {
use super::{CubeExtension, DeserializationError, FieldElement, Vec};
use crate::field::f64::BaseElement;
use rand_utils::rand_value;
// BASIC ALGEBRA
// --------------------------------------------------------------------------------------------
#[test]
fn add() {
// identity
let r: CubeExtension<BaseElement> = rand_value();
assert_eq!(r, r + CubeExtension::<BaseElement>::ZERO);
// test random values
let r1: CubeExtension<BaseElement> = rand_value();
let r2: CubeExtension<BaseElement> = rand_value();
let expected = CubeExtension(r1.0 + r2.0, r1.1 + r2.1, r1.2 + r2.2);
assert_eq!(expected, r1 + r2);
}
#[test]
fn sub() {
// identity
let r: CubeExtension<BaseElement> = rand_value();
assert_eq!(r, r - CubeExtension::<BaseElement>::ZERO);
// test random values
let r1: CubeExtension<BaseElement> = rand_value();
let r2: CubeExtension<BaseElement> = rand_value();
let expected = CubeExtension(r1.0 - r2.0, r1.1 - r2.1, r1.2 - r2.2);
assert_eq!(expected, r1 - r2);
}
// INITIALIZATION
// --------------------------------------------------------------------------------------------
#[test]
fn zeroed_vector() {
let result = CubeExtension::<BaseElement>::zeroed_vector(4);
assert_eq!(4, result.len());
for element in result.into_iter() {
assert_eq!(CubeExtension::<BaseElement>::ZERO, element);
}
}
// SERIALIZATION / DESERIALIZATION
// --------------------------------------------------------------------------------------------
#[test]
fn elements_as_bytes() {
let source = vec![
CubeExtension(
BaseElement::new(1),
BaseElement::new(2),
BaseElement::new(3),
),
CubeExtension(
BaseElement::new(4),
BaseElement::new(5),
BaseElement::new(6),
),
];
let expected: Vec<u8> = vec![
1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0,
0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0,
];
assert_eq!(
expected,
CubeExtension::<BaseElement>::elements_as_bytes(&source)
);
}
#[test]
fn bytes_as_elements() {
let bytes: Vec<u8> = vec![
1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0,
0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 7,
];
let expected = vec![
CubeExtension(
BaseElement::new(1),
BaseElement::new(2),
BaseElement::new(3),
),
CubeExtension(
BaseElement::new(4),
BaseElement::new(5),
BaseElement::new(6),
),
];
let result = unsafe { CubeExtension::<BaseElement>::bytes_as_elements(&bytes[..48]) };
assert!(result.is_ok());
assert_eq!(expected, result.unwrap());
let result = unsafe { CubeExtension::<BaseElement>::bytes_as_elements(&bytes) };
assert!(matches!(result, Err(DeserializationError::InvalidValue(_))));
let result = unsafe { CubeExtension::<BaseElement>::bytes_as_elements(&bytes[1..]) };
assert!(matches!(result, Err(DeserializationError::InvalidValue(_))));
}
// UTILITIES
// --------------------------------------------------------------------------------------------
#[test]
fn as_base_elements() {
let elements = vec![
CubeExtension(
BaseElement::new(1),
BaseElement::new(2),
BaseElement::new(3),
),
CubeExtension(
BaseElement::new(4),
BaseElement::new(5),
BaseElement::new(6),
),
];
let expected = vec![
BaseElement::new(1),
BaseElement::new(2),
BaseElement::new(3),
BaseElement::new(4),
BaseElement::new(5),
BaseElement::new(6),
];
assert_eq!(
expected,
CubeExtension::<BaseElement>::as_base_elements(&elements)
);
}
}