in data_augmentation/my_training.py [0:0]
def run_model(params, ckpt_path=None, repo=None):
args = params
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
if args.gpu is not None:
warnings.warn(
"You have chosen a specific GPU. This will completely "
"disable data parallelism."
)
ngpus_per_node = torch.cuda.device_count()
if args.distributed:
torch.cuda.set_device(args.rank)
torch.distributed.init_process_group(
backend=args.dist_backend,
init_method="tcp://{}:{}".format("localhost", 10001),
world_size=args.world_size,
rank=args.rank,
)
main_worker(args.gpu, ngpus_per_node, args, ckpt_path, repo)
# cleanup distributed
if args.distributed:
cleanup_distributed()