def traj_segment_generator()

in mlsh_code/rollouts.py [0:0]


def traj_segment_generator(pi, sub_policies, env, macrolen, horizon, stochastic, args):
    replay = args.replay
    t = 0
    ac = env.action_space.sample()
    new = True
    rew = 0.0
    ob = env.reset()
    cur_subpolicy = 0
    macro_vpred = 0
    macro_horizon = math.ceil(horizon/macrolen)

    cur_ep_ret = 0
    cur_ep_len = 0
    ep_rets = []
    ep_lens = []

    # Initialize history arrays
    obs = np.array([ob for _ in range(horizon)])
    rews = np.zeros(horizon, 'float32')
    vpreds = np.zeros(horizon, 'float32')
    news = np.zeros(horizon, 'int32')
    acs = np.array([ac for _ in range(horizon)])
    macro_acs = np.zeros(macro_horizon, 'int32')
    macro_vpreds = np.zeros(macro_horizon, 'float32')

    ob = env.reset()

    x = 0
    z = 0

    # total = [0,0]
    # tt = 0

    while True:
        if t % macrolen == 0:
            cur_subpolicy, macro_vpred = pi.act(stochastic, ob)

            if np.random.uniform() < 0.1:
                cur_subpolicy = np.random.randint(0, len(sub_policies))
            if args.force_subpolicy is not None:
                cur_subpolicy = args.force_subpolicy
                z += 1

        ac, vpred = sub_policies[cur_subpolicy].act(stochastic, ob)
        # if np.random.uniform(0,1) < 0.05:
            # ac = env.action_space.sample()

        if t > 0 and t % horizon == 0:
            # tt += 1
            # print(total)
            # total = [0,0]
            dicti = {"ob" : obs, "rew" : rews, "vpred" : vpreds, "new" : news, "ac" : acs, "ep_rets" : ep_rets, "ep_lens" : ep_lens, "macro_ac" : macro_acs, "macro_vpred" : macro_vpreds}
            yield {key: np.copy(val) for key,val in dicti.items()}
            ep_rets = []
            ep_lens = []
            x += 1

        i = t % horizon
        obs[i] = ob
        vpreds[i] = vpred
        news[i] = new
        acs[i] = ac
        if t % macrolen == 0:
            macro_acs[int(i/macrolen)] = cur_subpolicy
            macro_vpreds[int(i/macrolen)] = macro_vpred

        ob, rew, new, info = env.step(ac)
        rews[i] = rew

        if replay:
            if len(ep_rets) == 0:
                # if x % 5 == 0:
                env.render()
                    # print(info)
            pass

        cur_ep_ret += rew
        cur_ep_len += 1
        if new and ((t+1) % macrolen == 0):
        # if new:
            ep_rets.append(cur_ep_ret)
            ep_lens.append(cur_ep_len)
            cur_ep_ret = 0
            cur_ep_len = 0
            ob = env.reset()
        t += 1