ma_policy/layers.py (242 lines of code) (raw):

import numpy as np import tensorflow as tf from ma_policy.util import shape_list ################# # Pooling ####### ################# def entity_avg_pooling_masked(x, mask): ''' Masks and pools x along the second to last dimension. Arguments have dimensions: x: batch x time x n_entities x n_features mask: batch x time x n_entities ''' mask = tf.expand_dims(mask, -1) masked = x * mask summed = tf.reduce_sum(masked, -2) denom = tf.reduce_sum(mask, -2) + 1e-5 return summed / denom def entity_max_pooling_masked(x, mask): ''' Masks and pools x along the second to last dimension. Arguments have dimensions: x: batch x time x n_entities x n_features mask: batch x time x n_entities ''' mask = tf.expand_dims(mask, -1) has_unmasked_entities = tf.sign(tf.reduce_sum(mask, axis=-2, keepdims=True)) offset = (mask - 1) * 1e9 masked = (x + offset) * has_unmasked_entities return tf.reduce_max(masked, -2) ################# # Contat Ops #### ################# def entity_concat(inps): ''' Concat 4D tensors along the third dimension. If a 3D tensor is in the list then treat it as a single entity and expand the third dimension Args: inps (list of tensors): tensors to concatenate ''' with tf.variable_scope('concat_entities'): shapes = [shape_list(_x) for _x in inps] # For inputs that don't have entity dimension add one. inps = [_x if len(_shape) == 4 else tf.expand_dims(_x, 2) for _x, _shape in zip(inps, shapes)] shapes = [shape_list(_x) for _x in inps] assert np.all([_shape[-1] == shapes[0][-1] for _shape in shapes]),\ f"Some entities don't have the same outer or inner dimensions {shapes}" # Concatenate along entity dimension out = tf.concat(inps, -2) return out def concat_entity_masks(inps, masks): ''' Concats masks together. If mask is None, then it creates a tensor of 1's with shape (BS, T, NE). Args: inps (list of tensors): tensors that masks apply to masks (list of tensors): corresponding masks ''' assert len(inps) == len(masks), "There should be the same number of inputs as masks" with tf.variable_scope('concat_masks'): shapes = [shape_list(_x) for _x in inps] new_masks = [] for inp, mask in zip(inps, masks): if mask is None: inp_shape = shape_list(inp) if len(inp_shape) == 4: # this is an entity tensor new_masks.append(tf.ones(inp_shape[:3])) elif len(inp_shape) == 3: # this is a pooled or main tensor. Set NE (outer dimension) to 1 new_masks.append(tf.ones(inp_shape[:2] + [1])) else: new_masks.append(mask) new_mask = tf.concat(new_masks, -1) return new_mask ################# # Transformer ### ################# def residual_sa_block(inp, mask, heads, n_embd, layer_norm=False, post_sa_layer_norm=False, n_mlp=1, qk_w=0.125, v_w=0.125, post_w=0.125, mlp_w1=0.125, mlp_w2=0.125, scope="residual_sa_block", reuse=False): ''' Residual self attention block for entities. Notation: T - Time NE - Number entities Args: inp (tf): (BS, T, NE, f) mask (tf): (BS, T, NE) heads (int) -- number of attention heads n_embd (int) -- dimension of queries, keys, and values will be n_embd / heads layer_norm (bool) -- normalize embedding prior to computing qkv n_mlp (int) -- number of mlp layers. If there are more than 1 mlp layers, we'll add a residual connection from after the first mlp to after the last mlp. qk_w, v_w, post_w, mlp_w1, mlp_w2 (float) -- scale for gaussian init for keys/queries, values, mlp post self attention, second mlp, and third mlp, respectively. Std will be sqrt(scale/n_embd) scope (string) -- tf scope reuse (bool) -- tf reuse ''' with tf.variable_scope(scope, reuse=reuse): a = self_attention(inp, mask, heads, n_embd, layer_norm=layer_norm, qk_w=qk_w, v_w=v_w, scope='self_attention', reuse=reuse) post_scale = np.sqrt(post_w / n_embd) post_a_mlp = tf.layers.dense(a, n_embd, kernel_initializer=tf.random_normal_initializer(stddev=post_scale), name="mlp1") x = inp + post_a_mlp if post_sa_layer_norm: with tf.variable_scope('post_a_layernorm'): x = tf.contrib.layers.layer_norm(x, begin_norm_axis=3) if n_mlp > 1: mlp = x mlp2_scale = np.sqrt(mlp_w1 / n_embd) mlp = tf.layers.dense(mlp, n_embd, kernel_initializer=tf.random_normal_initializer(stddev=mlp2_scale), name="mlp2") if n_mlp > 2: mlp3_scale = np.sqrt(mlp_w2 / n_embd) mlp = tf.layers.dense(mlp, n_embd, kernel_initializer=tf.random_normal_initializer(stddev=mlp3_scale), name="mlp3") if n_mlp > 1: x = x + mlp return x def self_attention(inp, mask, heads, n_embd, layer_norm=False, qk_w=1.0, v_w=0.01, scope='', reuse=False): ''' Self attention over entities. Notation: T - Time NE - Number entities Args: inp (tf) -- tensor w/ shape (bs, T, NE, features) mask (tf) -- binary tensor with shape (bs, T, NE). For each batch x time, nner matrix represents entity i's ability to see entity j heads (int) -- number of attention heads n_embd (int) -- dimension of queries, keys, and values will be n_embd / heads layer_norm (bool) -- normalize embedding prior to computing qkv qk_w, v_w (float) -- scale for gaussian init for keys/queries and values Std will be sqrt(scale/n_embd) scope (string) -- tf scope reuse (bool) -- tf reuse ''' with tf.variable_scope(scope, reuse=reuse): bs, T, NE, features = shape_list(inp) # Put mask in format correct for logit matrix entity_mask = None if mask is not None: with tf.variable_scope('expand_mask'): assert np.all(np.array(mask.get_shape().as_list()) == np.array(inp.get_shape().as_list()[:3])),\ f"Mask and input should have the same first 3 dimensions. {shape_list(mask)} -- {shape_list(inp)}" entity_mask = mask mask = tf.expand_dims(mask, -2) # (BS, T, 1, NE) query, key, value = qkv_embed(inp, heads, n_embd, layer_norm=layer_norm, qk_w=qk_w, v_w=v_w, reuse=reuse) logits = tf.matmul(query, key, name="matmul_qk_parallel") # (bs, T, heads, NE, NE) logits /= np.sqrt(n_embd / heads) softmax = stable_masked_softmax(logits, mask) att_sum = tf.matmul(softmax, value, name="matmul_softmax_value") # (bs, T, heads, NE, features) with tf.variable_scope('flatten_heads'): out = tf.transpose(att_sum, (0, 1, 3, 2, 4)) # (bs, T, n_output_entities, heads, features) n_output_entities = shape_list(out)[2] out = tf.reshape(out, (bs, T, n_output_entities, n_embd)) # (bs, T, n_output_entities, n_embd) return out def stable_masked_softmax(logits, mask): ''' Args: logits (tf): tensor with shape (bs, T, heads, NE, NE) mask (tf): tensor with shape(bs, T, 1, NE) ''' with tf.variable_scope('stable_softmax'): # Subtract a big number from the masked logits so they don't interfere with computing the max value if mask is not None: mask = tf.expand_dims(mask, 2) logits -= (1.0 - mask) * 1e10 # Subtract the max logit from everything so we don't overflow logits -= tf.reduce_max(logits, axis=-1, keepdims=True) unnormalized_p = tf.exp(logits) # Mask the unnormalized probibilities and then normalize and remask if mask is not None: unnormalized_p *= mask normalized_p = unnormalized_p / (tf.reduce_sum(unnormalized_p, axis=-1, keepdims=True) + 1e-10) if mask is not None: normalized_p *= mask return normalized_p def qkv_embed(inp, heads, n_embd, layer_norm=False, qk_w=1.0, v_w=0.01, reuse=False): ''' Compute queries, keys, and values Args: inp (tf) -- tensor w/ shape (bs, T, NE, features) heads (int) -- number of attention heads n_embd (int) -- dimension of queries, keys, and values will be n_embd / heads layer_norm (bool) -- normalize embedding prior to computing qkv qk_w (float) -- Initialization scale for keys and queries. Actual scale will be sqrt(qk_w / #input features) v_w (float) -- Initialization scale for values. Actual scale will be sqrt(v_w / #input features) reuse (bool) -- tf reuse ''' with tf.variable_scope('qkv_embed'): bs, T, NE, features = shape_list(inp) if layer_norm: with tf.variable_scope('pre_sa_layer_norm'): inp = tf.contrib.layers.layer_norm(inp, begin_norm_axis=3) # qk shape (bs x T x NE x h x n_embd/h) qk_scale = np.sqrt(qk_w / features) qk = tf.layers.dense(inp, n_embd * 2, kernel_initializer=tf.random_normal_initializer(stddev=qk_scale), reuse=reuse, name="qk_embed") # bs x T x n_embd*2 qk = tf.reshape(qk, (bs, T, NE, heads, n_embd // heads, 2)) # (bs, T, NE, heads, features) query, key = [tf.squeeze(x, -1) for x in tf.split(qk, 2, -1)] v_scale = np.sqrt(v_w / features) value = tf.layers.dense(inp, n_embd, kernel_initializer=tf.random_normal_initializer(stddev=v_scale), reuse=reuse, name="v_embed") # bs x T x n_embd value = tf.reshape(value, (bs, T, NE, heads, n_embd // heads)) query = tf.transpose(query, (0, 1, 3, 2, 4), name="transpose_query") # (bs, T, heads, NE, n_embd / heads) key = tf.transpose(key, (0, 1, 3, 4, 2), name="transpose_key") # (bs, T, heads, n_embd / heads, NE) value = tf.transpose(value, (0, 1, 3, 2, 4), name="transpose_value") # (bs, T, heads, NE, n_embd / heads) return query, key, value ################## # 1D Convolution # ################## def circ_conv1d(inp, **conv_kwargs): valid_activations = {'relu': tf.nn.relu, 'tanh': tf.tanh, '': None} assert 'kernel_size' in conv_kwargs, f"Kernel size needs to be specified for circular convolution layer." conv_kwargs['activation'] = valid_activations[conv_kwargs['activation']] # concatenate input for circular convolution kernel_size = conv_kwargs['kernel_size'] num_pad = kernel_size // 2 inp_shape = shape_list(inp) inp_rs = tf.reshape(inp, shape=[inp_shape[0] * inp_shape[1]] + inp_shape[2:]) # (BS * T, NE, feats) inp_padded = tf.concat([inp_rs[..., -num_pad:, :], inp_rs, inp_rs[..., :num_pad, :]], -2) out = tf.layers.conv1d(inp_padded, kernel_initializer=tf.contrib.layers.xavier_initializer(), padding='valid', **conv_kwargs) out = tf.reshape(out, shape=inp_shape[:3] + [conv_kwargs['filters']]) return out ################## # Misc ########### ################## def layernorm(x, scope, epsilon=1e-5, reuse=False): ''' normalize state vector to be zero mean / unit variance + learned scale/shift ''' with tf.variable_scope(scope, reuse=reuse): n_state = x.get_shape()[-1] gain = tf.get_variable('gain', [n_state], initializer=tf.constant_initializer(1)) bias = tf.get_variable('bias', [n_state], initializer=tf.constant_initializer(0)) mean = tf.reduce_mean(x, axis=[-1], keep_dims=True) variance = tf.reduce_mean(tf.square(x - mean), axis=[-1], keep_dims=True) norm_x = (x - mean) * tf.rsqrt(variance + epsilon) return norm_x * gain + bias