in rts/game_MC/actor_critic_changed.py [0:0]
def update(self, mi, batch, stats):
''' Actor critic model '''
m = mi["model"]
args = self.args
T = batch["a"].size(0)
state_curr = m(batch.hist(T - 1))
self.discounted_reward.setR(state_curr["V"].squeeze().data, stats)
next_h = state_curr["h"].data
policies = [0] * T
policies[T - 1] = state_curr["pi"].data
for t in range(T - 2, -1, -1):
bht = batch.hist(t)
state_curr = m.forward(bht)
# go through the sample and get the rewards.
a = batch["a"][t]
V = state_curr["V"].squeeze()
R = self.discounted_reward.feed(
dict(r=batch["r"][t], terminal=batch["terminal"][t]),
stats=stats)
pi = state_curr["pi"]
policies[t] = pi.data
overall_err = None
if not args.fixed_policy:
overall_err = self.pg.feed(R - V.data, state_curr, bht, stats, old_pi_s=bht)
overall_err += self.value_matcher.feed(dict(V=V, target=R), stats)
if args.h_smooth:
curr_h = state_curr["h"]
# Block gradient
curr_h = Variable(curr_h.data)
future_pred = m.transition(curr_h, a)
pred_h = future_pred["hf"]
predict_err = self.prediction_loss(pred_h, Variable(next_h))
overall_err = add_err(overall_err, predict_err)
stats["predict_err"].feed(predict_err.data[0])
if args.contrastive_V:
# Sample an action other than the current action.
prob = pi.data.clone().fill_(1 / (pi.size(1) - 1))
# Make the selected entry zero.
prob.scatter_(1, a.view(-1, 1), 0.0)
other_a = prob.multinomial(1)
other_future_pred = m.transition(curr_h, other_a)
other_pred_h = other_future_pred["hf"]
# Make sure the predicted values are lower than the gt
# one (we might need to add prob?)
# Stop the gradient.
pi_V = m.decision(pred_h.data)
pi_V_other = m.decision(other_pred_h.data)
all_one = R.clone().view(-1, 1).fill_(1.0)
rank_err = self.rank_loss(pi_V["V"], pi_V_other["V"], Variable(all_one))
value_err = self.prediction_loss(pi_V["V"], Variable(R))
stats["rank_err"].feed(rank_err.data[0])
stats["value_err"].feed(value_err.data[0])
overall_err = add_err(overall_err, rank_err)
overall_err = add_err(overall_err, value_err)
if overall_err is not None:
overall_err.backward()
next_h = state_curr["h"].data
if overall_err is not None:
stats["cost"].feed(overall_err.data[0])
#print("[%d]: reward=%.4f, sum_reward=%.2f, acc_reward=%.4f, value_err=%.4f, policy_err=%.4f" % (i, r.mean(), r.sum(), R.mean(), value_err.data[0], policy_err.data[0]))
if args.h_match_policy or args.h_match_action:
state_curr = m.forward(batch.hist(0))
h = state_curr["h"]
if args.fixed_policy:
h = Variable(h.data)
total_policy_err = None
for t in range(0, T - 1):
# forwarded policy should be identical with current policy
V_pi = m.decision_fix_weight(h)
a = batch["a"][t]
pi_h = V_pi["pi"]
# Nothing to learn when t = 0
if t > 0:
if args.h_match_policy:
policy_err = self.policy_match_loss(pi_h, Variable(policies[t]))
stats["policy_match_err%d" % t].feed(policy_err.data[0])
elif args.h_match_action:
# Add normalization constant
logpi_h = (pi_h + args.min_prob).log()
policy_err = self.policy_max_action_loss(logpi_h, Variable(a))
stats["policy_match_a_err%d" % t].feed(policy_err.data[0])
total_policy_err = add_err(total_policy_err, policy_err)
future_pred = m.transition(h, a)
h = future_pred["hf"]
total_policy_err.backward()
stats["total_policy_match_err"].feed(total_policy_err.data[0])