in separate_vae/decode_masks.py [0:0]
def batch_generation_from_update(batch_size, save_path, fname_list, features_mat, checkpoints_dir, classname, black=True):
''' Generate decoded segmentation map from input features
Args: batch_size (int), batch size for the VAE to take in
save_path (str), save generated masks to path
fname_list (list of str), save generated masks with fname
features_mat (numpy array): input features to be decoded,
in the order of fname_list
checkpoints_dir (str), load VAE weights from path
classname (str), label taxonomy defined by dataset with classname
black (boolean), black is True for regular generation;
black is False for debugging, thus the generated mask
is not in the format for cGAN input
'''
# Create and load model weights in
vae_opt = initialize_option(classname)
vae_opt.batchSize = batch_size
vae_opt.checkpoints_dir = checkpoints_dir
vae_util.mkdirs(save_path)
if vae_opt.share_decoder and vae_opt.share_encoder:
if vae_opt.separate_clothing_unrelated:
from models.separate_clothing_encoder_models import create_model as vae_create_model
else:
print('Only supports separating clothing and clothing-irrelevant')
raise NotImplementedError
else:
print('Only supports sharing encoder and decoder among all parts')
raise NotImplementedError
model = vae_create_model(vae_opt)
# Forward input into the model
dataset_size = len(fname_list)
num_batch = int(dataset_size / batch_size)
elem_in_last_batch = dataset_size % batch_size
for i in range(num_batch):
generated = model.generate_from_random(torch.Tensor(features_mat[i * batch_size: (i+1) * batch_size, :]).cuda())
for j in range(batch_size):
if black:
vae_util.save_image(vae_util.tensor2label_black(generated.data[j], vae_opt.output_nc, normalize=True), os.path.join(save_path, '%s.png' % (fname_list[i * batch_size + j])))
else:
vae_util.save_image(vae_util.tensor2label(generated.data[j], vae_opt.output_nc, normalize=True), os.path.join(save_path, '%s.png' % (fname_list[i * batch_size + j])))
# Remaining instance in the last batch needs to be generated here
if elem_in_last_batch > 0:
generated = model.generate_from_random(torch.Tensor(features_mat[-elem_in_last_batch:, :]).cuda())
for j in range(elem_in_last_batch):
if black:
vae_util.save_image(vae_util.tensor2label_black(generated.data[j], vae_opt.output_nc, normalize=True), os.path.join(save_path, '%s.png' % (fname_list[-elem_in_last_batch + j])))
else:
vae_util.save_image(vae_util.tensor2label(generated.data[j], vae_opt.output_nc, normalize=True), os.path.join(save_path, '%s.png' % (fname_list[-elem_in_last_batch + j])))