def main()

in visu/activ-retrieval.py [0:0]


def main(args):
    # create repo
    repo = os.path.join(args.exp, 'conv' + str(args.conv))
    if not os.path.isdir(repo):
        os.makedirs(repo)

    # build model
    model = load_model(args.model)
    model.cuda()
    for params in model.parameters():
        params.requires_grad = False
    model.eval()

    #load data
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    tra = [transforms.Resize(256),
           transforms.CenterCrop(224),
           transforms.ToTensor(),
           normalize]

    # dataset
    dataset = datasets.ImageFolder(args.data, transform=transforms.Compose(tra))
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=256,
                                             num_workers=args.workers)

    # keys are filters and value are arrays with activation scores for the whole dataset
    layers_activations = {}
    for i, (input_tensor, _) in enumerate(dataloader):
        input_var = torch.autograd.Variable(input_tensor.cuda(), volatile=True)
        activations = forward(model, args.conv, input_var)

        if i == 0:
            layers_activations = {filt: np.zeros(len(dataset)) for filt in activations}
        if i < len(dataloader) - 1:
            e_idx = (i + 1) * 256
        else:
            e_idx = len(dataset)
        s_idx = i * 256
        for filt in activations:
            layers_activations[filt][s_idx: e_idx] = activations[filt].cpu().data.numpy()

        if i % 100 == 0:
            print('{0}/{1}'.format(i, len(dataloader)))

    # save top N images for each filter
    for filt in layers_activations:
        repofilter = os.path.join(repo, filt)
        if not os.path.isdir(repofilter):
            os.mkdir(repofilter)
        top = np.argsort(layers_activations[filt])[::-1]
        if args.count > 0:
            top = top[:args.count]

        for pos, img in enumerate(top):
            src, _ = dataset.imgs[img]
            copyfile(src, os.path.join(repofilter, "{}_{}".format(pos, src.split('/')[-1])))