in train_deep_sdf.py [0:0]
def load_latent_vectors(experiment_directory, filename, lat_vecs):
full_filename = os.path.join(
ws.get_latent_codes_dir(experiment_directory), filename
)
if not os.path.isfile(full_filename):
raise Exception('latent state file "{}" does not exist'.format(full_filename))
data = torch.load(full_filename)
if isinstance(data["latent_codes"], torch.Tensor):
# for backwards compatibility
if not lat_vecs.num_embeddings == data["latent_codes"].size()[0]:
raise Exception(
"num latent codes mismatched: {} vs {}".format(
lat_vecs.num_embeddings, data["latent_codes"].size()[0]
)
)
if not lat_vecs.embedding_dim == data["latent_codes"].size()[2]:
raise Exception("latent code dimensionality mismatch")
for i, lat_vec in enumerate(data["latent_codes"]):
lat_vecs.weight.data[i, :] = lat_vec
else:
lat_vecs.load_state_dict(data["latent_codes"])
return data["epoch"]