def plot_loss()

in crop_yield_prediction/plot/plot_loss.py [0:0]


def plot_loss(params):
    out_dir = '../../results/spatial_temporal/plots/{}'.format(params[:-4])
    os.makedirs(out_dir, exist_ok=True)
    prediction_log = '../../results/spatial_temporal/prediction_logs/{}'.format(params)
    train_epochs_dic = defaultdict(lambda: defaultdict(list))
    train_loss_dic, train_super_loss_dic, train_unsuper_loss_dic = (defaultdict(lambda: defaultdict(list)) for _ in range(3))
    valid_loss_dic, valid_super_loss_dic, valid_unsuper_loss_dic = (defaultdict(lambda: defaultdict(list)) for _ in range(3))
    valid_l_n_loss_dic, valid_l_d_loss_dic, valid_l_nd_loss_dic, valid_sn_loss_dic, valid_tn_loss_dic, valid_norm_loss_dic = \
        (defaultdict(lambda: defaultdict(list)) for _ in range(6))
    valid_rmse_dic, valid_r2_dic, valid_corr_dic = (defaultdict(lambda: defaultdict(list)) for _ in range(3))
    test_epochs_dic = defaultdict(lambda: defaultdict(list))
    test_rmse_dic, test_r2_dic, test_corr_dic = (defaultdict(lambda: defaultdict(list)) for _ in range(3))

    exp = 0
    year = 0
    with open(prediction_log) as f:
        content = f.readlines()
        for line in content:
            line = line.strip()
            if line.startswith('Predict'):
                year = int(line.split()[2][:4])
            if line.startswith('Experiment'):
                exp = int(line.split()[1])
            if 'Epoch' in line:
                train_epochs_dic[year][exp].append(int(line.split()[2]))
            if 'Training' in line:
                ws = line.split()
                train_loss_dic[year][exp].append(float(ws[4][:-1]))
                train_super_loss_dic[year][exp].append(float(ws[7][:-1]))
                train_unsuper_loss_dic[year][exp].append(float(ws[10][:-1]))
            if 'Validation' in line:
                ws = line.split()
                valid_loss_dic[year][exp].append(float(ws[3][:-1]))
                valid_super_loss_dic[year][exp].append(float(ws[6][:-1]))
                valid_unsuper_loss_dic[year][exp].append(float(ws[9][:-1]))
                valid_l_n_loss_dic[year][exp].append(float(ws[12][:-1]))
                valid_l_d_loss_dic[year][exp].append(float(ws[15][:-1]))
                valid_l_nd_loss_dic[year][exp].append(float(ws[18][:-1]))
                valid_sn_loss_dic[year][exp].append(float(ws[20][:-1]))
                valid_tn_loss_dic[year][exp].append(float(ws[22][:-1]))
                valid_norm_loss_dic[year][exp].append(float(ws[24][:-1]))
                valid_rmse_dic[year][exp].append(float(ws[26][:-1]))
                valid_r2_dic[year][exp].append(float(ws[28][:-1]))
                valid_corr_dic[year][exp].append(float(ws[30][:-1]))
            if '(Test)' in line and 'epoch' in line:
                ws = line.split()
                test_epochs_dic[year][exp].append(int(ws[3][:-1]))
                test_rmse_dic[year][exp].append(float(ws[5][:-1]))
                test_r2_dic[year][exp].append(float(ws[7][:-1]))
                test_corr_dic[year][exp].append(float(ws[9]))

    for year in train_epochs_dic.keys():
        n_exps = len(train_epochs_dic[year])
        for i in range(n_exps):
            # assert train_epochs_dic[year][i] == test_epochs_dic[year][i], params

            plt.plot(train_epochs_dic[year][i], train_loss_dic[year][i], label='Training')
            plt.plot(train_epochs_dic[year][i], valid_loss_dic[year][i], label='Validation')
            plt.title(params, fontsize=8)
            plt.grid(True)
            plt.legend()
            plt.savefig('{}/{}_{}_total_loss.jpg'.format(out_dir, year, i), dpi=300)
            plt.close()

            plt.plot(train_epochs_dic[year][i], train_super_loss_dic[year][i], label='Training')
            plt.plot(train_epochs_dic[year][i], valid_super_loss_dic[year][i], label='Validation')
            plt.title(params, fontsize=8)
            plt.grid(True)
            plt.legend()
            plt.savefig('{}/{}_{}_supervised_loss.jpg'.format(out_dir, year, i), dpi=300)
            plt.close()

            plt.plot(train_epochs_dic[year][i], train_unsuper_loss_dic[year][i], label='Training')
            plt.plot(train_epochs_dic[year][i], valid_unsuper_loss_dic[year][i], label='Validation')
            plt.title(params, fontsize=8)
            plt.grid(True)
            plt.legend()
            plt.savefig('{}/{}_{}_unsupervised_loss.jpg'.format(out_dir, year, i), dpi=300)
            plt.close()

            # valid_l_n_loss, valid_l_d_loss, valid_l_nd_loss, valid_sn_loss, valid_tn_loss, valid_norm_loss
            plt.plot(train_epochs_dic[year][i], valid_l_n_loss_dic[year][i], label='l_n_loss')
            plt.plot(train_epochs_dic[year][i], valid_l_d_loss_dic[year][i], label='l_d_loss')
            plt.plot(train_epochs_dic[year][i], valid_l_nd_loss_dic[year][i], label='l_nd_loss')
            plt.plot(train_epochs_dic[year][i], valid_sn_loss_dic[year][i], label='spatial_neighbor_loss')
            plt.plot(train_epochs_dic[year][i], valid_tn_loss_dic[year][i], label='temporal_neighbor_loss')
            plt.plot(train_epochs_dic[year][i], valid_norm_loss_dic[year][i], label='l2_norm_loss')
            plt.title(params, fontsize=8)
            plt.grid(True)
            plt.legend()
            plt.savefig('{}/{}_{}_validation_various_losses.jpg'.format(out_dir, year, i), dpi=300)
            plt.close()

            plt.plot(train_epochs_dic[year][i], valid_rmse_dic[year][i], label='Validation')
            plt.plot(test_epochs_dic[year][i], test_rmse_dic[year][i], label='Test')
            plt.title(params, fontsize=8)
            plt.grid(True)
            plt.legend()
            plt.savefig('{}/{}_{}_rmse.jpg'.format(out_dir, year, i), dpi=300)
            plt.close()

            plt.plot(train_epochs_dic[year][i], valid_r2_dic[year][i], label='Validation')
            plt.plot(test_epochs_dic[year][i], test_r2_dic[year][i], label='Test')
            plt.title(params, fontsize=8)
            plt.grid(True)
            plt.legend()
            plt.savefig('{}/{}_{}_r2.jpg'.format(out_dir, year, i), dpi=300)
            plt.close()

            plt.plot(train_epochs_dic[year][i], valid_corr_dic[year][i], label='Validation')
            plt.plot(test_epochs_dic[year][i], test_corr_dic[year][i], label='Test')
            plt.title(params, fontsize=8)
            plt.grid(True)
            plt.legend()
            plt.savefig('{}/{}_{}_corr.jpg'.format(out_dir, year, i), dpi=300)
            plt.close()