easy_rec/python/layers/senet.py (45 lines of code) (raw):

# -*- encoding:utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. import tensorflow as tf if tf.__version__ >= '2.0': tf = tf.compat.v1 class SENet: """Squeeze and Excite Network. Input shape - A list of 2D tensor with shape: ``(batch_size,embedding_size)``. The ``embedding_size`` of each field can have different value. Args: num_fields: int, number of fields. num_squeeze_group: int, number of groups for squeeze. reduction_ratio: int, reduction ratio for squeeze. l2_reg: float, l2 regularizer for embedding. name: str, name of the layer. """ def __init__(self, num_fields, num_squeeze_group, reduction_ratio, l2_reg, name='SENet'): self.num_fields = num_fields self.num_squeeze_group = num_squeeze_group self.reduction_ratio = reduction_ratio self._l2_reg = l2_reg self._name = name def __call__(self, inputs): g = self.num_squeeze_group f = self.num_fields r = self.reduction_ratio reduction_size = max(1, f * g * 2 // r) emb_size = 0 for input in inputs: emb_size += int(input.shape[-1]) group_embs = [ tf.reshape(emb, [-1, g, int(emb.shape[-1]) // g]) for emb in inputs ] squeezed = [] for emb in group_embs: squeezed.append(tf.reduce_max(emb, axis=-1)) # [B, g] squeezed.append(tf.reduce_mean(emb, axis=-1)) # [B, g] z = tf.concat(squeezed, axis=1) # [bs, field_size * num_groups * 2] reduced = tf.layers.dense( inputs=z, units=reduction_size, kernel_regularizer=self._l2_reg, activation='relu', name='%s/reduce' % self._name) excited_weights = tf.layers.dense( inputs=reduced, units=emb_size, kernel_initializer='glorot_normal', name='%s/excite' % self._name) # Re-weight inputs = tf.concat(inputs, axis=-1) output = inputs * excited_weights return output