easy_rec/python/layers/keras/interaction.py (230 lines of code) (raw):

# -*- encoding:utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import tensorflow as tf from easy_rec.python.utils.activation import get_activation class FM(tf.keras.layers.Layer): """Factorization Machine models pairwise (order-2) feature interactions without linear term and bias. References - [Factorization Machines](https://www.csie.ntu.edu.tw/~b97053/paper/Rendle2010FM.pdf) Input shape. - List of 2D tensor with shape: ``(batch_size,embedding_size)``. - Or a 3D tensor with shape: ``(batch_size,field_size,embedding_size)`` Output shape - 2D tensor with shape: ``(batch_size, 1)``. """ def __init__(self, params, name='fm', reuse=None, **kwargs): super(FM, self).__init__(name=name, **kwargs) self.use_variant = params.get_or_default('use_variant', False) def call(self, inputs, **kwargs): if type(inputs) == list: emb_dims = set(map(lambda x: int(x.shape[-1]), inputs)) if len(emb_dims) != 1: dims = ','.join([str(d) for d in emb_dims]) raise ValueError('all embedding dim must be equal in FM layer:' + dims) with tf.name_scope(self.name): fea = tf.stack(inputs, axis=1) else: assert inputs.shape.ndims == 3, 'input of FM layer must be a 3D tensor or a list of 2D tensors' fea = inputs with tf.name_scope(self.name): square_of_sum = tf.square(tf.reduce_sum(fea, axis=1)) sum_of_square = tf.reduce_sum(tf.square(fea), axis=1) cross_term = tf.subtract(square_of_sum, sum_of_square) if self.use_variant: cross_term = 0.5 * cross_term else: cross_term = 0.5 * tf.reduce_sum(cross_term, axis=-1, keepdims=True) return cross_term class DotInteraction(tf.keras.layers.Layer): """Dot interaction layer of DLRM model.. See theory in the DLRM paper: https://arxiv.org/pdf/1906.00091.pdf, section 2.1.3. Sparse activations and dense activations are combined. Dot interaction is applied to a batch of input Tensors [e1,...,e_k] of the same dimension and the output is a batch of Tensors with all distinct pairwise dot products of the form dot(e_i, e_j) for i <= j if self self_interaction is True, otherwise dot(e_i, e_j) i < j. Attributes: self_interaction: Boolean indicating if features should self-interact. If it is True, then the diagonal entries of the interaction metric are also taken. skip_gather: An optimization flag. If it's set then the upper triangle part of the dot interaction matrix dot(e_i, e_j) is set to 0. The resulting activations will be of dimension [num_features * num_features] from which half will be zeros. Otherwise activations will be only lower triangle part of the interaction matrix. The later saves space but is much slower. name: String name of the layer. """ def __init__(self, params, name=None, reuse=None, **kwargs): super(DotInteraction, self).__init__(name=name, **kwargs) self._self_interaction = params.get_or_default('self_interaction', False) self._skip_gather = params.get_or_default('skip_gather', False) def call(self, inputs, **kwargs): """Performs the interaction operation on the tensors in the list. The tensors represent as transformed dense features and embedded categorical features. Pre-condition: The tensors should all have the same shape. Args: inputs: List of features with shapes [batch_size, feature_dim]. Returns: activations: Tensor representing interacted features. It has a dimension `num_features * num_features` if skip_gather is True, otherside `num_features * (num_features + 1) / 2` if self_interaction is True and `num_features * (num_features - 1) / 2` if self_interaction is False. """ if isinstance(inputs, (list, tuple)): # concat_features shape: batch_size, num_features, feature_dim try: concat_features = tf.stack(inputs, axis=1) except (ValueError, tf.errors.InvalidArgumentError) as e: raise ValueError('Input tensors` dimensions must be equal, original' 'error message: {}'.format(e)) else: assert inputs.shape.ndims == 3, 'input of dot func must be a 3D tensor or a list of 2D tensors' concat_features = inputs batch_size = tf.shape(concat_features)[0] # Interact features, select lower-triangular portion, and re-shape. xactions = tf.matmul(concat_features, concat_features, transpose_b=True) num_features = xactions.shape[-1] ones = tf.ones_like(xactions) if self._self_interaction: # Selecting lower-triangular portion including the diagonal. lower_tri_mask = tf.linalg.band_part(ones, -1, 0) upper_tri_mask = ones - lower_tri_mask out_dim = num_features * (num_features + 1) // 2 else: # Selecting lower-triangular portion not included the diagonal. upper_tri_mask = tf.linalg.band_part(ones, 0, -1) lower_tri_mask = ones - upper_tri_mask out_dim = num_features * (num_features - 1) // 2 if self._skip_gather: # Setting upper triangle part of the interaction matrix to zeros. activations = tf.where( condition=tf.cast(upper_tri_mask, tf.bool), x=tf.zeros_like(xactions), y=xactions) out_dim = num_features * num_features else: activations = tf.boolean_mask(xactions, lower_tri_mask) activations = tf.reshape(activations, (batch_size, out_dim)) return activations class Cross(tf.keras.layers.Layer): """Cross Layer in Deep & Cross Network to learn explicit feature interactions. A layer that creates explicit and bounded-degree feature interactions efficiently. The `call` method accepts `inputs` as a tuple of size 2 tensors. The first input `x0` is the base layer that contains the original features (usually the embedding layer); the second input `xi` is the output of the previous `Cross` layer in the stack, i.e., the i-th `Cross` layer. For the first `Cross` layer in the stack, x0 = xi. The output is x_{i+1} = x0 .* (W * xi + bias + diag_scale * xi) + xi, where .* designates elementwise multiplication, W could be a full-rank matrix, or a low-rank matrix U*V to reduce the computational cost, and diag_scale increases the diagonal of W to improve training stability ( especially for the low-rank case). References: 1. [R. Wang et al.](https://arxiv.org/pdf/2008.13535.pdf) See Eq. (1) for full-rank and Eq. (2) for low-rank version. 2. [R. Wang et al.](https://arxiv.org/pdf/1708.05123.pdf) Example: ```python # after embedding layer in a functional model: input = tf.keras.Input(shape=(None,), name='index', dtype=tf.int64) x0 = tf.keras.layers.Embedding(input_dim=32, output_dim=6) x1 = Cross()(x0, x0) x2 = Cross()(x0, x1) logits = tf.keras.layers.Dense(units=10)(x2) model = tf.keras.Model(input, logits) ``` Args: projection_dim: project dimension to reduce the computational cost. Default is `None` such that a full (`input_dim` by `input_dim`) matrix W is used. If enabled, a low-rank matrix W = U*V will be used, where U is of size `input_dim` by `projection_dim` and V is of size `projection_dim` by `input_dim`. `projection_dim` need to be smaller than `input_dim`/2 to improve the model efficiency. In practice, we've observed that `projection_dim` = d/4 consistently preserved the accuracy of a full-rank version. diag_scale: a non-negative float used to increase the diagonal of the kernel W by `diag_scale`, that is, W + diag_scale * I, where I is an identity matrix. use_bias: whether to add a bias term for this layer. If set to False, no bias term will be used. preactivation: Activation applied to output matrix of the layer, before multiplication with the input. Can be used to control the scale of the layer's outputs and improve stability. kernel_initializer: Initializer to use on the kernel matrix. bias_initializer: Initializer to use on the bias vector. kernel_regularizer: Regularizer to use on the kernel matrix. bias_regularizer: Regularizer to use on bias vector. Input shape: A tuple of 2 (batch_size, `input_dim`) dimensional inputs. Output shape: A single (batch_size, `input_dim`) dimensional output. """ def __init__(self, params, name='cross', reuse=None, **kwargs): super(Cross, self).__init__(name=name, **kwargs) self._projection_dim = params.get_or_default('projection_dim', None) self._diag_scale = params.get_or_default('diag_scale', 0.0) self._use_bias = params.get_or_default('use_bias', True) preactivation = params.get_or_default('preactivation', None) preact = get_activation(preactivation) self._preactivation = tf.keras.activations.get(preact) kernel_initializer = params.get_or_default('kernel_initializer', 'truncated_normal') self._kernel_initializer = tf.keras.initializers.get(kernel_initializer) bias_initializer = params.get_or_default('bias_initializer', 'zeros') self._bias_initializer = tf.keras.initializers.get(bias_initializer) kernel_regularizer = params.get_or_default('kernel_regularizer', None) self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer) bias_regularizer = params.get_or_default('bias_regularizer', None) self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer) self._input_dim = None self._supports_masking = True if self._diag_scale < 0: # pytype: disable=unsupported-operands raise ValueError( '`diag_scale` should be non-negative. Got `diag_scale` = {}'.format( self._diag_scale)) def build(self, input_shape): last_dim = input_shape[0][-1] if self._projection_dim is None: self._dense = tf.keras.layers.Dense( last_dim, kernel_initializer=_clone_initializer(self._kernel_initializer), bias_initializer=self._bias_initializer, kernel_regularizer=self._kernel_regularizer, bias_regularizer=self._bias_regularizer, use_bias=self._use_bias, dtype=self.dtype, activation=self._preactivation, ) else: self._dense_u = tf.keras.layers.Dense( self._projection_dim, kernel_initializer=_clone_initializer(self._kernel_initializer), kernel_regularizer=self._kernel_regularizer, use_bias=False, dtype=self.dtype, ) self._dense_v = tf.keras.layers.Dense( last_dim, kernel_initializer=_clone_initializer(self._kernel_initializer), bias_initializer=self._bias_initializer, kernel_regularizer=self._kernel_regularizer, bias_regularizer=self._bias_regularizer, use_bias=self._use_bias, dtype=self.dtype, activation=self._preactivation, ) super(Cross, self).build(input_shape) # Be sure to call this somewhere! def call(self, inputs, **kwargs): """Computes the feature cross. Args: inputs: The input tensor(x0, x) - x0: The input tensor - x: Optional second input tensor. If provided, the layer will compute crosses between x0 and x; if not provided, the layer will compute crosses between x0 and itself. Returns: Tensor of crosses. """ if isinstance(inputs, (list, tuple)): x0, x = inputs else: x0, x = inputs, inputs if not self.built: self.build(x0.shape) if x0.shape[-1] != x.shape[-1]: raise ValueError( '`x0` and `x` dimension mismatch! Got `x0` dimension {}, and x ' 'dimension {}. This case is not supported yet.'.format( x0.shape[-1], x.shape[-1])) if self._projection_dim is None: prod_output = self._dense(x) else: prod_output = self._dense_v(self._dense_u(x)) # prod_output = tf.cast(prod_output, self.compute_dtype) if self._diag_scale: prod_output = prod_output + self._diag_scale * x return x0 * prod_output + x def get_config(self): config = { 'projection_dim': self._projection_dim, 'diag_scale': self._diag_scale, 'use_bias': self._use_bias, 'preactivation': tf.keras.activations.serialize(self._preactivation), 'kernel_initializer': tf.keras.initializers.serialize(self._kernel_initializer), 'bias_initializer': tf.keras.initializers.serialize(self._bias_initializer), 'kernel_regularizer': tf.keras.regularizers.serialize(self._kernel_regularizer), 'bias_regularizer': tf.keras.regularizers.serialize(self._bias_regularizer), } base_config = super(Cross, self).get_config() return dict(list(base_config.items()) + list(config.items())) class CIN(tf.keras.layers.Layer): """Compressed Interaction Network(CIN) module in xDeepFM model. CIN layer is aimed at achieving high-order feature interactions at vector-wise level rather than bit-wise level. Reference: [xDeepFM](https://arxiv.org/pdf/1803.05170) xDeepFM: Combining Explicit and Implicit Feature Interactions for Recommender Systems """ def __init__(self, params, name='cin', reuse=None, **kwargs): super(CIN, self).__init__(name=name, **kwargs) self._name = name self._hidden_feature_sizes = list( params.get_or_default('hidden_feature_sizes', [])) assert isinstance(self._hidden_feature_sizes, list) and len( self._hidden_feature_sizes ) > 0, 'parameter hidden_feature_sizes must be a list of int with length greater than 0' kernel_regularizer = params.get_or_default('kernel_regularizer', None) self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer) bias_regularizer = params.get_or_default('bias_regularizer', None) self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer) def build(self, input_shape): if len(input_shape) != 3: raise ValueError( 'Unexpected inputs dimensions %d, expect to be 3 dimensions' % (len(input_shape))) hidden_feature_sizes = [input_shape[1] ] + [h for h in self._hidden_feature_sizes] tfv1 = tf.compat.v1 if tf.__version__ >= '2.0' else tf with tfv1.variable_scope(self._name): self.kernel_list = [ tfv1.get_variable( name='cin_kernel_%d' % i, shape=[ hidden_feature_sizes[i + 1], hidden_feature_sizes[i], hidden_feature_sizes[0] ], initializer=tf.initializers.he_normal(), regularizer=self._kernel_regularizer, trainable=True) for i in range(len(self._hidden_feature_sizes)) ] self.bias_list = [ tfv1.get_variable( name='cin_bias_%d' % i, shape=[hidden_feature_sizes[i + 1]], initializer=tf.keras.initializers.Zeros, regularizer=self._bias_regularizer, trainable=True) for i in range(len(self._hidden_feature_sizes)) ] super(CIN, self).build(input_shape) def call(self, input, **kwargs): """Computes the compressed feature maps. Args: input: The 3D input tensor with shape (b, h0, d), where b is batch_size, h0 is the number of features, d is the feature embedding dimension. Returns: 2D tensor of compressed feature map with shape (b, featuremap_num), where b is the batch_size, featuremap_num is sum of the hidden layer sizes """ x_0 = input x_i = input x_0_expanded = tf.expand_dims(x_0, 1) pooled_feature_map_list = [] for i in range(len(self._hidden_feature_sizes)): hk = self._hidden_feature_sizes[i] x_i_expanded = tf.expand_dims(x_i, 2) intermediate_tensor = tf.multiply(x_0_expanded, x_i_expanded) intermediate_tensor_expanded = tf.expand_dims(intermediate_tensor, 1) intermediate_tensor_expanded = tf.tile(intermediate_tensor_expanded, [1, hk, 1, 1, 1]) feature_map_elementwise = tf.multiply( intermediate_tensor_expanded, tf.expand_dims(tf.expand_dims(self.kernel_list[i], -1), 0)) feature_map = tf.reduce_sum( tf.reduce_sum(feature_map_elementwise, axis=3), axis=2) feature_map = tf.add( feature_map, tf.expand_dims(tf.expand_dims(self.bias_list[i], axis=-1), axis=0)) feature_map = tf.nn.relu(feature_map) x_i = feature_map pooled_feature_map_list.append(tf.reduce_sum(feature_map, axis=-1)) return tf.concat( pooled_feature_map_list, axis=-1) # shape = (b, h1 + ... + hk) def get_config(self): pass def _clone_initializer(initializer): return initializer.__class__.from_config(initializer.get_config())