in xformers/factory/block_factory.py [0:0]
def __init__(self, config: xFormerEncoderConfig, **kwargs):
super().__init__()
self.reversible_f = None
self.reversible_g = None
self.layer_norm_style = config.layer_norm_style
self.dim_model = config.dim_model
# If this layer is the first one, and a pose encoding has been requested
self.pose_encoding = (
build_positional_embedding(asdict(config.position_encoding_config))
if config.position_encoding_config and config.layer_position.is_first()
else None
)
# mini helper, builds a LayerNorm with the right Pre/Post config, residuals, and the right dimensions
ln_factory = _get_ln_factory(
config.dim_model, config.layer_norm_style, use_triton=config.use_triton
)
self.mha = build_multi_head_attention(config.multi_head_config)
self.feedforward = build_feedforward(asdict(config.feedforward_config))
# Wrappers handle the different layer norm styles (pre- and post-) and the residual path
self.wrap_att = ln_factory(self.mha)
self.wrap_ff: Union[Residual, PostNorm] = ln_factory(self.feedforward)
if (
config.layer_norm_style == LayerNormStyle.Pre
and config.layer_position.is_last()
):
self.wrap_ff = PostNorm(config.dim_model, self.wrap_ff)