versions/0.11.0/tutorials/python/matrix_factorization.html (450 lines of code) (raw):

<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"/> <meta content="IE=edge" http-equiv="X-UA-Compatible"/> <meta content="width=device-width, initial-scale=1" name="viewport"/> <meta content="Matrix Factorization" property="og:title"> <meta content="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/og-logo.png" property="og:image"> <meta content="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/og-logo.png" property="og:image:secure_url"> <meta content="Matrix Factorization" property="og:description"/> <title>Matrix Factorization — mxnet documentation</title> <link crossorigin="anonymous" href="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/css/bootstrap.min.css" integrity="sha384-1q8mTJOASx8j1Au+a5WDVnPi2lkFfwwEAa8hDDdjZlpLegxhjVME1fgjWPGmkzs7" rel="stylesheet"/> <link href="https://maxcdn.bootstrapcdn.com/font-awesome/4.5.0/css/font-awesome.min.css" rel="stylesheet"/> <link href="../../_static/basic.css" rel="stylesheet" type="text/css"> <link href="../../_static/pygments.css" rel="stylesheet" type="text/css"> <link href="../../_static/mxnet.css" rel="stylesheet" type="text/css"/> <script type="text/javascript"> var DOCUMENTATION_OPTIONS = { URL_ROOT: '../../', VERSION: '', COLLAPSE_INDEX: false, FILE_SUFFIX: '.html', HAS_SOURCE: true, SOURCELINK_SUFFIX: '.txt' }; </script> <script src="https://code.jquery.com/jquery-1.11.1.min.js" type="text/javascript"></script> <script src="../../_static/underscore.js" type="text/javascript"></script> <script src="../../_static/searchtools_custom.js" type="text/javascript"></script> <script src="../../_static/doctools.js" type="text/javascript"></script> <script src="../../_static/selectlang.js" type="text/javascript"></script> <script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML" type="text/javascript"></script> <script type="text/javascript"> jQuery(function() { Search.loadIndex("/versions/0.11.0/searchindex.js"); Search.init();}); </script> <!-- --> <!-- <script type="text/javascript" src="../../_static/jquery.js"></script> --> <!-- --> <!-- <script type="text/javascript" src="../../_static/underscore.js"></script> --> <!-- --> <!-- <script type="text/javascript" src="../../_static/doctools.js"></script> --> <!-- --> <!-- <script type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.0/MathJax.js?config=TeX-AMS-MML_HTMLorMML"></script> --> <!-- --> <link href="../../genindex.html" rel="index" title="Index"> <link href="../../search.html" rel="search" title="Search"/> <link href="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet-icon.png" rel="icon" type="image/png"/> </link></link></link></meta></meta></meta></head> <body background="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet-background-compressed.jpeg" role="document"> <div class="content-block"><div class="navbar navbar-fixed-top"> <div class="container" id="navContainer"> <div class="innder" id="header-inner"> <h1 id="logo-wrap"> <a href="../../" id="logo"><img src="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/mxnet_logo.png"/></a> </h1> <nav class="nav-bar" id="main-nav"> <a class="main-nav-link" href="/versions/0.11.0/get_started/install.html">Install</a> <span id="dropdown-menu-position-anchor"> <a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">Gluon <span class="caret"></span></a> <ul class="dropdown-menu navbar-menu" id="package-dropdown-menu"> <li><a class="main-nav-link" href="/versions/0.11.0/tutorials/gluon/gluon.html">About</a></li> <li><a class="main-nav-link" href="https://www.d2l.ai/">Dive into Deep Learning</a></li> <li><a class="main-nav-link" href="https://gluon-cv.mxnet.io">GluonCV Toolkit</a></li> <li><a class="main-nav-link" href="https://gluon-nlp.mxnet.io/">GluonNLP Toolkit</a></li> </ul> </span> <span id="dropdown-menu-position-anchor"> <a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">API <span class="caret"></span></a> <ul class="dropdown-menu navbar-menu" id="package-dropdown-menu"> <li><a class="main-nav-link" href="/versions/0.11.0/api/python/index.html">Python</a></li> <li><a class="main-nav-link" href="/versions/0.11.0/api/c++/index.html">C++</a></li> <li><a class="main-nav-link" href="/versions/0.11.0/api/julia/index.html">Julia</a></li> <li><a class="main-nav-link" href="/versions/0.11.0/api/perl/index.html">Perl</a></li> <li><a class="main-nav-link" href="/versions/0.11.0/api/r/index.html">R</a></li> <li><a class="main-nav-link" href="/versions/0.11.0/api/scala/index.html">Scala</a></li> </ul> </span> <span id="dropdown-menu-position-anchor-docs"> <a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">Docs <span class="caret"></span></a> <ul class="dropdown-menu navbar-menu" id="package-dropdown-menu-docs"> <li><a class="main-nav-link" href="/versions/0.11.0/how_to/faq.html">FAQ</a></li> <li><a class="main-nav-link" href="/versions/0.11.0/tutorials/index.html">Tutorials</a> <li><a class="main-nav-link" href="https://github.com/apache/incubator-mxnet/tree/v0.11.0/example">Examples</a></li> <li><a class="main-nav-link" href="/versions/0.11.0/architecture/index.html">Architecture</a></li> <li><a class="main-nav-link" href="https://cwiki.apache.org/confluence/display/MXNET/Apache+MXNet+Home">Developer Wiki</a></li> <li><a class="main-nav-link" href="/versions/0.11.0/model_zoo/index.html">Model Zoo</a></li> <li><a class="main-nav-link" href="https://github.com/onnx/onnx-mxnet">ONNX</a></li> </li></ul> </span> <span id="dropdown-menu-position-anchor-community"> <a aria-expanded="true" aria-haspopup="true" class="main-nav-link dropdown-toggle" data-toggle="dropdown" href="#" role="button">Community <span class="caret"></span></a> <ul class="dropdown-menu navbar-menu" id="package-dropdown-menu-community"> <li><a class="main-nav-link" href="http://discuss.mxnet.io">Forum</a></li> <li><a class="main-nav-link" href="https://github.com/apache/incubator-mxnet/tree/v0.11.0">Github</a></li> <li><a class="main-nav-link" href="/versions/0.11.0/community/contribute.html">Contribute</a></li> </ul> </span> <span id="dropdown-menu-position-anchor-version" style="position: relative"><a href="#" class="main-nav-link dropdown-toggle" data-toggle="dropdown" role="button" aria-haspopup="true" aria-expanded="true">0.11.0<span class="caret"></span></a><ul id="package-dropdown-menu" class="dropdown-menu"><li><a href="/">master</a></li><li><a href="/versions/1.7.0/">1.7.0</a></li><li><a href=/versions/1.6.0/>1.6.0</a></li><li><a href=/versions/1.5.0/>1.5.0</a></li><li><a href=/versions/1.4.1/>1.4.1</a></li><li><a href=/versions/1.3.1/>1.3.1</a></li><li><a href=/versions/1.2.1/>1.2.1</a></li><li><a href=/versions/1.1.0/>1.1.0</a></li><li><a href=/versions/1.0.0/>1.0.0</a></li><li><a href=/versions/0.12.1/>0.12.1</a></li><li><a href=/versions/0.11.0/>0.11.0</a></li></ul></span></nav> <script> function getRootPath(){ return "../../" } </script> <div class="burgerIcon dropdown"> <a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button">☰</a> <ul class="dropdown-menu" id="burgerMenu"> <li><a href="/versions/0.11.0/get_started/install.html">Install</a></li> <li><a class="main-nav-link" href="/versions/0.11.0/tutorials/index.html">Tutorials</a></li> <li class="dropdown-submenu dropdown"> <a aria-expanded="true" aria-haspopup="true" class="dropdown-toggle burger-link" data-toggle="dropdown" href="#" tabindex="-1">Gluon</a> <ul class="dropdown-menu navbar-menu" id="package-dropdown-menu"> <li><a class="main-nav-link" href="/versions/0.11.0/tutorials/gluon/gluon.html">About</a></li> <li><a class="main-nav-link" href="http://gluon.mxnet.io">The Straight Dope (Tutorials)</a></li> <li><a class="main-nav-link" href="https://gluon-cv.mxnet.io">GluonCV Toolkit</a></li> <li><a class="main-nav-link" href="https://gluon-nlp.mxnet.io/">GluonNLP Toolkit</a></li> </ul> </li> <li class="dropdown-submenu"> <a aria-expanded="true" aria-haspopup="true" class="dropdown-toggle burger-link" data-toggle="dropdown" href="#" tabindex="-1">API</a> <ul class="dropdown-menu"> <li><a class="main-nav-link" href="/versions/0.11.0/api/python/index.html">Python</a></li> <li><a class="main-nav-link" href="/versions/0.11.0/api/c++/index.html">C++</a></li> <li><a class="main-nav-link" href="/versions/0.11.0/api/julia/index.html">Julia</a></li> <li><a class="main-nav-link" href="/versions/0.11.0/api/perl/index.html">Perl</a></li> <li><a class="main-nav-link" href="/versions/0.11.0/api/r/index.html">R</a></li> <li><a class="main-nav-link" href="/versions/0.11.0/api/scala/index.html">Scala</a></li> </ul> </li> <li class="dropdown-submenu"> <a aria-expanded="true" aria-haspopup="true" class="dropdown-toggle burger-link" data-toggle="dropdown" href="#" tabindex="-1">Docs</a> <ul class="dropdown-menu"> <li><a href="/versions/0.11.0/how_to/faq.html" tabindex="-1">FAQ</a></li> <li><a href="/versions/0.11.0/tutorials/index.html" tabindex="-1">Tutorials</a></li> <li><a href="https://github.com/apache/incubator-mxnet/tree/v0.11.0/example" tabindex="-1">Examples</a></li> <li><a href="/versions/0.11.0/architecture/index.html" tabindex="-1">Architecture</a></li> <li><a href="https://cwiki.apache.org/confluence/display/MXNET/Apache+MXNet+Home" tabindex="-1">Developer Wiki</a></li> <li><a href="/versions/0.11.0/model_zoo/index.html" tabindex="-1">Gluon Model Zoo</a></li> <li><a href="https://github.com/onnx/onnx-mxnet" tabindex="-1">ONNX</a></li> </ul> </li> <li class="dropdown-submenu dropdown"> <a aria-haspopup="true" class="dropdown-toggle burger-link" data-toggle="dropdown" href="#" role="button" tabindex="-1">Community</a> <ul class="dropdown-menu"> <li><a href="http://discuss.mxnet.io" tabindex="-1">Forum</a></li> <li><a href="https://github.com/apache/incubator-mxnet/tree/v0.11.0" tabindex="-1">Github</a></li> <li><a href="/versions/0.11.0/community/contribute.html" tabindex="-1">Contribute</a></li> </ul> </li> <li id="dropdown-menu-position-anchor-version-mobile" class="dropdown-submenu" style="position: relative"><a href="#" tabindex="-1">0.11.0</a><ul class="dropdown-menu"><li><a tabindex="-1" href=/>master</a></li><li><a tabindex="-1" href=/versions/1.6.0/>1.6.0</a></li><li><a tabindex="-1" href=/versions/1.5.0/>1.5.0</a></li><li><a tabindex="-1" href=/versions/1.4.1/>1.4.1</a></li><li><a tabindex="-1" href=/versions/1.3.1/>1.3.1</a></li><li><a tabindex="-1" href=/versions/1.2.1/>1.2.1</a></li><li><a tabindex="-1" href=/versions/1.1.0/>1.1.0</a></li><li><a tabindex="-1" href=/versions/1.0.0/>1.0.0</a></li><li><a tabindex="-1" href=/versions/0.12.1/>0.12.1</a></li><li><a tabindex="-1" href=/versions/0.11.0/>0.11.0</a></li></ul></li></ul> </div> <div class="plusIcon dropdown"> <a class="dropdown-toggle" data-toggle="dropdown" href="#" role="button"><span aria-hidden="true" class="glyphicon glyphicon-plus"></span></a> <ul class="dropdown-menu dropdown-menu-right" id="plusMenu"></ul> </div> <div id="search-input-wrap"> <form action="../../search.html" autocomplete="off" class="" method="get" role="search"> <div class="form-group inner-addon left-addon"> <i class="glyphicon glyphicon-search"></i> <input class="form-control" name="q" placeholder="Search" type="text"/> </div> <input name="check_keywords" type="hidden" value="yes"> <input name="area" type="hidden" value="default"/> </input></form> <div id="search-preview"></div> </div> <div id="searchIcon"> <span aria-hidden="true" class="glyphicon glyphicon-search"></span> </div> <!-- <div id="lang-select-wrap"> --> <!-- <label id="lang-select-label"> --> <!-- <\!-- <i class="fa fa-globe"></i> -\-> --> <!-- <span></span> --> <!-- </label> --> <!-- <select id="lang-select"> --> <!-- <option value="en">Eng</option> --> <!-- <option value="zh">中文</option> --> <!-- </select> --> <!-- </div> --> <!-- <a id="mobile-nav-toggle"> <span class="mobile-nav-toggle-bar"></span> <span class="mobile-nav-toggle-bar"></span> <span class="mobile-nav-toggle-bar"></span> </a> --> </div> </div> </div> <script type="text/javascript"> $('body').css('background', 'white'); </script> <div class="container"> <div class="row"> <div aria-label="main navigation" class="sphinxsidebar leftsidebar" role="navigation"> <div class="sphinxsidebarwrapper"> <ul> <li class="toctree-l1"><a class="reference internal" href="../../api/python/index.html">Python Documents</a></li> <li class="toctree-l1"><a class="reference internal" href="../../api/r/index.html">R Documents</a></li> <li class="toctree-l1"><a class="reference internal" href="../../api/julia/index.html">Julia Documents</a></li> <li class="toctree-l1"><a class="reference internal" href="../../api/c++/index.html">C++ Documents</a></li> <li class="toctree-l1"><a class="reference internal" href="../../api/scala/index.html">Scala Documents</a></li> <li class="toctree-l1"><a class="reference internal" href="../../api/perl/index.html">Perl Documents</a></li> <li class="toctree-l1"><a class="reference internal" href="../../how_to/index.html">HowTo Documents</a></li> <li class="toctree-l1"><a class="reference internal" href="../../architecture/index.html">System Documents</a></li> <li class="toctree-l1"><a class="reference internal" href="../index.html">Tutorials</a></li> </ul> </div> </div> <div class="content"> <div class="page-tracker"></div> <div class="section" id="matrix-factorization"> <span id="matrix-factorization"></span><h1>Matrix Factorization<a class="headerlink" href="#matrix-factorization" title="Permalink to this headline">¶</a></h1> <p>In a recommendation system, there is a group of users and a set of items. Given that each users have rated some items in the system, we would like to predict how the users would rate the items that they have not yet rated, such that we can make recommendations to the users.</p> <p>Matrix factorization is one of the mainly used algorithm in recommendation systems. It can be used to discover latent features underlying the interactions between two different kinds of entities.</p> <p>Assume we assign a k-dimensional vector to each user and a k-dimensional vector to each item such that the dot product of these two vectors gives the user’s rating of that item. We can learn the user and item vectors directly, which is essentially performing SVD on the user-item matrix. We can also try to learn the latent features using multi-layer neural networks.</p> <p>In this tutorial, we will work though the steps to implement these ideas in MXNet.</p> <div class="section" id="prepare-data"> <span id="prepare-data"></span><h2>Prepare Data<a class="headerlink" href="#prepare-data" title="Permalink to this headline">¶</a></h2> <p>We use the <a class="reference external" href="http://grouplens.org/datasets/movielens/">MovieLens</a> data here, but it can apply to other datasets as well. Each row of this dataset contains a tuple of user id, movie id, rating, and time stamp, we will only use the first three items. We first define the a batch which contains n tuples. It also provides name and shape information to MXNet about the data and label.</p> <div class="highlight-python"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">Batch</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data_names</span><span class="p">,</span> <span class="n">data</span><span class="p">,</span> <span class="n">label_names</span><span class="p">,</span> <span class="n">label</span><span class="p">):</span> <span class="bp">self</span><span class="o">.</span><span class="n">data</span> <span class="o">=</span> <span class="n">data</span> <span class="bp">self</span><span class="o">.</span><span class="n">label</span> <span class="o">=</span> <span class="n">label</span> <span class="bp">self</span><span class="o">.</span><span class="n">data_names</span> <span class="o">=</span> <span class="n">data_names</span> <span class="bp">self</span><span class="o">.</span><span class="n">label_names</span> <span class="o">=</span> <span class="n">label_names</span> <span class="nd">@property</span> <span class="k">def</span> <span class="nf">provide_data</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="k">return</span> <span class="p">[(</span><span class="n">n</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="k">for</span> <span class="n">n</span><span class="p">,</span> <span class="n">x</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">data_names</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">data</span><span class="p">)]</span> <span class="nd">@property</span> <span class="k">def</span> <span class="nf">provide_label</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="k">return</span> <span class="p">[(</span><span class="n">n</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="k">for</span> <span class="n">n</span><span class="p">,</span> <span class="n">x</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">label_names</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">label</span><span class="p">)]</span> </pre></div> </div> <p>Then we define a data iterator, which returns a batch of tuples each time.</p> <div class="highlight-python"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">mxnet</span> <span class="kn">as</span> <span class="nn">mx</span> <span class="kn">import</span> <span class="nn">random</span> <span class="k">class</span> <span class="nc">Batch</span><span class="p">(</span><span class="nb">object</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data_names</span><span class="p">,</span> <span class="n">data</span><span class="p">,</span> <span class="n">label_names</span><span class="p">,</span> <span class="n">label</span><span class="p">):</span> <span class="bp">self</span><span class="o">.</span><span class="n">data</span> <span class="o">=</span> <span class="n">data</span> <span class="bp">self</span><span class="o">.</span><span class="n">label</span> <span class="o">=</span> <span class="n">label</span> <span class="bp">self</span><span class="o">.</span><span class="n">data_names</span> <span class="o">=</span> <span class="n">data_names</span> <span class="bp">self</span><span class="o">.</span><span class="n">label_names</span> <span class="o">=</span> <span class="n">label_names</span> <span class="nd">@property</span> <span class="k">def</span> <span class="nf">provide_data</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="k">return</span> <span class="p">[(</span><span class="n">n</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="k">for</span> <span class="n">n</span><span class="p">,</span> <span class="n">x</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">data_names</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">data</span><span class="p">)]</span> <span class="nd">@property</span> <span class="k">def</span> <span class="nf">provide_label</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="k">return</span> <span class="p">[(</span><span class="n">n</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="k">for</span> <span class="n">n</span><span class="p">,</span> <span class="n">x</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">label_names</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">label</span><span class="p">)]</span> <span class="k">class</span> <span class="nc">DataIter</span><span class="p">(</span><span class="n">mx</span><span class="o">.</span><span class="n">io</span><span class="o">.</span><span class="n">DataIter</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">fname</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">):</span> <span class="nb">super</span><span class="p">(</span><span class="n">DataIter</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">=</span> <span class="n">batch_size</span> <span class="bp">self</span><span class="o">.</span><span class="n">data</span> <span class="o">=</span> <span class="p">[]</span> <span class="k">for</span> <span class="n">line</span> <span class="ow">in</span> <span class="nb">file</span><span class="p">(</span><span class="n">fname</span><span class="p">):</span> <span class="n">tks</span> <span class="o">=</span> <span class="n">line</span><span class="o">.</span><span class="n">strip</span><span class="p">()</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s1">'</span><span class="se">\t</span><span class="s1">'</span><span class="p">)</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">tks</span><span class="p">)</span> <span class="o">!=</span> <span class="mi">4</span><span class="p">:</span> <span class="k">continue</span> <span class="bp">self</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="nb">int</span><span class="p">(</span><span class="n">tks</span><span class="p">[</span><span class="mi">0</span><span class="p">]),</span> <span class="nb">int</span><span class="p">(</span><span class="n">tks</span><span class="p">[</span><span class="mi">1</span><span class="p">]),</span> <span class="nb">float</span><span class="p">(</span><span class="n">tks</span><span class="p">[</span><span class="mi">2</span><span class="p">])))</span> <span class="bp">self</span><span class="o">.</span><span class="n">provide_data</span> <span class="o">=</span> <span class="p">[(</span><span class="s1">'user'</span><span class="p">,</span> <span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="p">)),</span> <span class="p">(</span><span class="s1">'item'</span><span class="p">,</span> <span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="p">))]</span> <span class="bp">self</span><span class="o">.</span><span class="n">provide_label</span> <span class="o">=</span> <span class="p">[(</span><span class="s1">'score'</span><span class="p">,</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="p">))]</span> <span class="k">def</span> <span class="fm">__iter__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">data</span><span class="p">)</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span><span class="p">):</span> <span class="n">users</span> <span class="o">=</span> <span class="p">[]</span> <span class="n">items</span> <span class="o">=</span> <span class="p">[]</span> <span class="n">scores</span> <span class="o">=</span> <span class="p">[]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span><span class="p">):</span> <span class="n">j</span> <span class="o">=</span> <span class="n">k</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">+</span> <span class="n">i</span> <span class="n">user</span><span class="p">,</span> <span class="n">item</span><span class="p">,</span> <span class="n">score</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">data</span><span class="p">[</span><span class="n">j</span><span class="p">]</span> <span class="n">users</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">user</span><span class="p">)</span> <span class="n">items</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">item</span><span class="p">)</span> <span class="n">scores</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">score</span><span class="p">)</span> <span class="n">data_all</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">users</span><span class="p">),</span> <span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">items</span><span class="p">)]</span> <span class="n">label_all</span> <span class="o">=</span> <span class="p">[</span><span class="n">mx</span><span class="o">.</span><span class="n">nd</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">scores</span><span class="p">)]</span> <span class="n">data_names</span> <span class="o">=</span> <span class="p">[</span><span class="s1">'user'</span><span class="p">,</span> <span class="s1">'item'</span><span class="p">]</span> <span class="n">label_names</span> <span class="o">=</span> <span class="p">[</span><span class="s1">'score'</span><span class="p">]</span> <span class="n">data_batch</span> <span class="o">=</span> <span class="n">Batch</span><span class="p">(</span><span class="n">data_names</span><span class="p">,</span> <span class="n">data_all</span><span class="p">,</span> <span class="n">label_names</span><span class="p">,</span> <span class="n">label_all</span><span class="p">)</span> <span class="k">yield</span> <span class="n">data_batch</span> <span class="k">def</span> <span class="nf">reset</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="n">random</span><span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">data</span><span class="p">)</span> </pre></div> </div> <p>Now we download the data and provide a function to obtain the data iterator:</p> <div class="highlight-python"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">os</span> <span class="kn">import</span> <span class="nn">urllib</span> <span class="kn">import</span> <span class="nn">zipfile</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="s1">'ml-100k.zip'</span><span class="p">):</span> <span class="n">urllib</span><span class="o">.</span><span class="n">urlretrieve</span><span class="p">(</span><span class="s1">'http://files.grouplens.org/datasets/movielens/ml-100k.zip'</span><span class="p">,</span> <span class="s1">'ml-100k.zip'</span><span class="p">)</span> <span class="k">with</span> <span class="n">zipfile</span><span class="o">.</span><span class="n">ZipFile</span><span class="p">(</span><span class="s2">"ml-100k.zip"</span><span class="p">,</span><span class="s2">"r"</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span> <span class="n">f</span><span class="o">.</span><span class="n">extractall</span><span class="p">(</span><span class="s2">"./"</span><span class="p">)</span> <span class="k">def</span> <span class="nf">get_data</span><span class="p">(</span><span class="n">batch_size</span><span class="p">):</span> <span class="k">return</span> <span class="p">(</span><span class="n">DataIter</span><span class="p">(</span><span class="s1">'./ml-100k/u1.base'</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">),</span> <span class="n">DataIter</span><span class="p">(</span><span class="s1">'./ml-100k/u1.test'</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">))</span> </pre></div> </div> <p>Finally we calculate the numbers of users and items for later use.</p> <div class="highlight-python"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">max_id</span><span class="p">(</span><span class="n">fname</span><span class="p">):</span> <span class="n">mu</span> <span class="o">=</span> <span class="mi">0</span> <span class="n">mi</span> <span class="o">=</span> <span class="mi">0</span> <span class="k">for</span> <span class="n">line</span> <span class="ow">in</span> <span class="nb">file</span><span class="p">(</span><span class="n">fname</span><span class="p">):</span> <span class="n">tks</span> <span class="o">=</span> <span class="n">line</span><span class="o">.</span><span class="n">strip</span><span class="p">()</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s1">'</span><span class="se">\t</span><span class="s1">'</span><span class="p">)</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">tks</span><span class="p">)</span> <span class="o">!=</span> <span class="mi">4</span><span class="p">:</span> <span class="k">continue</span> <span class="n">mu</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="n">mu</span><span class="p">,</span> <span class="nb">int</span><span class="p">(</span><span class="n">tks</span><span class="p">[</span><span class="mi">0</span><span class="p">]))</span> <span class="n">mi</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="n">mi</span><span class="p">,</span> <span class="nb">int</span><span class="p">(</span><span class="n">tks</span><span class="p">[</span><span class="mi">1</span><span class="p">]))</span> <span class="k">return</span> <span class="n">mu</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">mi</span> <span class="o">+</span> <span class="mi">1</span> <span class="n">max_user</span><span class="p">,</span> <span class="n">max_item</span> <span class="o">=</span> <span class="n">max_id</span><span class="p">(</span><span class="s1">'./ml-100k/u.data'</span><span class="p">)</span> <span class="p">(</span><span class="n">max_user</span><span class="p">,</span> <span class="n">max_item</span><span class="p">)</span> </pre></div> </div> </div> <div class="section" id="optimization"> <span id="optimization"></span><h2>Optimization<a class="headerlink" href="#optimization" title="Permalink to this headline">¶</a></h2> <p>We first implement the RMSE (root-mean-square error) measurement, which is commonly used by matrix factorization.</p> <div class="highlight-python"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">math</span> <span class="k">def</span> <span class="nf">RMSE</span><span class="p">(</span><span class="n">label</span><span class="p">,</span> <span class="n">pred</span><span class="p">):</span> <span class="n">ret</span> <span class="o">=</span> <span class="mf">0.0</span> <span class="n">n</span> <span class="o">=</span> <span class="mf">0.0</span> <span class="n">pred</span> <span class="o">=</span> <span class="n">pred</span><span class="o">.</span><span class="n">flatten</span><span class="p">()</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">label</span><span class="p">)):</span> <span class="n">ret</span> <span class="o">+=</span> <span class="p">(</span><span class="n">label</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">-</span> <span class="n">pred</span><span class="p">[</span><span class="n">i</span><span class="p">])</span> <span class="o">*</span> <span class="p">(</span><span class="n">label</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">-</span> <span class="n">pred</span><span class="p">[</span><span class="n">i</span><span class="p">])</span> <span class="n">n</span> <span class="o">+=</span> <span class="mf">1.0</span> <span class="k">return</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">ret</span> <span class="o">/</span> <span class="n">n</span><span class="p">)</span> </pre></div> </div> <p>Then we define a general training module, which is borrowed from the image classification application.</p> <div class="highlight-python"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">train</span><span class="p">(</span><span class="n">network</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">num_epoch</span><span class="p">,</span> <span class="n">learning_rate</span><span class="p">):</span> <span class="n">model</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">FeedForward</span><span class="p">(</span> <span class="n">ctx</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">gpu</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">symbol</span> <span class="o">=</span> <span class="n">network</span><span class="p">,</span> <span class="n">num_epoch</span> <span class="o">=</span> <span class="n">num_epoch</span><span class="p">,</span> <span class="n">learning_rate</span> <span class="o">=</span> <span class="n">learning_rate</span><span class="p">,</span> <span class="n">wd</span> <span class="o">=</span> <span class="mf">0.0001</span><span class="p">,</span> <span class="n">momentum</span> <span class="o">=</span> <span class="mf">0.9</span><span class="p">)</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="mi">64</span> <span class="n">train</span><span class="p">,</span> <span class="n">test</span> <span class="o">=</span> <span class="n">get_data</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)</span> <span class="kn">import</span> <span class="nn">logging</span> <span class="n">head</span> <span class="o">=</span> <span class="s1">'</span><span class="si">%(asctime)-15s</span><span class="s1"> </span><span class="si">%(message)s</span><span class="s1">'</span> <span class="n">logging</span><span class="o">.</span><span class="n">basicConfig</span><span class="p">(</span><span class="n">level</span><span class="o">=</span><span class="n">logging</span><span class="o">.</span><span class="n">DEBUG</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">X</span> <span class="o">=</span> <span class="n">train</span><span class="p">,</span> <span class="n">eval_data</span> <span class="o">=</span> <span class="n">test</span><span class="p">,</span> <span class="n">eval_metric</span> <span class="o">=</span> <span class="n">RMSE</span><span class="p">,</span> <span class="n">batch_end_callback</span><span class="o">=</span><span class="n">mx</span><span class="o">.</span><span class="n">callback</span><span class="o">.</span><span class="n">Speedometer</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">20000</span><span class="o">/</span><span class="n">batch_size</span><span class="p">),)</span> </pre></div> </div> </div> <div class="section" id="networks"> <span id="networks"></span><h2>Networks<a class="headerlink" href="#networks" title="Permalink to this headline">¶</a></h2> <p>Now we try various networks. We first learn the latent vectors directly.</p> <div class="highlight-python"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">plain_net</span><span class="p">(</span><span class="n">k</span><span class="p">):</span> <span class="c1"># input</span> <span class="n">user</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">Variable</span><span class="p">(</span><span class="s1">'user'</span><span class="p">)</span> <span class="n">item</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">Variable</span><span class="p">(</span><span class="s1">'item'</span><span class="p">)</span> <span class="n">score</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">Variable</span><span class="p">(</span><span class="s1">'score'</span><span class="p">)</span> <span class="c1"># user feature lookup</span> <span class="n">user</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span><span class="n">data</span> <span class="o">=</span> <span class="n">user</span><span class="p">,</span> <span class="n">input_dim</span> <span class="o">=</span> <span class="n">max_user</span><span class="p">,</span> <span class="n">output_dim</span> <span class="o">=</span> <span class="n">k</span><span class="p">)</span> <span class="c1"># item feature lookup</span> <span class="n">item</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span><span class="n">data</span> <span class="o">=</span> <span class="n">item</span><span class="p">,</span> <span class="n">input_dim</span> <span class="o">=</span> <span class="n">max_item</span><span class="p">,</span> <span class="n">output_dim</span> <span class="o">=</span> <span class="n">k</span><span class="p">)</span> <span class="c1"># predict by the inner product, which is elementwise product and then sum</span> <span class="n">pred</span> <span class="o">=</span> <span class="n">user</span> <span class="o">*</span> <span class="n">item</span> <span class="n">pred</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">sum_axis</span><span class="p">(</span><span class="n">data</span> <span class="o">=</span> <span class="n">pred</span><span class="p">,</span> <span class="n">axis</span> <span class="o">=</span> <span class="mi">1</span><span class="p">)</span> <span class="n">pred</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">Flatten</span><span class="p">(</span><span class="n">data</span> <span class="o">=</span> <span class="n">pred</span><span class="p">)</span> <span class="c1"># loss layer</span> <span class="n">pred</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">LinearRegressionOutput</span><span class="p">(</span><span class="n">data</span> <span class="o">=</span> <span class="n">pred</span><span class="p">,</span> <span class="n">label</span> <span class="o">=</span> <span class="n">score</span><span class="p">)</span> <span class="k">return</span> <span class="n">pred</span> <span class="n">train</span><span class="p">(</span><span class="n">plain_net</span><span class="p">(</span><span class="mi">64</span><span class="p">),</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">num_epoch</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">learning_rate</span><span class="o">=.</span><span class="mo">05</span><span class="p">)</span> </pre></div> </div> <p>Next we try to use 2 layers neural network to learn the latent variables, which stack a fully connected layer above the embedding layers:</p> <div class="highlight-python"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">get_one_layer_mlp</span><span class="p">(</span><span class="n">hidden</span><span class="p">,</span> <span class="n">k</span><span class="p">):</span> <span class="c1"># input</span> <span class="n">user</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">Variable</span><span class="p">(</span><span class="s1">'user'</span><span class="p">)</span> <span class="n">item</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">Variable</span><span class="p">(</span><span class="s1">'item'</span><span class="p">)</span> <span class="n">score</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">Variable</span><span class="p">(</span><span class="s1">'score'</span><span class="p">)</span> <span class="c1"># user latent features</span> <span class="n">user</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span><span class="n">data</span> <span class="o">=</span> <span class="n">user</span><span class="p">,</span> <span class="n">input_dim</span> <span class="o">=</span> <span class="n">max_user</span><span class="p">,</span> <span class="n">output_dim</span> <span class="o">=</span> <span class="n">k</span><span class="p">)</span> <span class="n">user</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">Activation</span><span class="p">(</span><span class="n">data</span> <span class="o">=</span> <span class="n">user</span><span class="p">,</span> <span class="n">act_type</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">)</span> <span class="n">user</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">FullyConnected</span><span class="p">(</span><span class="n">data</span> <span class="o">=</span> <span class="n">user</span><span class="p">,</span> <span class="n">num_hidden</span> <span class="o">=</span> <span class="n">hidden</span><span class="p">)</span> <span class="c1"># item latent features</span> <span class="n">item</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span><span class="n">data</span> <span class="o">=</span> <span class="n">item</span><span class="p">,</span> <span class="n">input_dim</span> <span class="o">=</span> <span class="n">max_item</span><span class="p">,</span> <span class="n">output_dim</span> <span class="o">=</span> <span class="n">k</span><span class="p">)</span> <span class="n">item</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">Activation</span><span class="p">(</span><span class="n">data</span> <span class="o">=</span> <span class="n">item</span><span class="p">,</span> <span class="n">act_type</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">)</span> <span class="n">item</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">FullyConnected</span><span class="p">(</span><span class="n">data</span> <span class="o">=</span> <span class="n">item</span><span class="p">,</span> <span class="n">num_hidden</span> <span class="o">=</span> <span class="n">hidden</span><span class="p">)</span> <span class="c1"># predict by the inner product</span> <span class="n">pred</span> <span class="o">=</span> <span class="n">user</span> <span class="o">*</span> <span class="n">item</span> <span class="n">pred</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">sum_axis</span><span class="p">(</span><span class="n">data</span> <span class="o">=</span> <span class="n">pred</span><span class="p">,</span> <span class="n">axis</span> <span class="o">=</span> <span class="mi">1</span><span class="p">)</span> <span class="n">pred</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">Flatten</span><span class="p">(</span><span class="n">data</span> <span class="o">=</span> <span class="n">pred</span><span class="p">)</span> <span class="c1"># loss layer</span> <span class="n">pred</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">LinearRegressionOutput</span><span class="p">(</span><span class="n">data</span> <span class="o">=</span> <span class="n">pred</span><span class="p">,</span> <span class="n">label</span> <span class="o">=</span> <span class="n">score</span><span class="p">)</span> <span class="k">return</span> <span class="n">pred</span> <span class="n">train</span><span class="p">(</span><span class="n">get_one_layer_mlp</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="mi">64</span><span class="p">),</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">num_epoch</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">learning_rate</span><span class="o">=.</span><span class="mo">05</span><span class="p">)</span> </pre></div> </div> <p>Adding dropout layers to relief the over-fitting.</p> <div class="highlight-python"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">get_one_layer_dropout_mlp</span><span class="p">(</span><span class="n">hidden</span><span class="p">,</span> <span class="n">k</span><span class="p">):</span> <span class="c1"># input</span> <span class="n">user</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">Variable</span><span class="p">(</span><span class="s1">'user'</span><span class="p">)</span> <span class="n">item</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">Variable</span><span class="p">(</span><span class="s1">'item'</span><span class="p">)</span> <span class="n">score</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">Variable</span><span class="p">(</span><span class="s1">'score'</span><span class="p">)</span> <span class="c1"># user latent features</span> <span class="n">user</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span><span class="n">data</span> <span class="o">=</span> <span class="n">user</span><span class="p">,</span> <span class="n">input_dim</span> <span class="o">=</span> <span class="n">max_user</span><span class="p">,</span> <span class="n">output_dim</span> <span class="o">=</span> <span class="n">k</span><span class="p">)</span> <span class="n">user</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">Activation</span><span class="p">(</span><span class="n">data</span> <span class="o">=</span> <span class="n">user</span><span class="p">,</span> <span class="n">act_type</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">)</span> <span class="n">user</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">FullyConnected</span><span class="p">(</span><span class="n">data</span> <span class="o">=</span> <span class="n">user</span><span class="p">,</span> <span class="n">num_hidden</span> <span class="o">=</span> <span class="n">hidden</span><span class="p">)</span> <span class="n">user</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">user</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="mf">0.5</span><span class="p">)</span> <span class="c1"># item latent features</span> <span class="n">item</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span><span class="n">data</span> <span class="o">=</span> <span class="n">item</span><span class="p">,</span> <span class="n">input_dim</span> <span class="o">=</span> <span class="n">max_item</span><span class="p">,</span> <span class="n">output_dim</span> <span class="o">=</span> <span class="n">k</span><span class="p">)</span> <span class="n">item</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">Activation</span><span class="p">(</span><span class="n">data</span> <span class="o">=</span> <span class="n">item</span><span class="p">,</span> <span class="n">act_type</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">)</span> <span class="n">item</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">FullyConnected</span><span class="p">(</span><span class="n">data</span> <span class="o">=</span> <span class="n">item</span><span class="p">,</span> <span class="n">num_hidden</span> <span class="o">=</span> <span class="n">hidden</span><span class="p">)</span> <span class="n">item</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="n">item</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="mf">0.5</span><span class="p">)</span> <span class="c1"># predict by the inner product</span> <span class="n">pred</span> <span class="o">=</span> <span class="n">user</span> <span class="o">*</span> <span class="n">item</span> <span class="n">pred</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">sum_axis</span><span class="p">(</span><span class="n">data</span> <span class="o">=</span> <span class="n">pred</span><span class="p">,</span> <span class="n">axis</span> <span class="o">=</span> <span class="mi">1</span><span class="p">)</span> <span class="n">pred</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">Flatten</span><span class="p">(</span><span class="n">data</span> <span class="o">=</span> <span class="n">pred</span><span class="p">)</span> <span class="c1"># loss layer</span> <span class="n">pred</span> <span class="o">=</span> <span class="n">mx</span><span class="o">.</span><span class="n">symbol</span><span class="o">.</span><span class="n">LinearRegressionOutput</span><span class="p">(</span><span class="n">data</span> <span class="o">=</span> <span class="n">pred</span><span class="p">,</span> <span class="n">label</span> <span class="o">=</span> <span class="n">score</span><span class="p">)</span> <span class="k">return</span> <span class="n">pred</span> <span class="n">train</span><span class="p">(</span><span class="n">get_one_layer_mlp</span><span class="p">(</span><span class="mi">256</span><span class="p">,</span> <span class="mi">512</span><span class="p">),</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">num_epoch</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">learning_rate</span><span class="o">=.</span><span class="mo">05</span><span class="p">)</span> </pre></div> </div> <div class="btn-group" role="group"> <div class="download-btn"><a download="matrix_factorization.ipynb" href="matrix_factorization.ipynb"><span class="glyphicon glyphicon-download-alt"></span> matrix_factorization.ipynb</a></div></div></div> </div> </div> </div> <div aria-label="main navigation" class="sphinxsidebar rightsidebar" role="navigation"> <div class="sphinxsidebarwrapper"> <h3><a href="../../index.html">Table Of Contents</a></h3> <ul> <li><a class="reference internal" href="#">Matrix Factorization</a><ul> <li><a class="reference internal" href="#prepare-data">Prepare Data</a></li> <li><a class="reference internal" href="#optimization">Optimization</a></li> <li><a class="reference internal" href="#networks">Networks</a></li> </ul> </li> </ul> </div> </div> </div><div class="footer"> <div class="section-disclaimer"> <div class="container"> <div> <img height="60" src="https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/image/apache_incubator_logo.png"/> <p> Apache MXNet is an effort undergoing incubation at The Apache Software Foundation (ASF), <strong>sponsored by the <i>Apache Incubator</i></strong>. Incubation is required of all newly accepted projects until a further review indicates that the infrastructure, communications, and decision making process have stabilized in a manner consistent with other successful ASF projects. While incubation status is not necessarily a reflection of the completeness or stability of the code, it does indicate that the project has yet to be fully endorsed by the ASF. </p> <p> "Copyright © 2017-2018, The Apache Software Foundation Apache MXNet, MXNet, Apache, the Apache feather, and the Apache MXNet project logo are either registered trademarks or trademarks of the Apache Software Foundation." </p> </div> </div> </div> </div> <!-- pagename != index --> </div> <script crossorigin="anonymous" integrity="sha384-0mSbJDEHialfmuBBQP6A4Qrprq5OVfW37PRR3j5ELqxss1yVqOtnepnHVP9aJ7xS" src="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.6/js/bootstrap.min.js"></script> <script src="../../_static/js/sidebar.js" type="text/javascript"></script> <script src="../../_static/js/search.js" type="text/javascript"></script> <script src="../../_static/js/navbar.js" type="text/javascript"></script> <script src="../../_static/js/clipboard.min.js" type="text/javascript"></script> <script src="../../_static/js/copycode.js" type="text/javascript"></script> <script src="../../_static/js/page.js" type="text/javascript"></script> <script src="../../_static/js/docversion.js" type="text/javascript"></script> <script type="text/javascript"> $('body').ready(function () { $('body').css('visibility', 'visible'); }); </script> </body> </html>