import os
import pathlib
import torch
import torch.nn.functional as F
from diffusers import FluxPipeline
from torch._inductor.package import load_package as inductor_load_package
from typing import List, Optional, Tuple
import inspect


@torch.library.custom_op("flash::flash_attn_func", mutates_args=())
def flash_attn_func(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    softmax_scale: Optional[float] =None,
    causal: bool = False,
    # probably wrong type for these 4
    qv: Optional[float] = None,
    q_descale: Optional[float] = None,
    k_descale: Optional[float] = None,
    v_descale: Optional[float] = None,
    window_size: Optional[List[int]] = None,
    sink_token_length: int = 0,
    softcap: float = 0.0,
    num_splits: int = 1,
    # probably wrong type for this too
    pack_gqa: Optional[float] = None,
    deterministic: bool = False,
    sm_margin: int = 0,
) -> torch.Tensor: #Tuple[torch.Tensor, torch.Tensor]:
    if window_size is None:
        window_size = (-1, -1)
    else:
        window_size = tuple(window_size)

    import flash_attn_interface

    dtype = torch.float8_e4m3fn

    sig = inspect.signature(flash_attn_interface.flash_attn_func)
    accepted = set(sig.parameters)
    all_kwargs = {
        "softmax_scale": softmax_scale,
        "causal": causal,
        "qv": qv,
        "q_descale": q_descale,
        "k_descale": k_descale,
        "v_descale": v_descale,
        "window_size": window_size,
        "sink_token_length": sink_token_length,
        "softcap": softcap,
        "num_splits": num_splits,
        "pack_gqa": pack_gqa,
        "deterministic": deterministic,
        "sm_margin": sm_margin,
    }
    kwargs = {k: v for k, v in all_kwargs.items() if k in accepted}

    outputs = flash_attn_interface.flash_attn_func(
        q.to(dtype), k.to(dtype), v.to(dtype), **kwargs,
    )
    return outputs[0]


@flash_attn_func.register_fake
def _(q, k, v, **kwargs):
    # two outputs:
    # 1. output: (batch, seq_len, num_heads, head_dim)
    # 2. softmax_lse: (batch, num_heads, seq_len) with dtype=torch.float32
    meta_q = torch.empty_like(q).contiguous()
    return meta_q #, q.new_empty((q.size(0), q.size(2), q.size(1)), dtype=torch.float32)


# Copied FusedFluxAttnProcessor2_0 but using flash v3 instead of SDPA
class FlashFusedFluxAttnProcessor3_0:
    """Attention processor used typically in processing the SD3-like self-attention projections."""

    def __init__(self):
        try:
            import flash_attn_interface
        except ImportError:
            raise ImportError(
                "flash_attention v3 package is required to be installed"
            )

    def __call__(
        self,
        attn,
        hidden_states: torch.FloatTensor,
        encoder_hidden_states: torch.FloatTensor = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        image_rotary_emb: Optional[torch.Tensor] = None,
    ) -> torch.FloatTensor:
        batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape

        # `sample` projections.
        qkv = attn.to_qkv(hidden_states)
        split_size = qkv.shape[-1] // 3
        query, key, value = torch.split(qkv, split_size, dim=-1)

        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads

        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        if attn.norm_q is not None:
            query = attn.norm_q(query)
        if attn.norm_k is not None:
            key = attn.norm_k(key)

        # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
        # `context` projections.
        if encoder_hidden_states is not None:
            encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
            split_size = encoder_qkv.shape[-1] // 3
            (
                encoder_hidden_states_query_proj,
                encoder_hidden_states_key_proj,
                encoder_hidden_states_value_proj,
            ) = torch.split(encoder_qkv, split_size, dim=-1)

            encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
                batch_size, -1, attn.heads, head_dim
            ).transpose(1, 2)
            encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
                batch_size, -1, attn.heads, head_dim
            ).transpose(1, 2)
            encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
                batch_size, -1, attn.heads, head_dim
            ).transpose(1, 2)

            if attn.norm_added_q is not None:
                encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
            if attn.norm_added_k is not None:
                encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)

            # attention
            query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
            key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
            value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)

        if image_rotary_emb is not None:
            from diffusers.models.embeddings import apply_rotary_emb

            query = apply_rotary_emb(query, image_rotary_emb)
            key = apply_rotary_emb(key, image_rotary_emb)

        # NB: transposes are necessary to match expected SDPA input shape
        hidden_states = flash_attn_func(
            query.transpose(1, 2),
            key.transpose(1, 2),
            value.transpose(1, 2))[0].transpose(1, 2)

        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        hidden_states = hidden_states.to(query.dtype)

        if encoder_hidden_states is not None:
            encoder_hidden_states, hidden_states = (
                hidden_states[:, : encoder_hidden_states.shape[1]],
                hidden_states[:, encoder_hidden_states.shape[1] :],
            )

            # linear proj
            hidden_states = attn.to_out[0](hidden_states)
            # dropout
            hidden_states = attn.to_out[1](hidden_states)
            encoder_hidden_states = attn.to_add_out(encoder_hidden_states)

            return hidden_states, encoder_hidden_states
        else:
            return hidden_states


