in recommenders/models/deeprec/models/sequential/rnn_cell_implement.py [0:0]
def call(self, inputs, state):
time_now_score = tf.expand_dims(inputs[:, -1], -1)
time_last_score = tf.expand_dims(inputs[:, -2], -1)
inputs = inputs[:, :-2]
num_proj = self._num_units if self._num_proj is None else self._num_proj
sigmoid = math_ops.sigmoid
if self._state_is_tuple:
(c_prev, m_prev) = state
else:
c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units])
m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj])
dtype = inputs.dtype
input_size = inputs.get_shape().with_rank(2)[1]
if input_size is None:
raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
if self._time_kernel_w1 is None:
scope = vs.get_variable_scope()
with vs.variable_scope(scope, initializer=self._initializer) as unit_scope:
with vs.variable_scope(unit_scope):
self._time_input_w1 = vs.get_variable(
"_time_input_w1", shape=[self._num_units], dtype=dtype
)
self._time_input_bias1 = vs.get_variable(
"_time_input_bias1", shape=[self._num_units], dtype=dtype
)
self._time_input_w2 = vs.get_variable(
"_time_input_w2", shape=[self._num_units], dtype=dtype
)
self._time_input_bias2 = vs.get_variable(
"_time_input_bias2", shape=[self._num_units], dtype=dtype
)
self._time_kernel_w1 = vs.get_variable(
"_time_kernel_w1",
shape=[input_size, self._num_units],
dtype=dtype,
)
self._time_kernel_t1 = vs.get_variable(
"_time_kernel_t1",
shape=[self._num_units, self._num_units],
dtype=dtype,
)
self._time_bias1 = vs.get_variable(
"_time_bias1", shape=[self._num_units], dtype=dtype
)
self._time_kernel_w2 = vs.get_variable(
"_time_kernel_w2",
shape=[input_size, self._num_units],
dtype=dtype,
)
self._time_kernel_t2 = vs.get_variable(
"_time_kernel_t2",
shape=[self._num_units, self._num_units],
dtype=dtype,
)
self._time_bias2 = vs.get_variable(
"_time_bias2", shape=[self._num_units], dtype=dtype
)
self._o_kernel_t1 = vs.get_variable(
"_o_kernel_t1",
shape=[self._num_units, self._num_units],
dtype=dtype,
)
self._o_kernel_t2 = vs.get_variable(
"_o_kernel_t2",
shape=[self._num_units, self._num_units],
dtype=dtype,
)
time_now_input = tf.nn.tanh(
time_now_score * self._time_input_w1 + self._time_input_bias1
)
time_last_input = tf.nn.tanh(
time_last_score * self._time_input_w2 + self._time_input_bias2
)
time_now_state = (
math_ops.matmul(inputs, self._time_kernel_w1)
+ math_ops.matmul(time_now_input, self._time_kernel_t1)
+ self._time_bias1
)
time_last_state = (
math_ops.matmul(inputs, self._time_kernel_w2)
+ math_ops.matmul(time_last_input, self._time_kernel_t2)
+ self._time_bias2
)
if self._linear1 is None:
scope = vs.get_variable_scope()
with vs.variable_scope(scope, initializer=self._initializer) as unit_scope:
if self._num_unit_shards is not None:
unit_scope.set_partitioner(
partitioned_variables.fixed_size_partitioner(
self._num_unit_shards
)
)
self._linear1 = _Linear([inputs, m_prev], 4 * self._num_units, True)
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
lstm_matrix = self._linear1([inputs, m_prev])
i, j, f, o = array_ops.split(value=lstm_matrix, num_or_size_splits=4, axis=1)
o = (
o
+ math_ops.matmul(time_now_input, self._o_kernel_t1)
+ math_ops.matmul(time_last_input, self._o_kernel_t2)
)
# Diagonal connections
if self._use_peepholes and not self._w_f_diag:
scope = vs.get_variable_scope()
with vs.variable_scope(scope, initializer=self._initializer) as unit_scope:
with vs.variable_scope(unit_scope):
self._w_f_diag = vs.get_variable(
"w_f_diag", shape=[self._num_units], dtype=dtype
)
self._w_i_diag = vs.get_variable(
"w_i_diag", shape=[self._num_units], dtype=dtype
)
self._w_o_diag = vs.get_variable(
"w_o_diag", shape=[self._num_units], dtype=dtype
)
if self._use_peepholes:
c = sigmoid(f + self._forget_bias + self._w_f_diag * c_prev) * sigmoid(
time_last_state
) * c_prev + sigmoid(i + self._w_i_diag * c_prev) * sigmoid(
time_now_state
) * self._activation(
j
)
else:
c = sigmoid(f + self._forget_bias) * sigmoid(
time_last_state
) * c_prev + sigmoid(i) * sigmoid(time_now_state) * self._activation(j)
if self._cell_clip is not None:
# pylint: disable=invalid-unary-operand-type
c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
# pylint: enable=invalid-unary-operand-type
if self._use_peepholes:
m = sigmoid(o + self._w_o_diag * c) * self._activation(c)
else:
m = sigmoid(o) * self._activation(c)
if self._num_proj is not None:
if self._linear2 is None:
scope = vs.get_variable_scope()
with vs.variable_scope(scope, initializer=self._initializer):
with vs.variable_scope("projection") as proj_scope:
if self._num_proj_shards is not None:
proj_scope.set_partitioner(
partitioned_variables.fixed_size_partitioner(
self._num_proj_shards
)
)
self._linear2 = _Linear(m, self._num_proj, False)
m = self._linear2(m)
if self._proj_clip is not None:
# pylint: disable=invalid-unary-operand-type
m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
# pylint: enable=invalid-unary-operand-type
new_state = (
LSTMStateTuple(c, m)
if self._state_is_tuple
else array_ops.concat([c, m], 1)
)
return m, new_state