def _pad_arrs_to_max_length()

in src/mlm/batchify.py [0:0]


def _pad_arrs_to_max_length(arrs, pad_axis, pad_val, use_shared_mem, dtype, round_to=None):
    """Inner Implementation of the Pad batchify

    Parameters
    ----------
    arrs : list
    pad_axis : int
    pad_val : number
    use_shared_mem : bool, default False

    Returns
    -------
    ret : NDArray
    original_length : NDArray
    """
    if isinstance(arrs[0], mx.nd.NDArray):
        dtype = arrs[0].dtype if dtype is None else dtype
        arrs = [arr.asnumpy() for arr in arrs]
    elif not isinstance(arrs[0], np.ndarray):
        arrs = [np.asarray(ele) for ele in arrs]
    else:
        dtype = arrs[0].dtype if dtype is None else dtype

    original_length = [ele.shape[pad_axis] for ele in arrs]
    max_size = max(original_length)
    if round_to is not None:
        max_size = round_to * math.ceil(max_size / round_to)

    ret_shape = list(arrs[0].shape)
    ret_shape[pad_axis] = max_size
    ret_shape = (len(arrs), ) + tuple(ret_shape)

    ret = np.full(shape=ret_shape, fill_value=pad_val, dtype=dtype)

    for i, arr in enumerate(arrs):
        if arr.shape[pad_axis] == max_size:
            ret[i] = arr
        else:
            slices = [slice(None) for _ in range(arr.ndim)]
            slices[pad_axis] = slice(0, arr.shape[pad_axis])
            if slices[pad_axis].start != slices[pad_axis].stop:
                slices = [slice(i, i + 1)] + slices
                ret[tuple(slices)] = arr

    # ctx = mx.Context('cpu_shared', 0) if use_shared_mem else mx.cpu()
    # ret = mx.nd.array(ret, ctx=ctx, dtype=dtype)
    original_length = np.array(original_length, dtype=np.int32)
    # original_length = mx.nd.array(original_length, ctx=ctx, dtype=np.int32)

    return ret, original_length