77
88from easy_rec .python .loss .focal_loss import sigmoid_focal_loss_with_logits
99from easy_rec .python .loss .jrc_loss import jrc_loss
10- from easy_rec .python .loss .listwise_loss import listwise_distill_loss
11- from easy_rec .python .loss .listwise_loss import listwise_rank_loss
12- from easy_rec .python .loss .pairwise_loss import pairwise_focal_loss
13- from easy_rec .python .loss .pairwise_loss import pairwise_hinge_loss
14- from easy_rec .python .loss .pairwise_loss import pairwise_logistic_loss
15- from easy_rec .python .loss .pairwise_loss import pairwise_loss
1610from easy_rec .python .protos .loss_pb2 import LossType
1711
18- from easy_rec .python .loss .zero_inflated_lognormal import zero_inflated_lognormal_loss # NOQA
19-
20- from easy_rec .python .loss .f1_reweight_loss import f1_reweight_sigmoid_cross_entropy # NOQA
12+ from easy_rec .python .loss .f1_reweight_loss import ( # NOQA
13+ f1_reweight_sigmoid_cross_entropy ,)
14+ from easy_rec .python .loss .listwise_loss import ( # NOQA
15+ listwise_distill_loss , listwise_rank_loss ,
16+ )
17+ from easy_rec .python .loss .pairwise_loss import ( # NOQA
18+ pairwise_focal_loss , pairwise_hinge_loss , pairwise_logistic_loss ,
19+ pairwise_loss ,
20+ )
21+ from easy_rec .python .loss .zero_inflated_lognormal import ( # NOQA
22+ zero_inflated_lognormal_loss ,)
2123
2224if tf .__version__ >= '2.0' :
2325 tf = tf .compat .v1
@@ -36,8 +38,10 @@ def build(loss_type,
3638 return tf .losses .sigmoid_cross_entropy (
3739 label , logits = pred , weights = loss_weight , ** kwargs )
3840 else :
39- assert label .dtype in [tf .int32 , tf .int64 ], \
40- 'label.dtype must in [tf.int32, tf.int64] when use sparse_softmax_cross_entropy.'
41+ assert label .dtype in [
42+ tf .int32 ,
43+ tf .int64 ,
44+ ], 'label.dtype must in [tf.int32, tf.int64] when use sparse_softmax_cross_entropy.'
4145 return tf .losses .sparse_softmax_cross_entropy (
4246 labels = label , logits = pred , weights = loss_weight , ** kwargs )
4347 elif loss_type == LossType .CROSS_ENTROPY_LOSS :
@@ -50,7 +54,23 @@ def build(loss_type,
5054 return tf .losses .mean_squared_error (
5155 labels = label , predictions = pred , weights = loss_weight , ** kwargs )
5256 elif loss_type == LossType .ZILN_LOSS :
53- loss = zero_inflated_lognormal_loss (label , pred )
57+ if loss_param is None :
58+ loss = zero_inflated_lognormal_loss (label , pred )
59+ else :
60+ mu_reg = loss_param .mu_regularization
61+ sigma_reg = loss_param .sigma_regularization
62+ max_sigma = loss_param .max_sigma
63+ class_weight = loss_param .classification_weight
64+ reg_weight = loss_param .regression_weight
65+ loss = zero_inflated_lognormal_loss (
66+ label ,
67+ pred ,
68+ max_sigma = max_sigma ,
69+ mu_reg = mu_reg ,
70+ sigma_reg = sigma_reg ,
71+ class_weight = class_weight ,
72+ reg_weight = reg_weight ,
73+ )
5474 if np .isscalar (loss_weight ) and loss_weight != 1.0 :
5575 return loss * loss_weight
5676 return loss
@@ -219,9 +239,9 @@ def build_kd_loss(kds, prediction_dict, label_dict, feature_dict):
219239 """
220240 loss_dict = {}
221241 for kd in kds :
222- assert kd .pred_name in prediction_dict , \
223- 'invalid predict_name: %s available ones: %s' % (
224- kd . pred_name , ',' .join (prediction_dict .keys ()))
242+ assert kd .pred_name in prediction_dict , 'invalid predict_name: %s available ones: %s' % (
243+ kd . pred_name ,
244+ ',' .join (prediction_dict .keys ()))
225245
226246 loss_name = kd .loss_name
227247 if not loss_name :
@@ -232,8 +252,10 @@ def build_kd_loss(kds, prediction_dict, label_dict, feature_dict):
232252 if kd .HasField ('task_space_indicator_name' ) and kd .HasField (
233253 'task_space_indicator_value' ):
234254 in_task_space = tf .to_float (
235- tf .equal (feature_dict [kd .task_space_indicator_name ],
236- kd .task_space_indicator_value ))
255+ tf .equal (
256+ feature_dict [kd .task_space_indicator_name ],
257+ kd .task_space_indicator_value ,
258+ ))
237259 loss_weight = loss_weight * (
238260 kd .in_task_space_weight * in_task_space + kd .out_task_space_weight *
239261 (1 - in_task_space ))
0 commit comments