import copy
import torch
import torch.nn as nn

from dassl.optim import build_optimizer, build_lr_scheduler
from dassl.utils import check_isfile, count_num_param, open_specified_layers
from dassl.engine import TRAINER_REGISTRY, TrainerXU
from dassl.modeling import build_head


@TRAINER_REGISTRY.register()
class ADDA(TrainerXU):
    """Adversarial Discriminative Domain Adaptation.

    https://arxiv.org/abs/1702.05464.
    """

    def __init__(self, cfg):
        super().__init__(cfg)
        self.open_layers = ["backbone"]
        if isinstance(self.model.head, nn.Module):
            self.open_layers.append("head")

        self.source_model = copy.deepcopy(self.model)
        self.source_model.eval()
        for param in self.source_model.parameters():
            param.requires_grad_(False)

        self.build_critic()

        self.bce = nn.BCEWithLogitsLoss()

    def check_cfg(self, cfg):
        assert check_isfile(
            cfg.MODEL.INIT_WEIGHTS
        ), "The weights of source model must be provided"

    def build_critic(self):
        cfg = self.cfg

        print("Building critic network")
        fdim = self.model.fdim
        critic_body = build_head(
            "mlp",
            verbose=cfg.VERBOSE,
            in_features=fdim,
            hidden_layers=[fdim, fdim // 2],
            activation="leaky_relu",
        )
        self.critic = nn.Sequential(critic_body, nn.Linear(fdim // 2, 1))
        print("# params: {:,}".format(count_num_param(self.critic)))
        self.critic.to(self.device)
        self.optim_c = build_optimizer(self.critic, cfg.OPTIM)
        self.sched_c = build_lr_scheduler(self.optim_c, cfg.OPTIM)
        self.register_model("critic", self.critic, self.optim_c, self.sched_c)

    def forward_backward(self, batch_x, batch_u):
        open_specified_layers(self.model, self.open_layers)
        input_x, _, input_u = self.parse_batch_train(batch_x, batch_u)
        domain_x = torch.ones(input_x.shape[0], 1).to(self.device)
        domain_u = torch.zeros(input_u.shape[0], 1).to(self.device)

        _, feat_x = self.source_model(input_x, return_feature=True)
        _, feat_u = self.model(input_u, return_feature=True)

        logit_xd = self.critic(feat_x)
        logit_ud = self.critic(feat_u.detach())

        loss_critic = self.bce(logit_xd, domain_x)
        loss_critic += self.bce(logit_ud, domain_u)
        self.model_backward_and_update(loss_critic, "critic")

        logit_ud = self.critic(feat_u)
        loss_model = self.bce(logit_ud, 1 - domain_u)
        self.model_backward_and_update(loss_model, "model")

        loss_summary = {
            "loss_critic": loss_critic.item(),
            "loss_model": loss_model.item(),
        }

        if (self.batch_idx + 1) == self.num_batches:
            self.update_lr()

        return loss_summary
