def train_task_sequence()

in conv_split_awa_hybrid.py [0:0]


def train_task_sequence(model, sess, saver, datasets, class_attr, num_classes_per_task, cross_validate_mode, train_single_epoch, do_sampling, is_herding,  
        episodic_mem_size, train_iters, batch_size, num_runs, init_checkpoint, online_cross_val, random_seed):
    """
    Train and evaluate LLL system such that we only see a example once
    Args:
    Returns:
        dict    A dictionary containing mean and stds for the experiment
    """
    # List to store accuracy for each run
    runs = []
    task_labels_dataset = []

    break_training = 0
    # Loop over number of runs to average over
    for runid in range(num_runs):
        print('\t\tRun %d:'%(runid))

        # Initialize the random seeds
        np.random.seed(random_seed+runid)
        random.seed(random_seed+runid)

        # Get the task labels from the total number of tasks and full label space
        task_labels = []
        classes_per_task = num_classes_per_task
        classes_appearing_in_tasks = dict()
        for cls in range(TOTAL_CLASSES):
            classes_appearing_in_tasks[cls] = 0

        if online_cross_val:
            label_array = np.arange(TOTAL_CLASSES)
            for tt in range(model.num_tasks):
                offset = tt * classes_per_task
                task_labels.append(list(label_array[offset:offset+classes_per_task]))
        else:
            for tt in range(model.num_tasks):
                task_labels.append(random.sample(range(K_FOR_CROSS_VAL*classes_per_task, TOTAL_CLASSES), classes_per_task))
                for lab in task_labels[tt]:
                    classes_appearing_in_tasks[lab] += 1
                print('Task: {}, Labels:{}'.format(tt, task_labels[tt]))
            print('Class frequency in Tasks: {}'.format(classes_appearing_in_tasks))

        # Store the task labels
        task_labels_dataset.append(task_labels)

        # Initialize all the variables in the model
        sess.run(tf.global_variables_initializer())

        if PRETRAIN:
            # Load the variables from a checkpoint
            if model.network_arch == 'RESNET-B':
                # Define loader (weights which will be loaded from a checkpoint)
                restore_vars = [v for v in model.trainable_vars if 'fc' not in v.name and 'attr_embed' not in v.name]
                loader = tf.train.Saver(restore_vars)
                load(loader, sess, init_checkpoint)
            elif model.network_arch == 'VGG':
                # Load the pretrained weights from the npz file
                weights = np.load(init_checkpoint)
                keys = sorted(weights.keys())
                for i, key in enumerate(keys[:-2]): # Load everything except the last layer
                    sess.run(model.trainable_vars[i].assign(weights[key]))

        # Run the init ops
        model.init_updates(sess)

        # List to store accuracies for a run
        evals = []

        if model.imp_method == 'S-GEM':
            # List to store the episodic memories of the previous tasks
            task_based_memory = []

        if model.imp_method == 'A-GEM':
            # Reserve a space for episodic memory
            episodic_images = np.zeros([episodic_mem_size, IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS])
            episodic_labels = np.zeros([episodic_mem_size, model.num_tasks*TOTAL_CLASSES])
            episodic_filled_counter = 0
            a_gem_logit_mask = np.zeros([model.num_tasks, model.total_classes])
            # Labels for all the tasks that we have seen in the past
            prev_task_labels = []
            prev_class_attrs = np.zeros([model.total_classes, class_attr.shape[1]])

        if do_sampling:
            # List to store important samples from the previous tasks
            last_task_x = None
            last_task_y_ = None

        # Mask for softmax
        logit_mask = np.zeros(model.total_classes)

        max_batch_dimension = 500

        # Dict to store the number of times a class has already been seen in the training
        class_seen_already = dict()
        for cls in range(TOTAL_CLASSES):
            class_seen_already[cls] = 0

        # Training loop for all the tasks
        for task in range(len(task_labels)):
            print('\t\tTask %d:'%(task))

            # If not the first task then restore weights from previous task
            if(task > 0):
                model.restore(sess)

            # Increment the class seen count
            for cls in task_labels[task]:
                class_seen_already[cls] += 1

            task_train_images, task_train_labels = load_task_specific_data_in_proportion(datasets[0]['train'], task_labels[task], classes_appearing_in_tasks, class_seen_already)
            print('Received {} images, {} labels at task {}'.format(task_train_images.shape[0], task_train_labels.shape[0], task))
            print('Unique labels in the task: {}'.format(np.unique(np.nonzero(task_train_labels)[1])))

            # Assign equal weights to all the examples
            task_sample_weights = np.ones([task_train_labels.shape[0]], dtype=np.float32)

            num_train_examples = task_train_images.shape[0]

            # Train a task observing sequence of data
            logit_mask[:] = 0
            if train_single_epoch:
                # Ceiling operation
                num_iters = (num_train_examples + batch_size - 1) // batch_size
            else:
                num_iters = train_iters

            logit_mask_offset = task * TOTAL_CLASSES
            classes_adjusted_for_head = [cls + logit_mask_offset for cls in task_labels[task]]
            logit_mask[classes_adjusted_for_head] = 1.0

            # Randomly suffle the training examples
            perm = np.arange(num_train_examples)
            np.random.shuffle(perm)
            train_x = task_train_images[perm]
            train_y = task_train_labels[perm]
            task_sample_weights = task_sample_weights[perm]

            # Array to store accuracies when training for task T
            if cross_validate_mode:
                # Because we will evaluate at the end
                ftask = 0
            elif train_single_epoch:
                # Because we will evaluate after every mini-batch of every task
                ftask = np.zeros([max_batch_dimension+1, model.num_tasks])
                batch_dim_count = 0
            else:
                # Multi-epoch because we will evaluate after every task
                ftask = []

            # Attribute mask
            masked_class_attrs = np.zeros([model.total_classes, class_attr.shape[1]])
            masked_class_attrs[classes_adjusted_for_head] = class_attr[task_labels[task]]
            #masked_class_attrs[task_labels[task]] = class_attr[task_labels[task]]

            # Number of iterations after which convergence is checked
            convergence_iters = int(num_iters * MEASURE_CONVERGENCE_AFTER)

            final_train_labels = np.zeros([batch_size, model.total_classes])
            head_offset = task * TOTAL_CLASSES

            # Training loop for task T
            for iters in range(num_iters):

                if train_single_epoch and not cross_validate_mode:
                    if (iters < 11):
                        # Snapshot the current performance across all tasks after each mini-batch
                        fbatch = test_task_sequence(model, sess, datasets[0]['test'], class_attr, num_classes_per_task, task_labels, task, online_cross_val)
                        ftask[batch_dim_count] = fbatch
                        # Increment the batch_dim_count
                        batch_dim_count += 1
                        # Set the output labels over which the model needs to be trained
                        if model.imp_method == 'A-GEM':
                            a_gem_logit_mask[:] = 0
                            a_gem_logit_mask[task][classes_adjusted_for_head] = 1.0
                        else:
                            logit_mask[:] = 0
                            logit_mask[classes_adjusted_for_head] = 1.0
                        
                offset = iters * batch_size
                if (offset+batch_size <= num_train_examples):
                    residual = batch_size
                else:
                    residual = num_train_examples - offset

                final_train_labels[:residual, head_offset:head_offset+TOTAL_CLASSES] = train_y[offset:offset+residual]
                feed_dict = {model.x: train_x[offset:offset+residual], model.y_: final_train_labels[:residual], 
                        model.class_attr: masked_class_attrs,
                        model.sample_weights: task_sample_weights[offset:offset+residual],
                        model.training_iters: num_iters, model.train_step: iters, model.keep_prob: 0.5, 
                        model.train_phase: True}

                if model.imp_method == 'VAN':
                    feed_dict[model.output_mask] = logit_mask
                    _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict)

                elif model.imp_method == 'EWC' or model.imp_method == 'M-EWC':
                    feed_dict[model.output_mask] = logit_mask
                    # If first iteration of the first task then set the initial value of the running fisher
                    if task == 0 and iters == 0:
                        sess.run([model.set_initial_running_fisher], feed_dict=feed_dict)
                    # Update fisher after every few iterations
                    if (iters + 1) % model.fisher_update_after == 0:
                        sess.run(model.set_running_fisher)
                        sess.run(model.reset_tmp_fisher)
                    
                    if (iters >= convergence_iters) and (model.imp_method == 'M-EWC'):
                        _, _, _, _, loss = sess.run([model.weights_old_ops_grouped, model.set_tmp_fisher, model.train, model.update_small_omega,
                            model.reg_loss], feed_dict=feed_dict)
                    else:
                        _, _, loss = sess.run([model.set_tmp_fisher, model.train, model.reg_loss], feed_dict=feed_dict)

                elif model.imp_method == 'PI':
                    feed_dict[model.output_mask] = logit_mask
                    _, _, _, loss = sess.run([model.weights_old_ops_grouped, model.train, model.update_small_omega, 
                                              model.reg_loss], feed_dict=feed_dict)

                elif model.imp_method == 'MAS':
                    feed_dict[model.output_mask] = logit_mask
                    _, loss = sess.run([model.train, model.reg_loss], feed_dict=feed_dict)

                elif model.imp_method == 'S-GEM':
                    if task == 0:
                        logit_mask[:] = 0
                        logit_mask[task_labels[task]] = 1.0
                        feed_dict[model.output_mask] = logit_mask
                        # Normal application of gradients
                        _, loss = sess.run([model.train_first_task, model.agem_loss], feed_dict=feed_dict)
                    else:
                        # Randomly sample a task from the previous tasks
                        prev_task = np.random.randint(0, task)
                        # Set the logit mask for the randomly sampled task
                        logit_mask[:] = 0
                        logit_mask[task_labels[prev_task]] = 1.0
                        prev_class_attrs = np.zeros_like(class_attr)
                        if online_cross_val:
                            attr_offset = prev_task * num_classes_per_task
                        else:
                            attr_offset = (prev_task + K_FOR_CROSS_VAL) * num_classes_per_task
                        prev_class_attrs[attr_offset:attr_offset+num_classes_per_task] = class_attr[attr_offset:attr_offset+num_classes_per_task]
                        # Store the reference gradient
                        sess.run(model.store_ref_grads, feed_dict={model.x: task_based_memory[prev_task]['images'], model.y_: task_based_memory[prev_task]['labels'],
                            model.class_attr: prev_class_attrs,
                            model.keep_prob: 1.0, model.output_mask: logit_mask, model.train_phase: True})
                        # Compute the gradient for current task and project if need be
                        logit_mask[:] = 0
                        logit_mask[task_labels[task]] = 1.0
                        feed_dict[model.output_mask] = logit_mask
                        _, loss = sess.run([model.train_subseq_tasks, model.agem_loss], feed_dict=feed_dict)

                elif model.imp_method == 'A-GEM':
                    if task == 0:
                        a_gem_logit_mask[:] = 0
                        a_gem_logit_mask[task][classes_adjusted_for_head] = 1.0
                        logit_mask_dict = {m_t: i_t for (m_t, i_t) in zip(model.output_mask, a_gem_logit_mask)}
                        feed_dict.update(logit_mask_dict)
                        feed_dict[model.mem_batch_size] = batch_size
                        # Normal application of gradients
                        _, loss = sess.run([model.train_first_task, model.agem_loss], feed_dict=feed_dict)
                    else:
                        ## Compute and store the reference gradients on the previous tasks
                        # Set the mask for all the previous tasks so far
                        a_gem_logit_mask[:] = 0
                        for tt in range(task):
                            logit_mask_offset = tt * TOTAL_CLASSES
                            classes_adjusted_for_head = [cls + logit_mask_offset for cls in task_labels[tt]]
                            a_gem_logit_mask[tt][classes_adjusted_for_head] = 1.0

                        if KEEP_EPISODIC_MEMORY_FULL:
                            mem_sample_mask = np.random.choice(episodic_mem_size, EPS_MEM_BATCH_SIZE, replace=False) # Sample without replacement so that we don't sample an example more than once
                        else:
                            if episodic_filled_counter <= EPS_MEM_BATCH_SIZE:
                                mem_sample_mask = np.arange(episodic_filled_counter)
                            else:
                                # Sample a random subset from episodic memory buffer
                                mem_sample_mask = np.random.choice(episodic_filled_counter, EPS_MEM_BATCH_SIZE, replace=False) # Sample without replacement so that we don't sample an example more than once
                        # Store the reference gradient
                        ref_feed_dict = {model.x: episodic_images[mem_sample_mask], model.y_: episodic_labels[mem_sample_mask], model.class_attr: prev_class_attrs,  
                                model.keep_prob: 1.0, model.train_phase: True}
                        logit_mask_dict = {m_t: i_t for (m_t, i_t) in zip(model.output_mask, a_gem_logit_mask)}
                        ref_feed_dict.update(logit_mask_dict)
                        ref_feed_dict[model.mem_batch_size] = float(len(mem_sample_mask))
                        sess.run(model.store_ref_grads, feed_dict=ref_feed_dict)
                        # Compute the gradient for current task and project if need be
                        a_gem_logit_mask[:] = 0
                        logit_mask_offset = task * TOTAL_CLASSES
                        classes_adjusted_for_head = [cls + logit_mask_offset for cls in task_labels[task]]
                        a_gem_logit_mask[task][classes_adjusted_for_head] = 1.0
                        logit_mask_dict = {m_t: i_t for (m_t, i_t) in zip(model.output_mask, a_gem_logit_mask)}
                        feed_dict.update(logit_mask_dict)
                        feed_dict[model.mem_batch_size] = batch_size
                        _, loss = sess.run([model.train_subseq_tasks, model.agem_loss], feed_dict=feed_dict)

                elif model.imp_method == 'RWALK':
                    feed_dict[model.output_mask] = logit_mask
                    # If first iteration of the first task then set the initial value of the running fisher
                    if task == 0 and iters == 0:
                        sess.run([model.set_initial_running_fisher], feed_dict=feed_dict)
                        # Store the current value of the weights
                        sess.run(model.weights_delta_old_grouped)
                    # Update fisher and importance score after every few iterations
                    if (iters + 1) % model.fisher_update_after == 0:
                        # Update the importance score using distance in riemannian manifold   
                        sess.run(model.update_big_omega_riemann)
                        # Now that the score is updated, compute the new value for running Fisher
                        sess.run(model.set_running_fisher)
                        # Store the current value of the weights
                        sess.run(model.weights_delta_old_grouped)
                        # Reset the delta_L
                        sess.run([model.reset_small_omega])

                    _, _, _, _, loss = sess.run([model.set_tmp_fisher, model.weights_old_ops_grouped, 
                        model.train, model.update_small_omega, model.reg_loss], feed_dict=feed_dict)

                if (iters % 100 == 0):
                    print('Step {:d} {:.3f}'.format(iters, loss))

                if (math.isnan(loss)):
                    print('ERROR: NaNs NaNs NaNs!!!')
                    break_training = 1
                    break

            print('\t\t\t\tTraining for Task%d done!'%(task))

            if model.imp_method == 'A-GEM':
                # Update the previous task labels
                prev_task_labels.extend(classes_adjusted_for_head)
                prev_class_attrs[classes_adjusted_for_head] = class_attr[task_labels[task]]

            if break_training:
                break

            # Compute the inter-task updates, Fisher/ importance scores etc
            # Don't calculate the task updates for the last task
            if task < (len(task_labels) - 1):
                # TODO: For MAS, should the gradients be for current task or all the previous tasks
                model.task_updates(sess, task, task_train_images, task_labels[task], num_classes_per_task=num_classes_per_task, class_attr=class_attr, online_cross_val=online_cross_val) 
                print('\t\t\t\tTask updates after Task%d done!'%(task))

                # If importance method is '*-GEM' then store the episodic memory for the task
                if 'GEM' in model.imp_method:
                    data_to_sample_from = {
                            'images': task_train_images,
                            'labels': task_train_labels,
                            }
                    if model.imp_method == 'S-GEM':
                        # Get the important samples from the current task
                        if is_herding: # Sampling based on MoF
                            # Compute the features of training data
                            features_dim = model.image_feature_dim
                            features = np.zeros([num_train_examples, features_dim])
                            samples_at_a_time = 32
                            residual = num_train_examples % samples_at_a_time
                            for i in range(num_train_examples// samples_at_a_time):
                                offset = i * samples_at_a_time
                                features[offset:offset+samples_at_a_time] = sess.run(model.features, feed_dict={model.x: task_train_images[offset:offset+samples_at_a_time],
                                    model.y_: task_train_labels[offset:offset+samples_at_a_time], model.keep_prob: 1.0,
                                    model.output_mask: logit_mask, model.train_phase: False})
                            if residual > 0:
                                offset = (i + 1) * samples_at_a_time
                                features[offset:offset+residual] = sess.run(model.features, feed_dict={model.x: task_train_images[offset:offset+residual],
                                    model.y_: task_train_labels[offset:offset+residual], model.keep_prob: 1.0,
                                    model.output_mask: logit_mask, model.train_phase: False})
                            imp_images, imp_labels = sample_from_dataset_icarl(data_to_sample_from, features, task_labels[task], SAMPLES_PER_CLASS)
                        else: # Random sampling
                            # Do the uniform sampling/ only get examples from current task
                            importance_array = np.ones(num_train_examples, dtype=np.float32)
                            imp_images, imp_labels = sample_from_dataset(data_to_sample_from, importance_array, task_labels[task], SAMPLES_PER_CLASS)
                        task_memory = {
                                'images': deepcopy(imp_images),
                                'labels': deepcopy(imp_labels),
                                }
                        task_based_memory.append(task_memory)

                    elif model.imp_method == 'A-GEM':
                        # Do the uniform sampling/ only get examples from current task
                        importance_array = np.ones(num_train_examples, dtype=np.float32)
                        if KEEP_EPISODIC_MEMORY_FULL:
                            update_episodic_memory(data_to_sample_from, importance_array, episodic_mem_size, task, episodic_images, episodic_labels)
                        else:
                            imp_images, imp_labels = sample_from_dataset(data_to_sample_from, importance_array, task_labels[task], SAMPLES_PER_CLASS)
                        if not KEEP_EPISODIC_MEMORY_FULL: # Fill the memory to always keep M/T samples per task
                            total_imp_samples = imp_images.shape[0]
                            eps_offset = task * total_imp_samples
                            episodic_images[eps_offset:eps_offset+total_imp_samples] = imp_images
                            episodic_labels[eps_offset:eps_offset+total_imp_samples, head_offset:head_offset+TOTAL_CLASSES] = imp_labels
                            episodic_filled_counter += total_imp_samples
                        print('Unique labels in the episodic memory: {}'.format(np.unique(np.nonzero(episodic_labels)[1])))
                        # Inspect episodic memory
                        if DEBUG_EPISODIC_MEMORY:
                            # Which labels are present in the memory
                            unique_labels = np.unique(np.nonzero(episodic_labels)[-1])
                            print('Unique Labels present in the episodic memory'.format(unique_labels))
                            print('Labels count:')
                            for lbl in unique_labels:
                                print('Label {}: {} samples'.format(lbl, np.where(np.nonzero(episodic_labels)[-1] == lbl)[0].size))
                            # Is there any space which is not filled
                            print('Empty space: {}'.format(np.where(np.sum(episodic_labels, axis=1) == 0)))
                        print('Episodic memory of {} images at task {} saved!'.format(episodic_images.shape[0], task))

                # If sampling flag is set, store few of the samples from previous task
                if do_sampling:
                    # Do the uniform sampling/ only get examples from current task
                    importance_array = np.ones([datasets[task]['train']['images'].shape[0]], dtype=np.float32)
                    # Get the important samples from the current task
                    imp_images, imp_labels = sample_from_dataset(datasets[task]['train'], importance_array, 
                            task_labels[task], SAMPLES_PER_CLASS)

                    if imp_images is not None:
                        if last_task_x is None:
                            last_task_x = imp_images
                            last_task_y_ = imp_labels
                        else:
                            last_task_x = np.concatenate((last_task_x, imp_images), axis=0)
                            last_task_y_ = np.concatenate((last_task_y_, imp_labels), axis=0)

                    # Delete the importance array now that you don't need it in the current run
                    del importance_array

                    print('\t\t\t\tEpisodic memory is saved for Task%d!'%(task))

            if cross_validate_mode:
                if (task == model.num_tasks - 1) or MULTI_TASK:
                    # List to store accuracy for all the tasks for the current trained model
                    ftask = test_task_sequence(model, sess, datasets[0]['test'], class_attr, num_classes_per_task, task_labels, task, online_cross_val)
            elif train_single_epoch:
                fbatch = test_task_sequence(model, sess, datasets[0]['test'], class_attr, num_classes_per_task, task_labels, task, False)
                ftask[batch_dim_count] = fbatch
                print('Task: {}, {}'.format(task, fbatch))
            else:
                # Multi-epoch training, so compute accuracy at the end
                ftask = test_task_sequence(model, sess, datasets[0]['test'], class_attr, num_classes_per_task, task_labels, task, online_cross_val)

            if SAVE_MODEL_PARAMS:
                save(saver, sess, SNAPSHOT_DIR, iters)

            if not cross_validate_mode:
                # Store the accuracies computed at task T in a list
                evals.append(np.array(ftask))

            # Reset the optimizer
            model.reset_optimizer(sess)

            #-> End for loop task

        if not cross_validate_mode:
            runs.append(np.array(evals))

        if break_training:
            break
        # End for loop runid
    if cross_validate_mode:
        return np.mean(ftask)
    else:
        runs = np.array(runs)
        return runs, task_labels_dataset