easy_rec/python/layers/keras/embedding.py (68 lines of code) (raw):

# -*- encoding:utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. """Fused embedding layer.""" import tensorflow as tf from tensorflow.python.keras.layers import Embedding from tensorflow.python.keras.layers import Layer def _combine(embeddings, weights, comb_fn): # embeddings shape: [B, N, D] if callable(comb_fn): return comb_fn(embeddings, axis=1) if weights is None: return tf.reduce_mean(embeddings, axis=1) if isinstance(weights, tf.SparseTensor): if weights.dtype == tf.string: weights = tf.sparse.to_dense(weights, default_value='0') weights = tf.string_to_number(weights) else: weights = tf.sparse.to_dense(weights, default_value=0.0) sum_weights = tf.reduce_sum(weights, axis=1, keepdims=True) weights = tf.expand_dims(weights / sum_weights, axis=-1) return tf.reduce_sum(embeddings * weights, axis=1) class EmbeddingLayer(Layer): def __init__(self, params, name='embedding_layer', reuse=None, **kwargs): super(EmbeddingLayer, self).__init__(name=name, **kwargs) params.check_required(['vocab_size', 'embedding_dim']) vocab_size = int(params.vocab_size) combiner = params.get_or_default('combiner', 'weight') if combiner == 'mean': self.combine_fn = tf.reduce_mean elif combiner == 'sum': self.combine_fn = tf.reduce_sum elif combiner == 'max': self.combine_fn = tf.reduce_max elif combiner == 'min': self.combine_fn = tf.reduce_min elif combiner == 'weight': self.combine_fn = 'weight' else: raise ValueError('unsupported embedding combiner: ' + combiner) self.embed_dim = int(params.embedding_dim) self.embedding = Embedding(vocab_size, self.embed_dim) self.do_concat = params.get_or_default('concat', True) def call(self, inputs, training=None, **kwargs): inputs, weights = inputs # 将多个特征的输入合并为一个索引 tensor flat_inputs = [tf.reshape(input_field, [-1]) for input_field in inputs] all_indices = tf.concat(flat_inputs, axis=0) # 从共享的嵌入表中进行一次 embedding lookup all_embeddings = self.embedding(all_indices) is_multi = [] # 计算每个特征的嵌入 split_sizes = [] for input_field in inputs: assert input_field.shape.ndims <= 2, 'dims of embedding layer input must be <= 2' input_shape = tf.shape(input_field) size = input_shape[0] if input_field.shape.ndims > 1: size *= input_shape[-1] is_multi.append(True) else: is_multi.append(False) split_sizes.append(size) embeddings = tf.split(all_embeddings, split_sizes, axis=0) for i in range(len(embeddings)): if is_multi[i]: batch_size = tf.shape(inputs[i])[0] embeddings[i] = tf.cond( tf.equal(tf.size(embeddings[i]), 0), lambda: tf.zeros([batch_size, self.embed_dim]), lambda: _combine( tf.reshape(embeddings[i], [batch_size, -1, self.embed_dim]), weights[i], self.combine_fn)) if self.do_concat: embeddings = tf.concat(embeddings, axis=-1) print('Embedding layer:', self.name, embeddings) return embeddings