def save_model_for_inference()

in tutorials-and-examples/nvidia-bionemo/fine-tuning/job/finetuning.py [0:0]


def save_model_for_inference(checkpoint_dir, save_path):
    """Save the model in a format compatible with Hugging Face Transformers."""
    os.makedirs(save_path, exist_ok=True)
    
    print(f"Loading checkpoint from directory: {checkpoint_dir}")
    
    try:
        # Check the weights directory
        weights_dir = os.path.join(checkpoint_dir, "weights")
        if not os.path.exists(weights_dir):
            raise FileNotFoundError(f"Weights directory not found in {checkpoint_dir}")
            
        print(f"Contents of weights directory:")
        for file in os.listdir(weights_dir):
            print(f"- {file}")
            
        # Load weights from the weights directory
        weight_files = [f for f in os.listdir(weights_dir) if f.endswith('.pt')]
        if not weight_files:
            raise FileNotFoundError(f"No weight files found in {weights_dir}")
            
        model_file = os.path.join(weights_dir, weight_files[0])
        print(f"Loading model weights from: {model_file}")
        
        checkpoint = torch.load(model_file)
        print("Checkpoint loaded successfully")
        
        # Save the model weights
        if isinstance(checkpoint, dict):
            if 'state_dict' in checkpoint:
                state_dict = checkpoint['state_dict']
            else:
                state_dict = checkpoint
        else:
            state_dict = checkpoint
            
        torch.save(state_dict, os.path.join(save_path, "pytorch_model.bin"))
        print("Model weights saved successfully")
        
        # Save the ESM vocabulary file
        vocab_file = os.path.join(save_path, "vocab.txt")
        vocab = [
            "<pad>", "<mask>", "<cls>", "<sep>", "<unk>",
            "L", "A", "G", "V", "S", "E", "R", "T", "I", "D",
            "P", "K", "Q", "N", "F", "Y", "M", "H", "W", "C",
            "X", "B", "U", "Z", "O", ".", "-", "*"
        ]
        with open(vocab_file, "w") as f:
            f.write("\n".join(vocab))
        print("Vocabulary file saved successfully")
        
        # Create and save the config
        config = {
            "model_type": "esm",
            "architectures": ["ESMForSequenceClassification"],
            "hidden_size": 1280,
            "num_attention_heads": 20,
            "num_hidden_layers": 33,
            "vocab_size": 33,
            "max_position_embeddings": 1024,
            "pad_token_id": 1,
            "eos_token_id": 2,
            "hidden_act": "gelu",
            "attention_probs_dropout_prob": 0.0,
            "hidden_dropout_prob": 0.0,
            "initializer_range": 0.02,
            "layer_norm_eps": 1e-5,
            "position_embedding_type": "absolute"
        }
        
        with open(os.path.join(save_path, "config.json"), "w") as f:
            json.dump(config, f, indent=2)
        print("Config saved successfully")
        
        # Create tokenizer config with vocab file reference
        tokenizer_config = {
            "model_max_length": 1024,
            "padding_side": "right",
            "truncation_side": "right",
            "vocab_file": "vocab.txt",
            "do_lower_case": False,
            "special_tokens_map_file": None
        }
        
        with open(os.path.join(save_path, "tokenizer_config.json"), "w") as f:
            json.dump(tokenizer_config, f, indent=2)
        print("Tokenizer config saved successfully")
        
    except Exception as e:
        print(f"Error during model saving: {str(e)}")
        raise