in main_deepclusterv2.py [0:0]
def init_memory(dataloader, model):
size_memory_per_process = len(dataloader) * args.batch_size
local_memory_index = torch.zeros(size_memory_per_process).long().cuda()
local_memory_embeddings = torch.zeros(len(args.crops_for_assign), size_memory_per_process, args.feat_dim).cuda()
start_idx = 0
with torch.no_grad():
logger.info('Start initializing the memory banks')
for index, inputs in dataloader:
nmb_unique_idx = inputs[0].size(0)
index = index.cuda(non_blocking=True)
# get embeddings
outputs = []
for crop_idx in args.crops_for_assign:
inp = inputs[crop_idx].cuda(non_blocking=True)
outputs.append(model(inp)[0])
# fill the memory bank
local_memory_index[start_idx : start_idx + nmb_unique_idx] = index
for mb_idx, embeddings in enumerate(outputs):
local_memory_embeddings[mb_idx][
start_idx : start_idx + nmb_unique_idx
] = embeddings
start_idx += nmb_unique_idx
logger.info('Initializion of the memory banks done.')
return local_memory_index, local_memory_embeddings