tzrec/protos/module.proto (67 lines of code) (raw):
syntax = "proto2";
package tzrec.protos;
message MLP {
// hidden units for each layer
repeated uint32 hidden_units = 1;
// ratio of dropout
repeated float dropout_ratio = 2;
// activation function
optional string activation = 3 [default = 'nn.ReLU'];
// use batch normalization
optional bool use_bn = 4 [default = false];
// use bias
optional bool bias = 5 [default = true];
}
message ExtractionNetwork {
required string network_name = 1;
// number of experts per task
required uint32 expert_num_per_task = 2;
// number of experts for share
optional uint32 share_num = 3;
// mlp network of experts per task
required MLP task_expert_net = 4;
// mlp network of experts for share
optional MLP share_expert_net = 5;
}
message VariationalDropout{
// regularization coefficient lambda
optional float regularization_lambda = 1 [default = 0.01];
// variational_dropout dimension
optional bool embedding_wise_variational_dropout = 2 [default = false];
}
message B2ICapsule{
// max number of high capsules Default: 5
optional uint32 max_k = 1 [default = 5];
// max behaviour sequence length
required uint32 max_seq_len = 2;
// high capsule embedding vector dimension
required uint32 high_dim = 3;
// dynamic routing iterations, Default: 3
optional uint32 num_iters = 4 [default = 3];
// routing logits scale Default: 20
optional float routing_logits_scale = 5 [default = 20];
// routing logits initial stddev Default: 1
optional float routing_logits_stddev = 6 [default = 1];
// squash power Default: 1
optional float squash_pow = 7 [default = 1];
// whether to use constant capsule number, Default: false
optional bool const_caps_num = 8 [default = false];
}
message MaskBlock{
// the ratio between aggregation dim and masked input dim
optional float reduction_ratio = 1 [default=1.0];
// the dim of aggregation layer
optional uint32 aggregation_dim = 2;
// the dim of hidden ffn layer
required uint32 hidden_dim = 3;
}
message MaskNetModule {
// number of mask blocks
required uint32 n_mask_blocks = 1;
// mask block
required MaskBlock mask_block = 2;
// mlp layer on top of mask blocks
required MLP top_mlp = 3;
// use parallel or serial mask blocks
optional bool use_parallel = 4 [default=true];
}