Skip to content

Commit 9565c7f

Browse files
authored
add support for flux2 klein (#1193)
* add support for flux2 klein 4b * add support for flux2 klein 8b * use attention_mask in Flux.2 klein LLMEmbedder * update docs
1 parent fbce16e commit 9565c7f

15 files changed

+197
-42
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ API and command-line option may change frequently.***
4343
- SDXL, [SDXL-Turbo](https://huggingface.co/stabilityai/sdxl-turbo)
4444
- [Some SD1.x and SDXL distilled models](./docs/distilled_sd.md)
4545
- [SD3/SD3.5](./docs/sd3.md)
46-
- [FlUX.1-dev/FlUX.1-schnell](./docs/flux.md)
47-
- [FLUX.2-dev](./docs/flux2.md)
46+
- [FLUX.1-dev/FLUX.1-schnell](./docs/flux.md)
47+
- [FLUX.2-dev/FLUX.2-klein](./docs/flux2.md)
4848
- [Chroma](./docs/chroma.md)
4949
- [Chroma1-Radiance](./docs/chroma_radiance.md)
5050
- [Qwen Image](./docs/qwen_image.md)
@@ -127,8 +127,8 @@ If you want to improve performance or reduce VRAM/RAM usage, please refer to [pe
127127

128128
- [SD1.x/SD2.x/SDXL](./docs/sd.md)
129129
- [SD3/SD3.5](./docs/sd3.md)
130-
- [FlUX.1-dev/FlUX.1-schnell](./docs/flux.md)
131-
- [FLUX.2-dev](./docs/flux2.md)
130+
- [FLUX.1-dev/FLUX.1-schnell](./docs/flux.md)
131+
- [FLUX.2-dev/FLUX.2-klein](./docs/flux2.md)
132132
- [FLUX.1-Kontext-dev](./docs/kontext.md)
133133
- [Chroma](./docs/chroma.md)
134134
- [🔥Qwen Image](./docs/qwen_image.md)
510 KB
Loading

assets/flux2/flux2-klein-4b.png

455 KB
Loading
511 KB
Loading

assets/flux2/flux2-klein-9b.png

491 KB
Loading
464 KB
Loading
552 KB
Loading

conditioner.hpp

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1614,9 +1614,9 @@ struct LLMEmbedder : public Conditioner {
16141614
bool enable_vision = false)
16151615
: version(version) {
16161616
LLM::LLMArch arch = LLM::LLMArch::QWEN2_5_VL;
1617-
if (sd_version_is_flux2(version)) {
1617+
if (version == VERSION_FLUX2) {
16181618
arch = LLM::LLMArch::MISTRAL_SMALL_3_2;
1619-
} else if (sd_version_is_z_image(version) || version == VERSION_OVIS_IMAGE) {
1619+
} else if (sd_version_is_z_image(version) || version == VERSION_OVIS_IMAGE || version == VERSION_FLUX2_KLEIN) {
16201620
arch = LLM::LLMArch::QWEN3;
16211621
}
16221622
if (arch == LLM::LLMArch::MISTRAL_SMALL_3_2) {
@@ -1708,6 +1708,9 @@ struct LLMEmbedder : public Conditioner {
17081708
int prompt_template_encode_start_idx = 34;
17091709
int max_length = 0;
17101710
std::set<int> out_layers;
1711+
std::vector<int> tokens;
1712+
std::vector<float> weights;
1713+
std::vector<float> mask;
17111714
if (llm->enable_vision && conditioner_params.ref_images.size() > 0) {
17121715
LOG_INFO("QwenImageEditPlusPipeline");
17131716
prompt_template_encode_start_idx = 64;
@@ -1771,7 +1774,7 @@ struct LLMEmbedder : public Conditioner {
17711774
prompt_attn_range.second = static_cast<int>(prompt.size());
17721775

17731776
prompt += "<|im_end|>\n<|im_start|>assistant\n";
1774-
} else if (sd_version_is_flux2(version)) {
1777+
} else if (version == VERSION_FLUX2) {
17751778
prompt_template_encode_start_idx = 0;
17761779
out_layers = {10, 20, 30};
17771780

@@ -1793,17 +1796,28 @@ struct LLMEmbedder : public Conditioner {
17931796
prompt_attn_range.second = static_cast<int>(prompt.size());
17941797

17951798
prompt += "<|im_end|>\n<|im_start|>assistant\n";
1796-
} else if (sd_version_is_flux2(version)) {
1799+
} else if (version == VERSION_FLUX2_KLEIN) {
17971800
prompt_template_encode_start_idx = 0;
1798-
out_layers = {10, 20, 30};
1801+
max_length = 512;
1802+
out_layers = {9, 18, 27};
17991803

1800-
prompt = "[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]";
1804+
prompt = "<|im_start|>user\n";
18011805

18021806
prompt_attn_range.first = static_cast<int>(prompt.size());
18031807
prompt += conditioner_params.text;
18041808
prompt_attn_range.second = static_cast<int>(prompt.size());
18051809

1806-
prompt += "[/INST]";
1810+
prompt += "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n";
1811+
1812+
auto tokens_and_weights = tokenize(prompt, prompt_attn_range, 0, false);
1813+
tokens = std::get<0>(tokens_and_weights);
1814+
weights = std::get<1>(tokens_and_weights);
1815+
1816+
mask.insert(mask.end(), tokens.size(), 1.f);
1817+
if (tokens.size() < max_length) {
1818+
mask.insert(mask.end(), max_length - tokens.size(), 0.f);
1819+
tokenizer->pad_tokens(tokens, weights, max_length, true);
1820+
}
18071821
} else if (version == VERSION_OVIS_IMAGE) {
18081822
prompt_template_encode_start_idx = 28;
18091823
max_length = prompt_template_encode_start_idx + 256;
@@ -1827,17 +1841,34 @@ struct LLMEmbedder : public Conditioner {
18271841
prompt += "<|im_end|>\n<|im_start|>assistant\n";
18281842
}
18291843

1830-
auto tokens_and_weights = tokenize(prompt, prompt_attn_range, max_length, max_length > 0);
1831-
auto& tokens = std::get<0>(tokens_and_weights);
1832-
auto& weights = std::get<1>(tokens_and_weights);
1844+
if (tokens.empty()) {
1845+
auto tokens_and_weights = tokenize(prompt, prompt_attn_range, max_length, max_length > 0);
1846+
tokens = std::get<0>(tokens_and_weights);
1847+
weights = std::get<1>(tokens_and_weights);
1848+
}
18331849

18341850
int64_t t0 = ggml_time_ms();
18351851
struct ggml_tensor* hidden_states = nullptr; // [N, n_token, 3584]
18361852

18371853
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens);
18381854

1855+
ggml_tensor* attention_mask = nullptr;
1856+
if (!mask.empty()) {
1857+
attention_mask = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, mask.size(), mask.size());
1858+
ggml_ext_tensor_iter(attention_mask, [&](ggml_tensor* attention_mask, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
1859+
float value = 0.f;
1860+
if (mask[i0] == 0.f) {
1861+
value = -INFINITY;
1862+
} else if (i0 > i1) {
1863+
value = -INFINITY;
1864+
}
1865+
ggml_ext_tensor_set_f32(attention_mask, value, i0, i1, i2, i3);
1866+
});
1867+
}
1868+
18391869
llm->compute(n_threads,
18401870
input_ids,
1871+
attention_mask,
18411872
image_embeds,
18421873
out_layers,
18431874
&hidden_states,
@@ -1861,7 +1892,7 @@ struct LLMEmbedder : public Conditioner {
18611892
GGML_ASSERT(hidden_states->ne[1] > prompt_template_encode_start_idx);
18621893

18631894
int64_t min_length = 0;
1864-
if (sd_version_is_flux2(version)) {
1895+
if (version == VERSION_FLUX2) {
18651896
min_length = 512;
18661897
}
18671898

docs/flux2.md

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# How to Use
22

3-
## Download weights
3+
## Flux.2-dev
4+
5+
### Download weights
46

57
- Download FLUX.2-dev
68
- gguf: https://huggingface.co/city96/FLUX.2-dev-gguf/tree/main
@@ -9,13 +11,82 @@
911
- Download Mistral-Small-3.2-24B-Instruct-2506-GGUF
1012
- gguf: https://huggingface.co/unsloth/Mistral-Small-3.2-24B-Instruct-2506-GGUF/tree/main
1113

12-
## Examples
14+
### Examples
1315

1416
```
1517
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\flux2-dev-Q4_K_S.gguf --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\Mistral-Small-3.2-24B-Instruct-2506-Q4_K_M.gguf -r .\kontext_input.png -p "change 'flux.cpp' to 'flux2-dev.cpp'" --cfg-scale 1.0 --sampling-method euler -v --diffusion-fa --offload-to-cpu
1618
```
1719

1820
<img alt="flux2 example" src="../assets/flux2/example.png" />
1921

22+
## Flux.2 klein 4B / Flux.2 klein base 4B
23+
24+
### Download weights
25+
26+
- Download FLUX.2-klein-4B
27+
- safetensors: https://huggingface.co/black-forest-labs/FLUX.2-klein-4B
28+
- gguf: https://huggingface.co/leejet/FLUX.2-klein-4B-GGUF/tree/main
29+
- Download FLUX.2-klein-base-4B
30+
- safetensors: https://huggingface.co/black-forest-labs/FLUX.2-klein-base-4B
31+
- gguf: https://huggingface.co/leejet/FLUX.2-klein-base-4B-GGUF/tree/main
32+
- Download vae
33+
- safetensors: https://huggingface.co/black-forest-labs/FLUX.2-dev/tree/main
34+
- Download Qwen3 4b
35+
- safetensors: https://huggingface.co/Comfy-Org/flux2-klein-4B/tree/main/split_files/text_encoders
36+
- gguf: https://huggingface.co/unsloth/Qwen3-4B-GGUF/tree/main
37+
38+
### Examples
39+
40+
```
41+
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\flux-2-klein-4b.safetensors --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\qwen_3_4b.safetensors -p "a lovely cat" --cfg-scale 1.0 --steps 4 -v --offload-to-cpu --diffusion-fa
42+
```
43+
44+
<img alt="flux2-klein-4b" src="../assets/flux2/flux2-klein-4b.png" />
45+
46+
```
47+
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\flux-2-klein-4b.safetensors --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\qwen_3_4b.safetensors -r .\kontext_input.png -p "change 'flux.cpp' to 'klein.cpp'" --cfg-scale 1.0 --sampling-method euler -v --diffusion-fa --offload-to-cpu --steps 4
48+
```
49+
50+
<img alt="flux2-klein-4b-edit" src="../assets/flux2/flux2-klein-4b-edit.png" />
51+
52+
```
53+
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\flux-2-klein-base-4b.safetensors --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\qwen_3_4b.safetensors -p "a lovely cat" --cfg-scale 4.0 --steps 20 -v --offload-to-cpu --diffusion-fa
54+
```
55+
56+
<img alt="flux2-klein-base-4b" src="../assets/flux2/flux2-klein-base-4b.png" />
57+
58+
## Flux.2 klein 9B / Flux.2 klein base 9B
59+
60+
### Download weights
2061

62+
- Download FLUX.2-klein-9B
63+
- safetensors: https://huggingface.co/black-forest-labs/FLUX.2-klein-9B
64+
- gguf: https://huggingface.co/leejet/FLUX.2-klein-9B-GGUF/tree/main
65+
- Download FLUX.2-klein-base-9B
66+
- safetensors: https://huggingface.co/black-forest-labs/FLUX.2-klein-base-9B
67+
- gguf: https://huggingface.co/leejet/FLUX.2-klein-base-9B-GGUF/tree/main
68+
- Download vae
69+
- safetensors: https://huggingface.co/black-forest-labs/FLUX.2-dev/tree/main
70+
- Download Qwen3 8B
71+
- safetensors: https://huggingface.co/Comfy-Org/flux2-klein-9B/tree/main/split_files/text_encoders
72+
- gguf: https://huggingface.co/unsloth/Qwen3-8B-GGUF/tree/main
73+
74+
### Examples
75+
76+
```
77+
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\flux-2-klein-9b.safetensors --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\qwen_3_8b.safetensors -p "a lovely cat" --cfg-scale 1.0 --steps 4 -v --offload-to-cpu --diffusion-fa
78+
```
79+
80+
<img alt="flux2-klein-9b" src="../assets/flux2/flux2-klein-9b.png" />
81+
82+
```
83+
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\flux-2-klein-9b.safetensors --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\qwen_3_8b.safetensors -r .\kontext_input.png -p "change 'flux.cpp' to 'klein.cpp'" --cfg-scale 1.0 --sampling-method euler -v --diffusion-fa --offload-to-cpu --steps 4
84+
```
85+
86+
<img alt="flux2-klein-9b-edit" src="../assets/flux2/flux2-klein-9b-edit.png" />
87+
88+
```
89+
.\bin\Release\sd-cli.exe --diffusion-model ..\..\ComfyUI\models\diffusion_models\flux-2-klein-base-9b.safetensors --vae ..\..\ComfyUI\models\vae\flux2_ae.safetensors --llm ..\..\ComfyUI\models\text_encoders\qwen_3_8b.safetensors -p "a lovely cat" --cfg-scale 4.0 --steps 20 -v --offload-to-cpu --diffusion-fa
90+
```
2191

92+
<img alt="flux2-klein-base-9b" src="../assets/flux2/flux2-klein-base-9b.png" />

flux.hpp

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1288,13 +1288,9 @@ namespace Flux {
12881288
} else if (version == VERSION_OVIS_IMAGE) {
12891289
flux_params.semantic_txt_norm = true;
12901290
flux_params.use_yak_mlp = true;
1291-
flux_params.context_in_dim = 2048;
12921291
flux_params.vec_in_dim = 0;
12931292
} else if (sd_version_is_flux2(version)) {
1294-
flux_params.context_in_dim = 15360;
12951293
flux_params.in_channels = 128;
1296-
flux_params.hidden_size = 6144;
1297-
flux_params.num_heads = 48;
12981294
flux_params.patch_size = 1;
12991295
flux_params.out_channels = 128;
13001296
flux_params.mlp_ratio = 3.f;
@@ -1307,12 +1303,12 @@ namespace Flux {
13071303
flux_params.ref_index_scale = 10.f;
13081304
flux_params.use_mlp_silu_act = true;
13091305
}
1306+
int64_t head_dim = 0;
13101307
for (auto pair : tensor_storage_map) {
13111308
std::string tensor_name = pair.first;
13121309
if (!starts_with(tensor_name, prefix))
13131310
continue;
13141311
if (tensor_name.find("guidance_in.in_layer.weight") != std::string::npos) {
1315-
// not schnell
13161312
flux_params.guidance_embed = true;
13171313
}
13181314
if (tensor_name.find("__x0__") != std::string::npos) {
@@ -1344,13 +1340,30 @@ namespace Flux {
13441340
flux_params.depth_single_blocks = block_depth + 1;
13451341
}
13461342
}
1343+
if (ends_with(tensor_name, "txt_in.weight")) {
1344+
flux_params.context_in_dim = pair.second.ne[0];
1345+
flux_params.hidden_size = pair.second.ne[1];
1346+
}
1347+
if (ends_with(tensor_name, "single_blocks.0.norm.key_norm.scale")) {
1348+
head_dim = pair.second.ne[0];
1349+
}
1350+
if (ends_with(tensor_name, "double_blocks.0.txt_attn.norm.key_norm.scale")) {
1351+
head_dim = pair.second.ne[0];
1352+
}
13471353
}
13481354

1349-
LOG_INFO("Flux blocks: %d double, %d single", flux_params.depth, flux_params.depth_single_blocks);
1355+
flux_params.num_heads = static_cast<int>(flux_params.hidden_size / head_dim);
1356+
1357+
LOG_INFO("flux: depth = %d, depth_single_blocks = %d, guidance_embed = %s, context_in_dim = %" PRId64
1358+
", hidden_size = %" PRId64 ", num_heads = %d",
1359+
flux_params.depth,
1360+
flux_params.depth_single_blocks,
1361+
flux_params.guidance_embed ? "true" : "false",
1362+
flux_params.context_in_dim,
1363+
flux_params.hidden_size,
1364+
flux_params.num_heads);
13501365
if (flux_params.is_chroma) {
13511366
LOG_INFO("Using pruned modulation (Chroma)");
1352-
} else if (!flux_params.guidance_embed) {
1353-
LOG_INFO("Flux guidance is disabled (Schnell mode)");
13541367
}
13551368

13561369
flux = Flux(flux_params);

0 commit comments

Comments
 (0)