in tensor2tensor/models/neural_architecture_search/nas_model.py [0:0]
def apply_nas_layers(input_tensor,
left_inputs,
left_layers,
left_activations,
left_output_dims,
left_norms,
right_inputs,
right_layers,
right_activations,
right_output_dims,
right_norms,
combiner_functions,
final_combiner_function,
num_cells,
nonpadding,
layer_registry,
mask_future,
hparams,
var_scope,
encoder_decoder_attention_bias=None,
encoder_cell_outputs=None,
decoder_self_attention_bias=None,
final_layer_norm=True,
enforce_fixed_output_sizes=True):
"""Applies layers with NasNet search space style branching.
Args:
input_tensor: Input [batch_size, input_length, hidden_dim] sequence tensor.
left_inputs: Int list of left branch hidden layer input indexes.
left_layers: String list of left branch layers.
left_activations: String list of left branch activations.
left_output_dims: String list of left branch output dimensions.
left_norms: String list of left branch norms.
right_inputs: Int list of right branch hidden layer input indexes.
right_layers: String list of right branch layers.
right_activations: String list of right branch activations.
right_output_dims: String list of right branch output dimensions.
right_norms: String list of right branch norms.
combiner_functions: String list of branch combining functions.
final_combiner_function: String. The final combiner function that combines
all the unused hidden layers in a cell.
num_cells: The number of cells. This is the number of times the given
layers will be repeated.
nonpadding: Tensor with 1s at all nonpadding time step positions and 0s
everywhere else.
layer_registry: The LayerRegistry that holds all valid layers.
mask_future: Whether or not to mask future sequence values.
hparams: Hyperparameters for the model.
var_scope: The variable scope name.
encoder_decoder_attention_bias: The attention bias for decoder attending to
`encoder_output`.
encoder_cell_outputs: List of tensors. The encoder cell outputs, listed in
order.
decoder_self_attention_bias: The self attention bias for decoders. This
needs to be set for decoders.
final_layer_norm: Whether or not to apply a final layer_norm to the output
of the model.
enforce_fixed_output_sizes: Whether or not to automatically resize output
dimensions to match the input dimension if `should_alter_output_dim()`
returns True.
Raises:
ValueError: When branching inputs are not of the same length.
ValueError: If item in left_norms is not LAYER_NORM_KEY or NO_NORM_KEY.
ValueError: If item in right_norms is not LAYER_NORM_KEY or NO_NORM_KEY.
Returns:
Output of applied layers and list of each cell's outputs in order.
"""
if not (len(left_inputs) == len(left_layers) == len(left_activations) ==
len(left_output_dims) == len(left_norms) == len(right_inputs) ==
len(right_layers) == len(right_activations) == len(right_output_dims)
== len(right_norms) == len(combiner_functions)):
raise ValueError("All branching inputs must be of the same length.")
cell_output = None
modified_left_inputs = [
left_inputs[i]
for i in range(len(left_inputs))
if left_layers[i] != DEAD_BRANCH_KEY
]
modified_right_inputs = [
right_inputs[i]
for i in range(len(right_inputs))
if right_layers[i] != DEAD_BRANCH_KEY
]
unused_cell_hidden_states = [
i for i in range(len(left_inputs) + 1)
if i not in modified_left_inputs and i not in modified_right_inputs
]
assert unused_cell_hidden_states
cell_outputs = []
with tf.variable_scope(var_scope):
dropout_broadcast_dims = (
common_layers.comma_separated_string_to_integer_list(
getattr(hparams, "attention_dropout_broadcast_dims", "")))
for cell_num in range(num_cells):
# h_0 is the input tensor.
# Keep a dict for layer norm states.
if cell_output is not None:
cell_hidden_states = [cell_output]
else:
cell_hidden_states = [input_tensor]
layer_norm_dict = {}
with tf.variable_scope("cell_%d" % cell_num):
for i, (left_input, left_layer_name, left_activation_name,
left_output_dim, left_norm, right_input, right_layer_name,
right_activation_name, right_output_dim, right_norm,
combiner) in enumerate(
zip(left_inputs, left_layers, left_activations,
left_output_dims, left_norms, right_inputs,
right_layers, right_activations, right_output_dims,
right_norms, combiner_functions)):
left_input = int(left_input)
right_input = int(right_input)
with tf.variable_scope("layer_%d" % i):
assert not (left_layer_name == DEAD_BRANCH_KEY and
right_layer_name == DEAD_BRANCH_KEY)
if left_layer_name != DEAD_BRANCH_KEY:
left_raw_input_tensor = cell_hidden_states[left_input]
left_input_dim = left_raw_input_tensor.shape.as_list()[-1]
if should_alter_output_dim(left_layer_name,
enforce_fixed_output_sizes,
left_input_dim, left_output_dim):
left_output_dim = left_input_dim
# First process the left branch.
left_tensor = _apply_nas_branch(
norm=left_norm,
layer_norm_dict=layer_norm_dict,
hidden_states=cell_hidden_states,
nonpadding=nonpadding,
hparams=hparams,
input_index=left_input,
layer_name=left_layer_name,
activation_name=left_activation_name,
layer_registry=layer_registry,
output_dim=left_output_dim,
branch_scope_name="left_%s" % str(i),
mask_future=mask_future,
dropout_broadcast_dims=dropout_broadcast_dims,
encoder_decoder_attention_bias=encoder_decoder_attention_bias,
encoder_cell_outputs=encoder_cell_outputs,
decoder_self_attention_bias=decoder_self_attention_bias,
cell_number=cell_num)
if right_layer_name != DEAD_BRANCH_KEY:
right_raw_input_tensor = cell_hidden_states[right_input]
right_input_dim = right_raw_input_tensor.shape.as_list()[-1]
if should_alter_output_dim(right_layer_name,
enforce_fixed_output_sizes,
right_input_dim, right_output_dim):
right_output_dim = right_input_dim
# Next process the right branch.
right_tensor = _apply_nas_branch(
norm=right_norm,
layer_norm_dict=layer_norm_dict,
hidden_states=cell_hidden_states,
nonpadding=nonpadding,
hparams=hparams,
input_index=right_input,
layer_name=right_layer_name,
activation_name=right_activation_name,
layer_registry=layer_registry,
output_dim=right_output_dim,
branch_scope_name="right_%s" % str(i),
mask_future=mask_future,
dropout_broadcast_dims=dropout_broadcast_dims,
encoder_decoder_attention_bias=encoder_decoder_attention_bias,
encoder_cell_outputs=encoder_cell_outputs,
decoder_self_attention_bias=decoder_self_attention_bias,
cell_number=cell_num)
# Combine the branches.
if left_layer_name == DEAD_BRANCH_KEY:
hidden_tensor = right_tensor
elif right_layer_name == DEAD_BRANCH_KEY:
hidden_tensor = left_tensor
else:
hidden_tensor = COMBINER_FUNCTIONS[combiner]().combine_tensors(
[left_tensor, right_tensor])
cell_hidden_states.append(hidden_tensor)
states_to_combine = [
cell_hidden_states[j] for j in unused_cell_hidden_states
]
cell_output = COMBINER_FUNCTIONS[final_combiner_function](
).combine_tensors(states_to_combine)
cell_outputs.append(cell_output)
if final_layer_norm:
final_output = common_layers.layer_preprocess(cell_output, hparams)
cell_outputs = [
common_layers.layer_preprocess(cell_output, hparams)
for cell_output in cell_outputs
]
return final_output, cell_outputs
else:
return cell_output, cell_outputs