def __init__()

in graphlearn_torch/python/data/graph.py [0:0]


  def __init__(self,
               edge_index: Union[TensorDataType,
                                 Tuple[TensorDataType, TensorDataType]],
               edge_ids: Optional[TensorDataType] = None,
               edge_weights: Optional[TensorDataType] = None,
               input_layout: str = 'COO',
               layout: Literal['CSR', 'CSC'] = 'CSR'):
    
    edge_index = convert_to_tensor(edge_index, dtype=torch.int64)
    row, col = edge_index[0], edge_index[1]
    input_layout = str(input_layout).upper()
    if input_layout == 'COO':
        assert row.numel() == col.numel()
        num_edges = row.numel()
    elif input_layout == 'CSR':
        num_edges = col.numel()
    elif input_layout == 'CSC':
        num_edges = row.numel()
    else:
      raise RuntimeError(f"'{self.__class__.__name__}': got "
                         f"invalid edge layout {input_layout}")

    edge_ids = convert_to_tensor(edge_ids, dtype=torch.int64)
    if edge_ids is None:
      edge_ids = torch.arange(num_edges, dtype=torch.int64, device=row.device)
    else:
      assert edge_ids.numel() == num_edges

    edge_weights = convert_to_tensor(edge_weights, dtype=torch.float)
    if edge_weights is not None:
      assert edge_weights.numel() == num_edges

    self._layout = layout
    
    if input_layout == layout:
      if input_layout == 'CSC':
        self._indices, self._indptr = row, col
      elif input_layout == 'CSR':
        self._indptr, self._indices = row, col
      self._edge_ids = edge_ids
      self._edge_weights = edge_weights
      return
    elif input_layout == 'CSC':
      col = ptr2ind(col)
    elif input_layout == 'CSR':
      row = ptr2ind(row)
    # COO format data is prepared.
    
    if layout == 'CSR':
      self._indptr, self._indices, self._edge_ids, self._edge_weights = \
        coo_to_csr(row, col, edge_id=edge_ids, edge_weight=edge_weights)
    elif layout == 'CSC':
      self._indices, self._indptr, self._edge_ids, self._edge_weights = \
        coo_to_csc(row, col, edge_id=edge_ids, edge_weight=edge_weights)