easy_rec/python/layers/keras/mask_net.py (125 lines of code) (raw):

# -*- encoding:utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import logging import tensorflow as tf from tensorflow.python.keras.layers import Activation from tensorflow.python.keras.layers import Dense from tensorflow.python.keras.layers import Layer from easy_rec.python.layers.keras.blocks import MLP from easy_rec.python.layers.keras.layer_norm import LayerNormalization from easy_rec.python.layers.utils import Parameter class MaskBlock(Layer): """MaskBlock use in MaskNet. Args: projection_dim: project dimension to reduce the computational cost. Default is `None` such that a full (`input_dim` by `aggregation_size`) matrix W is used. If enabled, a low-rank matrix W = U*V will be used, where U is of size `input_dim` by `projection_dim` and V is of size `projection_dim` by `aggregation_size`. `projection_dim` need to be smaller than `aggregation_size`/2 to improve the model efficiency. In practice, we've observed that `projection_dim` = d/4 consistently preserved the accuracy of a full-rank version. """ def __init__(self, params, name='mask_block', reuse=None, **kwargs): super(MaskBlock, self).__init__(name=name, **kwargs) self.config = params.get_pb_config() self.l2_reg = params.l2_regularizer self._projection_dim = params.get_or_default('projection_dim', None) self.reuse = reuse self.final_relu = Activation('relu', name='relu') def build(self, input_shape): if type(input_shape) in (tuple, list): assert len(input_shape) >= 2, 'MaskBlock must has at least two inputs' input_dim = int(input_shape[0][-1]) mask_input_dim = int(input_shape[1][-1]) else: input_dim, mask_input_dim = input_shape[-1], input_shape[-1] if self.config.HasField('reduction_factor'): aggregation_size = int(mask_input_dim * self.config.reduction_factor) elif self.config.HasField('aggregation_size') is not None: aggregation_size = self.config.aggregation_size else: raise ValueError( 'Need one of reduction factor or aggregation size for MaskBlock.') self.aggr_layer = Dense( aggregation_size, activation='relu', kernel_initializer='he_uniform', kernel_regularizer=self.l2_reg, name='aggregation') self.weight_layer = Dense(input_dim, name='weights') if self._projection_dim is not None: logging.info('%s project dim is %d', self.name, self._projection_dim) self.project_layer = Dense( self._projection_dim, kernel_regularizer=self.l2_reg, use_bias=False, name='project') if self.config.input_layer_norm: # 推荐在调用MaskBlock之前做好 layer norm,否则每一次调用都需要对input做ln if tf.__version__ >= '2.0': self.input_layer_norm = tf.keras.layers.LayerNormalization( name='input_ln') else: self.input_layer_norm = LayerNormalization(name='input_ln') if self.config.HasField('output_size'): self.output_layer = Dense( self.config.output_size, use_bias=False, name='output') if tf.__version__ >= '2.0': self.output_layer_norm = tf.keras.layers.LayerNormalization( name='output_ln') else: self.output_layer_norm = LayerNormalization(name='output_ln') super(MaskBlock, self).build(input_shape) def call(self, inputs, training=None, **kwargs): if type(inputs) in (tuple, list): net, mask_input = inputs[:2] else: net, mask_input = inputs, inputs if self.config.input_layer_norm: net = self.input_layer_norm(net) if self._projection_dim is None: aggr = self.aggr_layer(mask_input) else: u = self.project_layer(mask_input) aggr = self.aggr_layer(u) weights = self.weight_layer(aggr) masked_net = net * weights if not self.config.HasField('output_size'): return masked_net hidden = self.output_layer(masked_net) ln_hidden = self.output_layer_norm(hidden) return self.final_relu(ln_hidden) class MaskNet(Layer): """MaskNet: Introducing Feature-Wise Multiplication to CTR Ranking Models by Instance-Guided Mask. Refer: https://arxiv.org/pdf/2102.07619.pdf """ def __init__(self, params, name='mask_net', reuse=None, **kwargs): super(MaskNet, self).__init__(name=name, **kwargs) self.reuse = reuse self.params = params self.config = params.get_pb_config() if self.config.HasField('mlp'): p = Parameter.make_from_pb(self.config.mlp) p.l2_regularizer = params.l2_regularizer self.mlp = MLP(p, name='mlp', reuse=reuse) else: self.mlp = None self.mask_layers = [] for i, block_conf in enumerate(self.config.mask_blocks): params = Parameter.make_from_pb(block_conf) params.l2_regularizer = self.params.l2_regularizer mask_layer = MaskBlock(params, name='block_%d' % i, reuse=self.reuse) self.mask_layers.append(mask_layer) if self.config.input_layer_norm: if tf.__version__ >= '2.0': self.input_layer_norm = tf.keras.layers.LayerNormalization( name='input_ln') else: self.input_layer_norm = LayerNormalization(name='input_ln') def call(self, inputs, training=None, **kwargs): if self.config.input_layer_norm: inputs = self.input_layer_norm(inputs) if self.config.use_parallel: mask_outputs = [ mask_layer((inputs, inputs)) for mask_layer in self.mask_layers ] all_mask_outputs = tf.concat(mask_outputs, axis=1) if self.mlp is not None: output = self.mlp(all_mask_outputs, training=training) else: output = all_mask_outputs return output else: net = inputs for i, _ in enumerate(self.config.mask_blocks): mask_layer = self.mask_layers[i] net = mask_layer((net, inputs)) if self.mlp is not None: output = self.mlp(net, training=training) else: output = net return output