def main()

in research/active_learning/main.py [0:0]


def main():
    global args, best_acc1
    args = parser.parse_args()
    print(args)

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')
    checkpoint={}
    if args.resume!='':
      checkpoint= load_checkpoint(args.resume)
      args.loss_type= checkpoint['loss_type']
      args.feat_dim= checkpoint['feat_dim']
      best_accl= checkpoint['best_acc1']

    db_path = os.path.join(args.train_data, os.path.basename(args.train_data)) + ".db"
    print(db_path)
    db = SqliteDatabase(db_path)
    proxy.initialize(db)
    db.connect()
    """
    to use full images
    train_query =  Detection.select(Detection.image_id,Oracle.label,Detection.kind).join(Oracle).order_by(fn.random()).limit(limit)
    
    train_dataset = SQLDataLoader('/lscratch/datasets/serengeti', is_training= True, num_workers= args.workers, 
            raw_size= args.raw_size, processed_size= args.processed_size)
    """
    train_dataset = SQLDataLoader(os.path.join(args.train_data, 'crops'), is_training= True, num_workers= args.workers, 
            raw_size= args.raw_size, processed_size= args.processed_size)
    train_dataset.setKind(DetectionKind.UserDetection.value)
    if args.val_data is not None:
        val_dataset = SQLDataLoader(os.path.join(args.val_data, 'crops'), is_training= False, num_workers= args.workers)
    #num_classes= len(train_dataset.getClassesInfo()[0])
    num_classes=args.num_classes
    if args.balanced_P==-1:
      args.balanced_P= num_classes
    #print("Num Classes= "+str(num_classes))
    if args.loss_type.lower()=='center' or args.loss_type.lower() == 'softmax':
      train_loader = train_dataset.getSingleLoader(batch_size = args.batch_size)
      train_embd_loader= train_loader
      if args.val_data is not None:
          val_loader = val_dataset.getSingleLoader(batch_size = args.batch_size)
          val_embd_loader= val_loader
    else:
      train_loader = train_dataset.getBalancedLoader(P=args.balanced_P, K=args.balanced_K)
      train_embd_loader= train_dataset.getSingleLoader(batch_size = args.batch_size)
      if args.val_data is not None:
          val_loader = val_dataset.getBalancedLoader(P=args.balanced_P, K=args.balanced_K)
          val_embd_loader = val_dataset.getSingleLoader(batch_size = args.batch_size)

    center_loss= None
    if args.loss_type.lower() == 'center' or args.loss_type.lower() == 'softmax':
      model = torch.nn.DataParallel(SoftmaxNet(args.arch, args.feat_dim, num_classes, use_pretrained = args.pretrained)).cuda()
      if args.loss_type.lower() == 'center':
          criterion = CenterLoss(num_classes = num_classes, feat_dim = args.feat_dim)
          params = list(model.parameters()) + list(criterion.parameters())
      else:
          criterion = nn.CrossEntropyLoss().cuda()
          params = model.parameters()
    else:
      model = torch.nn.DataParallel(NormalizedEmbeddingNet(args.arch, args.feat_dim, use_pretrained = args.pretrained)).cuda()
      if args.loss_type.lower() == 'siamese':
        criterion = OnlineContrastiveLoss(args.margin, HardNegativePairSelector())
      else:
        criterion = OnlineTripletLoss(args.margin, RandomNegativeTripletSelector(args.margin))
      params = model.parameters()

    # define loss function (criterion) and optimizer
    optimizer = torch.optim.Adam(params, lr=args.lr, weight_decay= args.weight_decay)
    #optimizer = torch.optim.SGD(params, momentum = 0.9, lr = args.lr, weight_decay = args.weight_decay)
    start_epoch = 0

    if checkpoint:
      start_epoch= checkpoint['epoch']
      model.load_state_dict(checkpoint['state_dict'])
      #optimizer.load_state_dict(checkpoint['optimizer'])
      if args.loss_type.lower() == 'center':
        criterion.load_state_dict(checkpoint['centers'])

    e= Engine(model, criterion, optimizer, verbose = True, print_freq = args.print_freq)
    for epoch in range(start_epoch, args.epochs):
        # train for one epoch
        #adjust_lr(optimizer,epoch)
        e.train_one_epoch(train_loader, epoch, True if args.loss_type.lower() == 'center' or args.loss_type.lower() == 'softmax' else False)
        #if epoch % 1 == 0 and epoch > 0:
        #    a, b, c = e.predict(train_embd_loader, load_info = True, dim = args.feat_dim)
        #    plot_embedding(reduce_dimensionality(a), b, c, {})
        # evaluate on validation set
        if args.val_data is not None:
            e.validate(val_loader, True if args.loss_type.lower() == 'center' else False)
        save_checkpoint({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'best_acc1': best_acc1,
            'optimizer' : optimizer.state_dict(),
            'loss_type' : args.loss_type,
            'num_classes' : args.num_classes,
            'feat_dim' : args.feat_dim,
            'centers': criterion.state_dict() if args.loss_type.lower() == 'center' else None
        }, False, "%s%s_%s_%04d.tar"%(args.checkpoint_prefix, args.loss_type, args.arch, epoch))