in nasrec/base_searcher.py [0:0]
def vec_to_block_config(self, vec, block_id):
"""convert a controller vector to block_config
"""
# split a vector and convert the corresponding part to the id format
block_type_id = (
vec["block_type"].numpy()[0]
if type(vec["block_type"]) is torch.Tensor
else vec["block_type"]
)
input_dense = vec["dense_feat"]
input_sparse = vec["sparse_feat"]
skip_connection = vec["skip_connect"]
if (
self.controller_option.macro_space_type
== config.MacroSearchSpaceType.INPUT_GROUP
):
input_dense_id = [-1] if input_dense == 1 else []
input_sparse_id = [-1] if input_sparse == 1 else []
else:
input_dense_id = [i for i, e in enumerate(input_dense) if e == 1]
input_sparse_id = [i for i, e in enumerate(input_sparse) if e == 1]
skip_connection_id = [
i + 1 for i, e in enumerate(skip_connection) if e == 1 and i + 1 < block_id
]
dense_as_sparse = (
True
if config.FeatureProcessingType.IDASP in self.feature_processing_type
else False
)
# construct input config
# orignal input features
input_feat_config = [
b_config.FeatSelectionConfig(
block_id=0, dense=input_dense_id, sparse=input_sparse_id
)
]
# input from other blocks' outputs
input_feat_config += [
b_config.FeatSelectionConfig(block_id=id, dense=[-1], sparse=[-1])
for id in skip_connection_id
]
comm_embed_dim = self.sparse_feature_options.embed_dim
block_type = self.int_type_dict[block_type_id]
if block_type == b_config.ExtendedBlockType.CROSSNET:
block_config = b_config.BlockConfig(
crossnet_block=b_config.CrossNetBlockConfig(
name="CrossNetBlocks",
block_id=block_id,
num_of_layers=1,
input_feat_config=input_feat_config,
cross_feat_config=input_feat_config,
)
)
elif block_type == b_config.ExtendedBlockType.ATTENTION:
head, layer, emb, drop = (
(
self.micro_attention_option.num_of_heads[vec["attention"]["head"]],
self.micro_attention_option.num_of_layers[
vec["attention"]["layer"]
],
self.micro_attention_option.att_embed_dim[vec["attention"]["emb"]],
self.micro_attention_option.dropout_prob[vec["attention"]["drop"]],
)
if "attention" in vec
else (2, 1, 10, 0.0)
)
block_config = b_config.BlockConfig(
attention_block=b_config.AttentionBlockConfig(
name="AttentionBlock",
block_id=block_id,
input_feat_config=input_feat_config,
emb_config=b_config.EmbedBlockType(
comm_embed_dim=comm_embed_dim, dense_as_sparse=dense_as_sparse
),
att_embed_dim=emb,
num_of_heads=head,
num_of_layers=layer,
dropout_prob=drop,
use_res=True,
batchnorm=False,
)
)
elif block_type == b_config.ExtendedBlockType.CIN:
arc = (
[self.micro_cin_option.arc[vec["cin"]["width"]]]
* self.micro_cin_option.num_of_layers[vec["cin"]["depth"]]
if "cin" in vec
else [128]
)
block_config = b_config.BlockConfig(
cin_block=b_config.CINBlockConfig(
name="CINBlock",
block_id=block_id,
emb_config=b_config.EmbedBlockType(
comm_embed_dim=comm_embed_dim, dense_as_sparse=dense_as_sparse
),
arc=arc,
split_half=True,
input_feat_config=input_feat_config,
)
)
elif block_type == b_config.ExtendedBlockType.MLP_DENSE:
arc = (
self.micro_mlp_option.arc[vec["mlp_dense"]]
if "mlp_dense" in vec
else 128
)
block_config = b_config.BlockConfig(
mlp_block=b_config.MLPBlockConfig(
name="MLPBlock",
block_id=block_id,
arc=[arc],
type=b_config.BlockType(dense=b_config.DenseBlockType()),
input_feat_config=input_feat_config,
)
)
elif block_type == b_config.ExtendedBlockType.MLP_EMB:
arc = self.micro_mlp_option.arc[vec["mlp_emb"]] if "mlp_emb" in vec else 128
block_config = b_config.BlockConfig(
mlp_block=b_config.MLPBlockConfig(
name="MLPBlock",
block_id=block_id,
arc=[arc],
type=b_config.BlockType(
emb=b_config.EmbedBlockType(
comm_embed_dim=comm_embed_dim,
dense_as_sparse=dense_as_sparse,
)
),
input_feat_config=input_feat_config,
)
)
elif block_type == b_config.ExtendedBlockType.FM_DENSE:
block_config = b_config.BlockConfig(
fm_block=b_config.FMBlockConfig(
name="FMBlock",
block_id=block_id,
type=b_config.BlockType(dense=b_config.DenseBlockType()),
input_feat_config=input_feat_config,
)
)
elif block_type == b_config.ExtendedBlockType.FM_EMB:
block_config = b_config.BlockConfig(
fm_block=b_config.FMBlockConfig(
name="FMBlock",
block_id=block_id,
type=b_config.BlockType(
emb=b_config.EmbedBlockType(
comm_embed_dim=comm_embed_dim,
dense_as_sparse=dense_as_sparse,
)
),
input_feat_config=input_feat_config,
)
)
elif block_type == b_config.ExtendedBlockType.DOTPROCESSOR_DENSE:
block_config = b_config.BlockConfig(
dotprocessor_block=b_config.DotProcessorBlockConfig(
name="DotProcessorBlock",
block_id=block_id,
type=b_config.BlockType(dense=b_config.DenseBlockType()),
input_feat_config=input_feat_config,
)
)
elif block_type == b_config.ExtendedBlockType.DOTPROCESSOR_EMB:
block_config = b_config.BlockConfig(
dotprocessor_block=b_config.DotProcessorBlockConfig(
name="DotProcessorBlock",
block_id=block_id,
type=b_config.BlockType(
emb=b_config.EmbedBlockType(
comm_embed_dim=comm_embed_dim,
dense_as_sparse=dense_as_sparse,
)
),
input_feat_config=input_feat_config,
)
)
elif block_type == b_config.ExtendedBlockType.CAT_DENSE:
block_config = b_config.BlockConfig(
cat_block=b_config.CatBlockConfig(
name="CatBlock",
block_id=block_id,
type=b_config.BlockType(dense=b_config.DenseBlockType()),
input_feat_config=input_feat_config,
)
)
elif block_type == b_config.ExtendedBlockType.CAT_EMB:
block_config = b_config.BlockConfig(
cat_block=b_config.CatBlockConfig(
name="CatBlock",
block_id=block_id,
type=b_config.BlockType(
emb=b_config.EmbedBlockType(
comm_embed_dim=comm_embed_dim,
dense_as_sparse=dense_as_sparse,
)
),
input_feat_config=input_feat_config,
)
)
return block_config