in eval_linear.py [0:0]
def main():
global args
args = parser.parse_args()
#fix random seeds
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
np.random.seed(args.seed)
best_prec1 = 0
# load model
model = load_model(args.model)
model.cuda()
cudnn.benchmark = True
# freeze the features layers
for param in model.features.parameters():
param.requires_grad = False
# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss().cuda()
# data loading code
traindir = os.path.join(args.data, 'train')
valdir = os.path.join(args.data, 'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
if args.tencrops:
transformations_val = [
transforms.Resize(256),
transforms.TenCrop(224),
transforms.Lambda(lambda crops: torch.stack([normalize(transforms.ToTensor()(crop)) for crop in crops])),
]
else:
transformations_val = [transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize]
transformations_train = [transforms.Resize(256),
transforms.CenterCrop(256),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize]
train_dataset = datasets.ImageFolder(
traindir,
transform=transforms.Compose(transformations_train)
)
val_dataset = datasets.ImageFolder(
valdir,
transform=transforms.Compose(transformations_val)
)
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.workers,
pin_memory=True)
val_loader = torch.utils.data.DataLoader(val_dataset,
batch_size=int(args.batch_size/2),
shuffle=False,
num_workers=args.workers)
# logistic regression
reglog = RegLog(args.conv, len(train_dataset.classes)).cuda()
optimizer = torch.optim.SGD(
filter(lambda x: x.requires_grad, reglog.parameters()),
args.lr,
momentum=args.momentum,
weight_decay=10**args.weight_decay
)
# create logs
exp_log = os.path.join(args.exp, 'log')
if not os.path.isdir(exp_log):
os.makedirs(exp_log)
loss_log = Logger(os.path.join(exp_log, 'loss_log'))
prec1_log = Logger(os.path.join(exp_log, 'prec1'))
prec5_log = Logger(os.path.join(exp_log, 'prec5'))
for epoch in range(args.epochs):
end = time.time()
# train for one epoch
train(train_loader, model, reglog, criterion, optimizer, epoch)
# evaluate on validation set
prec1, prec5, loss = validate(val_loader, model, reglog, criterion)
loss_log.log(loss)
prec1_log.log(prec1)
prec5_log.log(prec5)
# remember best prec@1 and save checkpoint
is_best = prec1 > best_prec1
best_prec1 = max(prec1, best_prec1)
if is_best:
filename = 'model_best.pth.tar'
else:
filename = 'checkpoint.pth.tar'
torch.save({
'epoch': epoch + 1,
'arch': 'alexnet',
'state_dict': model.state_dict(),
'prec5': prec5,
'best_prec1': best_prec1,
'optimizer' : optimizer.state_dict(),
}, os.path.join(args.exp, filename))