tzrec/protos/models/multi_task_rank.proto (47 lines of code) (raw):
syntax = "proto2";
package tzrec.protos;
import "tzrec/protos/module.proto";
import "tzrec/protos/tower.proto";
message SimpleMultiTask {
repeated TaskTower task_towers = 1;
}
message MMoE {
// mmoe expert module definition
required MLP expert_mlp = 1;
// mmoe gate module definition
optional MLP gate_mlp = 2;
// number of mmoe experts
required uint32 num_expert = 3 [default=3];
// task tower
repeated TaskTower task_towers = 4;
}
message DBMTL {
// shared bottom mlp layer
optional MLP bottom_mlp = 1;
// mmoe expert mlp layer definition
optional MLP expert_mlp = 2;
// mmoe gate module definition
optional MLP gate_mlp = 3;
// number of mmoe experts
optional uint32 num_expert = 4 [default=3];
// bayes task tower
repeated BayesTaskTower task_towers = 5;
}
message DC2VR {
// shared bottom mlp layer
optional MLP bottom_mlp = 1;
// mmoe expert mlp layer definition
optional MLP expert_mlp = 2;
// mmoe gate module definition
optional MLP gate_mlp = 3;
// number of mmoe experts
optional uint32 num_expert = 4 [default=3];
// task tower
repeated InterventionTaskTower task_towers = 5;
}
message PLE {
// extraction network
repeated ExtractionNetwork extraction_networks = 1;
// task tower
repeated TaskTower task_towers = 2;
}