Skip to content

Commit f454181

Browse files
committed
feat(cli): add --model-type parameter for manual version override
Adds a new --model-type CLI parameter that allows users to manually specify the model version instead of relying on auto-detection. This is useful when: - Auto-detection fails or is ambiguous - Testing model behavior with different version settings - Working with modified/custom models Usage: --model-type sdxl # Force SDXL version --model-type sd1 # Force SD 1.x version --model-type flux # Force FLUX version Supported values: sd1, sd2, sdxl, sdxl_inpaint, sdxl_pix2pix, flux, sd3, svd Implementation: - Added version_override field to sd_ctx_params_t struct - Added model_type string parameter to SDContextParams - Added string-to-enum conversion in to_sd_ctx_params_t() - Updated model loading to check for manual override before auto-detection - Auto-detection still works when --model-type is not specified Testing: - Tested manual override with --model-type sdxl (works) - Tested auto-detection without parameter (still works) - Tested with SD 1.5 model and --model-type sd1 (works) Files changed: - stable-diffusion.h: Added version_override field to sd_ctx_params_t - stable-diffusion.cpp: Added version override logic and initialization - examples/common/common.hpp: Added CLI parameter and string-to-enum conversion
1 parent f6ae111 commit f454181

File tree

3 files changed

+57
-9
lines changed

3 files changed

+57
-9
lines changed

examples/common/common.hpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11

2+
#include <algorithm>
23
#include <filesystem>
34
#include <iostream>
45
#include <map>
@@ -18,6 +19,7 @@ namespace fs = std::filesystem;
1819
#endif // _WIN32
1920

2021
#include "stable-diffusion.h"
22+
#include "model.h" // For SDVersion enum
2123

