in research/gam/gam/trainer/trainer_classification.py [0:0]
def __init__(self,
model,
data,
trainer_agr,
optimizer,
lr_initial,
batch_size,
min_num_iter,
max_num_iter,
num_iter_after_best_val,
max_num_iter_cotrain,
reg_weight_ll,
reg_weight_lu,
reg_weight_uu,
num_pairs_reg,
iter_cotrain,
reg_weight_vat=0.0,
use_ent_min=False,
enable_summaries=False,
summary_step=1,
summary_dir=None,
warm_start=False,
gradient_clip=None,
logging_step=1,
eval_step=1,
abs_loss_chg_tol=1e-10,
rel_loss_chg_tol=1e-7,
loss_chg_iter_below_tol=30,
checkpoints_dir=None,
weight_decay=None,
weight_decay_schedule=None,
penalize_neg_agr=False,
first_iter_original=True,
use_l2_classif=True,
seed=None,
lr_decay_steps=None,
lr_decay_rate=None,
use_graph=False):
super(TrainerClassification, self).__init__(
model=model,
abs_loss_chg_tol=abs_loss_chg_tol,
rel_loss_chg_tol=rel_loss_chg_tol,
loss_chg_iter_below_tol=loss_chg_iter_below_tol)
self.data = data
self.trainer_agr = trainer_agr
self.batch_size = batch_size
self.min_num_iter = min_num_iter
self.max_num_iter = max_num_iter
self.num_iter_after_best_val = num_iter_after_best_val
self.max_num_iter_cotrain = max_num_iter_cotrain
self.enable_summaries = enable_summaries
self.summary_step = summary_step
self.summary_dir = summary_dir
self.warm_start = warm_start
self.gradient_clip = gradient_clip
self.logging_step = logging_step
self.eval_step = eval_step
self.checkpoint_path = (
os.path.join(checkpoints_dir, 'classif_best.ckpt')
if checkpoints_dir is not None else None)
self.weight_decay_initial = weight_decay
self.weight_decay_schedule = weight_decay_schedule
self.num_pairs_reg = num_pairs_reg
self.reg_weight_ll = reg_weight_ll
self.reg_weight_lu = reg_weight_lu
self.reg_weight_uu = reg_weight_uu
self.reg_weight_vat = reg_weight_vat
self.use_ent_min = use_ent_min
self.penalize_neg_agr = penalize_neg_agr
self.use_l2_classif = use_l2_classif
self.first_iter_original = first_iter_original
self.iter_cotrain = iter_cotrain
self.lr_initial = lr_initial
self.lr_decay_steps = lr_decay_steps
self.lr_decay_rate = lr_decay_rate
self.use_graph = use_graph
# Build TensorFlow graph.
logging.info('Building classification TensorFlow graph...')
# Create placeholders.
# First obtain the features shape from the dataset, and append a batch_size
# dimension to it (i.e., `None` to allow for variable batch size).
features_shape = [None] + list(data.features_shape)
input_features = tf.placeholder(
tf.float32, shape=features_shape, name='input_features')
input_features_unlabeled = tf.placeholder(
tf.float32, shape=features_shape, name='input_features_unlabeled')
input_labels = tf.placeholder(tf.int64, shape=(None,), name='input_labels')
one_hot_labels = tf.one_hot(
input_labels, data.num_classes, name='input_labels_one_hot')
# Create a placeholder specifying if this is train time.
is_train = tf.placeholder_with_default(False, shape=[], name='is_train')
# Create variables and predictions.
with tf.variable_scope('predictions'):
encoding, variables_enc, reg_params_enc = (
self.model.get_encoding_and_params(
inputs=input_features, is_train=is_train))
self.variables = variables_enc
self.reg_params = reg_params_enc
predictions, variables_pred, reg_params_pred = (
self.model.get_predictions_and_params(
encoding=encoding, is_train=is_train))
self.variables.update(variables_pred)
self.reg_params.update(reg_params_pred)
normalized_predictions = self.model.normalize_predictions(predictions)
predictions_var_scope = tf.get_variable_scope()
# Create predictions on unlabeled data, which is only used for VAT loss.
with tf.variable_scope('predictions', reuse=True):
encoding_unlabeled, _, _ = self.model.get_encoding_and_params(
inputs=input_features_unlabeled,
is_train=is_train,
update_batch_stats=False)
predictions_unlabeled, _, _ = (
self.model.get_predictions_and_params(
encoding=encoding_unlabeled, is_train=is_train))
# Create a variable for weight decay that may be updated.
weight_decay_var, weight_decay_update = self._create_weight_decay_var(
weight_decay, weight_decay_schedule)
# Create counter for classification iterations.
iter_cls_total, iter_cls_total_update = self._create_counter()
# Create loss.
with tf.name_scope('loss'):
if self.use_l2_classif:
loss_supervised = tf.square(one_hot_labels - normalized_predictions)
loss_supervised = tf.reduce_sum(loss_supervised, axis=-1)
loss_supervised = tf.reduce_mean(loss_supervised)
else:
loss_supervised = self.model.get_loss(
predictions=predictions, targets=one_hot_labels, weight_decay=None)
# Agreement regularization loss.
loss_agr = self._get_agreement_reg_loss(data, is_train, features_shape)
# If the first co-train iteration trains the original model (for
# comparison purposes), then we do not add an agreement loss.
if self.first_iter_original:
loss_agr_weight = tf.cast(tf.greater(iter_cotrain, 0), tf.float32)
loss_agr = loss_agr * loss_agr_weight
# Weight decay loss.
loss_reg = 0.0
if weight_decay_var is not None:
for var in self.reg_params.values():
loss_reg += weight_decay_var * tf.nn.l2_loss(var)
# Adversarial loss, in case we want to add VAT on top of GAM.
loss_vat = get_loss_vat(input_features_unlabeled, predictions_unlabeled,
is_train, model, predictions_var_scope)
num_unlabeled = tf.shape(input_features_unlabeled)[0]
loss_vat = tf.cond(
tf.greater(num_unlabeled, 0), lambda: loss_vat, lambda: 0.0)
if self.use_ent_min:
# Use entropy minimization with VAT (i.e. VATENT).
loss_ent = entropy_y_x(predictions_unlabeled)
loss_vat = loss_vat + tf.cond(
tf.greater(num_unlabeled, 0), lambda: loss_ent, lambda: 0.0)
loss_vat = loss_vat * self.reg_weight_vat
if self.first_iter_original:
# Do not add the adversarial loss in the first iteration if
# the first iteration trains the plain baseline model.
weight_loss_vat = tf.cond(
tf.greater(iter_cotrain, 0), lambda: 1.0, lambda: 0.0)
loss_vat = loss_vat * weight_loss_vat
# Total loss.
loss_op = loss_supervised + loss_agr + loss_reg + loss_vat
# Create accuracy.
accuracy = tf.equal(tf.argmax(normalized_predictions, 1), input_labels)
accuracy = tf.reduce_mean(tf.cast(accuracy, tf.float32))
# Create Tensorboard summaries.
if self.enable_summaries:
summaries = [
tf.summary.scalar('loss_supervised', loss_supervised),
tf.summary.scalar('loss_agr', loss_agr),
tf.summary.scalar('loss_reg', loss_reg),
tf.summary.scalar('loss_total', loss_op)
]
self.summary_op = tf.summary.merge(summaries)
# Create learning rate schedule and optimizer.
self.global_step = tf.train.get_or_create_global_step()
if self.lr_decay_steps is not None and self.lr_decay_rate is not None:
self.lr = tf.train.exponential_decay(
self.lr_initial,
self.global_step,
self.lr_decay_steps,
self.lr_decay_rate,
staircase=True)
self.optimizer = optimizer(self.lr)
else:
self.optimizer = optimizer(lr_initial)
# Get trainable variables and compute gradients.
grads_and_vars = self.optimizer.compute_gradients(
loss_op,
tf.trainable_variables(scope=tf.get_default_graph().get_name_scope()))
# Clip gradients.
if self.gradient_clip:
variab = [elem[1] for elem in grads_and_vars]
gradients = [elem[0] for elem in grads_and_vars]
gradients, _ = tf.clip_by_global_norm(gradients, self.gradient_clip)
grads_and_vars = tuple(zip(gradients, variab))
with tf.control_dependencies(
tf.get_collection(
tf.GraphKeys.UPDATE_OPS,
scope=tf.get_default_graph().get_name_scope())):
train_op = self.optimizer.apply_gradients(
grads_and_vars, global_step=self.global_step)
# Create a saver for model variables.
trainable_vars = [v for _, v in grads_and_vars]
# Put together the subset of variables to save and restore from the best
# validation accuracy as we train the agreement model in one cotrain round.
vars_to_save = trainable_vars + []
if isinstance(weight_decay_var, tf.Variable):
vars_to_save.append(weight_decay_var)
saver = tf.train.Saver(vars_to_save)
# Put together all variables that need to be saved in case the process is
# interrupted and needs to be restarted.
self.vars_to_save = [iter_cls_total, self.global_step]
if isinstance(weight_decay_var, tf.Variable):
self.vars_to_save.append(weight_decay_var)
if self.warm_start:
self.vars_to_save.extend([v for v in self.variables])
# More variables to be initialized after the session is created.
self.is_initialized = False
self.rng = np.random.RandomState(seed)
self.input_features = input_features
self.input_features_unlabeled = input_features_unlabeled
self.input_labels = input_labels
self.predictions = predictions
self.normalized_predictions = normalized_predictions
self.weight_decay_var = weight_decay_var
self.weight_decay_update = weight_decay_update
self.iter_cls_total = iter_cls_total
self.iter_cls_total_update = iter_cls_total_update
self.accuracy = accuracy
self.train_op = train_op
self.loss_op = loss_op
self.saver = saver
self.batch_size_actual = tf.shape(self.predictions)[0]
self.reset_optimizer = tf.variables_initializer(self.optimizer.variables())
self.is_train = is_train