def plot_weights()

in neural/linear/lin_model_template.py [0:0]


    def plot_weights(self, summarize=True, names_u=[]):

        if len(names_u) == 0:
            names_u = ["U Channel " + str(channel_u) for channel_u in range(self.n_channels_u)]

        if not summarize:

            # plot forcing weights
            fig, axes = plt.subplots(self.n_channels_u, 1, sharex=True)

            for channel_u in range(self.n_channels_u):

                axes[channel_u].set_title(names_u[channel_u] + " Weights")
                axes[channel_u].plot(self.weights_u[:, :, channel_u].T)

            plt.xlabel("Lags")
            plt.tight_layout()
            plt.show()
            plt.close()

            # plot recurrence weights
            # be careful: display optimized for even nb of MEG Princ Comps
            fig, axes = plt.subplots(self.n_channels_y // 2, 2, sharex=True)

            for channel_y in range(self.n_channels_y // 2):
                axes[channel_y, 0].set_title("Y Channel " + str(channel_y) + " Weights")
                axes[channel_y, 0].plot(self.weights_y[:, :, channel_y].T)

            for channel_y in range(self.n_channels_y // 2, self.n_channels_y):
                axes[channel_y - (self.n_channels_y // 2), 1].set_title("Y Channel " +
                                                                        str(channel_y) + " Weights")
                axes[channel_y - (self.n_channels_y // 2), 1].plot(
                    self.weights_y[:, :, channel_y].T)

            plt.xlabel("Lags")
            plt.tight_layout()
            plt.tight_layout()
            plt.show()
            plt.close()

        if summarize:

            fig, axes = plt.subplots(2, 1, figsize=(8.15, 3.53))

            # forcing weights
            for channel_u in range(self.n_channels_u):
                axes[0].fill_between(
                    range(self.lag_u),
                    np.mean(self.weights_u[:, :, channel_u]**2,
                            axis=0),  # mean over output channels
                    label=names_u[channel_u],
                    alpha=0.25)
                axes[0].set_title("Forcing Weights over Lags")
            axes[0].legend()

            # recurrence weights
            for channel_y in range(self.n_channels_y):
                axes[1].fill_between(
                    range(self.lag_y),
                    np.mean(self.weights_y[:, :, channel_y]**2,
                            axis=0),  # mean over output channels
                    label="channel " + str(channel_y),
                    alpha=0.25)
                axes[1].set_title("Recurrence Weights over Lags")

            plt.tight_layout()
            plt.show()
            plt.close()