def prepare_needed_items()

in sample_info/modules/ntk.py [0:0]


def prepare_needed_items(model, train_data, test_data=None, projection='none', cpu=False,
                         batch_size=256, **kwargs):
    jacobian_estimator = JacobianEstimator(projection=projection, **kwargs)
    train_jacobians = jacobian_estimator.compute_jacobian(model=model, dataset=train_data,
                                                          output_key='pred', cpu=cpu)
    test_jacobians = None
    if test_data is not None:
        test_jacobians = jacobian_estimator.compute_jacobian(model=model, dataset=test_data,
                                                             output_key='pred', cpu=cpu)

    train_init_preds = utils.apply_on_dataset(model=model, dataset=train_data, cpu=cpu,
                                              batch_size=batch_size)['pred']
    test_init_preds = None
    if test_data is not None:
        test_init_preds = utils.apply_on_dataset(model=model, dataset=test_data, cpu=cpu,
                                                 batch_size=batch_size)['pred']

    init_params = dict(model.named_parameters())
    if cpu:
        for k, v in init_params.items():
            init_params[k] = v.to('cpu')

    ntk = compute_ntk(jacobians=train_jacobians)
    lamb, _ = torch.eig(ntk)
    lamb = lamb[:, 0]
    logging.info(f'Min eigenvalue of NTK: {torch.min(lamb).item():.3f}\t'
                 f'Max eigenvalue of NTK: {torch.max(lamb).item():.3f}')
    if torch.min(lamb).item() < 0:
        logging.warning('The lowest eigenvalue of NTK is negative, consider adding at least small weight decay.')

    test_train_ntk = None
    if test_data is not None:
        test_train_ntk = compute_test_train_ntk(train_jacobians=train_jacobians,
                                                test_jacobians=test_jacobians)

    def extract_labels(data):
        ys = [utils.to_tensor(y, device=ntk.device).view((-1,)) for x, y in data]
        return torch.stack(ys).float()

    train_Y = extract_labels(train_data)
    test_Y = None
    if test_data is not None:
        test_Y = extract_labels(test_data)

    return {
        'jacobian_estimator': jacobian_estimator,
        'train_jacobians': train_jacobians,
        'test_jacobians': test_jacobians,
        'train_init_preds': train_init_preds,
        'test_init_preds': test_init_preds,
        'init_params': init_params,
        'ntk': ntk,
        'test_train_ntk': test_train_ntk,
        'train_Y': train_Y,
        'test_Y': test_Y
    }