Skip to content

Commit 75c0b7f

Browse files
committed
feat: add support for Segmind-Vega model
1 parent 7010bb4 commit 75c0b7f

File tree

4 files changed

+15
-7
lines changed

4 files changed

+15
-7
lines changed

model.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,7 +1038,7 @@ SDVersion ModelLoader::get_sd_version() {
10381038
int64_t patch_embedding_channels = 0;
10391039
bool has_img_emb = false;
10401040
bool has_middle_block_1 = false;
1041-
bool has_output_block_71 = false;
1041+
bool has_output_block_311 = false;
10421042

10431043
for (auto& [name, tensor_storage] : tensor_storage_map) {
10441044
if (!(is_xl)) {
@@ -1095,8 +1095,8 @@ SDVersion ModelLoader::get_sd_version() {
10951095
tensor_storage.name.find("unet.mid_block.resnets.1.") != std::string::npos) {
10961096
has_middle_block_1 = true;
10971097
}
1098-
if (tensor_storage.name.find("model.diffusion_model.output_blocks.7.1") != std::string::npos) {
1099-
has_output_block_71 = true;
1098+
if (tensor_storage.name.find("model.diffusion_model.output_blocks.3.1.transformer_blocks.1") != std::string::npos) {
1099+
has_output_block_311 = true;
11001100
}
11011101
if (tensor_storage.name == "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight" ||
11021102
tensor_storage.name == "cond_stage_model.model.token_embedding.weight" ||
@@ -1133,6 +1133,9 @@ SDVersion ModelLoader::get_sd_version() {
11331133
return VERSION_SDXL_PIX2PIX;
11341134
}
11351135
if (!has_middle_block_1) {
1136+
if (!has_output_block_311) {
1137+
return VERSION_SDXL_VEGA;
1138+
}
11361139
return VERSION_SDXL_SSD1B;
11371140
}
11381141
return VERSION_SDXL;
@@ -1159,7 +1162,7 @@ SDVersion ModelLoader::get_sd_version() {
11591162
return VERSION_SD1_PIX2PIX;
11601163
}
11611164
if (!has_middle_block_1) {
1162-
if (!has_output_block_71) {
1165+
if (!has_output_block_311) {
11631166
return VERSION_SDXS;
11641167
}
11651168
return VERSION_SD1_TINY_UNET;

model.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ enum SDVersion {
3232
VERSION_SDXL,
3333
VERSION_SDXL_INPAINT,
3434
VERSION_SDXL_PIX2PIX,
35+
VERSION_SDXL_VEGA,
3536
VERSION_SDXL_SSD1B,
3637
VERSION_SVD,
3738
VERSION_SD3,
@@ -65,7 +66,7 @@ static inline bool sd_version_is_sd2(SDVersion version) {
6566
}
6667

6768
static inline bool sd_version_is_sdxl(SDVersion version) {
68-
if (version == VERSION_SDXL || version == VERSION_SDXL_INPAINT || version == VERSION_SDXL_PIX2PIX || version == VERSION_SDXL_SSD1B) {
69+
if (version == VERSION_SDXL || version == VERSION_SDXL_INPAINT || version == VERSION_SDXL_PIX2PIX || version == VERSION_SDXL_SSD1B || version == VERSION_SDXL_VEGA) {
6970
return true;
7071
}
7172
return false;

stable-diffusion.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ const char* model_version_to_str[] = {
3535
"SDXL",
3636
"SDXL Inpaint",
3737
"SDXL Instruct-Pix2Pix",
38+
"SDXL (Vega)",
3839
"SDXL (SSD1B)",
3940
"SVD",
4041
"SD3.x",

unet.hpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,9 @@ class UnetModelBlock : public GGMLBlock {
201201
num_head_channels = 64;
202202
num_heads = -1;
203203
use_linear_projection = true;
204+
if (version == VERSION_SDXL_VEGA) {
205+
transformer_depth = {1, 1, 2};
206+
}
204207
} else if (version == VERSION_SVD) {
205208
in_channels = 8;
206209
out_channels = 4;
@@ -319,7 +322,7 @@ class UnetModelBlock : public GGMLBlock {
319322
}
320323
if (!tiny_unet) {
321324
blocks["middle_block.0"] = std::shared_ptr<GGMLBlock>(get_resblock(ch, time_embed_dim, ch));
322-
if (version != VERSION_SDXL_SSD1B) {
325+
if (version != VERSION_SDXL_SSD1B && version != VERSION_SDXL_VEGA) {
323326
blocks["middle_block.1"] = std::shared_ptr<GGMLBlock>(get_attention_layer(ch,
324327
n_head,
325328
d_head,
@@ -520,7 +523,7 @@ class UnetModelBlock : public GGMLBlock {
520523
// middle_block
521524
if (!tiny_unet) {
522525
h = resblock_forward("middle_block.0", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
523-
if (version != VERSION_SDXL_SSD1B) {
526+
if (version != VERSION_SDXL_SSD1B && version != VERSION_SDXL_VEGA) {
524527
h = attention_layer_forward("middle_block.1", ctx, h, context, num_video_frames); // [N, 4*model_channels, h/8, w/8]
525528
h = resblock_forward("middle_block.2", ctx, h, emb, num_video_frames); // [N, 4*model_channels, h/8, w/8]
526529
}

0 commit comments

Comments
 (0)