in training/flax/distil_whisper/layers.py [0:0]
def __call__(self, inputs: Array) -> Array:
"""Embeds the inputs along the last dimension.
Args:
inputs: input data, all dimensions are considered batch dimensions.
Returns:
Output which is embedded input data. The output shape follows the input,
with an additional `features` dimension appended.
"""
if self.cast_input_dtype:
inputs = inputs.astype(self.cast_input_dtype)
if not jnp.issubdtype(inputs.dtype, jnp.integer):
raise ValueError("Input type must be an integer or unsigned integer.")
if self.one_hot:
iota = lax.iota(jnp.int32, self.num_embeddings)
one_hot = jnp.array(inputs[..., jnp.newaxis] == iota, dtype=self.dtype)
output = jnp.dot(one_hot, jnp.asarray(self.embedding, self.dtype))
else:
output = jnp.asarray(self.embedding, self.dtype)[inputs]
output = with_sharding_constraint(output, ("batch", "length", "embed"))
return output