in domainbed/algorithms.py [0:0]
def __init__(self, input_shape, num_classes, num_domains,
hparams, conditional, class_balance):
super(AbstractDANN, self).__init__(input_shape, num_classes, num_domains,
hparams)
self.register_buffer('update_count', torch.tensor([0]))
self.conditional = conditional
self.class_balance = class_balance
# Algorithms
self.featurizer = networks.Featurizer(input_shape, self.hparams)
self.classifier = networks.Classifier(
self.featurizer.n_outputs,
num_classes,
self.hparams['nonlinear_classifier'])
self.discriminator = networks.MLP(self.featurizer.n_outputs,
num_domains, self.hparams)
self.class_embeddings = nn.Embedding(num_classes,
self.featurizer.n_outputs)
# Optimizers
self.disc_opt = torch.optim.Adam(
(list(self.discriminator.parameters()) +
list(self.class_embeddings.parameters())),
lr=self.hparams["lr_d"],
weight_decay=self.hparams['weight_decay_d'],
betas=(self.hparams['beta1'], 0.9))
self.gen_opt = torch.optim.Adam(
(list(self.featurizer.parameters()) +
list(self.classifier.parameters())),
lr=self.hparams["lr_g"],
weight_decay=self.hparams['weight_decay_g'],
betas=(self.hparams['beta1'], 0.9))