mlsh_code/rollouts.py (120 lines of code) (raw):

import numpy as np import math import time def traj_segment_generator(pi, sub_policies, env, macrolen, horizon, stochastic, args): replay = args.replay t = 0 ac = env.action_space.sample() new = True rew = 0.0 ob = env.reset() cur_subpolicy = 0 macro_vpred = 0 macro_horizon = math.ceil(horizon/macrolen) cur_ep_ret = 0 cur_ep_len = 0 ep_rets = [] ep_lens = [] # Initialize history arrays obs = np.array([ob for _ in range(horizon)]) rews = np.zeros(horizon, 'float32') vpreds = np.zeros(horizon, 'float32') news = np.zeros(horizon, 'int32') acs = np.array([ac for _ in range(horizon)]) macro_acs = np.zeros(macro_horizon, 'int32') macro_vpreds = np.zeros(macro_horizon, 'float32') ob = env.reset() x = 0 z = 0 # total = [0,0] # tt = 0 while True: if t % macrolen == 0: cur_subpolicy, macro_vpred = pi.act(stochastic, ob) if np.random.uniform() < 0.1: cur_subpolicy = np.random.randint(0, len(sub_policies)) if args.force_subpolicy is not None: cur_subpolicy = args.force_subpolicy z += 1 ac, vpred = sub_policies[cur_subpolicy].act(stochastic, ob) # if np.random.uniform(0,1) < 0.05: # ac = env.action_space.sample() if t > 0 and t % horizon == 0: # tt += 1 # print(total) # total = [0,0] dicti = {"ob" : obs, "rew" : rews, "vpred" : vpreds, "new" : news, "ac" : acs, "ep_rets" : ep_rets, "ep_lens" : ep_lens, "macro_ac" : macro_acs, "macro_vpred" : macro_vpreds} yield {key: np.copy(val) for key,val in dicti.items()} ep_rets = [] ep_lens = [] x += 1 i = t % horizon obs[i] = ob vpreds[i] = vpred news[i] = new acs[i] = ac if t % macrolen == 0: macro_acs[int(i/macrolen)] = cur_subpolicy macro_vpreds[int(i/macrolen)] = macro_vpred ob, rew, new, info = env.step(ac) rews[i] = rew if replay: if len(ep_rets) == 0: # if x % 5 == 0: env.render() # print(info) pass cur_ep_ret += rew cur_ep_len += 1 if new and ((t+1) % macrolen == 0): # if new: ep_rets.append(cur_ep_ret) ep_lens.append(cur_ep_len) cur_ep_ret = 0 cur_ep_len = 0 ob = env.reset() t += 1 def add_advantage_macro(seg, macrolen, gamma, lam): new = np.append(seg["new"][0::macrolen], 0) # last element is only used for last vtarg, but we already zeroed it if last new = 1 vpred = np.append(seg["macro_vpred"], 0) T = int(len(seg["rew"])/macrolen) seg["macro_adv"] = gaelam = np.empty(T, 'float32') rew = np.sum(seg["rew"].reshape(-1, macrolen), axis=1) lastgaelam = 0 for t in reversed(range(T)): nonterminal = 1-new[t+1] delta = rew[t] + gamma * vpred[t+1] * nonterminal - vpred[t] gaelam[t] = lastgaelam = delta + gamma * lam * nonterminal * lastgaelam seg["macro_tdlamret"] = seg["macro_adv"] + seg["macro_vpred"] # print(seg["macro_ac"]) # print(rew) # print(seg["macro_adv"]) seg["macro_ob"] = seg["ob"][0::macrolen] def prepare_allrolls(allrolls, macrolen, gamma, lam, num_subpolicies): for i in range(len(allrolls) - 1): for key,value in allrolls[i + 1].items(): allrolls[0][key] = np.append(allrolls[0][key], value, axis=0) test_seg = allrolls[0] # calculate advantages new = np.append(test_seg["new"], 0) vpred = np.append(test_seg["vpred"], 0) T = len(test_seg["rew"]) test_seg["adv"] = gaelam = np.empty(T, 'float32') rew = test_seg["rew"] lastgaelam = 0 for t in reversed(range(T)): nonterminal = 1-new[t+1] delta = rew[t] + gamma * vpred[t+1] * nonterminal - vpred[t] gaelam[t] = lastgaelam = delta + gamma * lam * nonterminal * lastgaelam test_seg["tdlamret"] = test_seg["adv"] + test_seg["vpred"] split_test = split_segments(test_seg, macrolen, num_subpolicies) return split_test def split_segments(seg, macrolen, num_subpolicies): subpol_counts = [] for i in range(num_subpolicies): subpol_counts.append(0) for macro_ac in seg["macro_ac"]: subpol_counts[macro_ac] += macrolen subpols = [] for i in range(num_subpolicies): obs = np.array([seg["ob"][0] for _ in range(subpol_counts[i])]) advs = np.zeros(subpol_counts[i], 'float32') tdlams = np.zeros(subpol_counts[i], 'float32') acs = np.array([seg["ac"][0] for _ in range(subpol_counts[i])]) subpols.append({"ob": obs, "adv": advs, "tdlamret": tdlams, "ac": acs}) subpol_counts = [] for i in range(num_subpolicies): subpol_counts.append(0) for i in range(len(seg["ob"])): mac = seg["macro_ac"][int(i/macrolen)] subpols[mac]["ob"][subpol_counts[mac]] = seg["ob"][i] subpols[mac]["adv"][subpol_counts[mac]] = seg["adv"][i] subpols[mac]["tdlamret"][subpol_counts[mac]] = seg["tdlamret"][i] subpols[mac]["ac"][subpol_counts[mac]] = seg["ac"][i] subpol_counts[mac] += 1 return subpols