# wrapper to automatically handle CUDAGraph record / replay over the given function
def cudagraph(f):
    from torch.utils._pytree import tree_map_only

    _graphs = {}
    def f_(*args, **kwargs):
        key = hash(tuple(tuple(kwargs[a].shape) for a in sorted(kwargs.keys())
                         if isinstance(kwargs[a], torch.Tensor)))
        if key in _graphs:
            # use the cached wrapper if one exists. this will perform CUDAGraph replay
            wrapped, *_ = _graphs[key]
            return wrapped(*args, **kwargs)

        # record a new CUDAGraph and cache it for future use
        g = torch.cuda.CUDAGraph()
        in_args, in_kwargs = tree_map_only(torch.Tensor, lambda t: t.clone(), (args, kwargs))
        f(*in_args, **in_kwargs) # stream warmup
        with torch.cuda.graph(g):
            out_tensors = f(*in_args, **in_kwargs)
        def wrapped(*args, **kwargs):
            # note that CUDAGraphs require inputs / outputs to be in fixed memory locations.
            # inputs must be copied into the fixed input memory locations.
            [a.copy_(b) for a, b in zip(in_args, args) if isinstance(a, torch.Tensor)]
            for key in kwargs:
                if isinstance(kwargs[key], torch.Tensor):
                    in_kwargs[key].copy_(kwargs[key])
            g.replay()
            # clone() outputs on the way out to disconnect them from the fixed output memory
            # locations. this allows for CUDAGraph reuse without accidentally overwriting memory
            return [o.clone() for o in out_tensors]

        # cache function that does CUDAGraph replay
        _graphs[key] = (wrapped, g, in_args, in_kwargs, out_tensors)
        return wrapped(*args, **kwargs)
    return f_


def use_compile(pipeline):
    # Compile the compute-intensive portions of the model: denoising transformer / decoder
    pipeline.transformer = torch.compile(
        pipeline.transformer, mode="max-autotune", fullgraph=True
    )
    pipeline.vae.decode = torch.compile(
        pipeline.vae.decode, mode="max-autotune", fullgraph=True
    )

    # warmup for a few iterations (`num_inference_steps` shouldn't matter)
    for _ in range(3):
        pipeline(
            "dummy prompt to trigger torch compilation",
            output_type="pil",
            num_inference_steps=4,
        ).images[0]

    return pipeline


def download_hosted_file(filename, output_path):
    # Download hosted binaries from huggingface Hub.
    from huggingface_hub import hf_hub_download

    REPO_NAME = "jbschlosser/flux-fast"
    hf_hub_download(REPO_NAME, filename, local_dir=os.path.dirname(output_path))


def load_package(package_path):
    if not os.path.exists(package_path):
        download_hosted_file(os.path.basename(package_path), package_path)

    loaded_package = inductor_load_package(package_path, run_single_threaded=True)
    return loaded_package


