def save_model()

in kats/models/globalmodel/ensemble.py [0:0]


    def save_model(self, file_name: str) -> None:
        """Save ensemble model to file.

        Args:
            file_name: A string representing the file address and file name.
        """

        if len(self.gm_models) == 0:
            msg = "Please train global models before saving GMEnsemble."
            logging.error(msg)
            raise ValueError(msg)
        try:
            # clean-up unnecessary info
            [gm._reset_nn_states() for gm in self.gm_models]
            state_dict = (
                [gm.rnn.state_dict() for gm in self.gm_models]
                if self.params.model_type == "rnn"
                else None
            )
            encoder_dict = (
                [gm.encoder.state_dict() for gm in self.gm_models]
                if self.params.model_type == "s2s"
                else None
            )
            decoder_dict = (
                [gm.decoder.state_dict() for gm in self.gm_models]
                if self.params.model_type == "s2s"
                else None
            )
            gmparam_string = self.params.to_string()
            info = {
                "state_dict": state_dict,
                "encoder_dict": encoder_dict,
                "decoder_dict": decoder_dict,
                "gmparam_string": gmparam_string,
                "gm_info": self.gm_info,
                "test_ids": self.test_ids,
                "gmensemble_params": {},
            }
            for attr in [
                "splits",
                "overlap",
                "replicate",
                "multi",
                "max_core",
                "ensemble_type",
            ]:
                info["gmensemble_params"][attr] = getattr(self, attr)
            with open(file_name, "wb") as f:
                joblib.dump(info, f)
            logging.info(f"Successfully save GMEnsemble to {file_name}.")
        except Exception as e:
            msg = f"Fail to save GMEnsemble to {file_name} with Exception {e}."
            logging.error(msg)
            raise ValueError(msg)