tzrec/protos/tower.proto (141 lines of code) (raw):
syntax = "proto2";
package tzrec.protos;
import "tzrec/protos/module.proto";
import "tzrec/protos/loss.proto";
import "tzrec/protos/metric.proto";
import "tzrec/protos/seq_encoder.proto";
message Tower {
// input feature group name
required string input = 1;
// mlp config
required MLP mlp = 2;
};
message HSTUMatchTower {
// input feature group name
required string input = 1;
// hstu_encoder config
required HSTUEncoder hstu_encoder = 2;
}
message DINTower {
// input feature group name
required string input = 1;
// mlp config for target attention score
required MLP attn_mlp = 2;
};
message TaskTower {
// task name for the task tower
required string tower_name = 1;
// label for the task
required string label_name = 2;
// metrics for the task
repeated MetricConfig metrics = 3;
// loss for the task
repeated LossConfig losses = 4;
// num_class for multi-class classification loss
optional uint32 num_class = 5 [default = 1];
// task specific mlp
optional MLP mlp = 6;
// training loss weights
optional float weight = 7 [default = 1.0];
// sample weight for the task
optional string sample_weight_name = 8;
// label name for indicating the sample space for the task tower
optional string task_space_indicator_label = 9;
// the loss weight for sample in the task space
optional float in_task_space_weight = 10 [default = 1.0];
// the loss weight for sample out the task space
optional float out_task_space_weight = 11 [default = 1.0];
};
message BayesTaskTower {
// task name for the task tower
required string tower_name = 1;
// label for the task, default is label_fields by order
optional string label_name = 2;
// metrics for the task
repeated MetricConfig metrics = 3;
// loss for the task
repeated LossConfig losses = 4;
// num_class for multi-class classification loss
optional uint32 num_class = 5 [default = 1];
// task specific mlp
optional MLP mlp = 6;
// training loss weights
optional float weight = 7 [default = 1.0];
// related tower names
repeated string relation_tower_names = 8;
// relation mlp
optional MLP relation_mlp = 9;
// sample weight for the task
optional string sample_weight_name = 10;
// label name for indicating the sample space for the task tower
optional string task_space_indicator_label = 11;
// the loss weight for sample in the task space
optional float in_task_space_weight = 12 [default = 1.0];
// the loss weight for sample out the task space
optional float out_task_space_weight = 13 [default = 1.0];
};
message InterventionTaskTower {
// task name for the task tower
required string tower_name = 1;
// label for the task, default is label_fields by order
optional string label_name = 2;
// metrics for the task
repeated MetricConfig metrics = 3;
// loss for the task
repeated LossConfig losses = 4;
// num_class for multi-class classification loss
optional uint32 num_class = 5 [default = 1];
// task specific mlp
optional MLP mlp = 6;
// training loss weights
optional float weight = 7 [default = 1.0];
// intervention tower names
repeated string intervention_tower_names = 8;
// low_rank_dim
required uint32 low_rank_dim = 9;
// dropout_ratio
optional float dropout_ratio = 10 [default = 0.1];
// label name for indicating the sample space for the task tower
optional string task_space_indicator_label = 11;
// the loss weight for sample in the task space
optional float in_task_space_weight = 12 [default = 1.0];
// the loss weight for sample out the task space
optional float out_task_space_weight = 13 [default = 1.0];
};
message MultiWindowDINTower {
// time windows len
repeated uint32 windows_len = 1;
// mlp config for target attention score
required MLP attn_mlp = 2;
}
message DATTower {
// input feature group name
required string input = 1;
// augmented feature group name
required string augment_input = 2;
// mlp config
required MLP mlp = 3;
}
message MINDUserTower {
enum UserSeqCombineMethod {
CONCAT = 0;
SUM = 1;
}
// user feature group name
required string input = 1;
// user history group name
required string history_input = 2;
required MLP user_mlp = 3;
optional MLP hist_seq_mlp = 4;
optional UserSeqCombineMethod user_seq_combine = 5 [default=SUM];
// capsule config
required B2ICapsule capsule_config = 6;
// concat mlp config for user interests vector
required MLP concat_mlp = 7;
}
message MINDItemTower {
// item feature group name
required string input = 1;
// mlp config
required MLP mlp = 2;
}