easy_rec/python/protos/easy_rec_model.proto (140 lines of code) (raw):

syntax = "proto2"; package protos; import "easy_rec/python/protos/backbone.proto"; import "easy_rec/python/protos/fm.proto"; import "easy_rec/python/protos/deepfm.proto"; import "easy_rec/python/protos/wide_and_deep.proto"; import "easy_rec/python/protos/multi_tower.proto"; import "easy_rec/python/protos/dlrm.proto"; import "easy_rec/python/protos/feature_config.proto"; import "easy_rec/python/protos/dropoutnet.proto"; import "easy_rec/python/protos/dssm.proto"; import "easy_rec/python/protos/collaborative_metric_learning.proto"; import "easy_rec/python/protos/mmoe.proto"; import "easy_rec/python/protos/esmm.proto"; import "easy_rec/python/protos/dbmtl.proto"; import "easy_rec/python/protos/ple.proto"; import "easy_rec/python/protos/simple_multi_task.proto"; import "easy_rec/python/protos/dcn.proto"; import "easy_rec/python/protos/cmbf.proto"; import "easy_rec/python/protos/uniter.proto"; import "easy_rec/python/protos/autoint.proto"; import "easy_rec/python/protos/mind.proto"; import "easy_rec/python/protos/loss.proto"; import "easy_rec/python/protos/rocket_launching.proto"; import "easy_rec/python/protos/variational_dropout.proto"; import "easy_rec/python/protos/multi_tower_recall.proto"; import "easy_rec/python/protos/tower.proto"; import "easy_rec/python/protos/pdn.proto"; import "easy_rec/python/protos/dssm_senet.proto"; import "easy_rec/python/protos/simi.proto"; import "easy_rec/python/protos/dat.proto"; // for input performance test message DummyModel { } // configure backbone network common parameters message ModelParams { optional float l2_regularization = 1; repeated string outputs = 2; repeated BayesTaskTower task_towers = 3; optional int32 user_tower_idx_in_output = 4 [default = 0]; optional int32 item_tower_idx_in_output = 5 [default = 1]; optional Similarity simi_func = 6 [default = COSINE]; optional float temperature = 7 [default = 1.0]; optional bool scale_simi = 8 [default = false]; } // for knowledge distillation message KD { optional string loss_name = 10; required string pred_name = 11; // default to be logits optional bool pred_is_logits = 12 [default=true]; // for CROSS_ENTROPY_LOSS, soft_label must be logits instead of probs required string soft_label_name = 21; // default to be logits optional bool label_is_logits = 22 [default=true]; required LossType loss_type = 3; optional float loss_weight = 4 [default=1.0]; // only for loss_type == CROSS_ENTROPY_LOSS or BINARY_CROSS_ENTROPY_LOSS or KL_DIVERGENCE_LOSS optional float temperature = 5 [default=1.0]; // field name for indicating the sample space for this task optional string task_space_indicator_name = 6; // field value for indicating the sample space for this task optional string task_space_indicator_value = 7; // the loss weight for sample in the task space optional float in_task_space_weight = 8 [default = 1.0]; // the loss weight for sample out the task space optional float out_task_space_weight = 9 [default = 1.0]; oneof loss_param { F1ReweighedLoss f1_reweighted_loss = 101; SoftmaxCrossEntropyWithNegativeMining softmax_loss = 102; CircleLoss circle_loss = 103; MultiSimilarityLoss multi_simi_loss = 104; BinaryFocalLoss binary_focal_loss = 105; PairwiseLoss pairwise_loss = 106; PairwiseFocalLoss pairwise_focal_loss = 107; PairwiseLogisticLoss pairwise_logistic_loss = 108; JRCLoss jrc_loss = 109; PairwiseHingeLoss pairwise_hinge_loss = 110; ListwiseRankLoss listwise_rank_loss = 111; ListwiseDistillLoss listwise_distill_loss = 112; } } message EasyRecModel { required string model_class = 1; // just a name for backbone config optional string model_name = 99; // actually input layers, each layer produce a group of feature repeated FeatureGroupConfig feature_groups = 2; // model parameters oneof model { ModelParams model_params = 100; DummyModel dummy = 101; WideAndDeep wide_and_deep = 102; DeepFM deepfm = 103; MultiTower multi_tower = 104; FM fm = 105; DCN dcn = 106; AutoInt autoint = 107; DLRM dlrm = 108; CMBF cmbf = 109; Uniter uniter = 110; MultiTowerRecall multi_tower_recall = 200; DSSM dssm = 201; MIND mind = 202; DropoutNet dropoutnet = 203; CoMetricLearningI2I metric_learning = 204; PDN pdn = 205; DSSM_SENet dssm_senet = 206; DAT dat = 207; MMoE mmoe = 301; ESMM esmm = 302; DBMTL dbmtl = 303; SimpleMultiTask simple_multi_task = 304; PLE ple = 305; RocketLaunching rocket_launching = 401; } repeated SeqAttGroupConfig seq_att_groups = 7; // implemented in easy_rec/python/model/easy_rec_estimator // add regularization to all variables with "embedding_weights:" // in name optional float embedding_regularization = 8 [default = 0.0]; optional LossType loss_type = 9 [default = CLASSIFICATION]; optional uint32 num_class = 10 [default = 1]; optional EVParams ev_params = 11; repeated KD kd = 12; // filter variables matching any pattern in restore_filters // common filters are Adam, Momentum, etc. repeated string restore_filters = 13; optional VariationalDropoutLayer variational_dropout = 14; repeated Loss losses = 15; enum LossWeightStrategy { Fixed = 0; Uncertainty = 1; Random = 2; } required LossWeightStrategy loss_weight_strategy = 16 [default = Fixed]; optional BackboneTower backbone = 17; // label name for rank_model to select one label between multiple labels optional string label_name = 18; }