in train_deep_sdf.py [0:0]
def main_function(experiment_directory, continue_from, batch_split):
logging.debug("running " + experiment_directory)
specs = ws.load_experiment_specifications(experiment_directory)
logging.info("Experiment description: \n" + specs["Description"])
data_source = specs["DataSource"]
train_split_file = specs["TrainSplit"]
arch = __import__("networks." + specs["NetworkArch"], fromlist=["Decoder"])
logging.debug(specs["NetworkSpecs"])
latent_size = specs["CodeLength"]
checkpoints = list(
range(
specs["SnapshotFrequency"],
specs["NumEpochs"] + 1,
specs["SnapshotFrequency"],
)
)
for checkpoint in specs["AdditionalSnapshots"]:
checkpoints.append(checkpoint)
checkpoints.sort()
lr_schedules = get_learning_rate_schedules(specs)
grad_clip = get_spec_with_default(specs, "GradientClipNorm", None)
if grad_clip is not None:
logging.debug("clipping gradients to max norm {}".format(grad_clip))
def save_latest(epoch):
save_model(experiment_directory, "latest.pth", decoder, epoch)
save_optimizer(experiment_directory, "latest.pth", optimizer_all, epoch)
save_latent_vectors(experiment_directory, "latest.pth", lat_vecs, epoch)
def save_checkpoints(epoch):
save_model(experiment_directory, str(epoch) + ".pth", decoder, epoch)
save_optimizer(experiment_directory, str(epoch) + ".pth", optimizer_all, epoch)
save_latent_vectors(experiment_directory, str(epoch) + ".pth", lat_vecs, epoch)
def signal_handler(sig, frame):
logging.info("Stopping early...")
sys.exit(0)
def adjust_learning_rate(lr_schedules, optimizer, epoch):
for i, param_group in enumerate(optimizer.param_groups):
param_group["lr"] = lr_schedules[i].get_learning_rate(epoch)
def empirical_stat(latent_vecs, indices):
lat_mat = torch.zeros(0).cuda()
for ind in indices:
lat_mat = torch.cat([lat_mat, latent_vecs[ind]], 0)
mean = torch.mean(lat_mat, 0)
var = torch.var(lat_mat, 0)
return mean, var
signal.signal(signal.SIGINT, signal_handler)
num_samp_per_scene = specs["SamplesPerScene"]
scene_per_batch = specs["ScenesPerBatch"]
clamp_dist = specs["ClampingDistance"]
minT = -clamp_dist
maxT = clamp_dist
enforce_minmax = True
do_code_regularization = get_spec_with_default(specs, "CodeRegularization", True)
code_reg_lambda = get_spec_with_default(specs, "CodeRegularizationLambda", 1e-4)
code_bound = get_spec_with_default(specs, "CodeBound", None)
decoder = arch.Decoder(latent_size, **specs["NetworkSpecs"]).cuda()
logging.info("training with {} GPU(s)".format(torch.cuda.device_count()))
# if torch.cuda.device_count() > 1:
decoder = torch.nn.DataParallel(decoder)
num_epochs = specs["NumEpochs"]
log_frequency = get_spec_with_default(specs, "LogFrequency", 10)
with open(train_split_file, "r") as f:
train_split = json.load(f)
sdf_dataset = deep_sdf.data.SDFSamples(
data_source, train_split, num_samp_per_scene, load_ram=False
)
num_data_loader_threads = get_spec_with_default(specs, "DataLoaderThreads", 1)
logging.debug("loading data with {} threads".format(num_data_loader_threads))
sdf_loader = data_utils.DataLoader(
sdf_dataset,
batch_size=scene_per_batch,
shuffle=True,
num_workers=num_data_loader_threads,
drop_last=True,
)
logging.debug("torch num_threads: {}".format(torch.get_num_threads()))
num_scenes = len(sdf_dataset)
logging.info("There are {} scenes".format(num_scenes))
logging.debug(decoder)
lat_vecs = torch.nn.Embedding(num_scenes, latent_size, max_norm=code_bound)
torch.nn.init.normal_(
lat_vecs.weight.data,
0.0,
get_spec_with_default(specs, "CodeInitStdDev", 1.0) / math.sqrt(latent_size),
)
logging.debug(
"initialized with mean magnitude {}".format(
get_mean_latent_vector_magnitude(lat_vecs)
)
)
loss_l1 = torch.nn.L1Loss(reduction="sum")
optimizer_all = torch.optim.Adam(
[
{
"params": decoder.parameters(),
"lr": lr_schedules[0].get_learning_rate(0),
},
{
"params": lat_vecs.parameters(),
"lr": lr_schedules[1].get_learning_rate(0),
},
]
)
loss_log = []
lr_log = []
lat_mag_log = []
timing_log = []
param_mag_log = {}
start_epoch = 1
if continue_from is not None:
logging.info('continuing from "{}"'.format(continue_from))
lat_epoch = load_latent_vectors(
experiment_directory, continue_from + ".pth", lat_vecs
)
model_epoch = ws.load_model_parameters(
experiment_directory, continue_from, decoder
)
optimizer_epoch = load_optimizer(
experiment_directory, continue_from + ".pth", optimizer_all
)
loss_log, lr_log, timing_log, lat_mag_log, param_mag_log, log_epoch = load_logs(
experiment_directory
)
if not log_epoch == model_epoch:
loss_log, lr_log, timing_log, lat_mag_log, param_mag_log = clip_logs(
loss_log, lr_log, timing_log, lat_mag_log, param_mag_log, model_epoch
)
if not (model_epoch == optimizer_epoch and model_epoch == lat_epoch):
raise RuntimeError(
"epoch mismatch: {} vs {} vs {} vs {}".format(
model_epoch, optimizer_epoch, lat_epoch, log_epoch
)
)
start_epoch = model_epoch + 1
logging.debug("loaded")
logging.info("starting from epoch {}".format(start_epoch))
logging.info(
"Number of decoder parameters: {}".format(
sum(p.data.nelement() for p in decoder.parameters())
)
)
logging.info(
"Number of shape code parameters: {} (# codes {}, code dim {})".format(
lat_vecs.num_embeddings * lat_vecs.embedding_dim,
lat_vecs.num_embeddings,
lat_vecs.embedding_dim,
)
)
for epoch in range(start_epoch, num_epochs + 1):
start = time.time()
logging.info("epoch {}...".format(epoch))
decoder.train()
adjust_learning_rate(lr_schedules, optimizer_all, epoch)
for sdf_data, indices in sdf_loader:
# Process the input data
sdf_data = sdf_data.reshape(-1, 4)
num_sdf_samples = sdf_data.shape[0]
sdf_data.requires_grad = False
xyz = sdf_data[:, 0:3]
sdf_gt = sdf_data[:, 3].unsqueeze(1)
if enforce_minmax:
sdf_gt = torch.clamp(sdf_gt, minT, maxT)
xyz = torch.chunk(xyz, batch_split)
indices = torch.chunk(
indices.unsqueeze(-1).repeat(1, num_samp_per_scene).view(-1),
batch_split,
)
sdf_gt = torch.chunk(sdf_gt, batch_split)
batch_loss = 0.0
optimizer_all.zero_grad()
for i in range(batch_split):
batch_vecs = lat_vecs(indices[i])
input = torch.cat([batch_vecs, xyz[i]], dim=1)
# NN optimization
pred_sdf = decoder(input)
if enforce_minmax:
pred_sdf = torch.clamp(pred_sdf, minT, maxT)
chunk_loss = loss_l1(pred_sdf, sdf_gt[i].cuda()) / num_sdf_samples
if do_code_regularization:
l2_size_loss = torch.sum(torch.norm(batch_vecs, dim=1))
reg_loss = (
code_reg_lambda * min(1, epoch / 100) * l2_size_loss
) / num_sdf_samples
chunk_loss = chunk_loss + reg_loss.cuda()
chunk_loss.backward()
batch_loss += chunk_loss.item()
logging.debug("loss = {}".format(batch_loss))
loss_log.append(batch_loss)
if grad_clip is not None:
torch.nn.utils.clip_grad_norm_(decoder.parameters(), grad_clip)
optimizer_all.step()
end = time.time()
seconds_elapsed = end - start
timing_log.append(seconds_elapsed)
lr_log.append([schedule.get_learning_rate(epoch) for schedule in lr_schedules])
lat_mag_log.append(get_mean_latent_vector_magnitude(lat_vecs))
append_parameter_magnitudes(param_mag_log, decoder)
if epoch in checkpoints:
save_checkpoints(epoch)
if epoch % log_frequency == 0:
save_latest(epoch)
save_logs(
experiment_directory,
loss_log,
lr_log,
timing_log,
lat_mag_log,
param_mag_log,
epoch,
)