in cfvpy/selfplay.py [0:0]
def run_trainer(self):
# Fix version so that training always continues.
if self.is_master:
logger = pl_logging.TestTubeLogger(save_dir=os.getcwd(), version=0)
# Storing the whole dict to preserve ref_models.
datagen = self.initialize_datagen()
context = datagen["context"]
replay = datagen["replay"]
policy_replay = datagen["policy_replay"]
if self.cfg.data.train_preload:
# Must preload data before starting generators to avoid deadlocks.
_preload_data(self.cfg.data.train_preload, replay)
preloaded_size = replay.size()
else:
preloaded_size = 0
self.opt, self.policy_opt = self.configure_optimizers()
self.scheduler = self.configure_scheduler(self.opt)
context.start()
if self.cfg.benchmark_data_gen:
# Benchmark generation speed and exit.
time.sleep(self.cfg.benchmark_data_gen)
context.terminate()
size = replay.num_add()
logging.info(
"BENCHMARK size %s speed %.2f", size, size / context.running_time
)
return
train_size = self.cfg.data.train_epoch_size or 128 * 1000
logging.info("Train set size (forced): %s", train_size)
assert self.cfg.data.train_batch_size
batch_size = self.cfg.data.train_batch_size
epoch_size = train_size // batch_size
if self.is_master:
val_datasets = []
logging.info(
"model size is %s",
sum(p.numel() for p in self.net.parameters() if p.requires_grad),
)
save_dir = pathlib.Path("ckpt")
if self.is_master and not save_dir.exists():
logging.info(f"Creating savedir: {save_dir}")
save_dir.mkdir(parents=True)
burn_in_frames = batch_size * 2
while replay.size() < burn_in_frames or (
policy_replay is not None and policy_replay.size() < burn_in_frames
):
logging.info(
"warming up replay buffer: %d/%d", replay.size(), burn_in_frames
)
if policy_replay is not None:
logging.info(
"warming up POLICY replay buffer: %d/%d",
policy_replay.size(),
burn_in_frames,
)
time.sleep(30)
def compute_gen_bps():
return (
(replay.num_add() - preloaded_size) / context.running_time / batch_size
)
def compute_gen_bps_policy():
return policy_replay.num_add() / context.running_time / batch_size
metrics = None
num_decays = 0
for epoch in range(self.cfg.max_epochs):
self.train_timer.start("start")
if (
epoch % self.cfg.decrease_lr_every == self.cfg.decrease_lr_every - 1
and self.scheduler is None
):
if (
not self.cfg.decrease_lr_times
or num_decays < self.cfg.decrease_lr_times
):
for param_group in self.opt.param_groups:
param_group["lr"] /= 2
num_decays += 1
if (
self.cfg.create_validation_set_every
and self.is_master
and epoch % self.cfg.create_validation_set_every == 0
):
logging.info("Adding new validation set")
val_batches = [
replay.sample(batch_size, "cpu")[0]
for _ in range(512 * 100 // batch_size)
]
val_datasets.append((f"valid_snapshot_{epoch:04d}", val_batches))
if (
self.cfg.selfplay.dump_dataset_every_epochs
and epoch % self.cfg.selfplay.dump_dataset_every_epochs == 0
and (not self.cfg.data.train_preload or epoch > 0)
):
dataset_folder = pathlib.Path("dumped_data").resolve()
dataset_folder.mkdir(exist_ok=True, parents=True)
dataset_path = dataset_folder / f"data_{epoch:03d}.dat"
logging.info(
"Saving replay buffer as supervised dataset to %s", dataset_path
)
replay.save(str(dataset_path))
metrics = {}
metrics["optim/lr"] = next(iter(self.opt.param_groups))["lr"]
metrics["epoch"] = epoch
counters = collections.defaultdict(cfvpy.utils.FractionCounter)
if self.cfg.grad_clip:
counters["optim/grad_max"] = cfvpy.utils.MaxCounter()
if self.cfg.train_policy:
counters["optim_policy/grad_max"] = cfvpy.utils.MaxCounter()
use_progress_bar = not heyhi.is_on_slurm() or self.cfg.show_progress_bar
train_loader = range(epoch_size)
train_device = self.device
train_iter = tqdm.tqdm(train_loader) if use_progress_bar else train_loader
training_start = time.time()
if self.cfg.train_gen_ratio:
while True:
if replay.num_add() * self.cfg.train_gen_ratio >= train_size * (
epoch + 1
):
break
logging.info(
"Throttling to satisfy |replay| * ratio >= train_size * epochs:"
" %s * %s >= %s %s",
replay.num_add(),
self.cfg.train_gen_ratio,
train_size,
epoch + 1,
)
time.sleep(60)
assert self.cfg.replay.use_priority is False, "Not supported"
value_loss = policy_loss = 0 # For progress bar.
for iter_id in train_iter:
self.train_timer.start("train-get_batch")
use_policy_net = iter_id % 2 and policy_replay is not None
if use_policy_net:
batch, _ = policy_replay.sample(batch_size, train_device)
suffix = "_policy"
else:
batch, _ = replay.sample(batch_size, train_device)
suffix = ""
self.train_timer.start("train-forward")
self.net.train()
loss_dict = self._compute_loss_dict(
batch, train_device, use_policy_net, timer_prefix="train-"
)
self.train_timer.start("train-backward")
loss = loss_dict["loss"]
opt = self.policy_opt if use_policy_net else self.opt
params = (
self.get_policy_params()
if use_policy_net
else self.get_value_params()
)
opt.zero_grad()
loss.backward()
if self.cfg.grad_clip:
g_norm = clip_grad_norm_(params, self.cfg.grad_clip)
else:
g_norm = None
opt.step()
loss.item() # Force sync.
self.train_timer.start("train-rest")
if g_norm is not None:
g_norm = g_norm.item()
counters[f"optim{suffix}/grad_max"].update(g_norm)
counters[f"optim{suffix}/grad_mean"].update(g_norm)
counters[f"optim{suffix}/grad_clip_ratio"].update(
int(g_norm >= self.cfg.grad_clip - 1e-5)
)
counters[f"loss{suffix}/train"].update(loss)
for num_cards, partial_data in loss_dict["partials"].items():
counters[f"loss{suffix}/train_{num_cards}"].update(
partial_data["loss_sum"], partial_data["count"],
)
counters[f"val{suffix}/train_{num_cards}"].update(
partial_data["val_sum"], partial_data["count"],
)
counters[f"shares{suffix}/train_{num_cards}"].update(
partial_data["count"], batch_size
)
if use_progress_bar:
if use_policy_net:
policy_loss = loss.detach().item()
else:
value_loss = loss.detach().item()
pbar_fields = dict(
policy_loss=policy_loss,
value_loss=value_loss,
buffer_size=replay.size(),
gen_bps=compute_gen_bps(),
)
if policy_replay is not None:
pbar_fields["pol_buffer_size"] = policy_replay.size()
train_iter.set_postfix(**pbar_fields)
if self.cfg.fake_training:
# Generation benchmarking mode in which training is
# skipped. The goal is to measure generation speed withot
# sample() calls..
break
if self.cfg.fake_training:
# Fake training epoch takes a minute.
time.sleep(60)
if len(train_loader) > 0:
metrics["bps/train"] = len(train_loader) / (
time.time() - training_start
)
metrics["bps/train_examples"] = metrics["bps/train"] * batch_size
logging.info(
"[Train] epoch %d complete, avg error is %f",
epoch,
counters["loss/train"].value(),
)
if self.scheduler is not None:
self.scheduler.step()
for name, counter in counters.items():
metrics[name] = counter.value()
metrics["buffer/size"] = replay.size()
metrics["buffer/added"] = replay.num_add()
metrics["bps/gen"] = compute_gen_bps()
metrics["bps/gen_examples"] = metrics["bps/gen"] * batch_size
if policy_replay is not None:
metrics["buffer/policy_size"] = policy_replay.size()
metrics["buffer/policy_added"] = policy_replay.num_add()
metrics["bps/gen_policy"] = compute_gen_bps_policy()
metrics["bps/gen_policy_examples"] = (
metrics["bps/gen_policy"] * batch_size
)
if (epoch + 1) % self.cfg.selfplay.network_sync_epochs == 0 or epoch < 15:
logging.info("Copying current network to the eval network")
for model_locker in datagen["model_lockers"]:
model_locker.update_model(self.get_model())
if self.cfg.purging_epochs and (epoch + 1) in self.cfg.purging_epochs:
new_size = max(
burn_in_frames,
int((self.cfg.purging_share_keep or 0.0) * replay.size()),
)
logging.info(
"Going to purge everything but %d elements in the buffer", new_size,
)
replay.pop_until(new_size)
if self.is_master and epoch % 10 == 0:
with torch.no_grad():
for i, (name, val_loader) in enumerate(val_datasets):
self.train_timer.start("valid-acc-extra")
eval_errors = []
val_iter = (
tqdm.tqdm(val_loader, desc="Eval")
if use_progress_bar
else val_loader
)
for data in val_iter:
self.net.eval()
loss = self._compute_loss_dict(
data, train_device, use_policy_net=False
)["loss"]
eval_errors.append(loss.detach().item())
current_error = sum(eval_errors) / len(eval_errors)
logging.info(
"[Eval] epoch %d complete, data is %s, avg error is %f",
epoch,
name,
current_error,
)
metrics[f"loss/{name}"] = current_error
self.train_timer.start("valid-trace")
ckpt_path = save_dir / f"epoch{epoch}.ckpt"
torch.save(self.get_model().state_dict(), ckpt_path)
bin_path = ckpt_path.with_suffix(".torchscript")
torch.jit.save(torch.jit.script(self.get_model()), str(bin_path))
self.train_timer.start("valid-exploit")
if self.cfg.exploit and epoch % 20 == 0:
bin_path = pathlib.Path("tmp.torchscript")
torch.jit.save(torch.jit.script(self.get_model()), str(bin_path))
(
exploitability,
mse_net_traverse,
mse_fp_traverse,
) = cfvpy.rela.compute_stats_with_net(
create_mdp_config(self.cfg.env), str(bin_path)
)
logging.info(
"Exploitability to leaf (epoch=%d): %.2f", epoch, exploitability
)
metrics["exploitability_last"] = exploitability
metrics["eval_mse/net_reach"] = mse_net_traverse
metrics["eval_mse/fp_reach"] = mse_fp_traverse
if len(train_loader) > 0:
metrics["bps/loop"] = len(train_loader) / (time.time() - training_start)
total = 1e-5
for k, v in self.train_timer.timings.items():
metrics[f"timing/{k}"] = v / (epoch + 1)
total += v
for k, v in self.train_timer.timings.items():
metrics[f"timing_pct/{k}"] = v * 100 / total
logging.info("Metrics: %s", metrics)
if self.is_master:
logger.log_metrics(metrics)
logger.save()
return metrics