in research/active_learning/main.py [0:0]
def main():
global args, best_acc1
args = parser.parse_args()
print(args)
if args.seed is not None:
random.seed(args.seed)
torch.manual_seed(args.seed)
cudnn.deterministic = True
warnings.warn('You have chosen to seed training. '
'This will turn on the CUDNN deterministic setting, '
'which can slow down your training considerably! '
'You may see unexpected behavior when restarting '
'from checkpoints.')
checkpoint={}
if args.resume!='':
checkpoint= load_checkpoint(args.resume)
args.loss_type= checkpoint['loss_type']
args.feat_dim= checkpoint['feat_dim']
best_accl= checkpoint['best_acc1']
db_path = os.path.join(args.train_data, os.path.basename(args.train_data)) + ".db"
print(db_path)
db = SqliteDatabase(db_path)
proxy.initialize(db)
db.connect()
"""
to use full images
train_query = Detection.select(Detection.image_id,Oracle.label,Detection.kind).join(Oracle).order_by(fn.random()).limit(limit)
train_dataset = SQLDataLoader('/lscratch/datasets/serengeti', is_training= True, num_workers= args.workers,
raw_size= args.raw_size, processed_size= args.processed_size)
"""
train_dataset = SQLDataLoader(os.path.join(args.train_data, 'crops'), is_training= True, num_workers= args.workers,
raw_size= args.raw_size, processed_size= args.processed_size)
train_dataset.setKind(DetectionKind.UserDetection.value)
if args.val_data is not None:
val_dataset = SQLDataLoader(os.path.join(args.val_data, 'crops'), is_training= False, num_workers= args.workers)
#num_classes= len(train_dataset.getClassesInfo()[0])
num_classes=args.num_classes
if args.balanced_P==-1:
args.balanced_P= num_classes
#print("Num Classes= "+str(num_classes))
if args.loss_type.lower()=='center' or args.loss_type.lower() == 'softmax':
train_loader = train_dataset.getSingleLoader(batch_size = args.batch_size)
train_embd_loader= train_loader
if args.val_data is not None:
val_loader = val_dataset.getSingleLoader(batch_size = args.batch_size)
val_embd_loader= val_loader
else:
train_loader = train_dataset.getBalancedLoader(P=args.balanced_P, K=args.balanced_K)
train_embd_loader= train_dataset.getSingleLoader(batch_size = args.batch_size)
if args.val_data is not None:
val_loader = val_dataset.getBalancedLoader(P=args.balanced_P, K=args.balanced_K)
val_embd_loader = val_dataset.getSingleLoader(batch_size = args.batch_size)
center_loss= None
if args.loss_type.lower() == 'center' or args.loss_type.lower() == 'softmax':
model = torch.nn.DataParallel(SoftmaxNet(args.arch, args.feat_dim, num_classes, use_pretrained = args.pretrained)).cuda()
if args.loss_type.lower() == 'center':
criterion = CenterLoss(num_classes = num_classes, feat_dim = args.feat_dim)
params = list(model.parameters()) + list(criterion.parameters())
else:
criterion = nn.CrossEntropyLoss().cuda()
params = model.parameters()
else:
model = torch.nn.DataParallel(NormalizedEmbeddingNet(args.arch, args.feat_dim, use_pretrained = args.pretrained)).cuda()
if args.loss_type.lower() == 'siamese':
criterion = OnlineContrastiveLoss(args.margin, HardNegativePairSelector())
else:
criterion = OnlineTripletLoss(args.margin, RandomNegativeTripletSelector(args.margin))
params = model.parameters()
# define loss function (criterion) and optimizer
optimizer = torch.optim.Adam(params, lr=args.lr, weight_decay= args.weight_decay)
#optimizer = torch.optim.SGD(params, momentum = 0.9, lr = args.lr, weight_decay = args.weight_decay)
start_epoch = 0
if checkpoint:
start_epoch= checkpoint['epoch']
model.load_state_dict(checkpoint['state_dict'])
#optimizer.load_state_dict(checkpoint['optimizer'])
if args.loss_type.lower() == 'center':
criterion.load_state_dict(checkpoint['centers'])
e= Engine(model, criterion, optimizer, verbose = True, print_freq = args.print_freq)
for epoch in range(start_epoch, args.epochs):
# train for one epoch
#adjust_lr(optimizer,epoch)
e.train_one_epoch(train_loader, epoch, True if args.loss_type.lower() == 'center' or args.loss_type.lower() == 'softmax' else False)
#if epoch % 1 == 0 and epoch > 0:
# a, b, c = e.predict(train_embd_loader, load_info = True, dim = args.feat_dim)
# plot_embedding(reduce_dimensionality(a), b, c, {})
# evaluate on validation set
if args.val_data is not None:
e.validate(val_loader, True if args.loss_type.lower() == 'center' else False)
save_checkpoint({
'epoch': epoch + 1,
'arch': args.arch,
'state_dict': model.state_dict(),
'best_acc1': best_acc1,
'optimizer' : optimizer.state_dict(),
'loss_type' : args.loss_type,
'num_classes' : args.num_classes,
'feat_dim' : args.feat_dim,
'centers': criterion.state_dict() if args.loss_type.lower() == 'center' else None
}, False, "%s%s_%s_%04d.tar"%(args.checkpoint_prefix, args.loss_type, args.arch, epoch))