tzrec/protos/model.proto (59 lines of code) (raw):

syntax = "proto2"; package tzrec.protos; import "tzrec/protos/models/rank_model.proto"; import "tzrec/protos/models/multi_task_rank.proto"; import "tzrec/protos/models/match_model.proto"; import "tzrec/protos/models/general_rank_model.proto"; import "tzrec/protos/loss.proto"; import "tzrec/protos/metric.proto"; import "tzrec/protos/seq_encoder.proto"; import "tzrec/protos/module.proto"; enum FeatureGroupType { DEEP = 0; WIDE = 1; SEQUENCE = 2; } message SeqGroupConfig { optional string group_name = 1; repeated string feature_names = 2; } message FeatureGroupConfig { required string group_name = 1; repeated string feature_names = 2; required FeatureGroupType group_type = 3 [default = DEEP]; repeated SeqGroupConfig sequence_groups = 4; repeated SeqEncoderConfig sequence_encoders = 5; } enum Kernel { TRITON = 0; PYTORCH = 1; CUDA = 2; } message ModelConfig { repeated FeatureGroupConfig feature_groups = 1; oneof model { DLRM dlrm = 100; DeepFM deepfm = 101; MultiTower multi_tower = 102; MultiTowerDIN multi_tower_din = 103; MaskNet mask_net = 104; SimpleMultiTask simple_multi_task = 200; MMoE mmoe = 201; DBMTL dbmtl = 202; PLE ple = 203; DC2VR dc2vr = 204; DSSM dssm = 301; DSSMV2 dssm_v2 = 302; DAT dat = 303; HSTUMatch hstu_match = 304; MIND mind = 305; TDM tdm = 400; MultiTowerDINTRT multi_tower_din_trt =500; RocketLaunching rocket_launching = 600; } optional uint32 num_class = 2 [default = 1]; repeated LossConfig losses = 3; repeated MetricConfig metrics = 4; optional VariationalDropout variational_dropout = 11; optional Kernel kernel = 12 [default = PYTORCH]; }