# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import os
import logging
import json
from abc import ABC

from ts.torch_handler.base_handler import BaseHandler
from transformers import T5Tokenizer, T5ForConditionalGeneration

logger = logging.getLogger(__name__)

class TransformersSeqGeneration(BaseHandler, ABC):
    _LANG_MAP = {
        "ro": "Romanian",
        "fr": "French",
        "de": "German",
        "en": "English",
    }
    def __init__(self):
        super().__init__()
        self.initialized = False

    def initialize(self, ctx):
        self.manifest = ctx.manifest
        properties = ctx.system_properties
        model_dir = properties.get("model_dir")
        serialized_file = self.manifest["model"]["serializedFile"]
        model_pt_path = os.path.join(model_dir, serialized_file)
        self.device = torch.device(
            "cuda:" + str(properties.get("gpu_id"))
            if torch.cuda.is_available()
            else "cpu"
        )
        # read configs for the mode, model_name, etc. from setup_config.json
        setup_config_path = os.path.join(model_dir, "setup_config.json")
        if os.path.isfile(setup_config_path):
            with open(setup_config_path) as setup_config_file:
                self.setup_config = json.load(setup_config_file)
        else:
            logger.warning("Missing the setup_config.json file.")
        # Loading the model and tokenizer from checkpoint and config files based on the user's choice of mode
        # further setup config can be added.
        self.tokenizer = T5Tokenizer.from_pretrained(model_dir)
        if self.setup_config["save_mode"] == "torchscript":
            self.model = torch.jit.load(model_pt_path)
        elif self.setup_config["save_mode"] == "pretrained":
            self.model = T5ForConditionalGeneration.from_pretrained(model_dir)
        else:
            logger.warning("Missing the checkpoint or state_dict.")
        self.model.to(self.device)
        self.model.eval()
        logger.info("Transformer model from path %s loaded successfully", model_dir)
        self.initialized = True

    def preprocess(self, requests):
        input_batch = None
        texts_batch = []
        for idx, data in enumerate(requests):
            data = data["body"]
            input_text = data["text"]
            src_lang = data["from"]
            tgt_lang = data["to"]
            if isinstance(input_text, (bytes, bytearray)):
                input_text = input_text.decode("utf-8")
                src_lang = src_lang.decode("utf-8")
                tgt_lang = tgt_lang.decode("utf-8")
            texts_batch.append(f"translate {self._LANG_MAP[src_lang]} to {self._LANG_MAP[tgt_lang]}: {input_text}")
        inputs = self.tokenizer(texts_batch, return_tensors="pt")
        input_batch = inputs["input_ids"].to(self.device)
        return input_batch

    def inference(self, input_batch):
        generations = self.model.generate(input_batch)
        generations = self.tokenizer.batch_decode(generations, skip_special_tokens=True)
        return generations

    def postprocess(self, inference_output):
        return [{"text": text} for text in inference_output]
