@@ -347,36 +347,60 @@ struct SmoothStepScheduler : SigmaScheduler {
347347 }
348348};
349349
350- // Implementation adapted from https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15608
350+ // KL Optimal: https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15608
351351struct KLOptimalScheduler : SigmaScheduler {
352352 std::vector<float > get_sigmas (uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
353353 std::vector<float > sigmas;
354354
355355 if (n == 0 ) {
356356 return sigmas;
357357 }
358- if (n == 1 ) {
359- sigmas.push_back (sigma_max);
360- sigmas.push_back (0 .0f );
361- return sigmas;
362- }
363358
364- float alpha_min = std::atan (sigma_min);
365- float alpha_max = std::atan (sigma_max);
359+ if (sigma_min <= 1e-6f ) {
360+ /* sigma_min = 0:
361+ * implemented using
362+ * https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15608
363+ * as a reference.
364+ */
365+
366+ float alpha_min = std::atan (sigma_min);
367+ float alpha_max = std::atan (sigma_max);
366368
367- for (uint32_t i = 0 ; i < n; ++i) {
368- // t goes from 0.0 to 1.0
369- float t = static_cast <float >(i) / static_cast <float >(n - 1 );
369+ for (uint32_t i = 0 ; i <= n; ++i) {
370370
371- // Interpolate in the angle domain
372- float angle = t * alpha_min + (1 .0f - t) * alpha_max;
371+ float t = static_cast <float >(i) / static_cast <float >(n);
373372
374- // Convert back to sigma
375- sigmas.push_back (std::tan (angle));
376- }
373+ float angle = t * alpha_min + (1 .0f - t) * alpha_max;
377374
378- // Append the final zero to sigma
379- sigmas.push_back (0 .0f );
375+ sigmas.push_back (std::tan (angle));
376+ }
377+
378+ } else {
379+ /* sigma_min != 0:
380+ * implemented using
381+ * https://github.com/comfyanonymous/ComfyUI/pull/6206
382+ * as a reference.
383+ */
384+
385+ if (n == 1 ) {
386+ sigmas.push_back (sigma_max);
387+ sigmas.push_back (0 .0f );
388+ return sigmas;
389+ }
390+
391+ float alpha_min = std::atan (sigma_min);
392+ float alpha_max = std::atan (sigma_max);
393+
394+ for (uint32_t i = 0 ; i < n; ++i) {
395+
396+ float t = static_cast <float >(i) / static_cast <float >(n - 1 );
397+
398+ float angle = t * alpha_min + (1 .0f - t) * alpha_max;
399+ sigmas.push_back (std::tan (angle));
400+ }
401+
402+ sigmas.push_back (0 .0f );
403+ }
380404
381405 return sigmas;
382406 }
0 commit comments