arrow-schema/src/extension/canonical/fixed_shape_tensor.rs (101 lines of code) (raw):

// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. //! FixedShapeTensor //! //! <https://arrow.apache.org/docs/format/CanonicalExtensions.html#fixed-shape-tensor> use serde::{Deserialize, Serialize}; use crate::{extension::ExtensionType, ArrowError, DataType}; /// The extension type for fixed shape tensor. /// /// Extension name: `arrow.fixed_shape_tensor`. /// /// The storage type of the extension: `FixedSizeList` where: /// - `value_type` is the data type of individual tensor elements. /// - `list_size` is the product of all the elements in tensor shape. /// /// Extension type parameters: /// - `value_type`: the Arrow data type of individual tensor elements. /// - `shape`: the physical shape of the contained tensors as an array. /// /// Optional parameters describing the logical layout: /// - `dim_names`: explicit names to tensor dimensions as an array. The /// length of it should be equal to the shape length and equal to the /// number of dimensions. /// `dim_names` can be used if the dimensions have /// well-known names and they map to the physical layout (row-major). /// - `permutation`: indices of the desired ordering of the original /// dimensions, defined as an array. /// The indices contain a permutation of the values `[0, 1, .., N-1]` /// where `N` is the number of dimensions. The permutation indicates /// which dimension of the logical layout corresponds to which dimension /// of the physical tensor (the i-th dimension of the logical view /// corresponds to the dimension with number `permutations[i]` of the /// physical tensor). /// Permutation can be useful in case the logical order of the tensor is /// a permutation of the physical order (row-major). /// When logical and physical layout are equal, the permutation will /// always be `([0, 1, .., N-1])` and can therefore be left out. /// /// Description of the serialization: /// The metadata must be a valid JSON object including shape of the /// contained tensors as an array with key `shape` plus optional /// dimension names with keys `dim_names` and ordering of the /// dimensions with key `permutation`. /// Example: `{ "shape": [2, 5]}` /// Example with `dim_names` metadata for NCHW ordered data: /// `{ "shape": [100, 200, 500], "dim_names": ["C", "H", "W"]}` /// Example of permuted 3-dimensional tensor: /// `{ "shape": [100, 200, 500], "permutation": [2, 0, 1]}` /// /// This is the physical layout shape and the shape of the logical layout /// would in this case be `[500, 100, 200]`. /// /// <https://arrow.apache.org/docs/format/CanonicalExtensions.html#fixed-shape-tensor> #[derive(Debug, Clone, PartialEq)] pub struct FixedShapeTensor { /// The data type of individual tensor elements. value_type: DataType, /// The metadata of this extension type. metadata: FixedShapeTensorMetadata, } impl FixedShapeTensor { /// Returns a new fixed shape tensor extension type. /// /// # Error /// /// Return an error if the provided dimension names or permutations are /// invalid. pub fn try_new( value_type: DataType, shape: impl IntoIterator<Item = usize>, dimension_names: Option<Vec<String>>, permutations: Option<Vec<usize>>, ) -> Result<Self, ArrowError> { // TODO: are all data types are suitable as value type? FixedShapeTensorMetadata::try_new(shape, dimension_names, permutations).map(|metadata| { Self { value_type, metadata, } }) } /// Returns the value type of the individual tensor elements. pub fn value_type(&self) -> &DataType { &self.value_type } /// Returns the product of all the elements in tensor shape. pub fn list_size(&self) -> usize { self.metadata.list_size() } /// Returns the number of dimensions in this fixed shape tensor. pub fn dimensions(&self) -> usize { self.metadata.dimensions() } /// Returns the names of the dimensions in this fixed shape tensor, if /// set. pub fn dimension_names(&self) -> Option<&[String]> { self.metadata.dimension_names() } /// Returns the indices of the desired ordering of the original /// dimensions, if set. pub fn permutations(&self) -> Option<&[usize]> { self.metadata.permutations() } } /// Extension type metadata for [`FixedShapeTensor`]. #[derive(Debug, Clone, PartialEq, Deserialize, Serialize)] pub struct FixedShapeTensorMetadata { /// The physical shape of the contained tensors. shape: Vec<usize>, /// Explicit names to tensor dimensions. dim_names: Option<Vec<String>>, /// Indices of the desired ordering of the original dimensions. permutations: Option<Vec<usize>>, } impl FixedShapeTensorMetadata { /// Returns metadata for a fixed shape tensor extension type. /// /// # Error /// /// Return an error if the provided dimension names or permutations are /// invalid. pub fn try_new( shape: impl IntoIterator<Item = usize>, dimension_names: Option<Vec<String>>, permutations: Option<Vec<usize>>, ) -> Result<Self, ArrowError> { let shape = shape.into_iter().collect::<Vec<_>>(); let dimensions = shape.len(); let dim_names = dimension_names.map(|dimension_names| { if dimension_names.len() != dimensions { Err(ArrowError::InvalidArgumentError(format!( "FixedShapeTensor dimension names size mismatch, expected {dimensions}, found {}", dimension_names.len() ))) } else { Ok(dimension_names) } }).transpose()?; let permutations = permutations .map(|permutations| { if permutations.len() != dimensions { Err(ArrowError::InvalidArgumentError(format!( "FixedShapeTensor permutations size mismatch, expected {dimensions}, found {}", permutations.len() ))) } else { let mut sorted_permutations = permutations.clone(); sorted_permutations.sort_unstable(); if (0..dimensions).zip(sorted_permutations).any(|(a, b)| a != b) { Err(ArrowError::InvalidArgumentError(format!( "FixedShapeTensor permutations invalid, expected a permutation of [0, 1, .., N-1], where N is the number of dimensions: {dimensions}" ))) } else { Ok(permutations) } } }) .transpose()?; Ok(Self { shape, dim_names, permutations, }) } /// Returns the product of all the elements in tensor shape. pub fn list_size(&self) -> usize { self.shape.iter().product() } /// Returns the number of dimensions in this fixed shape tensor. pub fn dimensions(&self) -> usize { self.shape.len() } /// Returns the names of the dimensions in this fixed shape tensor, if /// set. pub fn dimension_names(&self) -> Option<&[String]> { self.dim_names.as_ref().map(AsRef::as_ref) } /// Returns the indices of the desired ordering of the original /// dimensions, if set. pub fn permutations(&self) -> Option<&[usize]> { self.permutations.as_ref().map(AsRef::as_ref) } } impl ExtensionType for FixedShapeTensor { const NAME: &'static str = "arrow.fixed_shape_tensor"; type Metadata = FixedShapeTensorMetadata; fn metadata(&self) -> &Self::Metadata { &self.metadata } fn serialize_metadata(&self) -> Option<String> { Some(serde_json::to_string(&self.metadata).expect("metadata serialization")) } fn deserialize_metadata(metadata: Option<&str>) -> Result<Self::Metadata, ArrowError> { metadata.map_or_else( || { Err(ArrowError::InvalidArgumentError( "FixedShapeTensor extension types requires metadata".to_owned(), )) }, |value| { serde_json::from_str(value).map_err(|e| { ArrowError::InvalidArgumentError(format!( "FixedShapeTensor metadata deserialization failed: {e}" )) }) }, ) } fn supports_data_type(&self, data_type: &DataType) -> Result<(), ArrowError> { let expected = DataType::new_fixed_size_list( self.value_type.clone(), i32::try_from(self.list_size()).expect("overflow"), false, ); data_type .equals_datatype(&expected) .then_some(()) .ok_or_else(|| { ArrowError::InvalidArgumentError(format!( "FixedShapeTensor data type mismatch, expected {expected}, found {data_type}" )) }) } fn try_new(data_type: &DataType, metadata: Self::Metadata) -> Result<Self, ArrowError> { match data_type { DataType::FixedSizeList(field, list_size) if !field.is_nullable() => { // Make sure the metadata is valid. let metadata = FixedShapeTensorMetadata::try_new( metadata.shape, metadata.dim_names, metadata.permutations, )?; // Make sure it is compatible with this data type. let expected_size = i32::try_from(metadata.list_size()).expect("overflow"); if *list_size != expected_size { Err(ArrowError::InvalidArgumentError(format!( "FixedShapeTensor list size mismatch, expected {expected_size} (metadata), found {list_size} (data type)" ))) } else { Ok(Self { value_type: field.data_type().clone(), metadata, }) } } data_type => Err(ArrowError::InvalidArgumentError(format!( "FixedShapeTensor data type mismatch, expected FixedSizeList with non-nullable field, found {data_type}" ))), } } } #[cfg(test)] mod tests { #[cfg(feature = "canonical_extension_types")] use crate::extension::CanonicalExtensionType; use crate::{ extension::{EXTENSION_TYPE_METADATA_KEY, EXTENSION_TYPE_NAME_KEY}, Field, }; use super::*; #[test] fn valid() -> Result<(), ArrowError> { let fixed_shape_tensor = FixedShapeTensor::try_new( DataType::Float32, [100, 200, 500], Some(vec!["C".to_owned(), "H".to_owned(), "W".to_owned()]), Some(vec![2, 0, 1]), )?; let mut field = Field::new_fixed_size_list( "", Field::new("", DataType::Float32, false), i32::try_from(fixed_shape_tensor.list_size()).expect("overflow"), false, ); field.try_with_extension_type(fixed_shape_tensor.clone())?; assert_eq!( field.try_extension_type::<FixedShapeTensor>()?, fixed_shape_tensor ); #[cfg(feature = "canonical_extension_types")] assert_eq!( field.try_canonical_extension_type()?, CanonicalExtensionType::FixedShapeTensor(fixed_shape_tensor) ); Ok(()) } #[test] #[should_panic(expected = "Field extension type name missing")] fn missing_name() { let field = Field::new_fixed_size_list("", Field::new("", DataType::Float32, false), 3, false) .with_metadata( [( EXTENSION_TYPE_METADATA_KEY.to_owned(), r#"{ "shape": [100, 200, 500], }"#.to_owned(), )] .into_iter() .collect(), ); field.extension_type::<FixedShapeTensor>(); } #[test] #[should_panic(expected = "FixedShapeTensor data type mismatch, expected FixedSizeList")] fn invalid_type() { let fixed_shape_tensor = FixedShapeTensor::try_new(DataType::Int32, [100, 200, 500], None, None).unwrap(); let field = Field::new_fixed_size_list( "", Field::new("", DataType::Float32, false), i32::try_from(fixed_shape_tensor.list_size()).expect("overflow"), false, ); field.with_extension_type(fixed_shape_tensor); } #[test] #[should_panic(expected = "FixedShapeTensor extension types requires metadata")] fn missing_metadata() { let field = Field::new_fixed_size_list("", Field::new("", DataType::Float32, false), 3, false) .with_metadata( [( EXTENSION_TYPE_NAME_KEY.to_owned(), FixedShapeTensor::NAME.to_owned(), )] .into_iter() .collect(), ); field.extension_type::<FixedShapeTensor>(); } #[test] #[should_panic( expected = "FixedShapeTensor metadata deserialization failed: missing field `shape`" )] fn invalid_metadata() { let fixed_shape_tensor = FixedShapeTensor::try_new(DataType::Float32, [100, 200, 500], None, None).unwrap(); let field = Field::new_fixed_size_list( "", Field::new("", DataType::Float32, false), i32::try_from(fixed_shape_tensor.list_size()).expect("overflow"), false, ) .with_metadata( [ ( EXTENSION_TYPE_NAME_KEY.to_owned(), FixedShapeTensor::NAME.to_owned(), ), ( EXTENSION_TYPE_METADATA_KEY.to_owned(), r#"{ "not-shape": [] }"#.to_owned(), ), ] .into_iter() .collect(), ); field.extension_type::<FixedShapeTensor>(); } #[test] #[should_panic( expected = "FixedShapeTensor dimension names size mismatch, expected 3, found 2" )] fn invalid_metadata_dimension_names() { FixedShapeTensor::try_new( DataType::Float32, [100, 200, 500], Some(vec!["a".to_owned(), "b".to_owned()]), None, ) .unwrap(); } #[test] #[should_panic(expected = "FixedShapeTensor permutations size mismatch, expected 3, found 2")] fn invalid_metadata_permutations_len() { FixedShapeTensor::try_new(DataType::Float32, [100, 200, 500], None, Some(vec![1, 0])) .unwrap(); } #[test] #[should_panic( expected = "FixedShapeTensor permutations invalid, expected a permutation of [0, 1, .., N-1], where N is the number of dimensions: 3" )] fn invalid_metadata_permutations_values() { FixedShapeTensor::try_new( DataType::Float32, [100, 200, 500], None, Some(vec![4, 3, 2]), ) .unwrap(); } }