tzrec/protos/module.proto (67 lines of code) (raw):

syntax = "proto2"; package tzrec.protos; message MLP { // hidden units for each layer repeated uint32 hidden_units = 1; // ratio of dropout repeated float dropout_ratio = 2; // activation function optional string activation = 3 [default = 'nn.ReLU']; // use batch normalization optional bool use_bn = 4 [default = false]; // use bias optional bool bias = 5 [default = true]; } message ExtractionNetwork { required string network_name = 1; // number of experts per task required uint32 expert_num_per_task = 2; // number of experts for share optional uint32 share_num = 3; // mlp network of experts per task required MLP task_expert_net = 4; // mlp network of experts for share optional MLP share_expert_net = 5; } message VariationalDropout{ // regularization coefficient lambda optional float regularization_lambda = 1 [default = 0.01]; // variational_dropout dimension optional bool embedding_wise_variational_dropout = 2 [default = false]; } message B2ICapsule{ // max number of high capsules Default: 5 optional uint32 max_k = 1 [default = 5]; // max behaviour sequence length required uint32 max_seq_len = 2; // high capsule embedding vector dimension required uint32 high_dim = 3; // dynamic routing iterations, Default: 3 optional uint32 num_iters = 4 [default = 3]; // routing logits scale Default: 20 optional float routing_logits_scale = 5 [default = 20]; // routing logits initial stddev Default: 1 optional float routing_logits_stddev = 6 [default = 1]; // squash power Default: 1 optional float squash_pow = 7 [default = 1]; // whether to use constant capsule number, Default: false optional bool const_caps_num = 8 [default = false]; } message MaskBlock{ // the ratio between aggregation dim and masked input dim optional float reduction_ratio = 1 [default=1.0]; // the dim of aggregation layer optional uint32 aggregation_dim = 2; // the dim of hidden ffn layer required uint32 hidden_dim = 3; } message MaskNetModule { // number of mask blocks required uint32 n_mask_blocks = 1; // mask block required MaskBlock mask_block = 2; // mlp layer on top of mask blocks required MLP top_mlp = 3; // use parallel or serial mask blocks optional bool use_parallel = 4 [default=true]; }