in torchbiggraph/entitylist.py [0:0]
def __init__(self, tensor: LongTensorType, tensor_list: TensorList) -> None:
if not isinstance(tensor, (torch.LongTensor, torch.cuda.LongTensor)):
raise TypeError(
"Expected long tensor as first argument, got %s" % type(tensor)
)
if not isinstance(tensor_list, TensorList):
raise TypeError(
"Expected tensor list as second argument, got %s" % type(tensor_list)
)
if tensor.dim() != 1:
raise ValueError(
"Expected 1-dimensional tensor, got %d-dimensional one" % tensor.dim()
)
if tensor.shape[0] != len(tensor_list):
raise ValueError(
"The tensor and tensor list have different lengths: %d != %d"
% (tensor.shape[0], len(tensor_list))
)
# TODO We could check that, for all i, we have either tensor[i] < 0 or
# tensor_list[i] empty, however it's expensive and we're already doing
# something similar at retrieval inside to_tensor(_list).
self.tensor: LongTensorType = tensor
self.tensor_list: TensorList = tensor_list