def get_model_tflops_per_batch_per_gpu()

in vision/m4/models/idefics/modeling_idefics.py [0:0]


    def get_model_tflops_per_batch_per_gpu(self, hparams, data_param, tokenizer, max_num_images):
        config_vl_model = self.config

        language_embed_size = config_vl_model.hidden_size
        num_language_layers = config_vl_model.num_hidden_layers
        ffn_inner_size = config_vl_model.intermediate_size

        vision_config = self.model.vision_model.config
        if hasattr(vision_config, "vision_config"):
            vision_config = vision_config.vision_config

        # Get vision model blocks infos
        vision_patch_size = vision_config.patch_size
        vision_hidden_size = vision_config.hidden_size
        num_vision_layers = vision_config.num_hidden_layers
        # The +1 is for the CLS token
        single_image_seq_len = (vision_config.image_size // vision_patch_size) ** 2 + 1
        vision_exp_factor = vision_config.intermediate_size // vision_hidden_size

        # Get language and cross-att blocks infos
        num_cross_attn_layers = num_language_layers // config_vl_model.cross_layer_interval
        language_seq_len = data_param.max_seq_len
        language_exp_factor = (ffn_inner_size // language_embed_size) if ffn_inner_size is not None else 4
        cross_att_exp_factor = (ffn_inner_size // language_embed_size) if ffn_inner_size is not None else 4
        k_v_cross_attn_seq_len = (
            (self.config.perceiver_config.resampler_n_latents * max_num_images)
            if self.config.use_resampler
            else (single_image_seq_len * max_num_images)
        )

        language_tflops_per_batch_per_gpu = compute_tflops_per_batch_per_gpu(
            num_layers=num_language_layers,
            batch_size=hparams.batch_size_per_gpu,
            q_seq_len=language_seq_len,
            k_seq_len=language_seq_len,
            hidden_size=language_embed_size,
            kv_in_dim=language_embed_size,
            ff_exp_factor=language_exp_factor,
            grad_acc_size=hparams.grad_acc_size,
            swiglu=True,
            vocab_size=tokenizer.vocab_size,
            count_backward=True,  # Always True regardless of freezing, because gradients are computed for cross-attentions
            use_grad_checkpointing=hparams.gradient_checkpointing,
        )
        cross_attention_tflops_per_batch_per_gpu = compute_tflops_per_batch_per_gpu(
            num_layers=num_cross_attn_layers,
            batch_size=hparams.batch_size_per_gpu,
            q_seq_len=language_seq_len,
            k_seq_len=k_v_cross_attn_seq_len,
            hidden_size=language_embed_size,
            kv_in_dim=vision_hidden_size,
            ff_exp_factor=cross_att_exp_factor,
            grad_acc_size=hparams.grad_acc_size,
            swiglu=True,
            vocab_size=None,
            count_backward=True,
            use_grad_checkpointing=hparams.gradient_checkpointing,
        )
        vision_tflops_per_batch_per_gpu = compute_tflops_per_batch_per_gpu(
            num_layers=num_vision_layers,
            batch_size=hparams.batch_size_per_gpu * max_num_images,
            q_seq_len=single_image_seq_len,
            k_seq_len=single_image_seq_len,
            hidden_size=vision_hidden_size,
            kv_in_dim=vision_hidden_size,
            ff_exp_factor=vision_exp_factor,
            grad_acc_size=hparams.grad_acc_size,
            swiglu=False,
            vocab_size=None,
            count_backward=not hparams.model_config["freeze_vision_layers"],
            use_grad_checkpointing=hparams.gradient_checkpointing,
        )
        if self.config.use_resampler:
            perceiver_tflops_per_batch_per_gpu = compute_perceiver_tflops_per_batch_per_gpu(
                num_layers=self.config.perceiver_config.resampler_depth,
                batch_size=hparams.batch_size_per_gpu * max_num_images,
                q_seq_len=self.config.perceiver_config.resampler_n_latents,
                vision_embed_seq_len=single_image_seq_len,
                q_k_v_input_dim=vision_hidden_size,
                attention_hidden_size=self.config.perceiver_config.resampler_n_heads
                * self.config.perceiver_config.resampler_head_dim,
                ff_exp_factor=cross_att_exp_factor,
                count_backward=True,
                use_grad_checkpointing=hparams.gradient_checkpointing,
            )
            flop_count = (
                language_tflops_per_batch_per_gpu
                + cross_attention_tflops_per_batch_per_gpu
                + vision_tflops_per_batch_per_gpu
                + perceiver_tflops_per_batch_per_gpu
            )
        else:
            flop_count = (
                language_tflops_per_batch_per_gpu
                + cross_attention_tflops_per_batch_per_gpu
                + vision_tflops_per_batch_per_gpu
            )
        return flop_count