in mlsh_code/rollouts.py [0:0]
def split_segments(seg, macrolen, num_subpolicies):
subpol_counts = []
for i in range(num_subpolicies):
subpol_counts.append(0)
for macro_ac in seg["macro_ac"]:
subpol_counts[macro_ac] += macrolen
subpols = []
for i in range(num_subpolicies):
obs = np.array([seg["ob"][0] for _ in range(subpol_counts[i])])
advs = np.zeros(subpol_counts[i], 'float32')
tdlams = np.zeros(subpol_counts[i], 'float32')
acs = np.array([seg["ac"][0] for _ in range(subpol_counts[i])])
subpols.append({"ob": obs, "adv": advs, "tdlamret": tdlams, "ac": acs})
subpol_counts = []
for i in range(num_subpolicies):
subpol_counts.append(0)
for i in range(len(seg["ob"])):
mac = seg["macro_ac"][int(i/macrolen)]
subpols[mac]["ob"][subpol_counts[mac]] = seg["ob"][i]
subpols[mac]["adv"][subpol_counts[mac]] = seg["adv"][i]
subpols[mac]["tdlamret"][subpol_counts[mac]] = seg["tdlamret"][i]
subpols[mac]["ac"][subpol_counts[mac]] = seg["ac"][i]
subpol_counts[mac] += 1
return subpols