Skip to content

Commit 88ec9d3

Browse files
authored
feat: add scale_rope support (#1121)
1 parent 60abda5 commit 88ec9d3

File tree

1 file changed

+20
-8
lines changed

1 file changed

+20
-8
lines changed

rope.hpp

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,14 +91,23 @@ namespace Rope {
9191
int axes_dim_num,
9292
int index = 0,
9393
int h_offset = 0,
94-
int w_offset = 0) {
94+
int w_offset = 0,
95+
bool scale_rope = false) {
9596
int h_len = (h + (patch_size / 2)) / patch_size;
9697
int w_len = (w + (patch_size / 2)) / patch_size;
9798

9899
std::vector<std::vector<float>> img_ids(h_len * w_len, std::vector<float>(axes_dim_num, 0.0));
99100

100-
std::vector<float> row_ids = linspace<float>(h_offset, h_len - 1 + h_offset, h_len);
101-
std::vector<float> col_ids = linspace<float>(w_offset, w_len - 1 + w_offset, w_len);
101+
int h_start = h_offset;
102+
int w_start = w_offset;
103+
104+
if (scale_rope) {
105+
h_start -= h_len / 2;
106+
w_start -= w_len / 2;
107+
}
108+
109+
std::vector<float> row_ids = linspace<float>(h_start, h_start + h_len - 1, h_len);
110+
std::vector<float> col_ids = linspace<float>(w_start, w_start + w_len - 1, w_len);
102111

103112
for (int i = 0; i < h_len; ++i) {
104113
for (int j = 0; j < w_len; ++j) {
@@ -171,7 +180,8 @@ namespace Rope {
171180
int axes_dim_num,
172181
const std::vector<ggml_tensor*>& ref_latents,
173182
bool increase_ref_index,
174-
float ref_index_scale) {
183+
float ref_index_scale,
184+
bool scale_rope) {
175185
std::vector<std::vector<float>> ids;
176186
uint64_t curr_h_offset = 0;
177187
uint64_t curr_w_offset = 0;
@@ -185,6 +195,7 @@ namespace Rope {
185195
} else {
186196
h_offset = curr_h_offset;
187197
}
198+
scale_rope = false;
188199
}
189200

190201
auto ref_ids = gen_flux_img_ids(ref->ne[1],
@@ -194,7 +205,8 @@ namespace Rope {
194205
axes_dim_num,
195206
static_cast<int>(index * ref_index_scale),
196207
h_offset,
197-
w_offset);
208+
w_offset,
209+
scale_rope);
198210
ids = concat_ids(ids, ref_ids, bs);
199211

200212
if (increase_ref_index) {
@@ -222,7 +234,7 @@ namespace Rope {
222234

223235
auto ids = concat_ids(txt_ids, img_ids, bs);
224236
if (ref_latents.size() > 0) {
225-
auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, ref_index_scale);
237+
auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, ref_index_scale, false);
226238
ids = concat_ids(ids, refs_ids, bs);
227239
}
228240
return ids;
@@ -271,10 +283,10 @@ namespace Rope {
271283
}
272284
}
273285
int axes_dim_num = 3;
274-
auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num);
286+
auto img_ids = gen_flux_img_ids(h, w, patch_size, bs, axes_dim_num, 0, 0, 0, true);
275287
auto ids = concat_ids(txt_ids_repeated, img_ids, bs);
276288
if (ref_latents.size() > 0) {
277-
auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, 1.f);
289+
auto refs_ids = gen_refs_ids(patch_size, bs, axes_dim_num, ref_latents, increase_ref_index, 1.f, true);
278290
ids = concat_ids(ids, refs_ids, bs);
279291
}
280292
return ids;

0 commit comments

Comments
 (0)