def train()

in research/multi_representation_adversary/multi_representation_adversary/trainer.py [0:0]


def train(ckpt_dir=None,
          summary_dir=None,
          epochs=200,
          steps_per_epoch=351,  # 45000 / 128 for CIFAR-10
          global_batch_size=128,
          model_fn=resnet.build_resnet_v1,
          lr_scheduler=learning_rate_scheduler,
          representation_list=(("identity", "none"),)):
  """Train a model with adversarial training in multiple representation spaces.

  Args:
    ckpt_dir: The directory to store model checkpoints.
    summary_dir: The directory to store training summaries.
    epochs: Maximum number of epochs to train for.
    steps_per_epoch: Number of training steps in each epoch.
    global_batch_size: Batch size across all processors/accelerators for each
      training step.
    model_fn: A callable which builds the model structure.
    lr_scheduler: A callable which returns the learning rate at any given epoch.
    representation_list: A list of (transform, attack) tuples representing the
      adversaries that this model should consider.
  """
  # Set up distributed training strategy first because all variables (model,
  # optimizer, etc) have to be created in the strategy's scope.
  strategy = tf.distribute.MirroredStrategy()
  with strategy.scope():
    model = model_fn(return_logits=True)  # Other params are set in gin
    optimizer = tf.keras.optimizers.SGD(learning_rate=lr_scheduler(0),
                                        momentum=0.9)
    loss_obj = tf.keras.losses.SparseCategoricalCrossentropy(
        from_logits=True, reduction=tf.keras.losses.Reduction.NONE)

    def loss_fn(label, logit):
      # Normalize by global_batch_size, which is different from usual
      # (per-replica) batch size in a distributed training environment.
      return tf.nn.compute_average_loss(loss_obj(label, logit),
                                        global_batch_size=global_batch_size)

    metrics = [
        tf.keras.metrics.SparseCategoricalCrossentropy("loss",
                                                       from_logits=True),
        tf.keras.metrics.SparseCategoricalAccuracy("accuracy")]

    # Compile a tf.function for training and eval (validation) steps for each
    # (transform, attack) tuple.
    representation_names = []
    train_step_fns, eval_step_fns = [], []
    for transform_name, attack_name in representation_list:
      representation_names.append(f"{transform_name}_{attack_name}")
      attack_fn = helper.build_attack_fn(model, transform_name, attack_name)
      train_step_fns.append(helper.build_train_step_fn(
          model, optimizer, loss_fn, metrics, attack_fn))
      eval_step_fns.append(helper.build_eval_step_fn(model, metrics, attack_fn))
    selector = selectors.construct_representation_selector(representation_names)

    # Create checkpoint object for saving model weights and selector state.
    checkpoint = tf.train.Checkpoint(model=model, selector=selector)
    ckpt_mgr = tf.train.CheckpointManager(checkpoint, ckpt_dir,
                                          max_to_keep=None)
    restored_path = ckpt_mgr.restore_or_initialize()
    if restored_path:
      logging.info("Restored checkpoint %s", restored_path)
      start_epoch = int(restored_path.rsplit("-", 1)[-1])  # path like "ckpt-N"
      total_steps = start_epoch * steps_per_epoch
    else:
      logging.info("Model initialized")
      start_epoch, total_steps = 0, 0
      ckpt_mgr.save(0)

    train_dataset = data.get_training_dataset(global_batch_size)
    valid_dataset = data.get_validation_dataset(global_batch_size)

    with tf.summary.create_file_writer(summary_dir).as_default():
      for epoch in range(start_epoch + 1, epochs + 1):
        logging.info("Epoch %d", epoch)

        # Learning rate decay
        if lr_scheduler(epoch) != optimizer.learning_rate:
          optimizer.learning_rate = lr_scheduler(epoch)
          logging.info("New learning rate: %g", optimizer.learning_rate)

        # Training
        dist_dataset = strategy.experimental_distribute_dataset(
            train_dataset.take(steps_per_epoch))
        for x, y in dist_dataset:
          selected_idx = selector.select(total_steps)
          train_step_fn = train_step_fns[selected_idx]
          per_replica_loss = strategy.run(train_step_fn, args=(x, y))
          loss_value = strategy.reduce(tf.distribute.ReduceOp.SUM,
                                       per_replica_loss, axis=None)
          if total_steps % 50 == 0:
            tf.summary.scalar("train/batch_loss", loss_value, step=total_steps)
          total_steps += 1

        for metric in metrics:
          tf.summary.scalar(f"train/{metric.name}", metric.result(), step=epoch)
          metric.reset_states()

        # Maybe update the selector's state
        if selector.should_update(epoch):
          logging.info("Evaluate on validation set and update selector state")
          validation_losses = []
          dist_val_dataset = strategy.experimental_distribute_dataset(
              valid_dataset)
          for i, eval_step_fn in enumerate(eval_step_fns):
            for x, y in dist_val_dataset:
              strategy.run(eval_step_fn, args=(x, y))
            validation_losses.append(metrics[0].result())  # Crossentropy loss
            for metric in metrics:
              name = f"validation/{metric.name}/{representation_names[i]}"
              tf.summary.scalar(name, metric.result(), step=epoch)
              metric.reset_states()
          selector.update(epoch, validation_losses)

        # Save a checkpoint
        ckpt_mgr.save(epoch)