in python/singa/tensor.py [0:0]
def repeat(self, repeats, axis):
'''Repeat data of a tensor
Args:
repeats(int or a sequence): the number that the tensor need to repeat for
axis (int):the axis to do repeat
If it is None, then the repeated tensor will be flattened.If it isn't None,
the repeats could be sequence, but it's size should match the axis's shape
Returns:
the tensor which has been repeated
'''
t = Tensor()
t_ndim = self.ndim()
if isinstance(repeats, int) or isinstance(repeats, complex):
if repeats < 0:
raise ValueError(
"'repeats' should not be negative: {}".format(repeats))
if axis != None and axis < 0:
axis += t_ndim
# broadcast = True
if axis is None:
axis = 9999
t.shape = (product(self.shape) * repeats,)
Repeats = [
repeats,
]
t.data = self.data.Repeat(Repeats, axis)
elif axis >= 0:
t_shape = list(self.shape)
t_shape[axis] = self.shape[axis] * repeats
t.shape = tuple(t_shape)
Repeats = [
repeats,
]
t.data = self.data.Repeat(Repeats, axis)
elif isinstance(repeats, tuple) or isinstance(repeats, list):
for rep in repeats:
if rep < 0:
raise ValueError(
"'repeats' should be int or sequence: {}".format(
repeats))
if axis != None and axis < 0:
axis += t_ndim
if axis is None:
raise ValueError(
"when axis us None, 'repeats' should be int: {}".format(
repeats))
elif axis >= 0:
t_shape = list(self.shape)
t_shape[axis] = sum(repeats)
t.shape = tuple(t_shape)
t.data = self.data.Repeat(list(repeats), axis)
else:
raise ValueError('repeats should be int or sequence')
return t