@@ -91,14 +91,23 @@ namespace Rope {
9191 int axes_dim_num,
9292 int index = 0 ,
9393 int h_offset = 0 ,
94- int w_offset = 0 ) {
94+ int w_offset = 0 ,
95+ bool scale_rope = false ) {
9596 int h_len = (h + (patch_size / 2 )) / patch_size;
9697 int w_len = (w + (patch_size / 2 )) / patch_size;
9798
9899 std::vector<std::vector<float >> img_ids (h_len * w_len, std::vector<float >(axes_dim_num, 0.0 ));
99100
100- std::vector<float > row_ids = linspace<float >(h_offset, h_len - 1 + h_offset, h_len);
101- std::vector<float > col_ids = linspace<float >(w_offset, w_len - 1 + w_offset, w_len);
101+ int h_start = h_offset;
102+ int w_start = w_offset;
103+
104+ if (scale_rope) {
105+ h_start -= h_len / 2 ;
106+ w_start -= w_len / 2 ;
107+ }
108+
109+ std::vector<float > row_ids = linspace<float >(h_start, h_start + h_len - 1 , h_len);
110+ std::vector<float > col_ids = linspace<float >(w_start, w_start + w_len - 1 , w_len);
102111
103112 for (int i = 0 ; i < h_len; ++i) {
104113 for (int j = 0 ; j < w_len; ++j) {
@@ -171,7 +180,8 @@ namespace Rope {
171180 int axes_dim_num,
172181 const std::vector<ggml_tensor*>& ref_latents,
173182 bool increase_ref_index,
174- float ref_index_scale) {
183+ float ref_index_scale,
184+ bool scale_rope) {
175185 std::vector<std::vector<float >> ids;
176186 uint64_t curr_h_offset = 0 ;
177187 uint64_t curr_w_offset = 0 ;
@@ -185,6 +195,7 @@ namespace Rope {
185195 } else {
186196 h_offset = curr_h_offset;
187197 }
198+ scale_rope = false ;
188199 }
189200
190201 auto ref_ids = gen_flux_img_ids (ref->ne [1 ],
@@ -194,7 +205,8 @@ namespace Rope {
194205 axes_dim_num,
195206 static_cast <int >(index * ref_index_scale),
196207 h_offset,
197- w_offset);
208+ w_offset,
209+ scale_rope);
198210 ids = concat_ids (ids, ref_ids, bs);
199211
200212 if (increase_ref_index) {
@@ -222,7 +234,7 @@ namespace Rope {
222234
223235 auto ids = concat_ids (txt_ids, img_ids, bs);
224236 if (ref_latents.size () > 0 ) {
225- auto refs_ids = gen_refs_ids (patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, ref_index_scale);
237+ auto refs_ids = gen_refs_ids (patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, ref_index_scale, false );
226238 ids = concat_ids (ids, refs_ids, bs);
227239 }
228240 return ids;
@@ -271,10 +283,10 @@ namespace Rope {
271283 }
272284 }
273285 int axes_dim_num = 3 ;
274- auto img_ids = gen_flux_img_ids (h, w, patch_size, bs, axes_dim_num);
286+ auto img_ids = gen_flux_img_ids (h, w, patch_size, bs, axes_dim_num, 0 , 0 , 0 , true );
275287 auto ids = concat_ids (txt_ids_repeated, img_ids, bs);
276288 if (ref_latents.size () > 0 ) {
277- auto refs_ids = gen_refs_ids (patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, 1 .f );
289+ auto refs_ids = gen_refs_ids (patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, 1 .f , true );
278290 ids = concat_ids (ids, refs_ids, bs);
279291 }
280292 return ids;
0 commit comments