def get_model_tflops_per_batch_per_gpu()

in vision/m4/models/vllama3/modeling_vllama3.py [0:0]


    def get_model_tflops_per_batch_per_gpu(self, hparams, data_param, tokenizer, max_num_images, max_num_tokens=None):
        config_vl_model = self.config

        language_embed_size = config_vl_model.hidden_size
        num_language_layers = config_vl_model.num_hidden_layers
        ffn_inner_size = config_vl_model.intermediate_size

        vision_config = config_vl_model.vision_config

        # Get vision model blocks infos
        vision_patch_size = vision_config.patch_size
        vision_hidden_size = vision_config.embed_dim
        num_vision_layers = vision_config.num_hidden_layers
        # The +1 is for the CLS token
        single_image_vision_encoder_seq_len = int(((vision_config.image_size // vision_patch_size) ** 2) // (self.config.pixel_shuffle_factor**2))
        vision_exp_factor = vision_config.intermediate_size // vision_hidden_size

        # Get language blocks infos
        language_seq_len = max_num_tokens if max_num_tokens is not None else data_param.max_seq_len
        language_exp_factor = (ffn_inner_size // language_embed_size) if ffn_inner_size is not None else 4

        # Get modality projection infos
        vision_pipeline_output_seq_len = (
            self.config.perceiver_config.resampler_n_latents
            if self.config.use_resampler
            else single_image_vision_encoder_seq_len
        )

        language_tflops_per_batch_per_gpu = compute_tflops_per_batch_per_gpu(
            num_layers=num_language_layers,
            batch_size=hparams.batch_size_per_gpu,
            q_seq_len=language_seq_len,
            k_seq_len=language_seq_len,
            hidden_size=language_embed_size,
            kv_in_dim=language_embed_size,
            ff_exp_factor=language_exp_factor,
            grad_acc_size=hparams.grad_acc_size,
            swiglu=True,
            vocab_size=tokenizer.vocab_size,
            count_backward=True,  # Always True regardless of freezing, because gradients are computed for vision adaptor
            use_grad_checkpointing=hparams.gradient_checkpointing,
        )
        modality_projection_tflops_per_batch_per_gpu = compute_linear_tflops_per_batch_per_gpu(
            batch_size=hparams.batch_size_per_gpu * max_num_images,
            seq_len=vision_pipeline_output_seq_len,
            in_features=vision_hidden_size,
            out_features=language_embed_size,
            count_backward=True,
            use_grad_checkpointing=hparams.gradient_checkpointing,
        )

        vision_tflops_per_batch_per_gpu = compute_tflops_per_batch_per_gpu(
            num_layers=num_vision_layers,
            batch_size=hparams.batch_size_per_gpu * max_num_images,
            q_seq_len=single_image_vision_encoder_seq_len,
            k_seq_len=single_image_vision_encoder_seq_len,
            hidden_size=vision_hidden_size,
            kv_in_dim=vision_hidden_size,
            ff_exp_factor=vision_exp_factor,
            grad_acc_size=hparams.grad_acc_size,
            swiglu=False,
            vocab_size=None,
            count_backward=not hparams.model_config["freeze_vision_layers"],
            use_grad_checkpointing=hparams.gradient_checkpointing,
        )
        if self.config.use_resampler:
            perceiver_tflops_per_batch_per_gpu = compute_perceiver_tflops_per_batch_per_gpu(
                num_layers=self.config.perceiver_config.resampler_depth,
                batch_size=hparams.batch_size_per_gpu * max_num_images,
                q_seq_len=self.config.perceiver_config.resampler_n_latents,
                vision_embed_seq_len=single_image_vision_encoder_seq_len,
                q_k_v_input_dim=vision_hidden_size,
                attention_hidden_size=self.config.perceiver_config.resampler_n_heads
                * self.config.perceiver_config.resampler_head_dim,
                ff_exp_factor=4,
                count_backward=True,
                use_grad_checkpointing=hparams.gradient_checkpointing,
            )
            tflop_count = (
                language_tflops_per_batch_per_gpu
                + modality_projection_tflops_per_batch_per_gpu
                + perceiver_tflops_per_batch_per_gpu
                + vision_tflops_per_batch_per_gpu
            )
        else:
            tflop_count = (
                language_tflops_per_batch_per_gpu
                + modality_projection_tflops_per_batch_per_gpu
                + vision_tflops_per_batch_per_gpu
            )
        return tflop_count