2224
#define STB_IMAGE_IMPLEMENTATION
2325
#define STB_IMAGE_STATIC
@@ -443,6 +445,7 @@ struct SDContextParams {
443445
std::string control_net_path;
444446
std::string embedding_dir;
445447
std::string photo_maker_path;
448+
std::string model_type; // Manual model version override (sd1, sd2, sdxl, flux, etc.)
446449
sd_type_t wtype = SD_TYPE_COUNT;
447450
std::string tensor_type_rules;
448451
std::string lora_model_dir = ".";
@@ -487,6 +490,10 @@ struct SDContextParams {
487490
"--model",
488491
"path to full model",
489492
&model_path},
493+
{"",
494+
"--model-type",
495+
"force model type (sd1, sd2, sdxl, flux, sdxl_inpaint, etc). Auto-detect if not specified.",
496+
&model_type},
490497
{"",
491498
"--clip_l",
492499
"path to the clip-l text encoder", &clip_l_path},
@@ -944,6 +951,38 @@ struct SDContextParams {
944951
embedding_vec.emplace_back(item);
945952
}
946953

954+
// Parse model_type string to SDVersion enum
955+
int version_override = VERSION_COUNT; // Auto-detect by default
956+
if (!model_type.empty()) {
957+
std::string mt = model_type;
958+
// Convert to lowercase for case-insensitive matching
959+
std::transform(mt.begin(), mt.end(), mt.begin(), ::tolower);
960+
961+
if (mt == "sd1" || mt == "sd1.5" || mt == "sd1.x") {
962+
version_override = VERSION_SD1;
963+
} else if (mt == "sd1_inpaint") {
964+
version_override = VERSION_SD1_INPAINT;
965+
} else if (mt == "sd2" || mt == "sd2.0" || mt == "sd2.1" || mt == "sd2.x") {
966+
version_override = VERSION_SD2;
967+
} else if (mt == "sd2_inpaint") {
968+
version_override = VERSION_SD2_INPAINT;
969+
} else if (mt == "sdxl" || mt == "sdxl1.0") {
970+
version_override = VERSION_SDXL;
971+
} else if (mt == "sdxl_inpaint") {
972+
version_override = VERSION_SDXL_INPAINT;
973+
} else if (mt == "sdxl_pix2pix") {
974+
version_override = VERSION_SDXL_PIX2PIX;
975+
} else if (mt == "flux" || mt == "flux1") {
976+
version_override = VERSION_FLUX;
977+
} else if (mt == "sd3" || mt == "sd3.5") {
978+
version_override = VERSION_SD3;
979+
} else if (mt == "svd") {
980+
version_override = VERSION_SVD;
981+
} else {
982+
fprintf(stderr, "Warning: Unknown model type '%s', using auto-detect\n", model_type.c_str());
983+
}
984+
}
985+
947986
sd_ctx_params_t sd_ctx_params = {
948987
model_path.c_str(),
949988
clip_l_path.c_str(),
@@ -969,6 +1008,7 @@ struct SDContextParams {
9691008
sampler_rng_type,
9701009
prediction,
9711010
lora_apply_mode,
1011+
version_override, // Add version_override parameter
9721012
offload_params_to_cpu,
9731013
enable_mmap,
9741014
clip_on_cpu,

include/stable-diffusion.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ typedef struct {
184184
enum rng_type_t sampler_rng_type;
185185
enum prediction_t prediction;
186186
enum lora_apply_mode_t lora_apply_mode;
187+
int version_override; // SDVersion enum value, VERSION_COUNT = auto-detect
187188
bool offload_params_to_cpu;
188189
bool enable_mmap;
189190
bool keep_clip_on_cpu;

src/stable-diffusion.cpp

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -326,16 +326,22 @@ class StableDiffusionGGML {
326326

327327
model_loader.convert_tensors_name();
328328

329-
// SDXL_FIX: Don't overwrite if already detected as SDXL in earlier component
330-
SDVersion detected_version = model_loader.get_sd_version();
331-
if (version != VERSION_SDXL && version != VERSION_SDXL_INPAINT && version != VERSION_SDXL_PIX2PIX) {
332-
version = detected_version;
329+
// Check for manual version override first
330+
if (sd_ctx_params->version_override != VERSION_COUNT) {
331+
version = (SDVersion)sd_ctx_params->version_override;
332+
LOG_INFO("Version overridden to: %s", model_version_to_str[version]);
333333
} else {
334-
LOG_INFO("Keeping previous SDXL version, detected version: %s", model_version_to_str[detected_version]);
335-
}
336-
if (version == VERSION_COUNT) {
337-
LOG_ERROR("get sd version from file failed: '%s'", SAFE_STR(sd_ctx_params->model_path));
338-
return false;
334+
// Auto-detect version - don't overwrite if already detected as SDXL in earlier component
335+
SDVersion detected_version = model_loader.get_sd_version();
336+
if (version != VERSION_SDXL && version != VERSION_SDXL_INPAINT && version != VERSION_SDXL_PIX2PIX) {
337+
version = detected_version;
338+
} else {
339+
LOG_INFO("Keeping previous SDXL version, detected version: %s", model_version_to_str[detected_version]);
340+
}
341+
if (version == VERSION_COUNT) {
342+
LOG_ERROR("get sd version from file failed: '%s'", SAFE_STR(sd_ctx_params->model_path));
343+
return false;
344+
}
339345
}
340346

341347
auto& tensor_storage_map = model_loader.get_tensor_storage_map();
@@ -2925,6 +2931,7 @@ void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {
29252931
sd_ctx_params->sampler_rng_type = RNG_TYPE_COUNT;
29262932
sd_ctx_params->prediction = PREDICTION_COUNT;
29272933
sd_ctx_params->lora_apply_mode = LORA_APPLY_AUTO;
2934+
sd_ctx_params->version_override = VERSION_COUNT; // Auto-detect
29282935
sd_ctx_params->offload_params_to_cpu = false;
29292936
sd_ctx_params->enable_mmap = false;
29302937
sd_ctx_params->keep_clip_on_cpu = false;

0 commit comments

Comments
 (0)