in sample.py [0:0]
def main(config):
# Load checkpoint config
old_config = config
config_dir = os.path.dirname(os.path.dirname(config['checkpoint']))
config_path = os.path.join(config_dir, 'config.json')
config = configs.load_config(config_path)
# Remove multigpu flags and adjust batch size
config['multigpu'] = 0
config['batch_size'] = 1
# Overwrite config params
config['checkpoint'] = old_config['checkpoint']
if old_config['n_steps'] is not None:
config['n_steps'] = old_config['n_steps']
config['n_seqs'] = old_config['n_seqs']
config['n_samples'] = old_config['n_samples']
# Set up device
local_rank = 0
config['local_rank'] = 0
config['device'] = 'cuda:{}'.format(local_rank)
train_loader, val_loader = get_dataset(config)
print('Dataset loaded')
model = init_model(config)
print(model)
print('Model loaded')
# Define output dirs
out_dir = config_dir
samples_dir = os.path.join(out_dir, 'samples')
if not os.path.exists(samples_dir):
os.makedirs(samples_dir, exist_ok=True)
# Define saving function
def save_samples(preds, gt, ctx, out_dir, seq_id):
# Compute number of samples and sequences
seq_dir = os.path.join(samples_dir, '{:0>4}'.format(seq_id))
n_samples = len(preds)
timesteps = gt.shape[1]
# Save samples
for sample_id in range(n_samples):
sample_dir = os.path.join(seq_dir, '{:0>4}'.format(sample_id))
os.makedirs(sample_dir, exist_ok=True)
Parallel(n_jobs=20)(delayed(save_sample_png)(sample_dir, frame, f_id) for f_id, frame in enumerate(preds[sample_id]))
# Save ctx
sample_dir = os.path.join(seq_dir, 'ctx')
os.makedirs(sample_dir, exist_ok=True)
Parallel(n_jobs=20)(delayed(save_sample_png)(sample_dir, frame, f_id) for f_id, frame in enumerate(ctx[0]))
# Save gt
sample_dir = os.path.join(seq_dir, 'gt')
os.makedirs(sample_dir, exist_ok=True)
Parallel(n_jobs=20)(delayed(save_sample_png)(sample_dir, frame, f_id) for f_id, frame in enumerate(gt[0]))
model.eval()
n_seqs = 0
# for batch_idx, batch in enumerate(tqdm(val_loader, desc='Sequence loop')):
for batch_idx, batch in enumerate(val_loader):
if n_seqs >= config['n_seqs']:
break
frames, idxs = train_fns.prepare_batch(batch, config)
# Find id of the sequence and decide whether to work on it or not
sequence_id = idxs[0]
sequence_dir = os.path.join(samples_dir, '{:0>4}'.format(sequence_id))
if os.path.exists(sequence_dir):
n_seqs += frames.shape[0]
continue
os.makedirs(sequence_dir, exist_ok=True)
batch_size = 1
frames = frames.repeat(batch_size, 1, 1, 1, 1)
samples_done = 0
all_preds = []
sampling_ok = True
while samples_done < config['n_samples']:
try:
(preds, targets), _ = train_fns.sample_step(model, config, frames)
except:
sampling_ok = False
break
preds = preds[:, config['n_ctx']:].contiguous()
preds = preds.detach()
targets = targets.detach()
all_preds.append(preds)
samples_done += batch_size
if not sampling_ok:
continue
# Trim extra samples
all_preds = torch.cat(all_preds, 0)
all_preds = all_preds[:config['n_samples']]
# Convert to numpy
ctx = targets[:, :config['n_ctx']]
targets = targets[:, config['n_ctx']:]
targets = targets.detach().cpu().numpy().transpose(0, 1, 3, 4, 2)
ctx = ctx.detach().cpu().numpy().transpose(0, 1, 3, 4, 2)
all_preds = all_preds.detach().cpu().numpy().transpose(0, 1, 3, 4, 2)
# Save samples to PNG files
save_samples(all_preds, targets, ctx, out_dir, sequence_id)
# Update number of samples
n_seqs += frames.shape[0]
print('All done')