in main.py [0:0]
def main(args):
# fix random seeds
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
np.random.seed(args.seed)
# CNN
if args.verbose:
print('Architecture: {}'.format(args.arch))
model = models.__dict__[args.arch](sobel=args.sobel)
fd = int(model.top_layer.weight.size()[1])
model.top_layer = None
model.features = torch.nn.DataParallel(model.features)
model.cuda()
cudnn.benchmark = True
# create optimizer
optimizer = torch.optim.SGD(
filter(lambda x: x.requires_grad, model.parameters()),
lr=args.lr,
momentum=args.momentum,
weight_decay=10**args.wd,
)
# define loss function
criterion = nn.CrossEntropyLoss().cuda()
# optionally resume from a checkpoint
if args.resume:
if os.path.isfile(args.resume):
print("=> loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume)
args.start_epoch = checkpoint['epoch']
# remove top_layer parameters from checkpoint
for key in checkpoint['state_dict']:
if 'top_layer' in key:
del checkpoint['state_dict'][key]
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(args.resume))
# creating checkpoint repo
exp_check = os.path.join(args.exp, 'checkpoints')
if not os.path.isdir(exp_check):
os.makedirs(exp_check)
# creating cluster assignments log
cluster_log = Logger(os.path.join(args.exp, 'clusters'))
# preprocessing of data
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
tra = [transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize]
# load the data
end = time.time()
dataset = datasets.ImageFolder(args.data, transform=transforms.Compose(tra))
if args.verbose:
print('Load dataset: {0:.2f} s'.format(time.time() - end))
dataloader = torch.utils.data.DataLoader(dataset,
batch_size=args.batch,
num_workers=args.workers,
pin_memory=True)
# clustering algorithm to use
deepcluster = clustering.__dict__[args.clustering](args.nmb_cluster)
# training convnet with DeepCluster
for epoch in range(args.start_epoch, args.epochs):
end = time.time()
# remove head
model.top_layer = None
model.classifier = nn.Sequential(*list(model.classifier.children())[:-1])
# get the features for the whole dataset
features = compute_features(dataloader, model, len(dataset))
# cluster the features
if args.verbose:
print('Cluster the features')
clustering_loss = deepcluster.cluster(features, verbose=args.verbose)
# assign pseudo-labels
if args.verbose:
print('Assign pseudo labels')
train_dataset = clustering.cluster_assign(deepcluster.images_lists,
dataset.imgs)
# uniformly sample per target
sampler = UnifLabelSampler(int(args.reassign * len(train_dataset)),
deepcluster.images_lists)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.batch,
num_workers=args.workers,
sampler=sampler,
pin_memory=True,
)
# set last fully connected layer
mlp = list(model.classifier.children())
mlp.append(nn.ReLU(inplace=True).cuda())
model.classifier = nn.Sequential(*mlp)
model.top_layer = nn.Linear(fd, len(deepcluster.images_lists))
model.top_layer.weight.data.normal_(0, 0.01)
model.top_layer.bias.data.zero_()
model.top_layer.cuda()
# train network with clusters as pseudo-labels
end = time.time()
loss = train(train_dataloader, model, criterion, optimizer, epoch)
# print log
if args.verbose:
print('###### Epoch [{0}] ###### \n'
'Time: {1:.3f} s\n'
'Clustering loss: {2:.3f} \n'
'ConvNet loss: {3:.3f}'
.format(epoch, time.time() - end, clustering_loss, loss))
try:
nmi = normalized_mutual_info_score(
clustering.arrange_clustering(deepcluster.images_lists),
clustering.arrange_clustering(cluster_log.data[-1])
)
print('NMI against previous assignment: {0:.3f}'.format(nmi))
except IndexError:
pass
print('####################### \n')
# save running checkpoint
torch.save({'epoch': epoch + 1,
'arch': args.arch,
'state_dict': model.state_dict(),
'optimizer' : optimizer.state_dict()},
os.path.join(args.exp, 'checkpoint.pth.tar'))
# save cluster assignments
cluster_log.log(deepcluster.images_lists)