in vision/m4/models/vmistral/modeling_vmistral.py [0:0]
def get_model_tflops_per_batch_per_gpu(self, hparams, data_param, tokenizer, max_num_images, max_num_tokens=None):
config_vl_model = self.config
language_embed_size = config_vl_model.hidden_size
num_language_layers = config_vl_model.num_hidden_layers
ffn_inner_size = config_vl_model.intermediate_size
vision_config = config_vl_model.vision_config
# Get vision model blocks infos
vision_patch_size = vision_config.patch_size
vision_hidden_size = vision_config.embed_dim
num_vision_layers = vision_config.num_hidden_layers
# The +1 is for the CLS token
single_image_vision_encoder_seq_len = (
vision_config.image_size // vision_patch_size
) ** 2 # + 1 # That +1 is only valid for vit/clip
vision_exp_factor = vision_config.intermediate_size // vision_hidden_size
# Get language blocks infos
language_seq_len = max_num_tokens if max_num_tokens is not None else data_param.max_seq_len
language_exp_factor = (ffn_inner_size // language_embed_size) if ffn_inner_size is not None else 4
# Get modality projection infos
vision_pipeline_output_seq_len = (
self.config.perceiver_config.resampler_n_latents
if self.config.use_resampler
else single_image_vision_encoder_seq_len
)
language_tflops_per_batch_per_gpu = compute_tflops_per_batch_per_gpu(
num_layers=num_language_layers,
batch_size=hparams.batch_size_per_gpu,
q_seq_len=language_seq_len,
k_seq_len=language_seq_len,
hidden_size=language_embed_size,
kv_in_dim=language_embed_size,
ff_exp_factor=language_exp_factor,
grad_acc_size=hparams.grad_acc_size,
swiglu=True,
vocab_size=tokenizer.vocab_size,
count_backward=True, # Always True regardless of freezing, because gradients are computed for vision adaptor
use_grad_checkpointing=hparams.gradient_checkpointing,
)
modality_projection_tflops_per_batch_per_gpu = compute_linear_tflops_per_batch_per_gpu(
batch_size=hparams.batch_size_per_gpu * max_num_images,
seq_len=vision_pipeline_output_seq_len,
in_features=vision_hidden_size,
out_features=language_embed_size,
count_backward=True,
use_grad_checkpointing=hparams.gradient_checkpointing,
)
vision_tflops_per_batch_per_gpu = compute_tflops_per_batch_per_gpu(
num_layers=num_vision_layers,
batch_size=hparams.batch_size_per_gpu * max_num_images,
q_seq_len=single_image_vision_encoder_seq_len,
k_seq_len=single_image_vision_encoder_seq_len,
hidden_size=vision_hidden_size,
kv_in_dim=vision_hidden_size,
ff_exp_factor=vision_exp_factor,
grad_acc_size=hparams.grad_acc_size,
swiglu=False,
vocab_size=None,
count_backward=not hparams.model_config["freeze_vision_layers"],
use_grad_checkpointing=hparams.gradient_checkpointing,
)
if self.config.use_resampler:
perceiver_tflops_per_batch_per_gpu = compute_perceiver_tflops_per_batch_per_gpu(
num_layers=self.config.perceiver_config.resampler_depth,
batch_size=hparams.batch_size_per_gpu * max_num_images,
q_seq_len=self.config.perceiver_config.resampler_n_latents,
vision_embed_seq_len=single_image_vision_encoder_seq_len,
q_k_v_input_dim=vision_hidden_size,
attention_hidden_size=self.config.perceiver_config.resampler_n_heads
* self.config.perceiver_config.resampler_head_dim,
ff_exp_factor=4,
count_backward=True,
use_grad_checkpointing=hparams.gradient_checkpointing,
)
tflop_count = (
language_tflops_per_batch_per_gpu
+ modality_projection_tflops_per_batch_per_gpu
+ perceiver_tflops_per_batch_per_gpu
+ vision_tflops_per_batch_per_gpu
)
else:
tflop_count = (
language_tflops_per_batch_per_gpu
+ modality_projection_tflops_per_batch_per_gpu
+ vision_tflops_per_batch_per_gpu
)
return tflop_count