def main()

in research/gam/gam/experiments/run_train_gam.py [0:0]


def main(argv):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')

  if FLAGS.logging_config:
    print('Setting logging configuration: ', FLAGS.logging_config)
    config.fileConfig(FLAGS.logging_config)

  # Set random seed.
  np.random.seed(FLAGS.seed)
  tf.set_random_seed(FLAGS.seed)

  ############################################################################
  #                               DATA                                       #
  ############################################################################
  # Potentially create a folder where to save the preprocessed data.
  if not os.path.exists(FLAGS.data_output_dir):
    os.makedirs(FLAGS.data_output_dir)

  # Load and potentially preprocess data.
  if FLAGS.load_preprocessed:
    logging.info('Loading preprocessed data...')
    path = os.path.join(FLAGS.data_output_dir, FLAGS.filename_preprocessed_data)
    data = Dataset.load_from_pickle(path)
  else:
    data = load_data()
    if FLAGS.save_preprocessed:
      assert FLAGS.output_dir
      path = os.path.join(FLAGS.data_output_dir,
                          FLAGS.filename_preprocessed_data)
      data.save_to_pickle(path)
      logging.info('Preprocessed data saved to %s.', path)

  ############################################################################
  #                            PREPARE OUTPUTS                               #
  ############################################################################
  # Put together parameters to create a model name.
  model_name = FLAGS.model_cls
  model_name += ('_' + FLAGS.hidden_cls) if FLAGS.model_cls == 'mlp' else ''
  model_name += '-' + FLAGS.model_agr
  model_name += ('_' + FLAGS.hidden_agr) if FLAGS.model_agr == 'mlp' else ''
  model_name += '-aggr_' + FLAGS.aggregation_agr_inputs
  model_name += ('_' + FLAGS.hidden_aggreg) if FLAGS.hidden_aggreg else ''
  model_name += (
      '-add_%d-conf_%.2f-iterCls_%d-iterAgr_%d-batchCls_%d' %
      (FLAGS.num_samples_to_label, FLAGS.min_confidence_new_label,
       FLAGS.max_num_iter_cls, FLAGS.max_num_iter_agr, FLAGS.batch_size_cls))
  model_name += (('-wdecayCls_%.4f' %
                  FLAGS.weight_decay_cls) if FLAGS.weight_decay_cls else '')
  model_name += (('-wdecayAgr_%.4f' %
                  FLAGS.weight_decay_agr) if FLAGS.weight_decay_agr else '')
  model_name += '-LL_%s_LU_%s_UU_%s' % (str(
      FLAGS.reg_weight_ll), str(FLAGS.reg_weight_lu), str(FLAGS.reg_weight_uu))
  model_name += '-perfAgr' if FLAGS.use_perfect_agreement else ''
  model_name += '-perfCls' if FLAGS.use_perfect_classifier else ''
  model_name += '-keepProp' if FLAGS.keep_label_proportions else ''
  model_name += '-PenNegAgr' if FLAGS.penalize_neg_agr else ''
  model_name += '-transd' if not FLAGS.inductive else ''
  model_name += '-L2' if FLAGS.use_l2_cls else '-CE'
  model_name += '-seed_' + str(FLAGS.seed)
  model_name += FLAGS.experiment_suffix
  logging.info('Model name: %s', model_name)

  # Create directories for model checkpoints, summaries, and
  # self-labeled data backup.
  summary_dir = os.path.join(FLAGS.output_dir, 'summaries', FLAGS.dataset_name,
                             model_name)
  checkpoints_dir = os.path.join(FLAGS.output_dir, 'checkpoints',
                                 FLAGS.dataset_name, model_name)
  data_dir = os.path.join(FLAGS.data_output_dir, 'data_checkpoints',
                          FLAGS.dataset_name, model_name)
  if not os.path.exists(checkpoints_dir):
    os.makedirs(checkpoints_dir)
  if not os.path.exists(data_dir):
    os.makedirs(data_dir)

  ############################################################################
  #                            MODEL SETUP                                   #
  ############################################################################
  # Select the model based on the provided FLAGS.
  model_cls = get_model_cls(
      model_name=FLAGS.model_cls,
      data=data,
      dataset_name=FLAGS.dataset_name,
      hidden=FLAGS.hidden_cls)

  # Create agreement model.
  model_agr = get_model_agr(
      model_name=FLAGS.model_agr,
      dataset_name=FLAGS.dataset_name,
      hidden_aggreg=FLAGS.hidden_aggreg,
      aggregation_agr_inputs=FLAGS.aggregation_agr_inputs,
      hidden=FLAGS.hidden_agr)

  # Train.
  trainer = TrainerCotraining(
      model_cls=model_cls,
      model_agr=model_agr,
      max_num_iter_cotrain=FLAGS.max_num_iter_cotrain,
      min_num_iter_cls=FLAGS.min_num_iter_cls,
      max_num_iter_cls=FLAGS.max_num_iter_cls,
      num_iter_after_best_val_cls=FLAGS.num_iter_after_best_val_cls,
      min_num_iter_agr=FLAGS.min_num_iter_agr,
      max_num_iter_agr=FLAGS.max_num_iter_agr,
      num_iter_after_best_val_agr=FLAGS.num_iter_after_best_val_agr,
      num_samples_to_label=FLAGS.num_samples_to_label,
      min_confidence_new_label=FLAGS.min_confidence_new_label,
      keep_label_proportions=FLAGS.keep_label_proportions,
      num_warm_up_iter_agr=FLAGS.num_warm_up_iter_agr,
      optimizer=tf.train.AdamOptimizer,
      gradient_clip=FLAGS.gradient_clip,
      batch_size_agr=FLAGS.batch_size_agr,
      batch_size_cls=FLAGS.batch_size_cls,
      learning_rate_cls=FLAGS.learning_rate_cls,
      learning_rate_agr=FLAGS.learning_rate_agr,
      enable_summaries=True,
      enable_summaries_per_model=True,
      summary_dir=summary_dir,
      summary_step_cls=FLAGS.summary_step_cls,
      summary_step_agr=FLAGS.summary_step_agr,
      logging_step_cls=FLAGS.logging_step_cls,
      logging_step_agr=FLAGS.logging_step_agr,
      eval_step_cls=FLAGS.eval_step_cls,
      eval_step_agr=FLAGS.eval_step_agr,
      checkpoints_dir=checkpoints_dir,
      checkpoints_step=1,
      data_dir=data_dir,
      abs_loss_chg_tol=1e-10,
      rel_loss_chg_tol=1e-7,
      loss_chg_iter_below_tol=30,
      use_perfect_agr=FLAGS.use_perfect_agreement,
      use_perfect_cls=FLAGS.use_perfect_classifier,
      warm_start_cls=FLAGS.warm_start_cls,
      warm_start_agr=FLAGS.warm_start_agr,
      ratio_valid_agr=FLAGS.ratio_valid_agr,
      max_samples_valid_agr=FLAGS.max_samples_valid_agr,
      weight_decay_cls=FLAGS.weight_decay_cls,
      weight_decay_schedule_cls=FLAGS.weight_decay_schedule_cls,
      weight_decay_schedule_agr=FLAGS.weight_decay_schedule_agr,
      weight_decay_agr=FLAGS.weight_decay_agr,
      reg_weight_ll=FLAGS.reg_weight_ll,
      reg_weight_lu=FLAGS.reg_weight_lu,
      reg_weight_uu=FLAGS.reg_weight_uu,
      reg_weight_vat=FLAGS.reg_weight_vat,
      use_ent_min=FLAGS.use_ent_min,
      num_pairs_reg=FLAGS.num_pairs_reg,
      penalize_neg_agr=FLAGS.penalize_neg_agr,
      use_l2_cls=FLAGS.use_l2_cls,
      first_iter_original=FLAGS.first_iter_original,
      inductive=FLAGS.inductive,
      seed=FLAGS.seed,
      eval_acc_pred_by_agr=FLAGS.eval_acc_pred_by_agr,
      num_neighbors_pred_by_agr=FLAGS.num_neighbors_pred_by_agr,
      lr_decay_rate_cls=FLAGS.lr_decay_rate_cls,
      lr_decay_steps_cls=FLAGS.lr_decay_steps_cls,
      lr_decay_rate_agr=FLAGS.lr_decay_rate_agr,
      lr_decay_steps_agr=FLAGS.lr_decay_steps_agr,
      load_from_checkpoint=FLAGS.load_from_checkpoint)

  ############################################################################
  #                            TRAIN                                         #
  ############################################################################
  trainer.train(data)