reinforcement-learning/rl_cartpole_batch_coach.ipynb (821 lines of code) (raw):
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Training Batch Reinforcement Learning Policies with Amazon SageMaker RL and Coach library\n"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"---\n",
"\n",
"This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook. \n",
"\n",
"\n",
"\n",
"---"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n",
"For many real-world problems, the reinforcement learning (RL) agent needs to learn from historical data that was generated by some deployed policy. For example, we may have historical data of experts playing games, users interacting with a website or sensor data from a control system. This notebook shows an example of how to use batch RL to train a new policy from offline dataset[1]. We use gym `CartPole-v0` as a fake simulated system to generate offline dataset and the RL agents are trained using Amazon SageMaker RL.\n",
"\n",
"We may want to evaluate the policy learned from historical data before deployment. Since simulators may not be available in all use cases, we need to evaluate how good the learned policy by using held out historical data. This is called as off-policy evaluation or counterfactual evaluation. In this notebook, we evaluate the policy during the training using several off-policy evaluation metrics. \n",
"\n",
"We can deploy the policy using SageMaker Hosting endpoint. However, some use cases may not require a persistent serving endpoint with sub-second latency. Here we demonstrate how to deploy the policy with [SageMaker Batch Transform](https://docs.aws.amazon.com/sagemaker/latest/dg/batch-transform.html), where large volumes of input state features can be inferenced with high throughput.\n",
"\n",
"Figure below shows an overview of the entire notebook.\n",
"\n",
""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Pre-requisites\n",
"\n",
"### Roles and permissions\n",
"\n",
"To get started, we'll import the Python libraries we need, set up the environment with a few pre-requisites for permissions and configurations."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sagemaker\n",
"import boto3\n",
"import sys\n",
"import os\n",
"import glob\n",
"import re\n",
"import subprocess\n",
"from IPython.display import HTML\n",
"import time\n",
"from time import gmtime, strftime\n",
"\n",
"sys.path.append(\"common\")\n",
"from misc import get_execution_role, wait_for_s3_object\n",
"from sagemaker.rl import RLEstimator, RLToolkit, RLFramework\n",
"\n",
"# install gym environments if needed\n",
"!pip install gym\n",
"from env_utils import VectoredGymEnvironment"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Setup S3 buckets\n",
"\n",
"Setup the linkage and authentication to the S3 bucket that you want to use for checkpoint and the metadata. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# S3 bucket\n",
"sage_session = sagemaker.session.Session()\n",
"s3_bucket = sage_session.default_bucket()\n",
"region_name = sage_session.boto_region_name\n",
"s3_output_path = \"s3://{}/\".format(s3_bucket) # SDK appends the job name and output folder\n",
"print(\"S3 bucket path: {}\".format(s3_output_path))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Define Variables \n",
"\n",
"We define variables such as the job prefix for the training jobs *and the image path for the container (only when this is BYOC).*"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# create unique job name\n",
"job_name_prefix = \"rl-batch-cartpole\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Configure settings\n",
"\n",
"You can run your RL training jobs on a SageMaker notebook instance or on your own machine. In both of these scenarios, you can run the following in either `local` or `SageMaker` modes. The `local` mode uses the SageMaker Python SDK to run your code in a local container before deploying to SageMaker. This can speed up iterative testing and debugging while using the same familiar Python SDK interface. You just need to set `local_mode = True`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%time\n",
"\n",
"# run in local mode?\n",
"local_mode = False"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Create an IAM role\n",
"Either get the execution role when running from a SageMaker notebook `role = sagemaker.get_execution_role()` or, when running from local machine, use utils method `role = get_execution_role()` to create an execution role."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"try:\n",
" role = sagemaker.get_execution_role()\n",
"except:\n",
" role = get_execution_role()\n",
"\n",
"print(\"Using IAM role arn: {}\".format(role))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Install docker for `local` mode\n",
"\n",
"In order to work in `local` mode, you need to have docker installed. When running from you local machine, please make sure that you have docker or docker-compose (for local CPU machines) and nvidia-docker (for local GPU machines) installed. Alternatively, when running from a SageMaker notebook instance, you can simply run the following script to install dependenceis.\n",
"\n",
"Note, you can only run a single local notebook at one time."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# only run from SageMaker notebook instance\n",
"if local_mode:\n",
" !/bin/bash ./common/setup.sh"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Collect offline data\n",
"\n",
"In order to do Batch RL training, we need to first prepare the dataset that is generated by a deployed policy. In real world scenarios, customers can collect these offline data by interacting the live environment using the already deployed agent. In this notebook, we used OpenAI gym `Cartpole-v0` as the environment to mimic a live environment and used a random policy with uniform action distribution to mimic a deployed agent. By interacting with multiple environments simultaneously, we can gather more trajectories from the environments.\n",
"\n",
"Here is a short introduction of the cart-pole balancing problem, where a pole is attached by an un-actuated joint to a cart, moving along a frictionless track.\n",
"\n",
"1. *Objective*: Prevent the pole from falling over\n",
"2. *Environment*: The environment used in this example is part of OpenAI Gym, corresponding to the version of the cart-pole problem described by Barto, Sutton, and Anderson [2]\n",
"3. *State*: Cart position, cart velocity, pole angle, pole velocity at tip\t\n",
"4. *Action*: Push cart to the left, push cart to the right\n",
"5. *Reward*: Reward is 1 for every step taken, including the termination step"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# initiate 100 environment to collect rollout data\n",
"NUM_ENVS = 100\n",
"NUM_EPISODES = 5\n",
"vectored_envs = VectoredGymEnvironment(\"CartPole-v0\", NUM_ENVS)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we have 100 environments of `Cartpole-v0` ready. We'll collect 5 episodes from each environment so we\u2019ll have 500 episodes of data for training. We start from a random policy that generates the same uniform action probabilities regardless of the state features."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# initiate a random policy by setting action probabilities as uniform distribution\n",
"action_probs = [[1 / 2, 1 / 2] for _ in range(NUM_ENVS)]\n",
"df = vectored_envs.collect_rollouts_with_given_action_probs(\n",
" action_probs=action_probs, num_episodes=NUM_EPISODES\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# the rollout dataframes contain attributes: action, action_probs, episode_id, reward, cumulative_rewards, state_features\n",
"# only show cumulative rewards at the last step of the episode\n",
"df.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can use the average cumulative reward of the random policy as a baseline for the Batch RL trained policy. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# average cumulative rewards for each episode\n",
"avg_rewards = df[\"cumulative_rewards\"].sum() / (NUM_ENVS * NUM_EPISODES)\n",
"print(\n",
" \"Average cumulative rewards over {} episodes rollouts was {}.\".format(\n",
" (NUM_ENVS * NUM_EPISODES), avg_rewards\n",
" )\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Save Dataframe as CSV for Batch RL Training\n",
"\n",
"Coach Batch RL support reading off policy data in CSV format. We will dump our collected rollout data in CSV format."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# dump dataframe as csv file\n",
"df.to_csv(\"src/cartpole_dataset.csv\", index=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Configure the presets for RL algorithm \n",
"\n",
"The presets that configure the Batch RL training jobs are defined in the `preset-cartpole-ddqnbcq.py` file which is also uploaded on the `/src` directory. Using the preset file, you can define agent parameters to select the specific agent algorithm. You can also set the environment parameters, define the schedule and visualization parameters, and define the graph manager. The schedule presets will define the number of heat up steps, periodic evaluation steps, training steps between evaluations.\n",
"\n",
"These can be overridden at runtime by specifying the `RLCOACH_PRESET` hyperparameter. Additionally, it can be used to define custom hyperparameters. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"!pygmentize src/preset-cartpole-ddqnbcq.py"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this notebook, we use DDQN[6] to update the policy in an off-policy manner, and combine it with BCQ[5] to address the error induced by inaccurately estimated values for unseen state-action pairs. The training is completely off-line."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Write the Training Code \n",
"\n",
"The training code is written in the file \u201ctrain-coach.py\u201d which is uploaded in the /src directory. \n",
"First import the environment files and the preset files, and then define the `main()` function. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"!pygmentize src/train-coach.py"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train the RL model using the Python SDK Script mode\n",
"\n",
"If you are using local mode, the training will run on the notebook instance. When using SageMaker for training, you can select a GPU or CPU instance. The RLEstimator is used for training RL jobs. \n",
"\n",
"1. Specify the source directory where the environment, presets and training code is uploaded.\n",
"2. Specify the entry point as the training code \n",
"3. Define the training parameters such as the instance count, job name, S3 path for output and job name. \n",
"4. Specify the hyperparameters for the RL agent algorithm. The `RLCOACH_PRESET` can be used to specify the RL agent algorithm you want to use. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%time\n",
"\n",
"if local_mode:\n",
" instance_type = \"local\"\n",
"else:\n",
" instance_type = \"ml.m4.xlarge\"\n",
"\n",
"estimator = RLEstimator(\n",
" entry_point=\"train-coach.py\",\n",
" source_dir=\"src\",\n",
" dependencies=[\"common/sagemaker_rl\"],\n",
" toolkit=RLToolkit.COACH,\n",
" toolkit_version=\"1.0.0\",\n",
" framework=RLFramework.TENSORFLOW,\n",
" role=role,\n",
" instance_type=instance_type,\n",
" instance_count=1,\n",
" output_path=s3_output_path,\n",
" base_job_name=job_name_prefix,\n",
" hyperparameters={\"RLCOACH_PRESET\": \"preset-cartpole-ddqnbcq\", \"save_model\": 1},\n",
")\n",
"estimator.fit()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Store intermediate training output and model checkpoints \n",
"\n",
"The output from the training job above is stored on S3. The intermediate folder contains gifs and metadata of the training. We'll need these metadata for metrics visualization and model evaluations."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"job_name = estimator._current_job_name\n",
"print(\"Job name: {}\".format(job_name))\n",
"\n",
"s3_url = \"s3://{}/{}\".format(s3_bucket, job_name)\n",
"\n",
"if local_mode:\n",
" output_tar_key = \"{}/output.tar.gz\".format(job_name)\n",
"else:\n",
" output_tar_key = \"{}/output/output.tar.gz\".format(job_name)\n",
"\n",
"intermediate_folder_key = \"{}/output/intermediate/\".format(job_name)\n",
"output_url = \"s3://{}/{}\".format(s3_bucket, output_tar_key)\n",
"intermediate_url = \"s3://{}/{}\".format(s3_bucket, intermediate_folder_key)\n",
"\n",
"print(\"S3 job path: {}\".format(s3_url))\n",
"print(\"Output.tar.gz location: {}\".format(output_url))\n",
"print(\"Intermediate folder path: {}\".format(intermediate_url))\n",
"\n",
"tmp_dir = \"/tmp/{}\".format(job_name)\n",
"os.system(\"mkdir {}\".format(tmp_dir))\n",
"print(\"Create local folder {}\".format(tmp_dir))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Visualization"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Plot metrics for training job\n",
"We can pull the Off Policy Evaluation(OPE) metric of the training and plot it to see the performance of the model over time."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"\n",
"csv_file_name = \"worker_0.batch_rl_graph.main_level.main_level.agent_0.csv\"\n",
"key = os.path.join(intermediate_folder_key, csv_file_name)\n",
"wait_for_s3_object(s3_bucket, key, tmp_dir, training_job_name=job_name)\n",
"\n",
"csv_file = \"{}/{}\".format(tmp_dir, csv_file_name)\n",
"df = pd.read_csv(csv_file)\n",
"df = df.dropna(subset=[\"Sequential Doubly Robust\"])\n",
"df.dropna(subset=[\"Weighted Importance Sampling\"])\n",
"\n",
"plt.figure(figsize=(12, 5))\n",
"plt.xlabel(\"Number of epochs\")\n",
"\n",
"ax1 = df[\"Weighted Importance Sampling\"].plot(color=\"blue\", grid=True, label=\"WIS\")\n",
"ax2 = df[\"Sequential Doubly Robust\"].plot(color=\"red\", grid=True, secondary_y=True, label=\"SDR\")\n",
"\n",
"h1, l1 = ax1.get_legend_handles_labels()\n",
"h2, l2 = ax2.get_legend_handles_labels()\n",
"\n",
"plt.legend(h1 + h2, l1 + l2, loc=1)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"There is a set of methods used to investigate the performance of the current trained policy without interacting with simulator / live environment. They can be used to estimate the goodness of the policy, based on the dataset collected from other policy. Here we showed two of these OPE metrics: WIS (Weighted Importance Sampling) [3] and SDR (Sequential Doubly Robust) [4]. As we can see in the plot, these metrics are improving as the learning agent is iterating over the given dataset."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Evaluation of RL models\n",
"\n",
"To evaluate the model trained with off policy data, we need to see the accumulative rewards of the agent by interacting with the environment. We use the last checkpointed model to run evaluation of the RL Agent. We use a different preset file here `preset-cartpole-ddqnbcq-env.py` to let the RL agent interact with the environment and collect rewards.\n",
"\n",
"### Load checkpointed model\n",
"\n",
"Checkpoint is passed on for evaluation / inference in the checkpoint channel. In local mode, we can simply use the local directory, whereas in the SageMaker mode, it needs to be moved to S3 first."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"wait_for_s3_object(s3_bucket, output_tar_key, tmp_dir, training_job_name=job_name)\n",
"\n",
"if not os.path.isfile(\"{}/output.tar.gz\".format(tmp_dir)):\n",
" raise FileNotFoundError(\"File output.tar.gz not found\")\n",
"os.system(\"tar -xvzf {}/output.tar.gz -C {}\".format(tmp_dir, tmp_dir))\n",
"\n",
"if local_mode:\n",
" checkpoint_dir = \"{}/data/checkpoint\".format(tmp_dir)\n",
"else:\n",
" checkpoint_dir = \"{}/checkpoint\".format(tmp_dir)\n",
"\n",
"print(\"Checkpoint directory {}\".format(checkpoint_dir))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if local_mode:\n",
" checkpoint_path = \"file://{}\".format(checkpoint_dir)\n",
" print(\"Local checkpoint file path: {}\".format(checkpoint_path))\n",
"else:\n",
" checkpoint_path = \"s3://{}/{}/checkpoint/\".format(s3_bucket, job_name)\n",
" if not os.listdir(checkpoint_dir):\n",
" raise FileNotFoundError(\"Checkpoint files not found under the path\")\n",
" os.system(\"aws s3 cp --recursive {} {}\".format(checkpoint_dir, checkpoint_path))\n",
" print(\"S3 checkpoint file path: {}\".format(checkpoint_path))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"estimator_eval = RLEstimator(\n",
" entry_point=\"evaluate-coach.py\",\n",
" source_dir=\"src\",\n",
" dependencies=[\"common/sagemaker_rl\"],\n",
" toolkit=RLToolkit.COACH,\n",
" toolkit_version=\"1.0.0\",\n",
" framework=RLFramework.TENSORFLOW,\n",
" role=role,\n",
" instance_type=instance_type,\n",
" instance_count=1,\n",
" output_path=s3_output_path,\n",
" base_job_name=job_name_prefix,\n",
" hyperparameters={\"RLCOACH_PRESET\": \"preset-cartpole-ddqnbcq-env\", \"evaluate_steps\": 1000},\n",
")\n",
"\n",
"\n",
"estimator_eval.fit({\"checkpoint\": checkpoint_path})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Batch Transform\n",
"\n",
"As we can see from the above evaluation job, the trained agent gets a total reward of around `200` as compared to a total reward around `25` in our offline dataset. Therefore, we can confirm that the agent has learned a better policy from the off-policy data.\n",
"\n",
"After we get the trained model, we can use it to do SageMaker Batch Transform, where customers can provide large volumes of input state features and get predictions with high throughput."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import time\n",
"\n",
"from sagemaker.tensorflow.model import TensorFlowModel\n",
"\n",
"if local_mode:\n",
" sage_session = sagemaker.local.LocalSession()\n",
"\n",
"# Create SageMaker model entity by using model data generated by the estimator\n",
"model = TensorFlowModel(\n",
" model_data=estimator.model_data,\n",
" framework_version=\"1.15\",\n",
" sagemaker_session=sage_session,\n",
" role=role,\n",
")\n",
"\n",
"prefix = \"batch_test\"\n",
"\n",
"# setup input data prefix and output data prefix for batch transform\n",
"batch_input = \"s3://{}/{}/{}/input/\".format(\n",
" s3_bucket, job_name, prefix\n",
") # The location of the test dataset\n",
"batch_output = \"s3://{}/{}/{}/output/\".format(\n",
" s3_bucket, job_name, prefix\n",
") # The location to store the results of the batch transform job\n",
"print(\"Inputpath for batch transform: {}\".format(batch_input))\n",
"print(\"Outputpath for batch transform: {}\".format(batch_output))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this notebook, we use the states of the environments as input for the Batch Transform."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import time\n",
"\n",
"file_name = \"env_states_{}.json\".format(int(time.time()))\n",
"# resetting the environments\n",
"vectored_envs.reset_all_envs()\n",
"# dump environment states into jsonlines file\n",
"vectored_envs.dump_environment_states(tmp_dir, file_name)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In order to use SageMaker Batch Transform, we'll need to first upload the input data from local to S3 bucket"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%time\n",
"from pathlib import Path\n",
"\n",
"local_input_file_path = Path(tmp_dir) / file_name\n",
"s3_input_file_path = batch_input + file_name # Path library will remove :// from s3 path\n",
"print(\n",
" \"Copy file from local path '{}' to s3 path '{}'\".format(\n",
" local_input_file_path, s3_input_file_path\n",
" )\n",
")\n",
"assert os.system(\"aws s3 cp {} {}\".format(local_input_file_path, s3_input_file_path)) == 0\n",
"print(\"S3 batch input file path: {}\".format(s3_input_file_path))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Similar to how we launch a training job on SageMaker, we can initiate a batch transform job either in `Local` mode or `SageMaker` mode."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if local_mode:\n",
" instance_type = \"local\"\n",
"else:\n",
" instance_type = \"ml.m4.xlarge\"\n",
"\n",
"transformer = model.transformer(\n",
" instance_count=1,\n",
" instance_type=instance_type,\n",
" output_path=batch_output,\n",
" assemble_with=\"Line\",\n",
" accept=\"application/jsonlines\",\n",
" strategy=\"SingleRecord\",\n",
")\n",
"\n",
"transformer.transform(\n",
" data=batch_input,\n",
" data_type=\"S3Prefix\",\n",
" content_type=\"application/jsonlines\",\n",
" split_type=\"Line\",\n",
" join_source=\"Input\",\n",
")\n",
"\n",
"transformer.wait()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"After we finished the batch transform job, we can download the prediction output from S3 bucket to local machine."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import subprocess\n",
"\n",
"# get the latest generated output file\n",
"cmd = \"aws s3 ls {} --recursive | sort | tail -n 1\".format(batch_output)\n",
"result = subprocess.check_output(cmd, shell=True).decode(\"utf-8\").split(\" \")[-1].strip()\n",
"local_output_file_path = Path(tmp_dir) / f\"{file_name}.out\"\n",
"s3_output_file_path = \"s3://{}/{}\".format(s3_bucket, result)\n",
"print(\n",
" \"Copy file from s3 path '{}' to local path '{}'\".format(\n",
" s3_output_file_path, local_output_file_path\n",
" )\n",
")\n",
"os.system(\"aws s3 cp {} {}\".format(s3_output_file_path, local_output_file_path))\n",
"print(\"S3 batch output file local path: {}\".format(local_output_file_path))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import subprocess\n",
"\n",
"batcmd = \"cat {}\".format(local_output_file_path)\n",
"results = subprocess.check_output(batcmd, shell=True).decode(\"utf-8\").split(\"\\n\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"results[:10]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this notebook, we use simulated environments to collect rollout data of a random policy. Assuming the updated policy is now deployed, we can use Batch Transform to collect rollout data from this policy. \n",
"\n",
"Here are the steps on how to collect rollout data with Batch Transform:\n",
"1. Use Batch Transform to get action predictions, provided observation features from the live environment at timestep *t*\n",
"2. Deployed agent takes suggested actions against the environment (simulator / real) at timestep *t*\n",
"3. Environment returns new observation features at timestep *t+1*\n",
"4. Return back to step 1. Use Batch Transform to get action predictions at timestep *t+1*\n",
"\n",
"This iterative procedure enables us to collect a set of data that can cover the whole episode, similar to what we've shown at the beginning of the notebook. Once the data is sufficient, we can use these data to kick off a BatchRL training again.\n",
"\n",
"Batch Transform works well when there are multiple episodes interacting with the environments concurrently. One of the typical use cases is email campaign, where each email user is an independent episode interacting with the deployed policy. Batch Transform can concurrently collect rollout data from millions of user context with efficiency. The collected rollout data can then be supplied to Batch RL Training to train a better policy to serve the email users."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Reference\n",
"\n",
"1. Batch Reinforcement Learning with Coach: https://github.com/NervanaSystems/coach/blob/master/tutorials/4.%20Batch%20Reinforcement%20Learning.ipynb\n",
"2. AG Barto, RS Sutton and CW Anderson, \"Neuronlike Adaptive Elements That Can Solve Difficult Learning Control Problem\", IEEE Transactions on Systems, Man, and Cybernetics, 1983.\n",
"3. Thomas, Philip, Georgios Theocharous, and Mohammad Ghavamzadeh. \"High confidence policy improvement.\" International Conference on Machine Learning. 2015.\n",
"4. Jiang, Nan, and Lihong Li. \"Doubly robust off-policy value evaluation for reinforcement learning.\" arXiv preprint arXiv:1511.03722 (2015).\n",
"5. Fujimoto, Scott, David Meger, and Doina Precup. \"Off-policy deep reinforcement learning without exploration.\" arXiv preprint arXiv:1812.02900 (2018)\n",
"6. Van Hasselt, Hado, Arthur Guez, and David Silver. \"Deep reinforcement learning with double q-learning.\" Thirtieth AAAI conference on artificial intelligence. 2016."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## Notebook CI Test Results\n",
"\n",
"This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n",
"\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "conda_mxnet_p36",
"language": "python",
"name": "conda_mxnet_p36"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.10"
}
},
"nbformat": 4,
"nbformat_minor": 2
}