in data_loader.py [0:0]
def __init__(self, dataset_dir, n_workers=8, batch_size=8, n_epochs=1, max_queue_size=16):
assert n_workers >= batch_size, "Number of workers must be equal or greater than batch size"
self.dataset_dir = dataset_dir
self.n_workers = n_workers
self.n_epochs = n_epochs
self.batch_size = batch_size
self.max_queue_size = max_queue_size
unique_ids = glob.glob(os.path.join(dataset_dir, "*.mp4"))
unique_ids = list(set([os.path.basename(x).split(".")[0] for x in unique_ids]))
self.unique_ids = unique_ids
# Create tuples of (video_path, json_path) for each unique_id
demonstration_tuples = []
for unique_id in unique_ids:
video_path = os.path.abspath(os.path.join(dataset_dir, unique_id + ".mp4"))
json_path = os.path.abspath(os.path.join(dataset_dir, unique_id + ".jsonl"))
demonstration_tuples.append((video_path, json_path))
assert n_workers <= len(demonstration_tuples), f"n_workers should be lower or equal than number of demonstrations {len(demonstration_tuples)}"
# Repeat dataset for n_epochs times, shuffling the order for
# each epoch
self.demonstration_tuples = []
for i in range(n_epochs):
random.shuffle(demonstration_tuples)
self.demonstration_tuples += demonstration_tuples
self.task_queue = Queue()
self.n_steps_processed = 0
for trajectory_id, task in enumerate(self.demonstration_tuples):
self.task_queue.put((trajectory_id, *task))
for _ in range(n_workers):
self.task_queue.put(None)
self.output_queues = [Queue(maxsize=max_queue_size) for _ in range(n_workers)]
self.quit_workers_event = Event()
self.processes = [
Process(
target=data_loader_worker,
args=(
self.task_queue,
output_queue,
self.quit_workers_event,
),
daemon=True
)
for output_queue in self.output_queues
]
for process in self.processes:
process.start()