tzrec/protos/tower.proto (141 lines of code) (raw):

syntax = "proto2"; package tzrec.protos; import "tzrec/protos/module.proto"; import "tzrec/protos/loss.proto"; import "tzrec/protos/metric.proto"; import "tzrec/protos/seq_encoder.proto"; message Tower { // input feature group name required string input = 1; // mlp config required MLP mlp = 2; }; message HSTUMatchTower { // input feature group name required string input = 1; // hstu_encoder config required HSTUEncoder hstu_encoder = 2; } message DINTower { // input feature group name required string input = 1; // mlp config for target attention score required MLP attn_mlp = 2; }; message TaskTower { // task name for the task tower required string tower_name = 1; // label for the task required string label_name = 2; // metrics for the task repeated MetricConfig metrics = 3; // loss for the task repeated LossConfig losses = 4; // num_class for multi-class classification loss optional uint32 num_class = 5 [default = 1]; // task specific mlp optional MLP mlp = 6; // training loss weights optional float weight = 7 [default = 1.0]; // sample weight for the task optional string sample_weight_name = 8; // label name for indicating the sample space for the task tower optional string task_space_indicator_label = 9; // the loss weight for sample in the task space optional float in_task_space_weight = 10 [default = 1.0]; // the loss weight for sample out the task space optional float out_task_space_weight = 11 [default = 1.0]; }; message BayesTaskTower { // task name for the task tower required string tower_name = 1; // label for the task, default is label_fields by order optional string label_name = 2; // metrics for the task repeated MetricConfig metrics = 3; // loss for the task repeated LossConfig losses = 4; // num_class for multi-class classification loss optional uint32 num_class = 5 [default = 1]; // task specific mlp optional MLP mlp = 6; // training loss weights optional float weight = 7 [default = 1.0]; // related tower names repeated string relation_tower_names = 8; // relation mlp optional MLP relation_mlp = 9; // sample weight for the task optional string sample_weight_name = 10; // label name for indicating the sample space for the task tower optional string task_space_indicator_label = 11; // the loss weight for sample in the task space optional float in_task_space_weight = 12 [default = 1.0]; // the loss weight for sample out the task space optional float out_task_space_weight = 13 [default = 1.0]; }; message InterventionTaskTower { // task name for the task tower required string tower_name = 1; // label for the task, default is label_fields by order optional string label_name = 2; // metrics for the task repeated MetricConfig metrics = 3; // loss for the task repeated LossConfig losses = 4; // num_class for multi-class classification loss optional uint32 num_class = 5 [default = 1]; // task specific mlp optional MLP mlp = 6; // training loss weights optional float weight = 7 [default = 1.0]; // intervention tower names repeated string intervention_tower_names = 8; // low_rank_dim required uint32 low_rank_dim = 9; // dropout_ratio optional float dropout_ratio = 10 [default = 0.1]; // label name for indicating the sample space for the task tower optional string task_space_indicator_label = 11; // the loss weight for sample in the task space optional float in_task_space_weight = 12 [default = 1.0]; // the loss weight for sample out the task space optional float out_task_space_weight = 13 [default = 1.0]; }; message MultiWindowDINTower { // time windows len repeated uint32 windows_len = 1; // mlp config for target attention score required MLP attn_mlp = 2; } message DATTower { // input feature group name required string input = 1; // augmented feature group name required string augment_input = 2; // mlp config required MLP mlp = 3; } message MINDUserTower { enum UserSeqCombineMethod { CONCAT = 0; SUM = 1; } // user feature group name required string input = 1; // user history group name required string history_input = 2; required MLP user_mlp = 3; optional MLP hist_seq_mlp = 4; optional UserSeqCombineMethod user_seq_combine = 5 [default=SUM]; // capsule config required B2ICapsule capsule_config = 6; // concat mlp config for user interests vector required MLP concat_mlp = 7; } message MINDItemTower { // item feature group name required string input = 1; // mlp config required MLP mlp = 2; }