# Copyright (c) 2025, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#    http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# We use the position ecnoder ops from generative-recommenders a starting point.
# https://github.com/facebookresearch/generative-recommenders
# thanks to their public work.


from typing import List, Optional, Tuple

import torch
import triton
import triton.language as tl
from triton.runtime.autotuner import autotune as triton_autotune

from tzrec.ops.utils import (
    autotune_max_seq_len,
    prev_power_of_2,
    switch_to_contiguous_if_needed,
)

torch.fx.wrap(prev_power_of_2)


def _add_position_embeddings_configs() -> List[triton.Config]:
    configs = []
    for BLOCK_N in [16, 32, 64]:
        for num_stages in [1, 2]:
            for num_warps in [2, 4, 8]:
                configs.append(
                    triton.Config(
                        {
                            "BLOCK_N": BLOCK_N,
                        },
                        num_stages=num_stages,
                        num_warps=num_warps,
                    )
                )
    return configs


@triton_autotune(
    configs=_add_position_embeddings_configs(),
    key=["AUTOTUNE_MAX_SEQ_LEN"],
)
@triton.jit
def _add_position_embeddings_kernel(
    Jagged,
    seq_offsets,
    high_inds,
    Dense,
    Out,
    AUTOTUNE_MAX_SEQ_LEN,
    D,
    scale,
    stride_jn,
    stride_dk,
    stride_on,
    SCALE_JAGGED: tl.constexpr,
    BLOCK_D: tl.constexpr,
    BLOCK_N: tl.constexpr,
):
    """Add position embedding.

    Jagged has shape (sum_B(N_i), D),
    Dense has shape (K, D),
    Out has shape (sum_B(N_i), D)
    """
    off_b = tl.program_id(0)
    off_n = tl.program_id(1)
    seq_start = tl.load(seq_offsets + off_b)
    seq_end = tl.load(seq_offsets + off_b + 1)
    max_ind = tl.load(high_inds + off_b)
    seq_len = seq_end - seq_start
    start_n = off_n * BLOCK_N
    if start_n >= seq_len:
        return
    offs_n = start_n + tl.arange(0, BLOCK_N)
    clamped_offs_n = tl.where(offs_n >= max_ind, max_ind, offs_n)
    offs_d = tl.arange(0, BLOCK_D)
    Jagged += seq_start.to(tl.int64) * stride_jn
    jagged_ptr_offsets = offs_n[:, None] * stride_jn + offs_d[None, :]
    Out += seq_start.to(tl.int64) * stride_on
    out_ptrs = Out + offs_n[:, None] * stride_on + offs_d[None, :]
    dense_ptrs = Dense + clamped_offs_n[:, None] * stride_dk + offs_d[None, :]
    for _d in range(0, D, BLOCK_D):
        mask = (offs_n[:, None] < seq_len) and offs_d[None, :] < D
        jg = tl.load(Jagged + jagged_ptr_offsets, mask=mask)
        if SCALE_JAGGED:
            jg = jg * scale
        dn = tl.load(dense_ptrs, mask=mask)
        jg += dn
        tl.store(out_ptrs, jg, mask=mask)
        dense_ptrs += BLOCK_D
        out_ptrs += BLOCK_D
        offs_d += BLOCK_D
        jagged_ptr_offsets += BLOCK_D


