Compute distillation loss over hard and soft targets#7
Conversation
f31ff96 to
e62efce
Compare
| teacher_probs = F.softmax(teacher_logits / temperature, dim=-1) | ||
| # Soft loss (distillation) | ||
| teacher_log_probs = F.log_softmax(teacher_logits / temperature, dim=-1) | ||
| student_log_probs = F.log_softmax(student_logits / temperature, dim=-1) |
There was a problem hiding this comment.
Shouldn't there be 2 different temperatures for student_log_probs (high temp for soft logits and 1 for hard logits)
From Hinton, page 3:
[...] we found that a better way is to simply use a weighted average of two different objective functions. The first objective function is the cross entropy with the soft targets and this cross entropy is computed using the same high temperature in the softmax of the distilled model as was used for generating the soft targets from the cumbersome model. The second objective function is the cross entropy with the correct labels. This is computed using exactly the same logits in softmax of the distilled model but at a temperature of 1.
There was a problem hiding this comment.
My bad, just realized you're already doing that; didn't notice the F.kl_div(student_log_probs, ..., log_target=True) vs. F.cross_entropy(student_logits_flat, .... log_target=False)
🔥
I guess the alternative to using CLI is to call the python code directly from the notebook? Do you have a preference? |
aec73f2 to
53501e6
Compare
|
@yonromai I pushed some more stuff here:
One small change sneaked into the logic - if the alpha is 1.0 (i.e. ignore hard target CE loss), we do not scale by |
Follow the
Distilling the Knowledge in a Neural Networkdistillation loss, incorporate both hard and soft targets. "Considerably lower weight" (alpha) on the hard target CE (tbd actual weight).Additionally allow to set max epochs via CLI. I don't know if we want to continue CLI interfaces; IFF we probably should switch to Typer or something similar.