in spinup/algos/pytorch/vpg/vpg.py [0:0]
def vpg(env_fn, actor_critic=core.MLPActorCritic, ac_kwargs=dict(), seed=0,
steps_per_epoch=4000, epochs=50, gamma=0.99, pi_lr=3e-4,
vf_lr=1e-3, train_v_iters=80, lam=0.97, max_ep_len=1000,
logger_kwargs=dict(), save_freq=10):
"""
Vanilla Policy Gradient
(with GAE-Lambda for advantage estimation)
Args:
env_fn : A function which creates a copy of the environment.
The environment must satisfy the OpenAI Gym API.
actor_critic: The constructor method for a PyTorch Module with a
``step`` method, an ``act`` method, a ``pi`` module, and a ``v``
module. The ``step`` method should accept a batch of observations
and return:
=========== ================ ======================================
Symbol Shape Description
=========== ================ ======================================
``a`` (batch, act_dim) | Numpy array of actions for each
| observation.
``v`` (batch,) | Numpy array of value estimates
| for the provided observations.
``logp_a`` (batch,) | Numpy array of log probs for the
| actions in ``a``.
=========== ================ ======================================
The ``act`` method behaves the same as ``step`` but only returns ``a``.
The ``pi`` module's forward call should accept a batch of
observations and optionally a batch of actions, and return:
=========== ================ ======================================
Symbol Shape Description
=========== ================ ======================================
``pi`` N/A | Torch Distribution object, containing
| a batch of distributions describing
| the policy for the provided observations.
``logp_a`` (batch,) | Optional (only returned if batch of
| actions is given). Tensor containing
| the log probability, according to
| the policy, of the provided actions.
| If actions not given, will contain
| ``None``.
=========== ================ ======================================
The ``v`` module's forward call should accept a batch of observations
and return:
=========== ================ ======================================
Symbol Shape Description
=========== ================ ======================================
``v`` (batch,) | Tensor containing the value estimates
| for the provided observations. (Critical:
| make sure to flatten this!)
=========== ================ ======================================
ac_kwargs (dict): Any kwargs appropriate for the ActorCritic object
you provided to VPG.
seed (int): Seed for random number generators.
steps_per_epoch (int): Number of steps of interaction (state-action pairs)
for the agent and the environment in each epoch.
epochs (int): Number of epochs of interaction (equivalent to
number of policy updates) to perform.
gamma (float): Discount factor. (Always between 0 and 1.)
pi_lr (float): Learning rate for policy optimizer.
vf_lr (float): Learning rate for value function optimizer.
train_v_iters (int): Number of gradient descent steps to take on
value function per epoch.
lam (float): Lambda for GAE-Lambda. (Always between 0 and 1,
close to 1.)
max_ep_len (int): Maximum length of trajectory / episode / rollout.
logger_kwargs (dict): Keyword args for EpochLogger.
save_freq (int): How often (in terms of gap between epochs) to save
the current policy and value function.
"""
# Special function to avoid certain slowdowns from PyTorch + MPI combo.
setup_pytorch_for_mpi()
# Set up logger and save configuration
logger = EpochLogger(**logger_kwargs)
logger.save_config(locals())
# Random seed
seed += 10000 * proc_id()
torch.manual_seed(seed)
np.random.seed(seed)
# Instantiate environment
env = env_fn()
obs_dim = env.observation_space.shape
act_dim = env.action_space.shape
# Create actor-critic module
ac = actor_critic(env.observation_space, env.action_space, **ac_kwargs)
# Sync params across processes
sync_params(ac)
# Count variables
var_counts = tuple(core.count_vars(module) for module in [ac.pi, ac.v])
logger.log('\nNumber of parameters: \t pi: %d, \t v: %d\n'%var_counts)
# Set up experience buffer
local_steps_per_epoch = int(steps_per_epoch / num_procs())
buf = VPGBuffer(obs_dim, act_dim, local_steps_per_epoch, gamma, lam)
# Set up function for computing VPG policy loss
def compute_loss_pi(data):
obs, act, adv, logp_old = data['obs'], data['act'], data['adv'], data['logp']
# Policy loss
pi, logp = ac.pi(obs, act)
loss_pi = -(logp * adv).mean()
# Useful extra info
approx_kl = (logp_old - logp).mean().item()
ent = pi.entropy().mean().item()
pi_info = dict(kl=approx_kl, ent=ent)
return loss_pi, pi_info
# Set up function for computing value loss
def compute_loss_v(data):
obs, ret = data['obs'], data['ret']
return ((ac.v(obs) - ret)**2).mean()
# Set up optimizers for policy and value function
pi_optimizer = Adam(ac.pi.parameters(), lr=pi_lr)
vf_optimizer = Adam(ac.v.parameters(), lr=vf_lr)
# Set up model saving
logger.setup_pytorch_saver(ac)
def update():
data = buf.get()
# Get loss and info values before update
pi_l_old, pi_info_old = compute_loss_pi(data)
pi_l_old = pi_l_old.item()
v_l_old = compute_loss_v(data).item()
# Train policy with a single step of gradient descent
pi_optimizer.zero_grad()
loss_pi, pi_info = compute_loss_pi(data)
loss_pi.backward()
mpi_avg_grads(ac.pi) # average grads across MPI processes
pi_optimizer.step()
# Value function learning
for i in range(train_v_iters):
vf_optimizer.zero_grad()
loss_v = compute_loss_v(data)
loss_v.backward()
mpi_avg_grads(ac.v) # average grads across MPI processes
vf_optimizer.step()
# Log changes from update
kl, ent = pi_info['kl'], pi_info_old['ent']
logger.store(LossPi=pi_l_old, LossV=v_l_old,
KL=kl, Entropy=ent,
DeltaLossPi=(loss_pi.item() - pi_l_old),
DeltaLossV=(loss_v.item() - v_l_old))
# Prepare for interaction with environment
start_time = time.time()
o, ep_ret, ep_len = env.reset(), 0, 0
# Main loop: collect experience in env and update/log each epoch
for epoch in range(epochs):
for t in range(local_steps_per_epoch):
a, v, logp = ac.step(torch.as_tensor(o, dtype=torch.float32))
next_o, r, d, _ = env.step(a)
ep_ret += r
ep_len += 1
# save and log
buf.store(o, a, r, v, logp)
logger.store(VVals=v)
# Update obs (critical!)
o = next_o
timeout = ep_len == max_ep_len
terminal = d or timeout
epoch_ended = t==local_steps_per_epoch-1
if terminal or epoch_ended:
if epoch_ended and not(terminal):
print('Warning: trajectory cut off by epoch at %d steps.'%ep_len, flush=True)
# if trajectory didn't reach terminal state, bootstrap value target
if timeout or epoch_ended:
_, v, _ = ac.step(torch.as_tensor(o, dtype=torch.float32))
else:
v = 0
buf.finish_path(v)
if terminal:
# only save EpRet / EpLen if trajectory finished
logger.store(EpRet=ep_ret, EpLen=ep_len)
o, ep_ret, ep_len = env.reset(), 0, 0
# Save model
if (epoch % save_freq == 0) or (epoch == epochs-1):
logger.save_state({'env': env}, None)
# Perform VPG update!
update()
# Log info about epoch
logger.log_tabular('Epoch', epoch)
logger.log_tabular('EpRet', with_min_and_max=True)
logger.log_tabular('EpLen', average_only=True)
logger.log_tabular('VVals', with_min_and_max=True)
logger.log_tabular('TotalEnvInteracts', (epoch+1)*steps_per_epoch)
logger.log_tabular('LossPi', average_only=True)
logger.log_tabular('LossV', average_only=True)
logger.log_tabular('DeltaLossPi', average_only=True)
logger.log_tabular('DeltaLossV', average_only=True)
logger.log_tabular('Entropy', average_only=True)
logger.log_tabular('KL', average_only=True)
logger.log_tabular('Time', time.time()-start_time)
logger.dump_tabular()