def from_csr()

in runtime/native/python/treelite_runtime/predictor.py [0:0]


  def from_csr(cls, csr, rbegin=None, rend=None):
    """
    Get a sparse batch from a subset of rows in a CSR (Compressed Sparse Row)
    matrix. The subset is given by the range ``[rbegin, rend)``.

    Parameters
    ----------
    csr : object of class :py:class:`treelite.DMatrix` or \
          :py:class:`scipy.sparse.csr_matrix`
        data matrix
    rbegin : :py:class:`int <python:int>`, optional
        the index of the first row in the subset
    rend : :py:class:`int <python:int>`, optional
        one past the index of the last row in the subset. If missing, set to
        the end of the matrix.

    Returns
    -------
    sparse_batch : :py:class:`Batch`
        a sparse batch consisting of rows ``[rbegin, rend)``
    """
    # use duck typing so as to accomodate both scipy.sparse.csr_matrix
    # and DMatrix without explictly importing any of them
    try:
      num_row = csr.shape[0]
      num_col = csr.shape[1]
    except AttributeError:
      raise ValueError('csr must contain shape attribute')
    except TypeError:
      raise ValueError('csr.shape must be of tuple type')
    except IndexError:
      raise ValueError('csr.shape must be of length 2 (indicating 2D matrix)')
    rbegin = rbegin if rbegin is not None else 0
    rend = rend if rend is not None else num_row
    if rbegin >= rend:
      raise TreeliteError('rbegin must be less than rend')
    if rbegin < 0:
      raise TreeliteError('rbegin must be nonnegative')
    if rend > num_row:
      raise TreeliteError('rend must be less than number of rows in csr')

    # compute submatrix with rows [rbegin, rend)
    ibegin = csr.indptr[rbegin]
    iend = csr.indptr[rend]
    data_subset = np.array(csr.data[ibegin:iend], copy=False,
                           dtype=np.float32, order='C')
    indices_subset = np.array(csr.indices[ibegin:iend], copy=False,
                              dtype=np.uint32, order='C')
    indptr_subset = np.array(csr.indptr[rbegin:(rend+1)] - ibegin, copy=False,
                             dtype=np.uintp, order='C')

    batch = Batch()
    batch.handle = ctypes.c_void_p()
    batch.kind = 'sparse'
    _check_call(_LIB.TreeliteAssembleSparseBatch(
        data_subset.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
        indices_subset.ctypes.data_as(ctypes.POINTER(ctypes.c_uint32)),
        indptr_subset.ctypes.data_as(ctypes.POINTER(ctypes.c_size_t)),
        ctypes.c_size_t(rend - rbegin),
        ctypes.c_size_t(num_col),
        ctypes.byref(batch.handle)))
    # save handles for internal arrays
    batch.data = data_subset
    batch.indices = indices_subset
    batch.indptr = indptr_subset
    # save pointer to csr so that it doesn't get garbage-collected prematurely
    batch.csr = csr
    return batch