def plot_loss()

in ml3/shaped_sine_utils.py [0:0]


def plot_loss(extra, exp_folder, freq=0.5):
    meta = MetaNetwork()
    meta.load_state_dict(torch.load(f'{exp_folder}/ml3_loss_shaped_sine_'+str(extra)+'.pt'))
    meta.eval()

    loss_landscape = []

    theta_ranges = np.arange(-7.0, 7.0, 0.1)
    test_x = np.expand_dims(np.arange(-5.0, 5.0, 0.1), 1)
    test_y = np.sin(freq * test_x)
    x = torch.Tensor(test_x)
    y = torch.Tensor(test_y)

    for theta in theta_ranges:
        pi = ShapedSineModel(theta)
        pi.learning_rate = 0.1
        pi_out = pi(x)
        loss = 0.5 * (pi_out - y) ** 2
        loss_landscape.append(loss.mean().detach().numpy())

    meta_loss_landscape = []
    for theta in theta_ranges:
        pi = ShapedSineModel(theta)
        policy_theta = torch.Tensor(np.zeros_like(test_y)) + pi.freq
        pi_out = pi(x)
        meta_input = torch.cat([x, pi_out, y], 1)
        loss = meta(meta_input).mean()
        meta_loss_landscape.append(loss.clone().mean().detach().numpy())

    return theta_ranges, np.array(meta_loss_landscape), np.array(loss_landscape)