community-artifacts/Model-selection/Cross_validation_v1.ipynb (1,130 lines of code) (raw):

{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Cross validation using general function\n", "\n", "Examples for \n", "http://madlib.apache.org/docs/latest/group__grp__validation.html" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/fmcquillan/anaconda/lib/python2.7/site-packages/IPython/config.py:13: ShimWarning: The `IPython.config` package has been deprecated since IPython 4.0. You should import from traitlets.config instead.\n", " \"You should import from traitlets.config instead.\", ShimWarning)\n", "/Users/fmcquillan/anaconda/lib/python2.7/site-packages/IPython/utils/traitlets.py:5: UserWarning: IPython.utils.traitlets has moved to a top-level traitlets package.\n", " warn(\"IPython.utils.traitlets has moved to a top-level traitlets package.\")\n" ] } ], "source": [ "%load_ext sql" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "u'Connected: fmcquillan@madlib'" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Greenplum Database 5.4.0 on GCP (demo machine)\n", "#%sql postgresql://gpadmin@35.184.232.200:5432/madlib\n", " \n", "# PostgreSQL local\n", "%sql postgresql://fmcquillan@localhost:5432/madlib" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1 rows affected.\n" ] }, { "data": { "text/html": [ "<table>\n", " <tr>\n", " <th>version</th>\n", " </tr>\n", " <tr>\n", " <td>MADlib version: 1.15.1, git revision: rc/1.15.1-rc1, cmake configuration time: Wed Oct 10 04:29:25 UTC 2018, build type: Release, build system: Darwin-17.7.0, C compiler: Clang, C++ compiler: Clang</td>\n", " </tr>\n", "</table>" ], "text/plain": [ "[(u'MADlib version: 1.15.1, git revision: rc/1.15.1-rc1, cmake configuration time: Wed Oct 10 04:29:25 UTC 2018, build type: Release, build system: Darwin-17.7.0, C compiler: Clang, C++ compiler: Clang',)]" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%sql select madlib.version();\n", "#%sql select version();" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 1. Create data set\n", "House prices and characteristics." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Done.\n", "Done.\n", "36 rows affected.\n", "36 rows affected.\n" ] }, { "data": { "text/html": [ "<table>\n", " <tr>\n", " <th>id</th>\n", " <th>tax</th>\n", " <th>bedroom</th>\n", " <th>bath</th>\n", " <th>size</th>\n", " <th>lot</th>\n", " <th>zipcode</th>\n", " <th>price</th>\n", " <th>high_priced</th>\n", " </tr>\n", " <tr>\n", " <td>1</td>\n", " <td>590</td>\n", " <td>2</td>\n", " <td>1.0</td>\n", " <td>770</td>\n", " <td>22100</td>\n", " <td>94301</td>\n", " <td>50000</td>\n", " <td>False</td>\n", " </tr>\n", " <tr>\n", " <td>2</td>\n", " <td>1050</td>\n", " <td>3</td>\n", " <td>2.0</td>\n", " <td>1410</td>\n", " <td>12000</td>\n", " <td>94301</td>\n", " <td>85000</td>\n", " <td>False</td>\n", " </tr>\n", " <tr>\n", " <td>3</td>\n", " <td>20</td>\n", " <td>3</td>\n", " <td>1.0</td>\n", " <td>1060</td>\n", " <td>3500</td>\n", " <td>94301</td>\n", " <td>22500</td>\n", " <td>False</td>\n", " </tr>\n", " <tr>\n", " <td>4</td>\n", " <td>870</td>\n", " <td>2</td>\n", " <td>2.0</td>\n", " <td>1300</td>\n", " <td>17500</td>\n", " <td>94301</td>\n", " <td>90000</td>\n", " <td>False</td>\n", " </tr>\n", " <tr>\n", " <td>5</td>\n", " <td>1320</td>\n", " <td>3</td>\n", " <td>2.0</td>\n", " <td>1500</td>\n", " <td>30000</td>\n", " <td>94301</td>\n", " <td>133000</td>\n", " <td>True</td>\n", " </tr>\n", " <tr>\n", " <td>6</td>\n", " <td>1350</td>\n", " <td>2</td>\n", " <td>1.0</td>\n", " <td>820</td>\n", " <td>25700</td>\n", " <td>94301</td>\n", " <td>90500</td>\n", " <td>False</td>\n", " </tr>\n", " <tr>\n", " <td>7</td>\n", " <td>2790</td>\n", " <td>3</td>\n", " <td>2.5</td>\n", " <td>2130</td>\n", " <td>25000</td>\n", " <td>94301</td>\n", " <td>260000</td>\n", " <td>True</td>\n", " </tr>\n", " <tr>\n", " <td>8</td>\n", " <td>680</td>\n", " <td>2</td>\n", " <td>1.0</td>\n", " <td>1170</td>\n", " <td>22000</td>\n", " <td>94301</td>\n", " <td>142500</td>\n", " <td>True</td>\n", " </tr>\n", " <tr>\n", " <td>9</td>\n", " <td>1840</td>\n", " <td>3</td>\n", " <td>2.0</td>\n", " <td>1500</td>\n", " <td>19000</td>\n", " <td>94301</td>\n", " <td>160000</td>\n", " <td>True</td>\n", " </tr>\n", " <tr>\n", " <td>10</td>\n", " <td>3680</td>\n", " <td>4</td>\n", " <td>2.0</td>\n", " <td>2790</td>\n", " <td>20000</td>\n", " <td>94301</td>\n", " <td>240000</td>\n", " <td>True</td>\n", " </tr>\n", " <tr>\n", " <td>11</td>\n", " <td>1660</td>\n", " <td>3</td>\n", " <td>1.0</td>\n", " <td>1030</td>\n", " <td>17500</td>\n", " <td>94301</td>\n", " <td>87000</td>\n", " <td>False</td>\n", " </tr>\n", " <tr>\n", " <td>12</td>\n", " <td>1620</td>\n", " <td>3</td>\n", " <td>2.0</td>\n", " <td>1250</td>\n", " <td>20000</td>\n", " <td>94301</td>\n", " <td>118600</td>\n", " <td>True</td>\n", " </tr>\n", " <tr>\n", " <td>13</td>\n", " <td>3100</td>\n", " <td>3</td>\n", " <td>2.0</td>\n", " <td>1760</td>\n", " <td>38000</td>\n", " <td>94301</td>\n", " <td>140000</td>\n", " <td>True</td>\n", " </tr>\n", " <tr>\n", " <td>14</td>\n", " <td>2070</td>\n", " <td>2</td>\n", " <td>3.0</td>\n", " <td>1550</td>\n", " <td>14000</td>\n", " <td>94301</td>\n", " <td>148000</td>\n", " <td>True</td>\n", " </tr>\n", " <tr>\n", " <td>15</td>\n", " <td>650</td>\n", " <td>3</td>\n", " <td>1.5</td>\n", " <td>1450</td>\n", " <td>12000</td>\n", " <td>94301</td>\n", " <td>65000</td>\n", " <td>False</td>\n", " </tr>\n", " <tr>\n", " <td>16</td>\n", " <td>770</td>\n", " <td>2</td>\n", " <td>2.0</td>\n", " <td>1300</td>\n", " <td>17500</td>\n", " <td>76010</td>\n", " <td>91000</td>\n", " <td>False</td>\n", " </tr>\n", " <tr>\n", " <td>17</td>\n", " <td>1220</td>\n", " <td>3</td>\n", " <td>2.0</td>\n", " <td>1500</td>\n", " <td>30000</td>\n", " <td>76010</td>\n", " <td>132300</td>\n", " <td>True</td>\n", " </tr>\n", " <tr>\n", " <td>18</td>\n", " <td>1150</td>\n", " <td>2</td>\n", " <td>1.0</td>\n", " <td>820</td>\n", " <td>25700</td>\n", " <td>76010</td>\n", " <td>91100</td>\n", " <td>False</td>\n", " </tr>\n", " <tr>\n", " <td>19</td>\n", " <td>2690</td>\n", " <td>3</td>\n", " <td>2.5</td>\n", " <td>2130</td>\n", " <td>25000</td>\n", " <td>76010</td>\n", " <td>260011</td>\n", " <td>True</td>\n", " </tr>\n", " <tr>\n", " <td>20</td>\n", " <td>780</td>\n", " <td>2</td>\n", " <td>1.0</td>\n", " <td>1170</td>\n", " <td>22000</td>\n", " <td>76010</td>\n", " <td>141800</td>\n", " <td>True</td>\n", " </tr>\n", " <tr>\n", " <td>21</td>\n", " <td>1910</td>\n", " <td>3</td>\n", " <td>2.0</td>\n", " <td>1500</td>\n", " <td>19000</td>\n", " <td>76010</td>\n", " <td>160900</td>\n", " <td>True</td>\n", " </tr>\n", " <tr>\n", " <td>22</td>\n", " <td>3600</td>\n", " <td>4</td>\n", " <td>2.0</td>\n", " <td>2790</td>\n", " <td>20000</td>\n", " <td>76010</td>\n", " <td>239000</td>\n", " <td>True</td>\n", " </tr>\n", " <tr>\n", " <td>23</td>\n", " <td>1600</td>\n", " <td>3</td>\n", " <td>1.0</td>\n", " <td>1030</td>\n", " <td>17500</td>\n", " <td>76010</td>\n", " <td>81010</td>\n", " <td>False</td>\n", " </tr>\n", " <tr>\n", " <td>24</td>\n", " <td>1590</td>\n", " <td>3</td>\n", " <td>2.0</td>\n", " <td>1250</td>\n", " <td>20000</td>\n", " <td>76010</td>\n", " <td>117910</td>\n", " <td>False</td>\n", " </tr>\n", " <tr>\n", " <td>25</td>\n", " <td>3200</td>\n", " <td>3</td>\n", " <td>2.0</td>\n", " <td>1760</td>\n", " <td>38000</td>\n", " <td>76010</td>\n", " <td>141100</td>\n", " <td>True</td>\n", " </tr>\n", " <tr>\n", " <td>26</td>\n", " <td>2270</td>\n", " <td>2</td>\n", " <td>3.0</td>\n", " <td>1550</td>\n", " <td>14000</td>\n", " <td>76010</td>\n", " <td>148011</td>\n", " <td>True</td>\n", " </tr>\n", " <tr>\n", " <td>27</td>\n", " <td>750</td>\n", " <td>3</td>\n", " <td>1.5</td>\n", " <td>1450</td>\n", " <td>12000</td>\n", " <td>76010</td>\n", " <td>66000</td>\n", " <td>False</td>\n", " </tr>\n", " <tr>\n", " <td>28</td>\n", " <td>2690</td>\n", " <td>3</td>\n", " <td>2.5</td>\n", " <td>2130</td>\n", " <td>25000</td>\n", " <td>76010</td>\n", " <td>260011</td>\n", " <td>True</td>\n", " </tr>\n", " <tr>\n", " <td>29</td>\n", " <td>780</td>\n", " <td>2</td>\n", " <td>1.0</td>\n", " <td>1170</td>\n", " <td>22000</td>\n", " <td>76010</td>\n", " <td>141800</td>\n", " <td>True</td>\n", " </tr>\n", " <tr>\n", " <td>30</td>\n", " <td>1910</td>\n", " <td>3</td>\n", " <td>2.0</td>\n", " <td>1500</td>\n", " <td>19000</td>\n", " <td>76010</td>\n", " <td>160900</td>\n", " <td>True</td>\n", " </tr>\n", " <tr>\n", " <td>31</td>\n", " <td>3600</td>\n", " <td>4</td>\n", " <td>2.0</td>\n", " <td>2790</td>\n", " <td>20000</td>\n", " <td>76010</td>\n", " <td>239000</td>\n", " <td>True</td>\n", " </tr>\n", " <tr>\n", " <td>32</td>\n", " <td>1600</td>\n", " <td>3</td>\n", " <td>1.0</td>\n", " <td>1030</td>\n", " <td>17500</td>\n", " <td>76010</td>\n", " <td>81010</td>\n", " <td>False</td>\n", " </tr>\n", " <tr>\n", " <td>33</td>\n", " <td>1590</td>\n", " <td>3</td>\n", " <td>2.0</td>\n", " <td>1250</td>\n", " <td>20000</td>\n", " <td>76010</td>\n", " <td>117910</td>\n", " <td>False</td>\n", " </tr>\n", " <tr>\n", " <td>34</td>\n", " <td>3200</td>\n", " <td>3</td>\n", " <td>2.0</td>\n", " <td>1760</td>\n", " <td>38000</td>\n", " <td>76010</td>\n", " <td>141100</td>\n", " <td>True</td>\n", " </tr>\n", " <tr>\n", " <td>35</td>\n", " <td>2270</td>\n", " <td>2</td>\n", " <td>3.0</td>\n", " <td>1550</td>\n", " <td>14000</td>\n", " <td>76010</td>\n", " <td>148011</td>\n", " <td>True</td>\n", " </tr>\n", " <tr>\n", " <td>36</td>\n", " <td>750</td>\n", " <td>3</td>\n", " <td>1.5</td>\n", " <td>1450</td>\n", " <td>12000</td>\n", " <td>76010</td>\n", " <td>66000</td>\n", " <td>False</td>\n", " </tr>\n", "</table>" ], "text/plain": [ "[(1, 590, 2, 1.0, 770, 22100, 94301, 50000, False),\n", " (2, 1050, 3, 2.0, 1410, 12000, 94301, 85000, False),\n", " (3, 20, 3, 1.0, 1060, 3500, 94301, 22500, False),\n", " (4, 870, 2, 2.0, 1300, 17500, 94301, 90000, False),\n", " (5, 1320, 3, 2.0, 1500, 30000, 94301, 133000, True),\n", " (6, 1350, 2, 1.0, 820, 25700, 94301, 90500, False),\n", " (7, 2790, 3, 2.5, 2130, 25000, 94301, 260000, True),\n", " (8, 680, 2, 1.0, 1170, 22000, 94301, 142500, True),\n", " (9, 1840, 3, 2.0, 1500, 19000, 94301, 160000, True),\n", " (10, 3680, 4, 2.0, 2790, 20000, 94301, 240000, True),\n", " (11, 1660, 3, 1.0, 1030, 17500, 94301, 87000, False),\n", " (12, 1620, 3, 2.0, 1250, 20000, 94301, 118600, True),\n", " (13, 3100, 3, 2.0, 1760, 38000, 94301, 140000, True),\n", " (14, 2070, 2, 3.0, 1550, 14000, 94301, 148000, True),\n", " (15, 650, 3, 1.5, 1450, 12000, 94301, 65000, False),\n", " (16, 770, 2, 2.0, 1300, 17500, 76010, 91000, False),\n", " (17, 1220, 3, 2.0, 1500, 30000, 76010, 132300, True),\n", " (18, 1150, 2, 1.0, 820, 25700, 76010, 91100, False),\n", " (19, 2690, 3, 2.5, 2130, 25000, 76010, 260011, True),\n", " (20, 780, 2, 1.0, 1170, 22000, 76010, 141800, True),\n", " (21, 1910, 3, 2.0, 1500, 19000, 76010, 160900, True),\n", " (22, 3600, 4, 2.0, 2790, 20000, 76010, 239000, True),\n", " (23, 1600, 3, 1.0, 1030, 17500, 76010, 81010, False),\n", " (24, 1590, 3, 2.0, 1250, 20000, 76010, 117910, False),\n", " (25, 3200, 3, 2.0, 1760, 38000, 76010, 141100, True),\n", " (26, 2270, 2, 3.0, 1550, 14000, 76010, 148011, True),\n", " (27, 750, 3, 1.5, 1450, 12000, 76010, 66000, False),\n", " (28, 2690, 3, 2.5, 2130, 25000, 76010, 260011, True),\n", " (29, 780, 2, 1.0, 1170, 22000, 76010, 141800, True),\n", " (30, 1910, 3, 2.0, 1500, 19000, 76010, 160900, True),\n", " (31, 3600, 4, 2.0, 2790, 20000, 76010, 239000, True),\n", " (32, 1600, 3, 1.0, 1030, 17500, 76010, 81010, False),\n", " (33, 1590, 3, 2.0, 1250, 20000, 76010, 117910, False),\n", " (34, 3200, 3, 2.0, 1760, 38000, 76010, 141100, True),\n", " (35, 2270, 2, 3.0, 1550, 14000, 76010, 148011, True),\n", " (36, 750, 3, 1.5, 1450, 12000, 76010, 66000, False)]" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%sql\n", "DROP TABLE IF EXISTS houses;\n", "\n", "CREATE TABLE houses ( id INT,\n", " tax INT,\n", " bedroom INT,\n", " bath FLOAT,\n", " size INT,\n", " lot INT,\n", " zipcode INT,\n", " price INT,\n", " high_priced BOOLEAN\n", " );\n", "\n", "INSERT INTO houses (id, tax, bedroom, bath, price, size, lot, zipcode, high_priced) VALUES\n", "(1 , 590 , 2 , 1 , 50000 , 770 , 22100 , 94301, 'f'::boolean),\n", "(2 , 1050 , 3 , 2 , 85000 , 1410 , 12000 , 94301, 'f'::boolean),\n", "(3 , 20 , 3 , 1 , 22500 , 1060 , 3500 , 94301, 'f'::boolean),\n", "(4 , 870 , 2 , 2 , 90000 , 1300 , 17500 , 94301, 'f'::boolean),\n", "(5 , 1320 , 3 , 2 , 133000 , 1500 , 30000 , 94301, 't'::boolean),\n", "(6 , 1350 , 2 , 1 , 90500 , 820 , 25700 , 94301, 'f'::boolean),\n", "(7 , 2790 , 3 , 2.5 , 260000 , 2130 , 25000 , 94301, 't'::boolean),\n", "(8 , 680 , 2 , 1 , 142500 , 1170 , 22000 , 94301, 't'::boolean),\n", "(9 , 1840 , 3 , 2 , 160000 , 1500 , 19000 , 94301, 't'::boolean),\n", "(10 , 3680 , 4 , 2 , 240000 , 2790 , 20000 , 94301, 't'::boolean),\n", "(11 , 1660 , 3 , 1 , 87000 , 1030 , 17500 , 94301, 'f'::boolean),\n", "(12 , 1620 , 3 , 2 , 118600 , 1250 , 20000 , 94301, 't'::boolean),\n", "(13 , 3100 , 3 , 2 , 140000 , 1760 , 38000 , 94301, 't'::boolean),\n", "(14 , 2070 , 2 , 3 , 148000 , 1550 , 14000 , 94301, 't'::boolean),\n", "(15 , 650 , 3 , 1.5 , 65000 , 1450 , 12000 , 94301, 'f'::boolean),\n", "(16 , 770 , 2 , 2 , 91000 , 1300 , 17500 , 76010, 'f'::boolean),\n", "(17 , 1220 , 3 , 2 , 132300 , 1500 , 30000 , 76010, 't'::boolean),\n", "(18 , 1150 , 2 , 1 , 91100 , 820 , 25700 , 76010, 'f'::boolean),\n", "(19 , 2690 , 3 , 2.5 , 260011 , 2130 , 25000 , 76010, 't'::boolean),\n", "(20 , 780 , 2 , 1 , 141800 , 1170 , 22000 , 76010, 't'::boolean),\n", "(21 , 1910 , 3 , 2 , 160900 , 1500 , 19000 , 76010, 't'::boolean),\n", "(22 , 3600 , 4 , 2 , 239000 , 2790 , 20000 , 76010, 't'::boolean),\n", "(23 , 1600 , 3 , 1 , 81010 , 1030 , 17500 , 76010, 'f'::boolean),\n", "(24 , 1590 , 3 , 2 , 117910 , 1250 , 20000 , 76010, 'f'::boolean),\n", "(25 , 3200 , 3 , 2 , 141100 , 1760 , 38000 , 76010, 't'::boolean),\n", "(26 , 2270 , 2 , 3 , 148011 , 1550 , 14000 , 76010, 't'::boolean),\n", "(27 , 750 , 3 , 1.5 , 66000 , 1450 , 12000 , 76010, 'f'::boolean),\n", "(28 , 2690 , 3 , 2.5 , 260011 , 2130 , 25000 , 76010, 't'::boolean),\n", "(29 , 780 , 2 , 1 , 141800 , 1170 , 22000 , 76010, 't'::boolean),\n", "(30 , 1910 , 3 , 2 , 160900 , 1500 , 19000 , 76010, 't'::boolean),\n", "(31 , 3600 , 4 , 2 , 239000 , 2790 , 20000 , 76010, 't'::boolean),\n", "(32 , 1600 , 3 , 1 , 81010 , 1030 , 17500 , 76010, 'f'::boolean),\n", "(33 , 1590 , 3 , 2 , 117910 , 1250 , 20000 , 76010, 'f'::boolean),\n", "(34 , 3200 , 3 , 2 , 141100 , 1760 , 38000 , 76010, 't'::boolean),\n", "(35 , 2270 , 2 , 3 , 148011 , 1550 , 14000 , 76010, 't'::boolean),\n", "(36 , 750 , 3 , 1.5 , 66000 , 1450 , 12000 , 76010, 'f'::boolean);\n", "\n", "SELECT * FROM houses ORDER BY id;" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 2. Elastic net\n", "\n", "Note that elastic net also has a built in cross validation function for selecting elastic net control parameter alpha and regularization value lambda\n", "http://madlib.apache.org/docs/latest/group__grp__elasticnet.html\n", "\n", "But here we use the general function to explore lambda values:" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Done.\n", "1 rows affected.\n" ] }, { "data": { "text/html": [ "<table>\n", " <tr>\n", " <th>cross_validation_general</th>\n", " </tr>\n", " <tr>\n", " <td></td>\n", " </tr>\n", "</table>" ], "text/plain": [ "[('',)]" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%sql\n", "DROP TABLE IF EXISTS houses_cv_results;\n", "\n", "SELECT madlib.cross_validation_general(\n", " -- modelling_func\n", " 'madlib.elastic_net_train',\n", " \n", " -- modelling_params\n", " '{%data%, %model%, price, \"array[tax, bath, size]\", gaussian, 0.5, lambda, TRUE, NULL, fista,\n", " \"{eta = 2, max_stepsize = 2, use_active_set = t}\",\n", " NULL, 10000, 1e-6}'::varchar[],\n", " \n", " -- modelling_params_type\n", " '{varchar, varchar, varchar, varchar, varchar, double precision,\n", " double precision, boolean, varchar, varchar, varchar, varchar,\n", " integer, double precision}'::varchar[],\n", " \n", " -- param_explored\n", " 'lambda',\n", " \n", " -- explore_values\n", " '{0.1, 0.2}'::varchar[],\n", " \n", " -- predict_func\n", " 'madlib.elastic_net_predict',\n", " \n", " -- predict_params\n", " '{%model%, %data%, %id%, %prediction%}'::varchar[],\n", " \n", " -- predict_params_type\n", " '{text, text, text, text}'::varchar[],\n", " \n", " -- metric_func\n", " 'madlib.mse_error',\n", " \n", " -- metric_params\n", " '{%prediction%, %data%, %id%, price, %error%}'::varchar[],\n", " \n", " -- metric_params_type\n", " '{varchar, varchar, varchar, varchar, varchar}'::varchar[],\n", " \n", " -- data_tbl\n", " 'houses',\n", " \n", " -- data_id\n", " 'id',\n", " \n", " -- id_is_random\n", " FALSE,\n", " \n", " -- validation_result\n", " 'houses_cv_results',\n", " \n", " -- data_cols\n", " NULL,\n", " \n", " -- fold_num\n", " 3\n", ");" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2 rows affected.\n" ] }, { "data": { "text/html": [ "<table>\n", " <tr>\n", " <th>lambda</th>\n", " <th>mean_squared_error_avg</th>\n", " <th>mean_squared_error_stddev</th>\n", " </tr>\n", " <tr>\n", " <td>0.1</td>\n", " <td>1194685622.16</td>\n", " <td>366687470.78</td>\n", " </tr>\n", " <tr>\n", " <td>0.2</td>\n", " <td>1181768409.98</td>\n", " <td>352203200.758</td>\n", " </tr>\n", "</table>" ], "text/plain": [ "[(0.1, 1194685622.1604, 366687470.779826),\n", " (0.2, 1181768409.98238, 352203200.758414)]" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%sql\n", "SELECT * FROM houses_cv_results;" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 3. Logistic regression \n", "\n", "Here we use the general function to explore maximum number of iterations:" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Done.\n", "1 rows affected.\n" ] }, { "data": { "text/html": [ "<table>\n", " <tr>\n", " <th>cross_validation_general</th>\n", " </tr>\n", " <tr>\n", " <td></td>\n", " </tr>\n", "</table>" ], "text/plain": [ "[('',)]" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%sql\n", "DROP TABLE IF EXISTS houses_logregr_cv;\n", "\n", "SELECT madlib.cross_validation_general(\n", " -- modelling_func\n", " 'madlib.logregr_train',\n", " \n", " -- modelling_params\n", " '{%data%, %model%, high_priced, \"ARRAY[1, bedroom, bath, size]\", NULL, max_iter}'::varchar[],\n", " \n", " -- modelling_params_type\n", " '{varchar, varchar, varchar, varchar, varchar, integer}'::varchar[],\n", " \n", " -- param_explored\n", " 'max_iter',\n", " \n", " -- explore_values\n", " '{2, 10, 40, 100}'::varchar[],\n", " \n", " -- predict_func\n", " 'madlib.cv_logregr_predict',\n", " \n", " -- predict_params\n", " '{%model%, %data%, \"ARRAY[1, bedroom, bath, size]\", id, %prediction%}'::varchar[],\n", " \n", " -- predict_params_type\n", " '{varchar, varchar,varchar,varchar,varchar}'::varchar[],\n", " \n", " -- metric_func\n", " 'madlib.misclassification_avg',\n", " \n", " -- metric_params\n", " '{%prediction%, %data%, id, high_priced, %error%}'::varchar[],\n", " \n", " -- metric_params_type\n", " '{varchar, varchar, varchar, varchar, varchar}'::varchar[],\n", " \n", " -- data_tbl\n", " 'houses',\n", " \n", " -- data_id\n", " 'id',\n", " \n", " -- id_is_random\n", " FALSE,\n", " \n", " -- validation_result\n", " 'houses_logregr_cv',\n", " \n", " -- data_cols\n", " NULL,\n", " \n", " -- fold_num\n", " 5\n", ");" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "4 rows affected.\n" ] }, { "data": { "text/html": [ "<table>\n", " <tr>\n", " <th>max_iter</th>\n", " <th>error_rate_avg</th>\n", " <th>error_rate_stddev</th>\n", " </tr>\n", " <tr>\n", " <td>2</td>\n", " <td>0.19642857142857142857</td>\n", " <td>0.0818317088384971429780598253843971801653</td>\n", " </tr>\n", " <tr>\n", " <td>10</td>\n", " <td>0.22142857142857142857</td>\n", " <td>0.0731925054711399884549944979733273803475</td>\n", " </tr>\n", " <tr>\n", " <td>40</td>\n", " <td>0.22142857142857142857</td>\n", " <td>0.0731925054711399884549944979733273803475</td>\n", " </tr>\n", " <tr>\n", " <td>100</td>\n", " <td>0.22142857142857142857</td>\n", " <td>0.0731925054711399884549944979733273803475</td>\n", " </tr>\n", "</table>" ], "text/plain": [ "[(2, Decimal('0.19642857142857142857'), Decimal('0.0818317088384971429780598253843971801653')),\n", " (10, Decimal('0.22142857142857142857'), Decimal('0.0731925054711399884549944979733273803475')),\n", " (40, Decimal('0.22142857142857142857'), Decimal('0.0731925054711399884549944979733273803475')),\n", " (100, Decimal('0.22142857142857142857'), Decimal('0.0731925054711399884549944979733273803475'))]" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%sql\n", "SELECT * FROM houses_logregr_cv;" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 4. Decision tree\n", "\n", "Here we use the general function to explore tree depth. First we need to create a wrapper function for predict that does a column rename:" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Done.\n" ] }, { "data": { "text/plain": [ "[]" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%sql\n", "CREATE OR REPLACE FUNCTION tree_predict_rename_col(model_table VARCHAR, data_table VARCHAR, output_table VARCHAR,\n", " orig_column VARCHAR, new_column VARCHAR)\n", "RETURNS VOID AS $$\n", "BEGIN\n", " EXECUTE format('SELECT madlib.tree_predict(''%s'', ''%s'', ''%s'')', model_table, data_table, output_table);\n", " EXECUTE 'ALTER TABLE ' || output_table || ' RENAME ' || orig_column || ' TO ' || new_column;\n", "END\n", "$$ LANGUAGE plpgsql VOLATILE;" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Done.\n", "1 rows affected.\n", "4 rows affected.\n" ] }, { "data": { "text/html": [ "<table>\n", " <tr>\n", " <th>max_depth</th>\n", " <th>error_rate_avg</th>\n", " <th>error_rate_stddev</th>\n", " </tr>\n", " <tr>\n", " <td>1</td>\n", " <td>0.16785714285714285714</td>\n", " <td>0.1208494593977759334440468761527971235643</td>\n", " </tr>\n", " <tr>\n", " <td>2</td>\n", " <td>0.16785714285714285714</td>\n", " <td>0.1208494593977759334440468761527971235643</td>\n", " </tr>\n", " <tr>\n", " <td>3</td>\n", " <td>0.17142857142857142857</td>\n", " <td>0.1564921592871903181329101774752513216155</td>\n", " </tr>\n", " <tr>\n", " <td>4</td>\n", " <td>0.14285714285714285714</td>\n", " <td>0.1428571428571428571449999999999999999999</td>\n", " </tr>\n", "</table>" ], "text/plain": [ "[(1, Decimal('0.16785714285714285714'), Decimal('0.1208494593977759334440468761527971235643')),\n", " (2, Decimal('0.16785714285714285714'), Decimal('0.1208494593977759334440468761527971235643')),\n", " (3, Decimal('0.17142857142857142857'), Decimal('0.1564921592871903181329101774752513216155')),\n", " (4, Decimal('0.14285714285714285714'), Decimal('0.1428571428571428571449999999999999999999'))]" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%%sql\n", "DROP TABLE IF EXISTS houses_dt_cv;\n", "\n", "SELECT madlib.cross_validation_general(\n", " -- modelling_func\n", " 'madlib.tree_train',\n", " \n", " -- modelling_params\n", " '{%data%, %model%, id, high_priced, \"bedroom, bath, size, zipcode\", NULL, NULL, NULL, NULL, max_depth, 1, 1, 10}'::varchar[],\n", " \n", " -- modelling_params_type\n", " '{varchar, varchar, varchar, varchar, varchar, varchar, varchar, varchar, varchar, integer, integer, integer, integer}',\n", " \n", " -- param_explored\n", " 'max_depth',\n", " \n", " -- explore_values\n", " '{1, 2, 3, 4}'::varchar[],\n", " \n", " -- predict_func\n", " 'tree_predict_rename_col',\n", " \n", " -- predict_params\n", " '{%model%, %data%, %prediction%, estimated_high_priced, prediction}'::varchar[],\n", " \n", " -- predict_params_type\n", " '{varchar,varchar,varchar,varchar,varchar}'::varchar[],\n", " \n", " -- metric_func\n", " 'madlib.misclassification_avg',\n", " \n", " -- metric_params\n", " '{%prediction%, %data%, id, high_priced, %error%}'::varchar[],\n", " \n", " -- metric_params_type\n", " '{varchar, varchar, varchar, varchar, varchar}'::varchar[],\n", " \n", " -- data_tbl\n", " 'houses',\n", " \n", " -- data_id\n", " 'id',\n", " \n", " -- id_is_random\n", " FALSE,\n", " \n", " -- validation_result\n", " 'houses_dt_cv',\n", " \n", " -- data_cols\n", " NULL,\n", " \n", " -- fold_num\n", " 5\n", ");\n", "SELECT * FROM houses_dt_cv;" ] } ], "metadata": { "kernelspec": { "display_name": "Python 2", "language": "python", "name": "python2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.12" } }, "nbformat": 4, "nbformat_minor": 1 }