crates/ratchet-core/src/strides.rs (151 lines of code) (raw):
use std::ops::{Index, IndexMut, RangeFrom, RangeTo};
use std::slice::Iter;
use crate::{rvec, RVec, Shape};
use encase::impl_wrapper;
#[derive(Clone, PartialEq, Eq, Default, Hash)]
pub struct Strides(RVec<isize>);
impl_wrapper!(Strides; using);
impl Strides {
pub fn to_vec(&self) -> Vec<isize> {
self.0.to_vec()
}
pub fn iter(&self) -> Iter<'_, isize> {
self.0.iter()
}
pub fn transpose(&mut self) {
let rank = self.0.len();
if rank < 2 {
return;
}
self.0.swap(rank - 2, rank - 1);
}
pub fn rank(&self) -> usize {
self.0.len()
}
pub fn as_slice(&self) -> &[isize] {
&self.0
}
}
impl std::fmt::Debug for Strides {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut shape = format!("[{}", self.0.first().unwrap_or(&0));
for dim in self.0.iter().skip(1) {
shape.push_str(&format!("x{}", dim));
}
write!(f, "{}]", shape)
}
}
impl core::ops::Deref for Strides {
type Target = [isize];
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl Index<usize> for Strides {
type Output = isize;
fn index(&self, index: usize) -> &Self::Output {
&self.0[index]
}
}
impl IndexMut<usize> for Strides {
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
&mut self.0[index]
}
}
impl Index<RangeFrom<usize>> for Strides {
type Output = [isize];
fn index(&self, index: RangeFrom<usize>) -> &Self::Output {
&self.0[index]
}
}
impl Index<RangeTo<usize>> for Strides {
type Output = [isize];
fn index(&self, index: RangeTo<usize>) -> &Self::Output {
&self.0[index]
}
}
impl From<&Shape> for Strides {
fn from(shape: &Shape) -> Self {
let mut strides = rvec![];
let mut stride = 1;
for size in shape.inner().iter().rev() {
strides.push(stride);
stride *= *size as isize;
}
strides.reverse();
Self(strides)
}
}
impl From<Vec<isize>> for Strides {
fn from(strides: Vec<isize>) -> Self {
Self(strides.into())
}
}
impl From<&[isize]> for Strides {
fn from(strides: &[isize]) -> Self {
Self(strides.into())
}
}
impl From<&Strides> for [u32; 3] {
fn from(strides: &Strides) -> Self {
assert!(strides.0.len() <= 3);
let mut array = [0; 3];
for (i, &stride) in strides.0.iter().enumerate() {
array[i] = stride as u32;
}
array
}
}
impl From<&Strides> for glam::UVec3 {
fn from(strides: &Strides) -> Self {
let array: [u32; 3] = strides.into();
glam::UVec3::from(array)
}
}
impl From<&Strides> for [u32; 4] {
fn from(strides: &Strides) -> Self {
assert!(strides.0.len() <= 4);
let mut array = [0; 4];
for (i, &stride) in strides.0.iter().enumerate() {
array[i] = stride as u32;
}
array
}
}
impl From<&Strides> for [usize; 4] {
fn from(strides: &Strides) -> Self {
assert!(strides.0.len() <= 4);
let mut array = [0; 4];
for (i, &stride) in strides.0.iter().enumerate() {
array[i] = stride as usize;
}
array
}
}
impl From<&Strides> for glam::UVec4 {
fn from(strides: &Strides) -> Self {
let array: [u32; 4] = strides.into();
glam::UVec4::from(array)
}
}
impl From<Strides> for glam::IVec3 {
fn from(strides: Strides) -> Self {
(&strides).into()
}
}
impl From<&Strides> for glam::IVec3 {
fn from(strides: &Strides) -> Self {
glam::IVec3::new(strides.0[0] as _, strides.0[1] as _, strides.0[2] as _)
}
}
#[cfg(test)]
mod tests {
use crate::shape;
#[test]
fn test_strides() {
use super::*;
let shape = shape![2, 3, 4];
let strides = Strides::from(&shape);
assert_eq!(strides.to_vec(), vec![12, 4, 1]);
}
}