in torchbiggraph/edgelist.py [0:0]
def __getitem__(self, index: Union[int, slice, LongTensorType]) -> "EdgeList":
if not isinstance(
index, (int, slice, (torch.LongTensor, torch.cuda.LongTensor))
):
raise TypeError(
"Index can only be int, slice or long tensor, got %s" % type(index)
)
if (
isinstance(index, (torch.LongTensor, torch.cuda.LongTensor))
and index.dim() != 1
):
raise ValueError(
"Long tensor index must be 1-dimensional, got %d-dimensional"
% (index.dim(),)
)
sub_lhs = self.lhs[index]
sub_rhs = self.rhs[index]
if self.has_scalar_relation_type():
sub_rel = self.rel
else:
sub_rel = self.rel[index]
if self.has_weight():
sub_weight = self.weight[index]
else:
sub_weight = None
return type(self)(sub_lhs, sub_rhs, sub_rel, sub_weight)