in models/feat_pool.py [0:0]
def update(self, features, labels):
if self.queue.device != features.device:
# self.queue = self.queue.to(features.device)
features = features.to(self.device)
if self.queue.dtype != features.dtype:
self.queue = self.queue.type_as(features)
unique_labels = torch.unique(labels)
unique_indices = (unique_labels.view(-1, 1) == labels.view(1, -1)).int().argmax(dim=1)
self.queue[unique_labels] = torch.cat((self.queue[unique_labels, 1:, :], features[unique_indices][:, None, :]), 1)
self.class_ptr[unique_labels] = (self.class_ptr[unique_labels] + 1).clamp(max=self.sample_num)