easy_rec/python/layers/keras/data_augment.py (100 lines of code) (raw):
# -*- encoding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import tensorflow as tf
from tensorflow.python.keras.layers import Layer
from easy_rec.python.utils.shape_utils import get_shape_list
if tf.__version__ >= '2.0':
tf = tf.compat.v1
def item_mask(aug_data, length, mask_emb, mask_rate):
length1 = tf.cast(length, dtype=tf.float32)
num_mask = tf.cast(tf.math.floor(length1 * mask_rate), dtype=tf.int32)
max_len = tf.shape(aug_data)[0]
seq_mask = tf.sequence_mask(num_mask, length)
seq_mask = tf.random.shuffle(seq_mask)
padding = tf.sequence_mask(0, max_len - length)
seq_mask = tf.concat([seq_mask, padding], axis=0)
mask_emb = tf.tile(mask_emb, [max_len, 1])
masked_item_seq = tf.where(seq_mask, mask_emb, aug_data)
return masked_item_seq, length
def item_crop(aug_data, length, crop_rate):
length1 = tf.cast(length, dtype=tf.float32)
max_len, _ = get_shape_list(aug_data)
max_length = tf.cast(max_len, dtype=tf.int32)
num_left = tf.cast(tf.math.floor(length1 * crop_rate), dtype=tf.int32)
crop_begin = tf.random.uniform([],
minval=0,
maxval=length - num_left,
dtype=tf.int32)
zeros = tf.zeros_like(aug_data)
x = aug_data[crop_begin:crop_begin + num_left]
y = zeros[:max_length - num_left]
cropped = tf.concat([x, y], axis=0)
cropped_item_seq = tf.where(
crop_begin + num_left < max_length, cropped,
tf.concat([aug_data[crop_begin:], zeros[:crop_begin]], axis=0))
return cropped_item_seq, num_left
def item_reorder(aug_data, length, reorder_rate):
length1 = tf.cast(length, dtype=tf.float32)
num_reorder = tf.cast(tf.math.floor(length1 * reorder_rate), dtype=tf.int32)
reorder_begin = tf.random.uniform([],
minval=0,
maxval=length - num_reorder,
dtype=tf.int32)
shuffle_index = tf.range(reorder_begin, reorder_begin + num_reorder)
shuffle_index = tf.random.shuffle(shuffle_index)
x = tf.range(get_shape_list(aug_data)[0])
left = tf.slice(x, [0], [reorder_begin])
right = tf.slice(x, [reorder_begin + num_reorder], [-1])
reordered_item_index = tf.concat([left, shuffle_index, right], axis=0)
reordered_item_seq = tf.scatter_nd(
tf.expand_dims(reordered_item_index, axis=1), aug_data,
tf.shape(aug_data))
return reordered_item_seq, length
def augment_fn(x, aug_param, mask):
seq, length = x
def crop_fn():
return item_crop(seq, length, aug_param.crop_rate)
def mask_fn():
return item_mask(seq, length, mask, aug_param.mask_rate)
def reorder_fn():
return item_reorder(seq, length, aug_param.reorder_rate)
trans_fn = []
if aug_param.crop_rate < 1.0:
trans_fn.append(crop_fn)
if aug_param.mask_rate > 0:
trans_fn.append(mask_fn)
if aug_param.reorder_rate > 0:
trans_fn.append(reorder_fn)
num_trans = len(trans_fn)
if num_trans == 0:
return seq, length
if num_trans == 1:
return trans_fn[0]()
method = tf.random.uniform([], minval=0, maxval=num_trans, dtype=tf.int32)
if num_trans == 2:
return tf.cond(tf.equal(method, 0), trans_fn[0], trans_fn[1])
aug_seq, aug_len = tf.cond(
tf.equal(method, 0), crop_fn,
lambda: tf.cond(tf.equal(method, 1), mask_fn, reorder_fn))
return aug_seq, aug_len
def sequence_augment(seq_input, seq_len, mask, aug_param):
lengths = tf.cast(seq_len, dtype=tf.int32)
aug_seq, aug_len = tf.map_fn(
lambda elems: augment_fn(elems, aug_param, mask),
elems=(seq_input, lengths),
dtype=(tf.float32, tf.int32))
aug_seq = tf.reshape(aug_seq, tf.shape(seq_input))
return aug_seq, aug_len
class SeqAugment(Layer):
"""Do data augmentation for input sequence embedding."""
def __init__(self, params, name='seq_aug', reuse=None, **kwargs):
super(SeqAugment, self).__init__(name=name, **kwargs)
self.reuse = reuse
self.seq_aug_params = params.get_pb_config()
def call(self, inputs, training=None, **kwargs):
assert isinstance(inputs, (list, tuple))
seq_input, seq_len = inputs[:2]
embedding_size = int(seq_input.shape[-1])
with tf.variable_scope(self.name, reuse=self.reuse):
mask_emb = tf.get_variable(
'mask', [1, embedding_size], dtype=tf.float32, trainable=True)
aug_seq, aug_len = sequence_augment(seq_input, seq_len, mask_emb,
self.seq_aug_params)
return aug_seq, aug_len