notebooks/classification/logistic_regression/wine_quality_classification_mlr.ipynb (865 lines of code) (raw):
{
"cells": [
{
"cell_type": "markdown",
"id": "a0293b07-d50b-440f-a8e9-255aa925a6d2",
"metadata": {
"tags": []
},
"source": [
"# Wine Quality Classification"
]
},
{
"cell_type": "markdown",
"id": "3a464362-a480-4af5-9da9-80d1d7fb32aa",
"metadata": {},
"source": [
"<table align=\"left\">\n",
"\n",
"<a href=\"https://github.com/GoogleCloudPlatform/ai-ml-recipes/blob/main/notebooks/classification/logistic_regression/wine_quality_classification_mlr.ipynb\">\n",
"<img src=\"https://cloud.google.com/ml-engine/images/github-logo-32px.png\" alt=\"GitHub logo\">\n",
"View on GitHub\n",
"</a>\n",
"</td>\n",
"<td>\n",
"<a href=\"https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/ai-ml-recipes/main/notebooks/classification/logistic_regression/wine_quality_classification_mlr.ipynb\">\n",
"<img src=\"https://lh3.googleusercontent.com/UiNooY4LUgW_oTvpsNhPpQzsstV5W8F7rYgxgGBD85cWJoLmrOzhVs_ksK_vgx40SHs7jCqkTkCk=e14-rj-sc0xffffff-h130-w32\" alt=\"Vertex AI logo\">\n",
"Open in Vertex AI Workbench\n",
"</a>\n",
"</td>\n",
"</table>"
]
},
{
"cell_type": "markdown",
"id": "7bb9457a-4696-4ebb-bb20-697436d93e82",
"metadata": {},
"source": [
"## Overview"
]
},
{
"cell_type": "markdown",
"id": "e59039a0-b3c3-4023-830f-662605bea056",
"metadata": {
"jp-MarkdownHeadingCollapsed": true,
"tags": []
},
"source": [
"### Multinomial Logistic Regression\n",
"\n",
"<strong>Type</strong>: Classification </p>\n",
"<strong>UCI Open Source Dataset</strong>: [Wine Quality](https://archive.ics.uci.edu/dataset/186/wine+quality) </p>\n",
"\n",
"This dataset contains red and white vinho verde wine samples, from the north of Portugal, and wine quality data based on physicochemical tests [Cortez et al., 2009](http://www3.dsi.uminho.pt/pcortez/wine/). \n",
"\n",
"<strong>Problem</strong>: Imagine you are a wine specialist who is looking for an automated way to categorize the wines you find based on wine quality data from physicochemical tests. You could use a machine learning algorithm to train a model that would be able to predict the quality of a wine based on its physicochemical properties. This would allow you to quickly and easily categorize new wines that you find, without having to manually taste them.\n",
"\n",
"Here are some of the benefits of using an automated wine categorization system:\n",
"\n",
"- <strong>Speed</strong>: An automated system can categorize wines much faster than a human can. This is especially beneficial for wine retailers and distributors who need to quickly categorize large numbers of wines.\n",
"- <strong>Accuracy</strong>: An automated system can be more accurate than a human when it comes to categorizing wines. This is because the system is not influenced by personal biases or preferences.\n",
"- <strong>Consistency</strong>: An automated system will consistently categorize wines in the same way, which can help to ensure that customers are getting the wines they expect.\n",
"\n",
"If you are a wine specialist who is looking for an efficient and accurate way to categorize wines, then an automated system may be the perfect solution for you."
]
},
{
"cell_type": "markdown",
"id": "1f8e5893-eb68-4721-b083-da5b40f78ac4",
"metadata": {},
"source": [
"## Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5377b0f3-bfdd-4153-a4b4-889962048eb4",
"metadata": {},
"outputs": [],
"source": [
"from pyspark.sql import SparkSession\n",
"from pyspark.sql.functions import col, countDistinct, isnan, when, count, udf\n",
"from pyspark.sql.types import StringType\n",
"\n",
"from pyspark.mllib.stat import Statistics\n",
"from pyspark.ml.feature import StringIndexer, StandardScaler\n",
"from pyspark.ml.classification import LogisticRegression\n",
"from pyspark.ml.linalg import Vectors"
]
},
{
"cell_type": "markdown",
"id": "16aae3c7-b31d-42d6-bbe1-defe4057b24f",
"metadata": {},
"source": [
"### Load dataset from public metastore"
]
},
{
"cell_type": "markdown",
"id": "7cd099d7-9b79-412f-9cc6-421b281b2d14",
"metadata": {
"tags": []
},
"source": [
"## Exploratory Data Analysis"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5eee0eaa-4890-4a7a-bd96-fb59ee658b38",
"metadata": {},
"outputs": [],
"source": [
"spark = SparkSession.builder \\\n",
" .appName(\"Multinomial logistic regression Wine Quality\") \\\n",
" .enableHiveSupport() \\\n",
" .getOrCreate()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d7aefa1a-12f7-41cd-81d6-f94610f157b7",
"metadata": {},
"outputs": [],
"source": [
"df = spark.read.option(\"header\", True).csv(\"gs://dataproc-metastore-public-binaries/winequality_white/\")"
]
},
{
"cell_type": "markdown",
"id": "0690c904",
"metadata": {
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"source": [
"|fixed_acidity|volatile_acidity|citric_acid|residual_sugar|chlorides|free_sulfur_dioxide|total_sulfur_dioxide|density| pH|sulphates|alcohol|quality|\n",
"|-------------|----------------|-----------|--------------|---------|-------------------|--------------------|-------|----|---------|-------|-------|\n",
"| 7.0| 0.27| 0.36| 20.7| 0.045| 45.0| 170.0| 1.001| 3.0| 0.45| 8.8| 6|\n",
"| 6.3| 0.3| 0.34| 1.6| 0.049| 14.0| 132.0| 0.994| 3.3| 0.49| 9.5| 6|\n",
"| 8.1| 0.28| 0.4| 6.9| 0.05| 30.0| 97.0| 0.9951|3.26| 0.44| 10.1| 6|\n",
"| 7.2| 0.23| 0.32| 8.5| 0.058| 47.0| 186.0| 0.9956|3.19| 0.4| 9.9| 6|\n",
"| 7.2| 0.23| 0.32| 8.5| 0.058| 47.0| 186.0| 0.9956|3.19| 0.4| 9.9| 6|"
]
},
{
"cell_type": "markdown",
"id": "ee393d03-4f05-42ce-b9bd-d773a57f6494",
"metadata": {},
"source": [
"### DataFrame Column Data Types"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b1fab229-deb4-49e8-8754-3b03f707a098",
"metadata": {},
"outputs": [],
"source": [
"df = df.select(*(col(c).cast(\"float\").alias(c) for c in df.columns))\n",
"df = df.withColumn(\"quality\", col(\"quality\").cast(\"int\"))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3649f263-ea02-4660-91bf-5cec489dd1fd",
"metadata": {},
"outputs": [],
"source": [
"df.printSchema()"
]
},
{
"cell_type": "markdown",
"id": "329201b2-9b22-4c41-940f-cefdfc92b525",
"metadata": {},
"source": [
"### Summary Statistics \n",
"\n",
"At this point, we have all columns contains numerical values. For features which contain numerical values, we are often interested in various statistical measures relating to those values."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b9378955-696c-4a1f-a44f-6401dbd5d5e2",
"metadata": {},
"outputs": [],
"source": [
"df.describe().show(5,8)"
]
},
{
"cell_type": "markdown",
"id": "ff21c0b9-1b62-4e4c-b9c8-673d802ef0c5",
"metadata": {},
"source": [
"Let's investigate a bit more of our target data by using the .groupby() function."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f60cd56e-1d9f-43f3-84f4-8a0b696a620b",
"metadata": {},
"outputs": [],
"source": [
"df.groupby(\n",
" col('quality')).\\\n",
" count().\\\n",
" show(5,50)"
]
},
{
"cell_type": "markdown",
"id": "658e23ae-3904-4fdc-99b7-539cf4aa5609",
"metadata": {},
"source": [
"We can see here that the data is <b>imbalanced</b> for our target. <b>Imbalanced</b> data is a common problem in machine learning, where the number of samples in one class is much larger than the number of samples in another class. This can make it difficult to train a model that can accurately predict the minority class. There are a number of techniques that can be used to handle imbalanced data, including:\n",
"\n",
"- <b>Resampling</b>: This involves increasing the number of samples in the minority class or decreasing the number of samples in the majority class. This can be done by oversampling the minority class (creating new samples), undersampling the majority class (removing samples), or a combination of both.\n",
"- <b>Cost-sensitive learning</b>: This involves assigning different costs to misclassifications of different classes. This can help to focus the model on correctly classifying the minority class.\n",
"- <b>Ensemble learning</b>: This involves training multiple models on different subsets of the data and then combining the predictions of the models. This can help to improve the accuracy of the model on the minority class."
]
},
{
"cell_type": "markdown",
"id": "101507c8-b4ab-432d-8b0f-0eaeadbc55d7",
"metadata": {},
"source": [
"We need to <b>resample</b> the data to balance the dataset. However, before we do that, we need to check if there are any issues with the data that need to be resolved. For example, we need to make sure that there are no missing values in the data. We also need to make sure that the data is not corrupted. Once we have resolved any issues with the data, we can then resample it to balance the dataset."
]
},
{
"cell_type": "markdown",
"id": "8ddec5b6-35ec-4142-9045-4a3e049fdec2",
"metadata": {},
"source": [
"### Let's summarize our data by row, column, features, unique, and missing values."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e6a4a9c8-11cf-4652-91ce-56af6232bae5",
"metadata": {},
"outputs": [],
"source": [
"print (\"Rows : \" ,df.count())\n",
"print (\"Columns : \" ,len(df.columns))\n",
"print (\"\\nFeatures : \\n\" ,df.columns)\n",
"print (\"\\n Count Distinct values : \", \"\")\n",
"expression = [countDistinct(c).alias(c) for c in df.columns]\n",
"print (\"\\nUnique values : \\n\", df.select(*expression).show())\n",
"print (\"\\nMissing values : \", \"\")\n",
"df.select([count(when(isnan(c) | col(c).isNull(), c)).alias(c) for c in df.columns]).show()"
]
},
{
"cell_type": "markdown",
"id": "90f22c89-83b6-4820-a72e-5eb12e393ee1",
"metadata": {},
"source": [
"There no missing values, or other data issue. So we can ressample the data."
]
},
{
"cell_type": "markdown",
"id": "7456c90b-303f-411f-8299-9857ca1c1fea",
"metadata": {},
"source": [
"### Distribution of Features"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e9ee5820-cc29-410e-9d69-bbd774248e95",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"fig = plt.figure(figsize=(20,10))\n",
"st = fig.suptitle(\"Distribution of Features\",\n",
" fontsize=20,\n",
" verticalalignment='center')\n",
"\n",
"for col,num in zip(df.toPandas().describe().columns, range(1,11)):\n",
" ax = fig.add_subplot(3,4,num)\n",
" ax.hist(df.toPandas()[col])\n",
" plt.grid(False)\n",
" plt.xticks(rotation=45,fontsize=10)\n",
" plt.yticks(fontsize=10)\n",
" plt.title(col.upper(),fontsize=20)\n",
"\n",
"plt.tight_layout()\n",
"st.set_y(0.95)\n",
"fig.subplots_adjust(top=0.85,hspace = 0.4)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "a040f021-3f5b-40e5-98ed-a5069f133578",
"metadata": {},
"source": [
"Great part of freatures had a normal distribution, also known as the Gaussian distribution, is a continuous probability distribution that is widely used in statistical modeling and machine learning. It is a bell-shaped curve that is symmetrical around its mean and is characterized by its mean and standard deviation.\n",
"\n",
"In machine learning, data that is normally distributed is beneficial for model building because it makes the math easier. Many machine learning algorithms, such as linear regression and logistic regression, are explicitly calculated from the assumption that the distribution is a bivariate or multivariate normal. Additionally, sigmoid functions work most naturally with normally distributed data."
]
},
{
"cell_type": "markdown",
"id": "24e7c3c3-c5b4-43e1-a283-bcaa9dc232a9",
"metadata": {},
"source": [
"### Pearson Correlation"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "afe091c9-b28f-43ba-aa47-edfb8abb0acd",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"\n",
"col_names = df.columns\n",
"features = df.rdd.map(lambda row: row[0:])\n",
"corr_mat= Statistics.corr(features, method=\"pearson\")\n",
"corr_df = pd.DataFrame(corr_mat)\n",
"corr_df.index, corr_df.columns = col_names, col_names\n",
"\n",
"corr_df"
]
},
{
"cell_type": "markdown",
"id": "babc3903-8764-4766-97dd-80f690da76b0",
"metadata": {},
"source": [
"A Pearson correlation coefficient of 0.7 or greater is generally considered to be a strong correlation. This means that there is a high degree of linear relationship between the two variables. For example, if the correlation coefficient between height and weight is 0.7, this means that taller people tend to be heavier, on average.\n",
"\n",
"However, it is important to note that the strength of a correlation coefficient can vary depending on the specific variables being measured. For example, a correlation coefficient of 0.7 might be considered strong for some variables, but weak for others.\n",
"\n",
"Here we could see that residual_sugar have a <b>strong correlation</b> with <b>density</b>."
]
},
{
"cell_type": "markdown",
"id": "637ddba0-5fae-4383-ba2d-cc3461555216",
"metadata": {},
"source": [
"## Feature engineering"
]
},
{
"cell_type": "markdown",
"id": "0aa2e35c-7306-4e3a-a6da-ced25b3a4271",
"metadata": {},
"source": [
"<b>Feature engineering</b> is the process of transforming raw data into features that are more informative and useful for machine learning algorithms. This can involve a variety of tasks, such as:\n",
"\n",
"- <b>Data transformation</b>: This involves transforming the data into a format that is more suitable for machine learning algorithms. For example, categorical data can be encoded as numerical data, and continuous data can be discretized.\n",
"- <b>Feature selection</b>: This involves selecting the most important features from the data set. This can be done using a variety of techniques, such as statistical significance tests and feature importance scores.\n",
"- <b>Feature creation</b>: This involves creating new features from the existing data. This can be done by combining existing features, or by creating derived features that are based on the relationships between different features."
]
},
{
"cell_type": "markdown",
"id": "e62cde88-e34d-4fe7-8958-b11595e1c22b",
"metadata": {},
"source": [
"### Bucketing a quality label"
]
},
{
"cell_type": "markdown",
"id": "d1da8857-0aac-446a-8334-700846c50a38",
"metadata": {},
"source": [
"- Poor for less than 3\n",
"- Poor between 4 and 5 (inclusive)\n",
"- Good between 6 and 7\n",
"- Excelent for more than 7"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e3a493c4-890c-4956-a5b7-15441a8f55e6",
"metadata": {},
"outputs": [],
"source": [
"@udf(returnType=StringType())\n",
"def quality_score_to_label(score: int):\n",
" if score >=0 and score < 4:\n",
" return 'poor'\n",
" elif score >=4 and score <=5:\n",
" return 'normal'\n",
" elif score >=6 and score <=7:\n",
" return 'good'\n",
" else: \n",
" return 'excelent'"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "36151911-2585-4bd4-b9a4-b8ba4269f4c8",
"metadata": {},
"outputs": [],
"source": [
"df.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f5f0363a-ba47-4ea1-b068-e7863f3e7d96",
"metadata": {},
"outputs": [],
"source": [
"df.printSchema()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f9731bd6-51b9-4dbc-ab25-10aed0667f14",
"metadata": {},
"outputs": [],
"source": [
"df_buc = df.withColumn('quality_label', quality_score_to_label('quality'))\n",
"df_buc = df_buc.drop('quality')\n",
"df_buc.show()"
]
},
{
"cell_type": "markdown",
"id": "36d2123e-de88-486e-b3ba-c61dd3878397",
"metadata": {},
"source": [
"### LabelIndexer"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5705db46-a58c-4387-a11d-bf86ce02b07c",
"metadata": {},
"outputs": [],
"source": [
"label_indexer = StringIndexer()\\\n",
" .setInputCol (\"quality_label\")\\\n",
" .setOutputCol (\"quality\")\n",
"\n",
"label_indexer_model = label_indexer.fit(df_buc)\n",
"label_indexer_df = label_indexer_model.transform(df_buc)\n",
"\n",
"label_indexer_df = label_indexer_df.drop('quality_label')\n",
"\n",
"label_indexer_df.show(5,50)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "84086f8f-1c9d-4dea-8370-8136607f0c48",
"metadata": {},
"outputs": [],
"source": [
"label_indexer_df.groupby('quality').\\\n",
" count().\\\n",
" show(5,50)"
]
},
{
"cell_type": "markdown",
"id": "9014e0be-2875-4a65-9346-a2acb841df1e",
"metadata": {},
"source": [
"### Resampling [Optional step]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "abb63754-1120-4e01-8695-a6f0bf4ab59d",
"metadata": {},
"outputs": [],
"source": [
"from imblearn.over_sampling import SMOTE\n",
"\n",
"X, y = label_indexer_df.toPandas().iloc[:, :-1], label_indexer_df.toPandas().iloc[:, [-1]]\n",
"sm = SMOTE(k_neighbors=6)\n",
"X_res, y_res = sm.fit_resample(X, y)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4b65f558-175e-4f2f-94f4-ce450143d5ae",
"metadata": {},
"outputs": [],
"source": [
"df_res = spark.createDataFrame(pd.concat([X_res, y_res], axis=1))\n",
"df_res.show(5,50)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "85f8830c-41e0-45e8-8fd1-6e699ec47824",
"metadata": {},
"outputs": [],
"source": [
"df_res.groupby('quality').\\\n",
" count().\\\n",
" show(5,50)"
]
},
{
"cell_type": "markdown",
"id": "1ecb7ca4-83cd-4c53-b292-92d7fb020b82",
"metadata": {},
"source": [
"### Vectoring to prepare Data for Machine Learning"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b8817295-3915-45e1-aee0-e4176964b296",
"metadata": {},
"outputs": [],
"source": [
"df_vec= df_res.rdd.map(lambda x:(Vectors.dense(x[0:-1]), x[-1])).toDF([\"vectorized_features\", \"label\"])\n",
"df_vec.show(5,50)"
]
},
{
"cell_type": "markdown",
"id": "306b45d3-75d8-44ef-885c-694af2dcf9e7",
"metadata": {},
"source": [
"### Standardization"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b64d1119-55dc-4586-a6d3-e85d4c0b9ab8",
"metadata": {},
"outputs": [],
"source": [
"scaler = StandardScaler()\\\n",
" .setInputCol (\"vectorized_features\")\\\n",
" .setOutputCol (\"features\")\n",
" \n",
"scaler_model = scaler.fit(df_vec)\n",
"scaler_df = scaler_model.transform(df_vec)\n",
"\n",
"scaler_df = scaler_df.select('features', 'label')\n",
"scaler_df.show(5,50)"
]
},
{
"cell_type": "markdown",
"id": "c9f8e115-9797-490f-a49a-5b577b492b6b",
"metadata": {},
"source": [
"## Model Choice"
]
},
{
"cell_type": "markdown",
"id": "74f15d40-f150-4ade-a1d4-2e87c974420b",
"metadata": {},
"source": [
"<b>Multinomial logistic regression</b> is a type of logistic regression that can be used for multi-class classification problems. In the case of wine quality classification, there are 4 classes (poor, normal, good and excelent) so multinomial logistic regression is a good choice for modeling this problem.\n",
"\n",
"The physicochemical tests can be used to measure the various properties of wine, such as acidity, alcohol content, and sugar content. These properties can then be used as features in the multinomial logistic regression model.\n",
"\n",
"Here are some of the advantages of using multinomial logistic regression for wine quality classification:\n",
"\n",
"- It is a relatively simple model that is easy to understand and interpret.\n",
"- It is a very flexible model that can be used to model a variety of different types of data.\n",
"- It is a very efficient model that can be estimated quickly and easily.\n",
"\n",
"Here is some of the disadvantages of using multinomial logistic regression for wine quality classification:\n",
"\n",
"- It may not be as accurate as some other models, such as support vector machines or decision trees.\n",
"- It may not be able to capture the nonlinear relationships between the features and the class labels.\n",
"\n",
"In addition to multinomial logistic regression, there are a number of other models that could be used for wine quality classification. Some of these other models include support vector machines, decision trees, and random forests. However, multinomial logistic regression is a good starting point for wine quality classification because it is a simple, flexible, and efficient model.\n"
]
},
{
"cell_type": "markdown",
"id": "88e5a896-ebf7-41d1-9087-ed0e99e2161d",
"metadata": {},
"source": [
"## Model Training"
]
},
{
"cell_type": "markdown",
"id": "ec542669-b32b-4874-bdc9-b88dc799da1c",
"metadata": {},
"source": [
"### Train/Test Split data "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d6143644-a1ce-4cc5-aa53-5f80975ec77e",
"metadata": {},
"outputs": [],
"source": [
"# Split training and test data\n",
"training, test = scaler_df.randomSplit([0.8, 0.2])\n",
"print (\"Training instances\", training.count(), \"Test instances\", test.count())"
]
},
{
"cell_type": "markdown",
"id": "5d352991-de8d-4004-b99f-970c78540180",
"metadata": {
"tags": []
},
"source": [
"### Model Training phase"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1359b31a-56b6-4cda-8005-f5cdf86532d6",
"metadata": {},
"outputs": [],
"source": [
"lr = LogisticRegression(\n",
" family=\"multinomial\", \n",
" featuresCol = 'features', \n",
" labelCol = 'label',\n",
" maxIter=200,\n",
" elasticNetParam=1.0, \n",
" tol=1e-6, \n",
" standardization=True,\n",
" fitIntercept=True\n",
")\n",
" \n",
"# Fit the model\n",
"lrModel = lr.fit(training)\n",
"\n",
"# Print the coefficients and intercept for multinomial logistic regression\n",
"print(\"Coefficients: \\n\" + str(lrModel.coefficientMatrix))\n",
"print(\"Intercept: \" + str(lrModel.interceptVector))\n",
"\n",
"trainingSummary = lrModel.summary\n",
"\n",
"# for multiclass, we can inspect metrics on a per-label basis\n",
"print(\"False positive rate by label:\")\n",
"for i, rate in enumerate(trainingSummary.falsePositiveRateByLabel):\n",
" print(\"label %d: %s\" % (i, rate))\n",
"\n",
"print(\"True positive rate by label:\")\n",
"for i, rate in enumerate(trainingSummary.truePositiveRateByLabel):\n",
" print(\"label %d: %s\" % (i, rate))\n",
"\n",
"print(\"Precision by label:\")\n",
"for i, prec in enumerate(trainingSummary.precisionByLabel):\n",
" print(\"label %d: %s\" % (i, prec))\n",
"\n",
"print(\"Recall by label:\")\n",
"for i, rec in enumerate(trainingSummary.recallByLabel):\n",
" print(\"label %d: %s\" % (i, rec))\n",
"\n",
"print(\"F-measure by label:\")\n",
"for i, f in enumerate(trainingSummary.fMeasureByLabel()):\n",
" print(\"label %d: %s\" % (i, f))\n",
"\n",
"accuracy = trainingSummary.accuracy\n",
"falsePositiveRate = trainingSummary.weightedFalsePositiveRate\n",
"truePositiveRate = trainingSummary.weightedTruePositiveRate\n",
"fMeasure = trainingSummary.weightedFMeasure()\n",
"precision = trainingSummary.weightedPrecision\n",
"recall = trainingSummary.weightedRecall\n",
"print(\"Accuracy: %s\\nFPR: %s\\nTPR: %s\\nF-measure: %s\\nPrecision: %s\\nRecall: %s\"\n",
" % (accuracy, falsePositiveRate, truePositiveRate, fMeasure, precision, recall))"
]
},
{
"cell_type": "markdown",
"id": "c01713d7-4e24-493f-a56b-a76a7e31c473",
"metadata": {},
"source": [
"## Model Evaluation"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "526b2585-a6d4-406a-806e-abece14959fd",
"metadata": {},
"outputs": [],
"source": [
"predictions = lrModel.transform(test)"
]
},
{
"cell_type": "markdown",
"id": "e2112684-c70c-4500-887b-4a2e5a9c2fff",
"metadata": {
"tags": []
},
"source": [
"### Confusion Matrix"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dc951bd4-be34-4ec0-bfb9-f0e84365295d",
"metadata": {},
"outputs": [],
"source": [
"class_names=list([0,1,2,3])\n",
"\n",
"import itertools\n",
"\n",
"def plot_confusion_matrix(cm, classes,\n",
" normalize=False,\n",
" title='Confusion matrix',\n",
" cmap=plt.cm.Blues):\n",
" if normalize:\n",
" cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]\n",
" print(\"Normalized confusion matrix\")\n",
" else:\n",
" print('Confusion matrix, without normalization')\n",
"\n",
" print(cm)\n",
"\n",
" plt.imshow(cm, interpolation='nearest', cmap=cmap)\n",
" plt.title(title)\n",
" plt.colorbar()\n",
" tick_marks = np.arange(len(classes))\n",
" plt.xticks(tick_marks, classes, rotation=45)\n",
" plt.yticks(tick_marks, classes)\n",
"\n",
" fmt = '.2f' if normalize else 'd'\n",
" thresh = cm.max() / 2.\n",
" for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):\n",
" plt.text(j, i, format(cm[i, j], fmt),\n",
" horizontalalignment=\"center\",\n",
" color=\"white\" if cm[i, j] > thresh else \"black\")\n",
"\n",
" plt.tight_layout()\n",
" plt.ylabel('True label')\n",
" plt.xlabel('Predicted label')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "268757d2-09a8-422e-9097-df2420a6f7ab",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.metrics import confusion_matrix\n",
"import numpy as np\n",
"\n",
"y_true = predictions.select(\"label\")\n",
"y_true = y_true.toPandas()\n",
"\n",
"y_pred = predictions.select(\"prediction\")\n",
"y_pred = y_pred.toPandas()\n",
"\n",
"cnf_matrix = confusion_matrix(\n",
" y_true,\n",
" y_pred,\n",
" labels=class_names\n",
")\n",
"\n",
"plt.figure()\n",
"plot_confusion_matrix(\n",
" cnf_matrix,\n",
" classes=class_names,\n",
" title='Confusion matrix'\n",
")\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "97723165-a6ba-4753-9230-3c108759917c",
"metadata": {},
"source": [
"## Prediction"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "00f680a2-e97b-40d1-b9ae-7b45cbd7bf5f",
"metadata": {},
"outputs": [],
"source": [
"predictions = predictions.withColumn('prediction', quality_score_to_label('prediction'))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8cec4da7-e539-4245-a5c0-68a40f3aef43",
"metadata": {},
"outputs": [],
"source": [
"predictions.select(\n",
" 'label',\n",
" 'features',\n",
" 'rawPrediction',\n",
" 'prediction',\n",
" 'probability'\n",
").toPandas()\\\n",
".head(10)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 on cluster-dpms (Remote)",
"language": "python",
"name": "7a42ae6b283a613cfabce6f7-python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}