in fc_permute_mnist.py [0:0]
def train_task_sequence(model, sess, args):
"""
Train and evaluate LLL system such that we only see a example once
Args:
Returns:
dict A dictionary containing mean and stds for the experiment
"""
# List to store accuracy for each run
runs = []
batch_size = args.batch_size
if model.imp_method == 'A-GEM' or model.imp_method == 'ER':
use_episodic_memory = True
else:
use_episodic_memory = False
# Loop over number of runs to average over
for runid in range(args.num_runs):
print('\t\tRun %d:'%(runid))
# Initialize the random seeds
np.random.seed(args.random_seed+runid)
# Load the permute mnist dataset
datasets = construct_permute_mnist(model.num_tasks)
episodic_mem_size = args.mem_size*model.num_tasks*TOTAL_CLASSES
# Initialize all the variables in the model
sess.run(tf.global_variables_initializer())
# Run the init ops
model.init_updates(sess)
# List to store accuracies for a run
evals = []
# List to store the classes that we have so far - used at test time
test_labels = np.arange(TOTAL_CLASSES)
if use_episodic_memory:
# Reserve a space for episodic memory
episodic_images = np.zeros([episodic_mem_size, INPUT_FEATURE_SIZE])
episodic_labels = np.zeros([episodic_mem_size, TOTAL_CLASSES])
count_cls = np.zeros(TOTAL_CLASSES, dtype=np.int32)
episodic_filled_counter = 0
examples_seen_so_far = 0
# Mask for softmax
# Since all the classes are present in all the tasks so nothing to mask
logit_mask = np.ones(TOTAL_CLASSES)
if model.imp_method == 'PNN':
pnn_train_phase = np.array(np.zeros(model.num_tasks), dtype=np.bool)
pnn_logit_mask = np.ones([model.num_tasks, TOTAL_CLASSES])
if COUNT_VIOLATIONS:
violation_count = np.zeros(model.num_tasks)
vc = 0
# Training loop for all the tasks
for task in range(len(datasets)):
print('\t\tTask %d:'%(task))
# If not the first task then restore weights from previous task
if(task > 0 and model.imp_method != 'PNN'):
model.restore(sess)
# Extract training images and labels for the current task
task_train_images = datasets[task]['train']['images']
task_train_labels = datasets[task]['train']['labels']
# If multi_task is set the train using datasets of all the tasks
if MULTI_TASK:
if task == 0:
for t_ in range(1, len(datasets)):
task_train_images = np.concatenate((task_train_images, datasets[t_]['train']['images']), axis=0)
task_train_labels = np.concatenate((task_train_labels, datasets[t_]['train']['labels']), axis=0)
else:
# Skip training for this task
continue
# Assign equal weights to all the examples
task_sample_weights = np.ones([task_train_labels.shape[0]], dtype=np.float32)
total_train_examples = task_train_images.shape[0]
# Randomly suffle the training examples
perm = np.arange(total_train_examples)
np.random.shuffle(perm)
train_x = task_train_images[perm][:args.examples_per_task]
train_y = task_train_labels[perm][:args.examples_per_task]
task_sample_weights = task_sample_weights[perm][:args.examples_per_task]
print('Received {} images, {} labels at task {}'.format(train_x.shape[0], train_y.shape[0], task))
# Array to store accuracies when training for task T
ftask = []
num_train_examples = train_x.shape[0]
# Train a task observing sequence of data
if args.train_single_epoch:
num_iters = num_train_examples // batch_size
else:
num_iters = args.train_iters
# Training loop for task T
for iters in range(num_iters):
if args.train_single_epoch and not args.cross_validate_mode:
if (iters < 10) or (iters < 100 and iters % 10 == 0) or (iters % 100 == 0):
# Snapshot the current performance across all tasks after each mini-batch
fbatch = test_task_sequence(model, sess, datasets, args.online_cross_val)
ftask.append(fbatch)
offset = (iters * batch_size) % (num_train_examples - batch_size)
residual = batch_size
if model.imp_method == 'PNN':
pnn_train_phase[:] = False
pnn_train_phase[task] = True
feed_dict = {model.x: train_x[offset:offset+batch_size], model.y_[task]: train_y[offset:offset+batch_size],
model.sample_weights: task_sample_weights[offset:offset+batch_size],
model.training_iters: num_iters, model.train_step: iters, model.keep_prob: 1.0}
train_phase_dict = {m_t: i_t for (m_t, i_t) in zip(model.train_phase, pnn_train_phase)}
logit_mask_dict = {m_t: i_t for (m_t, i_t) in zip(model.output_mask, pnn_logit_mask)}
feed_dict.update(train_phase_dict)
feed_dict.update(logit_mask_dict)
else:
feed_dict = {model.x: train_x[offset:offset+batch_size], model.y_: train_y[offset:offset+batch_size],
model.sample_weights: task_sample_weights[offset:offset+batch_size],
model.training_iters: num_iters, model.train_step: iters, model.keep_prob: 1.0,
model.output_mask: logit_mask, model.train_phase: True}
if model.imp_method == 'VAN':
_, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict)
elif model.imp_method == 'PNN':
feed_dict[model.task_id] = task
_, loss = sess.run([model.train[task], model.unweighted_entropy[task]], feed_dict=feed_dict)
elif model.imp_method == 'FTR_EXT':
if task == 0:
_, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict)
else:
_, loss = sess.run([model.train_classifier, model.reg_loss], feed_dict=feed_dict)
elif model.imp_method == 'EWC':
# If first iteration of the first task then set the initial value of the running fisher
if task == 0 and iters == 0:
sess.run([model.set_initial_running_fisher], feed_dict=feed_dict)
# Update fisher after every few iterations
if (iters + 1) % model.fisher_update_after == 0:
sess.run(model.set_running_fisher)
sess.run(model.reset_tmp_fisher)
_, _, loss = sess.run([model.set_tmp_fisher, model.train, model.reg_loss], feed_dict=feed_dict)
elif model.imp_method == 'PI':
_, _, _, loss = sess.run([model.weights_old_ops_grouped, model.train, model.update_small_omega,
model.reg_loss], feed_dict=feed_dict)
elif model.imp_method == 'MAS':
_, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict)
elif model.imp_method == 'A-GEM':
if task == 0:
# Normal application of gradients
_, loss = sess.run([model.train_first_task, model.agem_loss], feed_dict=feed_dict)
else:
## Compute and store the reference gradients on the previous tasks
if episodic_filled_counter <= args.eps_mem_batch:
mem_sample_mask = np.arange(episodic_filled_counter)
else:
# Sample a random subset from episodic memory buffer
mem_sample_mask = np.random.choice(episodic_filled_counter, args.eps_mem_batch, replace=False) # Sample without replacement so that we don't sample an example more than once
# Store the reference gradient
sess.run(model.store_ref_grads, feed_dict={model.x: episodic_images[mem_sample_mask], model.y_: episodic_labels[mem_sample_mask],
model.keep_prob: 1.0, model.output_mask: logit_mask, model.train_phase: True})
if COUNT_VIOLATIONS:
vc, _, loss = sess.run([model.violation_count, model.train_subseq_tasks, model.agem_loss], feed_dict=feed_dict)
else:
# Compute the gradient for current task and project if need be
_, loss = sess.run([model.train_subseq_tasks, model.agem_loss], feed_dict=feed_dict)
# Put the batch in the ring buffer
for er_x, er_y_ in zip(train_x[offset:offset+residual], train_y[offset:offset+residual]):
cls = np.unique(np.nonzero(er_y_))[-1]
# Write the example at the location pointed by count_cls[cls]
cls_to_index_map = cls
with_in_task_offset = args.mem_size * cls_to_index_map
mem_index = count_cls[cls] + with_in_task_offset + episodic_filled_counter
episodic_images[mem_index] = er_x
episodic_labels[mem_index] = er_y_
count_cls[cls] = (count_cls[cls] + 1) % args.mem_size
elif model.imp_method == 'RWALK':
# If first iteration of the first task then set the initial value of the running fisher
if task == 0 and iters == 0:
sess.run([model.set_initial_running_fisher], feed_dict=feed_dict)
# Store the current value of the weights
sess.run(model.weights_delta_old_grouped)
# Update fisher and importance score after every few iterations
if (iters + 1) % model.fisher_update_after == 0:
# Update the importance score using distance in riemannian manifold
sess.run(model.update_big_omega_riemann)
# Now that the score is updated, compute the new value for running Fisher
sess.run(model.set_running_fisher)
# Store the current value of the weights
sess.run(model.weights_delta_old_grouped)
# Reset the delta_L
sess.run([model.reset_small_omega])
_, _, _, _, loss = sess.run([model.set_tmp_fisher, model.weights_old_ops_grouped,
model.train, model.update_small_omega, model.reg_loss], feed_dict=feed_dict)
elif model.imp_method == 'ER':
mem_filled_so_far = examples_seen_so_far if (examples_seen_so_far < episodic_mem_size) else episodic_mem_size
if mem_filled_so_far < args.eps_mem_batch:
er_mem_indices = np.arange(mem_filled_so_far)
else:
er_mem_indices = np.random.choice(mem_filled_so_far, args.eps_mem_batch, replace=False)
np.random.shuffle(er_mem_indices)
# Train on a batch of episodic memory first
er_train_x_batch = np.concatenate((episodic_images[er_mem_indices], train_x[offset:offset+residual]), axis=0)
er_train_y_batch = np.concatenate((episodic_labels[er_mem_indices], train_y[offset:offset+residual]), axis=0)
feed_dict = {model.x: er_train_x_batch, model.y_: er_train_y_batch,
model.training_iters: num_iters, model.train_step: iters, model.keep_prob: 1.0,
model.output_mask: logit_mask, model.train_phase: True}
_, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict)
for er_x, er_y_ in zip(train_x[offset:offset+residual], train_y[offset:offset+residual]):
update_reservior(er_x, er_y_, episodic_images, episodic_labels, episodic_mem_size, examples_seen_so_far)
examples_seen_so_far += 1
if (iters % 100 == 0):
print('Step {:d} {:.3f}'.format(iters, loss))
if (math.isnan(loss)):
print('ERROR: NaNs NaNs Nans!!!')
sys.exit(0)
print('\t\t\t\tTraining for Task%d done!'%(task))
# Upaate the episodic memory filled counter
if use_episodic_memory:
episodic_filled_counter += args.mem_size * TOTAL_CLASSES
if model.imp_method == 'A-GEM' and COUNT_VIOLATIONS:
violation_count[task] = vc
print('Task {}: Violation Count: {}'.format(task, violation_count))
sess.run(model.reset_violation_count, feed_dict=feed_dict)
# Compute the inter-task updates, Fisher/ importance scores etc
# Don't calculate the task updates for the last task
if (task < (len(datasets) - 1)) or MEASURE_PERF_ON_EPS_MEMORY:
model.task_updates(sess, task, task_train_images, np.arange(TOTAL_CLASSES))
print('\t\t\t\tTask updates after Task%d done!'%(task))
if args.train_single_epoch and not args.cross_validate_mode:
fbatch = test_task_sequence(model, sess, datasets, False)
ftask.append(fbatch)
ftask = np.array(ftask)
else:
if MEASURE_PERF_ON_EPS_MEMORY:
eps_mem = {
'images': episodic_images,
'labels': episodic_labels,
}
# Measure perf on episodic memory
ftask = test_task_sequence(model, sess, eps_mem, args.online_cross_val)
else:
# List to store accuracy for all the tasks for the current trained model
ftask = test_task_sequence(model, sess, datasets, args.online_cross_val)
# Store the accuracies computed at task T in a list
evals.append(ftask)
# Reset the optimizer
model.reset_optimizer(sess)
#-> End for loop task
runs.append(np.array(evals))
# End for loop runid
runs = np.array(runs)
return runs