def _feature_dim_dropout_ratio()

in tzrec/tools/feature_selection.py [0:0]


    def _feature_dim_dropout_ratio(self) -> Dict[str, Dict[str, float]]:
        """Get dropout ratio of feature_groups feature-wise."""
        pipeline_config = config_util.load_pipeline_config(self._config_path)
        data_config = pipeline_config.data_config
        # Build feature
        features = _create_features(list(pipeline_config.feature_configs), data_config)
        model = _create_model(
            pipeline_config.model_config,
            features,
            list(data_config.label_fields),
        )
        model = ScriptWrapper(model)
        checkpoint_path, _ = checkpoint_util.latest_checkpoint(self._model_dir)
        if checkpoint_path:
            model_ckpt_path = os.path.join(checkpoint_path, "model")
            logger.info(
                f"Restoring model feature dropout ratio from {model_ckpt_path}..."
            )
            state_dict = model.state_dict()
            new_state_dict = {
                k: v for k, v in state_dict.items() if "group_variational_dropouts" in k
            }
            load(
                new_state_dict,
                checkpoint_id=model_ckpt_path,
            )
        else:
            raise ValueError("checkpoint path should be specified.")

        group_feature_importance = {}
        for name, sub_model in model.named_modules():
            if "group_variational_dropouts" == name.split(".")[-1]:
                for variational_dropout in sub_model.values():
                    group_name = variational_dropout.group_name
                    values = variational_dropout.feature_p.sigmoid().tolist()
                    feature_names = variational_dropout.features_dimension.keys()
                    group_feature_p = {
                        feature_name: dropout
                        for feature_name, dropout in zip(feature_names, values)
                    }
                    group_feature_importance[group_name] = group_feature_p

        if len(group_feature_importance) == 0:
            raise ValueError(
                "you not configure variational dropout "
                "or no group can be variational dropout."
            )
        return group_feature_importance