Skip to content

Commit a0798c9

Browse files
committed
fix: fix kl_optimal when sigma_min is 0
1 parent 3e81246 commit a0798c9

File tree

1 file changed

+42
-18
lines changed

1 file changed

+42
-18
lines changed

denoiser.hpp

Lines changed: 42 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -347,36 +347,60 @@ struct SmoothStepScheduler : SigmaScheduler {
347347
}
348348
};
349349

350-
// Implementation adapted from https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15608
350+
// KL Optimal: https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15608
351351
struct KLOptimalScheduler : SigmaScheduler {
352352
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
353353
std::vector<float> sigmas;
354354

355355
if (n == 0) {
356356
return sigmas;
357357
}
358-
if (n == 1) {
359-
sigmas.push_back(sigma_max);
360-
sigmas.push_back(0.0f);
361-
return sigmas;
362-
}
363358

364-
float alpha_min = std::atan(sigma_min);
365-
float alpha_max = std::atan(sigma_max);
359+
if (sigma_min <= 1e-6f) {
360+
/* sigma_min = 0:
361+
* implemented using
362+
* https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15608
363+
* as a reference.
364+
*/
365+
366+
float alpha_min = std::atan(sigma_min);
367+
float alpha_max = std::atan(sigma_max);
366368

367-
for (uint32_t i = 0; i < n; ++i) {
368-
// t goes from 0.0 to 1.0
369-
float t = static_cast<float>(i) / static_cast<float>(n - 1);
369+
for (uint32_t i = 0; i <= n; ++i) {
370370

371-
// Interpolate in the angle domain
372-
float angle = t * alpha_min + (1.0f - t) * alpha_max;
371+
float t = static_cast<float>(i) / static_cast<float>(n);
373372

374-
// Convert back to sigma
375-
sigmas.push_back(std::tan(angle));
376-
}
373+
float angle = t * alpha_min + (1.0f - t) * alpha_max;
377374

378-
// Append the final zero to sigma
379-
sigmas.push_back(0.0f);
375+
sigmas.push_back(std::tan(angle));
376+
}
377+
378+
} else {
379+
/* sigma_min != 0:
380+
* implemented using
381+
* https://github.com/comfyanonymous/ComfyUI/pull/6206
382+
* as a reference.
383+
*/
384+
385+
if (n == 1) {
386+
sigmas.push_back(sigma_max);
387+
sigmas.push_back(0.0f);
388+
return sigmas;
389+
}
390+
391+
float alpha_min = std::atan(sigma_min);
392+
float alpha_max = std::atan(sigma_max);
393+
394+
for (uint32_t i = 0; i < n; ++i) {
395+
396+
float t = static_cast<float>(i) / static_cast<float>(n - 1);
397+
398+
float angle = t * alpha_min + (1.0f - t) * alpha_max;
399+
sigmas.push_back(std::tan(angle));
400+
}
401+
402+
sigmas.push_back(0.0f);
403+
}
380404

381405
return sigmas;
382406
}

0 commit comments

Comments
 (0)