@triton.jit
def _add_position_embeddings_bwd_kernel(
    Jagged,
    seq_offsets,
    high_inds,
    DenseOut,
    JaggedOut,
    B,
    D,
    scale,
    stride_jn,
    stride_jon,
    stride_don,
    SCALE_JAGGED: tl.constexpr,
    BLOCK_D: tl.constexpr,
):
    off_k = tl.program_id(0)
    offs_d = tl.arange(0, BLOCK_D)
    accumulator = tl.zeros((BLOCK_D,), dtype=tl.float32)
    for off_b in range(0, B):
        max_ind = tl.load(high_inds + off_b)
        if off_k < max_ind:
            seq_start = tl.load(seq_offsets + off_b)
            jagged_ptr = (
                Jagged
                + seq_start.to(tl.int64) * stride_jn
                + off_k.to(tl.int64) * stride_jn
            )
            jagged_ptrs = jagged_ptr + offs_d
            jg = tl.load(
                jagged_ptrs,
                mask=offs_d < D,
            )
            accumulator += jg
            if SCALE_JAGGED:
                out_jagged_ptr = (
                    JaggedOut
                    + seq_start.to(tl.int64) * stride_jon
                    + off_k.to(tl.int64) * stride_jon
                )
                out_jagged_ptrs = out_jagged_ptr + offs_d
                tl.store(
                    out_jagged_ptrs,
                    jg * scale,
                    mask=offs_d < D,
                )
        elif off_k == max_ind:
            seq_start = tl.load(seq_offsets + off_b).to(tl.int64)
            seq_end = tl.load(seq_offsets + off_b + 1)
            for k in range(seq_start + max_ind, seq_end):
                jagged_ptr = Jagged + k * stride_jn
                jagged_ptrs = jagged_ptr + offs_d
                jg = tl.load(
                    jagged_ptrs,
                    mask=offs_d < D,
                )
                accumulator += jg
                if SCALE_JAGGED:
                    out_jagged_ptr = JaggedOut + k * stride_jon
                    out_jagged_ptrs = out_jagged_ptr + offs_d
                    tl.store(
                        out_jagged_ptrs,
                        jg * scale,
                        mask=offs_d < D,
                    )
    out = accumulator.to(DenseOut.dtype.element_ty)
    out_ptrs = DenseOut + off_k * stride_don + offs_d
    tl.store(
        out_ptrs,
        out,
        mask=offs_d < D,
    )


class _AddPositionEmbeddingsFunction(torch.autograd.Function):
    @staticmethod
    # pyre-ignore[14]
    def forward(
        ctx,
        jagged: torch.Tensor,
        jagged_offsets: torch.Tensor,
        high_inds: torch.Tensor,
        max_seq_len: int,
        dense: torch.Tensor,
        scale: float = 1.0,
    ):
        jagged = switch_to_contiguous_if_needed(jagged)
        dense = switch_to_contiguous_if_needed(dense)
        L, D = jagged.shape
        assert len(dense.shape) == 2
        out = torch.empty_like(jagged)
        B = high_inds.size(0)
        grid = lambda meta: (  # noqa E731
            B,
            triton.cdiv(max_seq_len, meta["BLOCK_N"]),
        )
        BLOCK_D = triton.next_power_of_2(D) if D < 64 else 64
        _add_position_embeddings_kernel[grid](
            Jagged=jagged,
            seq_offsets=jagged_offsets,
            high_inds=high_inds,
            Dense=dense,
            Out=out,
            AUTOTUNE_MAX_SEQ_LEN=autotune_max_seq_len(max_seq_len),
            D=D,
            scale=scale,
            stride_jn=jagged.stride(0),
            stride_dk=dense.stride(0),
            stride_on=out.stride(0),
            SCALE_JAGGED=scale != 1.0,
            BLOCK_D=BLOCK_D,
        )
        ctx.save_for_backward(jagged_offsets, high_inds)
        ctx.B = B
        ctx.D = D
        ctx.scale = scale
        ctx.K = dense.size(0)
        ctx.BLOCK_D = BLOCK_D
        return out

    @staticmethod
    # pyre-ignore[14]
    def backward(
        ctx, d_out: torch.Tensor
    ) -> Tuple[torch.Tensor, None, None, None, torch.Tensor, None]:
        jagged_offsets, high_inds = ctx.saved_tensors
        d_dense = torch.empty((ctx.K, ctx.D), device=d_out.device, dtype=d_out.dtype)
        scale_jagged = ctx.scale != 1.0
        if scale_jagged:
            d_jagged = torch.empty_like(d_out)
        BLOCK_D = triton.next_power_of_2(ctx.D)
        _add_position_embeddings_bwd_kernel[(ctx.K,)](
            Jagged=d_out,
            seq_offsets=jagged_offsets,
            high_inds=high_inds,
            DenseOut=d_dense,
            JaggedOut=d_jagged if scale_jagged else None,  # pyre-ignore[61]
            B=ctx.B,
            D=ctx.D,
            scale=ctx.scale,
            stride_jn=d_out.stride(0),
            stride_jon=d_jagged.stride(0) if scale_jagged else 0,
            stride_don=d_dense.stride(0),
            SCALE_JAGGED=scale_jagged,
            BLOCK_D=BLOCK_D,
        )
        # pyre-ignore[61]
        return d_jagged if scale_jagged else d_out, None, None, None, d_dense, None


