in long_term/pace_network.py [0:0]
def _prepare_next_batch(self, batch_size, chunk_length, dataset, sequences):
batch_in = np.zeros((batch_size, chunk_length, 2), dtype='float32')
batch_out = np.zeros((batch_size, chunk_length, 4), dtype='float32')
pseudo_passes = (len(sequences)+batch_size-1)//batch_size
probs = []
for i, (subject, action) in enumerate(sequences):
if 'spline' not in dataset[subject][action]:
raise KeyError('No splines found. Perhaps you forgot to compute them?')
probs.append(dataset[subject][action]['spline'].size())
probs = np.array(probs)/np.sum(probs)
for p in range(pseudo_passes):
idxs = np.random.choice(len(sequences), size=batch_size, replace=True, p=probs)
for i, (subject, action) in enumerate(np.array(sequences)[idxs]):
# Pick a random chunk from each sequence
spline = dataset[subject][action]['spline']
full_seq_length = spline.size()
max_index = full_seq_length - chunk_length + 1
start_idx = np.random.randint(0, max_index)
end_idx = start_idx + chunk_length
inputs, outputs = PaceNetwork._extract_features(spline)
batch_in[i], batch_out[i] = inputs[start_idx:end_idx], outputs[start_idx:end_idx]
yield batch_in, batch_out