easy_rec/python/protos/easy_rec_model.proto (140 lines of code) (raw):
syntax = "proto2";
package protos;
import "easy_rec/python/protos/backbone.proto";
import "easy_rec/python/protos/fm.proto";
import "easy_rec/python/protos/deepfm.proto";
import "easy_rec/python/protos/wide_and_deep.proto";
import "easy_rec/python/protos/multi_tower.proto";
import "easy_rec/python/protos/dlrm.proto";
import "easy_rec/python/protos/feature_config.proto";
import "easy_rec/python/protos/dropoutnet.proto";
import "easy_rec/python/protos/dssm.proto";
import "easy_rec/python/protos/collaborative_metric_learning.proto";
import "easy_rec/python/protos/mmoe.proto";
import "easy_rec/python/protos/esmm.proto";
import "easy_rec/python/protos/dbmtl.proto";
import "easy_rec/python/protos/ple.proto";
import "easy_rec/python/protos/simple_multi_task.proto";
import "easy_rec/python/protos/dcn.proto";
import "easy_rec/python/protos/cmbf.proto";
import "easy_rec/python/protos/uniter.proto";
import "easy_rec/python/protos/autoint.proto";
import "easy_rec/python/protos/mind.proto";
import "easy_rec/python/protos/loss.proto";
import "easy_rec/python/protos/rocket_launching.proto";
import "easy_rec/python/protos/variational_dropout.proto";
import "easy_rec/python/protos/multi_tower_recall.proto";
import "easy_rec/python/protos/tower.proto";
import "easy_rec/python/protos/pdn.proto";
import "easy_rec/python/protos/dssm_senet.proto";
import "easy_rec/python/protos/simi.proto";
import "easy_rec/python/protos/dat.proto";
// for input performance test
message DummyModel {
}
// configure backbone network common parameters
message ModelParams {
optional float l2_regularization = 1;
repeated string outputs = 2;
repeated BayesTaskTower task_towers = 3;
optional int32 user_tower_idx_in_output = 4 [default = 0];
optional int32 item_tower_idx_in_output = 5 [default = 1];
optional Similarity simi_func = 6 [default = COSINE];
optional float temperature = 7 [default = 1.0];
optional bool scale_simi = 8 [default = false];
}
// for knowledge distillation
message KD {
optional string loss_name = 10;
required string pred_name = 11;
// default to be logits
optional bool pred_is_logits = 12 [default=true];
// for CROSS_ENTROPY_LOSS, soft_label must be logits instead of probs
required string soft_label_name = 21;
// default to be logits
optional bool label_is_logits = 22 [default=true];
required LossType loss_type = 3;
optional float loss_weight = 4 [default=1.0];
// only for loss_type == CROSS_ENTROPY_LOSS or BINARY_CROSS_ENTROPY_LOSS or KL_DIVERGENCE_LOSS
optional float temperature = 5 [default=1.0];
// field name for indicating the sample space for this task
optional string task_space_indicator_name = 6;
// field value for indicating the sample space for this task
optional string task_space_indicator_value = 7;
// the loss weight for sample in the task space
optional float in_task_space_weight = 8 [default = 1.0];
// the loss weight for sample out the task space
optional float out_task_space_weight = 9 [default = 1.0];
oneof loss_param {
F1ReweighedLoss f1_reweighted_loss = 101;
SoftmaxCrossEntropyWithNegativeMining softmax_loss = 102;
CircleLoss circle_loss = 103;
MultiSimilarityLoss multi_simi_loss = 104;
BinaryFocalLoss binary_focal_loss = 105;
PairwiseLoss pairwise_loss = 106;
PairwiseFocalLoss pairwise_focal_loss = 107;
PairwiseLogisticLoss pairwise_logistic_loss = 108;
JRCLoss jrc_loss = 109;
PairwiseHingeLoss pairwise_hinge_loss = 110;
ListwiseRankLoss listwise_rank_loss = 111;
ListwiseDistillLoss listwise_distill_loss = 112;
}
}
message EasyRecModel {
required string model_class = 1;
// just a name for backbone config
optional string model_name = 99;
// actually input layers, each layer produce a group of feature
repeated FeatureGroupConfig feature_groups = 2;
// model parameters
oneof model {
ModelParams model_params = 100;
DummyModel dummy = 101;
WideAndDeep wide_and_deep = 102;
DeepFM deepfm = 103;
MultiTower multi_tower = 104;
FM fm = 105;
DCN dcn = 106;
AutoInt autoint = 107;
DLRM dlrm = 108;
CMBF cmbf = 109;
Uniter uniter = 110;
MultiTowerRecall multi_tower_recall = 200;
DSSM dssm = 201;
MIND mind = 202;
DropoutNet dropoutnet = 203;
CoMetricLearningI2I metric_learning = 204;
PDN pdn = 205;
DSSM_SENet dssm_senet = 206;
DAT dat = 207;
MMoE mmoe = 301;
ESMM esmm = 302;
DBMTL dbmtl = 303;
SimpleMultiTask simple_multi_task = 304;
PLE ple = 305;
RocketLaunching rocket_launching = 401;
}
repeated SeqAttGroupConfig seq_att_groups = 7;
// implemented in easy_rec/python/model/easy_rec_estimator
// add regularization to all variables with "embedding_weights:"
// in name
optional float embedding_regularization = 8 [default = 0.0];
optional LossType loss_type = 9 [default = CLASSIFICATION];
optional uint32 num_class = 10 [default = 1];
optional EVParams ev_params = 11;
repeated KD kd = 12;
// filter variables matching any pattern in restore_filters
// common filters are Adam, Momentum, etc.
repeated string restore_filters = 13;
optional VariationalDropoutLayer variational_dropout = 14;
repeated Loss losses = 15;
enum LossWeightStrategy {
Fixed = 0;
Uncertainty = 1;
Random = 2;
}
required LossWeightStrategy loss_weight_strategy = 16 [default = Fixed];
optional BackboneTower backbone = 17;
// label name for rank_model to select one label between multiple labels
optional string label_name = 18;
}