torchbenchmark/models/LearningToPaint/__init__.py (92 lines of code) (raw):
import cv2
import torch
import random
import numpy as np
from .baseline.Renderer.model import FCN
from .baseline.DRL.evaluator import Evaluator
from .baseline.utils.util import *
from .baseline.DRL.ddpg import DDPG
from .baseline.DRL.multi import fastenv
from ...util.model import BenchmarkModel
from typing import Tuple
from torchbenchmark.tasks import REINFORCEMENT_LEARNING
from argparse import Namespace
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
class Model(BenchmarkModel):
task = REINFORCEMENT_LEARNING.OTHER_RL
DEFAULT_TRAIN_BSIZE = 96
DEFAULT_EVAL_BSIZE = 96
def __init__(self, test, device, jit=False, batch_size=None, extra_args=[]):
super().__init__(test=test, device=device, jit=jit, batch_size=batch_size, extra_args=extra_args)
# Train: These options are from source code.
# Source: https://arxiv.org/pdf/1903.04411.pdf
# Code: https://github.com/megvii-research/ICCV2019-LearningToPaint/blob/master/baseline/train.py
self.args = Namespace(**{
'validate_episodes': 5,
'validate_interval': 50,
'max_step': 40,
'discount': 0.95**5,
'episode_train_times': 10,
'noise_factor': 0.0,
'tau': 0.001,
'rmsize': 800,
})
# Train: input images are from CelebFaces and resized to 128 x 128.
# Create 2000 random tensors for input, but randomly sample 200,000 images.
self.width = 128
self.image_examples = torch.rand(2000, 3, self.width, self.width)
# LearningToPaint includes actor, critic, and discriminator models.
self.Decoder = FCN()
self.step = 0
self.env = fastenv(max_episode_length=self.args.max_step, env_batch=self.batch_size,
images=self.image_examples, device=self.device, Decoder=self.Decoder)
self.agent = DDPG(batch_size=self.batch_size, env_batch=self.batch_size,
max_step=self.args.max_step, tau=self.args.tau, discount=self.args.discount,
rmsize=self.args.rmsize, device=self.device, Decoder=self.Decoder)
self.evaluate = Evaluator(args=self.args, env_batch=self.batch_size, writer=None)
self.observation = self.env.reset()
self.agent.reset(self.observation, self.args.noise_factor)
if test == "train":
self.agent.train()
elif test == "eval":
self.agent.eval()
def get_module(self):
action = self.agent.select_action(self.observation, noise_factor=self.args.noise_factor)
self.observation, reward, done, _ = self.env.step(action)
self.agent.observe(reward, self.observation, done, self.step)
state, action, reward, \
next_state, terminal = self.agent.memory.sample_batch(self.batch_size, self.device)
state = torch.cat((state[:, :6].float() / 255, state[:, 6:7].float() / self.args.max_step,
self.agent.coord.expand(state.shape[0], 2, 128, 128)), 1)
return self.agent.actor, (state, )
def set_module(self, new_model):
self.agent.actor = new_model
def train(self, niter=1):
episode = episode_steps = 0
for _ in range(niter):
episode_steps += 1
if self.observation is None:
self.observation = self.env.reset()
self.agent.reset(self.observation, self.args.noise_factor)
action = self.agent.select_action(self.observation, noise_factor=self.args.noise_factor)
self.observation, reward, done, _ = self.env.step(action)
self.agent.observe(reward, self.observation, done, self.step)
if (episode_steps >= self.args.max_step and self.args.max_step):
# [optional] evaluate
if episode > 0 and self.args.validate_interval > 0 and \
episode % self.args.validate_interval == 0:
reward, dist = self.evaluate(self.env, self.agent.select_action)
tot_Q = 0.
tot_value_loss = 0.
lr = (3e-4, 1e-3)
for i in range(self.args.episode_train_times):
Q, value_loss = self.agent.update_policy(lr)
tot_Q += Q.data.cpu().numpy()
tot_value_loss += value_loss.data.cpu().numpy()
# reset
self.observation = None
episode_steps = 0
episode += 1
self.step += 1
def eval(self, niter=1) -> Tuple[torch.Tensor]:
for _ in range(niter):
reward, dist = self.evaluate(self.env, self.agent.select_action)
return (torch.tensor(reward), torch.tensor(dist))
def _set_mode(self, train):
if train:
self.agent.train()
else:
self.agent.eval()