tutorials-and-examples/nvidia-bionemo/fine-tuning/inference/inference_server.py (120 lines of code) (raw):

# inference_server.py from fastapi import FastAPI, HTTPException from pydantic import BaseModel import torch from transformers import AutoModel, PreTrainedTokenizer, EsmConfig import logging import os logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI() class CustomEsmTokenizer(PreTrainedTokenizer): def __init__(self, **kwargs): # Initialize vocabulary first self.vocab = [ "<pad>", "<mask>", "<cls>", "<sep>", "<unk>", "L", "A", "G", "V", "S", "E", "R", "T", "I", "D", "P", "K", "Q", "N", "F", "Y", "M", "H", "W", "C", "X", "B", "U", "Z", "O", ".", "-", "*" ] self.ids_to_tokens = dict(enumerate(self.vocab)) self.tokens_to_ids = {tok: i for i, tok in enumerate(self.vocab)} # Set special token attributes kwargs["pad_token"] = "<pad>" kwargs["mask_token"] = "<mask>" kwargs["cls_token"] = "<cls>" kwargs["sep_token"] = "<sep>" kwargs["unk_token"] = "<unk>" # Now call parent constructor super().__init__(**kwargs) def get_vocab(self): return self.tokens_to_ids.copy() def _tokenize(self, text): return list(text.strip().upper()) def _convert_token_to_id(self, token): return self.tokens_to_ids.get(token, self.tokens_to_ids["<unk>"]) def _convert_id_to_token(self, index): return self.ids_to_tokens.get(index, "<unk>") def convert_tokens_to_string(self, tokens): return "".join(tokens) @property def vocab_size(self): return len(self.vocab) def save_vocabulary(self, save_directory): vocab_file = os.path.join(save_directory, "vocab.txt") with open(vocab_file, "w") as f: f.write("\n".join(self.vocab)) return (vocab_file,) class InferenceRequest(BaseModel): sequence: str @app.on_event("startup") async def load_model(): global model, tokenizer, config try: model_path = os.getenv("MODEL_PATH", "/mnt/data/model") logger.info(f"Loading model from {model_path}") # Load config config = EsmConfig.from_pretrained(model_path) logger.info(f"Model config loaded: vocab_size={config.vocab_size}") # Create custom tokenizer tokenizer = CustomEsmTokenizer() logger.info(f"Created custom tokenizer with vocab size: {tokenizer.vocab_size}") # Load model model = AutoModel.from_pretrained(model_path) model = model.eval() if torch.cuda.is_available(): model = model.cuda() logger.info("Model loaded on GPU") # Test tokenization test_seq = "MKTV" test_tokens = tokenizer( test_seq, return_tensors="pt", padding=True, truncation=True ) logger.info(f"Test tokenization shape: {test_tokens['input_ids'].shape}") logger.info(f"Test token values: {test_tokens['input_ids'].tolist()}") logger.info("Model and tokenizer loaded successfully") except Exception as e: logger.error(f"Error loading model: {str(e)}") if os.path.exists(model_path): logger.error(f"Directory contents: {os.listdir(model_path)}") raise RuntimeError(f"Failed to load model: {str(e)}") @app.post("/predict") async def predict(request: InferenceRequest): try: # Validate input if not request.sequence or len(request.sequence.strip()) == 0: raise HTTPException(status_code=400, detail="Empty sequence provided") logger.info(f"Processing sequence of length: {len(request.sequence)}") # Tokenize inputs = tokenizer( request.sequence, return_tensors="pt", padding=True, truncation=True, max_length=1024 ) # Remove token_type_ids as ESM doesn't use them if 'token_type_ids' in inputs: del inputs['token_type_ids'] logger.info(f"Tokenized shape: {inputs['input_ids'].shape}") # Move to GPU if available if torch.cuda.is_available(): inputs = {k: v.cuda() for k, v in inputs.items()} # Perform inference with torch.no_grad(): outputs = model(**inputs) # Get embeddings from last hidden state embeddings = outputs.last_hidden_state.mean(dim=1) return { "embeddings": embeddings.cpu().numpy().tolist()[0], "sequence_length": len(request.sequence), "input_length": inputs['input_ids'].shape[1] } except Exception as e: logger.error(f"Inference error: {str(e)}") logger.error("Detailed traceback:", exc_info=True) raise HTTPException( status_code=500, detail=f"Inference failed: {str(e)}" ) except Exception as e: logger.error(f"Inference error: {str(e)}") logger.error("Detailed traceback:", exc_info=True) raise HTTPException( status_code=500, detail=f"Inference failed: {str(e)}" ) @app.get("/health") async def health_check(): """Health check endpoint.""" if model is None or tokenizer is None: raise HTTPException(status_code=503, detail="Model not loaded") return {"status": "healthy"}