torchbenchmark/models/LearningToPaint/LearningToPaint.ipynb (393 lines of code) (raw):

{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "learningtopaint.ipynb", "version": "0.3.2", "provenance": [], "include_colab_link": true }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "<a href=\"https://colab.research.google.com/github/hzwer/LearningToPaint/blob/master/LearningToPaint.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" ] }, { "metadata": { "id": "TFN3oT1Hkjfs", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "!git clone https://github.com/hzwer/LearningToPaint.git" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "Dp7N29tGkwQs", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "cd LearningToPaint/" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "qTbhmFyawzhO", "colab_type": "text" }, "cell_type": "markdown", "source": [ "Testing " ] }, { "metadata": { "id": "z0wTTzOEbvps", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "!wget \"https://drive.google.com/uc?export=download&id=1-7dVdjCIZIxh8hHJnGTK-RA1-jL1tor4\" -O renderer.pkl" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "Pfd53Hw2cfaY", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "!wget \"https://drive.google.com/uc?export=download&id=1a3vpKgjCVXHON4P7wodqhCgCMPgg1KeR\" -O actor.pkl" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "QZpb3_3QiMZw", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "!wget -U NoSuchBrowser/1.0 -O image/test.png https://raw.githubusercontent.com/hzwer/LearningToPaint/master/image/Trump.png" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "brX4ZlQoc9ss", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "!python3 baseline/test.py --max_step=80 --actor=actor.pkl --renderer=renderer.pkl --img=image/test.png --divide=5" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "tLM4U6F0_yjV", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "!ffmpeg -r 30 -f image2 -i output/generated%d.png -s 512x512 -c:v libx264 -pix_fmt yuv420p video.mp4 -q:v 0 -q:a 0" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "ekY7HcBeh8zl", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "from IPython.display import display, Image\n", "import moviepy.editor as mpy\n", "display(mpy.ipython_display('video.mp4', height=256, max_duration=100.))\n", "display(Image('output/generated399.png'))" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "d2mAkgRjwwuf", "colab_type": "text" }, "cell_type": "markdown", "source": [ "Training" ] }, { "metadata": { "id": "_-p0NhqyTqO_", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "!mkdir data" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "XXAV9RwkTwKh", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "cd data" ], "execution_count": 0, "outputs": [] }, { "metadata": { "colab_type": "code", "id": "IzZUVjdrET2G", "colab": {} }, "cell_type": "code", "source": [ "!gdown https://drive.google.com/uc?id=0B7EVK8r0v71pZjFTYXZWM3FlRnM" ], "execution_count": 0, "outputs": [] }, { "metadata": { "colab_type": "code", "id": "zgguAW3eETVd", "colab": {} }, "cell_type": "code", "source": [ "!unzip img_align_celeba.zip" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "XBH--DY-sK8V", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "!rm img_align_celeba.zip" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "u6mVpjvBvzrb", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "cd .." ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "-PYJVt8pc6BP", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "!python3 baseline/train_renderer.py" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "VZWjNmD23gKm", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "!pip install tensorboardX" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "ehnzhWn9GG4I", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "%%writefile baseline/env.py\n", "import sys\n", "import json\n", "import torch\n", "import numpy as np\n", "import argparse\n", "import torchvision.transforms as transforms\n", "import cv2\n", "from DRL.ddpg import decode\n", "from utils.util import *\n", "from PIL import Image\n", "from torchvision import transforms, utils\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "aug = transforms.Compose(\n", " [transforms.ToPILImage(),\n", " transforms.RandomHorizontalFlip(),\n", " ])\n", "\n", "width = 128\n", "convas_area = width * width\n", "\n", "img_train = []\n", "img_test = []\n", "train_num = 0\n", "test_num = 0\n", "\n", "class Paint:\n", " def __init__(self, batch_size, max_step):\n", " self.batch_size = batch_size\n", " self.max_step = max_step\n", " self.action_space = (13)\n", " self.observation_space = (self.batch_size, width, width, 7)\n", " self.test = False\n", " \n", " def load_data(self):\n", " # CelebA\n", " global train_num, test_num\n", " for i in range(200000):\n", " img_id = '%06d' % (i + 1)\n", " try:\n", " img = cv2.imread('./data/img_align_celeba/' + img_id + '.jpg', cv2.IMREAD_UNCHANGED)\n", " img = cv2.resize(img, (width, width))\n", " if i > 2000: \n", " train_num += 1\n", " img_train.append(img)\n", " else:\n", " test_num += 1\n", " img_test.append(img)\n", " finally:\n", " if (i + 1) % 10000 == 0: \n", " print('loaded {} images'.format(i + 1))\n", " print('finish loading data, {} training images, {} testing images'.format(str(train_num), str(test_num)))\n", " \n", " def pre_data(self, id, test):\n", " if test:\n", " img = img_test[id]\n", " else:\n", " img = img_train[id]\n", " if not test:\n", " img = aug(img)\n", " img = np.asarray(img)\n", " return np.transpose(img, (2, 0, 1))\n", " \n", " def reset(self, test=False, begin_num=False):\n", " self.test = test\n", " self.imgid = [0] * self.batch_size\n", " self.gt = torch.zeros([self.batch_size, 3, width, width], dtype=torch.uint8).to(device)\n", " for i in range(self.batch_size):\n", " if test:\n", " id = (i + begin_num) % test_num\n", " else:\n", " id = np.random.randint(train_num)\n", " self.imgid[i] = id\n", " self.gt[i] = torch.tensor(self.pre_data(id, test))\n", " self.tot_reward = ((self.gt.float() / 255) ** 2).mean(1).mean(1).mean(1)\n", " self.stepnum = 0\n", " self.canvas = torch.zeros([self.batch_size, 3, width, width], dtype=torch.uint8).to(device)\n", " self.lastdis = self.ini_dis = self.cal_dis()\n", " return self.observation()\n", " \n", " def observation(self):\n", " # canvas B * 3 * width * width\n", " # gt B * 3 * width * width\n", " # T B * 1 * width * width\n", " ob = []\n", " T = torch.ones([self.batch_size, 1, width, width], dtype=torch.uint8) * self.stepnum\n", " return torch.cat((self.canvas, self.gt, T.to(device)), 1) # canvas, img, T\n", "\n", " def cal_trans(self, s, t):\n", " return (s.transpose(0, 3) * t).transpose(0, 3)\n", " \n", " def step(self, action):\n", " self.canvas = (decode(action, self.canvas.float() / 255) * 255).byte()\n", " self.stepnum += 1\n", " ob = self.observation()\n", " done = (self.stepnum == self.max_step)\n", " reward = self.cal_reward() # np.array([0.] * self.batch_size)\n", " return ob.detach(), reward, np.array([done] * self.batch_size), None\n", "\n", " def cal_dis(self):\n", " return (((self.canvas.float() - self.gt.float()) / 255) ** 2).mean(1).mean(1).mean(1)\n", " \n", " def cal_reward(self):\n", " dis = self.cal_dis()\n", " reward = (self.lastdis - dis) / (self.ini_dis + 1e-8)\n", " self.lastdis = dis\n", " return to_numpy(reward)\n" ], "execution_count": 0, "outputs": [] }, { "metadata": { "id": "0kwVmo6yv1w3", "colab_type": "code", "colab": {} }, "cell_type": "code", "source": [ "!python3 baseline/train.py --max_step=200 --debug --batch_size=96" ], "execution_count": 0, "outputs": [] } ] }