in train_stpp.py [0:0]
def _main(rank, world_size, args, savepath, logger):
if rank == 0:
logger.info(args)
logger.info(f"Saving to {savepath}")
tb_writer = SummaryWriter(os.path.join(savepath, "tb_logdir"))
device = torch.device(f'cuda:{rank:d}' if torch.cuda.is_available() else 'cpu')
if rank == 0:
if device.type == 'cuda':
logger.info('Found {} CUDA devices.'.format(torch.cuda.device_count()))
for i in range(torch.cuda.device_count()):
props = torch.cuda.get_device_properties(i)
logger.info('{} \t Memory: {:.2f}GB'.format(props.name, props.total_memory / (1024**3)))
else:
logger.info('WARNING: Using device {}'.format(device))
t0, t1 = map(lambda x: cast(x, device), get_t0_t1(args.data))
train_set = load_data(args.data, split="train")
val_set = load_data(args.data, split="val")
test_set = load_data(args.data, split="test")
train_epoch_iter = EpochBatchIterator(
dataset=train_set,
collate_fn=datasets.spatiotemporal_events_collate_fn,
batch_sampler=train_set.batch_by_size(args.max_events),
seed=args.seed + rank,
)
val_loader = torch.utils.data.DataLoader(
val_set,
batch_size=args.test_bsz,
shuffle=False,
collate_fn=datasets.spatiotemporal_events_collate_fn,
)
test_loader = torch.utils.data.DataLoader(
test_set,
batch_size=args.test_bsz,
shuffle=False,
collate_fn=datasets.spatiotemporal_events_collate_fn,
)
if rank == 0:
logger.info(f"{len(train_set)} training examples, {len(val_set)} val examples, {len(test_set)} test examples")
x_dim = get_dim(args.data)
if args.model == "jumpcnf" and args.tpp == "neural":
model = JumpCNFSpatiotemporalModel(dim=x_dim,
hidden_dims=list(map(int, args.hdims.split("-"))),
tpp_hidden_dims=list(map(int, args.tpp_hdims.split("-"))),
actfn=args.actfn,
tpp_cond=args.tpp_cond,
tpp_style=args.tpp_style,
tpp_actfn=args.tpp_actfn,
share_hidden=args.share_hidden,
solve_reverse=args.solve_reverse,
tol=args.tol,
otreg_strength=args.otreg_strength,
tpp_otreg_strength=args.tpp_otreg_strength,
layer_type=args.layer_type,
).to(device)
elif args.model == "attncnf" and args.tpp == "neural":
model = SelfAttentiveCNFSpatiotemporalModel(dim=x_dim,
hidden_dims=list(map(int, args.hdims.split("-"))),
tpp_hidden_dims=list(map(int, args.tpp_hdims.split("-"))),
actfn=args.actfn,
tpp_cond=args.tpp_cond,
tpp_style=args.tpp_style,
tpp_actfn=args.tpp_actfn,
share_hidden=args.share_hidden,
solve_reverse=args.solve_reverse,
l2_attn=args.l2_attn,
tol=args.tol,
otreg_strength=args.otreg_strength,
tpp_otreg_strength=args.tpp_otreg_strength,
layer_type=args.layer_type,
lowvar_trace=not args.naive_hutch,
).to(device)
elif args.model == "cond_gmm" and args.tpp == "neural":
model = JumpGMMSpatiotemporalModel(dim=x_dim,
hidden_dims=list(map(int, args.hdims.split("-"))),
tpp_hidden_dims=list(map(int, args.tpp_hdims.split("-"))),
actfn=args.actfn,
tpp_cond=args.tpp_cond,
tpp_style=args.tpp_style,
tpp_actfn=args.tpp_actfn,
share_hidden=args.share_hidden,
tol=args.tol,
tpp_otreg_strength=args.tpp_otreg_strength,
).to(device)
else:
# Mix and match between spatial and temporal models.
if args.tpp == "poisson":
tpp_model = HomogeneousPoissonPointProcess()
elif args.tpp == "hawkes":
tpp_model = HawkesPointProcess()
elif args.tpp == "correcting":
tpp_model = SelfCorrectingPointProcess()
elif args.tpp == "neural":
tpp_hidden_dims = list(map(int, args.tpp_hdims.split("-")))
tpp_model = NeuralPointProcess(
cond_dim=x_dim, hidden_dims=tpp_hidden_dims, cond=args.tpp_cond, style=args.tpp_style, actfn=args.tpp_actfn,
otreg_strength=args.tpp_otreg_strength, tol=args.tol)
else:
raise ValueError(f"Invalid tpp model {args.tpp}")
if args.model == "gmm":
model = CombinedSpatiotemporalModel(GaussianMixtureSpatialModel(), tpp_model).to(device)
elif args.model == "cnf":
model = CombinedSpatiotemporalModel(
IndependentCNF(dim=x_dim, hidden_dims=list(map(int, args.hdims.split("-"))),
layer_type=args.layer_type, actfn=args.actfn, tol=args.tol, otreg_strength=args.otreg_strength,
squash_time=True),
tpp_model).to(device)
elif args.model == "tvcnf":
model = CombinedSpatiotemporalModel(
IndependentCNF(dim=x_dim, hidden_dims=list(map(int, args.hdims.split("-"))),
layer_type=args.layer_type, actfn=args.actfn, tol=args.tol, otreg_strength=args.otreg_strength),
tpp_model).to(device)
elif args.model == "jumpcnf":
model = CombinedSpatiotemporalModel(
JumpCNF(dim=x_dim, hidden_dims=list(map(int, args.hdims.split("-"))),
layer_type=args.layer_type, actfn=args.actfn, tol=args.tol, otreg_strength=args.otreg_strength),
tpp_model).to(device)
elif args.model == "attncnf":
model = CombinedSpatiotemporalModel(
SelfAttentiveCNF(dim=x_dim, hidden_dims=list(map(int, args.hdims.split("-"))),
layer_type=args.layer_type, actfn=args.actfn, l2_attn=args.l2_attn, tol=args.tol, otreg_strength=args.otreg_strength),
tpp_model).to(device)
else:
raise ValueError(f"Invalid model {args.model}")
params = []
attn_params = []
for name, p in model.named_parameters():
if "self_attns" in name:
attn_params.append(p)
else:
params.append(p)
optimizer = torch.optim.AdamW([
{"params": params},
{"params": attn_params}
], lr=args.lr, weight_decay=args.weight_decay, betas=(0.9, 0.98))
if rank == 0:
ema = utils.ExponentialMovingAverage(model)
model = DDP(model, device_ids=[rank], find_unused_parameters=True)
if rank == 0:
logger.info(model)
begin_itr = 0
checkpt_path = os.path.join(savepath, "model.pth")
if os.path.exists(checkpt_path):
# Restart from checkpoint if run is a restart.
if rank == 0:
logger.info(f"Resuming checkpoint from {checkpt_path}")
checkpt = torch.load(checkpt_path, "cpu")
model.module.load_state_dict(checkpt["state_dict"])
optimizer.load_state_dict(checkpt["optim_state_dict"])
begin_itr = checkpt["itr"] + 1
elif args.resume:
# Check the resume flag if run is new.
if rank == 0:
logger.info(f"Resuming model from {args.resume}")
checkpt = torch.load(args.resume, "cpu")
model.module.load_state_dict(checkpt["state_dict"])
optimizer.load_state_dict(checkpt["optim_state_dict"])
begin_itr = checkpt["itr"] + 1
space_loglik_meter = utils.RunningAverageMeter(0.98)
time_loglik_meter = utils.RunningAverageMeter(0.98)
gradnorm_meter = utils.RunningAverageMeter(0.98)
model.train()
start_time = time.time()
iteration_counter = itertools.count(begin_itr)
begin_epoch = begin_itr // len(train_epoch_iter)
for epoch in range(begin_epoch, math.ceil(args.num_iterations / len(train_epoch_iter))):
batch_iter = train_epoch_iter.next_epoch_itr(shuffle=True)
for batch in batch_iter:
itr = next(iteration_counter)
optimizer.zero_grad()
event_times, spatial_locations, input_mask = map(lambda x: cast(x, device), batch)
N, T = input_mask.shape
num_events = input_mask.sum()
if num_events == 0:
raise RuntimeError("Got batch with no observations.")
space_loglik, time_loglik = model(event_times, spatial_locations, input_mask, t0, t1)
space_loglik = space_loglik.sum() / num_events
time_loglik = time_loglik.sum() / num_events
loglik = time_loglik + space_loglik
space_loglik_meter.update(space_loglik.item())
time_loglik_meter.update(time_loglik.item())
loss = loglik.mul(-1.0).mean()
loss.backward()
# Set learning rate
total_itrs = math.ceil(args.num_iterations / len(train_epoch_iter)) * len(train_epoch_iter)
lr = learning_rate_schedule(itr, args.warmup_itrs, args.lr, total_itrs)
set_learning_rate(optimizer, lr)
grad_norm = torch.nn.utils.clip_grad.clip_grad_norm_(model.parameters(), max_norm=args.gradclip).item()
gradnorm_meter.update(grad_norm)
optimizer.step()
if rank == 0:
if itr > 0.8 * args.num_iterations:
ema.apply()
else:
ema.apply(decay=0.0)
if rank == 0:
tb_writer.add_scalar("train/lr", lr, itr)
tb_writer.add_scalar("train/temporal_loss", time_loglik.item(), itr)
tb_writer.add_scalar("train/spatial_loss", space_loglik.item(), itr)
tb_writer.add_scalar("train/grad_norm", grad_norm, itr)
if itr % args.logfreq == 0:
elapsed_time = time.time() - start_time
# Average NFE across devices.
nfe = 0
for m in model.modules():
if isinstance(m, TimeVariableCNF) or isinstance(m, TimeVariableODE):
nfe += m.nfe
nfe = torch.tensor(nfe).to(device)
dist.all_reduce(nfe, op=dist.ReduceOp.SUM)
nfe = nfe // world_size
# Sum memory usage across devices.
mem = torch.tensor(memory_usage_psutil()).float().to(device)
dist.all_reduce(mem, op=dist.ReduceOp.SUM)
if rank == 0:
logger.info(
f"Iter {itr} | Epoch {epoch} | LR {lr:.5f} | Time {elapsed_time:.1f}"
f" | Temporal {time_loglik_meter.val:.4f}({time_loglik_meter.avg:.4f})"
f" | Spatial {space_loglik_meter.val:.4f}({space_loglik_meter.avg:.4f})"
f" | GradNorm {gradnorm_meter.val:.2f}({gradnorm_meter.avg:.2f})"
f" | NFE {nfe.item()}"
f" | Mem {mem.item():.2f} MB")
tb_writer.add_scalar("train/nfe", nfe, itr)
tb_writer.add_scalar("train/time_per_itr", elapsed_time / args.logfreq, itr)
start_time = time.time()
if rank == 0 and itr % args.testfreq == 0:
# ema.swap()
val_space_loglik, val_time_loglik = validate(model, val_loader, t0, t1, device)
test_space_loglik, test_time_loglik = validate(model, test_loader, t0, t1, device)
# ema.swap()
logger.info(
f"[Test] Iter {itr} | Val Temporal {val_time_loglik:.4f} | Val Spatial {val_space_loglik:.4f}"
f" | Test Temporal {test_time_loglik:.4f} | Test Spatial {test_space_loglik:.4f}")
tb_writer.add_scalar("val/temporal_loss", val_time_loglik, itr)
tb_writer.add_scalar("val/spatial_loss", val_space_loglik, itr)
tb_writer.add_scalar("test/temporal_loss", test_time_loglik, itr)
tb_writer.add_scalar("test/spatial_loss", test_space_loglik, itr)
torch.save({
"itr": itr,
"state_dict": model.module.state_dict(),
"optim_state_dict": optimizer.state_dict(),
"ema_parmas": ema.ema_params,
}, checkpt_path)
start_time = time.time()
if rank == 0:
tb_writer.close()