easy_rec/python/protos/loss.proto (121 lines of code) (raw):
syntax = "proto2";
package protos;
enum LossType {
CLASSIFICATION = 0;
L2_LOSS = 1;
SIGMOID_L2_LOSS = 2;
// crossentropy loss/log loss
CROSS_ENTROPY_LOSS = 3;
SOFTMAX_CROSS_ENTROPY = 4;
CIRCLE_LOSS = 5;
MULTI_SIMILARITY_LOSS = 6;
SOFTMAX_CROSS_ENTROPY_WITH_NEGATIVE_MINING = 7;
PAIR_WISE_LOSS = 8;
F1_REWEIGHTED_LOSS = 9;
BINARY_FOCAL_LOSS = 10;
PAIRWISE_FOCAL_LOSS = 11;
PAIRWISE_LOGISTIC_LOSS = 12;
PAIRWISE_HINGE_LOSS = 17;
JRC_LOSS = 13;
ORDER_CALIBRATE_LOSS = 14;
BINARY_CROSS_ENTROPY_LOSS = 15;
KL_DIVERGENCE_LOSS = 16;
LISTWISE_RANK_LOSS = 18;
LISTWISE_DISTILL_LOSS = 19;
ZILN_LOSS = 20;
}
message Loss {
required LossType loss_type = 1;
optional float weight = 2 [default = 1.0];
optional string loss_name = 3;
optional bool learn_loss_weight = 4 [default = false];
oneof loss_param {
F1ReweighedLoss f1_reweighted_loss = 101;
SoftmaxCrossEntropyWithNegativeMining softmax_loss = 102;
CircleLoss circle_loss = 103;
MultiSimilarityLoss multi_simi_loss = 104;
BinaryFocalLoss binary_focal_loss = 105;
PairwiseLoss pairwise_loss = 106;
PairwiseFocalLoss pairwise_focal_loss = 107;
PairwiseLogisticLoss pairwise_logistic_loss = 108;
JRCLoss jrc_loss = 109;
PairwiseHingeLoss pairwise_hinge_loss = 110;
ListwiseRankLoss listwise_rank_loss = 111;
ListwiseDistillLoss listwise_distill_loss = 112;
}
};
message SoftmaxCrossEntropyWithNegativeMining {
required uint32 num_negative_samples = 1;
required float margin = 2 [default = 0];
required float gamma = 3 [default = 1];
required float coefficient_of_support_vector = 4 [default = 1];
}
message CircleLoss {
required float margin = 1 [default = 0.25];
required float gamma = 2 [default = 32];
}
message MultiSimilarityLoss {
required float alpha = 1 [default = 2];
required float beta = 2 [default = 50];
required float lamb = 3 [default = 1];
required float eps = 4 [default = 0.1];
}
message F1ReweighedLoss {
required float f1_beta_square = 1 [default = 1.0];
required float label_smoothing = 2 [default = 0];
}
message BinaryFocalLoss {
required float gamma = 1 [default = 2.0];
optional float alpha = 2;
optional float ohem_ratio = 3 [default = 1.0];
optional float label_smoothing = 4 [default = 0];
}
message PairwiseLoss {
required float margin = 1 [default = 0];
optional string session_name = 2;
optional float temperature = 3 [default = 1.0];
}
message PairwiseFocalLoss {
required float gamma = 1 [default = 2.0];
optional float alpha = 2;
optional float hinge_margin = 3 [default = 1.0];
optional string session_name = 4;
optional float ohem_ratio = 5 [default = 1.0];
optional float temperature = 6 [default = 1.0];
}
message PairwiseLogisticLoss {
required float temperature = 1 [default = 1.0];
optional string session_name = 2;
optional float hinge_margin = 3;
optional float ohem_ratio = 4 [default = 1.0];
optional bool use_label_margin = 5 [default = false];
}
message PairwiseHingeLoss {
required float temperature = 1 [default = 1.0];
optional string session_name = 2;
optional float margin = 3 [default = 1.0];
optional float ohem_ratio = 4 [default = 1.0];
optional bool label_is_logits = 5 [default = true];
optional bool use_label_margin = 6 [default = true];
optional bool use_exponent = 7 [default = false];
}
message JRCLoss {
required string session_name = 1;
optional float alpha = 2 [default = 0.5];
optional bool same_label_loss = 3 [default = true];
required string loss_weight_strategy = 4 [default = 'fixed'];
}
message ListwiseRankLoss {
required float temperature = 1 [default = 1.0];
optional string session_name = 2;
optional string transform_fn = 3;
optional bool label_is_logits = 4 [default = false];
optional bool scale_logits = 5 [default = false];
}
message ListwiseDistillLoss {
required float temperature = 1 [default = 1.0];
optional string session_name = 2;
optional string transform_fn = 3;
optional float label_clip_max_value = 4 [default = 512.0];
optional bool scale_logits = 5 [default = false];
}