@triton_autotune(
    configs=_add_position_embeddings_configs(),
    key=["AUTOTUNE_MAX_SEQ_LEN"],
)
@triton.jit
def _add_timestamp_position_embeddings_kernel(
    SeqEmb,
    Offsets,
    Lengths,
    PosEmb,
    TsEmb,
    Out,
    TS,
    PosInds,
    TsInds,
    NumTargets,
    AUTOTUNE_MAX_SEQ_LEN,
    D,
    num_time_buckets,
    time_bucket_increments,
    time_bucket_scale,
    time_delta,
    max_contextual_seq_len,
    max_pos_ind,
    stride_sn,
    stride_pn,
    stride_tn,
    stride_on,
    TRAINING: tl.constexpr,
    HAS_MULTIPLE_TARGETS: tl.constexpr,
    INTERLEAVE_TARGETS: tl.constexpr,
    TIME_BUCKET_FN: tl.constexpr,
    BLOCK_D: tl.constexpr,
    BLOCK_N: tl.constexpr,
):
    """Add timestamp position embedding.

    SeqEmb has shape (sum_B(N_i), D),
    PosEmb has shape (N_p, D),
    TsEmb has shape (N_t, D),
    Out has shape (sum_B(N_i), D)
    """
    off_b = tl.program_id(0)
    off_n = tl.program_id(1)
    seq_start = tl.load(Offsets + off_b)
    seq_end = tl.load(Offsets + off_b + 1)
    seq_len = seq_end - seq_start
    start_n = off_n * BLOCK_N
    if start_n >= seq_len:
        return
    offs_n = start_n + tl.arange(0, BLOCK_N)
    offs_d = tl.arange(0, BLOCK_D)
    seq_emb_offsets = offs_n[:, None] * stride_sn + offs_d[None, :]
    SeqEmb += seq_start.to(tl.int64) * stride_sn
    mask_n = offs_n < seq_len
    # position encoding
    seq_len = tl.load(Lengths + off_b)
    if HAS_MULTIPLE_TARGETS:
        num_targets = tl.load(NumTargets + off_b)
        if INTERLEAVE_TARGETS:
            high_ind = seq_len - num_targets * 2
        else:
            high_ind = seq_len - num_targets
    else:
        high_ind = seq_len
    pos_inds = tl.where(offs_n < high_ind, offs_n, high_ind)
    pos_inds = high_ind - pos_inds + max_contextual_seq_len
    pos_inds = tl.where(pos_inds < max_pos_ind - 1, pos_inds, max_pos_ind - 1)
    pos_inds = tl.where(offs_n < max_contextual_seq_len, offs_n, pos_inds)
    if TRAINING:
        tl.store(PosInds + seq_start + offs_n, pos_inds, mask=mask_n)
    pos_emb_offsets = pos_inds[:, None] * stride_pn + offs_d[None, :]
    # timestamp encoding
    ts = tl.load(TS + seq_start + offs_n, mask=mask_n)
    query_time = tl.load(TS + seq_end - 1)
    ts = query_time - ts + time_delta
    ts = tl.where(ts > 1e-6, ts, 1e-6) / time_bucket_increments
    if TIME_BUCKET_FN == "log":
        ts = tl.log(ts)
    else:
        ts = tl.sqrt(ts)
    ts = ts * time_bucket_scale
    ts = ts.to(tl.int32)
    ts = tl.where(ts > 0, ts, 0)
    ts = tl.where(ts < num_time_buckets, ts, num_time_buckets)
    if TRAINING:
        tl.store(TsInds + seq_start + offs_n, ts, mask=mask_n)
    ts_emb_offsets = ts[:, None] * stride_tn + offs_d[None, :]
    Out += seq_start.to(tl.int64) * stride_on
    out_offsets = Out + offs_n[:, None] * stride_on + offs_d[None, :]
    for _d in range(0, D, BLOCK_D):
        mask = (offs_n[:, None] < seq_len) and offs_d[None, :] < D
        seq_emb = tl.load(SeqEmb + seq_emb_offsets, mask=mask)
        pos_emb = tl.load(PosEmb + pos_emb_offsets, mask=mask)
        ts_emb = tl.load(TsEmb + ts_emb_offsets, mask=mask)
        tl.store(out_offsets, seq_emb + (pos_emb + ts_emb).to(seq_emb.dtype), mask=mask)
        seq_emb_offsets += BLOCK_D
        pos_emb_offsets += BLOCK_D
        ts_emb_offsets += BLOCK_D
        out_offsets += BLOCK_D
        offs_d += BLOCK_D


