torchbenchmark/models/LearningToPaint/LearningToPaint.ipynb (393 lines of code) (raw):
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "learningtopaint.ipynb",
"version": "0.3.2",
"provenance": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/hzwer/LearningToPaint/blob/master/LearningToPaint.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"metadata": {
"id": "TFN3oT1Hkjfs",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"!git clone https://github.com/hzwer/LearningToPaint.git"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "Dp7N29tGkwQs",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"cd LearningToPaint/"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "qTbhmFyawzhO",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"Testing "
]
},
{
"metadata": {
"id": "z0wTTzOEbvps",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"!wget \"https://drive.google.com/uc?export=download&id=1-7dVdjCIZIxh8hHJnGTK-RA1-jL1tor4\" -O renderer.pkl"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "Pfd53Hw2cfaY",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"!wget \"https://drive.google.com/uc?export=download&id=1a3vpKgjCVXHON4P7wodqhCgCMPgg1KeR\" -O actor.pkl"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "QZpb3_3QiMZw",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"!wget -U NoSuchBrowser/1.0 -O image/test.png https://raw.githubusercontent.com/hzwer/LearningToPaint/master/image/Trump.png"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "brX4ZlQoc9ss",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"!python3 baseline/test.py --max_step=80 --actor=actor.pkl --renderer=renderer.pkl --img=image/test.png --divide=5"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "tLM4U6F0_yjV",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"!ffmpeg -r 30 -f image2 -i output/generated%d.png -s 512x512 -c:v libx264 -pix_fmt yuv420p video.mp4 -q:v 0 -q:a 0"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "ekY7HcBeh8zl",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"from IPython.display import display, Image\n",
"import moviepy.editor as mpy\n",
"display(mpy.ipython_display('video.mp4', height=256, max_duration=100.))\n",
"display(Image('output/generated399.png'))"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "d2mAkgRjwwuf",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"Training"
]
},
{
"metadata": {
"id": "_-p0NhqyTqO_",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"!mkdir data"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "XXAV9RwkTwKh",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"cd data"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"colab_type": "code",
"id": "IzZUVjdrET2G",
"colab": {}
},
"cell_type": "code",
"source": [
"!gdown https://drive.google.com/uc?id=0B7EVK8r0v71pZjFTYXZWM3FlRnM"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"colab_type": "code",
"id": "zgguAW3eETVd",
"colab": {}
},
"cell_type": "code",
"source": [
"!unzip img_align_celeba.zip"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "XBH--DY-sK8V",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"!rm img_align_celeba.zip"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "u6mVpjvBvzrb",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"cd .."
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "-PYJVt8pc6BP",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"!python3 baseline/train_renderer.py"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "VZWjNmD23gKm",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"!pip install tensorboardX"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "ehnzhWn9GG4I",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"%%writefile baseline/env.py\n",
"import sys\n",
"import json\n",
"import torch\n",
"import numpy as np\n",
"import argparse\n",
"import torchvision.transforms as transforms\n",
"import cv2\n",
"from DRL.ddpg import decode\n",
"from utils.util import *\n",
"from PIL import Image\n",
"from torchvision import transforms, utils\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"aug = transforms.Compose(\n",
" [transforms.ToPILImage(),\n",
" transforms.RandomHorizontalFlip(),\n",
" ])\n",
"\n",
"width = 128\n",
"convas_area = width * width\n",
"\n",
"img_train = []\n",
"img_test = []\n",
"train_num = 0\n",
"test_num = 0\n",
"\n",
"class Paint:\n",
" def __init__(self, batch_size, max_step):\n",
" self.batch_size = batch_size\n",
" self.max_step = max_step\n",
" self.action_space = (13)\n",
" self.observation_space = (self.batch_size, width, width, 7)\n",
" self.test = False\n",
" \n",
" def load_data(self):\n",
" # CelebA\n",
" global train_num, test_num\n",
" for i in range(200000):\n",
" img_id = '%06d' % (i + 1)\n",
" try:\n",
" img = cv2.imread('./data/img_align_celeba/' + img_id + '.jpg', cv2.IMREAD_UNCHANGED)\n",
" img = cv2.resize(img, (width, width))\n",
" if i > 2000: \n",
" train_num += 1\n",
" img_train.append(img)\n",
" else:\n",
" test_num += 1\n",
" img_test.append(img)\n",
" finally:\n",
" if (i + 1) % 10000 == 0: \n",
" print('loaded {} images'.format(i + 1))\n",
" print('finish loading data, {} training images, {} testing images'.format(str(train_num), str(test_num)))\n",
" \n",
" def pre_data(self, id, test):\n",
" if test:\n",
" img = img_test[id]\n",
" else:\n",
" img = img_train[id]\n",
" if not test:\n",
" img = aug(img)\n",
" img = np.asarray(img)\n",
" return np.transpose(img, (2, 0, 1))\n",
" \n",
" def reset(self, test=False, begin_num=False):\n",
" self.test = test\n",
" self.imgid = [0] * self.batch_size\n",
" self.gt = torch.zeros([self.batch_size, 3, width, width], dtype=torch.uint8).to(device)\n",
" for i in range(self.batch_size):\n",
" if test:\n",
" id = (i + begin_num) % test_num\n",
" else:\n",
" id = np.random.randint(train_num)\n",
" self.imgid[i] = id\n",
" self.gt[i] = torch.tensor(self.pre_data(id, test))\n",
" self.tot_reward = ((self.gt.float() / 255) ** 2).mean(1).mean(1).mean(1)\n",
" self.stepnum = 0\n",
" self.canvas = torch.zeros([self.batch_size, 3, width, width], dtype=torch.uint8).to(device)\n",
" self.lastdis = self.ini_dis = self.cal_dis()\n",
" return self.observation()\n",
" \n",
" def observation(self):\n",
" # canvas B * 3 * width * width\n",
" # gt B * 3 * width * width\n",
" # T B * 1 * width * width\n",
" ob = []\n",
" T = torch.ones([self.batch_size, 1, width, width], dtype=torch.uint8) * self.stepnum\n",
" return torch.cat((self.canvas, self.gt, T.to(device)), 1) # canvas, img, T\n",
"\n",
" def cal_trans(self, s, t):\n",
" return (s.transpose(0, 3) * t).transpose(0, 3)\n",
" \n",
" def step(self, action):\n",
" self.canvas = (decode(action, self.canvas.float() / 255) * 255).byte()\n",
" self.stepnum += 1\n",
" ob = self.observation()\n",
" done = (self.stepnum == self.max_step)\n",
" reward = self.cal_reward() # np.array([0.] * self.batch_size)\n",
" return ob.detach(), reward, np.array([done] * self.batch_size), None\n",
"\n",
" def cal_dis(self):\n",
" return (((self.canvas.float() - self.gt.float()) / 255) ** 2).mean(1).mean(1).mean(1)\n",
" \n",
" def cal_reward(self):\n",
" dis = self.cal_dis()\n",
" reward = (self.lastdis - dis) / (self.ini_dis + 1e-8)\n",
" self.lastdis = dis\n",
" return to_numpy(reward)\n"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "0kwVmo6yv1w3",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"!python3 baseline/train.py --max_step=200 --debug --batch_size=96"
],
"execution_count": 0,
"outputs": []
}
]
}