import torch
from torch import nn
from transformers import Idefics3Model, Idefics3ForConditionalGeneration
from typing import Dict, Any, List, Optional, Union, Tuple
from transformers.cache_utils import Cache, DynamicCache

from transformers.utils import add_start_docstrings_to_model_forward, logging
from transformers.models.idefics3.modeling_idefics3 import IDEFICS3_INPUTS_DOCSTRING, Idefics3BaseModelOutputWithPast

logger = logging.get_logger(__name__)

class SmolVLMModel(Idefics3Model):
    """
    A subclass of Idefics3Model. We do *not* remove or block the call to inputs_merger
    in forward. Instead, we override inputs_merger here with custom logic.
    """
    def inputs_merger(
        self,
        input_ids: torch.LongTensor,
        inputs_embeds: torch.Tensor,
        image_hidden_states: torch.Tensor
    ) -> torch.Tensor:
        """
        Merge text embeddings with image embeddings out-of-place (no in-place indexing).
    
        The shapes are something like:
          - input_ids:          (B, T)
          - inputs_embeds:      (B, T, D)
          - image_hidden_states:(N, S, D) where N is total images across the batch,
            S is #patches (or #slots) per image, D is embedding dim.
    
        Logic:
          1) For each sample in the batch, find <image> tokens in the text.
          2) If zero <image> tokens => text-only. Concatenate a zero-length slice
             from image_hidden_states but do NOT advance the offset. This ensures
             the model's image encoder is still in the computation graph, but we
             skip "consuming" any image block for a text-only sample.
          3) If there are <image> tokens, they appear in multiples of S for each image
             (because each image is S embeddings). We chunk those positions into groups
             of S. For each chunk => we consume one block from image_hidden_states[offset]
             (which is shape (S, D)), and place each row into the text in place of a token.
    
        Returns:
          A tensor of (B, T, D).
        """
    
        ##############################################
        # 1) Basic shape checks
        ##############################################
        #old_merger_outputs = self.inputs_merger_old(input_ids, inputs_embeds, image_hidden_states)
        B, T, D_text = inputs_embeds.shape
        N, S, D_img  = image_hidden_states.shape
        if D_text != D_img:
            raise ValueError(
                f"Text embedding dim {D_text} != image embedding dim {D_img}"
            )
    
        ##############################################
        # 2) We'll track how many images we've used so far across the entire batch
        ##############################################
        image_offset = 0
    
        # We'll store one merged tensor per batch sample
        merged_outputs: List[torch.Tensor] = []
    
        ##############################################
        # 3) Iterate through each sample
        ##############################################
        for b_idx, (cur_ids, cur_embeds) in enumerate(zip(input_ids, inputs_embeds)):
            # Find positions of <image> tokens in the text
            image_positions = (cur_ids == self.image_token_id).nonzero(as_tuple=True)[0]
            num_image_tokens = len(image_positions)
    
            # If no <image> => text-only
            if num_image_tokens == 0:
                # We do not consume any row from image_hidden_states; 
                # but we do a zero-length slice so the image encoder is in the graph.
                empty_slice = image_hidden_states[0][:0, :]  # shape (0, D)
                # Concatenate text plus that empty slice.
                # NOTE: this is important for DeepSpeed.
                merged_text_only = torch.cat([cur_embeds, empty_slice], dim=0)
                merged_outputs.append(merged_text_only)
                continue
    
            # Otherwise, we have at least one <image> token.
            # Typically, if each image is S embeddings, we expect the total # of <image> tokens
            # in this sample to be multiple of S => each group of S tokens = 1 image
            if num_image_tokens % S != 0:
                raise ValueError(
                    f"Sample {b_idx} has {num_image_tokens} <image> tokens, not a multiple of S={S}. "
                    "Cannot map them to blocks of shape (S, D)."
                )
    
            # We'll chunk image_positions into groups of size S
            positions_list = image_positions.tolist()
            # Example: if num_image_tokens=162 and S=81 => we have 2 images => 2 chunks each of length 81
            chunks = [
                positions_list[i : i + S]
                for i in range(0, num_image_tokens, S)
            ]
    
            # We'll build a list of segments: text, then image row(s), text, etc.
            segments = []
            text_start = 0
    
            # For each chunk (each chunk => 1 image)
            for chunk in chunks:
                # image_hidden_states[image_offset] => shape (S, D)
                cur_block = image_hidden_states[image_offset]
                image_offset += 1
    
                # We'll iterate over the S positions in ascending order
                for i_s, pos in enumerate(chunk):
                    # Add text from [text_start..pos)
                    if pos > text_start:
                        segments.append(cur_embeds[text_start:pos])
                    # Then add one row from cur_block => shape (1, D)
                    row_of_block = cur_block[i_s : i_s + 1, :]
                    segments.append(row_of_block)
                    # skip the <image> token
                    text_start = pos + 1
    
            # leftover text after the final <image> token
            if text_start < T:
                segments.append(cur_embeds[text_start:])
    
            # cat them into a single (T_b, D) tensor
            merged_sample = torch.cat(segments, dim=0)
            merged_outputs.append(merged_sample)
            
        merged_outputs = torch.stack(merged_outputs)
        #assert (old_merger_outputs==merged_outputs).all()
        return merged_outputs


    @add_start_docstrings_to_model_forward(
        """
        Inputs fed to the model can have an arbitrary number of images. To account for this, pixel_values fed to
        the model have image padding -> (batch_size, max_num_images, 3, max_heights, max_widths) where
        max_num_images is the maximum number of images among the batch_size samples in the batch.
        Padding images are not needed beyond padding the pixel_values at the entrance of the model.
        For efficiency, we only pass through the vision_model's forward the real images by
        discarding the padding images i.e. pixel_values of size (image_batch_size, 3, height, width) where
        image_batch_size would be 7 when num_images_per_sample=[1, 3, 1, 2] and max_num_images would be 3.
        """,
        IDEFICS3_INPUTS_DOCSTRING,
    )
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        pixel_values: Optional[torch.FloatTensor] = None,
        pixel_attention_mask: Optional[torch.BoolTensor] = None,
        image_hidden_states: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, Idefics3BaseModelOutputWithPast]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if self.training and self.text_model.gradient_checkpointing and use_cache:
            logger.warning_once(
                "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
            )
            use_cache = False

        # retrieve input_ids and inputs_embeds
        if input_ids is not None:
            batch_size, seq_length = input_ids.shape
        elif inputs_embeds is not None:
            batch_size, seq_length, _ = inputs_embeds.shape
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        past_seen_tokens = 0
        if use_cache:
            if past_key_values is None:
                past_key_values = DynamicCache()
            past_seen_tokens = past_key_values.get_seq_length()

        if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0:
            raise ValueError("When first calling the model, if input_embeds are passed, input_ids should not be None.")

        if inputs_embeds is None:
            inputs_embeds = self.text_model.get_input_embeddings()(input_ids).to(self.device)

        # START VISUAL INPUTS INTEGRATION
        if pixel_values is not None and image_hidden_states is not None:
            raise ValueError("You cannot specify both pixel_values and image_hidden_states at the same time")
        elif pixel_values is not None:
            batch_size, num_images, num_channels, height, width = pixel_values.shape
            pixel_values = pixel_values.to(dtype=self.dtype)  # fp16 compatibility
            pixel_values = pixel_values.view(batch_size * num_images, *pixel_values.shape[2:])

            # Remove padding images - padding images are full 0.
            nb_values_per_image = pixel_values.shape[1:].numel()
            real_images_inds = (pixel_values == 0.0).sum(dim=(-1, -2, -3)) != nb_values_per_image
            
            if not any(real_images_inds):
                # no images, leave one empty image.
                real_images_inds[0] = True
                
            pixel_values = pixel_values[real_images_inds].contiguous()
            
            # Handle the vision attention mask
            if pixel_attention_mask is None:
                pixel_attention_mask = torch.ones(
                    size=(pixel_values.size(0), pixel_values.size(2), pixel_values.size(3)),
                    dtype=torch.bool,
                    device=pixel_values.device,
                )
            else:
                # Remove padding images from the mask
                pixel_attention_mask = pixel_attention_mask.view(
                    batch_size * num_images, *pixel_attention_mask.shape[2:]
                )
                pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()

            patch_size = self.config.vision_config.patch_size
            patches_subgrid = pixel_attention_mask.unfold(dimension=1, size=patch_size, step=patch_size)
            patches_subgrid = patches_subgrid.unfold(dimension=2, size=patch_size, step=patch_size)
            patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()

            # Get sequence from the vision encoder
            image_hidden_states = self.vision_model(
                pixel_values=pixel_values,
                patch_attention_mask=patch_attention_mask,
            ).last_hidden_state
            
            # Modality projection & resampling
            image_hidden_states = self.connector(image_hidden_states)

        elif image_hidden_states is not None:
            image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)

        if past_seen_tokens == 0 and inputs_embeds is not None and image_hidden_states is not None:
            # When we generate, we don't want to replace the potential image_token_id that we generated by images
            # that simply don't exist
            inputs_embeds = self.inputs_merger(
                input_ids=input_ids,
                inputs_embeds=inputs_embeds,
                image_hidden_states=image_hidden_states,
            )

        outputs = self.text_model(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        if not return_dict:
            return tuple(v for v in [*outputs, image_hidden_states] if v is not None)

        return Idefics3BaseModelOutputWithPast(
            last_hidden_state=outputs.last_hidden_state,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            image_hidden_states=image_hidden_states,
        )




class SmolVLMForConditionalGeneration(Idefics3ForConditionalGeneration):
    """
    A subclass of Idefics3ForConditionalGeneration that uses MyIdefics3Model
    instead of the default Idefics3Model.
    """

    def __init__(self, config):
        super().__init__(config)
        # Instead of the original self.model = Idefics3Model(config),
        # we point to our custom class.
        self.model = SmolVLMModel(config)

        # We *keep* the same lm_head from the parent, or re-init if you prefer:
        self.lm_head = nn.Linear(
            config.text_config.hidden_size, config.text_config.vocab_size, bias=False
        )

        # If parent sets up any post_init() logic:
        self.post_init()