def report_correl_all()

in neural/linear/stats.py [0:0]


def report_correl_all(Y_trues, Y_preds, U_trues, path):
    """
    Y_trues: list of ndarray (n_epochs, n_times, n_channels_y)
    Y_preds: list of ndarray (n_epochs, n_times, n_channels_y)
    U_trues: list of ndarray (n_epochs, n_times, n_channels_u)
    """

    rs_dynamic_epochs = list()
    rs_dynamic_evoked = list()
    rs_scalar = list()

    for subject, example in enumerate(zip(Y_trues, Y_preds, U_trues)):

        # unpack
        Y_true, Y_pred, U_true = example

        # calculate metrics for a given subject
        r_dynamic_epochs = R_score(Y_true, Y_pred, avg_out="epochs")

        r_dynamic_evoked = R_score(
            Y_true.mean(0, keepdims=True), Y_pred.mean(0, keepdims=True), avg_out="epochs")

        r_scalar = R_score(Y_true, Y_pred, avg_out="times").mean()

        # record these metrics
        rs_dynamic_epochs.append(r_dynamic_epochs)
        rs_dynamic_evoked.append(r_dynamic_evoked)
        rs_scalar.append(r_scalar)

    r_dynamic_epochs_mean = np.mean(rs_dynamic_epochs, axis=0)
    r_dynamic_epochs_std = np.std(rs_dynamic_epochs, axis=0)

    r_dynamic_evoked_mean = np.mean(rs_dynamic_evoked, axis=0)
    r_dynamic_evoked_std = np.std(rs_dynamic_evoked, axis=0)

    r_scalar_mean = np.mean(rs_scalar)
    r_scalar_std = np.std(rs_scalar)

    fig, axes = plt.subplots(3, 3, figsize=(15, 5))

    # Mean response for a subject
    axes[0, 0].plot(Y_preds[0].mean(0))
    axes[0, 0].set_title("A Predicted Response (Evoked)")

    axes[1, 0].plot(Y_trues[0].mean(0))
    axes[1, 0].set_title("A True Response (Evoked)")

    axes[2, 0].plot(U_trues[0].mean(0)[:, 0], label="word presence")
    axes[2, 0].set_title("Stimulus (only onset shown here)")
    axes[2, 0].legend()

    # Reponse to one stimulus for a subject
    axes[0, 1].plot(Y_preds[0][0])
    axes[0, 1].set_title("A Predicted Response (Epoch 0)")

    axes[1, 1].plot(Y_trues[0][0])
    axes[1, 1].set_title("A True Response (Epoch 0)")

    axes[2, 1].plot(U_trues[0][0][:, 0], label="word presence")
    axes[2, 1].plot(U_trues[0][0][:, 1], label="word length")
    axes[2, 1].plot(U_trues[0][0][:, 2], label="word frequency")
    axes[2, 1].set_title("A Stimulus (all features)")
    axes[2, 1].legend(loc="upper right")

    # Dynamic Correlation score
    axes[0, 2].plot(r_dynamic_epochs_mean, label="epoch-wise correlation", color="#B03A2E")
    axes[0, 2].fill_between(
        range(r_dynamic_epochs_mean.size),
        r_dynamic_epochs_mean - r_dynamic_epochs_std,
        r_dynamic_epochs + r_dynamic_epochs_std,
        color="#F1948A",
        alpha=0.5)

    axes[0, 2].plot(r_dynamic_evoked_mean, label="evoked-wise correlation", color="#2874A6")
    axes[0, 2].fill_between(
        range(r_dynamic_evoked_mean.size),
        r_dynamic_evoked_mean - r_dynamic_evoked_std,
        r_dynamic_evoked_mean + r_dynamic_evoked_std,
        color="#85C1E9",
        alpha=0.5)

    axes[0, 2].legend()
    axes[0, 2].set_title("Correlation along time")

    # Scalar Correlation score
    axes[1, 2].bar([0, 1, 2], [0, r_scalar_mean, 0], yerr=[0, r_scalar_std, 0])
    axes[1, 2].set_title("Correlation Score")

    plt.tight_layout()
    plt.savefig(path)
    plt.close()