fn run_print()

in tensor-tools/src/main.rs [171:297]


fn run_print(
    file: &std::path::PathBuf,
    names: Vec<String>,
    format: Option<Format>,
    full: bool,
    line_width: Option<usize>,
    device: &Device,
) -> Result<()> {
    if full {
        candle::display::set_print_options_full();
    }
    if let Some(line_width) = line_width {
        candle::display::set_line_width(line_width)
    }
    let format = match format {
        Some(format) => format,
        None => match Format::infer(file) {
            Some(format) => format,
            None => {
                println!(
                    "{file:?}: cannot infer format from file extension, use the --format flag"
                );
                return Ok(());
            }
        },
    };
    match format {
        Format::Npz => {
            let tensors = candle::npy::NpzTensors::new(file)?;
            let names = if names.is_empty() {
                tensors.names().into_iter().map(|v| v.to_string()).collect()
            } else {
                names
            };
            for name in names.iter() {
                println!("==== {name} ====");
                match tensors.get(name)? {
                    Some(tensor) => println!("{tensor}"),
                    None => println!("not found"),
                }
            }
        }
        Format::Safetensors => {
            use candle::safetensors::Load;
            let tensors = unsafe { candle::safetensors::MmapedSafetensors::new(file)? };
            let tensors: std::collections::HashMap<_, _> = tensors.tensors().into_iter().collect();
            let names = if names.is_empty() {
                tensors.keys().map(|v| v.to_string()).collect()
            } else {
                names
            };
            for name in names.iter() {
                println!("==== {name} ====");
                match tensors.get(name) {
                    Some(tensor_view) => {
                        let tensor = tensor_view.load(device)?;
                        println!("{tensor}")
                    }
                    None => println!("not found"),
                }
            }
        }
        Format::Pth => {
            let pth_file = candle::pickle::PthTensors::new(file, None)?;
            let names = if names.is_empty() {
                pth_file
                    .tensor_infos()
                    .keys()
                    .map(|v| v.to_string())
                    .collect()
            } else {
                names
            };
            for name in names.iter() {
                println!("==== {name} ====");
                match pth_file.get(name)? {
                    Some(tensor) => {
                        println!("{tensor}")
                    }
                    None => println!("not found"),
                }
            }
        }
        Format::Pickle => {
            candle::bail!("pickle format is not supported for print")
        }
        Format::Ggml => {
            let mut file = std::fs::File::open(file)?;
            let content = candle::quantized::ggml_file::Content::read(&mut file, device)?;
            let names = if names.is_empty() {
                content.tensors.keys().map(|v| v.to_string()).collect()
            } else {
                names
            };
            for name in names.iter() {
                println!("==== {name} ====");
                match content.tensors.get(name) {
                    Some(tensor) => {
                        let tensor = tensor.dequantize(device)?;
                        println!("{tensor}")
                    }
                    None => println!("not found"),
                }
            }
        }
        Format::Gguf => {
            let mut file = std::fs::File::open(file)?;
            let content = gguf_file::Content::read(&mut file)?;
            let names = if names.is_empty() {
                content.tensor_infos.keys().map(|v| v.to_string()).collect()
            } else {
                names
            };
            for name in names.iter() {
                println!("==== {name} ====");
                match content.tensor(&mut file, name, device) {
                    Ok(tensor) => {
                        let tensor = tensor.dequantize(device)?;
                        println!("{tensor}")
                    }
                    Err(_) => println!("not found"),
                }
            }
        }
    }
    Ok(())
}