def apply_nas_layers()

in tensor2tensor/models/neural_architecture_search/nas_model.py [0:0]


def apply_nas_layers(input_tensor,
                     left_inputs,
                     left_layers,
                     left_activations,
                     left_output_dims,
                     left_norms,
                     right_inputs,
                     right_layers,
                     right_activations,
                     right_output_dims,
                     right_norms,
                     combiner_functions,
                     final_combiner_function,
                     num_cells,
                     nonpadding,
                     layer_registry,
                     mask_future,
                     hparams,
                     var_scope,
                     encoder_decoder_attention_bias=None,
                     encoder_cell_outputs=None,
                     decoder_self_attention_bias=None,
                     final_layer_norm=True,
                     enforce_fixed_output_sizes=True):
  """Applies layers with NasNet search space style branching.

  Args:
    input_tensor: Input [batch_size, input_length, hidden_dim] sequence tensor.
    left_inputs: Int list of left branch hidden layer input indexes.
    left_layers: String list of left branch layers.
    left_activations: String list of left branch activations.
    left_output_dims: String list of left branch output dimensions.
    left_norms: String list of left branch norms.
    right_inputs: Int list of right branch hidden layer input indexes.
    right_layers: String list of right branch layers.
    right_activations: String list of right branch activations.
    right_output_dims: String list of right branch output dimensions.
    right_norms: String list of right branch norms.
    combiner_functions: String list of branch combining functions.
    final_combiner_function: String. The final combiner function that combines
      all the unused hidden layers in a cell.
    num_cells: The number of cells. This is the number of times the given
      layers will be repeated.
    nonpadding: Tensor with 1s at all nonpadding time step positions and 0s
      everywhere else.
    layer_registry: The LayerRegistry that holds all valid layers.
    mask_future: Whether or not to mask future sequence values.
    hparams: Hyperparameters for the model.
    var_scope: The variable scope name.
    encoder_decoder_attention_bias: The attention bias for decoder attending to
      `encoder_output`.
    encoder_cell_outputs: List of tensors. The encoder cell outputs, listed in
      order.
    decoder_self_attention_bias: The self attention bias for decoders. This
      needs to be set for decoders.
    final_layer_norm: Whether or not to apply a final layer_norm to the output
      of the model.
    enforce_fixed_output_sizes: Whether or not to automatically resize output
      dimensions to match the input dimension if `should_alter_output_dim()`
      returns True.

  Raises:
    ValueError: When branching inputs are not of the same length.
    ValueError: If item in left_norms is not LAYER_NORM_KEY or NO_NORM_KEY.
    ValueError: If item in right_norms is not LAYER_NORM_KEY or NO_NORM_KEY.

  Returns:
    Output of applied layers and list of each cell's outputs in order.
  """

  if not (len(left_inputs) == len(left_layers) == len(left_activations) ==
          len(left_output_dims) == len(left_norms) == len(right_inputs) ==
          len(right_layers) == len(right_activations) == len(right_output_dims)
          == len(right_norms) == len(combiner_functions)):
    raise ValueError("All branching inputs must be of the same length.")

  cell_output = None
  modified_left_inputs = [
      left_inputs[i]
      for i in range(len(left_inputs))
      if left_layers[i] != DEAD_BRANCH_KEY
  ]
  modified_right_inputs = [
      right_inputs[i]
      for i in range(len(right_inputs))
      if right_layers[i] != DEAD_BRANCH_KEY
  ]
  unused_cell_hidden_states = [
      i for i in range(len(left_inputs) + 1)
      if i not in modified_left_inputs and i not in modified_right_inputs
  ]
  assert unused_cell_hidden_states

  cell_outputs = []

  with tf.variable_scope(var_scope):
    dropout_broadcast_dims = (
        common_layers.comma_separated_string_to_integer_list(
            getattr(hparams, "attention_dropout_broadcast_dims", "")))

    for cell_num in range(num_cells):
      # h_0 is the input tensor.
      # Keep a dict for layer norm states.
      if cell_output is not None:
        cell_hidden_states = [cell_output]
      else:
        cell_hidden_states = [input_tensor]
      layer_norm_dict = {}

      with tf.variable_scope("cell_%d" % cell_num):

        for i, (left_input, left_layer_name, left_activation_name,
                left_output_dim, left_norm, right_input, right_layer_name,
                right_activation_name, right_output_dim, right_norm,
                combiner) in enumerate(
                    zip(left_inputs, left_layers, left_activations,
                        left_output_dims, left_norms, right_inputs,
                        right_layers, right_activations, right_output_dims,
                        right_norms, combiner_functions)):
          left_input = int(left_input)
          right_input = int(right_input)

          with tf.variable_scope("layer_%d" % i):

            assert not (left_layer_name == DEAD_BRANCH_KEY and
                        right_layer_name == DEAD_BRANCH_KEY)

            if left_layer_name != DEAD_BRANCH_KEY:

              left_raw_input_tensor = cell_hidden_states[left_input]
              left_input_dim = left_raw_input_tensor.shape.as_list()[-1]
              if should_alter_output_dim(left_layer_name,
                                         enforce_fixed_output_sizes,
                                         left_input_dim, left_output_dim):
                left_output_dim = left_input_dim

              # First process the left branch.
              left_tensor = _apply_nas_branch(
                  norm=left_norm,
                  layer_norm_dict=layer_norm_dict,
                  hidden_states=cell_hidden_states,
                  nonpadding=nonpadding,
                  hparams=hparams,
                  input_index=left_input,
                  layer_name=left_layer_name,
                  activation_name=left_activation_name,
                  layer_registry=layer_registry,
                  output_dim=left_output_dim,
                  branch_scope_name="left_%s" % str(i),
                  mask_future=mask_future,
                  dropout_broadcast_dims=dropout_broadcast_dims,
                  encoder_decoder_attention_bias=encoder_decoder_attention_bias,
                  encoder_cell_outputs=encoder_cell_outputs,
                  decoder_self_attention_bias=decoder_self_attention_bias,
                  cell_number=cell_num)

            if right_layer_name != DEAD_BRANCH_KEY:
              right_raw_input_tensor = cell_hidden_states[right_input]
              right_input_dim = right_raw_input_tensor.shape.as_list()[-1]
              if should_alter_output_dim(right_layer_name,
                                         enforce_fixed_output_sizes,
                                         right_input_dim, right_output_dim):
                right_output_dim = right_input_dim
              # Next process the right branch.
              right_tensor = _apply_nas_branch(
                  norm=right_norm,
                  layer_norm_dict=layer_norm_dict,
                  hidden_states=cell_hidden_states,
                  nonpadding=nonpadding,
                  hparams=hparams,
                  input_index=right_input,
                  layer_name=right_layer_name,
                  activation_name=right_activation_name,
                  layer_registry=layer_registry,
                  output_dim=right_output_dim,
                  branch_scope_name="right_%s" % str(i),
                  mask_future=mask_future,
                  dropout_broadcast_dims=dropout_broadcast_dims,
                  encoder_decoder_attention_bias=encoder_decoder_attention_bias,
                  encoder_cell_outputs=encoder_cell_outputs,
                  decoder_self_attention_bias=decoder_self_attention_bias,
                  cell_number=cell_num)

            # Combine the branches.
            if left_layer_name == DEAD_BRANCH_KEY:
              hidden_tensor = right_tensor
            elif right_layer_name == DEAD_BRANCH_KEY:
              hidden_tensor = left_tensor
            else:
              hidden_tensor = COMBINER_FUNCTIONS[combiner]().combine_tensors(
                  [left_tensor, right_tensor])
            cell_hidden_states.append(hidden_tensor)

      states_to_combine = [
          cell_hidden_states[j] for j in unused_cell_hidden_states
      ]
      cell_output = COMBINER_FUNCTIONS[final_combiner_function](
      ).combine_tensors(states_to_combine)
      cell_outputs.append(cell_output)

  if final_layer_norm:
    final_output = common_layers.layer_preprocess(cell_output, hparams)
    cell_outputs = [
        common_layers.layer_preprocess(cell_output, hparams)
        for cell_output in cell_outputs
    ]
    return final_output, cell_outputs
  else:
    return cell_output, cell_outputs