in separate_vae/models/separate_clothing_encoder_models.py [0:0]
def initialize(self, opt):
self.use_vae = True and opt.isTrain # Inference time uses regular autoencoder
self.isTrain = opt.isTrain
self.gpu_ids = opt.gpu_ids
self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
# Get the labels that are clothing related
clothing_labels = np.loadtxt(opt.label_txt_path , delimiter=',', dtype=int)
self.clothing_labels = clothing_labels.tolist()
self.not_clothing_labels = []
for i in range(opt.output_nc):
if i not in self.clothing_labels:
self.not_clothing_labels.append(i)
num_clothing_irrelevant_labels = len(self.not_clothing_labels)
# Define encoders for each label, and a shared decoder
self.Separate_encoder, self.Together_encoder, self.Decoder = networks.define_separate_Es_and_D(num_clothing_irrelevant_labels, opt.output_nc, opt.nz, opt.nef,
opt.divide_by_K, opt.bottleneck, opt.n_downsample_global, opt.n_blocks_global,
opt.max_mult, opt.norm, gpu_ids=self.gpu_ids, vaeLike=self.use_vae)
# print(self.Separate_encoder)
# print(self.Together_encoder)
# print(self.Decoder)
# set up optimizer
if self.isTrain:
if opt.bottleneck == '2d': # Deprecated...
self.optimizer = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
else:
params = list(self.Separate_encoder.parameters())
params += list(self.Together_encoder.parameters())
params += list(self.Decoder.parameters())
self.optimizer = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
self.criterionMSE = torch.nn.MSELoss()
self.loss_names = ['MSE', 'KL']
self.scheduler = networks.get_scheduler(self.optimizer, opt)
self.old_lr = self.optimizer.param_groups[0]['lr']
self.opt = opt
# load networks
if not self.isTrain or opt.continue_train or opt.load_pretrain:
pretrained_path = '' if not self.isTrain else opt.load_pretrain
self.load_network(self.Separate_encoder, 'Separate_encoder', opt.which_epoch, pretrained_path)
self.load_network(self.Together_encoder, 'Together_encoder', opt.which_epoch, pretrained_path)
self.load_network(self.Decoder, 'Decoder', opt.which_epoch, pretrained_path)