in Models/exprsynth/model.py [0:0]
def __raw_batches_from_chunks_iterator(self, data_chunk_paths: List[RichPath], is_train: bool=False) -> Iterable[Tuple[Dict[str, Any], int, int]]:
chunk_iterator = read_data_chunks(data_chunk_paths, shuffle_chunks=is_train, num_workers=5, max_queue_size=25)
ChunkInformation = namedtuple("ChunkInformation", ["data", "sample_idx_list", "samples_used_so_far"])
open_chunks_info = []
def open_new_chunk():
try:
new_chunk = next(chunk_iterator)
except StopIteration:
return
num_samples_in_chunk = len(new_chunk)
chunk_sample_idx_list = np.arange(num_samples_in_chunk)
if is_train:
np.random.shuffle(chunk_sample_idx_list)
open_chunks_info.append(ChunkInformation(new_chunk, chunk_sample_idx_list, [0]))
# Keep a handful of chunks open:
for _ in range(25 if is_train else 1):
open_new_chunk()
cur_chunk_idx = 0
cur_batch_data = {} # type: Dict[str, Any]
self._init_minibatch(cur_batch_data)
samples_used_so_far = 0
while len(open_chunks_info) > 0:
# Read in round-robin fashion from chunks:
cur_chunk_idx = (cur_chunk_idx + 1) % len(open_chunks_info)
cur_chunk_info = open_chunks_info[cur_chunk_idx]
# Get next sample:
cur_sample = cur_chunk_info.data[cur_chunk_info.sample_idx_list[cur_chunk_info.samples_used_so_far[0]]]
cur_batch_data['samples_in_batch'] += 1
cur_chunk_info.samples_used_so_far[0] += 1
# Check if chunk is done now, and try open a new one:
if cur_chunk_info.samples_used_so_far[0] >= len(cur_chunk_info.data):
del(open_chunks_info[cur_chunk_idx])
open_new_chunk() # will silently fail if we are out of chunks
# Add sample to current minibatch. Yield and prepare fresh one if we are full now:
batch_finished = self._extend_minibatch_by_sample(cur_batch_data, cur_sample)
if batch_finished:
samples_used_so_far += cur_batch_data['samples_in_batch']
yield cur_batch_data, cur_batch_data['samples_in_batch'], samples_used_so_far
cur_batch_data = {}
self._init_minibatch(cur_batch_data)
# Return the last open, incomplete batch if it's non-empty:
if cur_batch_data['samples_in_batch'] > 0:
samples_used_so_far += cur_batch_data['samples_in_batch']
yield cur_batch_data, cur_batch_data['samples_in_batch'], samples_used_so_far