def analyze_gnmax_conf_data_dep()

in research/pate_2018/ICLR2018/plot_partition.py [0:0]


def analyze_gnmax_conf_data_dep(votes, threshold, sigma1, sigma2, delta):
  # Short list of orders.
  # orders = np.round(np.logspace(np.log10(20), np.log10(200), num=20))

  # Long list of orders.
  orders = np.concatenate((np.arange(20, 40, .2),
                           np.arange(40, 75, .5),
                            np.logspace(np.log10(75), np.log10(200), num=20)))

  n = votes.shape[0]
  num_classes = votes.shape[1]
  num_teachers = int(sum(votes[0,]))

  if threshold is not None and sigma1 is not None:
    is_data_ind_step1 = pate.is_data_independent_always_opt_gaussian(
        num_teachers, num_classes, sigma1, orders)
  else:
    is_data_ind_step1 = [True] * len(orders)

  is_data_ind_step2 = pate.is_data_independent_always_opt_gaussian(
      num_teachers, num_classes, sigma2, orders)

  eps_partitioned = np.full(n, None, dtype=Partition)
  order_opt = np.full(n, None, dtype=float)
  ss_std_opt = np.full(n, None, dtype=float)
  answered = np.zeros(n)

  rdp_step1_total = np.zeros(len(orders))
  rdp_step2_total = np.zeros(len(orders))

  ls_total = np.zeros((len(orders), num_teachers))
  answered_total = 0

  for i in range(n):
    v = votes[i,]

    if threshold is not None and sigma1 is not None:
      logq_step1 = pate.compute_logpr_answered(threshold, sigma1, v)
      rdp_step1_total += pate.compute_rdp_threshold(logq_step1, sigma1, orders)
    else:
      logq_step1 = 0.  # always answer

    pr_answered = np.exp(logq_step1)
    logq_step2 = pate.compute_logq_gaussian(v, sigma2)
    rdp_step2_total += pr_answered * pate.rdp_gaussian(logq_step2, sigma2,
                                                       orders)

    answered_total += pr_answered

    rdp_ss = np.zeros(len(orders))
    ss_std = np.zeros(len(orders))

    for j, order in enumerate(orders):
      if not is_data_ind_step1[j]:
        ls_step1 = pate_ss.compute_local_sensitivity_bounds_threshold(v,
            num_teachers, threshold, sigma1, order)
      else:
        ls_step1 = np.full(num_teachers, 0, dtype=float)

      if not is_data_ind_step2[j]:
        ls_step2 = pate_ss.compute_local_sensitivity_bounds_gnmax(
            v, num_teachers, sigma2, order)
      else:
        ls_step2 = np.full(num_teachers, 0, dtype=float)

      ls_total[j,] += ls_step1 + pr_answered * ls_step2

      beta_ss = .49 / order

      ss = pate_ss.compute_discounted_max(beta_ss, ls_total[j,])
      sigma_ss = ((order * math.exp(2 * beta_ss)) / ss) ** (1 / 3)
      rdp_ss[j] = pate_ss.compute_rdp_of_smooth_sensitivity_gaussian(
          beta_ss, sigma_ss, order)
      ss_std[j] = ss * sigma_ss

    rdp_total = rdp_step1_total + rdp_step2_total + rdp_ss

    answered[i] = answered_total
    _, order_opt[i] = pate.compute_eps_from_delta(orders, rdp_total, delta)
    order_idx = np.searchsorted(orders, order_opt[i])

    # Since optimal orders are always non-increasing, shrink orders array
    # and all cumulative arrays to speed up computation.
    if order_idx < len(orders):
      orders = orders[:order_idx + 1]
      rdp_step1_total = rdp_step1_total[:order_idx + 1]
      rdp_step2_total = rdp_step2_total[:order_idx + 1]

    eps_partitioned[i] = Partition(step1=rdp_step1_total[order_idx],
                                   step2=rdp_step2_total[order_idx],
                                   ss=rdp_ss[order_idx],
                                   delta=-math.log(delta) / (order_opt[i] - 1))
    ss_std_opt[i] = ss_std[order_idx]
    if i > 0 and (i + 1) % 1 == 0:
      print('queries = {}, E[answered] = {:.2f}, E[eps] = {:.3f} +/- {:.3f} '
            'at order = {:.2f}. Contributions: delta = {:.3f}, step1 = {:.3f}, '
            'step2 = {:.3f}, ss = {:.3f}'.format(
          i + 1,
          answered[i],
          sum(eps_partitioned[i]),
          ss_std_opt[i],
          order_opt[i],
          eps_partitioned[i].delta,
          eps_partitioned[i].step1,
          eps_partitioned[i].step2,
          eps_partitioned[i].ss))
      sys.stdout.flush()

  return eps_partitioned, answered, ss_std_opt, order_opt