in lingvo/core/attention.py [0:0]
def __init__(self, params):
"""Constructs an LocationSensitiveAttention object."""
super().__init__(params)
p = self.params
self._is_quantized = p.qdomain.default is not None
assert not p.packed_input, ('Packed input is not supported yet for '
'LocationSensitiveAttention.')
if p.atten_dropout_prob != 0:
raise NotImplementedError('dropout is not supported')
def AttenLogits(inputs):
"""Generates logits."""
fns = self.fns
def CollapseOutDim(x):
return tf.reshape(x, [-1, tf.shape(x)[-1]])
# => [sl, sb, hd]
location_feats = tf.transpose(inputs.location_feats, [2, 0, 1])
location_hidden = fns.qmatmul(
CollapseOutDim(location_feats),
inputs.location_var,
qout_name='logits_mul')
sl = py_utils.GetShape(location_feats)[0]
tb = py_utils.GetShape(location_feats)[1]
hd = py_utils.GetShape(inputs.location_var)[1]
location_hidden = tf.reshape(location_hidden, [sl, tb, hd])
sb = py_utils.GetShape(inputs.query_vec_reshaped)[2]
bs_mult = py_utils.GetShape(inputs.query_vec_reshaped)[1]
location_hidden = tf.reshape(location_hidden, [sl, bs_mult, sb, hd])
# Shape of summed is [sl, tb/sb, sb, hidden_dim].
summed = fns.qadd(
inputs.concated_source_vecs,
inputs.query_vec_reshaped,
qout_name='logits_add')
summed = fns.qadd(summed, location_hidden, qout_name='logits_bias')
summed = fns.qtanh(summed)
# logits is of shape [sl * tb/sb * sb, 1]. Computes dot product
# between v with every rows in 'summed'. Then we reshape the
# result to be of shape [sl, tb/sb, sb].
logits = fns.qmatmul(
tf.reshape(summed, [-1, p.hidden_dim]),
tf.reshape(inputs.hidden_v, [p.hidden_dim, 1]),
qout_name='logits')
logits = tf.reshape(logits, py_utils.GetShape(summed)[:3])
return logits
def AttenLogitsSameBatchSize(inputs):
"""Generates logits.
Optimized code path for when the target and the source have the same batch
size.
Args:
inputs: a NestedMap containing:
- concated_source_vecs: Tensor of shape [sl, batch, dim]
- query_vec_transformed: Tensor of shape [batch, dim]
- hidden_v: Tensor of shape [dim]
- location_feats: Tensor of shape [batch, location_feature_dim, sl]
- location_var: Tensor of shape [location_feature_dim, dim]
Returns:
logits in the shape [sl, batch_size].
"""
def CollapseOutDim(x):
return tf.reshape(x, [-1, tf.shape(x)[-1]])
fns = self.fns
# => [sl, sb, hd]
location_feats = tf.transpose(inputs.location_feats, [2, 0, 1])
location_hidden = fns.qmatmul(
CollapseOutDim(location_feats),
inputs.location_var,
qout_name='logits_mul')
sl = tf.shape(location_feats)[0]
tb = tf.shape(location_feats)[1]
hd = tf.shape(inputs.location_var)[1]
location_hidden = tf.reshape(location_hidden, [sl, tb, hd])
# Shape of summed is [sl, sb, hidden_dim].
summed = fns.qadd(
inputs.concated_source_vecs,
tf.expand_dims(inputs.query_vec_transformed, 0),
qout_name='logits_add')
summed = fns.qadd(summed, location_hidden, qout_name='logits_bias')
summed = fns.qtanh(summed)
# logits is of shape [sl * sb, 1]. Computes dot product
# between v with every rows in 'summed'. Then we reshape the
# result to be of shape [sl, tb].
logits = fns.qmatmul(
tf.reshape(summed, [-1, p.hidden_dim]),
tf.reshape(inputs.hidden_v, [p.hidden_dim, 1]),
qout_name='logits')
logits = tf.reshape(logits, py_utils.GetShape(summed)[:2])
return logits
def Atten(hidden_var, query_var, source_padding, concated_source_vecs,
concated_source_contexts, query_vec, attention_state,
location_filter_var, location_var, per_step_source_padding):
"""Computes the attention context vector."""
p = self.params
# attention_state shape [batch, len(p.location_features), slen]
# it contains previous and accumulated attention probabilites.
attention_state = py_utils.HasShape(attention_state,
[-1, len(p.location_features), -1])
fns = self.fns
location_feats = self._ApplyConv(attention_state, location_filter_var)
# concated_source_vecs is of shape [sl, sb, dims]
# concated_source_contexts is of shape [sb, sl, context_dim]
# query_vec is of shape [tb, dims]
sb = py_utils.GetShape(concated_source_vecs)[1]
tb = py_utils.GetShape(query_vec)[0]
multiplier = tb // sb
# concated_source_vecs is reshaped to [sl, 1, sb, hidden_dims]
concated_source_vecs = tf.expand_dims(concated_source_vecs, 1)
query_vec_transformed = fns.qmatmul(
query_vec, query_var, qout_name='atten_matmul')
# query_vec is reshaped to [1, tb/sb, sb, hidden_dims].
query_vec_reshaped = tf.reshape(query_vec_transformed,
[1, multiplier, sb, p.hidden_dim])
# logits is of shape [sl, tb/sb, sb]
logits = _ConditionalCallDefun(
self._is_quantized, AttenLogits,
py_utils.NestedMap(
concated_source_vecs=concated_source_vecs,
query_vec_reshaped=query_vec_reshaped,
hidden_v=hidden_var,
location_feats=location_feats,
location_var=location_var))
# Take out the padding states.
# _source_padding is of shape [sl, sb].
# reshaped to [sl, 1, sb].
source_padding = tf.expand_dims(source_padding, 1)
per_step_source_padding = tf.reshape(
tf.transpose(per_step_source_padding), [-1, multiplier, sb])
source_padding = tf.add(source_padding, per_step_source_padding)
source_padding = self.QRAct(source_padding,
quant_utils.QDistribution.PADDING)
# Reshape logits to a matrix of shape [tb, sl] and takes the
# softmax to compute the probabilities.
logits = tf.transpose(tf.reshape(logits, [-1, tb]))
source_padding = tf.transpose(tf.reshape(source_padding, [-1, tb]))
probs = self._PaddedSoftmax(logits, source_padding)
# Reshape probs to be of shape [tb/sb, sb, sl].
probs_reshaped = tf.reshape(probs, [multiplier, sb, -1])
# Transpose probs to be of shape [sb, tb/sb, sl]
probs_reshaped = tf.transpose(probs_reshaped, [1, 0, 2])
# [sb, tb/sb, sl] * [sb, sl, context_dim] = [sb, tb/sb, context_dim]
summed = fns.qbatchmatmul(
tf.cast(probs_reshaped, concated_source_contexts.dtype),
concated_source_contexts,
qout_name='atten_context')
# summed is of shape [tb/sb, sb, context_dim]
summed = tf.transpose(summed, [1, 0, 2])
return tf.reshape(summed, [tb, -1]), probs
def AttenSameBatchSize(hidden_var, query_var, source_padding,
concated_source_vecs, concated_source_contexts,
query_vec, attention_state, location_filter_var,
location_var, per_step_source_padding):
"""Computes the attention context vector.
Optimized code path for when source and target have the same batch size.
"""
del per_step_source_padding
p = self.params
# attention_state shape [batch, len(p.location_features), slen]
# it contains previous and accumulated attention probabilites.
attention_state = py_utils.HasShape(attention_state,
[-1, len(p.location_features), -1])
fns = self.fns
location_feats = self._ApplyConv(attention_state, location_filter_var)
query_vec_transformed = fns.qmatmul(
query_vec, query_var, qout_name='atten_matmul')
# logits is of shape [sl, sb]
logits = _ConditionalCallDefun(
not self._is_quantized, AttenLogitsSameBatchSize,
py_utils.NestedMap(
concated_source_vecs=concated_source_vecs,
query_vec_transformed=query_vec_transformed,
hidden_v=hidden_var,
location_feats=location_feats,
location_var=location_var))
# => [sl, tb]
logits.set_shape(source_padding.shape)
# Reshape logits to a matrix of shape [tb, sl] and takes the
# softmax to compute the probabilities.
logits = tf.transpose(logits)
source_padding = tf.transpose(source_padding)
probs = self._PaddedSoftmax(logits, source_padding)
summed = fns.qbatchmatmul(
tf.cast(tf.expand_dims(probs, 1), concated_source_contexts.dtype),
concated_source_contexts,
qout_name='atten_context')
return tf.squeeze(summed, 1), probs
if p.same_batch_size:
self._ctx_vec = AttenSameBatchSize
else:
self._ctx_vec = Atten
def EncodeSource(src_w, vecs, ctxs):
fns = self.fns
time, batch = py_utils.GetShape(vecs, 2)
ctxs = py_utils.HasShape(ctxs, [time, batch, -1])
transformed_vecs = tf.reshape(
fns.qmatmul(
tf.reshape(vecs, [-1, p.source_dim]),
src_w,
qout_name='encode_matmul'), [time, batch, -1])
transposed_ctxs = tf.transpose(ctxs, [1, 0, 2])
return transformed_vecs, transposed_ctxs
self._encode_source = EncodeSource