in research/gam/gam/trainer/trainer_cotrain.py [0:0]
def train(self, data, **kwargs):
# Create a wrapper around the dataset, that also accounts for some
# cotrain specific attributes and functions.
data = CotrainDataset(
data,
keep_label_proportions=self.keep_label_proportions,
inductive=self.inductive)
if os.path.exists(self.data_dir) and self.load_from_checkpoint:
# If this session is restored from a previous run, then we load the
# self-labeled data from the last checkpoint.
logging.info('Number of labeled samples before restoring: %d',
data.num_train())
logging.info('Restoring self-labeled data from %s...', self.data_dir)
data.restore_state_from_file(self.data_dir)
logging.info('Number of labeled samples after restoring: %d',
data.num_train())
# Build graph.
logging.info('Building graph...')
# Create a iteration counter.
iter_cotrain, iter_cotrain_update = self._create_counter()
if self.use_perfect_agr:
# A perfect agreement model used for model.
trainer_agr = TrainerPerfectAgreement(data=data)
else:
with tf.variable_scope('AgreementModel'):
if self.always_agree:
trainer_agr = TrainerAgreementAlwaysAgree(data=data)
else:
trainer_agr = TrainerAgreement(
model=self.model_agr,
data=data,
optimizer=self.optimizer,
gradient_clip=self.gradient_clip,
min_num_iter=self.min_num_iter_agr,
max_num_iter=self.max_num_iter_agr,
num_iter_after_best_val=self.num_iter_after_best_val_agr,
max_num_iter_cotrain=self.max_num_iter_cotrain,
num_warm_up_iter=self.num_warm_up_iter_agr,
warm_start=self.warm_start_agr,
batch_size=self.batch_size_agr,
enable_summaries=self.enable_summaries_per_model,
summary_step=self.summary_step_agr,
summary_dir=self.summary_dir,
logging_step=self.logging_step_agr,
eval_step=self.eval_step_agr,
abs_loss_chg_tol=self.abs_loss_chg_tol,
rel_loss_chg_tol=self.rel_loss_chg_tol,
loss_chg_iter_below_tol=self.loss_chg_iter_below_tol,
checkpoints_dir=self.checkpoints_dir,
weight_decay=self.weight_decay_agr,
weight_decay_schedule=self.weight_decay_schedule_agr,
agree_by_default=False,
percent_val=self.ratio_valid_agr,
max_num_samples_val=self.max_samples_valid_agr,
seed=self.seed,
lr_decay_rate=self.lr_decay_rate_agr,
lr_decay_steps=self.lr_decay_steps_agr,
lr_initial=self.learning_rate_agr,
use_graph=self.use_graph,
add_negative_edges=self.add_negative_edges_agr)
if self.use_perfect_cls:
# A perfect classification model used for debugging purposes.
trainer_cls = TrainerPerfectClassification(data=data)
else:
with tf.variable_scope('ClassificationModel'):
trainer_cls_class = (
TrainerClassificationGCN
if isinstance(self.model_cls, GCN) else TrainerClassification)
trainer_cls = trainer_cls_class(
model=self.model_cls,
data=data,
trainer_agr=trainer_agr,
optimizer=self.optimizer,
gradient_clip=self.gradient_clip,
batch_size=self.batch_size_cls,
min_num_iter=self.min_num_iter_cls,
max_num_iter=self.max_num_iter_cls,
num_iter_after_best_val=self.num_iter_after_best_val_cls,
max_num_iter_cotrain=self.max_num_iter_cotrain,
reg_weight_ll=self.reg_weight_ll,
reg_weight_lu=self.reg_weight_lu,
reg_weight_uu=self.reg_weight_uu,
num_pairs_reg=self.num_pairs_reg,
reg_weight_vat=self.reg_weight_vat,
use_ent_min=self.use_ent_min,
enable_summaries=self.enable_summaries_per_model,
summary_step=self.summary_step_cls,
summary_dir=self.summary_dir,
logging_step=self.logging_step_cls,
eval_step=self.eval_step_cls,
abs_loss_chg_tol=self.abs_loss_chg_tol,
rel_loss_chg_tol=self.rel_loss_chg_tol,
loss_chg_iter_below_tol=self.loss_chg_iter_below_tol,
warm_start=self.warm_start_cls,
checkpoints_dir=self.checkpoints_dir,
weight_decay=self.weight_decay_cls,
weight_decay_schedule=self.weight_decay_schedule_cls,
penalize_neg_agr=self.penalize_neg_agr,
use_l2_classif=self.use_l2_classif,
first_iter_original=self.first_iter_original,
seed=self.seed,
iter_cotrain=iter_cotrain,
lr_decay_rate=self.lr_decay_rate_cls,
lr_decay_steps=self.lr_decay_steps_cls,
lr_initial=self.learning_rate_cls,
use_graph=self.use_graph)
# Create a saver which saves only the variables that we would need to
# restore in case the training process is restarted.
vars_to_save = [iter_cotrain
] + trainer_agr.vars_to_save + trainer_cls.vars_to_save
saver = tf.train.Saver(vars_to_save)
# Create a TensorFlow session. We allow soft placement in order to place
# any supported ops on GPU. The allow_growth option lets our process
# progressively use more gpu memory, per need basis, as opposed to
# allocating it all from the beginning.
config = tf.ConfigProto(allow_soft_placement=True)
config.gpu_options.allow_growth = True
session = tf.Session(config=config)
# Create a Tensorflow summary writer, shared by all models.
summary_writer = tf.summary.FileWriter(self.summary_dir, session.graph)
# Initialize the values of all variables and the train dataset iterator.
session.run(tf.global_variables_initializer())
# If a checkpoint with the variables already exists, we restore them.
if self.checkpoints_dir:
checkpts_path_cotrain = os.path.join(self.checkpoints_dir, 'cotrain.ckpt')
if os.path.exists(checkpts_path_cotrain):
if self.load_from_checkpoint:
saver.restore(session, checkpts_path_cotrain)
else:
os.makedirs(checkpts_path_cotrain)
else:
checkpts_path_cotrain = None
# Create a progress bar showing how many samples are labeled.
pbar = tqdm(
total=data.num_samples - data.num_train(), desc='self-labeled nodes')
logging.info('Starting co-training...')
step = session.run(iter_cotrain)
stop = step >= self.max_num_iter_cotrain
best_val_acc = -1
test_acc_at_best = -1
iter_at_best = -1
while not stop:
logging.info('----------------- Cotrain step %6d -----------------', step)
# Train the agreement model.
if self.first_iter_original and step == 0:
logging.info('First iteration trains the original classifier.'
'No need to train the agreement model.')
val_acc_agree = None
acc_pred_by_agr = None
else:
val_acc_agree = trainer_agr.train(
data, session=session, summary_writer=summary_writer)
if self.eval_acc_pred_by_agr:
# Evaluate the prediction accuracy by a majority vote model using the
# agreement model.
logging.info('Computing agreement majority vote predictions on '
'test data...')
acc_pred_by_agr = trainer_agr.predict_label_by_agreement(
session, data.get_indices_test(), self.num_neighbors_pred_by_agr)
else:
acc_pred_by_agr = None
# Train classification model.
test_acc, val_acc = trainer_cls.train(
data, session=session, summary_writer=summary_writer)
if val_acc > best_val_acc:
best_val_acc = val_acc
test_acc_at_best = test_acc
iter_at_best = step
if self.enable_summaries:
summary = tf.Summary()
summary.value.add(tag='cotrain/test_acc', simple_value=test_acc)
summary.value.add(tag='cotrain/val_acc', simple_value=val_acc)
if val_acc_agree is not None:
summary.value.add(
tag='cotrain/val_acc_agree', simple_value=val_acc_agree)
if acc_pred_by_agr is not None:
summary.value.add(
tag='cotrain/acc_predict_by_agreement',
simple_value=acc_pred_by_agr)
summary_writer.add_summary(summary, step)
summary_writer.flush()
logging.info(
'--------- Cotrain step %6d | Accuracy val: %10.4f | '
'Accuracy test: %10.4f ---------', step, val_acc, test_acc)
logging.info(
'Best validation acc: %.4f, corresponding test acc: %.4f at '
'iteration %d', best_val_acc, test_acc_at_best, iter_at_best)
if self.first_iter_original and step == 0:
logging.info('No self-labeling because the first iteration trains the '
'original classifier for evaluation purposes.')
step += 1
else:
# Extend labeled set by self-labeling.
logging.info('Self-labeling...')
selected_samples = self._extend_label_set(data, trainer_cls, session)
# If no new data points are added to the training set, stop.
num_new_labels = len(selected_samples)
pbar.update(num_new_labels)
if num_new_labels > 0:
data.compute_dataset_statistics(selected_samples, summary_writer,
step)
else:
logging.info('No new samples labeled. Stopping...')
stop = True
step += 1
stop |= step >= self.max_num_iter_cotrain
# Save model and dataset state in case of process preemption.
if self.checkpoints_step and step % self.checkpoints_step == 0:
self._save_state(saver, session, data, checkpts_path_cotrain)
session.run(iter_cotrain_update)
logging.info('________________________________________________________')
logging.info(
'Best validation acc: %.4f, corresponding test acc: %.4f at '
'iteration %d', best_val_acc, test_acc_at_best, iter_at_best)
pbar.close()