def repeat()

in python/singa/tensor.py [0:0]


    def repeat(self, repeats, axis):
        '''Repeat data of a tensor

        Args:
            repeats(int or a sequence): the number that the tensor need to repeat for
            axis (int):the axis to do repeat
                       If it is None, then the repeated tensor will be flattened.If it isn't None,
                       the repeats could be sequence, but it's size should match the axis's shape

        Returns:
            the tensor which has been repeated

        '''
        t = Tensor()
        t_ndim = self.ndim()
        if isinstance(repeats, int) or isinstance(repeats, complex):
            if repeats < 0:
                raise ValueError(
                    "'repeats' should not be negative: {}".format(repeats))
            if axis != None and axis < 0:
                axis += t_ndim
            # broadcast = True
            if axis is None:
                axis = 9999
                t.shape = (product(self.shape) * repeats,)
                Repeats = [
                    repeats,
                ]
                t.data = self.data.Repeat(Repeats, axis)
            elif axis >= 0:
                t_shape = list(self.shape)
                t_shape[axis] = self.shape[axis] * repeats
                t.shape = tuple(t_shape)
                Repeats = [
                    repeats,
                ]
                t.data = self.data.Repeat(Repeats, axis)

        elif isinstance(repeats, tuple) or isinstance(repeats, list):
            for rep in repeats:
                if rep < 0:
                    raise ValueError(
                        "'repeats' should be int or sequence: {}".format(
                            repeats))

            if axis != None and axis < 0:
                axis += t_ndim
            if axis is None:
                raise ValueError(
                    "when axis us None, 'repeats' should be int: {}".format(
                        repeats))
            elif axis >= 0:
                t_shape = list(self.shape)
                t_shape[axis] = sum(repeats)
                t.shape = tuple(t_shape)
                t.data = self.data.Repeat(list(repeats), axis)
        else:
            raise ValueError('repeats should be int or sequence')

        return t