def tensor_flatten_pad()

in Synthesis_incorporation/models/prediction_model.py [0:0]


    def tensor_flatten_pad(
            self, tensor, embed_size = None, shape_embed_size = None, isNoise = False
        ):
        if embed_size is None:
            embed_size = self.embedding_size
        if shape_embed_size is None:
            shape_embed_size = self.shape_embedding_size

        if not isinstance(tensor, torch.Tensor):
            tensor = torch.tensor(tensor)

        t_flatten = torch.flatten(tensor)

        if self.use_value_encoding:
            t_flatten = self.encode_values_to_code(t_flatten)

        padding_length = embed_size - list(t_flatten.shape)[-1]
        p1d = (0,padding_length) #just padding the last dimension
        t_pad = F.pad(input=t_flatten, pad=p1d, mode='constant', value=0)

        if self.use_type_encoding:
            type_padding = 0
            if tensor.dtype == torch.bool:
                type_padding = 1
            if tensor.dtype == torch.float:
                type_padding = 2

        '''size embedding'''
        if self.use_shape_encoding:
            if not isinstance(tensor, torch.Tensor):
                t_shape = []
            else:
                t_shape = list(tensor.shape)
            padding_length = shape_embed_size -1 - len(t_shape)
            p1d = (0,padding_length) #just padding the last dimension
            s_pad = F.pad(input=torch.tensor(t_shape), pad=p1d, mode='constant', value=0)

            t_pad_list = t_pad.tolist()
            s_pad_list = s_pad.tolist()

            if self.use_type_encoding:
                tensor_embedding = torch.tensor([type_padding] + [-1] + t_pad_list + [-1] + s_pad_list + [-1])
            else:
                tensor_embedding = torch.tensor(t_pad_list + [-1] + s_pad_list + [-1])
        else:
            t_pad_list = t_pad.tolist()
            if self.use_type_encoding:
                tensor_embedding = torch.tensor([type_padding] + [-1] + t_pad_list + [-1])
            else:
                tensor_embedding = torch.tensor(t_pad_list + [-1])

        return tensor_embedding.float()