in lingvo/core/gshard_layers.py [0:0]
def Top2GatingOnLogits(inputs,
paddings,
logits,
num_devices,
experts_dim,
expert_capacity_dim,
fprop_dtype,
use_xla_sharding=True,
second_expert_policy='all',
second_expert_threshold=0.0,
legacy_mtf_behavior=True,
capacity_factor=None,
importance=None,
mask_dtype=None):
"""Computes Top-2 gating for Mixture-of-Experts.
There are two expected usages of this function:
1. used with xla_sharding. In this case, 'inputs' corresponds to a sharded
tensor across multiple tpu cores. The operations within this function are
automatically sharded/replicated across tpu cores.
2. used within other projects where'inputs' is always local to one tpu
core. All computations below are carried out on one tpu core only. This
function tries to dispatch examples across tpu cores in such a way that
each expert is assigned no more than 'expert_capacity_dim' number of
examples.
Below ` indicates common way of splitting along mesh dimension.
Dimensions cheat sheet::
G: group_dim
S: group_size_dim
E: number of experts
C: capacity per expert
M: model_dim (same as input_dim, same as output_dim)
B: original batch_dim
L: original sequence_length_dim
Note that for local_dispatch original batch BLM is reshaped into GSM, each
group `g = 0...G-1` is being dispatched independently.
Args:
inputs: G`SM Tensor.
paddings: G`S Tensor.
logits: G`SE Tensor.
num_devices: number of MoE devices for local dispatch
experts_dim: number of experts.
expert_capacity_dim: number of examples per minibatch(group) per expert.
Each example is typically a vector of size input_dim, representing
embedded token or an element of Transformer layer output.
fprop_dtype: activations datatype to use.
use_xla_sharding: bool, True if this function is used for the xla_sharding
case.
second_expert_policy: 'all', 'sampling' or 'random'.
- 'all': we greedily pick the 2nd expert.
- 'sampling': we sample the 2nd expert from the softmax.
- 'random': we optionally 'random'-ize dispatch to second-best expert
proportional to (weight / second_expert_threshold).
second_expert_threshold: threshold for probability normalization for
second_expert_policy == 'random'.
legacy_mtf_behavior: bool, True if to match legacy mtf behavior exactly.
capacity_factor: if set, increases expert_capacity_dim to at least
(group_size * capacity_factor) / experts_dim
where `group_size` is the size of G dimension of `inputs`. If the
value of expert_capacity_dim is already big enough no change is made.
importance: input importance weights for routing (G`S Tensor or None).
mask_dtype: using bfloat16 for fprop_dtype could be problematic for mask
tensors, mask_dtype is a special dtype for such tensors.
TODO(lepikhin): get rid of the legacy_mtf_behavior flag.
Returns:
A tuple (aux_loss, combine_tensor, dispatch_tensor).
- aux_loss: auxiliary loss, for equalizing the expert assignment ratios.
- combine_tensor: G`SEC Tensor for combining expert outputs.
- dispatch_tensor: G`SEC Tensor, scattering/dispatching inputs to
experts.
"""
if mask_dtype is None:
mask_dtype = fprop_dtype
if use_xla_sharding:
tf.logging.warning('Sharding propagation should be sufficient and Splits '
'within Top2GatingOnLogits are generally redundant.')
del inputs # inputs is currently not used.
# logits.dtype could be tf.float32
raw_gates = tf.nn.softmax(logits) # along E dim
if raw_gates.dtype != fprop_dtype:
raw_gates = tf.cast(raw_gates, fprop_dtype)
has_capacity_factor = (capacity_factor is not None and capacity_factor > 0)
if not has_capacity_factor or (expert_capacity_dim != 0):
tf.logging.warning(
'Please set expert_capacity_dim=0 '
'and non-zero capacity_factor '
'expert_capacity_dim=%s '
'capacity_factor=%s', expert_capacity_dim, capacity_factor)
if has_capacity_factor:
# Determine expert capacity automatically depending on the input size
group_size_dim = int(logits.shape[1])
auto_expert_capacity = int((group_size_dim * capacity_factor) / experts_dim)
if auto_expert_capacity == 0:
auto_expert_capacity = 4
tf.logging.info('Setting min value to auto_expert_capacity=%s',
auto_expert_capacity)
if expert_capacity_dim < auto_expert_capacity:
expert_capacity_dim = auto_expert_capacity
# Round up to a multiple of 4 to avoid possible padding.
while expert_capacity_dim % 4:
expert_capacity_dim += 1
tf.logging.info(
'Setting expert_capacity_dim=%r (capacity_factor=%r '
'group_size_dim=%r experts_dim=%r name_scope=%r)',
expert_capacity_dim, capacity_factor, group_size_dim, experts_dim,
tf.get_default_graph().get_name_scope())
tpu_summary.scalar('expert_capacity', expert_capacity_dim)
# top first and second gate value and expert index for each input
#
# GSK Tensors, K=2
def _MaybeSplit(x):
if use_xla_sharding:
return gshard_utils.Split(x, 0, num_devices)
else:
return x
def _CreateOverCapacityRatioSummary(mask, position_in_expert, capacity, name):
with tf.name_scope('over_capacity'):
ge_capacity = tf.greater_equal(mask * position_in_expert, capacity)
over_capacity = tf.reduce_sum(tf.cast(ge_capacity, tf.float32))
over_capacity_ratio = over_capacity / tf.maximum(
tf.constant(1.0, dtype=tf.float32),
tf.cast(tf.reduce_sum(mask), tf.float32))
py_utils.AddTpuSummaryTensor(name, over_capacity)
tpu_summary.scalar(name, over_capacity, while_loop_reduce='sum')
name = name + '_ratio'
py_utils.AddTpuSummaryTensor(name, over_capacity_ratio)
tpu_summary.scalar(name, over_capacity_ratio, while_loop_reduce='mean')
# As pointed out by zhifengc@ this method needs to be refactored. lepikhin@
# and krikun@ will:
# - expand moe_spmd_test to compare Adafactor updates, slots on TPU
# including 2x2 with sharding
#
# - add more tests for policy="random"
#
# - add single step test for full size WMT model on CPU
#
# and then break this function into modules.
#
# GS
index_1 = tf.math.argmax(raw_gates, axis=-1, output_type=tf.int32)
index_1 = _MaybeSplit(index_1)
tpu_summary.tensor('index_1', index_1)
# GSE
mask_1 = tf.one_hot(index_1, experts_dim, dtype=mask_dtype)
mask_1 = _MaybeSplit(mask_1)
density_1_proxy = raw_gates
if importance is not None:
importance_is_one = tf.equal(importance, 1.0)
mask_1 *= tf.expand_dims(tf.cast(importance_is_one, mask_1.dtype), -1)
density_1_proxy *= tf.expand_dims(
tf.cast(importance_is_one, density_1_proxy.dtype), -1)
else:
if len(mask_1.shape) == 3:
importance = tf.ones_like(mask_1[:, :, 0])
else:
importance = tf.ones_like(mask_1[:, :, :, 0])
if paddings is not None:
nonpaddings = 1.0 - paddings
mask_1 *= tf.expand_dims(tf.cast(nonpaddings, mask_1.dtype), -1)
density_1_proxy *= tf.expand_dims(
tf.cast(nonpaddings, density_1_proxy.dtype), -1)
importance = nonpaddings
gate_1 = tf.einsum('...GSE,...GSE->...GS', raw_gates,
tf.cast(mask_1, raw_gates.dtype))
gates_without_top_1 = raw_gates * (1.0 - tf.cast(mask_1, raw_gates.dtype))
if second_expert_policy == 'sampling':
# We directly sample the 2nd expert index from the softmax over of the 2nd
# expert by getting rid of the 1st expert already selected above. To do so,
# we set a very negative value to the logit corresponding to the 1st expert.
# Then we sample from the softmax (categorical) distribution using the
# Gumbel max trick.
noise = _MaybeSplit(tf.random.uniform(logits.shape, dtype=logits.dtype))
# Generates standard Gumbel(0, 1) noise, GSE Tensors
noise = -tf.math.log(-tf.math.log(noise))
very_negative_logits = _MaybeSplit(
(tf.ones_like(logits) * logits.dtype.max *
tf.constant(-0.7, dtype=logits.dtype)))
# Gets rid of the first expert by setting its logit to be very negative
updated_logits = _MaybeSplit(
tf.where(mask_1 > 0.0, very_negative_logits, logits))
# Adds the Gumbel noise to the updated logits
noised_logits = _MaybeSplit(updated_logits + noise)
# Picks the index of the largest noised logit as the 2nd expert. This is
# equivalent to sampling from the softmax over the 2nd experts.
index_2 = tf.math.argmax(noised_logits, axis=-1, output_type=tf.int32)
else:
index_2 = tf.math.argmax(gates_without_top_1, axis=-1, output_type=tf.int32)
index_2 = _MaybeSplit(index_2)
mask_2 = tf.one_hot(index_2, experts_dim, dtype=mask_dtype)
mask_2 = _MaybeSplit(mask_2)
if paddings is not None:
importance_is_nonzero = tf.greater(importance, 0.0)
mask_2 *= tf.expand_dims(tf.cast(importance_is_nonzero, mask_2.dtype), -1)
gate_2 = tf.einsum('...GSE,...GSE->...GS', gates_without_top_1,
tf.cast(mask_2, gates_without_top_1.dtype))
if legacy_mtf_behavior:
# cl/298510175 moved this branch for gate_{1,2} denom calculation here.
#
# For policy=random, it's better to nomalize gate_{1,2} before taking
# capacity into account and before potentially dropping second expert.
#
# According to mean_xent:
# MoE_512_102xen_PolicyAll_298510175
# MoE_512_102xen_PolicyRandom_298510175
#
# vs pre-cl/298510175
# MoE_512_102xen_PolicyRandom
# MoE_512_102xen_PolicyAll
#
# it substantially improves policy=random with threshold=0.5 which
# historically was better than policy="all"
#
# Also confirmed this by decoding
# nmt_train/m4/data/es_en/test.txt
# nmt_train/m4/data/ru_en/test.txt
# nmt_train/m4/data/zh_en/test.txt
# and improving BLEU
#
# moe_decode.MoE_512_102xen_PolicyRandom_298510175-160000.batch1024.beam4.c_dim4.ln0.8.rkv.mteval102
# 0.421443
# 0.327102
# 0.315693
# vs
# moe_decode.feb18_non_fig_snapshot_2626_MoE_512_102xen_PolicyRandom-190000.batch1024.beam4.c_dim4.ln0.8.rkv.mteval102
# 0.399232
# 0.310606
# 0.288229
#
# Additional comparison, see mean_xent with
# legacy_mtf_behavior=False models
# 3 - MoE_512_102xen_PolicyAll_LegacyFalse
# 6 - MoE_512_102xen_PolicyRandom_LegacyFalse
# shows that policy="random" gets worse with legacy_mtf_behavior=False, and
# is similar to pre-cl/298510175
# 4 - MoE_512_102xen_PolicyRandom
#
# gate_1 can become 0 due to Expert being out of capacity.
#
# gate_2 can become 0 due to
# second_expert_policy == 'random'
# or "out of capacity" scenario.
#
# Here we renormalize regardless of cases above.
denom = gate_1 + gate_2 + 1e-9
gate_1 /= denom
gate_2 /= denom
# We reshape the mask as [X*S, E], and compute cumulative sums of
# assignment indicators for each expert index e \in 0..E-1 independently.
# First occurrence of assignment indicator is excluded, see exclusive=True
# flag below.
#
# tf.cumsum over S dim: mask_1 is ...GSE tensor. Pontentially with outer_dim
# O.
position_in_expert_1 = tf.cumsum(mask_1, exclusive=True, axis=-2)
# GS Tensor
capacity = tf.cast(expert_capacity_dim, dtype=position_in_expert_1.dtype)
# GE Tensor (reducing S out of GSE tensor mask_1)
# density_1[:, e] represents assignment ratio (num assigned / total) to
# expert e as top_1 expert without taking capacity into account.
assert importance.dtype == fprop_dtype
if legacy_mtf_behavior:
density_denom = 1.0
else:
density_denom = tf.reduce_mean(importance, axis=(1))[:, tf.newaxis] + 1e-6
density_1 = tf.reduce_mean(
tf.cast(mask_1, fprop_dtype), axis=-2) / density_denom
# density_1_proxy[:, e] represents mean of raw_gates for expert e, including
# those of examples not assigned to e with top_k.
assert density_1_proxy.dtype == fprop_dtype
density_1_proxy = tf.reduce_mean(density_1_proxy, axis=-2) / density_denom
with tf.name_scope('aux_loss'):
# The MoE paper (https://arxiv.org/pdf/1701.06538.pdf) uses an aux loss of
# reduce_mean(density_1_proxy * density_1_proxy). Here we replace one of
# the density_1_proxy with the discrete density_1 following mesh_tensorflow.
aux_loss = tf.reduce_mean(density_1_proxy * density_1) # element-wise
aux_loss *= experts_dim * experts_dim # const coefficient
# Add the over capacity ratio for expert 1
_CreateOverCapacityRatioSummary(mask_1, position_in_expert_1, capacity,
'over_capacity_1')
mask_1 *= tf.cast(tf.less(position_in_expert_1, capacity), dtype=mask_1.dtype)
position_in_expert_1 = tf.einsum('...GSE,...GSE->...GS', position_in_expert_1,
mask_1)
# How many examples in this sequence go to this expert
mask_1_count = tf.einsum('...GSE->...GE', mask_1)
# [batch, group] - mostly ones, but zeros where something didn't fit
mask_1_flat = tf.einsum('...GSE->...GS', mask_1)
assert mask_1_count.dtype == mask_dtype
assert mask_1_flat.dtype == mask_dtype
if second_expert_policy == 'all' or second_expert_policy == 'sampling':
pass
elif second_expert_policy == 'random':
# gate_2 is between 0 and 1, reminder:
#
# raw_gates = tf.nn.softmax(logits)
# index_1 = tf.math.argmax(raw_gates, axis=-1, output_type=tf.int32)
# mask_1 = tf.one_hot(index_1, experts_dim, dtype=fprop_dtype)
# gate_1 = tf.einsum('GSE,GSE->GS', raw_gates, mask_1)
#
# E.g. if gate_2 exceeds second_expert_threshold, then we definitely
# dispatch to second-best expert. Otherwise we dispatch with probability
# proportional to (gate_2 / threshold).
#
sampled_2 = tf.less(
_MaybeSplit(tf.random.uniform(gate_2.shape, dtype=gate_2.dtype)),
(gate_2 / max(second_expert_threshold, 1e-9)))
gate_2 *= tf.cast(sampled_2, gate_2.dtype)
mask_2 *= tf.cast(tf.expand_dims(sampled_2, -1), mask_2.dtype)
else:
raise ValueError(second_expert_policy)
position_in_expert_2 = tf.cumsum(
mask_2, exclusive=True, axis=-2) + tf.expand_dims(mask_1_count, -2)
# Add the over capacity ratio for expert 2
_CreateOverCapacityRatioSummary(mask_2, position_in_expert_2, capacity,
'over_capacity_2')
mask_2 *= tf.cast(tf.less(position_in_expert_2, capacity), mask_2.dtype)
position_in_expert_2 = tf.einsum('...GSE,...GSE->...GS', position_in_expert_2,
mask_2)
mask_2_flat = tf.reduce_sum(mask_2, axis=-1)
# Equivalent non-einsum implementation:
#
# position_in_expert_2 *= mask_2
# position_in_expert_2 = tf.reduce_sum(
# position_in_expert_2, axis=-1, name='position_in_expert_2')
gate_1 *= tf.cast(mask_1_flat, gate_1.dtype)
gate_2 *= tf.cast(mask_2_flat, gate_2.dtype)
if not legacy_mtf_behavior:
denom = gate_1 + gate_2
# To avoid divide by 0.
denom = tf.where(denom > 0, denom, tf.ones_like(denom))
gate_1 /= denom
gate_2 /= denom
# GSC Tensor
assert position_in_expert_1.dtype == mask_dtype # could be float32 in tests
b = tf.one_hot(
tf.cast(position_in_expert_1, dtype=tf.int32),
expert_capacity_dim,
dtype=fprop_dtype,
name='one_hot_b_0')
# GSE Tensor
a = tf.expand_dims(gate_1 * tf.cast(mask_1_flat, fprop_dtype),
-1) * tf.one_hot(
index_1, experts_dim, dtype=fprop_dtype)
# GSEC Tensor
first_part_of_combine_tensor = tf.einsum(
'...GSE,...GSC->...GSEC', a, b, name='first_part_of_combine_tensor')
# GSC Tensor
assert position_in_expert_2.dtype == mask_dtype # could be float32 in tests
b = tf.one_hot(
tf.cast(position_in_expert_2, dtype=tf.int32),
expert_capacity_dim,
dtype=fprop_dtype,
name='one_hot_b_1')
# GSE Tensor
a = tf.expand_dims(gate_2 * tf.cast(mask_2_flat, fprop_dtype),
-1) * tf.one_hot(
index_2, experts_dim, dtype=fprop_dtype)
second_part_of_combine_tensor = tf.einsum(
'...GSE,...GSC->...GSEC', a, b, name='second_part_of_combine_tensor')
# GSEC Tensor
combine_tensor = tf.math.add(
first_part_of_combine_tensor,
second_part_of_combine_tensor,
name='combine_tensor')
combine_tensor = _MaybeSplit(combine_tensor)
# GSEC Tensor
dispatch_tensor = tf.cast(
tf.cast(combine_tensor, tf.bool), fprop_dtype, name='dispatch_tensor')
dispatch_tensor = _MaybeSplit(dispatch_tensor)
# TODO(yonghui): compute and return per-group aux_loss.
return aux_loss, combine_tensor, dispatch_tensor