gossip_sgd.py [200:259]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=args.nesterov)
    optimizer.zero_grad()

    # dictionary used to encode training state
    state = {}
    update_state(state, {
            'epoch': 0, 'itr': 0, 'best_prec1': 0, 'is_best': True,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'elapsed_time': 0,
            'batch_meter': Meter(ptag='Time').__dict__,
            'data_meter': Meter(ptag='Data').__dict__,
            'nn_meter': Meter(ptag='Forward/Backward').__dict__
    })

    # module used to relaunch jobs and handle external termination signals
    cmanager = ClusterManager(rank=args.rank,
                              world_size=args.world_size,
                              model_tag=args.tag,
                              state=state,
                              all_workers=args.checkpoint_all)

    # resume from checkpoint
    if args.resume:
        if os.path.isfile(cmanager.checkpoint_fpath):
            log.info("=> loading checkpoint '{}'"
                     .format(cmanager.checkpoint_fpath))
            checkpoint = torch.load(cmanager.checkpoint_fpath)
            update_state(state, {
                          'epoch': checkpoint['epoch'],
                          'itr': checkpoint['itr'],
                          'best_prec1': checkpoint['best_prec1'],
                          'is_best': False,
                          'state_dict': checkpoint['state_dict'],
                          'optimizer': checkpoint['optimizer'],
                          'elapsed_time': checkpoint['elapsed_time'],
                          'batch_meter': checkpoint['batch_meter'],
                          'data_meter': checkpoint['data_meter'],
                          'nn_meter': checkpoint['nn_meter']
            })
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            log.info("=> loaded checkpoint '{}' (epoch {}; itr {})"
                     .format(cmanager.checkpoint_fpath,
                             checkpoint['epoch'], checkpoint['itr']))
        else:
            log.info("=> no checkpoint found at '{}'"
                     .format(cmanager.checkpoint_fpath))

    # enable low-level optimization of compute graph using cuDNN library?
    cudnn.benchmark = True

    # meters used to compute timing stats
    batch_meter = Meter(state['batch_meter'])
    data_meter = Meter(state['data_meter'])
    nn_meter = Meter(state['nn_meter'])
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



gossip_sgd_adpsgd.py [180:239]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=args.nesterov)
    optimizer.zero_grad()

    # dictionary used to encode training state
    state = {}
    update_state(state, {
            'epoch': 0, 'itr': 0, 'best_prec1': 0, 'is_best': True,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'elapsed_time': 0,
            'batch_meter': Meter(ptag='Time').__dict__,
            'data_meter': Meter(ptag='Data').__dict__,
            'nn_meter': Meter(ptag='Forward/Backward').__dict__
    })

    # module used to relaunch jobs and handle external termination signals
    cmanager = ClusterManager(rank=args.rank,
                              world_size=args.world_size,
                              model_tag=args.tag,
                              state=state,
                              all_workers=args.checkpoint_all)

    # resume from checkpoint
    if args.resume:
        if os.path.isfile(cmanager.checkpoint_fpath):
            log.info("=> loading checkpoint '{}'"
                     .format(cmanager.checkpoint_fpath))
            checkpoint = torch.load(cmanager.checkpoint_fpath)
            update_state(state, {
                          'epoch': checkpoint['epoch'],
                          'itr': checkpoint['itr'],
                          'best_prec1': checkpoint['best_prec1'],
                          'is_best': False,
                          'state_dict': checkpoint['state_dict'],
                          'optimizer': checkpoint['optimizer'],
                          'elapsed_time': checkpoint['elapsed_time'],
                          'batch_meter': checkpoint['batch_meter'],
                          'data_meter': checkpoint['data_meter'],
                          'nn_meter': checkpoint['nn_meter']
            })
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            log.info("=> loaded checkpoint '{}' (epoch {}; itr {})"
                     .format(cmanager.checkpoint_fpath,
                             checkpoint['epoch'], checkpoint['itr']))
        else:
            log.info("=> no checkpoint found at '{}'"
                     .format(cmanager.checkpoint_fpath))

    # enable low-level optimization of compute graph using cuDNN library?
    cudnn.benchmark = True

    # meters used to compute timing stats
    batch_meter = Meter(state['batch_meter'])
    data_meter = Meter(state['data_meter'])
    nn_meter = Meter(state['nn_meter'])
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



