in research/gam/gam/experiments/run_train_gam.py [0:0]
def main(argv):
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
if FLAGS.logging_config:
print('Setting logging configuration: ', FLAGS.logging_config)
config.fileConfig(FLAGS.logging_config)
# Set random seed.
np.random.seed(FLAGS.seed)
tf.set_random_seed(FLAGS.seed)
############################################################################
# DATA #
############################################################################
# Potentially create a folder where to save the preprocessed data.
if not os.path.exists(FLAGS.data_output_dir):
os.makedirs(FLAGS.data_output_dir)
# Load and potentially preprocess data.
if FLAGS.load_preprocessed:
logging.info('Loading preprocessed data...')
path = os.path.join(FLAGS.data_output_dir, FLAGS.filename_preprocessed_data)
data = Dataset.load_from_pickle(path)
else:
data = load_data()
if FLAGS.save_preprocessed:
assert FLAGS.output_dir
path = os.path.join(FLAGS.data_output_dir,
FLAGS.filename_preprocessed_data)
data.save_to_pickle(path)
logging.info('Preprocessed data saved to %s.', path)
############################################################################
# PREPARE OUTPUTS #
############################################################################
# Put together parameters to create a model name.
model_name = FLAGS.model_cls
model_name += ('_' + FLAGS.hidden_cls) if FLAGS.model_cls == 'mlp' else ''
model_name += '-' + FLAGS.model_agr
model_name += ('_' + FLAGS.hidden_agr) if FLAGS.model_agr == 'mlp' else ''
model_name += '-aggr_' + FLAGS.aggregation_agr_inputs
model_name += ('_' + FLAGS.hidden_aggreg) if FLAGS.hidden_aggreg else ''
model_name += (
'-add_%d-conf_%.2f-iterCls_%d-iterAgr_%d-batchCls_%d' %
(FLAGS.num_samples_to_label, FLAGS.min_confidence_new_label,
FLAGS.max_num_iter_cls, FLAGS.max_num_iter_agr, FLAGS.batch_size_cls))
model_name += (('-wdecayCls_%.4f' %
FLAGS.weight_decay_cls) if FLAGS.weight_decay_cls else '')
model_name += (('-wdecayAgr_%.4f' %
FLAGS.weight_decay_agr) if FLAGS.weight_decay_agr else '')
model_name += '-LL_%s_LU_%s_UU_%s' % (str(
FLAGS.reg_weight_ll), str(FLAGS.reg_weight_lu), str(FLAGS.reg_weight_uu))
model_name += '-perfAgr' if FLAGS.use_perfect_agreement else ''
model_name += '-perfCls' if FLAGS.use_perfect_classifier else ''
model_name += '-keepProp' if FLAGS.keep_label_proportions else ''
model_name += '-PenNegAgr' if FLAGS.penalize_neg_agr else ''
model_name += '-transd' if not FLAGS.inductive else ''
model_name += '-L2' if FLAGS.use_l2_cls else '-CE'
model_name += '-seed_' + str(FLAGS.seed)
model_name += FLAGS.experiment_suffix
logging.info('Model name: %s', model_name)
# Create directories for model checkpoints, summaries, and
# self-labeled data backup.
summary_dir = os.path.join(FLAGS.output_dir, 'summaries', FLAGS.dataset_name,
model_name)
checkpoints_dir = os.path.join(FLAGS.output_dir, 'checkpoints',
FLAGS.dataset_name, model_name)
data_dir = os.path.join(FLAGS.data_output_dir, 'data_checkpoints',
FLAGS.dataset_name, model_name)
if not os.path.exists(checkpoints_dir):
os.makedirs(checkpoints_dir)
if not os.path.exists(data_dir):
os.makedirs(data_dir)
############################################################################
# MODEL SETUP #
############################################################################
# Select the model based on the provided FLAGS.
model_cls = get_model_cls(
model_name=FLAGS.model_cls,
data=data,
dataset_name=FLAGS.dataset_name,
hidden=FLAGS.hidden_cls)
# Create agreement model.
model_agr = get_model_agr(
model_name=FLAGS.model_agr,
dataset_name=FLAGS.dataset_name,
hidden_aggreg=FLAGS.hidden_aggreg,
aggregation_agr_inputs=FLAGS.aggregation_agr_inputs,
hidden=FLAGS.hidden_agr)
# Train.
trainer = TrainerCotraining(
model_cls=model_cls,
model_agr=model_agr,
max_num_iter_cotrain=FLAGS.max_num_iter_cotrain,
min_num_iter_cls=FLAGS.min_num_iter_cls,
max_num_iter_cls=FLAGS.max_num_iter_cls,
num_iter_after_best_val_cls=FLAGS.num_iter_after_best_val_cls,
min_num_iter_agr=FLAGS.min_num_iter_agr,
max_num_iter_agr=FLAGS.max_num_iter_agr,
num_iter_after_best_val_agr=FLAGS.num_iter_after_best_val_agr,
num_samples_to_label=FLAGS.num_samples_to_label,
min_confidence_new_label=FLAGS.min_confidence_new_label,
keep_label_proportions=FLAGS.keep_label_proportions,
num_warm_up_iter_agr=FLAGS.num_warm_up_iter_agr,
optimizer=tf.train.AdamOptimizer,
gradient_clip=FLAGS.gradient_clip,
batch_size_agr=FLAGS.batch_size_agr,
batch_size_cls=FLAGS.batch_size_cls,
learning_rate_cls=FLAGS.learning_rate_cls,
learning_rate_agr=FLAGS.learning_rate_agr,
enable_summaries=True,
enable_summaries_per_model=True,
summary_dir=summary_dir,
summary_step_cls=FLAGS.summary_step_cls,
summary_step_agr=FLAGS.summary_step_agr,
logging_step_cls=FLAGS.logging_step_cls,
logging_step_agr=FLAGS.logging_step_agr,
eval_step_cls=FLAGS.eval_step_cls,
eval_step_agr=FLAGS.eval_step_agr,
checkpoints_dir=checkpoints_dir,
checkpoints_step=1,
data_dir=data_dir,
abs_loss_chg_tol=1e-10,
rel_loss_chg_tol=1e-7,
loss_chg_iter_below_tol=30,
use_perfect_agr=FLAGS.use_perfect_agreement,
use_perfect_cls=FLAGS.use_perfect_classifier,
warm_start_cls=FLAGS.warm_start_cls,
warm_start_agr=FLAGS.warm_start_agr,
ratio_valid_agr=FLAGS.ratio_valid_agr,
max_samples_valid_agr=FLAGS.max_samples_valid_agr,
weight_decay_cls=FLAGS.weight_decay_cls,
weight_decay_schedule_cls=FLAGS.weight_decay_schedule_cls,
weight_decay_schedule_agr=FLAGS.weight_decay_schedule_agr,
weight_decay_agr=FLAGS.weight_decay_agr,
reg_weight_ll=FLAGS.reg_weight_ll,
reg_weight_lu=FLAGS.reg_weight_lu,
reg_weight_uu=FLAGS.reg_weight_uu,
reg_weight_vat=FLAGS.reg_weight_vat,
use_ent_min=FLAGS.use_ent_min,
num_pairs_reg=FLAGS.num_pairs_reg,
penalize_neg_agr=FLAGS.penalize_neg_agr,
use_l2_cls=FLAGS.use_l2_cls,
first_iter_original=FLAGS.first_iter_original,
inductive=FLAGS.inductive,
seed=FLAGS.seed,
eval_acc_pred_by_agr=FLAGS.eval_acc_pred_by_agr,
num_neighbors_pred_by_agr=FLAGS.num_neighbors_pred_by_agr,
lr_decay_rate_cls=FLAGS.lr_decay_rate_cls,
lr_decay_steps_cls=FLAGS.lr_decay_steps_cls,
lr_decay_rate_agr=FLAGS.lr_decay_rate_agr,
lr_decay_steps_agr=FLAGS.lr_decay_steps_agr,
load_from_checkpoint=FLAGS.load_from_checkpoint)
############################################################################
# TRAIN #
############################################################################
trainer.train(data)