def report_correl()

in neural/visuals.py [0:0]


def report_correl(Y_true, Y_pred, path, start, ref=None):
    """
    Y_true: ndarray (n_epochs, n_times, n_channels_y)
    Y_pred: ndarray (n_epochs, n_times, n_channels_y)
    """

    r_dynamic_epochs = R_score_v2(Y_true, Y_pred, avg_out="epochs")

    if ref is not None:
        r_average_epochs = R_score_v2(Y_true, ref, avg_out="epochs")
        ratio = (r_dynamic_epochs / r_average_epochs).mean(-1)
    else:
        r_average_epochs = R_score_v2(Y_true, Y_true.mean(0, keepdims=True), avg_out="epochs")

    mse_dynamic_epochs = R_score_v2(Y_true, Y_pred, score="relativemse", avg_out="epochs")

    mse_average_epochs = R_score_v2(
        Y_true, Y_true.mean(0, keepdims=True), score="relativemse", avg_out="epochs")
    # r_scalar = R_score(Y_true,
    #                    Y_pred,
    #                    avg_out="times").mean()

    r_average_times = R_score_v2(Y_true[:, start:, :], Y_pred[:, start:, :], avg_out="times")

    r_average_times_evoked = R_score_v2(
        Y_true[:, start:, :].mean(0, keepdims=True),
        Y_pred[:, start:, :].mean(0, keepdims=True),
        avg_out="times")

    mse_average_times = R_score_v2(
        Y_true[:, start:, :], Y_pred[:, start:, :], score="relativemse", avg_out="times")

    mse_average_times_evoked = R_score_v2(
        Y_true[:, start:, :].mean(0, keepdims=True),
        Y_pred[:, start:, :].mean(0, keepdims=True),
        score="relativemse",
        avg_out="times")

    # print(r_scalar)

    fig, axes = plt.subplots(2, 4, figsize=(15, 5))

    # Mean response
    axes[0, 0].plot(Y_pred.mean(0))
    axes[0, 0].set_title("Predicted Response (Evoked)")
    axes[0, 0].axvline(x=start, ls="--")
    axes[0, 0].text(x=start, y=0, s="init")

    axes[1, 0].plot(Y_true.mean(0))
    axes[1, 0].set_title("True Response (Evoked)")

    # Reponse to one stimulus
    axes[0, 1].plot(Y_pred[0])
    axes[0, 1].set_title("Predicted Response (Epoch 0)")
    axes[0, 1].axvline(start, ls="--")
    axes[0, 1].text(x=start, y=0, s="init")

    axes[1, 1].plot(Y_true[0])
    axes[1, 1].set_title("True Response (Epoch 0)")

    # Dynamic Correlation score
    if ref is not None:
        # axes[0, 2].plot(
        #     -np.log10(1e-8 + np.clip(1 - ratio, 0, 1)),
        #     label="log10 1 - ratio of correl")
        axes[0, 2].plot(ratio)
    else:
        axes[0, 2].plot(r_dynamic_epochs.mean(-1), label="epoch-wise correlation")
        axes[0, 2].plot(r_average_epochs.mean(-1), label="baseline correlation")
        axes[0, 2].set_ylim(0, 1)
    axes[0, 2].legend()
    axes[0, 2].set_title("Correlation along time")
    axes[0, 2].locator_params(axis='x', nbins=20)
    axes[0, 2].locator_params(axis='y', nbins=10)
    axes[0, 2].grid()
    axes[0, 2].axvline(start, ls="--")
    axes[0, 2].text(x=start, y=0, s="init")

    # Scalar Correlation score
    # axes[1, 2].bar([0, 1, 2], [0, r_scalar, 0])
    # axes[1, 2].set_title("Correlation Score")

    # Distributional Correlation score
    # epoched
    scores = r_average_times.T.flatten()
    pca_labels = np.concatenate(
        [[idx] * r_average_times.shape[0] for idx in range(r_average_times.shape[1])])
    df = pd.DataFrame({"scores": scores, "pca_labels": pca_labels})
    sns.boxplot(x="pca_labels", y="scores", data=df, ax=axes[1, 2])
    axes[1, 2].set_title("Overall Correlation")
    # evoked
    scores = r_average_times_evoked.mean(0)
    pca_labels = np.arange(r_average_times.shape[-1])
    axes[1, 2].plot(pca_labels, scores, label="corr of the trial-mean")
    axes[1, 2].legend(bbox_to_anchor=(0, 1), loc='upper left', ncol=1)

    # Dynamic MSE score
    axes[0, 3].plot(mse_dynamic_epochs.mean(-1), label="epoch-wise mse")
    # axes[0, 2].plot(r_dynamic_evoked, label="evoked-wise correlation")
    axes[0, 3].plot(mse_average_epochs.mean(-1), label="baseline mse")
    axes[0, 3].legend()
    axes[0, 3].set_title("Relative MSE along time")
    axes[0, 3].set_ylim(0, 1)
    axes[0, 3].locator_params(axis='x', nbins=20)
    axes[0, 3].locator_params(axis='y', nbins=10)
    axes[0, 3].grid()
    axes[0, 3].axvline(start, ls="--")
    axes[0, 3].text(x=start, y=0, s="init")

    # Distributional MSE score
    # epoched
    scores = mse_average_times.T.flatten()
    pca_labels = np.concatenate(
        [[idx] * mse_average_times.shape[0] for idx in range(mse_average_times.shape[1])])
    df = pd.DataFrame({"scores": scores, "pca_labels": pca_labels})
    sns.boxplot(x="pca_labels", y="scores", data=df, ax=axes[1, 3])
    axes[1, 3].set_title("Overall Relative MSE")
    # evoked
    scores = mse_average_times_evoked.mean(0)
    pca_labels = np.arange(mse_average_times.shape[-1])
    axes[1, 3].plot(pca_labels, scores, label="rel. MSE of the trial-mean")
    axes[1, 3].legend(bbox_to_anchor=(0, 1), loc='upper left', ncol=1)

    plt.tight_layout()
    plt.savefig(path)
    plt.close()