Skip to content

Commit 44f42bc

Browse files
committed
feat: support new chroma radiance "x0_x32_proto"
1 parent 9565c7f commit 44f42bc

File tree

1 file changed

+19
-7
lines changed

1 file changed

+19
-7
lines changed

flux.hpp

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -748,7 +748,7 @@ namespace Flux {
748748
int nerf_depth = 4;
749749
int nerf_max_freqs = 8;
750750
bool use_x0 = false;
751-
bool use_patch_size_32 = false;
751+
bool fake_patch_size_x2 = false;
752752
};
753753

754754
struct FluxParams {
@@ -786,7 +786,10 @@ namespace Flux {
786786
Flux(FluxParams params)
787787
: params(params) {
788788
if (params.version == VERSION_CHROMA_RADIANCE) {
789-
std::pair<int, int> kernel_size = {16, 16};
789+
std::pair<int, int> kernel_size = {params.patch_size, params.patch_size};
790+
if(params.chroma_radiance_params.fake_patch_size_x2){
791+
kernel_size = {params.patch_size/2, params.patch_size/2};
792+
}
790793
std::pair<int, int> stride = kernel_size;
791794

792795
blocks["img_in_patch"] = std::make_shared<Conv2d>(params.in_channels,
@@ -1082,7 +1085,7 @@ namespace Flux {
10821085
auto img = pad_to_patch_size(ctx, x);
10831086
auto orig_img = img;
10841087

1085-
if (params.chroma_radiance_params.use_patch_size_32) {
1088+
if (params.chroma_radiance_params.fake_patch_size_x2) {
10861089
// It's supposed to be using GGML_SCALE_MODE_NEAREST, but this seems more stable
10871090
// Maybe the implementation of nearest-neighbor interpolation in ggml behaves differently than the one in PyTorch?
10881091
// img = F.interpolate(img, size=(H//2, W//2), mode="nearest")
@@ -1303,7 +1306,8 @@ namespace Flux {
13031306
flux_params.ref_index_scale = 10.f;
13041307
flux_params.use_mlp_silu_act = true;
13051308
}
1306-
int64_t head_dim = 0;
1309+
int64_t head_dim = 0;
1310+
int64_t actual_radiance_patch_size = 16;
13071311
for (auto pair : tensor_storage_map) {
13081312
std::string tensor_name = pair.first;
13091313
if (!starts_with(tensor_name, prefix))
@@ -1316,9 +1320,12 @@ namespace Flux {
13161320
flux_params.chroma_radiance_params.use_x0 = true;
13171321
}
13181322
if (tensor_name.find("__32x32__") != std::string::npos) {
1319-
LOG_DEBUG("using patch size 32 prediction");
1320-
flux_params.chroma_radiance_params.use_patch_size_32 = true;
1321-
flux_params.patch_size = 32;
1323+
LOG_DEBUG("using patch size 32");
1324+
flux_params.patch_size = 32;
1325+
}
1326+
if (tensor_name.find("img_in_patch.weight") != std::string::npos) {
1327+
actual_radiance_patch_size = pair.second.ne[0];
1328+
LOG_DEBUG("actual radiance patch size: %d", actual_radiance_patch_size);
13221329
}
13231330
if (tensor_name.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) {
13241331
// Chroma
@@ -1351,6 +1358,11 @@ namespace Flux {
13511358
head_dim = pair.second.ne[0];
13521359
}
13531360
}
1361+
if (actual_radiance_patch_size != flux_params.patch_size) {
1362+
GGML_ASSERT(flux_params.patch_size == 2 * actual_radiance_patch_size);
1363+
LOG_DEBUG("using fake x2 patch size");
1364+
flux_params.chroma_radiance_params.fake_patch_size_x2 = true;
1365+
}
13541366

13551367
flux_params.num_heads = static_cast<int>(flux_params.hidden_size / head_dim);
13561368

0 commit comments

Comments
 (0)