def main()

in research/pate_2018/ICLR2018/rdp_bucketized.py [0:0]


def main(argv):
  del argv  # Unused.
  fin_name = os.path.expanduser(FLAGS.counts_file)
  print('Reading raw votes from ' + fin_name)
  sys.stdout.flush()

  votes = np.load(fin_name)
  votes = votes[:4000,]  # truncate to 4000 samples

  if FLAGS.plot == 'small':
    bin_num = 5
    m_check = compute_expected_answered_per_bin(bin_num, votes, 3500, 1500)
  elif FLAGS.plot == 'large':
    bin_num = 10
    m_check = compute_expected_answered_per_bin(bin_num, votes, 3500, 1500)
    a_check = compute_expected_answered_per_bin(bin_num, votes, 5000, 1500)
    eps = compute_privacy_cost_per_bins(bin_num, votes, 100, 50)
  else:
    raise ValueError('--plot flag must be one of ["small", "large"]')

  counts = compute_count_per_bin(bin_num, votes)
  bins = np.linspace(0, 100, num=bin_num, endpoint=False)

  plt.close('all')
  fig, ax = plt.subplots()
  if FLAGS.plot == 'small':
    fig.set_figheight(5)
    fig.set_figwidth(5)
    ax.bar(
        bins,
        counts,
        20,
        color='orangered',
        linestyle='dotted',
        linewidth=5,
        edgecolor='red',
        fill=False,
        alpha=.5,
        align='edge',
        label='LNMax answers')
    ax.bar(
        bins,
        m_check,
        20,
        color='g',
        alpha=.5,
        linewidth=0,
        edgecolor='g',
        align='edge',
        label='Confident-GNMax\nanswers')
  elif FLAGS.plot == 'large':
    fig.set_figheight(4.7)
    fig.set_figwidth(7)
    ax.bar(
        bins,
        counts,
        10,
        linestyle='dashed',
        linewidth=5,
        edgecolor='red',
        fill=False,
        alpha=.5,
        align='edge',
        label='LNMax answers')
    ax.bar(
        bins,
        m_check,
        10,
        color='g',
        alpha=.5,
        linewidth=0,
        edgecolor='g',
        align='edge',
        label='Confident-GNMax\nanswers (moderate)')
    ax.bar(
        bins,
        a_check,
        10,
        color='b',
        alpha=.5,
        align='edge',
        label='Confident-GNMax\nanswers (aggressive)')
    ax2 = ax.twinx()
    bin_centers = [x + 5 for x in bins]
    ax2.plot(bin_centers, eps, 'ko', alpha=.8)
    ax2.set_ylim([1e-200, 1.])
    ax2.set_yscale('log')
    ax2.grid(False)
    ax2.set_yticks([1e-3, 1e-50, 1e-100, 1e-150, 1e-200])
    plt.tick_params(which='minor', right='off')
    ax2.set_ylabel(r'Per query privacy cost $\varepsilon$', fontsize=16)

  plt.xlim([0, 100])
  ax.set_ylim([0, 2500])
  # ax.set_yscale('log')
  ax.set_xlabel('Percentage of teachers that agree', fontsize=16)
  ax.set_ylabel('Number of queries answered', fontsize=16)
  vals = ax.get_xticks()
  ax.set_xticklabels([str(int(x)) + '%' for x in vals])
  ax.tick_params(labelsize=14, bottom=True, top=True, left=True, right=True)
  ax.legend(loc=2, prop={'size': 16})

  # simple: 'figures/noisy_thresholding_check_perf.pdf')
  # detailed: 'figures/noisy_thresholding_check_perf_details.pdf'

  print('Saving the graph to ' + FLAGS.plot_file)
  plt.savefig(os.path.expanduser(FLAGS.plot_file), bbox_inches='tight')
  plt.show()