session1/deep-learning.ipynb (987 lines of code) (raw):
{
"cells": [
{
"cell_type": "markdown",
"id": "542c9b3d-4661-4172-8ffe-67cd9a27e0d4",
"metadata": {},
"source": [
"# Utils & Setup\n",
"- all taken from micrograd (https://github.com/karpathy/micrograd)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "5f23a85d-0b52-4eac-ae5c-30526ed217c7",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import math\n",
"\n",
"import torch\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "27d0d364-27aa-4ae2-90fb-b44eae300b87",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# you can ignore this cell\n",
"\n",
"from graphviz import Digraph\n",
"\n",
"def trace(root):\n",
" # builds a set of all nodes and edges in a graph\n",
" nodes, edges = set(), set()\n",
" def build(v):\n",
" if v not in nodes:\n",
" nodes.add(v)\n",
" for child in v._prev:\n",
" edges.add((child, v))\n",
" build(child)\n",
" build(root)\n",
" return nodes, edges\n",
"\n",
"def draw_dot(root, format='svg'):\n",
" dot = Digraph(format=format, graph_attr={'rankdir': 'LR'}) # LR = left to right\n",
" \n",
" nodes, edges = trace(root)\n",
" for n in nodes:\n",
" uid = str(id(n))\n",
" # for any value in the graph, create a rectangular ('record') node for it\n",
" dot.node(name = uid, label = \"{ %s | data %.4f | grad %.4f }\" % (n.label, n.data, n.grad), shape='record')\n",
" if n._op:\n",
" # if this value is a result of some operation, create an op node for it\n",
" dot.node(name = uid + n._op, label = n._op)\n",
" # and connect this node to it\n",
" dot.edge(uid + n._op, uid)\n",
"\n",
" for n1, n2 in edges:\n",
" # connect n1 to the op node of n2\n",
" dot.edge(str(id(n1)), str(id(n2)) + n2._op)\n",
"\n",
" return dot\n"
]
},
{
"cell_type": "markdown",
"id": "1fc3bf60-138a-43da-afbd-774c71e8c53d",
"metadata": {},
"source": [
"# Quick tour of micrograd"
]
},
{
"cell_type": "markdown",
"id": "10ef059a-7741-484f-9e5f-06e3743a95f5",
"metadata": {},
"source": [
"## The Value class"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "be63a1b8-1c24-4a42-b630-49ff4f53eb65",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"class Value:\n",
" \"\"\"Smaller subset of the full Value from micrograd. \"\"\"\n",
"\n",
" def __init__(self, data, _children=(), _op='', label=''):\n",
" self.data = data\n",
" self.grad = 0.0\n",
" self._backward = lambda: None\n",
" self._prev = set(_children)\n",
" self._op = _op\n",
" self.label = label\n",
"\n",
" def __repr__(self):\n",
" return f\"Value(data={self.data})\"\n",
" \n",
" def __add__(self, other):\n",
" out = Value(self.data + other.data, (self, other), '+')\n",
" \n",
" def _backward():\n",
" self.grad += 1.0 * out.grad\n",
" other.grad += 1.0 * out.grad\n",
" out._backward = _backward\n",
" \n",
" return out\n",
"\n",
" def __mul__(self, other):\n",
" out = Value(self.data * other.data, (self, other), '*')\n",
" \n",
" def _backward():\n",
" self.grad += other.data * out.grad\n",
" other.grad += self.data * out.grad\n",
" out._backward = _backward\n",
" \n",
" return out\n",
" \n",
" def tanh(self):\n",
" x = self.data\n",
" t = (math.exp(2*x) - 1)/(math.exp(2*x) + 1)\n",
" out = Value(t, (self, ), 'tanh')\n",
" \n",
" def _backward():\n",
" self.grad += (1 - t**2) * out.grad\n",
" out._backward = _backward\n",
" \n",
" return out\n",
" \n",
" def backward(self):\n",
" # builds a topological sort using DFS (recursion)\n",
" topo = []\n",
" visited = set()\n",
" def build_topo(v):\n",
" if v not in visited:\n",
" visited.add(v)\n",
" for child in v._prev:\n",
" build_topo(child)\n",
" topo.append(v)\n",
" build_topo(self)\n",
" \n",
" self.grad = 1.0\n",
" for node in reversed(topo):\n",
" node._backward()"
]
},
{
"cell_type": "markdown",
"id": "58cc7b67-6733-4837-8026-57aa0185d450",
"metadata": {},
"source": [
"## Sample expression"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "95891e0a-a235-4881-8487-00c18f50a93e",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"Value(data=-8.0)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# build a sample expression\n",
"a = Value(2.0, label='a')\n",
"b = Value(-3.0, label='b')\n",
"c = Value(10.0, label='c')\n",
"e = a*b; e.label = 'e'\n",
"d = e + c; d.label = 'd'\n",
"f = Value(-2.0, label='f')\n",
"L = d * f; L.label = 'L'\n",
"L"
]
},
{
"cell_type": "markdown",
"id": "7eb22cf7-9f3a-4816-b077-b843b3340a96",
"metadata": {},
"source": [
"# How does deep learning work?\n",
"- Let's create the expression of a neuron that takes two inputs, x1 and x2.\n",
"- There are two weights plus a bias, for a total of 3 parameters.\n",
"- The activation function will be a tanh (it's an S-looking curve)."
]
},
{
"cell_type": "markdown",
"id": "68bb549a-190e-4305-982d-e7259c582a40",
"metadata": {},
"source": [
"## Math expression for a 2-input neuron"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "105ab959-ffcb-459a-8933-3169c0fa9b07",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# inputs x1,x2\n",
"x1 = Value(2.0, label='x1')\n",
"x2 = Value(0.0, label='x2')\n",
"# weights w1,w2\n",
"w1 = Value(-3.0, label='w1')\n",
"w2 = Value(1.0, label='w2')\n",
"# bias of the neuron\n",
"b = Value(6.8813735870195432, label='b')\n",
"# x1*w1 + x2*w2 + b\n",
"x1w1 = x1*w1; x1w1.label = 'x1*w1'\n",
"x2w2 = x2*w2; x2w2.label = 'x2*w2'\n",
"x1w1x2w2 = x1w1 + x2w2; x1w1x2w2.label = 'x1*w1 + x2*w2'\n",
"n = x1w1x2w2 + b; n.label = 'n'\n",
"o = n.tanh(); o.label = 'o'"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "ec16b0ef-40b9-4666-b81b-1b1a2d0fb7a0",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"0.7071067811865476"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"o.data # tanh(2*-3 + 0*1 + 6.88)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "00b1b311-192a-4e95-857c-769fbb5d75cf",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Generated by graphviz version 2.43.0 (0)\n",
" -->\n",
"<!-- Title: %3 Pages: 1 -->\n",
"<svg width=\"1846pt\" height=\"210pt\"\n",
" viewBox=\"0.00 0.00 1845.69 210.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 206)\">\n",
"<title>%3</title>\n",
"<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-206 1841.69,-206 1841.69,4 -4,4\"/>\n",
"<!-- 140390867926560 -->\n",
"<g id=\"node1\" class=\"node\">\n",
"<title>140390867926560</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"827.5,-137.5 827.5,-173.5 1059.5,-173.5 1059.5,-137.5 827.5,-137.5\"/>\n",
"<text text-anchor=\"middle\" x=\"840.5\" y=\"-151.8\" font-family=\"Times,serif\" font-size=\"14.00\">b</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"853.5,-137.5 853.5,-173.5 \"/>\n",
"<text text-anchor=\"middle\" x=\"904.5\" y=\"-151.8\" font-family=\"Times,serif\" font-size=\"14.00\">data 6.8814</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"955.5,-137.5 955.5,-173.5 \"/>\n",
"<text text-anchor=\"middle\" x=\"1007.5\" y=\"-151.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.0000</text>\n",
"</g>\n",
"<!-- 140390867925264+ -->\n",
"<g id=\"node8\" class=\"node\">\n",
"<title>140390867925264+</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"1176\" cy=\"-127.5\" rx=\"27\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"1176\" y=\"-123.8\" font-family=\"Times,serif\" font-size=\"14.00\">+</text>\n",
"</g>\n",
"<!-- 140390867926560->140390867925264+ -->\n",
"<g id=\"edge8\" class=\"edge\">\n",
"<title>140390867926560->140390867925264+</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M1059.75,-141.5C1088.57,-138 1117.4,-134.5 1139.01,-131.87\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"1139.5,-135.34 1149.01,-130.66 1138.66,-128.39 1139.5,-135.34\"/>\n",
"</g>\n",
"<!-- 140390867774576 -->\n",
"<g id=\"node2\" class=\"node\">\n",
"<title>140390867774576</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"4.5,-55.5 4.5,-91.5 244.5,-91.5 244.5,-55.5 4.5,-55.5\"/>\n",
"<text text-anchor=\"middle\" x=\"21.5\" y=\"-69.8\" font-family=\"Times,serif\" font-size=\"14.00\">x1</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"38.5,-55.5 38.5,-91.5 \"/>\n",
"<text text-anchor=\"middle\" x=\"89.5\" y=\"-69.8\" font-family=\"Times,serif\" font-size=\"14.00\">data 2.0000</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"140.5,-55.5 140.5,-91.5 \"/>\n",
"<text text-anchor=\"middle\" x=\"192.5\" y=\"-69.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.0000</text>\n",
"</g>\n",
"<!-- 140390867923392* -->\n",
"<g id=\"node14\" class=\"node\">\n",
"<title>140390867923392*</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"312\" cy=\"-73.5\" rx=\"27\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"312\" y=\"-69.8\" font-family=\"Times,serif\" font-size=\"14.00\">*</text>\n",
"</g>\n",
"<!-- 140390867774576->140390867923392* -->\n",
"<g id=\"edge10\" class=\"edge\">\n",
"<title>140390867774576->140390867923392*</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M244.55,-73.5C255.31,-73.5 265.54,-73.5 274.62,-73.5\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"274.7,-77 284.7,-73.5 274.7,-70 274.7,-77\"/>\n",
"</g>\n",
"<!-- 140390867925744 -->\n",
"<g id=\"node3\" class=\"node\">\n",
"<title>140390867925744</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"377.5,-110.5 377.5,-146.5 645.5,-146.5 645.5,-110.5 377.5,-110.5\"/>\n",
"<text text-anchor=\"middle\" x=\"408.5\" y=\"-124.8\" font-family=\"Times,serif\" font-size=\"14.00\">x2*w2</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"439.5,-110.5 439.5,-146.5 \"/>\n",
"<text text-anchor=\"middle\" x=\"490.5\" y=\"-124.8\" font-family=\"Times,serif\" font-size=\"14.00\">data 0.0000</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"541.5,-110.5 541.5,-146.5 \"/>\n",
"<text text-anchor=\"middle\" x=\"593.5\" y=\"-124.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.0000</text>\n",
"</g>\n",
"<!-- 140390867924832+ -->\n",
"<g id=\"node12\" class=\"node\">\n",
"<title>140390867924832+</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"711\" cy=\"-100.5\" rx=\"27\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"711\" y=\"-96.8\" font-family=\"Times,serif\" font-size=\"14.00\">+</text>\n",
"</g>\n",
"<!-- 140390867925744->140390867924832+ -->\n",
"<g id=\"edge14\" class=\"edge\">\n",
"<title>140390867925744->140390867924832+</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M639.53,-110.49C652.09,-108.71 663.99,-107.02 674.3,-105.56\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"675.03,-108.99 684.44,-104.12 674.05,-102.06 675.03,-108.99\"/>\n",
"</g>\n",
"<!-- 140390867925744* -->\n",
"<g id=\"node4\" class=\"node\">\n",
"<title>140390867925744*</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"312\" cy=\"-128.5\" rx=\"27\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"312\" y=\"-124.8\" font-family=\"Times,serif\" font-size=\"14.00\">*</text>\n",
"</g>\n",
"<!-- 140390867925744*->140390867925744 -->\n",
"<g id=\"edge1\" class=\"edge\">\n",
"<title>140390867925744*->140390867925744</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M339.23,-128.5C347.26,-128.5 356.72,-128.5 366.99,-128.5\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"367.08,-132 377.08,-128.5 367.08,-125 367.08,-132\"/>\n",
"</g>\n",
"<!-- 140390867926272 -->\n",
"<g id=\"node5\" class=\"node\">\n",
"<title>140390867926272</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"1606.69,-109.5 1606.69,-145.5 1837.69,-145.5 1837.69,-109.5 1606.69,-109.5\"/>\n",
"<text text-anchor=\"middle\" x=\"1619.19\" y=\"-123.8\" font-family=\"Times,serif\" font-size=\"14.00\">o</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"1631.69,-109.5 1631.69,-145.5 \"/>\n",
"<text text-anchor=\"middle\" x=\"1682.69\" y=\"-123.8\" font-family=\"Times,serif\" font-size=\"14.00\">data 0.7071</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"1733.69,-109.5 1733.69,-145.5 \"/>\n",
"<text text-anchor=\"middle\" x=\"1785.69\" y=\"-123.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.0000</text>\n",
"</g>\n",
"<!-- 140390867926272tanh -->\n",
"<g id=\"node6\" class=\"node\">\n",
"<title>140390867926272tanh</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"1538.85\" cy=\"-127.5\" rx=\"31.7\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"1538.85\" y=\"-123.8\" font-family=\"Times,serif\" font-size=\"14.00\">tanh</text>\n",
"</g>\n",
"<!-- 140390867926272tanh->140390867926272 -->\n",
"<g id=\"edge2\" class=\"edge\">\n",
"<title>140390867926272tanh->140390867926272</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M1571.02,-127.5C1578.61,-127.5 1587.19,-127.5 1596.31,-127.5\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"1596.4,-131 1606.4,-127.5 1596.4,-124 1596.4,-131\"/>\n",
"</g>\n",
"<!-- 140390867925264 -->\n",
"<g id=\"node7\" class=\"node\">\n",
"<title>140390867925264</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"1239,-109.5 1239,-145.5 1471,-145.5 1471,-109.5 1239,-109.5\"/>\n",
"<text text-anchor=\"middle\" x=\"1252\" y=\"-123.8\" font-family=\"Times,serif\" font-size=\"14.00\">n</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"1265,-109.5 1265,-145.5 \"/>\n",
"<text text-anchor=\"middle\" x=\"1316\" y=\"-123.8\" font-family=\"Times,serif\" font-size=\"14.00\">data 0.8814</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"1367,-109.5 1367,-145.5 \"/>\n",
"<text text-anchor=\"middle\" x=\"1419\" y=\"-123.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.0000</text>\n",
"</g>\n",
"<!-- 140390867925264->140390867926272tanh -->\n",
"<g id=\"edge11\" class=\"edge\">\n",
"<title>140390867925264->140390867926272tanh</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M1471.17,-127.5C1480.08,-127.5 1488.66,-127.5 1496.53,-127.5\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"1496.78,-131 1506.78,-127.5 1496.78,-124 1496.78,-131\"/>\n",
"</g>\n",
"<!-- 140390867925264+->140390867925264 -->\n",
"<g id=\"edge3\" class=\"edge\">\n",
"<title>140390867925264+->140390867925264</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M1203.12,-127.5C1210.53,-127.5 1219.14,-127.5 1228.4,-127.5\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"1228.69,-131 1238.69,-127.5 1228.69,-124 1228.69,-131\"/>\n",
"</g>\n",
"<!-- 140390867926800 -->\n",
"<g id=\"node9\" class=\"node\">\n",
"<title>140390867926800</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"0,-0.5 0,-36.5 249,-36.5 249,-0.5 0,-0.5\"/>\n",
"<text text-anchor=\"middle\" x=\"19\" y=\"-14.8\" font-family=\"Times,serif\" font-size=\"14.00\">w1</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"38,-0.5 38,-36.5 \"/>\n",
"<text text-anchor=\"middle\" x=\"91.5\" y=\"-14.8\" font-family=\"Times,serif\" font-size=\"14.00\">data -3.0000</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"145,-0.5 145,-36.5 \"/>\n",
"<text text-anchor=\"middle\" x=\"197\" y=\"-14.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.0000</text>\n",
"</g>\n",
"<!-- 140390867926800->140390867923392* -->\n",
"<g id=\"edge6\" class=\"edge\">\n",
"<title>140390867926800->140390867923392*</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M214.42,-36.55C226.15,-39.54 237.93,-42.87 249,-46.5 259.46,-49.93 270.57,-54.47 280.46,-58.84\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"279.31,-62.17 289.87,-63.12 282.21,-55.79 279.31,-62.17\"/>\n",
"</g>\n",
"<!-- 140390867774768 -->\n",
"<g id=\"node10\" class=\"node\">\n",
"<title>140390867774768</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"4.5,-165.5 4.5,-201.5 244.5,-201.5 244.5,-165.5 4.5,-165.5\"/>\n",
"<text text-anchor=\"middle\" x=\"21.5\" y=\"-179.8\" font-family=\"Times,serif\" font-size=\"14.00\">x2</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"38.5,-165.5 38.5,-201.5 \"/>\n",
"<text text-anchor=\"middle\" x=\"89.5\" y=\"-179.8\" font-family=\"Times,serif\" font-size=\"14.00\">data 0.0000</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"140.5,-165.5 140.5,-201.5 \"/>\n",
"<text text-anchor=\"middle\" x=\"192.5\" y=\"-179.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.0000</text>\n",
"</g>\n",
"<!-- 140390867774768->140390867925744* -->\n",
"<g id=\"edge7\" class=\"edge\">\n",
"<title>140390867774768->140390867925744*</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M218.13,-165.46C228.63,-162.77 239.1,-159.78 249,-156.5 259.71,-152.96 271.05,-148.16 281.07,-143.54\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"282.67,-146.66 290.2,-139.21 279.67,-140.33 282.67,-146.66\"/>\n",
"</g>\n",
"<!-- 140390867924832 -->\n",
"<g id=\"node11\" class=\"node\">\n",
"<title>140390867924832</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"774,-82.5 774,-118.5 1113,-118.5 1113,-82.5 774,-82.5\"/>\n",
"<text text-anchor=\"middle\" x=\"838\" y=\"-96.8\" font-family=\"Times,serif\" font-size=\"14.00\">x1*w1 + x2*w2</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"902,-82.5 902,-118.5 \"/>\n",
"<text text-anchor=\"middle\" x=\"955.5\" y=\"-96.8\" font-family=\"Times,serif\" font-size=\"14.00\">data -6.0000</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"1009,-82.5 1009,-118.5 \"/>\n",
"<text text-anchor=\"middle\" x=\"1061\" y=\"-96.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.0000</text>\n",
"</g>\n",
"<!-- 140390867924832->140390867925264+ -->\n",
"<g id=\"edge9\" class=\"edge\">\n",
"<title>140390867924832->140390867925264+</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M1098.25,-118.51C1113.17,-120.26 1127.12,-121.89 1138.92,-123.27\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"1138.79,-126.78 1149.13,-124.47 1139.6,-119.83 1138.79,-126.78\"/>\n",
"</g>\n",
"<!-- 140390867924832+->140390867924832 -->\n",
"<g id=\"edge4\" class=\"edge\">\n",
"<title>140390867924832+->140390867924832</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M738.44,-100.5C745.81,-100.5 754.42,-100.5 763.84,-100.5\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"763.94,-104 773.94,-100.5 763.94,-97 763.94,-104\"/>\n",
"</g>\n",
"<!-- 140390867923392 -->\n",
"<g id=\"node13\" class=\"node\">\n",
"<title>140390867923392</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"375,-55.5 375,-91.5 648,-91.5 648,-55.5 375,-55.5\"/>\n",
"<text text-anchor=\"middle\" x=\"406\" y=\"-69.8\" font-family=\"Times,serif\" font-size=\"14.00\">x1*w1</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"437,-55.5 437,-91.5 \"/>\n",
"<text text-anchor=\"middle\" x=\"490.5\" y=\"-69.8\" font-family=\"Times,serif\" font-size=\"14.00\">data -6.0000</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"544,-55.5 544,-91.5 \"/>\n",
"<text text-anchor=\"middle\" x=\"596\" y=\"-69.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.0000</text>\n",
"</g>\n",
"<!-- 140390867923392->140390867924832+ -->\n",
"<g id=\"edge13\" class=\"edge\">\n",
"<title>140390867923392->140390867924832+</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M644.24,-91.51C655.12,-93 665.4,-94.4 674.45,-95.64\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"674.11,-99.13 684.49,-97.01 675.05,-92.19 674.11,-99.13\"/>\n",
"</g>\n",
"<!-- 140390867923392*->140390867923392 -->\n",
"<g id=\"edge5\" class=\"edge\">\n",
"<title>140390867923392*->140390867923392</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M339.23,-73.5C346.7,-73.5 355.41,-73.5 364.87,-73.5\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"364.98,-77 374.98,-73.5 364.98,-70 364.98,-77\"/>\n",
"</g>\n",
"<!-- 140390867772896 -->\n",
"<g id=\"node15\" class=\"node\">\n",
"<title>140390867772896</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"2.5,-110.5 2.5,-146.5 246.5,-146.5 246.5,-110.5 2.5,-110.5\"/>\n",
"<text text-anchor=\"middle\" x=\"21.5\" y=\"-124.8\" font-family=\"Times,serif\" font-size=\"14.00\">w2</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"40.5,-110.5 40.5,-146.5 \"/>\n",
"<text text-anchor=\"middle\" x=\"91.5\" y=\"-124.8\" font-family=\"Times,serif\" font-size=\"14.00\">data 1.0000</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"142.5,-110.5 142.5,-146.5 \"/>\n",
"<text text-anchor=\"middle\" x=\"194.5\" y=\"-124.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.0000</text>\n",
"</g>\n",
"<!-- 140390867772896->140390867925744* -->\n",
"<g id=\"edge12\" class=\"edge\">\n",
"<title>140390867772896->140390867925744*</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M246.64,-128.5C256.7,-128.5 266.26,-128.5 274.79,-128.5\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"275,-132 285,-128.5 275,-125 275,-132\"/>\n",
"</g>\n",
"</g>\n",
"</svg>\n"
],
"text/plain": [
"<graphviz.graphs.Digraph at 0x7faf4bcda610>"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"draw_dot(o)"
]
},
{
"cell_type": "markdown",
"id": "fbd6d6ee-79a2-4d46-972f-d1777da855f9",
"metadata": {},
"source": [
"## Let's backpropagate!\n",
"- we care about the derivative of o wrt to w1, w2, b\n",
"- those are called do/dw1; do/dw2; do/db respectively"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "e6504316-3cbb-45ff-8c1a-e762669f0f5c",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"o.backward() # will set o.grad to 1 and recurse"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "60c13ab3-e237-4c69-bdab-234197ef1cf6",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Generated by graphviz version 2.43.0 (0)\n",
" -->\n",
"<!-- Title: %3 Pages: 1 -->\n",
"<svg width=\"1846pt\" height=\"210pt\"\n",
" viewBox=\"0.00 0.00 1845.69 210.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 206)\">\n",
"<title>%3</title>\n",
"<polygon fill=\"white\" stroke=\"transparent\" points=\"-4,4 -4,-206 1841.69,-206 1841.69,4 -4,4\"/>\n",
"<!-- 140390867926560 -->\n",
"<g id=\"node1\" class=\"node\">\n",
"<title>140390867926560</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"827.5,-137.5 827.5,-173.5 1059.5,-173.5 1059.5,-137.5 827.5,-137.5\"/>\n",
"<text text-anchor=\"middle\" x=\"840.5\" y=\"-151.8\" font-family=\"Times,serif\" font-size=\"14.00\">b</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"853.5,-137.5 853.5,-173.5 \"/>\n",
"<text text-anchor=\"middle\" x=\"904.5\" y=\"-151.8\" font-family=\"Times,serif\" font-size=\"14.00\">data 6.8814</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"955.5,-137.5 955.5,-173.5 \"/>\n",
"<text text-anchor=\"middle\" x=\"1007.5\" y=\"-151.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.5000</text>\n",
"</g>\n",
"<!-- 140390867925264+ -->\n",
"<g id=\"node8\" class=\"node\">\n",
"<title>140390867925264+</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"1176\" cy=\"-127.5\" rx=\"27\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"1176\" y=\"-123.8\" font-family=\"Times,serif\" font-size=\"14.00\">+</text>\n",
"</g>\n",
"<!-- 140390867926560->140390867925264+ -->\n",
"<g id=\"edge8\" class=\"edge\">\n",
"<title>140390867926560->140390867925264+</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M1059.75,-141.5C1088.57,-138 1117.4,-134.5 1139.01,-131.87\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"1139.5,-135.34 1149.01,-130.66 1138.66,-128.39 1139.5,-135.34\"/>\n",
"</g>\n",
"<!-- 140390867774576 -->\n",
"<g id=\"node2\" class=\"node\">\n",
"<title>140390867774576</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"2.5,-55.5 2.5,-91.5 246.5,-91.5 246.5,-55.5 2.5,-55.5\"/>\n",
"<text text-anchor=\"middle\" x=\"19.5\" y=\"-69.8\" font-family=\"Times,serif\" font-size=\"14.00\">x1</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"36.5,-55.5 36.5,-91.5 \"/>\n",
"<text text-anchor=\"middle\" x=\"87.5\" y=\"-69.8\" font-family=\"Times,serif\" font-size=\"14.00\">data 2.0000</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"138.5,-55.5 138.5,-91.5 \"/>\n",
"<text text-anchor=\"middle\" x=\"192.5\" y=\"-69.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad -1.5000</text>\n",
"</g>\n",
"<!-- 140390867923392* -->\n",
"<g id=\"node14\" class=\"node\">\n",
"<title>140390867923392*</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"312\" cy=\"-73.5\" rx=\"27\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"312\" y=\"-69.8\" font-family=\"Times,serif\" font-size=\"14.00\">*</text>\n",
"</g>\n",
"<!-- 140390867774576->140390867923392* -->\n",
"<g id=\"edge10\" class=\"edge\">\n",
"<title>140390867774576->140390867923392*</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M246.64,-73.5C256.7,-73.5 266.26,-73.5 274.79,-73.5\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"275,-77 285,-73.5 275,-70 275,-77\"/>\n",
"</g>\n",
"<!-- 140390867925744 -->\n",
"<g id=\"node3\" class=\"node\">\n",
"<title>140390867925744</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"377.5,-110.5 377.5,-146.5 645.5,-146.5 645.5,-110.5 377.5,-110.5\"/>\n",
"<text text-anchor=\"middle\" x=\"408.5\" y=\"-124.8\" font-family=\"Times,serif\" font-size=\"14.00\">x2*w2</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"439.5,-110.5 439.5,-146.5 \"/>\n",
"<text text-anchor=\"middle\" x=\"490.5\" y=\"-124.8\" font-family=\"Times,serif\" font-size=\"14.00\">data 0.0000</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"541.5,-110.5 541.5,-146.5 \"/>\n",
"<text text-anchor=\"middle\" x=\"593.5\" y=\"-124.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.5000</text>\n",
"</g>\n",
"<!-- 140390867924832+ -->\n",
"<g id=\"node12\" class=\"node\">\n",
"<title>140390867924832+</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"711\" cy=\"-100.5\" rx=\"27\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"711\" y=\"-96.8\" font-family=\"Times,serif\" font-size=\"14.00\">+</text>\n",
"</g>\n",
"<!-- 140390867925744->140390867924832+ -->\n",
"<g id=\"edge14\" class=\"edge\">\n",
"<title>140390867925744->140390867924832+</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M639.53,-110.49C652.09,-108.71 663.99,-107.02 674.3,-105.56\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"675.03,-108.99 684.44,-104.12 674.05,-102.06 675.03,-108.99\"/>\n",
"</g>\n",
"<!-- 140390867925744* -->\n",
"<g id=\"node4\" class=\"node\">\n",
"<title>140390867925744*</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"312\" cy=\"-128.5\" rx=\"27\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"312\" y=\"-124.8\" font-family=\"Times,serif\" font-size=\"14.00\">*</text>\n",
"</g>\n",
"<!-- 140390867925744*->140390867925744 -->\n",
"<g id=\"edge1\" class=\"edge\">\n",
"<title>140390867925744*->140390867925744</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M339.23,-128.5C347.26,-128.5 356.72,-128.5 366.99,-128.5\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"367.08,-132 377.08,-128.5 367.08,-125 367.08,-132\"/>\n",
"</g>\n",
"<!-- 140390867926272 -->\n",
"<g id=\"node5\" class=\"node\">\n",
"<title>140390867926272</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"1606.69,-109.5 1606.69,-145.5 1837.69,-145.5 1837.69,-109.5 1606.69,-109.5\"/>\n",
"<text text-anchor=\"middle\" x=\"1619.19\" y=\"-123.8\" font-family=\"Times,serif\" font-size=\"14.00\">o</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"1631.69,-109.5 1631.69,-145.5 \"/>\n",
"<text text-anchor=\"middle\" x=\"1682.69\" y=\"-123.8\" font-family=\"Times,serif\" font-size=\"14.00\">data 0.7071</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"1733.69,-109.5 1733.69,-145.5 \"/>\n",
"<text text-anchor=\"middle\" x=\"1785.69\" y=\"-123.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad 1.0000</text>\n",
"</g>\n",
"<!-- 140390867926272tanh -->\n",
"<g id=\"node6\" class=\"node\">\n",
"<title>140390867926272tanh</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"1538.85\" cy=\"-127.5\" rx=\"31.7\" ry=\"18\"/>\n",
"<text text-anchor=\"middle\" x=\"1538.85\" y=\"-123.8\" font-family=\"Times,serif\" font-size=\"14.00\">tanh</text>\n",
"</g>\n",
"<!-- 140390867926272tanh->140390867926272 -->\n",
"<g id=\"edge2\" class=\"edge\">\n",
"<title>140390867926272tanh->140390867926272</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M1571.02,-127.5C1578.61,-127.5 1587.19,-127.5 1596.31,-127.5\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"1596.4,-131 1606.4,-127.5 1596.4,-124 1596.4,-131\"/>\n",
"</g>\n",
"<!-- 140390867925264 -->\n",
"<g id=\"node7\" class=\"node\">\n",
"<title>140390867925264</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"1239,-109.5 1239,-145.5 1471,-145.5 1471,-109.5 1239,-109.5\"/>\n",
"<text text-anchor=\"middle\" x=\"1252\" y=\"-123.8\" font-family=\"Times,serif\" font-size=\"14.00\">n</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"1265,-109.5 1265,-145.5 \"/>\n",
"<text text-anchor=\"middle\" x=\"1316\" y=\"-123.8\" font-family=\"Times,serif\" font-size=\"14.00\">data 0.8814</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"1367,-109.5 1367,-145.5 \"/>\n",
"<text text-anchor=\"middle\" x=\"1419\" y=\"-123.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.5000</text>\n",
"</g>\n",
"<!-- 140390867925264->140390867926272tanh -->\n",
"<g id=\"edge11\" class=\"edge\">\n",
"<title>140390867925264->140390867926272tanh</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M1471.17,-127.5C1480.08,-127.5 1488.66,-127.5 1496.53,-127.5\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"1496.78,-131 1506.78,-127.5 1496.78,-124 1496.78,-131\"/>\n",
"</g>\n",
"<!-- 140390867925264+->140390867925264 -->\n",
"<g id=\"edge3\" class=\"edge\">\n",
"<title>140390867925264+->140390867925264</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M1203.12,-127.5C1210.53,-127.5 1219.14,-127.5 1228.4,-127.5\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"1228.69,-131 1238.69,-127.5 1228.69,-124 1228.69,-131\"/>\n",
"</g>\n",
"<!-- 140390867926800 -->\n",
"<g id=\"node9\" class=\"node\">\n",
"<title>140390867926800</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"0,-0.5 0,-36.5 249,-36.5 249,-0.5 0,-0.5\"/>\n",
"<text text-anchor=\"middle\" x=\"19\" y=\"-14.8\" font-family=\"Times,serif\" font-size=\"14.00\">w1</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"38,-0.5 38,-36.5 \"/>\n",
"<text text-anchor=\"middle\" x=\"91.5\" y=\"-14.8\" font-family=\"Times,serif\" font-size=\"14.00\">data -3.0000</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"145,-0.5 145,-36.5 \"/>\n",
"<text text-anchor=\"middle\" x=\"197\" y=\"-14.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad 1.0000</text>\n",
"</g>\n",
"<!-- 140390867926800->140390867923392* -->\n",
"<g id=\"edge6\" class=\"edge\">\n",
"<title>140390867926800->140390867923392*</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M214.42,-36.55C226.15,-39.54 237.93,-42.87 249,-46.5 259.46,-49.93 270.57,-54.47 280.46,-58.84\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"279.31,-62.17 289.87,-63.12 282.21,-55.79 279.31,-62.17\"/>\n",
"</g>\n",
"<!-- 140390867774768 -->\n",
"<g id=\"node10\" class=\"node\">\n",
"<title>140390867774768</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"4.5,-165.5 4.5,-201.5 244.5,-201.5 244.5,-165.5 4.5,-165.5\"/>\n",
"<text text-anchor=\"middle\" x=\"21.5\" y=\"-179.8\" font-family=\"Times,serif\" font-size=\"14.00\">x2</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"38.5,-165.5 38.5,-201.5 \"/>\n",
"<text text-anchor=\"middle\" x=\"89.5\" y=\"-179.8\" font-family=\"Times,serif\" font-size=\"14.00\">data 0.0000</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"140.5,-165.5 140.5,-201.5 \"/>\n",
"<text text-anchor=\"middle\" x=\"192.5\" y=\"-179.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.5000</text>\n",
"</g>\n",
"<!-- 140390867774768->140390867925744* -->\n",
"<g id=\"edge7\" class=\"edge\">\n",
"<title>140390867774768->140390867925744*</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M218.13,-165.46C228.63,-162.77 239.1,-159.78 249,-156.5 259.71,-152.96 271.05,-148.16 281.07,-143.54\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"282.67,-146.66 290.2,-139.21 279.67,-140.33 282.67,-146.66\"/>\n",
"</g>\n",
"<!-- 140390867924832 -->\n",
"<g id=\"node11\" class=\"node\">\n",
"<title>140390867924832</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"774,-82.5 774,-118.5 1113,-118.5 1113,-82.5 774,-82.5\"/>\n",
"<text text-anchor=\"middle\" x=\"838\" y=\"-96.8\" font-family=\"Times,serif\" font-size=\"14.00\">x1*w1 + x2*w2</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"902,-82.5 902,-118.5 \"/>\n",
"<text text-anchor=\"middle\" x=\"955.5\" y=\"-96.8\" font-family=\"Times,serif\" font-size=\"14.00\">data -6.0000</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"1009,-82.5 1009,-118.5 \"/>\n",
"<text text-anchor=\"middle\" x=\"1061\" y=\"-96.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.5000</text>\n",
"</g>\n",
"<!-- 140390867924832->140390867925264+ -->\n",
"<g id=\"edge9\" class=\"edge\">\n",
"<title>140390867924832->140390867925264+</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M1098.25,-118.51C1113.17,-120.26 1127.12,-121.89 1138.92,-123.27\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"1138.79,-126.78 1149.13,-124.47 1139.6,-119.83 1138.79,-126.78\"/>\n",
"</g>\n",
"<!-- 140390867924832+->140390867924832 -->\n",
"<g id=\"edge4\" class=\"edge\">\n",
"<title>140390867924832+->140390867924832</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M738.44,-100.5C745.81,-100.5 754.42,-100.5 763.84,-100.5\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"763.94,-104 773.94,-100.5 763.94,-97 763.94,-104\"/>\n",
"</g>\n",
"<!-- 140390867923392 -->\n",
"<g id=\"node13\" class=\"node\">\n",
"<title>140390867923392</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"375,-55.5 375,-91.5 648,-91.5 648,-55.5 375,-55.5\"/>\n",
"<text text-anchor=\"middle\" x=\"406\" y=\"-69.8\" font-family=\"Times,serif\" font-size=\"14.00\">x1*w1</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"437,-55.5 437,-91.5 \"/>\n",
"<text text-anchor=\"middle\" x=\"490.5\" y=\"-69.8\" font-family=\"Times,serif\" font-size=\"14.00\">data -6.0000</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"544,-55.5 544,-91.5 \"/>\n",
"<text text-anchor=\"middle\" x=\"596\" y=\"-69.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.5000</text>\n",
"</g>\n",
"<!-- 140390867923392->140390867924832+ -->\n",
"<g id=\"edge13\" class=\"edge\">\n",
"<title>140390867923392->140390867924832+</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M644.24,-91.51C655.12,-93 665.4,-94.4 674.45,-95.64\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"674.11,-99.13 684.49,-97.01 675.05,-92.19 674.11,-99.13\"/>\n",
"</g>\n",
"<!-- 140390867923392*->140390867923392 -->\n",
"<g id=\"edge5\" class=\"edge\">\n",
"<title>140390867923392*->140390867923392</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M339.23,-73.5C346.7,-73.5 355.41,-73.5 364.87,-73.5\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"364.98,-77 374.98,-73.5 364.98,-70 364.98,-77\"/>\n",
"</g>\n",
"<!-- 140390867772896 -->\n",
"<g id=\"node15\" class=\"node\">\n",
"<title>140390867772896</title>\n",
"<polygon fill=\"none\" stroke=\"black\" points=\"2.5,-110.5 2.5,-146.5 246.5,-146.5 246.5,-110.5 2.5,-110.5\"/>\n",
"<text text-anchor=\"middle\" x=\"21.5\" y=\"-124.8\" font-family=\"Times,serif\" font-size=\"14.00\">w2</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"40.5,-110.5 40.5,-146.5 \"/>\n",
"<text text-anchor=\"middle\" x=\"91.5\" y=\"-124.8\" font-family=\"Times,serif\" font-size=\"14.00\">data 1.0000</text>\n",
"<polyline fill=\"none\" stroke=\"black\" points=\"142.5,-110.5 142.5,-146.5 \"/>\n",
"<text text-anchor=\"middle\" x=\"194.5\" y=\"-124.8\" font-family=\"Times,serif\" font-size=\"14.00\">grad 0.0000</text>\n",
"</g>\n",
"<!-- 140390867772896->140390867925744* -->\n",
"<g id=\"edge12\" class=\"edge\">\n",
"<title>140390867772896->140390867925744*</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M246.64,-128.5C256.7,-128.5 266.26,-128.5 274.79,-128.5\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"275,-132 285,-128.5 275,-125 275,-132\"/>\n",
"</g>\n",
"</g>\n",
"</svg>\n"
],
"text/plain": [
"<graphviz.graphs.Digraph at 0x7faf4bcff3a0>"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"draw_dot(o) # now the graph shows the grad for o wrt to each node"
]
},
{
"cell_type": "markdown",
"id": "97610480-b29b-4928-b57b-b7fe8a417ea5",
"metadata": {},
"source": [
"## Manually verify some of the gradients"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "c65bd5fd-4687-42b8-8b27-8fe3cda6a58a",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"0.4999999999999999"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# o = tanh(n)\n",
"# do/dn = 1 - o**2\n",
"(1 - o.data**2)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "8b409694-f2d3-4a4b-a8f2-a2ba6976b785",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"0.4999999999999999"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# our first chain rule example!\n",
"# intuition: \"If a car travels twice as fast as a bicycle and the bicycle is four times as fast as a walking man, then the car travels 2 × 4 = 8 times as fast as the man\"\n",
"\n",
"# o = tanh(n)\n",
"# n = x1*w1 + x2*w2 + b\n",
"# do/db = do/dn * dn/db\n",
"(1 - o.data**2) * 1 "
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "7e563dc8-3fac-4152-b3c9-05d9760bb6cf",
"metadata": {},
"outputs": [],
"source": [
"# derive for do/dj, where j is x1*w1+x2*w2\n"
]
},
{
"cell_type": "markdown",
"id": "2f5aa4b2-388d-440f-b4eb-dda1167047c3",
"metadata": {},
"source": [
"# In Pytorch\n",
"- And finally we can see that Pytorch is just doing the same thing!"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "bc4eed18-9092-45c6-a7bb-8e37dbd2ceaf",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# this is what that same expression looks like in Pytorch\n",
"x1 = torch.Tensor([2.0]).double() ; x1.requires_grad = True\n",
"x2 = torch.Tensor([0.0]).double() ; x2.requires_grad = True\n",
"w1 = torch.Tensor([-3.0]).double() ; w1.requires_grad = True\n",
"w2 = torch.Tensor([1.0]).double() ; w2.requires_grad = True\n",
"b = torch.Tensor([6.8813735870195432]).double() ; b.requires_grad = True\n",
"n = x1*w1 + x2*w2 + b\n",
"o = torch.tanh(n)"
]
},
{
"cell_type": "markdown",
"id": "0bd63dfc-8a33-4a78-be75-9d37495b3c71",
"metadata": {},
"source": [
"## Verify output and backpropagate in Pytorch"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "fbddfa4e-b4a2-460e-9e6d-dae70ec4d658",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.7071066904050358\n"
]
}
],
"source": [
"print(o.data.item())\n",
"o.backward()"
]
},
{
"cell_type": "markdown",
"id": "9535b1a2-2f43-4389-a597-88288f5fa0cf",
"metadata": {},
"source": [
"## Are the numbers correct?"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "4261bf49-6af8-42a1-88aa-36b52dc7b9cc",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"---\n",
"x2 0.5000001283844369\n",
"w2 0.0\n",
"x1 -1.5000003851533106\n",
"w1 1.0000002567688737\n"
]
}
],
"source": [
"print('---')\n",
"print('x2', x2.grad.item())\n",
"print('w2', w2.grad.item())\n",
"print('x1', x1.grad.item())\n",
"print('w1', w1.grad.item())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "46e36130-4ed1-4c70-82a1-fa32d8304b6a",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}