@@ -748,7 +748,7 @@ namespace Flux {
748748 int nerf_depth = 4 ;
749749 int nerf_max_freqs = 8 ;
750750 bool use_x0 = false ;
751- bool use_patch_size_32 = false ;
751+ bool fake_patch_size_x2 = false ;
752752 };
753753
754754 struct FluxParams {
@@ -786,7 +786,10 @@ namespace Flux {
786786 Flux (FluxParams params)
787787 : params(params) {
788788 if (params.version == VERSION_CHROMA_RADIANCE) {
789- std::pair<int , int > kernel_size = {16 , 16 };
789+ std::pair<int , int > kernel_size = {params.patch_size , params.patch_size };
790+ if (params.chroma_radiance_params .fake_patch_size_x2 ){
791+ kernel_size = {params.patch_size /2 , params.patch_size /2 };
792+ }
790793 std::pair<int , int > stride = kernel_size;
791794
792795 blocks[" img_in_patch" ] = std::make_shared<Conv2d>(params.in_channels ,
@@ -1082,7 +1085,7 @@ namespace Flux {
10821085 auto img = pad_to_patch_size (ctx, x);
10831086 auto orig_img = img;
10841087
1085- if (params.chroma_radiance_params .use_patch_size_32 ) {
1088+ if (params.chroma_radiance_params .fake_patch_size_x2 ) {
10861089 // It's supposed to be using GGML_SCALE_MODE_NEAREST, but this seems more stable
10871090 // Maybe the implementation of nearest-neighbor interpolation in ggml behaves differently than the one in PyTorch?
10881091 // img = F.interpolate(img, size=(H//2, W//2), mode="nearest")
@@ -1303,7 +1306,8 @@ namespace Flux {
13031306 flux_params.ref_index_scale = 10 .f ;
13041307 flux_params.use_mlp_silu_act = true ;
13051308 }
1306- int64_t head_dim = 0 ;
1309+ int64_t head_dim = 0 ;
1310+ int64_t actual_radiance_patch_size = 16 ;
13071311 for (auto pair : tensor_storage_map) {
13081312 std::string tensor_name = pair.first ;
13091313 if (!starts_with (tensor_name, prefix))
@@ -1316,9 +1320,12 @@ namespace Flux {
13161320 flux_params.chroma_radiance_params .use_x0 = true ;
13171321 }
13181322 if (tensor_name.find (" __32x32__" ) != std::string::npos) {
1319- LOG_DEBUG (" using patch size 32 prediction" );
1320- flux_params.chroma_radiance_params .use_patch_size_32 = true ;
1321- flux_params.patch_size = 32 ;
1323+ LOG_DEBUG (" using patch size 32" );
1324+ flux_params.patch_size = 32 ;
1325+ }
1326+ if (tensor_name.find (" img_in_patch.weight" ) != std::string::npos) {
1327+ actual_radiance_patch_size = pair.second .ne [0 ];
1328+ LOG_DEBUG (" actual radiance patch size: %d" , actual_radiance_patch_size);
13221329 }
13231330 if (tensor_name.find (" distilled_guidance_layer.in_proj.weight" ) != std::string::npos) {
13241331 // Chroma
@@ -1351,6 +1358,11 @@ namespace Flux {
13511358 head_dim = pair.second .ne [0 ];
13521359 }
13531360 }
1361+ if (actual_radiance_patch_size != flux_params.patch_size ) {
1362+ GGML_ASSERT (flux_params.patch_size == 2 * actual_radiance_patch_size);
1363+ LOG_DEBUG (" using fake x2 patch size" );
1364+ flux_params.chroma_radiance_params .fake_patch_size_x2 = true ;
1365+ }
13541366
13551367 flux_params.num_heads = static_cast <int >(flux_params.hidden_size / head_dim);
13561368
0 commit comments