in eval_retrieval_feature_extract.py [0:0]
def main(args):
data_loader, total_num = get_loader(args)
logger.info('using data: {}'.format(len(data_loader)))
model_config_dict = dict(
num_classes=128,
mlp=True,
)
model = resnet50(**model_config_dict).cuda()
model = DistributedDataParallel(model, device_ids=[args.local_rank])
load_pretrained(args, model)
model.eval()
logger.info('model init done')
all_feat = []
all_feat_cls = np.zeros([len(data_loader)], dtype=np.int32)
with torch.no_grad():
for idx, (data, cls) in enumerate(data_loader):
logger.info('{}/{}'.format(idx, len(data_loader)))
# data: B * S * C * H * W
data = data.cuda()
feat = model(data, layer=args.layer, tsn_mode=True).view(-1)
all_feat.append(feat.data.cpu().numpy())
all_feat_cls[idx] = cls.item()
all_feat = np.stack(all_feat, axis=0)
np.save(os.path.join(args.output_dir, 'feature_{}_{}.npy'.format(args.datamode, args.local_rank)), all_feat)
np.save(os.path.join(args.output_dir, 'feature_{}_cls_{}.npy'.format(args.datamode, args.local_rank)), all_feat_cls)
if dist.get_rank() == 0:
np.save(os.path.join(args.output_dir, 'vid_num_{}.npy'.format(args.datamode)), np.array([total_num]))