training/sagemaker-automatic-model-tuning/hpo_blazingtext_text_classification_20_newsgroups.ipynb (1,192 lines of code) (raw):

{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Text Classification with Amazon SageMaker BlazingText and Hyperparameter Tuning\n" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "\n", "This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook. \n", "\n", "![This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-west-2/hyperparameter_tuning|blazingtext_text_classification_20_newsgroups|hpo_blazingtext_text_classification_20_newsgroups.ipynb)\n", "\n", "---" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "Automatic model tuning, also known as hyperparameter tuning, finds the best version of a model by running many jobs that test a range of hyperparameters on your dataset. You choose the tunable hyperparameters, a range of values for each, and an objective metric. You choose the objective metric from the metrics that the algorithm computes. Automatic model tuning searches the hyperparameters chosen to find the combination of values that result in the model that optimizes the objective metric.\n", "\n", "\n", "## Introduction\n", "\n", "Text Classification can be used to solve various use-cases like sentiment analysis, spam detection, hashtag prediction etc. This notebook demonstrates the use of SageMaker BlazingText to perform supervised binary/multi class with single or multi label text classification. BlazingText can train the model on more than a billion words in a couple of minutes using a multi-core CPU or a GPU, while achieving performance on par with the state-of-the-art deep learning text classification algorithms. BlazingText extends the `fastText` text classifier to leverage GPU acceleration using custom `CUDA` kernels." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Install Python packages" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sys\n", "\n", "!{sys.executable} -m pip install \"scikit_learn==0.20.0\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Setup\n", "\n", "Let's start by specifying:\n", "\n", "- The S3 bucket and prefix that you want to use for training and model data. This should be within the same region as the Notebook Instance, training, and hosting. If you don't specify a bucket, SageMaker SDK will create a default bucket following a pre-defined naming convention in the same region. \n", "- The IAM role ARN used to give SageMaker access to your data. It can be fetched using the **get_execution_role** method from sagemaker python SDK." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "isConfigCell": true }, "outputs": [], "source": [ "import sagemaker\n", "from sagemaker import get_execution_role\n", "import json\n", "import boto3\n", "import pandas as pd\n", "import re\n", "import string\n", "from sklearn.model_selection import train_test_split\n", "\n", "sess = sagemaker.Session()\n", "\n", "role = get_execution_role()\n", "print(\n", " role\n", ") # This is the role that SageMaker would use to leverage AWS resources (S3, CloudWatch) on your behalf\n", "\n", "bucket = sess.default_bucket() # Replace with your own bucket name if needed\n", "print(bucket)\n", "prefix = \"blazingtext/supervised/20_newsgroups\" # Replace with the prefix under which you want to store the data if needed" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Data Preparation\n", "\n", "Now we'll download a dataset from the web on which we want to train the text classification model. BlazingText expects a single preprocessed text file with space separated tokens and each line of the file should contain a single sentence and the corresponding label(s) prefixed by \"\\__label\\__\".\n", "\n", "In this example, let us train the text classification model on the [`20 newsgroups dataset`](http://qwone.com/~jason/20Newsgroups/). The `20 newsgroups dataset` consists of 20000 messages taken from 20 Usenet newsgroups." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "import shutil\n", "\n", "data_dir = \"20_newsgroups_bulk\"\n", "if os.path.exists(data_dir): # cleanup existing data folder\n", " shutil.rmtree(data_dir)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!aws s3 cp s3://sagemaker-sample-files/datasets/text/20_newsgroups/20_newsgroups_bulk.tar.gz ." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!tar xzf 20_newsgroups_bulk.tar.gz\n", "!ls 20_newsgroups_bulk" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "file_list = [os.path.join(data_dir, f) for f in os.listdir(data_dir)]\n", "print(\"Number of files:\", len(file_list))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "documents_count = 0\n", "for file in file_list:\n", " df = pd.read_csv(file, header=None, names=[\"text\"])\n", " documents_count = documents_count + df.shape[0]\n", "print(\"Number of documents:\", documents_count)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "categories_list = [f.split(\"/\")[1] for f in file_list]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "categories_list" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let us inspect the dataset to get some understanding about how the data and the label is provided in the dataset. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df = pd.read_csv(\"./20_newsgroups_bulk/rec.motorcycles\", header=None, names=[\"text\"])\n", "df" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df[\"text\"][0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df = pd.read_csv(\"./20_newsgroups_bulk/comp.sys.mac.hardware\", header=None, names=[\"text\"])\n", "df" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "df[\"text\"][0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As we can see from the above, there is a single file for each class in the dataset. Each record is just a plain text paragraphs with header, body, footer and quotes. We will need to process them into a suitable data format." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data Preprocessing\n", "We need to preprocess the training data into **space separated tokenized text** format which can be consumed by `BlazingText` algorithm. Also, as mentioned previously, the class label(s) should be prefixed with `__label__` and it should be present in the same line along with the original sentence. We'll use `nltk` library to tokenize the input sentences from `20 newsgroups dataset`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Download the `nltk` tokenizer and other libraries" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import nltk\n", "from nltk.corpus import stopwords\n", "from nltk.tokenize import word_tokenize\n", "\n", "nltk.download(\"punkt\")\n", "nltk.download(\"stopwords\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sklearn.datasets.twenty_newsgroups import (\n", " strip_newsgroup_header,\n", " strip_newsgroup_quoting,\n", " strip_newsgroup_footer,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This following function will remove the header, footer and quotes (of earlier messages in each text)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def strip_newsgroup_item(item):\n", " item = strip_newsgroup_header(item)\n", " item = strip_newsgroup_quoting(item)\n", " item = strip_newsgroup_footer(item)\n", " return item" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The following function will take care of stop words removal, removing leading/trailing whitespace, extra space, tabs, and HTML tags/markups" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Let's get a list of stop words from the NLTK library\n", "stop_words = stopwords.words(\"english\")\n", "\n", "\n", "def process_text(texts):\n", " final_text_list = []\n", " for text in texts:\n", "\n", " # Check if the sentence is a missing value\n", " if isinstance(text, str) == False:\n", " text = \"\"\n", "\n", " filtered_sentence = []\n", "\n", " # Lowercase\n", " text = text.lower()\n", "\n", " # Remove leading/trailing whitespace, extra space, tabs, and HTML tags/markups\n", " text = text.strip()\n", " text = re.sub(\"\\[.*?\\]\", \"\", text)\n", " text = re.sub(\"https?://\\S+|www\\.\\S+\", \"\", text)\n", " text = re.sub(\"<.*?>+\", \"\", text)\n", " text = re.sub(\"[%s]\" % re.escape(string.punctuation), \"\", text)\n", " text = re.sub(\"\\n\", \"\", text)\n", " text = re.sub(\"\\w*\\d\\w*\", \"\", text)\n", "\n", " for w in word_tokenize(text):\n", " # We are applying some custom filtering here, feel free to try different things\n", " # Check if it is not numeric and its length>2 and not in stop words\n", " if (not w.isnumeric()) and (len(w) > 2) and (w not in stop_words):\n", " filtered_sentence.append(w)\n", " final_string = \" \".join(filtered_sentence) # final string of cleaned words\n", "\n", " final_text_list.append(final_string)\n", "\n", " return final_text_list" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we will read each of the `20_newsgroups` dataset files, call `strip_newsgroup_item` and `process_text` functions we defined earlier, and then aggregate all data into one dataframe." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "all_categories_df = pd.DataFrame()\n", "\n", "for file in file_list:\n", " print(f\"Processing {file}\")\n", " label = file.split(\"/\")[1]\n", " df = pd.read_csv(file, header=None, names=[\"text\"])\n", " df[\"text\"] = df[\"text\"].apply(strip_newsgroup_item)\n", " df[\"text\"] = process_text(df[\"text\"].tolist())\n", " df[\"label\"] = label\n", " all_categories_df = all_categories_df.append(df, ignore_index=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's inspect how many categories there are in our dataset." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "all_categories_df[\"label\"].value_counts()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In our dataset there are 20 categories which is too much, so we will combine the sub-categories." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# replace to politics\n", "all_categories_df[\"label\"].replace(\n", " {\n", " \"talk.politics.misc\": \"politics\",\n", " \"talk.politics.guns\": \"politics\",\n", " \"talk.politics.mideast\": \"politics\",\n", " },\n", " inplace=True,\n", ")\n", "\n", "# replace to recreational\n", "all_categories_df[\"label\"].replace(\n", " {\n", " \"rec.sport.hockey\": \"recreational\",\n", " \"rec.sport.baseball\": \"recreational\",\n", " \"rec.autos\": \"recreational\",\n", " \"rec.motorcycles\": \"recreational\",\n", " },\n", " inplace=True,\n", ")\n", "\n", "# replace to religion\n", "all_categories_df[\"label\"].replace(\n", " {\n", " \"soc.religion.christian\": \"religion\",\n", " \"talk.religion.misc\": \"religion\",\n", " \"alt.atheism\": \"religion\",\n", " },\n", " inplace=True,\n", ")\n", "\n", "# replace to computer\n", "all_categories_df[\"label\"].replace(\n", " {\n", " \"comp.windows.x\": \"computer\",\n", " \"comp.sys.ibm.pc.hardware\": \"computer\",\n", " \"comp.os.ms-windows.misc\": \"computer\",\n", " \"comp.graphics\": \"computer\",\n", " \"comp.sys.mac.hardware\": \"computer\",\n", " },\n", " inplace=True,\n", ")\n", "# replace to sales\n", "all_categories_df[\"label\"].replace({\"misc.forsale\": \"sales\"}, inplace=True)\n", "\n", "# replace to science\n", "all_categories_df[\"label\"].replace(\n", " {\n", " \"sci.crypt\": \"science\",\n", " \"sci.electronics\": \"science\",\n", " \"sci.med\": \"science\",\n", " \"sci.space\": \"science\",\n", " },\n", " inplace=True,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we are left with 6 categories, which is much better." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "all_categories_df[\"label\"].value_counts()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's calculate number of words for each row." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "all_categories_df[\"word_count\"] = all_categories_df[\"text\"].apply(lambda x: len(str(x).split()))\n", "all_categories_df.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's get basic statistics about the dataset." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "all_categories_df[\"word_count\"].describe()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can see that the mean value is around 86 words. However, there are outliers, such as a text with 6179 words. This can make it harder for the model to result in good performance. We will take care to drop those rows." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's drop empty rows first." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "no_text = all_categories_df[all_categories_df[\"word_count\"] == 0]\n", "print(len(no_text))\n", "\n", "# drop these rows\n", "all_categories_df.drop(no_text.index, inplace=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's drop the rows that are longer than 128 words. This is done to make it easy for the model to train without outliers." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "long_text = all_categories_df[all_categories_df[\"word_count\"] > 128]\n", "print(len(long_text))\n", "\n", "# drop these rows\n", "all_categories_df.drop(long_text.index, inplace=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "all_categories_df[\"label\"].value_counts()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's get basic statistics about the dataset after our outliers fixes." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "all_categories_df[\"word_count\"].describe()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This looks much more balanced." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we drop the `word_count` columns as we will not need it anymore." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "all_categories_df.drop(columns=\"word_count\", axis=1, inplace=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "all_categories_df" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We partition the dataset into 80% training and 20% validation set." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train, validation = train_test_split(all_categories_df, test_size=0.2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def save_to_csv_with_prefix(df, file_name):\n", " df[\"text\"] = \"__label__\" + df[\"label\"] + \" \" + df[\"text\"]\n", " df.drop(columns=\"label\", axis=1, inplace=True)\n", " records = df[\"text\"].values.tolist()\n", " print(len(records))\n", " f = open(file_name, \"w\")\n", " for element in records:\n", " f.write(element + \"\\n\")\n", " f.close()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "save_to_csv_with_prefix(train, \"20_newsgroups.train\")\n", "save_to_csv_with_prefix(validation, \"20_newsgroups.validation\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let us inspect the train and the validation datasets after the preprocessing, to get understanding about how the data and the labels now look." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!head 20_newsgroups.train -n 3" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!head 20_newsgroups.validation -n 3" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We need to upload it to S3 so that it can be consumed by SageMaker to execute training jobs. We'll use Python SDK to upload these two files to the bucket and prefix location that we have set above. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_channel = prefix + \"/train\"\n", "validation_channel = prefix + \"/validation\"\n", "\n", "sess.upload_data(path=\"20_newsgroups.train\", bucket=bucket, key_prefix=train_channel)\n", "sess.upload_data(path=\"20_newsgroups.validation\", bucket=bucket, key_prefix=validation_channel)\n", "\n", "s3_train_data = \"s3://{}/{}\".format(bucket, train_channel)\n", "s3_validation_data = \"s3://{}/{}\".format(bucket, validation_channel)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next we need to set up an output location at S3, where the model artifact will be dumped. These artifacts are also the output of the algorithm's training job." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "s3_output_location = \"s3://{}/{}/output\".format(bucket, prefix)\n", "print(s3_output_location)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Set up hyperparameter tuning job\n", "Now that we are done with all the setup that is needed, we are ready to train our BlazingText model. To begin, let us create a `Estimator` object. This estimator will launch the training job." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "region_name = boto3.Session().region_name" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "container = sagemaker.image_uris.retrieve(\"blazingtext\", region_name, \"1\")\n", "print(\"Using SageMaker BlazingText container: {} ({})\".format(container, region_name))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training the BlazingText model for supervised text classification" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "BlazingText supports a *supervised* mode for text classification. It extends the `FastText` text classifier to leverage GPU acceleration using custom `CUDA` kernels.\n", "The model can be trained on more than a billion words in a couple of minutes using a multi-core CPU or a GPU, while achieving performance on par with the state-of-the-art deep learning text classification algorithms.\n", "For more information, please refer to the [algorithm documentation](https://docs.aws.amazon.com/sagemaker/latest/dg/blazingtext.html)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, let's define the SageMaker `Estimator` with resource configurations and hyperparameters to train Text Classification on `20 newsgroups` dataset, using \"supervised\" mode on a `c4.4xlarge` instance." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "estimator = sagemaker.estimator.Estimator(\n", " container,\n", " role,\n", " instance_count=1,\n", " instance_type=\"ml.c4.4xlarge\",\n", " volume_size=30,\n", " max_run=360000,\n", " input_mode=\"File\",\n", " output_path=s3_output_location,\n", " hyperparameters={\n", " \"mode\": \"supervised\",\n", " \"epochs\": 25,\n", " \"min_count\": 2,\n", " \"early_stopping\": True,\n", " \"patience\": 4,\n", " \"min_epochs\": 5,\n", " \"word_ngrams\": 1,\n", " },\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Once we've defined our estimator we can specify the hyperparameters we'd like to tune and their possible values. We have three different types of hyperparameters.\n", "- Categorical parameters need to take one value from a discrete set. We define this by passing the list of possible values to `CategoricalParameter(list)`\n", "- Continuous parameters can take any real number value between the minimum and maximum value, defined by `ContinuousParameter(min, max)`\n", "- Integer parameters can take any integer value between the minimum and maximum value, defined by `IntegerParameter(min, max)`\n", "\n", "*Note, if possible, it's almost always best to specify a value as the least restrictive type. For example, tuning learning rate as a continuous value between 0.01 and 0.2 is likely to yield a better result than tuning as a categorical parameter with values 0.01, 0.1, 0.15, or 0.2.*\n", "\n", "Refer to [BlazingText Hyperparameters](https://docs.aws.amazon.com/sagemaker/latest/dg/blazingtext_hyperparameters.html) in the Amazon SageMaker documentation for the complete list of hyperparameters." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sagemaker.tuner import (\n", " IntegerParameter,\n", " CategoricalParameter,\n", " ContinuousParameter,\n", " HyperparameterTuner,\n", ")\n", "\n", "hyperparameter_ranges = {\n", " \"learning_rate\": ContinuousParameter(0.05, 0.15),\n", " \"vector_dim\": IntegerParameter(32, 300),\n", "}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next we'll specify the objective metric that we'd like to tune and its definition, which includes the regular expression (Regex) needed to extract that metric from the CloudWatch logs of the training job. Since we are using built-in `BlazingText` algorithm here, it emits two predefined metrics: `train:mean_rho` and `validation:accuracy`, and we elected to monitor `validation:accuracy` as you can see below. In this case, we only need to specify the metric name and do not need to provide regex. If you bring your own algorithm, your algorithm emits metrics by itself. In that case, you'll need to add a `MetricDefinition` object here to define the format of those metrics through regex, so that SageMaker knows how to extract those metrics from your CloudWatch logs." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "objective_metric_name = \"validation:accuracy\"\n", "objective_type = \"Maximize\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, we'll create a `HyperparameterTuner` object, to which we pass:\n", "- The `BlazingText` estimator we created above\n", "- Our hyperparameter ranges\n", "- Objective metric name and definition\n", "- Tuning resource configurations such as Number of training jobs to run in total and how many training jobs can be run in parallel." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tuner = HyperparameterTuner(\n", " estimator,\n", " objective_metric_name,\n", " hyperparameter_ranges,\n", " max_jobs=6,\n", " max_parallel_jobs=2,\n", " objective_type=objective_type,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now that the hyper-parameters are set up, let us prepare the handshake between our data channels and the algorithm. To do this, we need to create the `sagemaker.inputs.TrainingInput` objects from our data channels. These objects are then put in a simple dictionary, which the algorithm consumes." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_data = sagemaker.inputs.TrainingInput(\n", " s3_train_data,\n", " distribution=\"FullyReplicated\",\n", " content_type=\"text/plain\",\n", " s3_data_type=\"S3Prefix\",\n", ")\n", "validation_data = sagemaker.inputs.TrainingInput(\n", " s3_validation_data,\n", " distribution=\"FullyReplicated\",\n", " content_type=\"text/plain\",\n", " s3_data_type=\"S3Prefix\",\n", ")\n", "data_channels = {\"train\": train_data, \"validation\": validation_data}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We have our `Estimator` object, we have set the hyper-parameters for this object, and we have our data channels linked with the algorithm. The only remaining thing to do is to train the algorithm. The following command will train the algorithm. Training the algorithm involves a few steps. Firstly, the instance that we requested while creating the `Estimator` classes is provisioned and is set up with the appropriate libraries. Then, the data from our channels are downloaded into the instance. Once this is done, the training job begins. The provisioning and data downloading will take some time, depending on the size of the data. Therefore, it might be a few minutes before we start getting training logs for our training jobs. The data logs will also print out Accuracy on the validation data for every epoch after training job has executed `min_epochs`. This metric is a proxy for the quality of the algorithm.\n", "\n", "A \"Job complete\" message will be printed once the job has finished. The trained model can be found in the S3 bucket that was set up as `output_path` in the estimator." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Launch hyperparameter tuning job\n", "Now we can launch a hyperparameter tuning job by calling *fit()* function. After the hyperparameter tuning job is created, we can go to SageMaker console to track the progress of the hyperparameter tuning job until it is completed.\n", "\n", "This should take around 12 minutes to complete." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "\n", "tuner.fit(inputs=data_channels, logs=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Analyze Results of a Hyperparameter Tuning job\n", "\n", "Once you have completed a tuning job, (or even while the job is still running) you can use the code below to analyze the results to understand how each hyperparameter effects the quality of the model." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sm_client = boto3.Session().client(\"sagemaker\")\n", "\n", "tuning_job_name = tuner.latest_tuning_job.name\n", "tuning_job_name" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Track hyperparameter tuning job progress\n", "After you launch a tuning job, you can see its progress by calling `describe_tuning_job` API. The output from describe-tuning-job is a JSON object that contains information about the current state of the tuning job. You can call `list_training_jobs_for_tuning_job` to see a detailed list of the training jobs that the tuning job launched." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tuning_job_result = sm_client.describe_hyper_parameter_tuning_job(\n", " HyperParameterTuningJobName=tuning_job_name\n", ")\n", "\n", "status = tuning_job_result[\"HyperParameterTuningJobStatus\"]\n", "if status != \"Completed\":\n", " print(\"Reminder: the tuning job has not been completed.\")\n", "\n", "job_count = tuning_job_result[\"TrainingJobStatusCounters\"][\"Completed\"]\n", "print(\"%d training jobs have completed\" % job_count)\n", "\n", "is_minimize = (\n", " tuning_job_result[\"HyperParameterTuningJobConfig\"][\"HyperParameterTuningJobObjective\"][\"Type\"]\n", " != \"Maximize\"\n", ")\n", "objective_name = tuning_job_result[\"HyperParameterTuningJobConfig\"][\n", " \"HyperParameterTuningJobObjective\"\n", "][\"MetricName\"]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from pprint import pprint\n", "\n", "if tuning_job_result.get(\"BestTrainingJob\", None):\n", " print(\"Best model found so far:\")\n", " pprint(tuning_job_result[\"BestTrainingJob\"])\n", "else:\n", " print(\"No training jobs have reported results yet.\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Fetch all results as `DataFrame`\n", "We can list hyperparameters and objective metrics of all training jobs and pick up the training job with the best objective metric." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "\n", "tuner_analytics = sagemaker.HyperparameterTuningJobAnalytics(tuning_job_name)\n", "\n", "full_df = tuner_analytics.dataframe()\n", "\n", "if len(full_df) > 0:\n", " df = full_df[full_df[\"FinalObjectiveValue\"] > -float(\"inf\")]\n", " if len(df) > 0:\n", " df = df.sort_values(\"FinalObjectiveValue\", ascending=is_minimize)\n", " print(\"Number of training jobs with valid objective: %d\" % len(df))\n", " print({\"lowest\": min(df[\"FinalObjectiveValue\"]), \"highest\": max(df[\"FinalObjectiveValue\"])})\n", " pd.set_option(\"display.max_colwidth\", -1) # Don't truncate TrainingJobName\n", " else:\n", " print(\"No training jobs have reported valid results yet.\")\n", "\n", "df" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Deploy the best trained model\n", "Once the training is done, we can deploy the trained model as an Amazon SageMaker real-time hosted endpoint. This will allow us to make predictions (or inference) from the model. Note that we don't have to host on the same type of instance that we used to train, because usually for inference, less compute power is needed than for training, and in addition, instance endpoints will be up and running for long, it's advisable to choose a cheaper instance for inference.\n", "\n", "- `ml.c4.4xlarge` - Compute Optimized instances are ideal for compute bound applications that benefit from high performance processors.\n", "- `ml.m4.xlarge` - General purpose instances provide a balance of compute, memory and networking resources, and can be used for a variety of diverse workloads." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from sagemaker.serializers import JSONSerializer\n", "\n", "text_classifier = tuner.deploy(\n", " initial_instance_count=1, instance_type=\"ml.m4.xlarge\", serializer=JSONSerializer()\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Use JSON format for inference\n", "BlazingText supports `application/json` as the content-type for inference. The payload should contain a list of sentences with the key as \"**instances**\" while being passed to the endpoint." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sentences = [\n", " \"The modem is an internal AT/(E)ISA 8-bit card (just a little longer than a half-card).\",\n", " \"In the cage I usually wave to bikers. They usually don't wave back. My wife thinks it's strange but I don't care.\",\n", " \"Voyager has the unusual luck to be on a stable trajectory out of the solar system.\",\n", "]\n", "\n", "# using the same processing logic that we used during data preparation for training\n", "processed_sentences = process_text(sentences)\n", "\n", "print(processed_sentences)\n", "\n", "payload = {\"instances\": processed_sentences}\n", "\n", "response = text_classifier.predict(payload)\n", "\n", "predictions = json.loads(response)\n", "print(json.dumps(predictions, indent=2))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "By default, the model will return only one prediction, the one with the highest probability. For retrieving the top k predictions, you can set `k` in the configuration as shown below:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "payload = {\"instances\": processed_sentences, \"configuration\": {\"k\": 2}}\n", "\n", "response = text_classifier.predict(payload)\n", "\n", "predictions = json.loads(response)\n", "print(json.dumps(predictions, indent=2))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Clean up\n", "Endpoints should be deleted when no longer in use, since (per the [SageMaker pricing page](https://aws.amazon.com/sagemaker/pricing/)) they're billed by time deployed.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "text_classifier.delete_endpoint()" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Notebook CI Test Results\n", "\n", "This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.\n", "\n", "![This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-east-1/hyperparameter_tuning|blazingtext_text_classification_20_newsgroups|hpo_blazingtext_text_classification_20_newsgroups.ipynb)\n", "\n", "![This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-east-2/hyperparameter_tuning|blazingtext_text_classification_20_newsgroups|hpo_blazingtext_text_classification_20_newsgroups.ipynb)\n", "\n", "![This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-west-1/hyperparameter_tuning|blazingtext_text_classification_20_newsgroups|hpo_blazingtext_text_classification_20_newsgroups.ipynb)\n", "\n", "![This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ca-central-1/hyperparameter_tuning|blazingtext_text_classification_20_newsgroups|hpo_blazingtext_text_classification_20_newsgroups.ipynb)\n", "\n", "![This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/sa-east-1/hyperparameter_tuning|blazingtext_text_classification_20_newsgroups|hpo_blazingtext_text_classification_20_newsgroups.ipynb)\n", "\n", "![This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-1/hyperparameter_tuning|blazingtext_text_classification_20_newsgroups|hpo_blazingtext_text_classification_20_newsgroups.ipynb)\n", "\n", "![This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-2/hyperparameter_tuning|blazingtext_text_classification_20_newsgroups|hpo_blazingtext_text_classification_20_newsgroups.ipynb)\n", "\n", "![This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-3/hyperparameter_tuning|blazingtext_text_classification_20_newsgroups|hpo_blazingtext_text_classification_20_newsgroups.ipynb)\n", "\n", "![This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-central-1/hyperparameter_tuning|blazingtext_text_classification_20_newsgroups|hpo_blazingtext_text_classification_20_newsgroups.ipynb)\n", "\n", "![This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-north-1/hyperparameter_tuning|blazingtext_text_classification_20_newsgroups|hpo_blazingtext_text_classification_20_newsgroups.ipynb)\n", "\n", "![This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-southeast-1/hyperparameter_tuning|blazingtext_text_classification_20_newsgroups|hpo_blazingtext_text_classification_20_newsgroups.ipynb)\n", "\n", "![This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-southeast-2/hyperparameter_tuning|blazingtext_text_classification_20_newsgroups|hpo_blazingtext_text_classification_20_newsgroups.ipynb)\n", "\n", "![This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-northeast-1/hyperparameter_tuning|blazingtext_text_classification_20_newsgroups|hpo_blazingtext_text_classification_20_newsgroups.ipynb)\n", "\n", "![This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-northeast-2/hyperparameter_tuning|blazingtext_text_classification_20_newsgroups|hpo_blazingtext_text_classification_20_newsgroups.ipynb)\n", "\n", "![This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-south-1/hyperparameter_tuning|blazingtext_text_classification_20_newsgroups|hpo_blazingtext_text_classification_20_newsgroups.ipynb)\n" ] } ], "metadata": { "instance_type": "ml.t3.medium", "kernelspec": { "display_name": "Python 3 (Data Science)", "language": "python", "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-1:081325390199:image/datascience-1.0" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.10" }, "notice": "Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License." }, "nbformat": 4, "nbformat_minor": 4 }