def _load_rows_into_mem()

in petastorm/pytorch.py [0:0]


def _load_rows_into_mem(reader, transform_fn, rows_capacity):
    """Load upto rows_capacity number of rows from reader into memory.

    :param reader: petastorm Reader instance.
    :param transform_fn: transform function which converts batches from the reader to PyTorch tensors
    :param rows_capacity: max number of rows to be loaded into memory (truncated to real size if capacity
        is larger than total number of rows in reader).
    :return: (keys, buffer): keys is a dict_keys storing column names and buffer is a list storing loaded rows.
    """
    n_rows = 0
    buffer_full = False
    buffer = None
    keys = None

    for row in reader:
        if buffer_full:
            break
        # Default collate does not work nicely on namedtuples and treat them as lists
        # Using dict will result in the yielded structures being dicts as well
        row_as_dict = row._asdict()

        # Promote some types that are incompatible with pytorch to be pytorch friendly.
        _sanitize_pytorch_types(row_as_dict)

        for k, v in row_as_dict.items():
            if not reader.batched_output:
                row_as_dict[k] = transform_fn([v])
            else:
                row_as_dict[k] = transform_fn(v)

        if not keys:
            keys = row_as_dict.keys()

        # Add rows to buffer
        items = list(row_as_dict.values())
        expected_rows = n_rows + len(items[0])
        last_row = len(items[0])

        if rows_capacity <= expected_rows:
            buffer_full = True
            last_row = rows_capacity-n_rows
            expected_rows = rows_capacity
        if buffer is None:
            # Initialize buffer as a list of empty tensors
            buffer = []
            for v in items:
                buffer.append(torch.empty((rows_capacity,) + v.shape[1:], dtype=v.dtype, device=v.device))
        # Copy new items into buffer
        for i, v in enumerate(items):
            buffer[i][n_rows:expected_rows] = v[:last_row]
        n_rows = expected_rows

    # At this point, dataloader has enough rows storted in memory.
    # Stop the reader rather than draining the remainder of reader.
    # If reader has infinite epochs, draining will be deadlock.
    reader.stop()
    reader.join()
    # Truncate empty tensors if capacity is larger than total rows.
    if n_rows < rows_capacity:
        for i, v in enumerate(buffer):
            buffer[i] = buffer[i][:n_rows]
    return (keys, buffer)