def bwd_pre_hook(nargs):
    nargs["Out"].zero_()


def _add_embeddings_bwd_configs() -> List[triton.Config]:
    configs = []
    for BLOCK in [32, 64, 128]:
        for num_stages in [2, 3, 4]:
            for num_warps in [2, 4, 8]:
                configs.append(
                    triton.Config(
                        {
                            "BLOCK": BLOCK,
                        },
                        num_stages=num_stages,
                        num_warps=num_warps,
                        pre_hook=bwd_pre_hook,
                    )
                )
    return configs


@triton_autotune(
    configs=_add_embeddings_bwd_configs(),
    key=["AUTOTUNE_MAX_SEQ_LEN", "AUTOTUNE_B", "D"],
)
@triton.jit
def _add_embeddings_bwd_kernel(
    In,
    KeyInds,
    ValueInds,
    Out,
    AUTOTUNE_MAX_SEQ_LEN,
    B,
    AUTOTUNE_B,
    D,
    jagged_size,
    stride_in,
    stride_on,
    BLOCK_D: tl.constexpr,
    BLOCK: tl.constexpr,
):
    off_block = tl.program_id(0)
    offs_d = tl.arange(0, BLOCK_D)
    mask_d = offs_d < D
    key_ind = -1
    key_ind = key_ind.to(KeyInds.dtype.element_ty)  # pyre-ignore[16]
    accumulator = tl.zeros((BLOCK_D,), dtype=In.dtype.element_ty)
    for off_i in range(0, BLOCK):
        off = off_block * BLOCK + off_i
        if off < jagged_size:
            value_ind = tl.load(ValueInds + off)
            in_offset = In + value_ind.to(tl.int64) * stride_in
            jagged_in = tl.load(in_offset + offs_d, mask=mask_d)
            key_ind_new = tl.load(KeyInds + off)
            if key_ind == key_ind_new:
                accumulator += jagged_in
            else:
                if key_ind >= 0:
                    out_offset = Out + key_ind.to(tl.int64) * stride_on
                    tl.atomic_add(
                        out_offset + offs_d,
                        accumulator.to(Out.dtype.element_ty),
                        mask=mask_d,
                        sem="relaxed",
                    )
                key_ind = key_ind_new
                accumulator = jagged_in
    if key_ind >= 0:
        out_offset = Out + key_ind.to(tl.int64) * stride_on
        tl.atomic_add(
            out_offset + offs_d,
            accumulator.to(Out.dtype.element_ty),
            mask=mask_d,
            sem="relaxed",
        )


