torchbenchmark/util/prefetch.py (
8
lines of code) (
raw
):
def prefetch_loader(loader, device): result = [] for data in loader: items = [] for item in data: items.append(item.to(device)) result.append(tuple(items)) return result