# Origin:   https://github.com/predibase/lorax
# Path:     lorax/server/lorax_server/adapters/weights.py
# License:  Apache License Version 2.0, January 2004

from abc import ABC, abstractclassmethod
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, List, Optional, Set, Type

import torch


@dataclass
class AdapterBatchMetadata:
    # [batch_size]
    adapter_indices: torch.Tensor

    # [num_adapters]
    adapter_set: Set[int]

    # [num_segments + 1]
    adapter_segments: torch.Tensor

    # [num_segments]
    # maps from segment index to adapter index, i.e.:
    # segment_indices[s] == adapter_indices[i]
    segment_indices: List[int]


class AdapterWeights(ABC):
    @abstractclassmethod
    def get_batch_types(cls) -> List[Type["BatchAdapterWeights"]]:
        pass

    @property
    def speculative_tokens(self) -> int:
        return 0


class BatchAdapterWeights(ABC):
    @abstractclassmethod
    def has_adapter(self, adapter_index: int) -> bool:
        pass

    @abstractclassmethod
    def load(
        cls,
        adapter_weights: Dict[int, AdapterWeights],
        meta: "AdapterBatchMetadata",
        prefill: bool,
        prefill_head_indices: torch.Tensor,
    ) -> Optional["BatchAdapterWeights"]:
        pass


class LayerAdapterWeights:
    """Adapter weights that apply to a particular layer."""

    def __init__(self):
        self.adapter_weights: Dict[int, AdapterWeights] = {}

    def add_adapter(self, adapter_idx: int, weights: AdapterWeights):
        self.adapter_weights[adapter_idx] = weights

    def remove_adapter(self, adapter_idx: int):
        if adapter_idx not in self.adapter_weights:
            return
        del self.adapter_weights[adapter_idx]

    def is_empty(self) -> bool:
        return len(self.adapter_weights) == 0

    def get_data(
        self,
        meta: AdapterBatchMetadata,
        prefill: bool,
        prefill_head_indices: Optional[torch.Tensor],
    ) -> Dict[str, BatchAdapterWeights]:
        # bucket adapters by batch class
        adapter_batch_types: Dict[
            Type[BatchAdapterWeights], Dict[int, AdapterWeights]
        ] = defaultdict(dict)
        for adapter_index, adapter_weights in self.adapter_weights.items():
            for batch_type in adapter_weights.get_batch_types():
                adapter_batch_types[batch_type][adapter_index] = adapter_weights

        batch_data = {}
        for batch_type, adapter_weights in adapter_batch_types.items():
            batched_weights = batch_type.load(
                adapter_weights, meta, prefill, prefill_head_indices
            )
            if batched_weights is not None:
                batch_data = batched_weights
        return batch_data


@dataclass
class AdapterBatchData:
    meta: AdapterBatchMetadata

    # layer type -> adapter type -> batch weight data
    data: Dict[str, Dict[str, BatchAdapterWeights]]

    prefill: bool

    @staticmethod
    def from_meta(
        meta: AdapterBatchMetadata,
        weights: Dict[str, LayerAdapterWeights],
        prefill: bool,
        prefill_head_indices: Optional[torch.Tensor],
    ) -> "AdapterBatchData":
        data = {}
        for k, v in weights.items():
            if v.is_empty():
                continue
            data[k] = v.get_data(
                meta, prefill, prefill_head_indices if k == "lm_head" else None
            )
        return AdapterBatchData(meta=meta, data=data, prefill=prefill)

    def ranks(self) -> Set[int]:
        # TODO(travis): refactor to be less coupled to lora implementation
        ranks = set()
        for lora_data in self.data.values():
            if lora_data is None:
                continue

            for rank_data in lora_data.rank_data.values():
                ranks.add(rank_data.rank)

        return ranks

    def layer_names(self) -> Set[str]:
        return set(self.data.keys())

    def adapter_keys(self) -> Set[str]:
        adapter_keys = set()
        for layer_data in self.data.values():
            adapter_keys.update(layer_data.keys())
        return adapter_keys

    @property
    def max_rank(self) -> int:
        ranks = self.ranks()
        return max(ranks) if len(ranks) > 0 else 0
