tzrec/protos/model.proto (59 lines of code) (raw):
syntax = "proto2";
package tzrec.protos;
import "tzrec/protos/models/rank_model.proto";
import "tzrec/protos/models/multi_task_rank.proto";
import "tzrec/protos/models/match_model.proto";
import "tzrec/protos/models/general_rank_model.proto";
import "tzrec/protos/loss.proto";
import "tzrec/protos/metric.proto";
import "tzrec/protos/seq_encoder.proto";
import "tzrec/protos/module.proto";
enum FeatureGroupType {
DEEP = 0;
WIDE = 1;
SEQUENCE = 2;
}
message SeqGroupConfig {
optional string group_name = 1;
repeated string feature_names = 2;
}
message FeatureGroupConfig {
required string group_name = 1;
repeated string feature_names = 2;
required FeatureGroupType group_type = 3 [default = DEEP];
repeated SeqGroupConfig sequence_groups = 4;
repeated SeqEncoderConfig sequence_encoders = 5;
}
enum Kernel {
TRITON = 0;
PYTORCH = 1;
CUDA = 2;
}
message ModelConfig {
repeated FeatureGroupConfig feature_groups = 1;
oneof model {
DLRM dlrm = 100;
DeepFM deepfm = 101;
MultiTower multi_tower = 102;
MultiTowerDIN multi_tower_din = 103;
MaskNet mask_net = 104;
SimpleMultiTask simple_multi_task = 200;
MMoE mmoe = 201;
DBMTL dbmtl = 202;
PLE ple = 203;
DC2VR dc2vr = 204;
DSSM dssm = 301;
DSSMV2 dssm_v2 = 302;
DAT dat = 303;
HSTUMatch hstu_match = 304;
MIND mind = 305;
TDM tdm = 400;
MultiTowerDINTRT multi_tower_din_trt =500;
RocketLaunching rocket_launching = 600;
}
optional uint32 num_class = 2 [default = 1];
repeated LossConfig losses = 3;
repeated MetricConfig metrics = 4;
optional VariationalDropout variational_dropout = 11;
optional Kernel kernel = 12 [default = PYTORCH];
}