def construct_tensor_dataset()

in svhn_based_exp/utils.py [0:0]


def construct_tensor_dataset(dataloader, nets, repres, models, batch_size, device="cuda"):
    num_data = len(dataloader.dataset)
    output_tensor = None
    label_tensor = None
    with torch.no_grad():
        for net, model in zip(nets, models):
            net.eval()
        for batch_idx, (inputs, targets) in enumerate(dataloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = torch.cat([get_output_from_model(inputs, net, repre, model, device) for net, repre, model in zip(nets, repres, models)], dim=1)
            if output_tensor is None:
                output_tensor = torch.zeros([num_data, outputs.size()[-1]])
                label_tensor = torch.zeros([num_data])
            output_tensor[batch_size * batch_idx : batch_size * (batch_idx + 1)] = outputs
            label_tensor[batch_size * batch_idx : batch_size * (batch_idx + 1)] = targets
    label_tensor = label_tensor.long()

    tensor_dataset = data_utils.TensorDataset(output_tensor, label_tensor)
    return tensor_dataset