class _AddTimestampPositionEmbeddingsFunction(torch.autograd.Function):
    @staticmethod
    # pyre-ignore[14]
    def forward(
        ctx,
        seq_embeddings: torch.Tensor,
        seq_offsets: torch.Tensor,
        pos_embeddings: torch.Tensor,
        ts_embeddings: torch.Tensor,
        timestamps: torch.Tensor,
        max_seq_len: int,
        max_contextual_seq_len: int,
        seq_lengths: torch.Tensor,
        num_targets: Optional[torch.Tensor],
        interleave_targets: bool,
        time_bucket_fn: str,
    ):
        seq_embeddings = switch_to_contiguous_if_needed(seq_embeddings)
        pos_embeddings = switch_to_contiguous_if_needed(pos_embeddings)
        ts_embeddings = switch_to_contiguous_if_needed(ts_embeddings)

        max_pos_ind = pos_embeddings.shape[0]
        B = seq_lengths.shape[0]
        N, D = seq_embeddings.shape
        assert len(pos_embeddings.shape) == 2
        assert len(ts_embeddings.shape) == 2
        assert pos_embeddings.shape[1] == D, (
            "shape[1] of pos_embeddings much match seq_embeddings"
        )
        assert ts_embeddings.shape[1] == D, (
            "shape[1] of ts_embeddings much match seq_embeddings"
        )
        out = torch.empty_like(seq_embeddings)

        timestamps = switch_to_contiguous_if_needed(timestamps)
        ts_inds = torch.empty_like(seq_embeddings[:, 0], dtype=torch.int32)
        pos_inds = torch.empty_like(seq_embeddings[:, 0], dtype=torch.int32)

        grid = lambda meta: (  # noqa E731
            B,
            triton.cdiv(max_seq_len, meta["BLOCK_N"]),
        )
        BLOCK_D = triton.next_power_of_2(D) if D < 64 else 64
        _add_timestamp_position_embeddings_kernel[grid](
            SeqEmb=seq_embeddings,
            Offsets=seq_offsets,
            Lengths=seq_lengths,
            PosEmb=pos_embeddings,
            TsEmb=ts_embeddings,
            Out=out,
            TS=timestamps,
            PosInds=pos_inds,
            TsInds=ts_inds,
            NumTargets=num_targets,
            AUTOTUNE_MAX_SEQ_LEN=autotune_max_seq_len(max_seq_len),
            D=D,
            num_time_buckets=2048,
            time_bucket_increments=60.0,
            time_bucket_scale=1.0,
            time_delta=0,
            max_contextual_seq_len=max_contextual_seq_len,
            max_pos_ind=max_pos_ind,
            stride_sn=seq_embeddings.stride(0),
            stride_pn=pos_embeddings.stride(0),
            stride_tn=ts_embeddings.stride(0),
            stride_on=out.stride(0),
            TRAINING=True,
            HAS_MULTIPLE_TARGETS=num_targets is not None,
            INTERLEAVE_TARGETS=interleave_targets,
            TIME_BUCKET_FN=time_bucket_fn,
            BLOCK_D=BLOCK_D,
        )
        try:
            values = torch.arange(0, N, dtype=torch.int32, device=timestamps.device)
            sorted_ts_key_inds, sorted_ts_value_inds = torch.ops.hammer.sort_kv_pairs(
                ts_inds, values
            )
            sorted_pos_key_inds, sorted_pos_value_inds = torch.ops.hammer.sort_kv_pairs(
                pos_inds, values
            )
        except Exception:
            sorted_ts_key_inds, sorted_ts_value_inds = torch.sort(ts_inds)
            sorted_pos_key_inds, sorted_pos_value_inds = torch.sort(pos_inds)
        ctx.save_for_backward(
            sorted_pos_key_inds,
            sorted_pos_value_inds,
            sorted_ts_key_inds,
            sorted_ts_value_inds,
        )
        ctx.B = B
        ctx.D = D
        ctx.max_seq_len = max_seq_len
        ctx.pos_emb_size = pos_embeddings.shape[0]
        ctx.ts_emb_size = ts_embeddings.shape[0]
        ctx.pos_dtype = pos_embeddings.dtype
        ctx.ts_dtype = ts_embeddings.dtype
        return out

    @staticmethod
    # pyre-ignore[14]
    def backward(
        ctx, d_out: torch.Tensor
    ) -> Tuple[
        torch.Tensor,
        None,
        torch.Tensor,
        torch.Tensor,
        None,
        None,
        None,
        None,
        None,
        None,
        None,
    ]:
        (
            sorted_pos_key_inds,
            sorted_pos_value_inds,
            sorted_ts_key_inds,
            sorted_ts_value_inds,
        ) = ctx.saved_tensors
        d_pos_embeddings = torch.empty(
            (ctx.pos_emb_size, ctx.D), device=d_out.device, dtype=torch.float32
        )
        d_ts_embeddings = torch.empty(
            (ctx.ts_emb_size, ctx.D), device=d_out.device, dtype=torch.float32
        )
        grid = lambda meta: (triton.cdiv(d_out.shape[0], meta["BLOCK"]),)  # noqa E731
        B = ctx.B
        AUTOTUNE_B = prev_power_of_2(B)
        _add_embeddings_bwd_kernel[grid](
            In=d_out,
            KeyInds=sorted_pos_key_inds,
            ValueInds=sorted_pos_value_inds,
            Out=d_pos_embeddings,
            AUTOTUNE_MAX_SEQ_LEN=autotune_max_seq_len(ctx.max_seq_len),
            B=B,
            AUTOTUNE_B=AUTOTUNE_B,
            D=ctx.D,
            jagged_size=d_out.shape[0],
            stride_in=d_out.stride(0),
            stride_on=d_pos_embeddings.stride(0),
            BLOCK_D=triton.next_power_of_2(ctx.D),
        )
        _add_embeddings_bwd_kernel[grid](
            In=d_out,
            KeyInds=sorted_ts_key_inds,
            ValueInds=sorted_ts_value_inds,
            Out=d_ts_embeddings,
            AUTOTUNE_MAX_SEQ_LEN=autotune_max_seq_len(ctx.max_seq_len),
            B=B,
            AUTOTUNE_B=AUTOTUNE_B,
            D=ctx.D,
            jagged_size=d_out.shape[0],
            stride_in=d_out.stride(0),
            stride_on=d_ts_embeddings.stride(0),
            BLOCK_D=triton.next_power_of_2(ctx.D),
        )
        return (
            d_out,
            None,
            d_pos_embeddings.to(ctx.pos_dtype),
            d_ts_embeddings.to(ctx.ts_dtype),
            None,
            None,
            None,
            None,
            None,
            None,
            None,
        )


