community-artifacts/Model-selection/Cross_validation_v1.ipynb (1,130 lines of code) (raw):
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Cross validation using general function\n",
"\n",
"Examples for \n",
"http://madlib.apache.org/docs/latest/group__grp__validation.html"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/fmcquillan/anaconda/lib/python2.7/site-packages/IPython/config.py:13: ShimWarning: The `IPython.config` package has been deprecated since IPython 4.0. You should import from traitlets.config instead.\n",
" \"You should import from traitlets.config instead.\", ShimWarning)\n",
"/Users/fmcquillan/anaconda/lib/python2.7/site-packages/IPython/utils/traitlets.py:5: UserWarning: IPython.utils.traitlets has moved to a top-level traitlets package.\n",
" warn(\"IPython.utils.traitlets has moved to a top-level traitlets package.\")\n"
]
}
],
"source": [
"%load_ext sql"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"u'Connected: fmcquillan@madlib'"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Greenplum Database 5.4.0 on GCP (demo machine)\n",
"#%sql postgresql://gpadmin@35.184.232.200:5432/madlib\n",
" \n",
"# PostgreSQL local\n",
"%sql postgresql://fmcquillan@localhost:5432/madlib"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>version</th>\n",
" </tr>\n",
" <tr>\n",
" <td>MADlib version: 1.15.1, git revision: rc/1.15.1-rc1, cmake configuration time: Wed Oct 10 04:29:25 UTC 2018, build type: Release, build system: Darwin-17.7.0, C compiler: Clang, C++ compiler: Clang</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(u'MADlib version: 1.15.1, git revision: rc/1.15.1-rc1, cmake configuration time: Wed Oct 10 04:29:25 UTC 2018, build type: Release, build system: Darwin-17.7.0, C compiler: Clang, C++ compiler: Clang',)]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%sql select madlib.version();\n",
"#%sql select version();"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 1. Create data set\n",
"House prices and characteristics."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"Done.\n",
"36 rows affected.\n",
"36 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>id</th>\n",
" <th>tax</th>\n",
" <th>bedroom</th>\n",
" <th>bath</th>\n",
" <th>size</th>\n",
" <th>lot</th>\n",
" <th>zipcode</th>\n",
" <th>price</th>\n",
" <th>high_priced</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>590</td>\n",
" <td>2</td>\n",
" <td>1.0</td>\n",
" <td>770</td>\n",
" <td>22100</td>\n",
" <td>94301</td>\n",
" <td>50000</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>1050</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>1410</td>\n",
" <td>12000</td>\n",
" <td>94301</td>\n",
" <td>85000</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>20</td>\n",
" <td>3</td>\n",
" <td>1.0</td>\n",
" <td>1060</td>\n",
" <td>3500</td>\n",
" <td>94301</td>\n",
" <td>22500</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>870</td>\n",
" <td>2</td>\n",
" <td>2.0</td>\n",
" <td>1300</td>\n",
" <td>17500</td>\n",
" <td>94301</td>\n",
" <td>90000</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5</td>\n",
" <td>1320</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>1500</td>\n",
" <td>30000</td>\n",
" <td>94301</td>\n",
" <td>133000</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6</td>\n",
" <td>1350</td>\n",
" <td>2</td>\n",
" <td>1.0</td>\n",
" <td>820</td>\n",
" <td>25700</td>\n",
" <td>94301</td>\n",
" <td>90500</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7</td>\n",
" <td>2790</td>\n",
" <td>3</td>\n",
" <td>2.5</td>\n",
" <td>2130</td>\n",
" <td>25000</td>\n",
" <td>94301</td>\n",
" <td>260000</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8</td>\n",
" <td>680</td>\n",
" <td>2</td>\n",
" <td>1.0</td>\n",
" <td>1170</td>\n",
" <td>22000</td>\n",
" <td>94301</td>\n",
" <td>142500</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9</td>\n",
" <td>1840</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>1500</td>\n",
" <td>19000</td>\n",
" <td>94301</td>\n",
" <td>160000</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>3680</td>\n",
" <td>4</td>\n",
" <td>2.0</td>\n",
" <td>2790</td>\n",
" <td>20000</td>\n",
" <td>94301</td>\n",
" <td>240000</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <td>11</td>\n",
" <td>1660</td>\n",
" <td>3</td>\n",
" <td>1.0</td>\n",
" <td>1030</td>\n",
" <td>17500</td>\n",
" <td>94301</td>\n",
" <td>87000</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <td>12</td>\n",
" <td>1620</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>1250</td>\n",
" <td>20000</td>\n",
" <td>94301</td>\n",
" <td>118600</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <td>13</td>\n",
" <td>3100</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>1760</td>\n",
" <td>38000</td>\n",
" <td>94301</td>\n",
" <td>140000</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <td>14</td>\n",
" <td>2070</td>\n",
" <td>2</td>\n",
" <td>3.0</td>\n",
" <td>1550</td>\n",
" <td>14000</td>\n",
" <td>94301</td>\n",
" <td>148000</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <td>15</td>\n",
" <td>650</td>\n",
" <td>3</td>\n",
" <td>1.5</td>\n",
" <td>1450</td>\n",
" <td>12000</td>\n",
" <td>94301</td>\n",
" <td>65000</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <td>16</td>\n",
" <td>770</td>\n",
" <td>2</td>\n",
" <td>2.0</td>\n",
" <td>1300</td>\n",
" <td>17500</td>\n",
" <td>76010</td>\n",
" <td>91000</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <td>17</td>\n",
" <td>1220</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>1500</td>\n",
" <td>30000</td>\n",
" <td>76010</td>\n",
" <td>132300</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <td>18</td>\n",
" <td>1150</td>\n",
" <td>2</td>\n",
" <td>1.0</td>\n",
" <td>820</td>\n",
" <td>25700</td>\n",
" <td>76010</td>\n",
" <td>91100</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <td>19</td>\n",
" <td>2690</td>\n",
" <td>3</td>\n",
" <td>2.5</td>\n",
" <td>2130</td>\n",
" <td>25000</td>\n",
" <td>76010</td>\n",
" <td>260011</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <td>20</td>\n",
" <td>780</td>\n",
" <td>2</td>\n",
" <td>1.0</td>\n",
" <td>1170</td>\n",
" <td>22000</td>\n",
" <td>76010</td>\n",
" <td>141800</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <td>21</td>\n",
" <td>1910</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>1500</td>\n",
" <td>19000</td>\n",
" <td>76010</td>\n",
" <td>160900</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <td>22</td>\n",
" <td>3600</td>\n",
" <td>4</td>\n",
" <td>2.0</td>\n",
" <td>2790</td>\n",
" <td>20000</td>\n",
" <td>76010</td>\n",
" <td>239000</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <td>23</td>\n",
" <td>1600</td>\n",
" <td>3</td>\n",
" <td>1.0</td>\n",
" <td>1030</td>\n",
" <td>17500</td>\n",
" <td>76010</td>\n",
" <td>81010</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <td>24</td>\n",
" <td>1590</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>1250</td>\n",
" <td>20000</td>\n",
" <td>76010</td>\n",
" <td>117910</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <td>25</td>\n",
" <td>3200</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>1760</td>\n",
" <td>38000</td>\n",
" <td>76010</td>\n",
" <td>141100</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <td>26</td>\n",
" <td>2270</td>\n",
" <td>2</td>\n",
" <td>3.0</td>\n",
" <td>1550</td>\n",
" <td>14000</td>\n",
" <td>76010</td>\n",
" <td>148011</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <td>27</td>\n",
" <td>750</td>\n",
" <td>3</td>\n",
" <td>1.5</td>\n",
" <td>1450</td>\n",
" <td>12000</td>\n",
" <td>76010</td>\n",
" <td>66000</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <td>28</td>\n",
" <td>2690</td>\n",
" <td>3</td>\n",
" <td>2.5</td>\n",
" <td>2130</td>\n",
" <td>25000</td>\n",
" <td>76010</td>\n",
" <td>260011</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <td>29</td>\n",
" <td>780</td>\n",
" <td>2</td>\n",
" <td>1.0</td>\n",
" <td>1170</td>\n",
" <td>22000</td>\n",
" <td>76010</td>\n",
" <td>141800</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <td>30</td>\n",
" <td>1910</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>1500</td>\n",
" <td>19000</td>\n",
" <td>76010</td>\n",
" <td>160900</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <td>31</td>\n",
" <td>3600</td>\n",
" <td>4</td>\n",
" <td>2.0</td>\n",
" <td>2790</td>\n",
" <td>20000</td>\n",
" <td>76010</td>\n",
" <td>239000</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <td>32</td>\n",
" <td>1600</td>\n",
" <td>3</td>\n",
" <td>1.0</td>\n",
" <td>1030</td>\n",
" <td>17500</td>\n",
" <td>76010</td>\n",
" <td>81010</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <td>33</td>\n",
" <td>1590</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>1250</td>\n",
" <td>20000</td>\n",
" <td>76010</td>\n",
" <td>117910</td>\n",
" <td>False</td>\n",
" </tr>\n",
" <tr>\n",
" <td>34</td>\n",
" <td>3200</td>\n",
" <td>3</td>\n",
" <td>2.0</td>\n",
" <td>1760</td>\n",
" <td>38000</td>\n",
" <td>76010</td>\n",
" <td>141100</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <td>35</td>\n",
" <td>2270</td>\n",
" <td>2</td>\n",
" <td>3.0</td>\n",
" <td>1550</td>\n",
" <td>14000</td>\n",
" <td>76010</td>\n",
" <td>148011</td>\n",
" <td>True</td>\n",
" </tr>\n",
" <tr>\n",
" <td>36</td>\n",
" <td>750</td>\n",
" <td>3</td>\n",
" <td>1.5</td>\n",
" <td>1450</td>\n",
" <td>12000</td>\n",
" <td>76010</td>\n",
" <td>66000</td>\n",
" <td>False</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, 590, 2, 1.0, 770, 22100, 94301, 50000, False),\n",
" (2, 1050, 3, 2.0, 1410, 12000, 94301, 85000, False),\n",
" (3, 20, 3, 1.0, 1060, 3500, 94301, 22500, False),\n",
" (4, 870, 2, 2.0, 1300, 17500, 94301, 90000, False),\n",
" (5, 1320, 3, 2.0, 1500, 30000, 94301, 133000, True),\n",
" (6, 1350, 2, 1.0, 820, 25700, 94301, 90500, False),\n",
" (7, 2790, 3, 2.5, 2130, 25000, 94301, 260000, True),\n",
" (8, 680, 2, 1.0, 1170, 22000, 94301, 142500, True),\n",
" (9, 1840, 3, 2.0, 1500, 19000, 94301, 160000, True),\n",
" (10, 3680, 4, 2.0, 2790, 20000, 94301, 240000, True),\n",
" (11, 1660, 3, 1.0, 1030, 17500, 94301, 87000, False),\n",
" (12, 1620, 3, 2.0, 1250, 20000, 94301, 118600, True),\n",
" (13, 3100, 3, 2.0, 1760, 38000, 94301, 140000, True),\n",
" (14, 2070, 2, 3.0, 1550, 14000, 94301, 148000, True),\n",
" (15, 650, 3, 1.5, 1450, 12000, 94301, 65000, False),\n",
" (16, 770, 2, 2.0, 1300, 17500, 76010, 91000, False),\n",
" (17, 1220, 3, 2.0, 1500, 30000, 76010, 132300, True),\n",
" (18, 1150, 2, 1.0, 820, 25700, 76010, 91100, False),\n",
" (19, 2690, 3, 2.5, 2130, 25000, 76010, 260011, True),\n",
" (20, 780, 2, 1.0, 1170, 22000, 76010, 141800, True),\n",
" (21, 1910, 3, 2.0, 1500, 19000, 76010, 160900, True),\n",
" (22, 3600, 4, 2.0, 2790, 20000, 76010, 239000, True),\n",
" (23, 1600, 3, 1.0, 1030, 17500, 76010, 81010, False),\n",
" (24, 1590, 3, 2.0, 1250, 20000, 76010, 117910, False),\n",
" (25, 3200, 3, 2.0, 1760, 38000, 76010, 141100, True),\n",
" (26, 2270, 2, 3.0, 1550, 14000, 76010, 148011, True),\n",
" (27, 750, 3, 1.5, 1450, 12000, 76010, 66000, False),\n",
" (28, 2690, 3, 2.5, 2130, 25000, 76010, 260011, True),\n",
" (29, 780, 2, 1.0, 1170, 22000, 76010, 141800, True),\n",
" (30, 1910, 3, 2.0, 1500, 19000, 76010, 160900, True),\n",
" (31, 3600, 4, 2.0, 2790, 20000, 76010, 239000, True),\n",
" (32, 1600, 3, 1.0, 1030, 17500, 76010, 81010, False),\n",
" (33, 1590, 3, 2.0, 1250, 20000, 76010, 117910, False),\n",
" (34, 3200, 3, 2.0, 1760, 38000, 76010, 141100, True),\n",
" (35, 2270, 2, 3.0, 1550, 14000, 76010, 148011, True),\n",
" (36, 750, 3, 1.5, 1450, 12000, 76010, 66000, False)]"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS houses;\n",
"\n",
"CREATE TABLE houses ( id INT,\n",
" tax INT,\n",
" bedroom INT,\n",
" bath FLOAT,\n",
" size INT,\n",
" lot INT,\n",
" zipcode INT,\n",
" price INT,\n",
" high_priced BOOLEAN\n",
" );\n",
"\n",
"INSERT INTO houses (id, tax, bedroom, bath, price, size, lot, zipcode, high_priced) VALUES\n",
"(1 , 590 , 2 , 1 , 50000 , 770 , 22100 , 94301, 'f'::boolean),\n",
"(2 , 1050 , 3 , 2 , 85000 , 1410 , 12000 , 94301, 'f'::boolean),\n",
"(3 , 20 , 3 , 1 , 22500 , 1060 , 3500 , 94301, 'f'::boolean),\n",
"(4 , 870 , 2 , 2 , 90000 , 1300 , 17500 , 94301, 'f'::boolean),\n",
"(5 , 1320 , 3 , 2 , 133000 , 1500 , 30000 , 94301, 't'::boolean),\n",
"(6 , 1350 , 2 , 1 , 90500 , 820 , 25700 , 94301, 'f'::boolean),\n",
"(7 , 2790 , 3 , 2.5 , 260000 , 2130 , 25000 , 94301, 't'::boolean),\n",
"(8 , 680 , 2 , 1 , 142500 , 1170 , 22000 , 94301, 't'::boolean),\n",
"(9 , 1840 , 3 , 2 , 160000 , 1500 , 19000 , 94301, 't'::boolean),\n",
"(10 , 3680 , 4 , 2 , 240000 , 2790 , 20000 , 94301, 't'::boolean),\n",
"(11 , 1660 , 3 , 1 , 87000 , 1030 , 17500 , 94301, 'f'::boolean),\n",
"(12 , 1620 , 3 , 2 , 118600 , 1250 , 20000 , 94301, 't'::boolean),\n",
"(13 , 3100 , 3 , 2 , 140000 , 1760 , 38000 , 94301, 't'::boolean),\n",
"(14 , 2070 , 2 , 3 , 148000 , 1550 , 14000 , 94301, 't'::boolean),\n",
"(15 , 650 , 3 , 1.5 , 65000 , 1450 , 12000 , 94301, 'f'::boolean),\n",
"(16 , 770 , 2 , 2 , 91000 , 1300 , 17500 , 76010, 'f'::boolean),\n",
"(17 , 1220 , 3 , 2 , 132300 , 1500 , 30000 , 76010, 't'::boolean),\n",
"(18 , 1150 , 2 , 1 , 91100 , 820 , 25700 , 76010, 'f'::boolean),\n",
"(19 , 2690 , 3 , 2.5 , 260011 , 2130 , 25000 , 76010, 't'::boolean),\n",
"(20 , 780 , 2 , 1 , 141800 , 1170 , 22000 , 76010, 't'::boolean),\n",
"(21 , 1910 , 3 , 2 , 160900 , 1500 , 19000 , 76010, 't'::boolean),\n",
"(22 , 3600 , 4 , 2 , 239000 , 2790 , 20000 , 76010, 't'::boolean),\n",
"(23 , 1600 , 3 , 1 , 81010 , 1030 , 17500 , 76010, 'f'::boolean),\n",
"(24 , 1590 , 3 , 2 , 117910 , 1250 , 20000 , 76010, 'f'::boolean),\n",
"(25 , 3200 , 3 , 2 , 141100 , 1760 , 38000 , 76010, 't'::boolean),\n",
"(26 , 2270 , 2 , 3 , 148011 , 1550 , 14000 , 76010, 't'::boolean),\n",
"(27 , 750 , 3 , 1.5 , 66000 , 1450 , 12000 , 76010, 'f'::boolean),\n",
"(28 , 2690 , 3 , 2.5 , 260011 , 2130 , 25000 , 76010, 't'::boolean),\n",
"(29 , 780 , 2 , 1 , 141800 , 1170 , 22000 , 76010, 't'::boolean),\n",
"(30 , 1910 , 3 , 2 , 160900 , 1500 , 19000 , 76010, 't'::boolean),\n",
"(31 , 3600 , 4 , 2 , 239000 , 2790 , 20000 , 76010, 't'::boolean),\n",
"(32 , 1600 , 3 , 1 , 81010 , 1030 , 17500 , 76010, 'f'::boolean),\n",
"(33 , 1590 , 3 , 2 , 117910 , 1250 , 20000 , 76010, 'f'::boolean),\n",
"(34 , 3200 , 3 , 2 , 141100 , 1760 , 38000 , 76010, 't'::boolean),\n",
"(35 , 2270 , 2 , 3 , 148011 , 1550 , 14000 , 76010, 't'::boolean),\n",
"(36 , 750 , 3 , 1.5 , 66000 , 1450 , 12000 , 76010, 'f'::boolean);\n",
"\n",
"SELECT * FROM houses ORDER BY id;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 2. Elastic net\n",
"\n",
"Note that elastic net also has a built in cross validation function for selecting elastic net control parameter alpha and regularization value lambda\n",
"http://madlib.apache.org/docs/latest/group__grp__elasticnet.html\n",
"\n",
"But here we use the general function to explore lambda values:"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>cross_validation_general</th>\n",
" </tr>\n",
" <tr>\n",
" <td></td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[('',)]"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS houses_cv_results;\n",
"\n",
"SELECT madlib.cross_validation_general(\n",
" -- modelling_func\n",
" 'madlib.elastic_net_train',\n",
" \n",
" -- modelling_params\n",
" '{%data%, %model%, price, \"array[tax, bath, size]\", gaussian, 0.5, lambda, TRUE, NULL, fista,\n",
" \"{eta = 2, max_stepsize = 2, use_active_set = t}\",\n",
" NULL, 10000, 1e-6}'::varchar[],\n",
" \n",
" -- modelling_params_type\n",
" '{varchar, varchar, varchar, varchar, varchar, double precision,\n",
" double precision, boolean, varchar, varchar, varchar, varchar,\n",
" integer, double precision}'::varchar[],\n",
" \n",
" -- param_explored\n",
" 'lambda',\n",
" \n",
" -- explore_values\n",
" '{0.1, 0.2}'::varchar[],\n",
" \n",
" -- predict_func\n",
" 'madlib.elastic_net_predict',\n",
" \n",
" -- predict_params\n",
" '{%model%, %data%, %id%, %prediction%}'::varchar[],\n",
" \n",
" -- predict_params_type\n",
" '{text, text, text, text}'::varchar[],\n",
" \n",
" -- metric_func\n",
" 'madlib.mse_error',\n",
" \n",
" -- metric_params\n",
" '{%prediction%, %data%, %id%, price, %error%}'::varchar[],\n",
" \n",
" -- metric_params_type\n",
" '{varchar, varchar, varchar, varchar, varchar}'::varchar[],\n",
" \n",
" -- data_tbl\n",
" 'houses',\n",
" \n",
" -- data_id\n",
" 'id',\n",
" \n",
" -- id_is_random\n",
" FALSE,\n",
" \n",
" -- validation_result\n",
" 'houses_cv_results',\n",
" \n",
" -- data_cols\n",
" NULL,\n",
" \n",
" -- fold_num\n",
" 3\n",
");"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>lambda</th>\n",
" <th>mean_squared_error_avg</th>\n",
" <th>mean_squared_error_stddev</th>\n",
" </tr>\n",
" <tr>\n",
" <td>0.1</td>\n",
" <td>1194685622.16</td>\n",
" <td>366687470.78</td>\n",
" </tr>\n",
" <tr>\n",
" <td>0.2</td>\n",
" <td>1181768409.98</td>\n",
" <td>352203200.758</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(0.1, 1194685622.1604, 366687470.779826),\n",
" (0.2, 1181768409.98238, 352203200.758414)]"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT * FROM houses_cv_results;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 3. Logistic regression \n",
"\n",
"Here we use the general function to explore maximum number of iterations:"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>cross_validation_general</th>\n",
" </tr>\n",
" <tr>\n",
" <td></td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[('',)]"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS houses_logregr_cv;\n",
"\n",
"SELECT madlib.cross_validation_general(\n",
" -- modelling_func\n",
" 'madlib.logregr_train',\n",
" \n",
" -- modelling_params\n",
" '{%data%, %model%, high_priced, \"ARRAY[1, bedroom, bath, size]\", NULL, max_iter}'::varchar[],\n",
" \n",
" -- modelling_params_type\n",
" '{varchar, varchar, varchar, varchar, varchar, integer}'::varchar[],\n",
" \n",
" -- param_explored\n",
" 'max_iter',\n",
" \n",
" -- explore_values\n",
" '{2, 10, 40, 100}'::varchar[],\n",
" \n",
" -- predict_func\n",
" 'madlib.cv_logregr_predict',\n",
" \n",
" -- predict_params\n",
" '{%model%, %data%, \"ARRAY[1, bedroom, bath, size]\", id, %prediction%}'::varchar[],\n",
" \n",
" -- predict_params_type\n",
" '{varchar, varchar,varchar,varchar,varchar}'::varchar[],\n",
" \n",
" -- metric_func\n",
" 'madlib.misclassification_avg',\n",
" \n",
" -- metric_params\n",
" '{%prediction%, %data%, id, high_priced, %error%}'::varchar[],\n",
" \n",
" -- metric_params_type\n",
" '{varchar, varchar, varchar, varchar, varchar}'::varchar[],\n",
" \n",
" -- data_tbl\n",
" 'houses',\n",
" \n",
" -- data_id\n",
" 'id',\n",
" \n",
" -- id_is_random\n",
" FALSE,\n",
" \n",
" -- validation_result\n",
" 'houses_logregr_cv',\n",
" \n",
" -- data_cols\n",
" NULL,\n",
" \n",
" -- fold_num\n",
" 5\n",
");"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"4 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>max_iter</th>\n",
" <th>error_rate_avg</th>\n",
" <th>error_rate_stddev</th>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>0.19642857142857142857</td>\n",
" <td>0.0818317088384971429780598253843971801653</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10</td>\n",
" <td>0.22142857142857142857</td>\n",
" <td>0.0731925054711399884549944979733273803475</td>\n",
" </tr>\n",
" <tr>\n",
" <td>40</td>\n",
" <td>0.22142857142857142857</td>\n",
" <td>0.0731925054711399884549944979733273803475</td>\n",
" </tr>\n",
" <tr>\n",
" <td>100</td>\n",
" <td>0.22142857142857142857</td>\n",
" <td>0.0731925054711399884549944979733273803475</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(2, Decimal('0.19642857142857142857'), Decimal('0.0818317088384971429780598253843971801653')),\n",
" (10, Decimal('0.22142857142857142857'), Decimal('0.0731925054711399884549944979733273803475')),\n",
" (40, Decimal('0.22142857142857142857'), Decimal('0.0731925054711399884549944979733273803475')),\n",
" (100, Decimal('0.22142857142857142857'), Decimal('0.0731925054711399884549944979733273803475'))]"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"SELECT * FROM houses_logregr_cv;"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 4. Decision tree\n",
"\n",
"Here we use the general function to explore tree depth. First we need to create a wrapper function for predict that does a column rename:"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n"
]
},
{
"data": {
"text/plain": [
"[]"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"CREATE OR REPLACE FUNCTION tree_predict_rename_col(model_table VARCHAR, data_table VARCHAR, output_table VARCHAR,\n",
" orig_column VARCHAR, new_column VARCHAR)\n",
"RETURNS VOID AS $$\n",
"BEGIN\n",
" EXECUTE format('SELECT madlib.tree_predict(''%s'', ''%s'', ''%s'')', model_table, data_table, output_table);\n",
" EXECUTE 'ALTER TABLE ' || output_table || ' RENAME ' || orig_column || ' TO ' || new_column;\n",
"END\n",
"$$ LANGUAGE plpgsql VOLATILE;"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Done.\n",
"1 rows affected.\n",
"4 rows affected.\n"
]
},
{
"data": {
"text/html": [
"<table>\n",
" <tr>\n",
" <th>max_depth</th>\n",
" <th>error_rate_avg</th>\n",
" <th>error_rate_stddev</th>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>0.16785714285714285714</td>\n",
" <td>0.1208494593977759334440468761527971235643</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>0.16785714285714285714</td>\n",
" <td>0.1208494593977759334440468761527971235643</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3</td>\n",
" <td>0.17142857142857142857</td>\n",
" <td>0.1564921592871903181329101774752513216155</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4</td>\n",
" <td>0.14285714285714285714</td>\n",
" <td>0.1428571428571428571449999999999999999999</td>\n",
" </tr>\n",
"</table>"
],
"text/plain": [
"[(1, Decimal('0.16785714285714285714'), Decimal('0.1208494593977759334440468761527971235643')),\n",
" (2, Decimal('0.16785714285714285714'), Decimal('0.1208494593977759334440468761527971235643')),\n",
" (3, Decimal('0.17142857142857142857'), Decimal('0.1564921592871903181329101774752513216155')),\n",
" (4, Decimal('0.14285714285714285714'), Decimal('0.1428571428571428571449999999999999999999'))]"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%%sql\n",
"DROP TABLE IF EXISTS houses_dt_cv;\n",
"\n",
"SELECT madlib.cross_validation_general(\n",
" -- modelling_func\n",
" 'madlib.tree_train',\n",
" \n",
" -- modelling_params\n",
" '{%data%, %model%, id, high_priced, \"bedroom, bath, size, zipcode\", NULL, NULL, NULL, NULL, max_depth, 1, 1, 10}'::varchar[],\n",
" \n",
" -- modelling_params_type\n",
" '{varchar, varchar, varchar, varchar, varchar, varchar, varchar, varchar, varchar, integer, integer, integer, integer}',\n",
" \n",
" -- param_explored\n",
" 'max_depth',\n",
" \n",
" -- explore_values\n",
" '{1, 2, 3, 4}'::varchar[],\n",
" \n",
" -- predict_func\n",
" 'tree_predict_rename_col',\n",
" \n",
" -- predict_params\n",
" '{%model%, %data%, %prediction%, estimated_high_priced, prediction}'::varchar[],\n",
" \n",
" -- predict_params_type\n",
" '{varchar,varchar,varchar,varchar,varchar}'::varchar[],\n",
" \n",
" -- metric_func\n",
" 'madlib.misclassification_avg',\n",
" \n",
" -- metric_params\n",
" '{%prediction%, %data%, id, high_priced, %error%}'::varchar[],\n",
" \n",
" -- metric_params_type\n",
" '{varchar, varchar, varchar, varchar, varchar}'::varchar[],\n",
" \n",
" -- data_tbl\n",
" 'houses',\n",
" \n",
" -- data_id\n",
" 'id',\n",
" \n",
" -- id_is_random\n",
" FALSE,\n",
" \n",
" -- validation_result\n",
" 'houses_dt_cv',\n",
" \n",
" -- data_cols\n",
" NULL,\n",
" \n",
" -- fold_num\n",
" 5\n",
");\n",
"SELECT * FROM houses_dt_cv;"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 2",
"language": "python",
"name": "python2"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.12"
}
},
"nbformat": 4,
"nbformat_minor": 1
}