def get_weights_at_time_t()

in sample_info/modules/ntk.py [0:0]


def get_weights_at_time_t(t, eta, init_params, ntk, init_preds, Y, jacobians=None, continuous=False,
                          ntk_inv=None, large_model_regime=False, model=None, dataset=None, batch_size=256):
    """ Computes the at time t. Since storing the full Jacobian can be impossible for large networks, the code
    can work in two regimes (given by large_model_regime):
        1. [small model regime] full Jacobian should be passed using the `jacobians` argument
        2. [large model regime] `model`, `dataset`, ['example_indices'] should be specified, so we can go
        over and compute all Jacobians one by one.

    :param t: time/iteration. If t=None, then final weights are going to be returned (assuming ntk is invertible).
    :param eta: the learning rate. NOTE: if network's loss is a mean over examples (instead of sum) then
                eta=learning_rate / n_samples.
    :param init_params: a dictionary containing parameters at initialziation.
    :param jacobians: dictionary of Jacobians at initialization.
    :param ntk: the neural tangent kernel.
    :param init_preds: predictions at initialization  # (n_samples, n_outputs)
    :param Y: labels  # (n_samples, n_outputs)
    :param continuous: True => continuous GD, False => discrete GD
    :param ntk_inv: Inverse of the ntk matrix. This argument is optional.
    :param large_model_regime: which regime is it

    # the following parameters are used in the large data regime only.
    :param model: the model.
    :param dataset: the training dataset for which ntk was computed for.
    :param batch_size: the batch_size argument to pass to the DataLoader.
    """

    # check that the needed arguments are not None
    if large_model_regime:
        assert (model is not None) and (dataset is not None)
    else:
        assert (jacobians is not None)

    exp_matrix = compute_exp_matrix(t=t, eta=eta, ntk=ntk, continuous=continuous)

    # compute the ntk inverse and multiply it with f(X) - Y
    if ntk_inv is None:
        ntk_inv = torch.inverse(ntk)
    init_preds = init_preds.reshape((-1, 1))  # (n_samples * n_outputs, 1)
    Y = Y.reshape((-1, 1))  # (n_samples * n_outputs, 1)
    identity_matrix = torch.eye(ntk.shape[0], dtype=torch.float, device=ntk.device)
    rhs_vector = -torch.mm(ntk_inv, torch.mm(identity_matrix - exp_matrix, init_preds - Y))

    # Now we need to compute jacobians * rhs_vector. This corresponds to the sum of all output gradients weighted
    # with rhs_vector coefficients. Therefore, we can go over (examples, output) pairs and sum their gradients.
    out = defaultdict(lambda: None)

    # is the Jacobian is already computed we can just use it
    if not large_model_regime:
        for k, v in jacobians.items():
            v = v.reshape((v.shape[0], -1))  # (n_samples * n_outputs, n_dim)
            out[k] = torch.mm(v.T, rhs_vector)[:, 0]   # (n_dim,)
    else:  # we need to go over examples and compute the Jacobian
        model.eval()
        loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False)

        # We want to compute \frac{\partial preds}{\partial W} and take a linear combination of its columns
        # with coefficients given with rhs_vector. To not go over all examples/outputs one by one, we can compute
        # compute #\frac{\partial <preds, coefficients>}{\partial W}.

        # loop over the dataset
        example_output_index = 0
        for inputs_batch, labels_batch in tqdm(loader):
            if isinstance(inputs_batch, torch.Tensor):
                inputs_batch = [inputs_batch]
            if not isinstance(labels_batch, list):
                labels_batch = [labels_batch]

            with torch.set_grad_enabled(True):
                outputs = model.forward(inputs=inputs_batch, labels=labels_batch, loader=loader)
                preds = outputs['pred']
                preds = preds.reshape((-1, 1))  # (n_examples * n_outputs, 1)
                coefficients = rhs_vector[example_output_index:example_output_index + preds.shape[0]].to(preds.device)
                loss = torch.sum(preds * coefficients)

            cur_jacobians = torch.autograd.grad(loss, model.parameters())
            for (k, _), v in zip(model.named_parameters(), cur_jacobians):
                v = v.detach().to(ntk.device).flatten()
                if out[k] is None:
                    out[k] = v
                else:
                    out[k] += v

            example_output_index += preds.shape[0]

        assert example_output_index == rhs_vector.shape[0]

    # add the initialized values
    for k, v in init_params.items():
        out[k] += v.flatten()
        out[k] = out[k].reshape(v.shape)

    return out