reinforcement-learning/rl_cartpole_coach_gymEnv.ipynb (636 lines of code) (raw):

{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Cart-pole Balancing Model with Amazon SageMaker and Coach library\n" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "\n", "This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook. \n", "\n", "![This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-west-2/reinforcement_learning|rl_cartpole_coach|rl_cartpole_coach_gymEnv.ipynb)\n", "\n", "---" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "---\n", "## Introduction\n", "\n", "In this notebook we'll start from the cart-pole balancing problem, where a pole is attached by an un-actuated joint to a cart, moving along a frictionless track. Instead of applying control theory to solve the problem, this example shows how to solve the problem with reinforcement learning on Amazon SageMaker and Coach.\n", "\n", "(For a similar Cart-pole example using Ray RLlib, see this [link](../rl_cartpole_ray/rl_cartpole_ray_gymEnv.ipynb). Another Cart-pole example using Coach library and offline data can be found [here](../rl_cartpole_batch_coach/rl_cartpole_batch_coach.ipynb).)\n", "\n", "1. *Objective*: Prevent the pole from falling over\n", "2. *Environment*: The environment used in this exmaple is part of OpenAI Gym, corresponding to the version of the cart-pole problem described by Barto, Sutton, and Anderson [1]\n", "3. *State*: Cart position, cart velocity, pole angle, pole velocity at tip\t\n", "4. *Action*: Push cart to the left, push cart to the right\n", "5. *Reward*: Reward is 1 for every step taken, including the termination step\n", "\n", "References\n", "\n", "1. AG Barto, RS Sutton and CW Anderson, \"Neuronlike Adaptive Elements That Can Solve Difficult Learning Control Problem\", IEEE Transactions on Systems, Man, and Cybernetics, 1983." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Pre-requisites \n", "\n", "### Imports\n", "\n", "To get started, we'll import the Python libraries we need, set up the environment with a few prerequisites for permissions and configurations." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "import sagemaker\n", "import boto3\n", "import sys\n", "import os\n", "import glob\n", "import re\n", "import numpy as np\n", "import subprocess\n", "from IPython.display import HTML\n", "import time\n", "from time import gmtime, strftime\n", "\n", "sys.path.append(\"common\")\n", "from misc import get_execution_role, wait_for_s3_object\n", "from sagemaker.rl import RLEstimator, RLToolkit, RLFramework" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Setup S3 bucket\n", "\n", "Set up the linkage and authentication to the S3 bucket that you want to use for checkpoint and the metadata. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sage_session = sagemaker.session.Session()\n", "s3_bucket = sage_session.default_bucket()\n", "s3_output_path = \"s3://{}/\".format(s3_bucket)\n", "print(\"S3 bucket path: {}\".format(s3_output_path))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Define Variables \n", "\n", "We define variables such as the job prefix for the training jobs *and the image path for the container (only when this is BYOC).*" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# create unique job name\n", "job_name_prefix = \"rl-cart-pole\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Configure where training happens\n", "\n", "You can run your RL training jobs on a SageMaker notebook instance or on your own machine. In both of these scenarios, you can run the following in either local or SageMaker modes. The local mode uses the SageMaker Python SDK to run your code in a local container before deploying to SageMaker. This can speed up iterative testing and debugging while using the same familiar Python SDK interface. You just need to set local_mode = True." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "parameters" ] }, "outputs": [], "source": [ "# run in local mode?\n", "local_mode = False\n", "\n", "if local_mode:\n", " instance_type = \"local\"\n", "else:\n", " instance_type = \"ml.m4.4xlarge\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Create an IAM role\n", "\n", "Either get the execution role when running from a SageMaker notebook instance `role = sagemaker.get_execution_role()` or, when running from local notebook instance, use utils method `role = get_execution_role()` to create an execution role." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "try:\n", " role = sagemaker.get_execution_role()\n", "except:\n", " role = get_execution_role()\n", "\n", "print(\"Using IAM role arn: {}\".format(role))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Install docker for `local` mode\n", "\n", "In order to work in `local` mode, you need to have docker installed. When running from you local machine, please make sure that you have docker or docker-compose (for local CPU machines) and nvidia-docker (for local GPU machines) installed. Alternatively, when running from a SageMaker notebook instance, you can simply run the following script to install dependenceis.\n", "\n", "Note, you can only run a single local notebook at one time." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# only run from SageMaker notebook instance\n", "if local_mode:\n", " !/bin/bash ./common/setup.sh" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup the environment\n", "\n", "Cartpole environment used in this example is part of OpenAI Gym." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Configure the presets for RL algorithm \n", "\n", "The presets that configure the RL training jobs are defined in the \u201cpreset-cartpole-clippedppo.py\u201d file which is also uploaded on the /src directory. Using the preset file, you can define agent parameters to select the specific agent algorithm. You can also set the environment parameters, define the schedule and visualization parameters, and define the graph manager. The schedule presets will define the number of heat up steps, periodic evaluation steps, training steps between evaluations.\n", "\n", "These can be overridden at runtime by specifying the RLCOACH_PRESET hyperparameter. Additionally, it can be used to define custom hyperparameters. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pygmentize src/preset-cartpole-clippedppo.py" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Write the Training Code \n", "\n", "The training code is written in the file \u201ctrain-coach.py\u201d which is uploaded in the /src directory. \n", "First import the environment files and the preset files, and then define the main() function. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pygmentize src/train-coach.py" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train the RL model using the Python SDK Script mode\n", "\n", "If you are using local mode, the training will run on the notebook instance. When using SageMaker for training, you can select a GPU or CPU instance. The RLEstimator is used for training RL jobs. \n", "\n", "1. Specify the source directory where the environment, presets and training code is uploaded.\n", "2. Specify the entry point as the training code \n", "3. Specify the choice of RL toolkit and framework. This automatically resolves to the ECR path for the RL Container. \n", "4. Define the training parameters such as the instance count, job name, S3 path for output and job name. \n", "5. Specify the hyperparameters for the RL agent algorithm. The RLCOACH_PRESET can be used to specify the RL agent algorithm you want to use. \n", "6. Define the metrics definitions that you are interested in capturing in your logs. These can also be visualized in CloudWatch and SageMaker Notebooks. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "estimator = RLEstimator(\n", " entry_point=\"train-coach.py\",\n", " source_dir=\"src\",\n", " dependencies=[\"common/sagemaker_rl\"],\n", " toolkit=RLToolkit.COACH,\n", " toolkit_version=\"0.11.0\",\n", " framework=RLFramework.MXNET,\n", " role=role,\n", " instance_type=instance_type,\n", " instance_count=1,\n", " output_path=s3_output_path,\n", " base_job_name=job_name_prefix,\n", " hyperparameters={\n", " \"RLCOACH_PRESET\": \"preset-cartpole-clippedppo\",\n", " \"rl.agent_params.algorithm.discount\": 0.9,\n", " \"rl.evaluation_steps:EnvironmentEpisodes\": 8,\n", " \"improve_steps\": 10000,\n", " \"save_model\": 1,\n", " },\n", ")\n", "\n", "estimator.fit(wait=local_mode)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Store intermediate training output and model checkpoints \n", "\n", "The output from the training job above is stored on S3. The intermediate folder contains gifs and metadata of the training." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "job_name = estimator._current_job_name\n", "print(\"Job name: {}\".format(job_name))\n", "\n", "s3_url = \"s3://{}/{}\".format(s3_bucket, job_name)\n", "\n", "if local_mode:\n", " output_tar_key = \"{}/output.tar.gz\".format(job_name)\n", "else:\n", " output_tar_key = \"{}/output/output.tar.gz\".format(job_name)\n", "\n", "intermediate_folder_key = \"{}/output/intermediate/\".format(job_name)\n", "output_url = \"s3://{}/{}\".format(s3_bucket, output_tar_key)\n", "intermediate_url = \"s3://{}/{}\".format(s3_bucket, intermediate_folder_key)\n", "\n", "print(\"S3 job path: {}\".format(s3_url))\n", "print(\"Output.tar.gz location: {}\".format(output_url))\n", "print(\"Intermediate folder path: {}\".format(intermediate_url))\n", "\n", "tmp_dir = \"/tmp/{}\".format(job_name)\n", "os.system(\"mkdir {}\".format(tmp_dir))\n", "print(\"Create local folder {}\".format(tmp_dir))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Visualization" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Plot metrics for training job\n", "We can pull the reward metric of the training and plot it to see the performance of the model over time." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%matplotlib inline\n", "import pandas as pd\n", "\n", "csv_file_name = \"worker_0.simple_rl_graph.main_level.main_level.agent_0.csv\"\n", "key = os.path.join(intermediate_folder_key, csv_file_name)\n", "wait_for_s3_object(s3_bucket, key, tmp_dir, training_job_name=job_name)\n", "\n", "csv_file = \"{}/{}\".format(tmp_dir, csv_file_name)\n", "df = pd.read_csv(csv_file)\n", "df = df.dropna(subset=[\"Training Reward\"])\n", "x_axis = \"Episode #\"\n", "y_axis = \"Training Reward\"\n", "\n", "plt = df.plot(x=x_axis, y=y_axis, figsize=(12, 5), legend=True, style=\"b-\")\n", "plt.set_ylabel(y_axis)\n", "plt.set_xlabel(x_axis);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Visualize the rendered gifs\n", "The latest gif file of the training is displayed. You can replace the gif_index below to visualize other files generated." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "key = os.path.join(intermediate_folder_key, \"gifs\")\n", "wait_for_s3_object(s3_bucket, key, tmp_dir, training_job_name=job_name)\n", "print(\"Copied gifs files to {}\".format(tmp_dir))\n", "\n", "glob_pattern = os.path.join(\"{}/*.gif\".format(tmp_dir))\n", "gifs = [file for file in glob.iglob(glob_pattern, recursive=True)]\n", "extract_episode = lambda string: int(\n", " re.search(\".*episode-(\\d*)_.*\", string, re.IGNORECASE).group(1)\n", ")\n", "gifs.sort(key=extract_episode)\n", "print(\"GIFs found:\\n{}\".format(\"\\n\".join([os.path.basename(gif) for gif in gifs])))\n", "\n", "# visualize a specific episode\n", "gif_index = -1 # since we want last gif\n", "gif_filepath = gifs[gif_index]\n", "gif_filename = os.path.basename(gif_filepath)\n", "print(\"Selected GIF: {}\".format(gif_filename))\n", "os.system(\"mkdir -p ./src/tmp/ && cp {} ./src/tmp/{}.gif\".format(gif_filepath, gif_filename))\n", "HTML('<img src=\"./src/tmp/{}.gif\">'.format(gif_filename))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Evaluation of RL models\n", "\n", "We use the last checkpointed model to run evaluation for the RL Agent. \n", "\n", "### Load checkpointed model\n", "\n", "Checkpointed data from the previously trained models will be passed on for evaluation / inference in the checkpoint channel. In local mode, we can simply use the local directory, whereas in the SageMaker mode, it needs to be moved to S3 first." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "wait_for_s3_object(s3_bucket, output_tar_key, tmp_dir, training_job_name=job_name)\n", "\n", "if not os.path.isfile(\"{}/output.tar.gz\".format(tmp_dir)):\n", " raise FileNotFoundError(\"File output.tar.gz not found\")\n", "os.system(\"tar -xvzf {}/output.tar.gz -C {}\".format(tmp_dir, tmp_dir))\n", "\n", "if local_mode:\n", " checkpoint_dir = \"{}/data/checkpoint\".format(tmp_dir)\n", "else:\n", " checkpoint_dir = \"{}/checkpoint\".format(tmp_dir)\n", "\n", "print(\"Checkpoint directory {}\".format(checkpoint_dir))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if local_mode:\n", " checkpoint_path = \"file://{}\".format(checkpoint_dir)\n", " print(\"Local checkpoint file path: {}\".format(checkpoint_path))\n", "else:\n", " checkpoint_path = \"s3://{}/{}/checkpoint/\".format(s3_bucket, job_name)\n", " if not os.listdir(checkpoint_dir):\n", " raise FileNotFoundError(\"Checkpoint files not found under the path\")\n", " os.system(\"aws s3 cp --recursive {} {}\".format(checkpoint_dir, checkpoint_path))\n", " print(\"S3 checkpoint file path: {}\".format(checkpoint_path))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Run the evaluation step\n", "\n", "Use the checkpointed model to run the evaluation step. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "estimator_eval = RLEstimator(\n", " role=role,\n", " source_dir=\"src/\",\n", " dependencies=[\"common/sagemaker_rl\"],\n", " toolkit=RLToolkit.COACH,\n", " toolkit_version=\"0.11.0\",\n", " framework=RLFramework.MXNET,\n", " entry_point=\"evaluate-coach.py\",\n", " instance_count=1,\n", " instance_type=instance_type,\n", " base_job_name=job_name_prefix + \"-evaluation\",\n", " hyperparameters={\"RLCOACH_PRESET\": \"preset-cartpole-clippedppo\", \"evaluate_steps\": 2000},\n", ")\n", "\n", "estimator_eval.fit({\"checkpoint\": checkpoint_path})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Visualize the output \n", "\n", "Optionally, you can run the steps defined earlier to visualize the output " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model deployment" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Since we specified MXNet when configuring the RLEstimator, the MXNet deployment container will be used for hosting." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "predictor = estimator.deploy(\n", " initial_instance_count=1, instance_type=instance_type, entry_point=\"deploy-mxnet-coach.py\"\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can test the endpoint with 2 samples observations. Starting with the cart stationary in the center of the environment, but the pole to the right and falling. Since the environment vector was of the form `[cart_position, cart_velocity, pole_angle, pole_velocity]` and we used observation normalization in our preset, we choose an observation of `[0, 0, 2, 2]`. Since we're deploying a PPO model, our model returns both state value and actions." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "value, action = predictor.predict(np.array([0.0, 0.0, 2.0, 2.0]))\n", "action" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We see the policy decides to move the cart to the right (2nd value) with a higher probability to recover the situation. And similarly in the other direction." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "value, action = predictor.predict(np.array([0.0, 0.0, -2.0, -2.0]))\n", "action" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Clean up endpoint" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "predictor.delete_endpoint()" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Notebook CI Test Results\n", "\n", "This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.\n", "\n", "![This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-east-1/reinforcement_learning|rl_cartpole_coach|rl_cartpole_coach_gymEnv.ipynb)\n", "\n", "![This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-east-2/reinforcement_learning|rl_cartpole_coach|rl_cartpole_coach_gymEnv.ipynb)\n", "\n", "![This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-west-1/reinforcement_learning|rl_cartpole_coach|rl_cartpole_coach_gymEnv.ipynb)\n", "\n", "![This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ca-central-1/reinforcement_learning|rl_cartpole_coach|rl_cartpole_coach_gymEnv.ipynb)\n", "\n", "![This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/sa-east-1/reinforcement_learning|rl_cartpole_coach|rl_cartpole_coach_gymEnv.ipynb)\n", "\n", "![This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-1/reinforcement_learning|rl_cartpole_coach|rl_cartpole_coach_gymEnv.ipynb)\n", "\n", "![This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-2/reinforcement_learning|rl_cartpole_coach|rl_cartpole_coach_gymEnv.ipynb)\n", "\n", "![This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-3/reinforcement_learning|rl_cartpole_coach|rl_cartpole_coach_gymEnv.ipynb)\n", "\n", "![This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-central-1/reinforcement_learning|rl_cartpole_coach|rl_cartpole_coach_gymEnv.ipynb)\n", "\n", "![This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-north-1/reinforcement_learning|rl_cartpole_coach|rl_cartpole_coach_gymEnv.ipynb)\n", "\n", "![This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-southeast-1/reinforcement_learning|rl_cartpole_coach|rl_cartpole_coach_gymEnv.ipynb)\n", "\n", "![This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-southeast-2/reinforcement_learning|rl_cartpole_coach|rl_cartpole_coach_gymEnv.ipynb)\n", "\n", "![This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-northeast-1/reinforcement_learning|rl_cartpole_coach|rl_cartpole_coach_gymEnv.ipynb)\n", "\n", "![This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-northeast-2/reinforcement_learning|rl_cartpole_coach|rl_cartpole_coach_gymEnv.ipynb)\n", "\n", "![This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-south-1/reinforcement_learning|rl_cartpole_coach|rl_cartpole_coach_gymEnv.ipynb)\n" ] } ], "metadata": { "anaconda-cloud": {}, "kernelspec": { "display_name": "conda_mxnet_p36", "language": "python", "name": "conda_mxnet_p36" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.10" }, "notice": "Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License." }, "nbformat": 4, "nbformat_minor": 2 }