easy_rec/python/layers/capsule_layer.py (121 lines of code) (raw):
# -*- encoding:utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import logging
import numpy as np
import tensorflow as tf
if tf.__version__ >= '2.0':
tf = tf.compat.v1
class CapsuleLayer:
def __init__(self, capsule_config, is_training):
# max_seq_len: max behaviour sequence length(history length)
self._max_seq_len = capsule_config.max_seq_len
# max_k: max high capsule number
self._max_k = capsule_config.max_k
# high_dim: high capsule vector dimension
self._high_dim = capsule_config.high_dim
# number of Expectation-Maximization iterations
self._num_iters = capsule_config.num_iters
# routing_logits_scale
self._routing_logits_scale = capsule_config.routing_logits_scale
# routing_logits_stddev
self._routing_logits_stddev = capsule_config.routing_logits_stddev
# squash power
self._squash_pow = capsule_config.squash_pow
# scale ratio
self._scale_ratio = capsule_config.scale_ratio
self._const_caps_num = capsule_config.const_caps_num
self._is_training = is_training
def squash(self, inputs):
"""Squash inputs over the last dimension."""
input_norm = tf.reduce_sum(tf.square(inputs), keep_dims=True, axis=-1)
input_norm_eps = tf.maximum(input_norm, 1e-8)
scale_factor = tf.pow(input_norm_eps / (1 + input_norm_eps), self._squash_pow) * \
self._scale_ratio / tf.sqrt(input_norm_eps)
tf.summary.histogram('capsule/squash_scale_factor', scale_factor)
return scale_factor * inputs
def _build_capsule_simi(self, high_capsules, capsule_num):
high_capsule_mask = tf.sequence_mask(capsule_num,
tf.shape(high_capsules)[1])
high_capsules = high_capsules * tf.to_float(high_capsule_mask[:, :, None])
high_capsules = tf.nn.l2_normalize(high_capsules, axis=-1)
sum_sqr = tf.square(tf.reduce_sum(high_capsules, axis=1))
sqr_sum = tf.reduce_sum(tf.square(high_capsules), axis=1)
simi = sum_sqr - sqr_sum
div = tf.maximum(tf.to_float(capsule_num * (capsule_num - 1)), 1.0)
simi = tf.reduce_sum(simi, axis=1) / div
is_multi = tf.to_float(capsule_num > 1)
avg_simi = tf.reduce_sum((simi + 1) * is_multi) / \
(2.0 * tf.reduce_sum(is_multi))
return avg_simi
def __call__(self, seq_feas, seq_lens):
"""Capsule layer implementation.
Args:
seq_feas: tensor of shape batch_size x self._max_seq_len x low_fea_dim(bsd)
seq_lens: tensor of shape batch_size
Return:
high_capsules: tensor of shape batch_size x max_k x high_dim
"""
# pad or clip to max_seq_len
seq_feas = tf.cond(
tf.greater(tf.shape(seq_feas)[1], self._max_seq_len),
lambda: seq_feas[:, :self._max_seq_len, :], lambda: tf.cond(
tf.less(tf.shape(seq_feas)[1], self._max_seq_len), lambda: tf.pad(
seq_feas, [[0, 0], [
0, self._max_seq_len - tf.shape(seq_feas)[1]
], [0, 0]]), lambda: seq_feas))
seq_lens = tf.minimum(seq_lens, self._max_seq_len)
batch_size = tf.shape(seq_lens)[0]
# max_seq_len x max_num_high_capsule(sh)
if self._is_training:
routing_logits = tf.truncated_normal(
[batch_size, self._max_seq_len, self._max_k],
stddev=self._routing_logits_stddev)
else:
np.random.seed(28)
routing_logits = tf.constant(
np.random.uniform(
high=self._routing_logits_stddev,
size=[self._max_seq_len, self._max_k]),
dtype=tf.float32)
routing_logits = tf.tile(routing_logits[None, :, :], [batch_size, 1, 1])
routing_logits = tf.stop_gradient(routing_logits)
# batch_size x max_seq_len x max_k(bsh)
low_fea_dim = seq_feas.get_shape()[-1]
# map low capsule features to high capsule features:
# low_fea_dim x high_dim(de)
bilinear_matrix = tf.get_variable(
dtype=tf.float32, shape=[low_fea_dim, self._high_dim], name='capsule/S')
# map sequence feature to high dimensional space
seq_feas_high = tf.tensordot(seq_feas, bilinear_matrix, axes=1)
seq_feas_high_stop = tf.stop_gradient(seq_feas_high)
seq_feas_high_norm = tf.nn.l2_normalize(seq_feas_high_stop, -1)
if self._const_caps_num:
logging.info('will use constant number of capsules: %d' % self._max_k)
num_high_capsules = tf.zeros_like(seq_lens, dtype=tf.int32) + self._max_k
else:
logging.info(
'will use log(seq_len) number of capsules, max_capsules: %d' %
self._max_k)
num_high_capsules = tf.maximum(
1, tf.minimum(self._max_k,
tf.to_int32(tf.log(tf.to_float(seq_lens)))))
# batch_size x max_seq_len(bs)
mask = tf.sequence_mask(seq_lens, self._max_seq_len)
mask = tf.cast(mask, tf.float32)
# batch_size x max_k(bh)
mask_cap = tf.sequence_mask(num_high_capsules, self._max_k)
mask_cap = tf.cast(mask_cap, tf.float32)
# batch_size x max_seq_len x 1(bs1)
# max_seq_thresh = (mask[:, :, None] * 2 - 1) * 1e32
# batch_size x 1 x h (b1h)
max_cap_thresh = (tf.cast(mask_cap[:, None, :], tf.float32) * 2 - 1) * 1e32
for iter_id in range(self._num_iters):
# batch_size x max_seq_len x max_k(bsh)
routing_logits = tf.minimum(routing_logits, max_cap_thresh)
routing_logits = tf.nn.softmax(routing_logits, axis=2)
routing_logits = routing_logits * mask[:, :, None]
logits_simi = self._build_capsule_simi(routing_logits, seq_lens)
tf.summary.scalar('capsule/rlogits_simi_%d' % iter_id, logits_simi)
seq_fea_simi = self._build_capsule_simi(seq_feas_high_stop, seq_lens)
tf.summary.scalar('capsule/seq_fea_simi_%d' % iter_id, seq_fea_simi)
# batch_size x max_k x high_dim(bse,bsh->bhe)
high_capsules = tf.einsum(
'bse, bsh->bhe', seq_feas_high_stop
if iter_id + 1 < self._num_iters else seq_feas_high, routing_logits)
if iter_id + 1 == self._num_iters:
capsule_simi = self._build_capsule_simi(high_capsules,
num_high_capsules)
tf.summary.scalar('caspule/simi_%d' % iter_id, capsule_simi)
tf.summary.scalar('capsule/before_squash',
tf.reduce_mean(tf.norm(high_capsules, axis=-1)))
high_capsules = self.squash(high_capsules)
tf.summary.scalar('capsule/after_squash',
tf.reduce_mean(tf.norm(high_capsules, axis=-1)))
capsule_simi_final = self._build_capsule_simi(high_capsules,
num_high_capsules)
tf.summary.scalar('caspule/simi_final', capsule_simi_final)
break
# batch_size x max_k x high_dim(bhe)
high_capsules = tf.nn.l2_normalize(high_capsules, -1)
capsule_simi = self._build_capsule_simi(high_capsules, num_high_capsules)
tf.summary.scalar('caspule/simi_%d' % iter_id, capsule_simi)
# batch_size x max_seq_len x max_k(bse, bhe->bsh)
if self._routing_logits_scale > 0:
if iter_id == 0:
logging.info('routing_logits_scale = %.2f' %
self._routing_logits_scale)
routing_logits = tf.einsum('bse, bhe->bsh', seq_feas_high_norm,
high_capsules) * self._routing_logits_scale
else:
routing_logits = tf.einsum('bse, bhe->bsh', seq_feas_high_stop,
high_capsules)
# zero paddings
high_capsule_mask = tf.sequence_mask(num_high_capsules, self._max_k)
high_capsules = high_capsules * tf.to_float(high_capsule_mask[:, :, None])
return high_capsules, num_high_capsules