in main.py [0:0]
def process_main(rank, sel, fname, world_size, devices):
import os
os.environ['CUDA_VISIBLE_DEVICES'] = str(devices[rank].split(':')[-1])
import logging
logging.basicConfig()
logger = logging.getLogger()
logger.info(f'called-params {sel} {fname}')
# -- load script params
params = None
with open(fname, 'r') as y_file:
params = yaml.load(y_file, Loader=yaml.FullLoader)
logger.info('loaded params...')
if rank == 0:
pp = pprint.PrettyPrinter(indent=4)
pp.pprint(params)
if rank == 0:
dump = os.path.join(params['logging']['folder'], f'params-{sel}.yaml')
with open(dump, 'w') as f:
yaml.dump(params, f)
world_size, rank = init_distributed(rank_and_world_size=(rank, world_size))
# -- make sure all processes correctly initialized torch-distributed
logger.info(f'Running {sel} (rank: {rank}/{world_size})')
# -- turn off info-logging for ranks > 0, otherwise too much std output
if rank == 0:
logger.setLevel(logging.INFO)
else:
logger.setLevel(logging.ERROR)
if sel == 'paws_train':
return paws(params)
elif sel == 'suncet_train':
return suncet(params)
elif sel == 'fine_tune':
return fine_tune(params)
elif sel == 'snn_fine_tune':
return snn_fine_tune(params)