def use_export_aoti(pipeline, cache_dir, serialize=False, is_timestep_distilled=True):
    # create cache dir if needed
    pathlib.Path(cache_dir).mkdir(parents=True, exist_ok=True)

    def _example_tensor(*shape):
        return torch.randn(*shape, device="cuda", dtype=torch.bfloat16)

    # === Transformer compile / export ===
    seq_length = 256 if is_timestep_distilled else 512
    # these shapes are for 1024x1024 resolution.
    transformer_kwargs = {
        "hidden_states": _example_tensor(1, 4096, 64),
        "timestep": torch.tensor([1.], device="cuda", dtype=torch.bfloat16),
        "guidance": None if is_timestep_distilled else torch.tensor([1.], device="cuda", dtype=torch.bfloat16),
        "pooled_projections": _example_tensor(1, 768),
        "encoder_hidden_states": _example_tensor(1, seq_length, 4096),
        "txt_ids": _example_tensor(seq_length, 3),
        "img_ids": _example_tensor(4096, 3),
        "joint_attention_kwargs": {},
        "return_dict": False,
    }

    # Possibly serialize model out
    transformer_package_path = os.path.join(
        cache_dir, "exported_transformer.pt2" if is_timestep_distilled else "exported_dev_transformer.pt2"
    )
    if serialize:
        # Apply export
        exported_transformer: torch.export.ExportedProgram = torch.export.export(
            pipeline.transformer, args=(), kwargs=transformer_kwargs
        )

        # Apply AOTI
        path = torch._inductor.aoti_compile_and_package(
            exported_transformer,
            package_path=transformer_package_path,
            inductor_configs={"max_autotune": True, "triton.cudagraphs": True},
        )
    # download serialized model if needed
    loaded_transformer = load_package(transformer_package_path)

    # warmup before cudagraphing
    with torch.no_grad():
        loaded_transformer(**transformer_kwargs)

    # Apply CUDAGraphs. CUDAGraphs are utilized in torch.compile with mode="max-autotune", but
    # they must be manually applied for torch.export + AOTI.
    loaded_transformer = cudagraph(loaded_transformer)
    pipeline.transformer.forward = loaded_transformer

    # warmup after cudagraphing
    with torch.no_grad():
        pipeline.transformer(**transformer_kwargs)

    # hack to get around export's limitations
    pipeline.vae.forward = pipeline.vae.decode

    vae_decode_kwargs = {"return_dict": False}

    # Possibly serialize model out
    decoder_package_path = os.path.join(
        cache_dir, "exported_decoder.pt2" if is_timestep_distilled else "exported_dev_decoder.pt2"
    )
    if serialize:
        # Apply export
        exported_decoder: torch.export.ExportedProgram = torch.export.export(
            pipeline.vae, args=(_example_tensor(1, 16, 128, 128),), kwargs=vae_decode_kwargs
        )

        # Apply AOTI
        path = torch._inductor.aoti_compile_and_package(
            exported_decoder,
            package_path=decoder_package_path,
            inductor_configs={"max_autotune": True, "triton.cudagraphs": True},
        )
    # download serialized model if needed
    loaded_decoder = load_package(decoder_package_path)

    # warmup before cudagraphing
    with torch.no_grad():
        loaded_decoder(_example_tensor(1, 16, 128, 128), **vae_decode_kwargs)

    loaded_decoder = cudagraph(loaded_decoder)
    pipeline.vae.decode = loaded_decoder

    # warmup for a few iterations
    for _ in range(3):
        pipeline(
            "dummy prompt to trigger torch compilation",
            output_type="pil",
            num_inference_steps=4,
        ).images[0]

    return pipeline


def optimize(pipeline, args):
    is_timestep_distilled = not pipeline.transformer.config.guidance_embeds

    # fuse QKV projections in Transformer and VAE
    if not args.disable_fused_projections:
        pipeline.transformer.fuse_qkv_projections()
        pipeline.vae.fuse_qkv_projections()

    # Use flash attention v3
    if not args.disable_fa3:
        pipeline.transformer.set_attn_processor(FlashFusedFluxAttnProcessor3_0())

    # switch memory layout to channels_last
    if not args.disable_channels_last:
        pipeline.vae.to(memory_format=torch.channels_last)

    # apply float8 quantization
    if not args.disable_quant:
        from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight #, PerRow

        quantize_(
            pipeline.transformer,
            float8_dynamic_activation_float8_weight(),
            # float8_dynamic_activation_float8_weight(granularity=PerRow()),
        )

    # set inductor flags
    if not args.disable_inductor_tuning_flags:
        config = torch._inductor.config
        config.conv_1x1_as_mm = True  # treat 1x1 convolutions as matrix muls
        config.epilogue_fusion = False  # do not fuse pointwise ops into matmuls
        # adjust autotuning algorithm
        config.coordinate_descent_tuning = True
        config.coordinate_descent_check_all_directions = True

        # TODO: Test out more mm settings
        # config.triton.enable_persistent_tma_matmul = True
        # config.max_autotune_gemm_backends = "ATEN,TRITON,CPP,CUTLASS"

    if args.compile_export_mode == "compile":
        pipeline = use_compile(pipeline)
    elif args.compile_export_mode == "export_aoti":
        pipeline = use_export_aoti(
            pipeline,
            cache_dir=args.cache_dir,
            serialize=(not args.use_cached_model),
            is_timestep_distilled=is_timestep_distilled
        )
    elif args.compile_export_mode == "disabled":
        pass
    else:
        raise RuntimeError(
            "expected compile_export_mode arg to be one of {compile, export_aoti, disabled}"
        )

    return pipeline


def load_pipeline(args):
    load_dtype = torch.float32 if args.disable_bf16 else torch.bfloat16
    pipeline = FluxPipeline.from_pretrained(args.ckpt, torch_dtype=load_dtype).to(args.device)
    pipeline.set_progress_bar_config(disable=True)
    pipeline = optimize(pipeline, args)
    return pipeline
