behavioural_cloning.py (86 lines of code) (raw):

# Basic behavioural cloning # Note: this uses gradient accumulation in batches of ones # to perform training. # This will fit inside even smaller GPUs (tested on 8GB one), # but is slow. # NOTE: This is _not_ the original code used for VPT! # This is merely to illustrate how to fine-tune the models and includes # the processing steps used. # This will likely be much worse than what original VPT did: # we are not training on full sequences, but only one step at a time to save VRAM. from argparse import ArgumentParser import pickle import time import gym import minerl import torch as th import numpy as np from agent import PI_HEAD_KWARGS, MineRLAgent from data_loader import DataLoader from lib.tree_util import tree_map EPOCHS = 2 # Needs to be <= number of videos BATCH_SIZE = 8 # Ideally more than batch size to create # variation in datasets (otherwise, you will # get a bunch of consecutive samples) # Decrease this (and batch_size) if you run out of memory N_WORKERS = 12 DEVICE = "cuda" LOSS_REPORT_RATE = 100 LEARNING_RATE = 0.000181 WEIGHT_DECAY = 0.039428 MAX_GRAD_NORM = 5.0 def load_model_parameters(path_to_model_file): agent_parameters = pickle.load(open(path_to_model_file, "rb")) policy_kwargs = agent_parameters["model"]["args"]["net"]["args"] pi_head_kwargs = agent_parameters["model"]["args"]["pi_head_opts"] pi_head_kwargs["temperature"] = float(pi_head_kwargs["temperature"]) return policy_kwargs, pi_head_kwargs def behavioural_cloning_train(data_dir, in_model, in_weights, out_weights): agent_policy_kwargs, agent_pi_head_kwargs = load_model_parameters(in_model) # To create model with the right environment. # All basalt environments have the same settings, so any of them works here env = gym.make("MineRLBasaltFindCave-v0") agent = MineRLAgent(env, device=DEVICE, policy_kwargs=agent_policy_kwargs, pi_head_kwargs=agent_pi_head_kwargs) agent.load_weights(in_weights) env.close() policy = agent.policy trainable_parameters = policy.parameters() # Parameters taken from the OpenAI VPT paper optimizer = th.optim.Adam( trainable_parameters, lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY ) data_loader = DataLoader( dataset_dir=data_dir, n_workers=N_WORKERS, batch_size=BATCH_SIZE, n_epochs=EPOCHS ) start_time = time.time() # Keep track of the hidden state per episode/trajectory. # DataLoader provides unique id for each episode, which will # be different even for the same trajectory when it is loaded # up again episode_hidden_states = {} dummy_first = th.from_numpy(np.array((False,))).to(DEVICE) loss_sum = 0 for batch_i, (batch_images, batch_actions, batch_episode_id) in enumerate(data_loader): batch_loss = 0 for image, action, episode_id in zip(batch_images, batch_actions, batch_episode_id): agent_action = agent._env_action_to_agent(action, to_torch=True, check_if_null=True) if agent_action is None: # Action was null continue agent_obs = agent._env_obs_to_agent({"pov": image}) if episode_id not in episode_hidden_states: # TODO need to clean up this hidden state after worker is done with the work item. # Leaks memory, but not tooooo much at these scales (will be a problem later). episode_hidden_states[episode_id] = policy.initial_state(1) agent_state = episode_hidden_states[episode_id] pi_distribution, v_prediction, new_agent_state = policy.get_output_for_observation( agent_obs, agent_state, dummy_first ) log_prob = policy.get_logprob_of_action(pi_distribution, agent_action) # Make sure we do not try to backprop through sequence # (fails with current accumulation) new_agent_state = tree_map(lambda x: x.detach(), new_agent_state) episode_hidden_states[episode_id] = new_agent_state # Finally, update the agent to increase the probability of the # taken action. # Remember to take mean over batch losses loss = -log_prob / BATCH_SIZE batch_loss += loss.item() loss.backward() th.nn.utils.clip_grad_norm_(trainable_parameters, MAX_GRAD_NORM) optimizer.step() optimizer.zero_grad() loss_sum += batch_loss if batch_i % LOSS_REPORT_RATE == 0: time_since_start = time.time() - start_time print(f"Time: {time_since_start:.2f}, Batches: {batch_i}, Avrg loss: {loss_sum / LOSS_REPORT_RATE:.4f}") loss_sum = 0 state_dict = policy.state_dict() th.save(state_dict, out_weights) if __name__ == "__main__": parser = ArgumentParser() parser.add_argument("--data-dir", type=str, required=True, help="Path to the directory containing recordings to be trained on") parser.add_argument("--in-model", required=True, type=str, help="Path to the .model file to be finetuned") parser.add_argument("--in-weights", required=True, type=str, help="Path to the .weights file to be finetuned") parser.add_argument("--out-weights", required=True, type=str, help="Path where finetuned weights will be saved") args = parser.parse_args() behavioural_cloning_train(args.data_dir, args.in_model, args.in_weights, args.out_weights)