in src/mlm/batchify.py [0:0]
def _pad_arrs_to_max_length(arrs, pad_axis, pad_val, use_shared_mem, dtype, round_to=None):
"""Inner Implementation of the Pad batchify
Parameters
----------
arrs : list
pad_axis : int
pad_val : number
use_shared_mem : bool, default False
Returns
-------
ret : NDArray
original_length : NDArray
"""
if isinstance(arrs[0], mx.nd.NDArray):
dtype = arrs[0].dtype if dtype is None else dtype
arrs = [arr.asnumpy() for arr in arrs]
elif not isinstance(arrs[0], np.ndarray):
arrs = [np.asarray(ele) for ele in arrs]
else:
dtype = arrs[0].dtype if dtype is None else dtype
original_length = [ele.shape[pad_axis] for ele in arrs]
max_size = max(original_length)
if round_to is not None:
max_size = round_to * math.ceil(max_size / round_to)
ret_shape = list(arrs[0].shape)
ret_shape[pad_axis] = max_size
ret_shape = (len(arrs), ) + tuple(ret_shape)
ret = np.full(shape=ret_shape, fill_value=pad_val, dtype=dtype)
for i, arr in enumerate(arrs):
if arr.shape[pad_axis] == max_size:
ret[i] = arr
else:
slices = [slice(None) for _ in range(arr.ndim)]
slices[pad_axis] = slice(0, arr.shape[pad_axis])
if slices[pad_axis].start != slices[pad_axis].stop:
slices = [slice(i, i + 1)] + slices
ret[tuple(slices)] = arr
# ctx = mx.Context('cpu_shared', 0) if use_shared_mem else mx.cpu()
# ret = mx.nd.array(ret, ctx=ctx, dtype=dtype)
original_length = np.array(original_length, dtype=np.int32)
# original_length = mx.nd.array(original_length, ctx=ctx, dtype=np.int32)
return ret, original_length