Skip to content

Compute distillation loss over hard and soft targets#7

Merged
ravwojdyla merged 6 commits intomainfrom
rav-fixup-distill-loss
Jun 19, 2025
Merged

Compute distillation loss over hard and soft targets#7
ravwojdyla merged 6 commits intomainfrom
rav-fixup-distill-loss

Conversation

@ravwojdyla
Copy link
Copy Markdown
Contributor

@ravwojdyla ravwojdyla commented Jun 18, 2025

Follow the Distilling the Knowledge in a Neural Network distillation loss, incorporate both hard and soft targets. "Considerably lower weight" (alpha) on the hard target CE (tbd actual weight).

Additionally allow to set max epochs via CLI. I don't know if we want to continue CLI interfaces; IFF we probably should switch to Typer or something similar.

@ravwojdyla ravwojdyla force-pushed the rav-fixup-distill-loss branch from f31ff96 to e62efce Compare June 18, 2025 08:54
@ravwojdyla ravwojdyla requested a review from yonromai June 18, 2025 08:56
teacher_probs = F.softmax(teacher_logits / temperature, dim=-1)
# Soft loss (distillation)
teacher_log_probs = F.log_softmax(teacher_logits / temperature, dim=-1)
student_log_probs = F.log_softmax(student_logits / temperature, dim=-1)
Copy link
Copy Markdown
Contributor

@yonromai yonromai Jun 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't there be 2 different temperatures for student_log_probs (high temp for soft logits and 1 for hard logits)

From Hinton, page 3:

[...] we found that a better way is to simply use a weighted average of two different objective functions. The first objective function is the cross entropy with the soft targets and this cross entropy is computed using the same high temperature in the softmax of the distilled model as was used for generating the soft targets from the cumbersome model. The second objective function is the cross entropy with the correct labels. This is computed using exactly the same logits in softmax of the distilled model but at a temperature of 1.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad, just realized you're already doing that; didn't notice the F.kl_div(student_log_probs, ..., log_target=True) vs. F.cross_entropy(student_logits_flat, .... log_target=False)

@yonromai
Copy link
Copy Markdown
Contributor

Follow the Distilling the Knowledge in a Neural Network distillation loss, incorporate both hard and soft targets. "Considerably lower weight" (alpha) on the hard target CE (tbd actual weight).

🔥
Perhaps alpha should be a CLI arg?

Additionally allow to set max epochs via CLI. I don't know if we want to continue CLI interfaces; IFF we probably should switch to Typer or something similar.

I guess the alternative to using CLI is to call the python code directly from the notebook? Do you have a preference?
(+1 to switch to Typer IFF we keep using the CLI)

Copy link
Copy Markdown
Contributor

@yonromai yonromai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔥🔥

Base automatically changed from rav-set-limit to main June 18, 2025 16:06
@ravwojdyla ravwojdyla force-pushed the rav-fixup-distill-loss branch from aec73f2 to 53501e6 Compare June 19, 2025 00:45
@ravwojdyla
Copy link
Copy Markdown
Contributor Author

ravwojdyla commented Jun 19, 2025

@yonromai I pushed some more stuff here:

  • expose the distillation alpha parameter in the CLI
  • add distillation loss unit tests
    • it includes the tests we have discussed $$temp \to \infty$$ and $$temp \to 0$$
  • run unit tests in the CI

One small change sneaked into the logic - if the alpha is 1.0 (i.e. ignore hard target CE loss), we do not scale by $$temp^2$$.

@ravwojdyla ravwojdyla merged commit 5298c17 into main Jun 19, 2025
2 checks passed
@ravwojdyla ravwojdyla deleted the rav-fixup-distill-loss branch June 19, 2025 00:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants