in behavioural_cloning.py [0:0]
def behavioural_cloning_train(data_dir, in_model, in_weights, out_weights):
agent_policy_kwargs, agent_pi_head_kwargs = load_model_parameters(in_model)
# To create model with the right environment.
# All basalt environments have the same settings, so any of them works here
env = gym.make("MineRLBasaltFindCave-v0")
agent = MineRLAgent(env, device=DEVICE, policy_kwargs=agent_policy_kwargs, pi_head_kwargs=agent_pi_head_kwargs)
agent.load_weights(in_weights)
env.close()
policy = agent.policy
trainable_parameters = policy.parameters()
# Parameters taken from the OpenAI VPT paper
optimizer = th.optim.Adam(
trainable_parameters,
lr=LEARNING_RATE,
weight_decay=WEIGHT_DECAY
)
data_loader = DataLoader(
dataset_dir=data_dir,
n_workers=N_WORKERS,
batch_size=BATCH_SIZE,
n_epochs=EPOCHS
)
start_time = time.time()
# Keep track of the hidden state per episode/trajectory.
# DataLoader provides unique id for each episode, which will
# be different even for the same trajectory when it is loaded
# up again
episode_hidden_states = {}
dummy_first = th.from_numpy(np.array((False,))).to(DEVICE)
loss_sum = 0
for batch_i, (batch_images, batch_actions, batch_episode_id) in enumerate(data_loader):
batch_loss = 0
for image, action, episode_id in zip(batch_images, batch_actions, batch_episode_id):
agent_action = agent._env_action_to_agent(action, to_torch=True, check_if_null=True)
if agent_action is None:
# Action was null
continue
agent_obs = agent._env_obs_to_agent({"pov": image})
if episode_id not in episode_hidden_states:
# TODO need to clean up this hidden state after worker is done with the work item.
# Leaks memory, but not tooooo much at these scales (will be a problem later).
episode_hidden_states[episode_id] = policy.initial_state(1)
agent_state = episode_hidden_states[episode_id]
pi_distribution, v_prediction, new_agent_state = policy.get_output_for_observation(
agent_obs,
agent_state,
dummy_first
)
log_prob = policy.get_logprob_of_action(pi_distribution, agent_action)
# Make sure we do not try to backprop through sequence
# (fails with current accumulation)
new_agent_state = tree_map(lambda x: x.detach(), new_agent_state)
episode_hidden_states[episode_id] = new_agent_state
# Finally, update the agent to increase the probability of the
# taken action.
# Remember to take mean over batch losses
loss = -log_prob / BATCH_SIZE
batch_loss += loss.item()
loss.backward()
th.nn.utils.clip_grad_norm_(trainable_parameters, MAX_GRAD_NORM)
optimizer.step()
optimizer.zero_grad()
loss_sum += batch_loss
if batch_i % LOSS_REPORT_RATE == 0:
time_since_start = time.time() - start_time
print(f"Time: {time_since_start:.2f}, Batches: {batch_i}, Avrg loss: {loss_sum / LOSS_REPORT_RATE:.4f}")
loss_sum = 0
state_dict = policy.state_dict()
th.save(state_dict, out_weights)