in gossip_sgd_adpsgd.py [0:0]
def main():
global args, state, log
args = parse_args()
log = make_logger(args.rank, args.verbose)
log.info('args: {}'.format(args))
log.info(socket.gethostname())
# seed for reproducibility
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.backends.cudnn.deterministic = True
# init model, loss, and optimizer
model = init_model()
assert args.bilat and not args.all_reduce
model = BilatGossipDataParallel(
model, master_addr=args.master_addr, master_port=args.master_port,
backend=args.backend, world_size=args.world_size, rank=args.rank,
graph_class=args.graph_class, mixing_class=args.mixing_class,
comm_device=args.comm_device, lr=args.lr, momentum=args.momentum,
weight_decay=args.weight_decay, nesterov=args.nesterov,
verbose=args.verbose, num_peers=args.ppi_schedule[0],
network_interface_type=args.network_interface_type,
tcp_interface_name=get_tcp_interface_name(
network_interface_type=args.network_interface_type
)
)
criterion = nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.SGD(model.parameters(),
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay,
nesterov=args.nesterov)
optimizer.zero_grad()
# dictionary used to encode training state
state = {}
update_state(state, {
'epoch': 0, 'itr': 0, 'best_prec1': 0, 'is_best': True,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
'elapsed_time': 0,
'batch_meter': Meter(ptag='Time').__dict__,
'data_meter': Meter(ptag='Data').__dict__,
'nn_meter': Meter(ptag='Forward/Backward').__dict__
})
# module used to relaunch jobs and handle external termination signals
cmanager = ClusterManager(rank=args.rank,
world_size=args.world_size,
model_tag=args.tag,
state=state,
all_workers=args.checkpoint_all)
# resume from checkpoint
if args.resume:
if os.path.isfile(cmanager.checkpoint_fpath):
log.info("=> loading checkpoint '{}'"
.format(cmanager.checkpoint_fpath))
checkpoint = torch.load(cmanager.checkpoint_fpath)
update_state(state, {
'epoch': checkpoint['epoch'],
'itr': checkpoint['itr'],
'best_prec1': checkpoint['best_prec1'],
'is_best': False,
'state_dict': checkpoint['state_dict'],
'optimizer': checkpoint['optimizer'],
'elapsed_time': checkpoint['elapsed_time'],
'batch_meter': checkpoint['batch_meter'],
'data_meter': checkpoint['data_meter'],
'nn_meter': checkpoint['nn_meter']
})
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
log.info("=> loaded checkpoint '{}' (epoch {}; itr {})"
.format(cmanager.checkpoint_fpath,
checkpoint['epoch'], checkpoint['itr']))
else:
log.info("=> no checkpoint found at '{}'"
.format(cmanager.checkpoint_fpath))
# enable low-level optimization of compute graph using cuDNN library?
cudnn.benchmark = True
# meters used to compute timing stats
batch_meter = Meter(state['batch_meter'])
data_meter = Meter(state['data_meter'])
nn_meter = Meter(state['nn_meter'])
# initalize log file
if not args.resume:
with open(args.out_fname, 'w') as f:
print('BEGIN-TRAINING\n'
'World-Size,{ws}\n'
'Num-DLWorkers,{nw}\n'
'Batch-Size,{bs}\n'
'Epoch,itr,BT(s),avg:BT(s),std:BT(s),'
'NT(s),avg:NT(s),std:NT(s),'
'DT(s),avg:DT(s),std:DT(s),'
'Loss,avg:Loss,Prec@1,avg:Prec@1,Prec@5,avg:Prec@5,val'
.format(ws=args.world_size,
nw=args.num_dataloader_workers,
bs=args.batch_size), file=f)
# create distributed data loaders
loader, sampler = make_dataloader(args, train=True)
if not args.train_fast:
val_loader = make_dataloader(args, train=False)
# start all agents' training loop at same time
model.block()
start_itr = state['itr']
start_epoch = state['epoch']
elapsed_time = state['elapsed_time']
begin_time = time.time() - state['elapsed_time']
epoch = start_epoch
stopping_criterion = epoch >= args.num_epochs
while not stopping_criterion:
# deterministic seed used to load agent's subset of data
sampler.set_epoch(epoch + args.seed * 90)
train(model, criterion, optimizer,
batch_meter, data_meter, nn_meter,
loader, epoch, start_itr, begin_time)
start_itr = 0
if not args.train_fast:
# update state after each epoch
elapsed_time = time.time() - begin_time
update_state(state, {
'epoch': epoch + 1, 'itr': start_itr,
'is_best': False,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
'elapsed_time': elapsed_time,
'batch_meter': batch_meter.__dict__,
'data_meter': data_meter.__dict__,
'nn_meter': nn_meter.__dict__
})
# evaluate on validation set and save checkpoint
prec1 = validate(val_loader, model, criterion)
with open(args.out_fname, '+a') as f:
print('{ep},{itr},{bt},{nt},{dt},'
'{filler},{filler},'
'{filler},{filler},'
'{filler},{filler},'
'{val}'
.format(ep=epoch, itr=-1,
bt=batch_meter,
dt=data_meter, nt=nn_meter,
filler=-1, val=prec1), file=f)
cmanager.save_checkpoint()
# sycnhronize models at the end of validation run
model.block()
epoch += 1
stopping_criterion = args.global_epoch >= args.num_epochs
if args.train_fast:
val_loader = make_dataloader(args, train=False)
prec1 = validate(val_loader, model, criterion)
log.info('Test accuracy: {}'.format(prec1))
log.info('elapsed_time {0}'.format(elapsed_time))