def finetuneimdb_dif_supervisions()

in panrep/plot_data.py [0:0]


def finetuneimdb_dif_supervisions():

    def split_acc(acc):
        panrep_acc = float(acc.split(" ")[4].split("~")[0])
        prft_tes_acc = float(acc.split(" ")[24])
        finpanrep_acc = float(acc.split(" ")[30].split("~")[0])

        mrr = float(acc.split(" ")[51].split('\n')[0])

        if acc.split(" ")[56]=='LPFT':
            lp_mrr = 0
            lpft_mrr = float(acc.split(" ")[58].split('\n')[0])
            entropy =float(acc.split(" ")[83])
            mlp_acc_pr=float(acc.split(" ")[88])
            mlp_acc_prft=float(acc.split(" ")[93])
            finlogpanrep_acc = float(acc.split(" ")[65].split("~")[0])
        else:
            lp_mrr = float(acc.split(" ")[58].split('\n')[0])
            lpft_mrr = float(acc.split(" ")[65].split('\n')[0])
            entropy =float(acc.split(" ")[90])
            mlp_acc_pr=float(acc.split(" ")[95])
            mlp_acc_prft=float(acc.split(" ")[100])
            finlogpanrep_acc = float(acc.split(" ")[72].split("~")[0])
        return panrep_acc,prft_tes_acc,finpanrep_acc,mrr,lp_mrr,lpft_mrr,entropy,mlp_acc_pr,mlp_acc_prft,finlogpanrep_acc
    def plot_results(results, paramlist):
        plots = {}
        elems = list(results.keys())[0]
        sets = []
        for el in list(elems):
            sets += [set()]
        keys = list(results.keys())
        experiment = keys[-1]
        keys = keys[:-1]
        for key in keys:
            (n_epochs,n_fine_tune_epochs, n_layers, n_hidden, n_bases, fanout, lr, dropout,
             use_link_prediction, use_reconstruction_loss,
             use_infomax_loss, mask_links, use_self_loop,
             use_node_motif,num_cluster,single_layer,motif_cluster,k_fold,rw, ng_rate,only_ssl,test_edge_split) = key
            for i in range(len(list(key))):
                sets[i].add(key[i])
            if len(results[key])>0:
                plots[key] = split_acc(results[key])

        plot_over = 21
        key = list(key)

        confs = [8,9,10, 13, 18, 4]
        sets[0] = set([800])
        # key[6]=0.001
        experiment = [n_layers]
        i = 0
        for conf6 in sets[confs[5]]:
            for conf5 in sets[confs[4]]:
                experiment = paramlist[confs[5]] + ' : ' + str(conf6) + \
                             paramlist[confs[4]] + ' : ' + str(conf5)
                key[confs[4]] = conf5
                key[confs[5]] = conf6
                legend = []
                i += 1
                fig = plt.figure(num=i, figsize=(8, 6))
                plt.rc('axes', prop_cycle=(cycler('color', list('kcbgyrgbrgykcmygbcg')) +
                                           cycler('linestyle',
                                                  ['--', ':', '-.', '-', ':', '-', '--', '-', ':', ':', '-.', '-', '--',
                                                   ':', '-.', '-', '--', ':', '-.'])))

                # fig.set_prop_cycle('color', plt.cm.Spectral(np.linspace(0, 1, 30)))
                for conf in sets[confs[0]]:
                    for conf2 in sets[confs[1]]:
                        for conf3 in sets[confs[2]]:
                            for conf4 in sets[confs[3]]:
                                skip_this = False
                                cur_legend = paramlist[confs[2]] + ' : ' + str(float(conf3)) + paramlist[
                                    confs[0]] + ' : ' + str(
                                    int(conf)) + paramlist[confs[1]] + ' : ' + str(float(conf2)) + paramlist[confs[3]] + \
                                             ' : ' + str(float(conf4))
                                y = []
                                key[confs[0]] = conf
                                key[confs[1]] = conf2
                                key[confs[2]] = conf3
                                key[confs[3]] = conf4
                                x = sorted(sets[plot_over])
                                for el in x:
                                    key[plot_over] = el
                                    if tuple(key) not in plots:
                                        skip_this = True
                                        break
                                    y += [plots[tuple(key)][0]]
                                if skip_this:
                                    break
                                plt.plot(x, y)
                                legend += [cur_legend]

                plt.legend(legend, loc='center left', bbox_to_anchor=(-0.1, 1.2), ncol=3, prop=fontP)
                plt.xlabel(paramlist[plot_over])
                plt.ylabel('Macro-F1')
                # plt.ylim(bottom=0.4)
                plt.title("PR " + experiment)
                plt.show()
                legend = []
                legend = []
                i += 1
                fig = plt.figure(num=i, figsize=(8, 6))
                plt.rc('axes', prop_cycle=(cycler('color', list('kcbgyrgbrgykcmygbcg')) +
                                           cycler('linestyle',
                                                  ['--', ':', '-.', '-', ':', '-', '--', '-', ':', ':', '-.', '-', '--',
                                                   ':', '-.', '-', '--', ':', '-.'])))

                # fig.set_prop_cycle('color', plt.cm.Spectral(np.linspace(0, 1, 30)))
                for conf in sets[confs[0]]:
                    for conf2 in sets[confs[1]]:
                        for conf3 in sets[confs[2]]:
                            for conf4 in sets[confs[3]]:
                                skip_this = False
                                cur_legend = paramlist[confs[2]] + ' : ' + str(float(conf3)) + paramlist[
                                    confs[0]] + ' : ' + str(
                                    int(conf)) + paramlist[confs[1]] + ' : ' + str(float(conf2)) + paramlist[confs[3]] + \
                                             ' : ' + str(float(conf4))
                                y = []
                                key[confs[0]] = conf
                                key[confs[1]] = conf2
                                key[confs[2]] = conf3
                                key[confs[3]] = conf4
                                x = sorted(sets[plot_over])
                                for el in x:
                                    key[plot_over] = el
                                    if tuple(key) not in plots:
                                        skip_this = True
                                        break
                                    y += [plots[tuple(key)][9]]
                                if skip_this:
                                    break
                                plt.plot(x, y)
                                legend += [cur_legend]

                plt.legend(legend, loc='center left', bbox_to_anchor=(-0.1, 1.2), ncol=3, prop=fontP)
                plt.xlabel(paramlist[plot_over])
                plt.ylabel('Macro-F1')
                # plt.ylim(bottom=0.4)
                plt.title("PR-FT Log" + experiment)
                plt.show()

                legend = []
                i += 1
                fig = plt.figure(num=i, figsize=(8, 6))
                plt.rc('axes', prop_cycle=(cycler('color', list('kcbgyrgbrgykcmygbcg')) +
                                           cycler('linestyle',
                                                  ['--', ':', '-.', '-', ':', '-', '--', '-', ':', ':', '-.', '-', '--',
                                                   ':', '-.', '-', '--', ':', '-.'])))

                # fig.set_prop_cycle('color', plt.cm.Spectral(np.linspace(0, 1, 30)))
                for conf in sets[confs[0]]:
                    for conf2 in sets[confs[1]]:
                        for conf3 in sets[confs[2]]:
                            for conf4 in sets[confs[3]]:
                                skip_this = False
                                cur_legend = paramlist[confs[2]] + ' : ' + str(float(conf3)) + paramlist[
                                    confs[0]] + ' : ' + str(
                                    int(conf)) + paramlist[confs[1]] + ' : ' + str(float(conf2)) + paramlist[confs[3]] + \
                                             ' : ' + str(float(conf4))
                                y = []
                                key[confs[0]] = conf
                                key[confs[1]] = conf2
                                key[confs[2]] = conf3
                                key[confs[3]] = conf4
                                x = sorted(sets[plot_over])
                                for el in x:
                                    key[plot_over] = el
                                    if tuple(key) not in plots:
                                        skip_this = True
                                        break
                                    y += [plots[tuple(key)][2]]
                                if skip_this:
                                    break
                                plt.plot(x, y)
                                legend += [cur_legend]
                plt.legend(legend, loc='center left', bbox_to_anchor=(-0.1, 1.2), ncol=3, prop=fontP)
                plt.xlabel(paramlist[plot_over])
                plt.ylabel('Macro-F1')
                # plt.ylim(bottom=0.4)
                plt.title("PR-FT " + experiment)
                plt.show()
                legend = []

                legend = []
                i += 1
                fig = plt.figure(num=i, figsize=(8, 6))
                plt.rc('axes', prop_cycle=(cycler('color', list('kcbgyrgbrgykcmygbcg')) +
                                           cycler('linestyle',
                                                  ['--', ':', '-.', '-', ':', '-', '--', '-', ':', ':', '-.', '-', '--',
                                                   ':', '-.', '-', '--', ':', '-.'])))

                # fig.set_prop_cycle('color', plt.cm.Spectral(np.linspace(0, 1, 30)))

                for conf in sets[confs[0]]:
                    for conf2 in sets[confs[1]]:
                        for conf3 in sets[confs[2]]:
                            for conf4 in sets[confs[3]]:
                                skip_this = False
                                cur_legend = paramlist[confs[2]] + ' : ' + str(float(conf3)) + paramlist[
                                    confs[0]] + ' : ' + str(
                                    int(conf)) + paramlist[confs[1]] + ' : ' + str(float(conf2)) + paramlist[confs[3]] + \
                                             ' : ' + str(float(conf4))
                                y = []
                                key[confs[0]] = conf
                                key[confs[1]] = conf2
                                key[confs[2]] = conf3
                                key[confs[3]] = conf4
                                x = sorted(sets[plot_over])
                                for el in x:
                                    key[plot_over] = el
                                    if tuple(key) not in plots:
                                        skip_this = True
                                        break
                                    y += [plots[tuple(key)][1]]
                                if skip_this:
                                    break
                                plt.plot(x, y)
                                legend += [cur_legend]
                plt.legend(legend, loc='center left', bbox_to_anchor=(-0.1, 1.2), ncol=3, prop=fontP)
                plt.xlabel(paramlist[plot_over])
                plt.ylabel('Test Acc')

                plt.title("MLP " + experiment)
                plt.show()
                legend = []
                i += 1
                fig = plt.figure(num=i, figsize=(8, 6))
                plt.rc('axes', prop_cycle=(cycler('color', list('kcbgyrgbrgykcmygbcg')) +
                                           cycler('linestyle',
                                                  ['--', ':', '-.', '-', ':', '-', '--', '-', ':', ':', '-.', '-', '--',
                                                   ':', '-.', '-', '--', ':', '-.'])))
                # fig.set_prop_cycle('color', plt.cm.Spectral(np.linspace(0, 1, 30)))

                for conf in sets[confs[0]]:
                    for conf2 in sets[confs[1]]:
                        for conf3 in sets[confs[2]]:
                            for conf4 in sets[confs[3]]:
                                skip_this = False
                                cur_legend = paramlist[confs[2]] + ' : ' + str(float(conf3)) + paramlist[
                                    confs[0]] + ' : ' + str(
                                    int(conf)) + paramlist[confs[1]] + ' : ' + str(float(conf2)) + paramlist[confs[3]] + \
                                             ' : ' + str(float(conf4))
                                y = []
                                key[confs[0]] = conf
                                key[confs[1]] = conf2
                                key[confs[2]] = conf3
                                key[confs[3]] = conf4
                                x = sorted(sets[plot_over])
                                for el in x:
                                    key[plot_over] = el
                                    if tuple(key) not in plots:
                                        skip_this = True
                                        break
                                    y += [plots[tuple(key)][3]]
                                if skip_this:
                                    break
                                plt.plot(x, y)
                                legend += [cur_legend]
                plt.legend(legend, loc='center left', bbox_to_anchor=(-0.1, 1.2), ncol=3, prop=fontP)
                plt.xlabel(paramlist[plot_over])
                plt.ylabel('MRR')

                plt.title("PanRep " + experiment)
                plt.show()
                legend = []
                i += 1
                fig = plt.figure(num=i, figsize=(8, 6))
                plt.rc('axes', prop_cycle=(cycler('color', list('kcbgyrgbrgykcmygbcg')) +
                                           cycler('linestyle',
                                                  ['--', ':', '-.', '-', ':', '-', '--', '-', ':', ':', '-.', '-', '--',
                                                   ':', '-.', '-', '--', ':', '-.'])))
                for conf in sets[confs[0]]:
                    for conf2 in sets[confs[1]]:
                        for conf3 in sets[confs[2]]:
                            for conf4 in sets[confs[3]]:
                                skip_this = False
                                cur_legend = paramlist[confs[2]] + ' : ' + str(float(conf3)) + paramlist[
                                    confs[0]] + ' : ' + str(
                                    int(conf)) + paramlist[confs[1]] + ' : ' + str(float(conf2)) + paramlist[confs[3]] + \
                                             ' : ' + str(float(conf4))
                                y = []
                                key[confs[0]] = conf
                                key[confs[1]] = conf2
                                key[confs[2]] = conf3
                                key[confs[3]] = conf4
                                x = sorted(sets[plot_over])
                                for el in x:
                                    key[plot_over] = el
                                    if tuple(key) not in plots:
                                        skip_this = True
                                        break
                                    y += [plots[tuple(key)][4]]
                                if skip_this:
                                    break
                                plt.plot(x, y)
                                legend += [cur_legend]
                plt.legend(legend, loc='center left', bbox_to_anchor=(-0.1, 1.2), ncol=3, prop=fontP)
                plt.xlabel(paramlist[plot_over])
                plt.ylabel('MRR')

                plt.title("PanRep-LP module" + experiment)
                plt.show()
                legend = []
                i += 1
                fig = plt.figure(num=i, figsize=(8, 6))
                plt.rc('axes', prop_cycle=(cycler('color', list('kcbgyrgbrgykcmygbcg')) +
                                           cycler('linestyle',
                                                  ['--', ':', '-.', '-', ':', '-', '--', '-', ':', ':', '-.', '-', '--',
                                                   ':', '-.', '-', '--', ':', '-.'])))
                for conf in sets[confs[0]]:
                    for conf2 in sets[confs[1]]:
                        for conf3 in sets[confs[2]]:
                            for conf4 in sets[confs[3]]:
                                skip_this = False
                                cur_legend = paramlist[confs[2]] + ' : ' + str(float(conf3)) + paramlist[
                                    confs[0]] + ' : ' + str(
                                    int(conf)) + paramlist[confs[1]] + ' : ' + str(float(conf2)) + paramlist[confs[3]] + \
                                             ' : ' + str(float(conf4))
                                y = []
                                key[confs[0]] = conf
                                key[confs[1]] = conf2
                                key[confs[2]] = conf3
                                key[confs[3]] = conf4
                                x = sorted(sets[plot_over])
                                for el in x:
                                    key[plot_over] = el
                                    if tuple(key) not in plots:
                                        skip_this = True
                                        break
                                    y += [plots[tuple(key)][5]]
                                if skip_this:
                                    break
                                plt.plot(x, y)
                                legend += [cur_legend]
                plt.legend(legend, loc='center left', bbox_to_anchor=(-0.1, 1.2), ncol=3, prop=fontP)
                plt.xlabel(paramlist[plot_over])
                plt.ylabel('MRR')

                plt.title("PanRep LP-FT " + experiment)
                plt.show()
                legend = []
                i += 1
                fig = plt.figure(num=i, figsize=(8, 6))
                plt.rc('axes', prop_cycle=(cycler('color', list('kcbgyrgbrgykcmygbcg')) +
                                           cycler('linestyle',
                                                  ['--', ':', '-.', '-', ':', '-', '--', '-', ':', ':', '-.', '-', '--',
                                                   ':', '-.', '-', '--', ':', '-.'])))
                for conf in sets[confs[0]]:
                    for conf2 in sets[confs[1]]:
                        for conf3 in sets[confs[2]]:
                            for conf4 in sets[confs[3]]:
                                skip_this = False
                                cur_legend = paramlist[confs[2]] + ' : ' + str(float(conf3)) + paramlist[
                                    confs[0]] + ' : ' + str(
                                    int(conf)) + paramlist[confs[1]] + ' : ' + str(float(conf2)) + paramlist[confs[3]] + \
                                             ' : ' + str(float(conf4))
                                y = []
                                key[confs[0]] = conf
                                key[confs[1]] = conf2
                                key[confs[2]] = conf3
                                key[confs[3]] = conf4
                                x = sorted(sets[plot_over])
                                for el in x:
                                    key[plot_over] = el
                                    if tuple(key) not in plots:
                                        skip_this = True
                                        break
                                    y += [plots[tuple(key)][7]]
                                if skip_this:
                                    break
                                plt.plot(x, y)
                                legend += [cur_legend]
                plt.legend(legend, loc='center left', bbox_to_anchor=(-0.1, 1.2), ncol=3, prop=fontP)
                plt.xlabel(paramlist[plot_over])
                plt.ylabel('Acc')

                plt.title("MLP acc of pr " + experiment)
                plt.show()
                legend = []
                i += 1
                fig = plt.figure(num=i, figsize=(8, 6))
                plt.rc('axes', prop_cycle=(cycler('color', list('kcbgyrgbrgykcmygbcg')) +
                                           cycler('linestyle',
                                                  ['--', ':', '-.', '-', ':', '-', '--', '-', ':', ':', '-.', '-', '--',
                                                   ':', '-.', '-', '--', ':', '-.'])))
                for conf in sets[confs[0]]:
                    for conf2 in sets[confs[1]]:
                        for conf3 in sets[confs[2]]:
                            for conf4 in sets[confs[3]]:
                                skip_this = False
                                cur_legend = paramlist[confs[2]] + ' : ' + str(float(conf3)) + paramlist[
                                    confs[0]] + ' : ' + str(
                                    int(conf)) + paramlist[confs[1]] + ' : ' + str(float(conf2)) + paramlist[confs[3]] + \
                                             ' : ' + str(float(conf4))
                                y = []
                                key[confs[0]] = conf
                                key[confs[1]] = conf2
                                key[confs[2]] = conf3
                                key[confs[3]] = conf4
                                x = sorted(sets[plot_over])
                                for el in x:
                                    key[plot_over] = el
                                    if tuple(key) not in plots:
                                        skip_this = True
                                        break
                                    y += [plots[tuple(key)][8]]
                                if skip_this:
                                    break
                                plt.plot(x, y)
                                legend += [cur_legend]
                plt.legend(legend, loc='center left', bbox_to_anchor=(-0.1, 1.2), ncol=3, prop=fontP)
                plt.xlabel(paramlist[plot_over])
                plt.ylabel('Acc')

                plt.title("MLP acc of pr-ft" + experiment)
                plt.show()

    paramlist = "n_epochs,n_ft_ep, L, n_h, n_b, fanout, lr, dr, LP, R, I, mask_links," \
                " self_loop , M,num_cluster,single_layer, n_mt_cls,k_shot, rw, ng_rate, ssl,split_pct "
    paramlist = paramlist.split(',')

    files=["imdb_preprocessed-2020-05-10-07:42:54.021860.pickle"]
    for f in files:
    #file = "imdb_preprocessed-2020-05-07-09:30:14.936903.pickle"

        results = pickle.load(open("results/universal_task/" + f, 'rb'))
        plot_results(results, paramlist)