session1/deep-learning.ipynb (987 lines of code) (raw):

{ "cells": [ { "cell_type": "markdown", "id": "542c9b3d-4661-4172-8ffe-67cd9a27e0d4", "metadata": {}, "source": [ "# Utils & Setup\n", "- all taken from micrograd (https://github.com/karpathy/micrograd)" ] }, { "cell_type": "code", "execution_count": 1, "id": "5f23a85d-0b52-4eac-ae5c-30526ed217c7", "metadata": { "tags": [] }, "outputs": [], "source": [ "import math\n", "\n", "import torch\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": 2, "id": "27d0d364-27aa-4ae2-90fb-b44eae300b87", "metadata": { "tags": [] }, "outputs": [], "source": [ "# you can ignore this cell\n", "\n", "from graphviz import Digraph\n", "\n", "def trace(root):\n", " # builds a set of all nodes and edges in a graph\n", " nodes, edges = set(), set()\n", " def build(v):\n", " if v not in nodes:\n", " nodes.add(v)\n", " for child in v._prev:\n", " edges.add((child, v))\n", " build(child)\n", " build(root)\n", " return nodes, edges\n", "\n", "def draw_dot(root, format='svg'):\n", " dot = Digraph(format=format, graph_attr={'rankdir': 'LR'}) # LR = left to right\n", " \n", " nodes, edges = trace(root)\n", " for n in nodes:\n", " uid = str(id(n))\n", " # for any value in the graph, create a rectangular ('record') node for it\n", " dot.node(name = uid, label = \"{ %s | data %.4f | grad %.4f }\" % (n.label, n.data, n.grad), shape='record')\n", " if n._op:\n", " # if this value is a result of some operation, create an op node for it\n", " dot.node(name = uid + n._op, label = n._op)\n", " # and connect this node to it\n", " dot.edge(uid + n._op, uid)\n", "\n", " for n1, n2 in edges:\n", " # connect n1 to the op node of n2\n", " dot.edge(str(id(n1)), str(id(n2)) + n2._op)\n", "\n", " return dot\n" ] }, { "cell_type": "markdown", "id": "1fc3bf60-138a-43da-afbd-774c71e8c53d", "metadata": {}, "source": [ "# Quick tour of micrograd" ] }, { "cell_type": "markdown", "id": "10ef059a-7741-484f-9e5f-06e3743a95f5", "metadata": {}, "source": [ "## The Value class" ] }, { "cell_type": "code", "execution_count": 3, "id": "be63a1b8-1c24-4a42-b630-49ff4f53eb65", "metadata": { "tags": [] }, "outputs": [], "source": [ "class Value:\n", " \"\"\"Smaller subset of the full Value from micrograd. \"\"\"\n", "\n", " def __init__(self, data, _children=(), _op='', label=''):\n", " self.data = data\n", " self.grad = 0.0\n", " self._backward = lambda: None\n", " self._prev = set(_children)\n", " self._op = _op\n", " self.label = label\n", "\n", " def __repr__(self):\n", " return f\"Value(data={self.data})\"\n", " \n", " def __add__(self, other):\n", " out = Value(self.data + other.data, (self, other), '+')\n", " \n", " def _backward():\n", " self.grad += 1.0 * out.grad\n", " other.grad += 1.0 * out.grad\n", " out._backward = _backward\n", " \n", " return out\n", "\n", " def __mul__(self, other):\n", " out = Value(self.data * other.data, (self, other), '*')\n", " \n", " def _backward():\n", " self.grad += other.data * out.grad\n", " other.grad += self.data * out.grad\n", " out._backward = _backward\n", " \n", " return out\n", " \n", " def tanh(self):\n", " x = self.data\n", " t = (math.exp(2*x) - 1)/(math.exp(2*x) + 1)\n", " out = Value(t, (self, ), 'tanh')\n", " \n", " def _backward():\n", " self.grad += (1 - t**2) * out.grad\n", " out._backward = _backward\n", " \n", " return out\n", " \n", " def backward(self):\n", " # builds a topological sort using DFS (recursion)\n", " topo = []\n", " visited = set()\n", " def build_topo(v):\n", " if v not in visited:\n", " visited.add(v)\n", " for child in v._prev:\n", " build_topo(child)\n", " topo.append(v)\n", " build_topo(self)\n", " \n", " self.grad = 1.0\n", " for node in reversed(topo):\n", " node._backward()" ] }, { "cell_type": "markdown", "id": "58cc7b67-6733-4837-8026-57aa0185d450", "metadata": {}, "source": [ "## Sample expression" ] }, { "cell_type": "code", "execution_count": 4, "id": "95891e0a-a235-4881-8487-00c18f50a93e", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "Value(data=-8.0)" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# build a sample expression\n", "a = Value(2.0, label='a')\n", "b = Value(-3.0, label='b')\n", "c = Value(10.0, label='c')\n", "e = a*b; e.label = 'e'\n", "d = e + c; d.label = 'd'\n", "f = Value(-2.0, label='f')\n", "L = d * f; L.label = 'L'\n", "L" ] }, { "cell_type": "markdown", "id": "7eb22cf7-9f3a-4816-b077-b843b3340a96", "metadata": {}, "source": [ "# How does deep learning work?\n", "- Let's create the expression of a neuron that takes two inputs, x1 and x2.\n", "- There are two weights plus a bias, for a total of 3 parameters.\n", "- The activation function will be a tanh (it's an S-looking curve)." ] }, { "cell_type": "markdown", "id": "68bb549a-190e-4305-982d-e7259c582a40", "metadata": {}, "source": [ "## Math expression for a 2-input neuron" ] }, { "cell_type": "code", "execution_count": 5, "id": "105ab959-ffcb-459a-8933-3169c0fa9b07", "metadata": { "tags": [] }, "outputs": [], "source": [ "# inputs x1,x2\n", "x1 = Value(2.0, label='x1')\n", "x2 = Value(0.0, label='x2')\n", "# weights w1,w2\n", "w1 = Value(-3.0, label='w1')\n", "w2 = Value(1.0, label='w2')\n", "# bias of the neuron\n", "b = Value(6.8813735870195432, label='b')\n", "# x1*w1 + x2*w2 + b\n", "x1w1 = x1*w1; x1w1.label = 'x1*w1'\n", "x2w2 = x2*w2; x2w2.label = 'x2*w2'\n", "x1w1x2w2 = x1w1 + x2w2; x1w1x2w2.label = 'x1*w1 + x2*w2'\n", "n = x1w1x2w2 + b; n.label = 'n'\n", "o = n.tanh(); o.label = 'o'" ] }, { "cell_type": "code", "execution_count": 6, "id": "ec16b0ef-40b9-4666-b81b-1b1a2d0fb7a0", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "0.7071067811865476" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "o.data # tanh(2*-3 + 0*1 + 6.88)" ] }, { "cell_type": "code", "execution_count": 7, "id": "00b1b311-192a-4e95-857c-769fbb5d75cf", "metadata": { "tags": [] }, "outputs": [ { "data": { "image/svg+xml": [ "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n", "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n", " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n", "<!-- Generated by graphviz version 2.43.0 (0)\n", " -->\n", "<!-- Title: %3 Pages: 1 -->\n", "<svg width=\"1846pt\" height=\"210pt\"\n", " viewBox=\"0.00 0.00 1845.69 210.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 206)\">\n", "<title>%3</title>\n", "<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-206 1841.69,-206 1841.69,4 -4,4\"/>\n", "<!-- 140390867926560 -->\n", "<g id=\"node1\" class=\"node\">\n", "<title>140390867926560</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"827.5,-137.5 827.5,-173.5 1059.5,-173.5 1059.5,-137.5 827.5,-137.5\"/>\n", "<text text-anchor=\"middle\" x=\"840.5\" y=\"-151.8\" font-family=\"Times,serif\" font-size=\"14.00\">b</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"853.5,-137.5 853.5,-173.5 \"/>\n", "<text text-anchor=\"middle\" x=\"904.5\" y=\"-151.8\" font-family=\"Times,serif\" font-size=\"14.00\">data 6.8814</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"955.5,-137.5 955.5,-173.5 \"/>\n", "<text text-anchor=\"middle\" x=\"1007.5\" y=\"-151.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.0000</text>\n", "</g>\n", "<!-- 140390867925264+ -->\n", "<g id=\"node8\" class=\"node\">\n", "<title>140390867925264+</title>\n", "<ellipse fill=\"none\" stroke=\"black\" cx=\"1176\" cy=\"-127.5\" rx=\"27\" ry=\"18\"/>\n", "<text text-anchor=\"middle\" x=\"1176\" y=\"-123.8\" font-family=\"Times,serif\" font-size=\"14.00\">+</text>\n", "</g>\n", "<!-- 140390867926560&#45;&gt;140390867925264+ -->\n", "<g id=\"edge8\" class=\"edge\">\n", "<title>140390867926560&#45;&gt;140390867925264+</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M1059.75,-141.5C1088.57,-138 1117.4,-134.5 1139.01,-131.87\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"1139.5,-135.34 1149.01,-130.66 1138.66,-128.39 1139.5,-135.34\"/>\n", "</g>\n", "<!-- 140390867774576 -->\n", "<g id=\"node2\" class=\"node\">\n", "<title>140390867774576</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"4.5,-55.5 4.5,-91.5 244.5,-91.5 244.5,-55.5 4.5,-55.5\"/>\n", "<text text-anchor=\"middle\" x=\"21.5\" y=\"-69.8\" font-family=\"Times,serif\" font-size=\"14.00\">x1</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"38.5,-55.5 38.5,-91.5 \"/>\n", "<text text-anchor=\"middle\" x=\"89.5\" y=\"-69.8\" font-family=\"Times,serif\" font-size=\"14.00\">data 2.0000</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"140.5,-55.5 140.5,-91.5 \"/>\n", "<text text-anchor=\"middle\" x=\"192.5\" y=\"-69.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.0000</text>\n", "</g>\n", "<!-- 140390867923392* -->\n", "<g id=\"node14\" class=\"node\">\n", "<title>140390867923392*</title>\n", "<ellipse fill=\"none\" stroke=\"black\" cx=\"312\" cy=\"-73.5\" rx=\"27\" ry=\"18\"/>\n", "<text text-anchor=\"middle\" x=\"312\" y=\"-69.8\" font-family=\"Times,serif\" font-size=\"14.00\">*</text>\n", "</g>\n", "<!-- 140390867774576&#45;&gt;140390867923392* -->\n", "<g id=\"edge10\" class=\"edge\">\n", "<title>140390867774576&#45;&gt;140390867923392*</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M244.55,-73.5C255.31,-73.5 265.54,-73.5 274.62,-73.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"274.7,-77 284.7,-73.5 274.7,-70 274.7,-77\"/>\n", "</g>\n", "<!-- 140390867925744 -->\n", "<g id=\"node3\" class=\"node\">\n", "<title>140390867925744</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"377.5,-110.5 377.5,-146.5 645.5,-146.5 645.5,-110.5 377.5,-110.5\"/>\n", "<text text-anchor=\"middle\" x=\"408.5\" y=\"-124.8\" font-family=\"Times,serif\" font-size=\"14.00\">x2*w2</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"439.5,-110.5 439.5,-146.5 \"/>\n", "<text text-anchor=\"middle\" x=\"490.5\" y=\"-124.8\" font-family=\"Times,serif\" font-size=\"14.00\">data 0.0000</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"541.5,-110.5 541.5,-146.5 \"/>\n", "<text text-anchor=\"middle\" x=\"593.5\" y=\"-124.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.0000</text>\n", "</g>\n", "<!-- 140390867924832+ -->\n", "<g id=\"node12\" class=\"node\">\n", "<title>140390867924832+</title>\n", "<ellipse fill=\"none\" stroke=\"black\" cx=\"711\" cy=\"-100.5\" rx=\"27\" ry=\"18\"/>\n", "<text text-anchor=\"middle\" x=\"711\" y=\"-96.8\" font-family=\"Times,serif\" font-size=\"14.00\">+</text>\n", "</g>\n", "<!-- 140390867925744&#45;&gt;140390867924832+ -->\n", "<g id=\"edge14\" class=\"edge\">\n", "<title>140390867925744&#45;&gt;140390867924832+</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M639.53,-110.49C652.09,-108.71 663.99,-107.02 674.3,-105.56\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"675.03,-108.99 684.44,-104.12 674.05,-102.06 675.03,-108.99\"/>\n", "</g>\n", "<!-- 140390867925744* -->\n", "<g id=\"node4\" class=\"node\">\n", "<title>140390867925744*</title>\n", "<ellipse fill=\"none\" stroke=\"black\" cx=\"312\" cy=\"-128.5\" rx=\"27\" ry=\"18\"/>\n", "<text text-anchor=\"middle\" x=\"312\" y=\"-124.8\" font-family=\"Times,serif\" font-size=\"14.00\">*</text>\n", "</g>\n", "<!-- 140390867925744*&#45;&gt;140390867925744 -->\n", "<g id=\"edge1\" class=\"edge\">\n", "<title>140390867925744*&#45;&gt;140390867925744</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M339.23,-128.5C347.26,-128.5 356.72,-128.5 366.99,-128.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"367.08,-132 377.08,-128.5 367.08,-125 367.08,-132\"/>\n", "</g>\n", "<!-- 140390867926272 -->\n", "<g id=\"node5\" class=\"node\">\n", "<title>140390867926272</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"1606.69,-109.5 1606.69,-145.5 1837.69,-145.5 1837.69,-109.5 1606.69,-109.5\"/>\n", "<text text-anchor=\"middle\" x=\"1619.19\" y=\"-123.8\" font-family=\"Times,serif\" font-size=\"14.00\">o</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"1631.69,-109.5 1631.69,-145.5 \"/>\n", "<text text-anchor=\"middle\" x=\"1682.69\" y=\"-123.8\" font-family=\"Times,serif\" font-size=\"14.00\">data 0.7071</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"1733.69,-109.5 1733.69,-145.5 \"/>\n", "<text text-anchor=\"middle\" x=\"1785.69\" y=\"-123.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.0000</text>\n", "</g>\n", "<!-- 140390867926272tanh -->\n", "<g id=\"node6\" class=\"node\">\n", "<title>140390867926272tanh</title>\n", "<ellipse fill=\"none\" stroke=\"black\" cx=\"1538.85\" cy=\"-127.5\" rx=\"31.7\" ry=\"18\"/>\n", "<text text-anchor=\"middle\" x=\"1538.85\" y=\"-123.8\" font-family=\"Times,serif\" font-size=\"14.00\">tanh</text>\n", "</g>\n", "<!-- 140390867926272tanh&#45;&gt;140390867926272 -->\n", "<g id=\"edge2\" class=\"edge\">\n", "<title>140390867926272tanh&#45;&gt;140390867926272</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M1571.02,-127.5C1578.61,-127.5 1587.19,-127.5 1596.31,-127.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"1596.4,-131 1606.4,-127.5 1596.4,-124 1596.4,-131\"/>\n", "</g>\n", "<!-- 140390867925264 -->\n", "<g id=\"node7\" class=\"node\">\n", "<title>140390867925264</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"1239,-109.5 1239,-145.5 1471,-145.5 1471,-109.5 1239,-109.5\"/>\n", "<text text-anchor=\"middle\" x=\"1252\" y=\"-123.8\" font-family=\"Times,serif\" font-size=\"14.00\">n</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"1265,-109.5 1265,-145.5 \"/>\n", "<text text-anchor=\"middle\" x=\"1316\" y=\"-123.8\" font-family=\"Times,serif\" font-size=\"14.00\">data 0.8814</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"1367,-109.5 1367,-145.5 \"/>\n", "<text text-anchor=\"middle\" x=\"1419\" y=\"-123.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.0000</text>\n", "</g>\n", "<!-- 140390867925264&#45;&gt;140390867926272tanh -->\n", "<g id=\"edge11\" class=\"edge\">\n", "<title>140390867925264&#45;&gt;140390867926272tanh</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M1471.17,-127.5C1480.08,-127.5 1488.66,-127.5 1496.53,-127.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"1496.78,-131 1506.78,-127.5 1496.78,-124 1496.78,-131\"/>\n", "</g>\n", "<!-- 140390867925264+&#45;&gt;140390867925264 -->\n", "<g id=\"edge3\" class=\"edge\">\n", "<title>140390867925264+&#45;&gt;140390867925264</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M1203.12,-127.5C1210.53,-127.5 1219.14,-127.5 1228.4,-127.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"1228.69,-131 1238.69,-127.5 1228.69,-124 1228.69,-131\"/>\n", "</g>\n", "<!-- 140390867926800 -->\n", "<g id=\"node9\" class=\"node\">\n", "<title>140390867926800</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"0,-0.5 0,-36.5 249,-36.5 249,-0.5 0,-0.5\"/>\n", "<text text-anchor=\"middle\" x=\"19\" y=\"-14.8\" font-family=\"Times,serif\" font-size=\"14.00\">w1</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"38,-0.5 38,-36.5 \"/>\n", "<text text-anchor=\"middle\" x=\"91.5\" y=\"-14.8\" font-family=\"Times,serif\" font-size=\"14.00\">data &#45;3.0000</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"145,-0.5 145,-36.5 \"/>\n", "<text text-anchor=\"middle\" x=\"197\" y=\"-14.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.0000</text>\n", "</g>\n", "<!-- 140390867926800&#45;&gt;140390867923392* -->\n", "<g id=\"edge6\" class=\"edge\">\n", "<title>140390867926800&#45;&gt;140390867923392*</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M214.42,-36.55C226.15,-39.54 237.93,-42.87 249,-46.5 259.46,-49.93 270.57,-54.47 280.46,-58.84\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"279.31,-62.17 289.87,-63.12 282.21,-55.79 279.31,-62.17\"/>\n", "</g>\n", "<!-- 140390867774768 -->\n", "<g id=\"node10\" class=\"node\">\n", "<title>140390867774768</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"4.5,-165.5 4.5,-201.5 244.5,-201.5 244.5,-165.5 4.5,-165.5\"/>\n", "<text text-anchor=\"middle\" x=\"21.5\" y=\"-179.8\" font-family=\"Times,serif\" font-size=\"14.00\">x2</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"38.5,-165.5 38.5,-201.5 \"/>\n", "<text text-anchor=\"middle\" x=\"89.5\" y=\"-179.8\" font-family=\"Times,serif\" font-size=\"14.00\">data 0.0000</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"140.5,-165.5 140.5,-201.5 \"/>\n", "<text text-anchor=\"middle\" x=\"192.5\" y=\"-179.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.0000</text>\n", "</g>\n", "<!-- 140390867774768&#45;&gt;140390867925744* -->\n", "<g id=\"edge7\" class=\"edge\">\n", "<title>140390867774768&#45;&gt;140390867925744*</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M218.13,-165.46C228.63,-162.77 239.1,-159.78 249,-156.5 259.71,-152.96 271.05,-148.16 281.07,-143.54\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"282.67,-146.66 290.2,-139.21 279.67,-140.33 282.67,-146.66\"/>\n", "</g>\n", "<!-- 140390867924832 -->\n", "<g id=\"node11\" class=\"node\">\n", "<title>140390867924832</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"774,-82.5 774,-118.5 1113,-118.5 1113,-82.5 774,-82.5\"/>\n", "<text text-anchor=\"middle\" x=\"838\" y=\"-96.8\" font-family=\"Times,serif\" font-size=\"14.00\">x1*w1 + x2*w2</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"902,-82.5 902,-118.5 \"/>\n", "<text text-anchor=\"middle\" x=\"955.5\" y=\"-96.8\" font-family=\"Times,serif\" font-size=\"14.00\">data &#45;6.0000</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"1009,-82.5 1009,-118.5 \"/>\n", "<text text-anchor=\"middle\" x=\"1061\" y=\"-96.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.0000</text>\n", "</g>\n", "<!-- 140390867924832&#45;&gt;140390867925264+ -->\n", "<g id=\"edge9\" class=\"edge\">\n", "<title>140390867924832&#45;&gt;140390867925264+</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M1098.25,-118.51C1113.17,-120.26 1127.12,-121.89 1138.92,-123.27\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"1138.79,-126.78 1149.13,-124.47 1139.6,-119.83 1138.79,-126.78\"/>\n", "</g>\n", "<!-- 140390867924832+&#45;&gt;140390867924832 -->\n", "<g id=\"edge4\" class=\"edge\">\n", "<title>140390867924832+&#45;&gt;140390867924832</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M738.44,-100.5C745.81,-100.5 754.42,-100.5 763.84,-100.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"763.94,-104 773.94,-100.5 763.94,-97 763.94,-104\"/>\n", "</g>\n", "<!-- 140390867923392 -->\n", "<g id=\"node13\" class=\"node\">\n", "<title>140390867923392</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"375,-55.5 375,-91.5 648,-91.5 648,-55.5 375,-55.5\"/>\n", "<text text-anchor=\"middle\" x=\"406\" y=\"-69.8\" font-family=\"Times,serif\" font-size=\"14.00\">x1*w1</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"437,-55.5 437,-91.5 \"/>\n", "<text text-anchor=\"middle\" x=\"490.5\" y=\"-69.8\" font-family=\"Times,serif\" font-size=\"14.00\">data &#45;6.0000</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"544,-55.5 544,-91.5 \"/>\n", "<text text-anchor=\"middle\" x=\"596\" y=\"-69.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.0000</text>\n", "</g>\n", "<!-- 140390867923392&#45;&gt;140390867924832+ -->\n", "<g id=\"edge13\" class=\"edge\">\n", "<title>140390867923392&#45;&gt;140390867924832+</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M644.24,-91.51C655.12,-93 665.4,-94.4 674.45,-95.64\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"674.11,-99.13 684.49,-97.01 675.05,-92.19 674.11,-99.13\"/>\n", "</g>\n", "<!-- 140390867923392*&#45;&gt;140390867923392 -->\n", "<g id=\"edge5\" class=\"edge\">\n", "<title>140390867923392*&#45;&gt;140390867923392</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M339.23,-73.5C346.7,-73.5 355.41,-73.5 364.87,-73.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"364.98,-77 374.98,-73.5 364.98,-70 364.98,-77\"/>\n", "</g>\n", "<!-- 140390867772896 -->\n", "<g id=\"node15\" class=\"node\">\n", "<title>140390867772896</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"2.5,-110.5 2.5,-146.5 246.5,-146.5 246.5,-110.5 2.5,-110.5\"/>\n", "<text text-anchor=\"middle\" x=\"21.5\" y=\"-124.8\" font-family=\"Times,serif\" font-size=\"14.00\">w2</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"40.5,-110.5 40.5,-146.5 \"/>\n", "<text text-anchor=\"middle\" x=\"91.5\" y=\"-124.8\" font-family=\"Times,serif\" font-size=\"14.00\">data 1.0000</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"142.5,-110.5 142.5,-146.5 \"/>\n", "<text text-anchor=\"middle\" x=\"194.5\" y=\"-124.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.0000</text>\n", "</g>\n", "<!-- 140390867772896&#45;&gt;140390867925744* -->\n", "<g id=\"edge12\" class=\"edge\">\n", "<title>140390867772896&#45;&gt;140390867925744*</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M246.64,-128.5C256.7,-128.5 266.26,-128.5 274.79,-128.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"275,-132 285,-128.5 275,-125 275,-132\"/>\n", "</g>\n", "</g>\n", "</svg>\n" ], "text/plain": [ "<graphviz.graphs.Digraph at 0x7faf4bcda610>" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "draw_dot(o)" ] }, { "cell_type": "markdown", "id": "fbd6d6ee-79a2-4d46-972f-d1777da855f9", "metadata": {}, "source": [ "## Let's backpropagate!\n", "- we care about the derivative of o wrt to w1, w2, b\n", "- those are called do/dw1; do/dw2; do/db respectively" ] }, { "cell_type": "code", "execution_count": 8, "id": "e6504316-3cbb-45ff-8c1a-e762669f0f5c", "metadata": { "tags": [] }, "outputs": [], "source": [ "o.backward() # will set o.grad to 1 and recurse" ] }, { "cell_type": "code", "execution_count": 9, "id": "60c13ab3-e237-4c69-bdab-234197ef1cf6", "metadata": { "tags": [] }, "outputs": [ { "data": { "image/svg+xml": [ "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n", "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n", " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n", "<!-- Generated by graphviz version 2.43.0 (0)\n", " -->\n", "<!-- Title: %3 Pages: 1 -->\n", "<svg width=\"1846pt\" height=\"210pt\"\n", " viewBox=\"0.00 0.00 1845.69 210.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n", "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 206)\">\n", "<title>%3</title>\n", "<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-206 1841.69,-206 1841.69,4 -4,4\"/>\n", "<!-- 140390867926560 -->\n", "<g id=\"node1\" class=\"node\">\n", "<title>140390867926560</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"827.5,-137.5 827.5,-173.5 1059.5,-173.5 1059.5,-137.5 827.5,-137.5\"/>\n", "<text text-anchor=\"middle\" x=\"840.5\" y=\"-151.8\" font-family=\"Times,serif\" font-size=\"14.00\">b</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"853.5,-137.5 853.5,-173.5 \"/>\n", "<text text-anchor=\"middle\" x=\"904.5\" y=\"-151.8\" font-family=\"Times,serif\" font-size=\"14.00\">data 6.8814</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"955.5,-137.5 955.5,-173.5 \"/>\n", "<text text-anchor=\"middle\" x=\"1007.5\" y=\"-151.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.5000</text>\n", "</g>\n", "<!-- 140390867925264+ -->\n", "<g id=\"node8\" class=\"node\">\n", "<title>140390867925264+</title>\n", "<ellipse fill=\"none\" stroke=\"black\" cx=\"1176\" cy=\"-127.5\" rx=\"27\" ry=\"18\"/>\n", "<text text-anchor=\"middle\" x=\"1176\" y=\"-123.8\" font-family=\"Times,serif\" font-size=\"14.00\">+</text>\n", "</g>\n", "<!-- 140390867926560&#45;&gt;140390867925264+ -->\n", "<g id=\"edge8\" class=\"edge\">\n", "<title>140390867926560&#45;&gt;140390867925264+</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M1059.75,-141.5C1088.57,-138 1117.4,-134.5 1139.01,-131.87\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"1139.5,-135.34 1149.01,-130.66 1138.66,-128.39 1139.5,-135.34\"/>\n", "</g>\n", "<!-- 140390867774576 -->\n", "<g id=\"node2\" class=\"node\">\n", "<title>140390867774576</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"2.5,-55.5 2.5,-91.5 246.5,-91.5 246.5,-55.5 2.5,-55.5\"/>\n", "<text text-anchor=\"middle\" x=\"19.5\" y=\"-69.8\" font-family=\"Times,serif\" font-size=\"14.00\">x1</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"36.5,-55.5 36.5,-91.5 \"/>\n", "<text text-anchor=\"middle\" x=\"87.5\" y=\"-69.8\" font-family=\"Times,serif\" font-size=\"14.00\">data 2.0000</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"138.5,-55.5 138.5,-91.5 \"/>\n", "<text text-anchor=\"middle\" x=\"192.5\" y=\"-69.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad &#45;1.5000</text>\n", "</g>\n", "<!-- 140390867923392* -->\n", "<g id=\"node14\" class=\"node\">\n", "<title>140390867923392*</title>\n", "<ellipse fill=\"none\" stroke=\"black\" cx=\"312\" cy=\"-73.5\" rx=\"27\" ry=\"18\"/>\n", "<text text-anchor=\"middle\" x=\"312\" y=\"-69.8\" font-family=\"Times,serif\" font-size=\"14.00\">*</text>\n", "</g>\n", "<!-- 140390867774576&#45;&gt;140390867923392* -->\n", "<g id=\"edge10\" class=\"edge\">\n", "<title>140390867774576&#45;&gt;140390867923392*</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M246.64,-73.5C256.7,-73.5 266.26,-73.5 274.79,-73.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"275,-77 285,-73.5 275,-70 275,-77\"/>\n", "</g>\n", "<!-- 140390867925744 -->\n", "<g id=\"node3\" class=\"node\">\n", "<title>140390867925744</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"377.5,-110.5 377.5,-146.5 645.5,-146.5 645.5,-110.5 377.5,-110.5\"/>\n", "<text text-anchor=\"middle\" x=\"408.5\" y=\"-124.8\" font-family=\"Times,serif\" font-size=\"14.00\">x2*w2</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"439.5,-110.5 439.5,-146.5 \"/>\n", "<text text-anchor=\"middle\" x=\"490.5\" y=\"-124.8\" font-family=\"Times,serif\" font-size=\"14.00\">data 0.0000</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"541.5,-110.5 541.5,-146.5 \"/>\n", "<text text-anchor=\"middle\" x=\"593.5\" y=\"-124.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.5000</text>\n", "</g>\n", "<!-- 140390867924832+ -->\n", "<g id=\"node12\" class=\"node\">\n", "<title>140390867924832+</title>\n", "<ellipse fill=\"none\" stroke=\"black\" cx=\"711\" cy=\"-100.5\" rx=\"27\" ry=\"18\"/>\n", "<text text-anchor=\"middle\" x=\"711\" y=\"-96.8\" font-family=\"Times,serif\" font-size=\"14.00\">+</text>\n", "</g>\n", "<!-- 140390867925744&#45;&gt;140390867924832+ -->\n", "<g id=\"edge14\" class=\"edge\">\n", "<title>140390867925744&#45;&gt;140390867924832+</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M639.53,-110.49C652.09,-108.71 663.99,-107.02 674.3,-105.56\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"675.03,-108.99 684.44,-104.12 674.05,-102.06 675.03,-108.99\"/>\n", "</g>\n", "<!-- 140390867925744* -->\n", "<g id=\"node4\" class=\"node\">\n", "<title>140390867925744*</title>\n", "<ellipse fill=\"none\" stroke=\"black\" cx=\"312\" cy=\"-128.5\" rx=\"27\" ry=\"18\"/>\n", "<text text-anchor=\"middle\" x=\"312\" y=\"-124.8\" font-family=\"Times,serif\" font-size=\"14.00\">*</text>\n", "</g>\n", "<!-- 140390867925744*&#45;&gt;140390867925744 -->\n", "<g id=\"edge1\" class=\"edge\">\n", "<title>140390867925744*&#45;&gt;140390867925744</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M339.23,-128.5C347.26,-128.5 356.72,-128.5 366.99,-128.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"367.08,-132 377.08,-128.5 367.08,-125 367.08,-132\"/>\n", "</g>\n", "<!-- 140390867926272 -->\n", "<g id=\"node5\" class=\"node\">\n", "<title>140390867926272</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"1606.69,-109.5 1606.69,-145.5 1837.69,-145.5 1837.69,-109.5 1606.69,-109.5\"/>\n", "<text text-anchor=\"middle\" x=\"1619.19\" y=\"-123.8\" font-family=\"Times,serif\" font-size=\"14.00\">o</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"1631.69,-109.5 1631.69,-145.5 \"/>\n", "<text text-anchor=\"middle\" x=\"1682.69\" y=\"-123.8\" font-family=\"Times,serif\" font-size=\"14.00\">data 0.7071</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"1733.69,-109.5 1733.69,-145.5 \"/>\n", "<text text-anchor=\"middle\" x=\"1785.69\" y=\"-123.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad 1.0000</text>\n", "</g>\n", "<!-- 140390867926272tanh -->\n", "<g id=\"node6\" class=\"node\">\n", "<title>140390867926272tanh</title>\n", "<ellipse fill=\"none\" stroke=\"black\" cx=\"1538.85\" cy=\"-127.5\" rx=\"31.7\" ry=\"18\"/>\n", "<text text-anchor=\"middle\" x=\"1538.85\" y=\"-123.8\" font-family=\"Times,serif\" font-size=\"14.00\">tanh</text>\n", "</g>\n", "<!-- 140390867926272tanh&#45;&gt;140390867926272 -->\n", "<g id=\"edge2\" class=\"edge\">\n", "<title>140390867926272tanh&#45;&gt;140390867926272</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M1571.02,-127.5C1578.61,-127.5 1587.19,-127.5 1596.31,-127.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"1596.4,-131 1606.4,-127.5 1596.4,-124 1596.4,-131\"/>\n", "</g>\n", "<!-- 140390867925264 -->\n", "<g id=\"node7\" class=\"node\">\n", "<title>140390867925264</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"1239,-109.5 1239,-145.5 1471,-145.5 1471,-109.5 1239,-109.5\"/>\n", "<text text-anchor=\"middle\" x=\"1252\" y=\"-123.8\" font-family=\"Times,serif\" font-size=\"14.00\">n</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"1265,-109.5 1265,-145.5 \"/>\n", "<text text-anchor=\"middle\" x=\"1316\" y=\"-123.8\" font-family=\"Times,serif\" font-size=\"14.00\">data 0.8814</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"1367,-109.5 1367,-145.5 \"/>\n", "<text text-anchor=\"middle\" x=\"1419\" y=\"-123.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.5000</text>\n", "</g>\n", "<!-- 140390867925264&#45;&gt;140390867926272tanh -->\n", "<g id=\"edge11\" class=\"edge\">\n", "<title>140390867925264&#45;&gt;140390867926272tanh</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M1471.17,-127.5C1480.08,-127.5 1488.66,-127.5 1496.53,-127.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"1496.78,-131 1506.78,-127.5 1496.78,-124 1496.78,-131\"/>\n", "</g>\n", "<!-- 140390867925264+&#45;&gt;140390867925264 -->\n", "<g id=\"edge3\" class=\"edge\">\n", "<title>140390867925264+&#45;&gt;140390867925264</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M1203.12,-127.5C1210.53,-127.5 1219.14,-127.5 1228.4,-127.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"1228.69,-131 1238.69,-127.5 1228.69,-124 1228.69,-131\"/>\n", "</g>\n", "<!-- 140390867926800 -->\n", "<g id=\"node9\" class=\"node\">\n", "<title>140390867926800</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"0,-0.5 0,-36.5 249,-36.5 249,-0.5 0,-0.5\"/>\n", "<text text-anchor=\"middle\" x=\"19\" y=\"-14.8\" font-family=\"Times,serif\" font-size=\"14.00\">w1</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"38,-0.5 38,-36.5 \"/>\n", "<text text-anchor=\"middle\" x=\"91.5\" y=\"-14.8\" font-family=\"Times,serif\" font-size=\"14.00\">data &#45;3.0000</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"145,-0.5 145,-36.5 \"/>\n", "<text text-anchor=\"middle\" x=\"197\" y=\"-14.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad 1.0000</text>\n", "</g>\n", "<!-- 140390867926800&#45;&gt;140390867923392* -->\n", "<g id=\"edge6\" class=\"edge\">\n", "<title>140390867926800&#45;&gt;140390867923392*</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M214.42,-36.55C226.15,-39.54 237.93,-42.87 249,-46.5 259.46,-49.93 270.57,-54.47 280.46,-58.84\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"279.31,-62.17 289.87,-63.12 282.21,-55.79 279.31,-62.17\"/>\n", "</g>\n", "<!-- 140390867774768 -->\n", "<g id=\"node10\" class=\"node\">\n", "<title>140390867774768</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"4.5,-165.5 4.5,-201.5 244.5,-201.5 244.5,-165.5 4.5,-165.5\"/>\n", "<text text-anchor=\"middle\" x=\"21.5\" y=\"-179.8\" font-family=\"Times,serif\" font-size=\"14.00\">x2</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"38.5,-165.5 38.5,-201.5 \"/>\n", "<text text-anchor=\"middle\" x=\"89.5\" y=\"-179.8\" font-family=\"Times,serif\" font-size=\"14.00\">data 0.0000</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"140.5,-165.5 140.5,-201.5 \"/>\n", "<text text-anchor=\"middle\" x=\"192.5\" y=\"-179.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.5000</text>\n", "</g>\n", "<!-- 140390867774768&#45;&gt;140390867925744* -->\n", "<g id=\"edge7\" class=\"edge\">\n", "<title>140390867774768&#45;&gt;140390867925744*</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M218.13,-165.46C228.63,-162.77 239.1,-159.78 249,-156.5 259.71,-152.96 271.05,-148.16 281.07,-143.54\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"282.67,-146.66 290.2,-139.21 279.67,-140.33 282.67,-146.66\"/>\n", "</g>\n", "<!-- 140390867924832 -->\n", "<g id=\"node11\" class=\"node\">\n", "<title>140390867924832</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"774,-82.5 774,-118.5 1113,-118.5 1113,-82.5 774,-82.5\"/>\n", "<text text-anchor=\"middle\" x=\"838\" y=\"-96.8\" font-family=\"Times,serif\" font-size=\"14.00\">x1*w1 + x2*w2</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"902,-82.5 902,-118.5 \"/>\n", "<text text-anchor=\"middle\" x=\"955.5\" y=\"-96.8\" font-family=\"Times,serif\" font-size=\"14.00\">data &#45;6.0000</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"1009,-82.5 1009,-118.5 \"/>\n", "<text text-anchor=\"middle\" x=\"1061\" y=\"-96.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.5000</text>\n", "</g>\n", "<!-- 140390867924832&#45;&gt;140390867925264+ -->\n", "<g id=\"edge9\" class=\"edge\">\n", "<title>140390867924832&#45;&gt;140390867925264+</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M1098.25,-118.51C1113.17,-120.26 1127.12,-121.89 1138.92,-123.27\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"1138.79,-126.78 1149.13,-124.47 1139.6,-119.83 1138.79,-126.78\"/>\n", "</g>\n", "<!-- 140390867924832+&#45;&gt;140390867924832 -->\n", "<g id=\"edge4\" class=\"edge\">\n", "<title>140390867924832+&#45;&gt;140390867924832</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M738.44,-100.5C745.81,-100.5 754.42,-100.5 763.84,-100.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"763.94,-104 773.94,-100.5 763.94,-97 763.94,-104\"/>\n", "</g>\n", "<!-- 140390867923392 -->\n", "<g id=\"node13\" class=\"node\">\n", "<title>140390867923392</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"375,-55.5 375,-91.5 648,-91.5 648,-55.5 375,-55.5\"/>\n", "<text text-anchor=\"middle\" x=\"406\" y=\"-69.8\" font-family=\"Times,serif\" font-size=\"14.00\">x1*w1</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"437,-55.5 437,-91.5 \"/>\n", "<text text-anchor=\"middle\" x=\"490.5\" y=\"-69.8\" font-family=\"Times,serif\" font-size=\"14.00\">data &#45;6.0000</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"544,-55.5 544,-91.5 \"/>\n", "<text text-anchor=\"middle\" x=\"596\" y=\"-69.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.5000</text>\n", "</g>\n", "<!-- 140390867923392&#45;&gt;140390867924832+ -->\n", "<g id=\"edge13\" class=\"edge\">\n", "<title>140390867923392&#45;&gt;140390867924832+</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M644.24,-91.51C655.12,-93 665.4,-94.4 674.45,-95.64\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"674.11,-99.13 684.49,-97.01 675.05,-92.19 674.11,-99.13\"/>\n", "</g>\n", "<!-- 140390867923392*&#45;&gt;140390867923392 -->\n", "<g id=\"edge5\" class=\"edge\">\n", "<title>140390867923392*&#45;&gt;140390867923392</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M339.23,-73.5C346.7,-73.5 355.41,-73.5 364.87,-73.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"364.98,-77 374.98,-73.5 364.98,-70 364.98,-77\"/>\n", "</g>\n", "<!-- 140390867772896 -->\n", "<g id=\"node15\" class=\"node\">\n", "<title>140390867772896</title>\n", "<polygon fill=\"none\" stroke=\"black\" points=\"2.5,-110.5 2.5,-146.5 246.5,-146.5 246.5,-110.5 2.5,-110.5\"/>\n", "<text text-anchor=\"middle\" x=\"21.5\" y=\"-124.8\" font-family=\"Times,serif\" font-size=\"14.00\">w2</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"40.5,-110.5 40.5,-146.5 \"/>\n", "<text text-anchor=\"middle\" x=\"91.5\" y=\"-124.8\" font-family=\"Times,serif\" font-size=\"14.00\">data 1.0000</text>\n", "<polyline fill=\"none\" stroke=\"black\" points=\"142.5,-110.5 142.5,-146.5 \"/>\n", "<text text-anchor=\"middle\" x=\"194.5\" y=\"-124.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.0000</text>\n", "</g>\n", "<!-- 140390867772896&#45;&gt;140390867925744* -->\n", "<g id=\"edge12\" class=\"edge\">\n", "<title>140390867772896&#45;&gt;140390867925744*</title>\n", "<path fill=\"none\" stroke=\"black\" d=\"M246.64,-128.5C256.7,-128.5 266.26,-128.5 274.79,-128.5\"/>\n", "<polygon fill=\"black\" stroke=\"black\" points=\"275,-132 285,-128.5 275,-125 275,-132\"/>\n", "</g>\n", "</g>\n", "</svg>\n" ], "text/plain": [ "<graphviz.graphs.Digraph at 0x7faf4bcff3a0>" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "draw_dot(o) # now the graph shows the grad for o wrt to each node" ] }, { "cell_type": "markdown", "id": "97610480-b29b-4928-b57b-b7fe8a417ea5", "metadata": {}, "source": [ "## Manually verify some of the gradients" ] }, { "cell_type": "code", "execution_count": 10, "id": "c65bd5fd-4687-42b8-8b27-8fe3cda6a58a", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "0.4999999999999999" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# o = tanh(n)\n", "# do/dn = 1 - o**2\n", "(1 - o.data**2)" ] }, { "cell_type": "code", "execution_count": 11, "id": "8b409694-f2d3-4a4b-a8f2-a2ba6976b785", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "0.4999999999999999" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# our first chain rule example!\n", "# intuition: \"If a car travels twice as fast as a bicycle and the bicycle is four times as fast as a walking man, then the car travels 2 × 4 = 8 times as fast as the man\"\n", "\n", "# o = tanh(n)\n", "# n = x1*w1 + x2*w2 + b\n", "# do/db = do/dn * dn/db\n", "(1 - o.data**2) * 1 " ] }, { "cell_type": "code", "execution_count": 12, "id": "7e563dc8-3fac-4152-b3c9-05d9760bb6cf", "metadata": {}, "outputs": [], "source": [ "# derive for do/dj, where j is x1*w1+x2*w2\n" ] }, { "cell_type": "markdown", "id": "2f5aa4b2-388d-440f-b4eb-dda1167047c3", "metadata": {}, "source": [ "# In Pytorch\n", "- And finally we can see that Pytorch is just doing the same thing!" ] }, { "cell_type": "code", "execution_count": 13, "id": "bc4eed18-9092-45c6-a7bb-8e37dbd2ceaf", "metadata": { "tags": [] }, "outputs": [], "source": [ "# this is what that same expression looks like in Pytorch\n", "x1 = torch.Tensor([2.0]).double() ; x1.requires_grad = True\n", "x2 = torch.Tensor([0.0]).double() ; x2.requires_grad = True\n", "w1 = torch.Tensor([-3.0]).double() ; w1.requires_grad = True\n", "w2 = torch.Tensor([1.0]).double() ; w2.requires_grad = True\n", "b = torch.Tensor([6.8813735870195432]).double() ; b.requires_grad = True\n", "n = x1*w1 + x2*w2 + b\n", "o = torch.tanh(n)" ] }, { "cell_type": "markdown", "id": "0bd63dfc-8a33-4a78-be75-9d37495b3c71", "metadata": {}, "source": [ "## Verify output and backpropagate in Pytorch" ] }, { "cell_type": "code", "execution_count": 14, "id": "fbddfa4e-b4a2-460e-9e6d-dae70ec4d658", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.7071066904050358\n" ] } ], "source": [ "print(o.data.item())\n", "o.backward()" ] }, { "cell_type": "markdown", "id": "9535b1a2-2f43-4389-a597-88288f5fa0cf", "metadata": {}, "source": [ "## Are the numbers correct?" ] }, { "cell_type": "code", "execution_count": 15, "id": "4261bf49-6af8-42a1-88aa-36b52dc7b9cc", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "---\n", "x2 0.5000001283844369\n", "w2 0.0\n", "x1 -1.5000003851533106\n", "w1 1.0000002567688737\n" ] } ], "source": [ "print('---')\n", "print('x2', x2.grad.item())\n", "print('w2', w2.grad.item())\n", "print('x1', x1.grad.item())\n", "print('w1', w1.grad.item())" ] }, { "cell_type": "code", "execution_count": null, "id": "46e36130-4ed1-4c70-82a1-fa32d8304b6a", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.5" } }, "nbformat": 4, "nbformat_minor": 5 }