in model/model.py [0:0]
def __init__(self, x_train, y_, num_tasks, opt, imp_method, synap_stgth, fisher_update_after, fisher_ema_decay, network_arch='FC-S',
is_ATT_DATASET=False, x_test=None, attr=None):
"""
Instantiate the model
"""
# Define some placeholders which are used to feed the data to the model
self.y_ = y_
if imp_method == 'PNN':
self.train_phase = []
self.total_classes = int(self.y_[0].get_shape()[1])
self.train_phase = [tf.placeholder(tf.bool, name='train_phase_%d'%(i)) for i in range(num_tasks)]
self.output_mask = [tf.placeholder(dtype=tf.float32, shape=[self.total_classes]) for i in range(num_tasks)]
else:
self.total_classes = int(self.y_.get_shape()[1])
self.train_phase = tf.placeholder(tf.bool, name='train_phase')
if (imp_method == 'A-GEM' or imp_method == 'ER') and 'FC-' not in network_arch: # Only for Split-X setups
self.output_mask = [tf.placeholder(dtype=tf.float32, shape=[self.total_classes]) for i in range(num_tasks)]
self.mem_batch_size = tf.placeholder(dtype=tf.float32, shape=())
else:
self.output_mask = tf.placeholder(dtype=tf.float32, shape=[self.total_classes])
self.sample_weights = tf.placeholder(tf.float32, shape=[None])
self.task_id = tf.placeholder(dtype=tf.int32, shape=())
self.store_grad_batches = tf.placeholder(dtype=tf.float32, shape=())
self.keep_prob = tf.placeholder(dtype=tf.float32, shape=())
self.train_samples = tf.placeholder(dtype=tf.float32, shape=())
self.training_iters = tf.placeholder(dtype=tf.float32, shape=())
self.train_step = tf.placeholder(dtype=tf.float32, shape=())
self.violation_count = tf.Variable(0, dtype=tf.float32, trainable=False)
self.is_ATT_DATASET = is_ATT_DATASET # To use a different (standard one) ResNet-18 for CUB
if x_test is not None:
# If CUB datatset then use augmented x (x_train) for training and non-augmented x (x_test) for testing
self.x = tf.cond(self.train_phase, lambda: tf.identity(x_train), lambda: tf.identity(x_test))
train_shape = x_train.get_shape().as_list()
x = tf.reshape(self.x, [-1, train_shape[1], train_shape[2], train_shape[3]])
else:
# We don't use data augmentation for other datasets
self.x = x_train
x = self.x
# Class attributes for zero shot transfer
self.class_attr = attr
if self.class_attr is not None:
self.attr_dims = int(self.class_attr.get_shape()[1])
# Save the arguments passed from the main script
self.opt = opt
self.num_tasks = num_tasks
self.imp_method = imp_method
self.fisher_update_after = fisher_update_after
self.fisher_ema_decay = fisher_ema_decay
self.network_arch = network_arch
# A scalar variable for previous syanpse strength
self.synap_stgth = tf.constant(synap_stgth, shape=[1], dtype=tf.float32)
self.triplet_loss_scale = 2.1
# Define different variables
self.weights_old = []
self.star_vars = []
self.small_omega_vars = []
self.big_omega_vars = []
self.big_omega_riemann_vars = []
self.fisher_diagonal_at_minima = []
self.hebbian_score_vars = []
self.running_fisher_vars = []
self.tmp_fisher_vars = []
self.max_fisher_vars = []
self.min_fisher_vars = []
self.max_score_vars = []
self.min_score_vars = []
self.normalized_score_vars = []
self.score_vars = []
self.normalized_fisher_at_minima_vars = []
self.weights_delta_old_vars = []
self.ref_grads = []
self.projected_gradients_list = []
if self.class_attr is not None:
self.loss_and_train_ops_for_attr_vector(x, self.y_)
else:
self.loss_and_train_ops_for_one_hot_vector(x, self.y_)
# Set the operations to reset the optimier when needed
self.reset_optimizer_ops()