torchbenchmark/models/LearningToPaint/baseline_modelfree/test.py (128 lines of code) (raw):

import os import cv2 import torch import numpy as np import argparse import torch.nn as nn import torch.nn.functional as F from DRL.actor import * from Renderer.stroke_gen import * from Renderer.model import * device = torch.device("cuda" if torch.cuda.is_available() else "cpu") width = 128 parser = argparse.ArgumentParser(description='Learning to Paint') parser.add_argument('--max_step', default=40, type=int, help='max length for episode') parser.add_argument('--actor', default='./model/Paint-run1/actor.pkl', type=str, help='Actor model') parser.add_argument('--renderer', default='./renderer.pkl', type=str, help='renderer model') parser.add_argument('--img', default='image/test.png', type=str, help='test image') parser.add_argument('--imgid', default=0, type=int, help='set begin number for generated image') parser.add_argument('--divide', default=4, type=int, help='divide the target image to get better resolution') args = parser.parse_args() canvas_cnt = args.divide * args.divide T = torch.ones([1, 1, width, width], dtype=torch.float32).to(device) img = cv2.imread(args.img, cv2.IMREAD_COLOR) origin_shape = (img.shape[1], img.shape[0]) coord = torch.zeros([1, 2, width, width]) for i in range(width): for j in range(width): coord[0, 0, i, j] = i / (width - 1.) coord[0, 1, i, j] = j / (width - 1.) coord = coord.to(device) # Coordconv Decoder = FCN() Decoder.load_state_dict(torch.load(args.renderer)) def decode(x, canvas): # b * (10 + 3) x = x.view(-1, 10 + 3) stroke = 1 - Decoder(x[:, :10]) stroke = stroke.view(-1, width, width, 1) color_stroke = stroke * x[:, -3:].view(-1, 1, 1, 3) stroke = stroke.permute(0, 3, 1, 2) color_stroke = color_stroke.permute(0, 3, 1, 2) stroke = stroke.view(-1, 5, 1, width, width) color_stroke = color_stroke.view(-1, 5, 3, width, width) res = [] for i in range(5): canvas = canvas * (1 - stroke[:, i]) + color_stroke[:, i] res.append(canvas) return canvas, res def small2large(x): # (d * d, width, width) -> (d * width, d * width) x = x.reshape(args.divide, args.divide, width, width, -1) x = np.transpose(x, (0, 2, 1, 3, 4)) x = x.reshape(args.divide * width, args.divide * width, -1) return x def large2small(x): # (d * width, d * width) -> (d * d, width, width) x = x.reshape(args.divide, width, args.divide, width, 3) x = np.transpose(x, (0, 2, 1, 3, 4)) x = x.reshape(canvas_cnt, width, width, 3) return x def smooth(img): def smooth_pix(img, tx, ty): if tx == args.divide * width - 1 or ty == args.divide * width - 1 or tx == 0 or ty == 0: return img img[tx, ty] = (img[tx, ty] + img[tx + 1, ty] + img[tx, ty + 1] + img[tx - 1, ty] + img[tx, ty - 1] + img[tx + 1, ty - 1] + img[tx - 1, ty + 1] + img[tx - 1, ty - 1] + img[tx + 1, ty + 1]) / 9 return img for p in range(args.divide): for q in range(args.divide): x = p * width y = q * width for k in range(width): img = smooth_pix(img, x + k, y + width - 1) if q != args.divide - 1: img = smooth_pix(img, x + k, y + width) for k in range(width): img = smooth_pix(img, x + width - 1, y + k) if p != args.divide - 1: img = smooth_pix(img, x + width, y + k) return img def save_img(res, imgid, divide=False): output = res.detach().cpu().numpy() # d * d, 3, width, width output = np.transpose(output, (0, 2, 3, 1)) if divide: output = small2large(output) output = smooth(output) else: output = output[0] output = (output * 255).astype('uint8') output = cv2.resize(output, origin_shape) cv2.imwrite('output/generated' + str(imgid) + '.png', output) actor = ResNet(9, 18, 65) # action_bundle = 5, 65 = 5 * 13 actor.load_state_dict(torch.load(args.actor)) actor = actor.to(device).eval() Decoder = Decoder.to(device).eval() canvas = torch.zeros([1, 3, width, width]).to(device) patch_img = cv2.resize(img, (width * args.divide, width * args.divide)) patch_img = large2small(patch_img) patch_img = np.transpose(patch_img, (0, 3, 1, 2)) patch_img = torch.tensor(patch_img).to(device).float() / 255. img = cv2.resize(img, (width, width)) img = img.reshape(1, width, width, 3) img = np.transpose(img, (0, 3, 1, 2)) img = torch.tensor(img).to(device).float() / 255. os.system('mkdir output') with torch.no_grad(): if args.divide != 1: args.max_step = args.max_step // 2 for i in range(args.max_step): stepnum = T * i / args.max_step actions = actor(torch.cat([canvas, img, stepnum, coord], 1)) canvas, res = decode(actions, canvas) print('canvas step {}, L2Loss = {}'.format(i, ((canvas - img) ** 2).mean())) for j in range(5): save_img(res[j], args.imgid) args.imgid += 1 if args.divide != 1: canvas = canvas[0].detach().cpu().numpy() canvas = np.transpose(canvas, (1, 2, 0)) canvas = cv2.resize(canvas, (width * args.divide, width * args.divide)) canvas = large2small(canvas) canvas = np.transpose(canvas, (0, 3, 1, 2)) canvas = torch.tensor(canvas).to(device).float() coord = coord.expand(canvas_cnt, 2, width, width) T = T.expand(canvas_cnt, 1, width, width) for i in range(args.max_step): stepnum = T * i / args.max_step actions = actor(torch.cat([canvas, patch_img, stepnum, coord], 1)) canvas, res = decode(actions, canvas) print('divided canvas step {}, L2Loss = {}'.format(i, ((canvas - patch_img) ** 2).mean())) for j in range(5): save_img(res[j], args.imgid, True) args.imgid += 1