in fairnr_cli/render_multigpu.py [0:0]
def _main(args, output_file):
logging.basicConfig(
format='%(asctime)s | %(levelname)s | %(name)s | %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO,
stream=output_file,
)
logger = logging.getLogger('fairnr_cli.render')
utils.import_user_module(args)
if args.max_tokens is None and args.max_sentences is None:
args.max_tokens = 12000
logger.info(args)
use_cuda = torch.cuda.is_available() and not args.cpu
# Load dataset splits
task = tasks.setup_task(args)
task.load_dataset(args.gen_subset)
# Load ensemble
logger.info('loading model(s) from {}'.format(args.path))
models, _model_args = checkpoint_utils.load_model_ensemble(
args.path.split(os.pathsep),
arg_overrides=eval(args.model_overrides),
task=task,
)
# Optimize ensemble for generation
for model in models:
if args.fp16:
model.half()
if use_cuda:
model.cuda()
logging.info(model)
# Load dataset (possibly sharded)
itr = task.get_batch_iterator(
dataset=task.dataset(args.gen_subset),
max_tokens=args.max_tokens,
max_sentences=args.max_sentences,
max_positions=utils.resolve_max_positions(
task.max_positions(),
*[model.max_positions() for model in models]
),
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
required_batch_size_multiple=args.required_batch_size_multiple,
seed=args.seed,
num_workers=args.num_workers
).next_epoch_itr(shuffle=False)
# Initialize generator
gen_timer = StopwatchMeter()
generator = task.build_generator(args)
shard_id, world_size = args.distributed_rank, args.distributed_world_size
output_files = []
if generator.test_poses is not None:
total_frames = generator.test_poses.shape[0]
_frames = int(np.floor(total_frames / world_size))
step = shard_id * _frames
frames = _frames if shard_id < (world_size - 1) else total_frames - step
else:
step = shard_id * args.render_num_frames
frames = args.render_num_frames
with progress_bar.build_progress_bar(args, itr) as t:
wps_meter = TimeMeter()
for i, sample in enumerate(t):
sample = utils.move_to_cuda(sample) if use_cuda else sample
gen_timer.start()
step, _output_files = task.inference_step(
generator, models, [sample, step, frames])
output_files += _output_files
gen_timer.stop(500)
wps_meter.update(500)
t.log({'wps': round(wps_meter.avg)})
timestamp = generator.save_images(
output_files, steps='shard{}'.format(shard_id), combine_output=args.render_combine_output)
# join videos from all GPUs and delete temp files
try:
timestamps = distributed_utils.all_gather_list(timestamp)
except:
timestamps = [timestamp]
if shard_id == 0:
generator.merge_videos(timestamps)