in kats/models/globalmodel/ensemble.py [0:0]
def save_model(self, file_name: str) -> None:
"""Save ensemble model to file.
Args:
file_name: A string representing the file address and file name.
"""
if len(self.gm_models) == 0:
msg = "Please train global models before saving GMEnsemble."
logging.error(msg)
raise ValueError(msg)
try:
# clean-up unnecessary info
[gm._reset_nn_states() for gm in self.gm_models]
state_dict = (
[gm.rnn.state_dict() for gm in self.gm_models]
if self.params.model_type == "rnn"
else None
)
encoder_dict = (
[gm.encoder.state_dict() for gm in self.gm_models]
if self.params.model_type == "s2s"
else None
)
decoder_dict = (
[gm.decoder.state_dict() for gm in self.gm_models]
if self.params.model_type == "s2s"
else None
)
gmparam_string = self.params.to_string()
info = {
"state_dict": state_dict,
"encoder_dict": encoder_dict,
"decoder_dict": decoder_dict,
"gmparam_string": gmparam_string,
"gm_info": self.gm_info,
"test_ids": self.test_ids,
"gmensemble_params": {},
}
for attr in [
"splits",
"overlap",
"replicate",
"multi",
"max_core",
"ensemble_type",
]:
info["gmensemble_params"][attr] = getattr(self, attr)
with open(file_name, "wb") as f:
joblib.dump(info, f)
logging.info(f"Successfully save GMEnsemble to {file_name}.")
except Exception as e:
msg = f"Fail to save GMEnsemble to {file_name} with Exception {e}."
logging.error(msg)
raise ValueError(msg)