# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import time
from typing import Union

import mlflow
import numpy as np
import torch.multiprocessing as mp
from ml_collections import config_dict

import pyrit.auxiliary_attacks.gcg.attack.gcg.gcg_attack as attack_lib
from pyrit.auxiliary_attacks.gcg.attack.base.attack_manager import (
    IndividualPromptAttack,
    ProgressiveMultiPromptAttack,
    get_goals_and_targets,
    get_workers,
)
from pyrit.auxiliary_attacks.gcg.experiments.log import (
    log_gpu_memory,
    log_params,
    log_train_goals,
)


class GreedyCoordinateGradientAdversarialSuffixGenerator:
    def __init__(self):
        if mp.get_start_method(allow_none=True) != "spawn":
            mp.set_start_method("spawn")

    def generate_suffix(
        self,
        *,
        token: str = "",
        tokenizer_paths: list = [],
        model_name: str = "",
        model_paths: list = [],
        conversation_templates: list = [],
        result_prefix: str = "",
        train_data: str = "",
        control_init: str = "! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !",
        n_train_data: int = 50,
        n_steps: int = 500,
        test_steps: int = 50,
        batch_size: int = 512,
        transfer: bool = False,
        target_weight: float = 1.0,
        control_weight: float = 0.0,
        progressive_goals: bool = False,
        progressive_models: bool = False,
        anneal: bool = False,
        incr_control: bool = False,
        stop_on_success: bool = False,
        verbose: bool = True,
        allow_non_ascii: bool = False,
        num_train_models: int = 1,
        devices: list = ["cuda:0"],
        model_kwargs: list = [{"low_cpu_mem_usage": True, "use_cache": False}],
        tokenizer_kwargs: list = [{"use_fast": False}],
        n_test_data: int = 0,
        test_data: str = "",
        lr: float = 0.01,
        topk: int = 256,
        temp: int = 1,
        filter_cand: bool = True,
        gbda_deterministic: bool = True,
        logfile: str = "",
        random_seed: int = 42,
    ):

        params = config_dict.ConfigDict()
        params.result_prefix = result_prefix
        params.train_data = train_data
        params.control_init = control_init
        params.n_train_data = n_train_data
        params.n_steps = n_steps
        params.test_steps = test_steps
        params.batch_size = batch_size
        params.transfer = transfer
        params.target_weight = target_weight
        params.control_weight = control_weight
        params.progressive_goals = progressive_goals
        params.progressive_models = progressive_models
        params.anneal = anneal
        params.incr_control = incr_control
        params.stop_on_success = stop_on_success
        params.verbose = verbose
        params.allow_non_ascii = allow_non_ascii
        params.num_train_models = num_train_models
        params.tokenizer_paths = tokenizer_paths
        params.tokenizer_kwargs = tokenizer_kwargs
        params.model_name = model_name
        params.model_paths = model_paths
        params.model_kwargs = model_kwargs
        params.conversation_templates = conversation_templates
        params.devices = devices
        params.n_test_data = n_test_data
        params.test_data = test_data
        params.lr = lr
        params.topk = topk
        params.temp = temp
        params.filter_cand = filter_cand
        params.gbda_deterministic = gbda_deterministic
        params.token = token
        params.logfile = logfile
        params.random_seed = random_seed
        print(params)

        # Start mlflow logging
        mlflow.start_run()
        log_gpu_memory(step=0)
        log_params(params=params)

        train_goals, train_targets, test_goals, test_targets = get_goals_and_targets(params)
        log_train_goals(train_goals=train_goals)

        def process_fn(s):
            return s.replace("Sure, h", "H")

        def process_fn2(s):
            return s.replace("Sure, here is", "Sure, here's")

        train_targets = [process_fn(t) if np.random.random() < 0.5 else process_fn2(t) for t in train_targets]
        test_targets = [process_fn(t) if np.random.random() < 0.5 else process_fn2(t) for t in test_targets]

        workers, test_workers = get_workers(params)
        managers = {
            "AP": attack_lib.GCGAttackPrompt,
            "PM": attack_lib.GCGPromptManager,
            "MPA": attack_lib.GCGMultiPromptAttack,
        }

        timestamp = time.strftime("%Y%m%d-%H%M%S")
        attack: Union[ProgressiveMultiPromptAttack, IndividualPromptAttack]
        if params.transfer:
            attack = ProgressiveMultiPromptAttack(
                train_goals,
                train_targets,
                workers,
                progressive_models=params.progressive_models,
                progressive_goals=params.progressive_goals,
                control_init=params.control_init,
                logfile=f"{params.result_prefix}_{timestamp}.json",
                managers=managers,
                test_goals=test_goals,
                test_targets=test_targets,
                test_workers=test_workers,
                mpa_deterministic=params.gbda_deterministic,
                mpa_lr=params.lr,
                mpa_batch_size=params.batch_size,
                mpa_n_steps=params.n_steps,
            )
        else:
            attack = IndividualPromptAttack(
                train_goals,
                train_targets,
                workers,
                control_init=params.control_init,
                logfile=f"{params.result_prefix}_{timestamp}.json",
                managers=managers,
                test_goals=getattr(params, "test_goals", []),
                test_targets=getattr(params, "test_targets", []),
                test_workers=test_workers,
                mpa_deterministic=params.gbda_deterministic,
                mpa_lr=params.lr,
                mpa_batch_size=params.batch_size,
                mpa_n_steps=params.n_steps,
            )
        attack.run(
            n_steps=params.n_steps,
            batch_size=params.batch_size,
            topk=params.topk,
            temp=params.temp,
            target_weight=params.target_weight,
            control_weight=params.control_weight,
            test_steps=getattr(params, "test_steps", 1),
            anneal=params.anneal,
            incr_control=params.incr_control,
            stop_on_success=params.stop_on_success,
            verbose=params.verbose,
            filter_cand=params.filter_cand,
            allow_non_ascii=params.allow_non_ascii,
        )

        for worker in workers + test_workers:
            worker.stop()