@torch.fx.wrap
def triton_add_position_embeddings(
    jagged: torch.Tensor,
    jagged_offsets: torch.Tensor,
    high_inds: torch.Tensor,
    max_seq_len: int,
    dense: torch.Tensor,
    scale: float = 1.0,
) -> torch.Tensor:
    return _AddPositionEmbeddingsFunction.apply(
        jagged, jagged_offsets, high_inds, max_seq_len, dense, scale
    )


@torch.fx.wrap
def triton_add_timestamp_positional_embeddings(
    seq_embeddings: torch.Tensor,
    seq_offsets: torch.Tensor,
    pos_embeddings: torch.Tensor,
    ts_embeddings: torch.Tensor,
    timestamps: torch.Tensor,
    max_seq_len: int,
    max_contextual_seq_len: int,
    seq_lengths: torch.Tensor,
    num_targets: Optional[torch.Tensor],
    interleave_targets: bool,
    time_bucket_fn: str,
) -> torch.Tensor:
    return _AddTimestampPositionEmbeddingsFunction.apply(
        seq_embeddings,
        seq_offsets,
        pos_embeddings,
        ts_embeddings,
        timestamps,
        max_seq_len,
        max_contextual_seq_len,
        seq_lengths,
        num_targets,
        interleave_targets,
        time_bucket_fn,
    )
