in runtime/native/python/treelite_runtime/predictor.py [0:0]
def from_csr(cls, csr, rbegin=None, rend=None):
"""
Get a sparse batch from a subset of rows in a CSR (Compressed Sparse Row)
matrix. The subset is given by the range ``[rbegin, rend)``.
Parameters
----------
csr : object of class :py:class:`treelite.DMatrix` or \
:py:class:`scipy.sparse.csr_matrix`
data matrix
rbegin : :py:class:`int <python:int>`, optional
the index of the first row in the subset
rend : :py:class:`int <python:int>`, optional
one past the index of the last row in the subset. If missing, set to
the end of the matrix.
Returns
-------
sparse_batch : :py:class:`Batch`
a sparse batch consisting of rows ``[rbegin, rend)``
"""
# use duck typing so as to accomodate both scipy.sparse.csr_matrix
# and DMatrix without explictly importing any of them
try:
num_row = csr.shape[0]
num_col = csr.shape[1]
except AttributeError:
raise ValueError('csr must contain shape attribute')
except TypeError:
raise ValueError('csr.shape must be of tuple type')
except IndexError:
raise ValueError('csr.shape must be of length 2 (indicating 2D matrix)')
rbegin = rbegin if rbegin is not None else 0
rend = rend if rend is not None else num_row
if rbegin >= rend:
raise TreeliteError('rbegin must be less than rend')
if rbegin < 0:
raise TreeliteError('rbegin must be nonnegative')
if rend > num_row:
raise TreeliteError('rend must be less than number of rows in csr')
# compute submatrix with rows [rbegin, rend)
ibegin = csr.indptr[rbegin]
iend = csr.indptr[rend]
data_subset = np.array(csr.data[ibegin:iend], copy=False,
dtype=np.float32, order='C')
indices_subset = np.array(csr.indices[ibegin:iend], copy=False,
dtype=np.uint32, order='C')
indptr_subset = np.array(csr.indptr[rbegin:(rend+1)] - ibegin, copy=False,
dtype=np.uintp, order='C')
batch = Batch()
batch.handle = ctypes.c_void_p()
batch.kind = 'sparse'
_check_call(_LIB.TreeliteAssembleSparseBatch(
data_subset.ctypes.data_as(ctypes.POINTER(ctypes.c_float)),
indices_subset.ctypes.data_as(ctypes.POINTER(ctypes.c_uint32)),
indptr_subset.ctypes.data_as(ctypes.POINTER(ctypes.c_size_t)),
ctypes.c_size_t(rend - rbegin),
ctypes.c_size_t(num_col),
ctypes.byref(batch.handle)))
# save handles for internal arrays
batch.data = data_subset
batch.indices = indices_subset
batch.indptr = indptr_subset
# save pointer to csr so that it doesn't get garbage-collected prematurely
batch.csr = csr
return batch