def _init_base_searcher_params()

in nasrec/base_searcher.py [0:0]


    def _init_base_searcher_params(self):

        # get micro search space configurations
        self._set_micro_space_from_config()

        # constraint search space
        if (
            self.controller_option.macro_space_type
            == config.MacroSearchSpaceType.INPUT_GROUP
        ):
            self.num_dense_feat = 1
            self.num_sparse_feat = 1

        # length of the DAG to be searched (exclude the final clf layer)
        self.num_blocks = self.controller_option.max_num_block
        # block_types to be searched
        self.block_types = list(set(self.controller_option.block_types))
        self.num_block_type = len(self.block_types)
        if self.num_block_type == 0:
            raise ValueError("Should provide at least one block type to be searched.")

        # construct dictionaries to map between int and block types
        self.type_int_dict = {
            self.block_types[i]: i for i in range(self.num_block_type)
        }
        self.int_type_dict = {
            i: self.block_types[i] for i in range(self.num_block_type)
        }

        # all tokens to be searched
        self.num_tokens = {
            "block_type": self.num_block_type,
            "dense_feat": self.num_dense_feat,
            "sparse_feat": self.num_sparse_feat,
            "skip_connect": self.num_blocks,
        }
        self.token_names = ["block_type", "dense_feat", "sparse_feat", "skip_connect"]
        if (
            self.controller_option.macro_space_type
            == config.MacroSearchSpaceType.INPUT_ELASTIC_PRIOR
        ):
            # constraint search space with smooth learnable priors
            self.num_tokens["elastic_prior"] = 2
            self.token_names.append("elastic_prior")

        self.num_total_tokens = sum(v for _, v in self.num_tokens.items())

        if config.MicroSearchSpaceType.MICRO_MLP in self.micro_space_types:
            if (
                b_config.ExtendedBlockType.MLP_DENSE
                in self.controller_option.block_types
            ):
                self.num_tokens["mlp_dense"] = len(self.micro_mlp_option.arc)
                self.token_names.append("mlp_dense")
                self.num_total_tokens += 1
            if b_config.ExtendedBlockType.MLP_EMB in self.controller_option.block_types:
                self.num_tokens["mlp_emb"] = len(self.micro_mlp_option.arc)
                self.token_names.append("mlp_emb")
                self.num_total_tokens += 1

        if config.MicroSearchSpaceType.MICRO_CIN in self.micro_space_types:
            if b_config.ExtendedBlockType.CIN in self.controller_option.block_types:
                self.num_tokens["cin"] = len(self.micro_cin_option.arc) + len(
                    self.micro_cin_option.num_of_layers
                )
                self.token_names.append("cin")
                self.num_total_tokens += 1 if len(self.micro_cin_option.arc) > 0 else 0
                self.num_total_tokens += (
                    1 if len(self.micro_cin_option.num_of_layers) > 0 else 0
                )

        if config.MicroSearchSpaceType.MICRO_ATTENTION in self.micro_space_types:
            if (
                b_config.ExtendedBlockType.ATTENTION
                in self.controller_option.block_types
            ):
                self.att_num_tokens = {
                    "head": len(self.micro_attention_option.num_of_heads),
                    "layer": len(self.micro_attention_option.num_of_layers),
                    "emb": len(self.micro_attention_option.att_embed_dim),
                    "drop": len(self.micro_attention_option.dropout_prob),
                }
                self.num_tokens["attention"] = sum(
                    v for _, v in self.att_num_tokens.items()
                )

                self.token_names.append("attention")
                for _, v in self.att_num_tokens.items():
                    self.num_total_tokens += 1 if v != 0 else 0