def from_pretrained()

in models/vision_transformer.py [0:0]


    def from_pretrained(cls, cfg):
        from transformers import SiglipVisionConfig
        from huggingface_hub import hf_hub_download
        import safetensors

        hf_config = SiglipVisionConfig.from_pretrained(cfg.vit_model_type)
        cfg.vit_dropout=hf_config.attention_dropout
        cfg.vit_hidden_dim=hf_config.hidden_size
        cfg.vit_img_size=hf_config.image_size
        cfg.vit_inter_dim=hf_config.intermediate_size
        cfg.vit_ln_eps=hf_config.layer_norm_eps
        cfg.vit_n_heads=hf_config.num_attention_heads
        cfg.vit_n_blocks=hf_config.num_hidden_layers
        cfg.vit_patch_size=hf_config.patch_size
        model = cls(cfg)
        safetensors_file = hf_hub_download(repo_id=cfg.vit_model_type, filename="model.safetensors")

        sd = model.state_dict()
        
        mapping = {
            'vision_model.embeddings.patch_embedding.weight': 'patch_embedding.conv.weight',
            'vision_model.embeddings.patch_embedding.bias': 'patch_embedding.conv.bias',
            'vision_model.embeddings.position_embedding.weight': 'patch_embedding.position_embedding',
            'vision_model.post_layernorm.weight': 'layer_norm.weight',
            'vision_model.post_layernorm.bias': 'layer_norm.bias',
        }
        
        for i in range(cfg.vit_n_blocks):
            # Layer norms
            mapping[f'vision_model.encoder.layers.{i}.layer_norm1.weight'] = f'blocks.{i}.ln1.weight'
            mapping[f'vision_model.encoder.layers.{i}.layer_norm1.bias'] = f'blocks.{i}.ln1.bias'
            mapping[f'vision_model.encoder.layers.{i}.layer_norm2.weight'] = f'blocks.{i}.ln2.weight'
            mapping[f'vision_model.encoder.layers.{i}.layer_norm2.bias'] = f'blocks.{i}.ln2.bias'
            
            # MLP
            mapping[f'vision_model.encoder.layers.{i}.mlp.fc1.weight'] = f'blocks.{i}.mlp.fc1.weight'
            mapping[f'vision_model.encoder.layers.{i}.mlp.fc1.bias'] = f'blocks.{i}.mlp.fc1.bias'
            mapping[f'vision_model.encoder.layers.{i}.mlp.fc2.weight'] = f'blocks.{i}.mlp.fc2.weight'
            mapping[f'vision_model.encoder.layers.{i}.mlp.fc2.bias'] = f'blocks.{i}.mlp.fc2.bias'
            
            # Output projection
            mapping[f'vision_model.encoder.layers.{i}.self_attn.out_proj.weight'] = f'blocks.{i}.attn.out_proj.weight'
            mapping[f'vision_model.encoder.layers.{i}.self_attn.out_proj.bias'] = f'blocks.{i}.attn.out_proj.bias'
        
        with safetensors.safe_open(filename=safetensors_file, framework="pt", device="cpu") as f:
            for hf_key, our_key in mapping.items():
                if hf_key in f.keys() and our_key in sd:
                    tensor = f.get_tensor(hf_key)
                    if tensor.shape == sd[our_key].shape:
                        sd[our_key].copy_(tensor)
                    else:
                        if 'position_embedding' in hf_key:
                            sd[our_key].copy_(tensor.unsqueeze(0))
                        else:
                            print(f"Shape mismatch for {hf_key} -> {our_key}: {tensor.shape} vs {sd[our_key].shape}")
                else:
                    if hf_key not in f.keys():
                        print(f"Warning: Key {hf_key} not found in safetensors file")
                    if our_key not in sd:
                        print(f"Warning: Key {our_key} not found in model state dict")
            
            # Manually handle QKV concatenation since our implementation combines Q, K, V into one
            for i in range(model.cfg.vit_n_blocks):
                q_weight = f.get_tensor(f'vision_model.encoder.layers.{i}.self_attn.q_proj.weight')
                k_weight = f.get_tensor(f'vision_model.encoder.layers.{i}.self_attn.k_proj.weight')
                v_weight = f.get_tensor(f'vision_model.encoder.layers.{i}.self_attn.v_proj.weight')
                
                qkv_weight = torch.cat((q_weight, k_weight, v_weight), dim=0)
                sd[f'blocks.{i}.attn.qkv_proj.weight'].copy_(qkv_weight)
                
                q_bias = f.get_tensor(f'vision_model.encoder.layers.{i}.self_attn.q_proj.bias')
                k_bias = f.get_tensor(f'vision_model.encoder.layers.{i}.self_attn.k_proj.bias')
                v_bias = f.get_tensor(f'vision_model.encoder.layers.{i}.self_attn.v_proj.bias')
                
                qkv_bias = torch.cat((q_bias, k_bias, v_bias), dim=0)
                sd[f'blocks.{i}.attn.qkv_proj.bias'].copy_(qkv_bias)
        
        model.load_state_dict(sd)
        print(f"Successfully loaded {cfg.vit_model_type} weights from safetensors. Model has {sum(p.numel() for p in model.parameters()):,} parameters.")
        return model