function/src/principal_components_analysis.rs (74 lines of code) (raw):

// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. use std::convert::TryFrom; use std::format; use std::io::{self, BufRead, BufReader, Write}; use teaclave_types::{FunctionArguments, FunctionRuntime}; use rusty_machine::learning::pca::PCA; use rusty_machine::learning::UnSupModel; use rusty_machine::linalg; use rusty_machine::linalg::BaseMatrix; const IN_DATA: &str = "input_data"; const OUT_RESULT: &str = "output_data"; #[derive(Default)] pub struct PrincipalComponentsAnalysis; #[derive(serde::Deserialize)] struct PrincipalComponentsAnalysisArguments { n: usize, center: bool, feature_size: usize, } impl TryFrom<FunctionArguments> for PrincipalComponentsAnalysisArguments { type Error = anyhow::Error; fn try_from(arguments: FunctionArguments) -> Result<Self, Self::Error> { use anyhow::Context; serde_json::from_str(&arguments.into_string()).context("Cannot deserialize arguments") } } impl PrincipalComponentsAnalysis { pub const NAME: &'static str = "builtin_principal_components_analysis"; pub fn new() -> Self { Default::default() } pub fn run( &self, arguments: FunctionArguments, runtime: FunctionRuntime, ) -> anyhow::Result<String> { let args = PrincipalComponentsAnalysisArguments::try_from(arguments)?; let input = runtime.open_input(IN_DATA)?; let (flattend_features, targets) = parse_input_data(input, args.feature_size)?; let data_size = targets.len(); let input_features = linalg::Matrix::new(data_size, args.feature_size, flattend_features); let mut model = PCA::new(args.n, args.center); model.train(&input_features)?; let predict_result = model.predict(&input_features)?; let mut output = runtime.create_output(OUT_RESULT)?; for i in 0..predict_result.rows() { for j in 0..predict_result.cols() { if j == predict_result.cols() - 1 { write!(&mut output, "{:?}", predict_result[[i, j]])?; } else { write!(&mut output, "{:?},", predict_result[[i, j]])?; } } writeln!(&mut output)?; } Ok(format!( "transform {} rows * {} cols lines of data.", predict_result.rows(), predict_result.cols() )) } } fn parse_input_data( input: impl io::Read, feature_size: usize, ) -> anyhow::Result<(Vec<f64>, Vec<f64>)> { let reader = BufReader::new(input); let mut targets = Vec::<f64>::new(); let mut features = Vec::new(); for line_result in reader.lines() { let line = line_result?; let trimed_line = line.trim(); anyhow::ensure!(!trimed_line.is_empty(), "Empty line"); let mut v: Vec<f64> = trimed_line .split(',') .map(|x| x.parse::<f64>()) .collect::<std::result::Result<_, _>>()?; anyhow::ensure!( v.len() == feature_size + 1, "Data format error: column len = {}, expected = {}", v.len(), feature_size + 1 ); let label = v.swap_remove(feature_size); targets.push(label); features.extend(v); } Ok((features, targets)) } #[cfg(feature = "enclave_unit_test")] pub mod tests { use super::*; use serde_json::json; use std::path::Path; use std::untrusted::fs; use teaclave_crypto::*; use teaclave_runtime::*; use teaclave_test_utils::*; use teaclave_types::*; pub fn run_tests() -> bool { run_tests!(test_pca_predict) } fn test_pca_predict() { let args = FunctionArguments::from_json(json!({ "n": 2, "feature_size": 4, "center":true })) .unwrap(); let base = Path::new("fixtures/functions/princopal_components_analysis"); let input_data_file = base.join("input.txt"); let output_data_file = base.join("result.txt"); let expected_output = base.join("expected_result.txt"); let input_files = StagedFiles::new(hashmap!( IN_DATA => StagedFileInfo::new(&input_data_file, TeaclaveFile128Key::random(), FileAuthTag::mock()), )); let output_files = StagedFiles::new(hashmap!( OUT_RESULT => StagedFileInfo::new(&output_data_file, TeaclaveFile128Key::random(), FileAuthTag::mock()), )); let runtime = Box::new(RawIoRuntime::new(input_files, output_files)); let summary = PrincipalComponentsAnalysis::new() .run(args, runtime) .unwrap(); assert_eq!(summary, "transform 90 rows * 2 cols lines of data."); let result = fs::read_to_string(&output_data_file).unwrap(); let expected = fs::read_to_string(&expected_output).unwrap(); assert_eq!(&result[..], &expected[..]); } }