def split_segments()

in mlsh_code/rollouts.py [0:0]


def split_segments(seg, macrolen, num_subpolicies):
    subpol_counts = []
    for i in range(num_subpolicies):
        subpol_counts.append(0)
    for macro_ac in seg["macro_ac"]:
        subpol_counts[macro_ac] += macrolen
    subpols = []
    for i in range(num_subpolicies):
        obs = np.array([seg["ob"][0] for _ in range(subpol_counts[i])])
        advs = np.zeros(subpol_counts[i], 'float32')
        tdlams = np.zeros(subpol_counts[i], 'float32')
        acs = np.array([seg["ac"][0] for _ in range(subpol_counts[i])])
        subpols.append({"ob": obs, "adv": advs, "tdlamret": tdlams, "ac": acs})
    subpol_counts = []
    for i in range(num_subpolicies):
        subpol_counts.append(0)
    for i in range(len(seg["ob"])):
        mac = seg["macro_ac"][int(i/macrolen)]
        subpols[mac]["ob"][subpol_counts[mac]] = seg["ob"][i]
        subpols[mac]["adv"][subpol_counts[mac]] = seg["adv"][i]
        subpols[mac]["tdlamret"][subpol_counts[mac]] = seg["tdlamret"][i]
        subpols[mac]["ac"][subpol_counts[mac]] = seg["ac"][i]
        subpol_counts[mac] += 1
    return subpols