wasm/wasm-sharding-js/src/internal_module/tensorflow_module.rs (595 lines of code) (raw):

/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use crate::*; /// reference https://github.com/second-state/wasmedge_tensorflow_interface mod wasmedge_tensorflow { /// wasmedge_tensorflow host functions. #[link(wasm_import_module = "wasmedge_tensorflow")] extern "C" { pub fn wasmedge_tensorflow_create_session(model_buf: *const u8, model_buf_len: u32) -> u64; pub fn wasmedge_tensorflow_delete_session(context: u64); pub fn wasmedge_tensorflow_run_session(context: u64) -> u32; pub fn wasmedge_tensorflow_get_output_tensor( context: u64, output_name: *const u8, output_name_len: u32, index: u32, ) -> u64; pub fn wasmedge_tensorflow_get_tensor_len(tensor_ptr: u64) -> u32; pub fn wasmedge_tensorflow_get_tensor_data(tensor_ptr: u64, buf: *mut u8); pub fn wasmedge_tensorflow_append_input( context: u64, input_name: *const u8, input_name_len: u32, index: u32, dim_vec: *const u8, dim_cnt: u32, data_type: u32, tensor_buf: *const u8, tensor_buf_len: u32, ); pub fn wasmedge_tensorflow_append_output( context: u64, output_name: *const u8, output_name_len: u32, index: u32, ); pub fn wasmedge_tensorflow_clear_input(context: u64); pub fn wasmedge_tensorflow_clear_output(context: u64); } /// wasmedge_tensorflowlite host functions. #[link(wasm_import_module = "wasmedge_tensorflowlite")] extern "C" { pub fn wasmedge_tensorflowlite_create_session( model_buf: *const u8, model_buf_len: u32, ) -> u64; pub fn wasmedge_tensorflowlite_delete_session(context: u64); pub fn wasmedge_tensorflowlite_run_session(context: u64) -> u32; pub fn wasmedge_tensorflowlite_get_output_tensor( context: u64, output_name: *const u8, output_name_len: u32, ) -> u64; pub fn wasmedge_tensorflowlite_get_tensor_len(tensor_ptr: u64) -> u32; pub fn wasmedge_tensorflowlite_get_tensor_data(tensor_ptr: u64, buf: *mut u8); pub fn wasmedge_tensorflowlite_append_input( context: u64, input_name: *const u8, input_name_len: u32, tensor_buf: *const u8, tensor_buf_len: u32, ); } /// wasmedge_image host helper functions. #[link(wasm_import_module = "wasmedge_image")] extern "C" { pub fn wasmedge_image_load_jpg_to_rgb8( img_buf: *const u8, img_buf_len: u32, img_width: u32, img_height: u32, dst_buf: *mut u8, ) -> u32; pub fn wasmedge_image_load_jpg_to_bgr8( img_buf: *const u8, img_buf_len: u32, img_width: u32, img_height: u32, dst_buf: *mut u8, ) -> u32; pub fn wasmedge_image_load_jpg_to_rgb32f( img_buf: *const u8, img_buf_len: u32, img_width: u32, img_height: u32, dst_buf: *mut u8, ) -> u32; pub fn wasmedge_image_load_jpg_to_bgr32f( img_buf: *const u8, img_buf_len: u32, img_width: u32, img_height: u32, dst_buf: *mut u8, ) -> u32; pub fn wasmedge_image_load_png_to_rgb8( img_buf: *const u8, img_buf_len: u32, img_width: u32, img_height: u32, dst_buf: *mut u8, ) -> u32; pub fn wasmedge_image_load_png_to_bgr8( img_buf: *const u8, img_buf_len: u32, img_width: u32, img_height: u32, dst_buf: *mut u8, ) -> u32; pub fn wasmedge_image_load_png_to_rgb32f( img_buf: *const u8, img_buf_len: u32, img_width: u32, img_height: u32, dst_buf: *mut u8, ) -> u32; pub fn wasmedge_image_load_png_to_bgr32f( img_buf: *const u8, img_buf_len: u32, img_width: u32, img_height: u32, dst_buf: *mut u8, ) -> u32; } } //--------------------- mod tensorflow { use super::wasmedge_tensorflow::*; use crate::*; use std::path::Path; pub enum InputDataType { F32 = 1, F64 = 2, I32 = 3, U8 = 4, U16 = 17, U32 = 22, U64 = 23, I16 = 5, I8 = 6, I64 = 9, Bool = 10, } pub struct TensorflowSession { context: u64, data: Vec<u8>, } impl Drop for TensorflowSession { fn drop(&mut self) { unsafe { wasmedge_tensorflow_delete_session(self.context); } } } impl TensorflowSession { pub fn new_from_path<T: AsRef<Path>>(path: T) -> Result<Self, String> { let data = std::fs::read(path).map_err(|e| e.to_string())?; let context = unsafe { wasmedge_tensorflow_create_session( data.as_slice().as_ptr().cast(), data.len() as u32, ) }; Ok(TensorflowSession { context, data }) } pub unsafe fn add_input( &mut self, name: &str, tensor_buf: *const u8, tensor_buf_len: u32, data_type: u32, shape: &[i64], ) { let mut idx: u32 = 0; let name_pair: Vec<&str> = name.split(":").collect(); if name_pair.len() > 1 { idx = name_pair[1].parse().unwrap(); } let input_name = make_c_string(name_pair[0]); wasmedge_tensorflow_append_input( self.context, input_name.as_ptr() as *const u8, input_name.as_bytes().len() as u32, idx, shape.as_ptr() as *const u8, shape.len() as u32, data_type, tensor_buf, tensor_buf_len, ); } pub unsafe fn add_output(&mut self, name: &str) { let name_pair: Vec<&str> = name.split(":").collect(); let output_name = make_c_string(name_pair[0]); let mut idx = 0; if name_pair.len() > 1 { idx = name_pair[1].parse().unwrap() }; wasmedge_tensorflow_append_output( self.context, output_name.as_ptr() as *const u8, output_name.as_bytes().len() as u32, idx, ); } pub unsafe fn run(&mut self) { wasmedge_tensorflow_run_session(self.context); } pub unsafe fn get_output(&self, name: &str) -> Vec<u8> { // Parse name and operation index. let name_pair: Vec<&str> = name.split(":").collect(); let output_name = make_c_string(name_pair[0]); let mut idx = 0; if name_pair.len() > 1 { idx = name_pair[1].parse().unwrap() }; // Get tensor data. let tensor = wasmedge_tensorflow_get_output_tensor( self.context, output_name.as_ptr() as *const u8, output_name.as_bytes().len() as u32, idx, ); let buf_len = wasmedge_tensorflow_get_tensor_len(tensor) as usize; if buf_len == 0 { return Vec::new(); } let mut data = vec![0u8; buf_len]; wasmedge_tensorflow_get_tensor_data(tensor, data.as_mut_ptr() as *mut u8); return data; } pub unsafe fn clear_input(&mut self) { wasmedge_tensorflow_clear_input(self.context); } pub unsafe fn clear_output(&mut self) { wasmedge_tensorflow_clear_output(self.context); } } impl TensorflowSession { fn js_add_input_8u( &mut self, _: &mut JsObject, ctx: &mut Context, argv: &[JsValue], ) -> JsValue { let name = if let Some(JsValue::String(s)) = argv.get(0) { s.to_string() } else { return ctx.throw_type_error("'name' must be of type string").into(); }; let tensor_buf = if let Some(JsValue::ArrayBuffer(buf)) = argv.get(1) { buf.as_ref() } else { return ctx .throw_type_error("'tensor_buf' must be of type buffer") .into(); }; let shape = if let Some(JsValue::Array(arr)) = argv.get(2) { match arr.to_vec() { Ok(a) => a, Err(e) => return e.into(), } } else { return ctx.throw_type_error("'shape' must be of type array").into(); }; let mut shape_arr = vec![]; for i in shape { let v = match i { JsValue::Int(i) => i as i64, JsValue::Float(i) => i as i64, _ => { return ctx .throw_type_error("'shape' must be of type number array") .into() } }; shape_arr.push(v); } unsafe { self.add_input( name.as_str(), tensor_buf.as_ptr(), tensor_buf.len() as u32, InputDataType::U8 as u32, shape_arr.as_slice(), ); } JsValue::UnDefined } fn js_add_input_32f( &mut self, _: &mut JsObject, ctx: &mut Context, argv: &[JsValue], ) -> JsValue { let name = if let Some(JsValue::String(s)) = argv.get(0) { s.to_string() } else { return ctx.throw_type_error("'name' must be of type string").into(); }; let tensor_buf = if let Some(JsValue::ArrayBuffer(buf)) = argv.get(1) { buf.as_ref() } else { return ctx .throw_type_error("'tensor_buf' must be of type buffer") .into(); }; let shape = if let Some(JsValue::Array(arr)) = argv.get(2) { match arr.to_vec() { Ok(a) => a, Err(e) => return e.into(), } } else { return ctx.throw_type_error("'shape' must be of type array").into(); }; let mut shape_arr = vec![]; for i in shape { let v = match i { JsValue::Int(i) => i as i64, JsValue::Float(i) => i as i64, _ => { return ctx .throw_type_error("'shape' must be of type number array") .into() } }; shape_arr.push(v); } unsafe { self.add_input( name.as_str(), tensor_buf.as_ptr(), tensor_buf.len() as u32, InputDataType::F32 as u32, shape_arr.as_slice(), ); } JsValue::UnDefined } fn js_add_output( &mut self, _: &mut JsObject, ctx: &mut Context, argv: &[JsValue], ) -> JsValue { let name = if let Some(JsValue::String(s)) = argv.get(0) { s.to_string() } else { return ctx.throw_type_error("'name' must be of type string").into(); }; unsafe { self.add_output(name.as_str()); } JsValue::UnDefined } fn js_run(&mut self, _: &mut JsObject, _ctx: &mut Context, _argv: &[JsValue]) -> JsValue { unsafe { self.run() } JsValue::UnDefined } fn js_get_output( &mut self, _: &mut JsObject, ctx: &mut Context, argv: &[JsValue], ) -> JsValue { let name = if let Some(JsValue::String(s)) = argv.get(0) { s.to_string() } else { return ctx.throw_type_error("'name' must be of type string").into(); }; let data = unsafe { self.get_output(name.as_str()) }; ctx.new_array_buffer(data.as_slice()).into() } fn js_clear_output( &mut self, _: &mut JsObject, _ctx: &mut Context, _argv: &[JsValue], ) -> JsValue { unsafe { self.clear_output() } JsValue::UnDefined } fn js_clear_input( &mut self, _: &mut JsObject, _ctx: &mut Context, _argv: &[JsValue], ) -> JsValue { unsafe { self.clear_input() } JsValue::UnDefined } } impl JsClassDef for TensorflowSession { type RefType = TensorflowSession; const CLASS_NAME: &'static str = "TensorflowSession\0"; const CONSTRUCTOR_ARGC: u8 = 1; const FIELDS: &'static [JsClassField<Self::RefType>] = &[]; const METHODS: &'static [JsClassMethod<Self::RefType>] = &[ ("add_input_8u", 3, Self::js_add_input_8u), ("add_input_32f", 3, Self::js_add_input_32f), ("add_output", 1, Self::js_add_output), ("run", 0, Self::js_run), ("get_output", 1, Self::js_get_output), ("clear_output", 0, Self::js_clear_output), ("clear_input", 0, Self::js_clear_input), ]; unsafe fn mut_class_id_ptr() -> &'static mut u32 { static mut CLASS_ID: u32 = 0; &mut CLASS_ID } fn constructor_fn( ctx: &mut Context, argv: &[JsValue], ) -> Result<TensorflowSession, JsValue> { match argv.get(0).ok_or(JsValue::UnDefined)? { JsValue::String(path) => { let path = path.to_string(); let session = TensorflowSession::new_from_path(path) .map_err(|e| ctx.throw_internal_type_error(e.as_str()))?; Ok(session) } _ => Err(JsValue::UnDefined), } } } struct TensorflowModDef; impl ModuleInit for TensorflowModDef { fn init_module(ctx: &mut Context, m: &mut JsModuleDef) { let ctor = register_class::<TensorflowSession>(ctx); m.add_export(TensorflowSession::CLASS_NAME, ctor) } } pub fn init_module_tensorflow(ctx: &mut Context) { ctx.register_module( "tensorflow\0", TensorflowModDef, &[TensorflowSession::CLASS_NAME], ) } } mod tensorflow_lite { use super::{tensorflow, wasmedge_tensorflow::*}; use crate::*; use std::path::Path; struct TensorflowLiteSession { context: u64, data: Vec<u8>, } impl Drop for TensorflowLiteSession { fn drop(&mut self) { unsafe { wasmedge_tensorflowlite_delete_session(self.context); } } } impl TensorflowLiteSession { pub fn new_from_path<T: AsRef<Path>>(path: T) -> Result<Self, String> { let data = std::fs::read(path).map_err(|e| e.to_string())?; let context = unsafe { wasmedge_tensorflowlite_create_session( data.as_slice().as_ptr().cast(), data.len() as u32, ) }; Ok(TensorflowLiteSession { context, data }) } pub unsafe fn add_input(&mut self, name: &str, tensor_buf: *const u8, tensor_buf_len: u32) { let input_name = make_c_string(name); wasmedge_tensorflowlite_append_input( self.context, input_name.as_ptr() as *const u8, input_name.as_bytes().len() as u32, tensor_buf as *const u8, tensor_buf_len, ); } pub unsafe fn run(&mut self) { wasmedge_tensorflowlite_run_session(self.context); } pub unsafe fn get_output(&self, name: &str) -> Vec<u8> { // Parse name and operation index. let output_name = make_c_string(name); // Get tensor data. let tensor = wasmedge_tensorflowlite_get_output_tensor( self.context, output_name.as_ptr() as *const u8, output_name.as_bytes().len() as u32, ); let buf_len = wasmedge_tensorflowlite_get_tensor_len(tensor) as usize; if buf_len == 0 { return Vec::new(); } let mut data = vec![0u8; buf_len]; wasmedge_tensorflowlite_get_tensor_data(tensor, data.as_mut_ptr() as *mut u8); return data; } } impl TensorflowLiteSession { pub fn js_add_input( &mut self, _: &mut JsObject, ctx: &mut Context, argv: &[JsValue], ) -> JsValue { let name = if let Some(JsValue::String(s)) = argv.get(0) { s.to_string() } else { return ctx.throw_type_error("'name' must be of type string").into(); }; let tensor_buf = if let Some(JsValue::ArrayBuffer(buf)) = argv.get(1) { buf.as_ref() } else { return ctx .throw_type_error("'tensor_buf' must be of type buffer") .into(); }; unsafe { self.add_input(name.as_str(), tensor_buf.as_ptr(), tensor_buf.len() as u32); } JsValue::UnDefined } pub fn js_run( &mut self, _: &mut JsObject, _ctx: &mut Context, _argv: &[JsValue], ) -> JsValue { unsafe { self.run() } JsValue::UnDefined } pub fn js_get_output( &mut self, _: &mut JsObject, ctx: &mut Context, argv: &[JsValue], ) -> JsValue { let name = if let Some(JsValue::String(s)) = argv.get(0) { s.to_string() } else { return ctx.throw_type_error("'name' must be of type string").into(); }; let data = unsafe { self.get_output(name.as_str()) }; ctx.new_array_buffer(data.as_slice()).into() } } impl JsClassDef for TensorflowLiteSession { type RefType = TensorflowLiteSession; const CLASS_NAME: &'static str = "TensorflowLiteSession\0"; const CONSTRUCTOR_ARGC: u8 = 1; const FIELDS: &'static [JsClassField<Self::RefType>] = &[]; const METHODS: &'static [JsClassMethod<Self::RefType>] = &[ ("add_input", 2, Self::js_add_input), ("run", 0, Self::js_run), ("get_output", 1, Self::js_get_output), ]; unsafe fn mut_class_id_ptr() -> &'static mut u32 { static mut CLASS_ID: u32 = 0; &mut CLASS_ID } fn constructor_fn( ctx: &mut Context, argv: &[JsValue], ) -> Result<TensorflowLiteSession, JsValue> { match argv.get(0).ok_or(JsValue::UnDefined)? { JsValue::String(path) => { let path = path.to_string(); let session = TensorflowLiteSession::new_from_path(path) .map_err(|e| ctx.throw_internal_type_error(e.as_str()))?; Ok(session) } _ => Err(JsValue::UnDefined), } } } struct TensorflowModDef; impl ModuleInit for TensorflowModDef { fn init_module(ctx: &mut Context, m: &mut JsModuleDef) { let ctor = register_class::<TensorflowLiteSession>(ctx); m.add_export(TensorflowLiteSession::CLASS_NAME, ctor) } } pub fn init_module_tensorflow_lite(ctx: &mut Context) { ctx.register_module( "tensorflow_lite\0", TensorflowModDef, &[TensorflowLiteSession::CLASS_NAME], ) } } pub use tensorflow::init_module_tensorflow; pub use tensorflow_lite::init_module_tensorflow_lite;