def vec_to_block_config()

in nasrec/base_searcher.py [0:0]


    def vec_to_block_config(self, vec, block_id):
        """convert a controller vector to block_config
        """
        # split a vector and convert the corresponding part to the id format
        block_type_id = (
            vec["block_type"].numpy()[0]
            if type(vec["block_type"]) is torch.Tensor
            else vec["block_type"]
        )
        input_dense = vec["dense_feat"]
        input_sparse = vec["sparse_feat"]
        skip_connection = vec["skip_connect"]

        if (
            self.controller_option.macro_space_type
            == config.MacroSearchSpaceType.INPUT_GROUP
        ):
            input_dense_id = [-1] if input_dense == 1 else []
            input_sparse_id = [-1] if input_sparse == 1 else []
        else:
            input_dense_id = [i for i, e in enumerate(input_dense) if e == 1]
            input_sparse_id = [i for i, e in enumerate(input_sparse) if e == 1]
        skip_connection_id = [
            i + 1 for i, e in enumerate(skip_connection) if e == 1 and i + 1 < block_id
        ]

        dense_as_sparse = (
            True
            if config.FeatureProcessingType.IDASP in self.feature_processing_type
            else False
        )

        # construct input config
        # orignal input features
        input_feat_config = [
            b_config.FeatSelectionConfig(
                block_id=0, dense=input_dense_id, sparse=input_sparse_id
            )
        ]
        # input from other blocks' outputs
        input_feat_config += [
            b_config.FeatSelectionConfig(block_id=id, dense=[-1], sparse=[-1])
            for id in skip_connection_id
        ]

        comm_embed_dim = self.sparse_feature_options.embed_dim

        block_type = self.int_type_dict[block_type_id]
        if block_type == b_config.ExtendedBlockType.CROSSNET:
            block_config = b_config.BlockConfig(
                crossnet_block=b_config.CrossNetBlockConfig(
                    name="CrossNetBlocks",
                    block_id=block_id,
                    num_of_layers=1,
                    input_feat_config=input_feat_config,
                    cross_feat_config=input_feat_config,
                )
            )
        elif block_type == b_config.ExtendedBlockType.ATTENTION:

            head, layer, emb, drop = (
                (
                    self.micro_attention_option.num_of_heads[vec["attention"]["head"]],
                    self.micro_attention_option.num_of_layers[
                        vec["attention"]["layer"]
                    ],
                    self.micro_attention_option.att_embed_dim[vec["attention"]["emb"]],
                    self.micro_attention_option.dropout_prob[vec["attention"]["drop"]],
                )
                if "attention" in vec
                else (2, 1, 10, 0.0)
            )
            block_config = b_config.BlockConfig(
                attention_block=b_config.AttentionBlockConfig(
                    name="AttentionBlock",
                    block_id=block_id,
                    input_feat_config=input_feat_config,
                    emb_config=b_config.EmbedBlockType(
                        comm_embed_dim=comm_embed_dim, dense_as_sparse=dense_as_sparse
                    ),
                    att_embed_dim=emb,
                    num_of_heads=head,
                    num_of_layers=layer,
                    dropout_prob=drop,
                    use_res=True,
                    batchnorm=False,
                )
            )
        elif block_type == b_config.ExtendedBlockType.CIN:
            arc = (
                [self.micro_cin_option.arc[vec["cin"]["width"]]]
                * self.micro_cin_option.num_of_layers[vec["cin"]["depth"]]
                if "cin" in vec
                else [128]
            )
            block_config = b_config.BlockConfig(
                cin_block=b_config.CINBlockConfig(
                    name="CINBlock",
                    block_id=block_id,
                    emb_config=b_config.EmbedBlockType(
                        comm_embed_dim=comm_embed_dim, dense_as_sparse=dense_as_sparse
                    ),
                    arc=arc,
                    split_half=True,
                    input_feat_config=input_feat_config,
                )
            )
        elif block_type == b_config.ExtendedBlockType.MLP_DENSE:
            arc = (
                self.micro_mlp_option.arc[vec["mlp_dense"]]
                if "mlp_dense" in vec
                else 128
            )
            block_config = b_config.BlockConfig(
                mlp_block=b_config.MLPBlockConfig(
                    name="MLPBlock",
                    block_id=block_id,
                    arc=[arc],
                    type=b_config.BlockType(dense=b_config.DenseBlockType()),
                    input_feat_config=input_feat_config,
                )
            )
        elif block_type == b_config.ExtendedBlockType.MLP_EMB:
            arc = self.micro_mlp_option.arc[vec["mlp_emb"]] if "mlp_emb" in vec else 128
            block_config = b_config.BlockConfig(
                mlp_block=b_config.MLPBlockConfig(
                    name="MLPBlock",
                    block_id=block_id,
                    arc=[arc],
                    type=b_config.BlockType(
                        emb=b_config.EmbedBlockType(
                            comm_embed_dim=comm_embed_dim,
                            dense_as_sparse=dense_as_sparse,
                        )
                    ),
                    input_feat_config=input_feat_config,
                )
            )
        elif block_type == b_config.ExtendedBlockType.FM_DENSE:
            block_config = b_config.BlockConfig(
                fm_block=b_config.FMBlockConfig(
                    name="FMBlock",
                    block_id=block_id,
                    type=b_config.BlockType(dense=b_config.DenseBlockType()),
                    input_feat_config=input_feat_config,
                )
            )
        elif block_type == b_config.ExtendedBlockType.FM_EMB:
            block_config = b_config.BlockConfig(
                fm_block=b_config.FMBlockConfig(
                    name="FMBlock",
                    block_id=block_id,
                    type=b_config.BlockType(
                        emb=b_config.EmbedBlockType(
                            comm_embed_dim=comm_embed_dim,
                            dense_as_sparse=dense_as_sparse,
                        )
                    ),
                    input_feat_config=input_feat_config,
                )
            )
        elif block_type == b_config.ExtendedBlockType.DOTPROCESSOR_DENSE:
            block_config = b_config.BlockConfig(
                dotprocessor_block=b_config.DotProcessorBlockConfig(
                    name="DotProcessorBlock",
                    block_id=block_id,
                    type=b_config.BlockType(dense=b_config.DenseBlockType()),
                    input_feat_config=input_feat_config,
                )
            )
        elif block_type == b_config.ExtendedBlockType.DOTPROCESSOR_EMB:
            block_config = b_config.BlockConfig(
                dotprocessor_block=b_config.DotProcessorBlockConfig(
                    name="DotProcessorBlock",
                    block_id=block_id,
                    type=b_config.BlockType(
                        emb=b_config.EmbedBlockType(
                            comm_embed_dim=comm_embed_dim,
                            dense_as_sparse=dense_as_sparse,
                        )
                    ),
                    input_feat_config=input_feat_config,
                )
            )
        elif block_type == b_config.ExtendedBlockType.CAT_DENSE:
            block_config = b_config.BlockConfig(
                cat_block=b_config.CatBlockConfig(
                    name="CatBlock",
                    block_id=block_id,
                    type=b_config.BlockType(dense=b_config.DenseBlockType()),
                    input_feat_config=input_feat_config,
                )
            )
        elif block_type == b_config.ExtendedBlockType.CAT_EMB:
            block_config = b_config.BlockConfig(
                cat_block=b_config.CatBlockConfig(
                    name="CatBlock",
                    block_id=block_id,
                    type=b_config.BlockType(
                        emb=b_config.EmbedBlockType(
                            comm_embed_dim=comm_embed_dim,
                            dense_as_sparse=dense_as_sparse,
                        )
                    ),
                    input_feat_config=input_feat_config,
                )
            )
        return block_config