Skip to content

Commit 3e7ce7f

Browse files
leejetstduhpf
authored andcommitted
fix conditionner
1 parent 1e855c8 commit 3e7ce7f

File tree

3 files changed

+68
-23
lines changed

3 files changed

+68
-23
lines changed

conditioner.hpp

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1755,9 +1755,13 @@ struct LLMEmbedder : public Conditioner {
17551755
std::vector<std::pair<int, ggml_tensor*>> image_embeds;
17561756
std::pair<int, int> prompt_attn_range;
17571757
int prompt_template_encode_start_idx = 34;
1758+
int prompt_template_encode_end_idx = 0;
17581759
int max_length = 0;
17591760
bool spell_quotes = false;
17601761
std::set<int> out_layers;
1762+
std::vector<int> tokens;
1763+
std::vector<float> weights;
1764+
std::vector<float> mask;
17611765
if (llm->enable_vision && conditioner_params.ref_images.size() > 0) {
17621766
if (sd_version_is_longcat(version)) {
17631767
LOG_INFO("LongCatEditPipeline");
@@ -1937,8 +1941,8 @@ struct LLMEmbedder : public Conditioner {
19371941
prompt += "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n";
19381942
} else if (sd_version_is_longcat(version)) {
19391943
prompt_template_encode_start_idx = 36;
1940-
// prompt_template_encode_end_idx = 5;
1941-
max_length = 512;
1944+
max_length = 512 + prompt_template_encode_start_idx;
1945+
prompt_template_encode_end_idx = 5;
19421946
spell_quotes = true;
19431947

19441948
prompt = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n";
@@ -1947,7 +1951,24 @@ struct LLMEmbedder : public Conditioner {
19471951
prompt += conditioner_params.text;
19481952
prompt_attn_range.second = static_cast<int>(prompt.size());
19491953

1950-
prompt += "<|im_end|>\n<|im_start|>assistant\n";
1954+
auto tokens_and_weights = tokenize(prompt, prompt_attn_range, 0, false, spell_quotes);
1955+
tokens = std::get<0>(tokens_and_weights);
1956+
weights = std::get<1>(tokens_and_weights);
1957+
1958+
mask.insert(mask.end(), tokens.size(), 1.f);
1959+
if (tokens.size() < max_length) {
1960+
mask.insert(mask.end(), max_length - tokens.size(), 0.f);
1961+
tokenizer->pad_tokens(tokens, weights, max_length, true);
1962+
}
1963+
1964+
std::string prompt_template_suffix = "<|im_end|>\n<|im_start|>assistant\n";
1965+
auto suffix_tokens = tokenizer->tokenize(prompt_template_suffix, nullptr);
1966+
1967+
LOG_DEBUG("%zd", tokens.size());
1968+
1969+
tokens.insert(tokens.end(), suffix_tokens.begin(), suffix_tokens.end());
1970+
weights.insert(weights.end(), suffix_tokens.size(), 1.f);
1971+
mask.insert(mask.end(), suffix_tokens.size(), 1.f);
19511972
} else {
19521973
prompt_template_encode_start_idx = 34;
19531974

@@ -1960,17 +1981,33 @@ struct LLMEmbedder : public Conditioner {
19601981
prompt += "<|im_end|>\n<|im_start|>assistant\n";
19611982
}
19621983

1963-
auto tokens_and_weights = tokenize(prompt, prompt_attn_range, max_length, max_length > 0, spell_quotes);
1964-
auto& tokens = std::get<0>(tokens_and_weights);
1965-
auto& weights = std::get<1>(tokens_and_weights);
1984+
if (tokens.empty()) {
1985+
auto tokens_and_weights = tokenize(prompt, prompt_attn_range, max_length, max_length > 0, spell_quotes);
1986+
tokens = std::get<0>(tokens_and_weights);
1987+
weights = std::get<1>(tokens_and_weights);
1988+
}
1989+
19661990

19671991
int64_t t0 = ggml_time_ms();
19681992
struct ggml_tensor* hidden_states = nullptr; // [N, n_token, 3584]
19691993

19701994
auto input_ids = vector_to_ggml_tensor_i32(work_ctx, tokens);
1995+
ggml_tensor* attention_mask = nullptr;
1996+
if (!mask.empty()) {
1997+
attention_mask = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, mask.size(), mask.size());
1998+
ggml_ext_tensor_iter(attention_mask, [&](ggml_tensor* attention_mask, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
1999+
float value = 0.f;
2000+
if (mask[i0] == 0.f || mask[i1] == 0.f) {
2001+
value = -INFINITY;
2002+
}
2003+
ggml_ext_tensor_set_f32(attention_mask, value, i0, i1, i2, i3);
2004+
});
2005+
print_ggml_tensor(attention_mask);
2006+
}
19712007

19722008
llm->compute(n_threads,
19732009
input_ids,
2010+
attention_mask,
19742011
image_embeds,
19752012
out_layers,
19762013
&hidden_states,
@@ -2008,18 +2045,18 @@ struct LLMEmbedder : public Conditioner {
20082045
ggml_tensor* new_hidden_states = ggml_new_tensor_3d(work_ctx,
20092046
GGML_TYPE_F32,
20102047
hidden_states->ne[0],
2011-
hidden_states->ne[1] - prompt_template_encode_start_idx + zero_pad_len,
2048+
hidden_states->ne[1] - prompt_template_encode_start_idx + zero_pad_len - prompt_template_encode_end_idx,
20122049
hidden_states->ne[2]);
20132050

20142051
ggml_ext_tensor_iter(new_hidden_states, [&](ggml_tensor* new_hidden_states, int64_t i0, int64_t i1, int64_t i2, int64_t i3) {
20152052
float value = 0.f;
2016-
if (i1 + prompt_template_encode_start_idx < hidden_states->ne[1]) {
2053+
if (i1 + prompt_template_encode_start_idx < hidden_states->ne[1] - prompt_template_encode_end_idx) {
20172054
value = ggml_ext_tensor_get_f32(hidden_states, i0, i1 + prompt_template_encode_start_idx, i2, i3);
20182055
}
20192056
ggml_ext_tensor_set_f32(new_hidden_states, value, i0, i1, i2, i3);
20202057
});
20212058

2022-
// print_ggml_tensor(new_hidden_states);
2059+
print_ggml_tensor(new_hidden_states, true);
20232060

20242061
int64_t t1 = ggml_time_ms();
20252062
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);

ggml_extend.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2207,7 +2207,7 @@ class Linear : public UnaryBlock {
22072207
bool bias = true,
22082208
bool force_f32 = false,
22092209
bool force_prec_f32 = false,
2210-
float scale = 1.f)
2210+
float scale = 1.f / 128.f)
22112211
: in_features(in_features),
22122212
out_features(out_features),
22132213
bias(bias),

llm.hpp

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -837,7 +837,8 @@ namespace LLM {
837837

838838
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
839839
struct ggml_tensor* x,
840-
struct ggml_tensor* input_pos) {
840+
struct ggml_tensor* input_pos,
841+
struct ggml_tensor* attention_mask = nullptr) {
841842
// x: [N, n_token, hidden_size]
842843
int64_t n_token = x->ne[1];
843844
int64_t N = x->ne[2];
@@ -880,7 +881,7 @@ namespace LLM {
880881
k = ggml_cont(ctx->ggml_ctx, ggml_ext_torch_permute(ctx->ggml_ctx, k, 0, 2, 1, 3)); // [N, num_kv_heads, n_token, head_dim]
881882
k = ggml_reshape_3d(ctx->ggml_ctx, k, k->ne[0], k->ne[1], k->ne[2] * k->ne[3]); // [N*num_kv_heads, n_token, head_dim]
882883

883-
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, nullptr, true, true, false); // [N, n_token, hidden_size]
884+
x = ggml_ext_attention_ext(ctx->ggml_ctx, ctx->backend, q, k, v, num_heads, attention_mask, true, true, false); // [N, n_token, hidden_size]
884885

885886
x = out_proj->forward(ctx, x); // [N, n_token, hidden_size]
886887
return x;
@@ -898,7 +899,8 @@ namespace LLM {
898899

899900
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
900901
struct ggml_tensor* x,
901-
struct ggml_tensor* input_pos) {
902+
struct ggml_tensor* input_pos,
903+
struct ggml_tensor* attention_mask = nullptr) {
902904
// x: [N, n_token, hidden_size]
903905
auto self_attn = std::dynamic_pointer_cast<Attention>(blocks["self_attn"]);
904906
auto mlp = std::dynamic_pointer_cast<MLP>(blocks["mlp"]);
@@ -907,7 +909,7 @@ namespace LLM {
907909

908910
auto residual = x;
909911
x = input_layernorm->forward(ctx, x);
910-
x = self_attn->forward(ctx, x, input_pos);
912+
x = self_attn->forward(ctx, x, input_pos, attention_mask);
911913
x = ggml_add_inplace(ctx->ggml_ctx, x, residual);
912914

913915
residual = x;
@@ -936,6 +938,7 @@ namespace LLM {
936938
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
937939
struct ggml_tensor* input_ids,
938940
struct ggml_tensor* input_pos,
941+
struct ggml_tensor* attention_mask,
939942
std::vector<std::pair<int, ggml_tensor*>> image_embeds,
940943
std::set<int> out_layers) {
941944
// input_ids: [N, n_token]
@@ -990,7 +993,7 @@ namespace LLM {
990993
for (int i = 0; i < num_layers; i++) {
991994
auto block = std::dynamic_pointer_cast<TransformerBlock>(blocks["layers." + std::to_string(i)]);
992995

993-
x = block->forward(ctx, x, input_pos);
996+
x = block->forward(ctx, x, input_pos, attention_mask);
994997
if (out_layers.find(i + 1) != out_layers.end()) {
995998
intermediate_outputs.push_back(x);
996999
}
@@ -1036,12 +1039,13 @@ namespace LLM {
10361039
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
10371040
struct ggml_tensor* input_ids,
10381041
struct ggml_tensor* input_pos,
1042+
struct ggml_tensor* attention_mask,
10391043
std::vector<std::pair<int, ggml_tensor*>> image_embeds,
10401044
std::set<int> out_layers) {
10411045
// input_ids: [N, n_token]
10421046
auto model = std::dynamic_pointer_cast<TextModel>(blocks["model"]);
10431047

1044-
auto x = model->forward(ctx, input_ids, input_pos, image_embeds, out_layers);
1048+
auto x = model->forward(ctx, input_ids, input_pos, attention_mask, image_embeds, out_layers);
10451049
return x;
10461050
}
10471051

@@ -1157,9 +1161,10 @@ namespace LLM {
11571161
struct ggml_tensor* forward(GGMLRunnerContext* ctx,
11581162
struct ggml_tensor* input_ids,
11591163
struct ggml_tensor* input_pos,
1164+
struct ggml_tensor* attention_mask,
11601165
std::vector<std::pair<int, ggml_tensor*>> image_embeds,
11611166
std::set<int> out_layers) {
1162-
auto hidden_states = model.forward(ctx, input_ids, input_pos, image_embeds, out_layers); // [N, n_token, hidden_size]
1167+
auto hidden_states = model.forward(ctx, input_ids, input_pos, attention_mask, image_embeds, out_layers); // [N, n_token, hidden_size]
11631168
return hidden_states;
11641169
}
11651170

@@ -1174,11 +1179,13 @@ namespace LLM {
11741179
}
11751180

11761181
struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids,
1182+
struct ggml_tensor* attention_mask,
11771183
std::vector<std::pair<int, ggml_tensor*>> image_embeds,
11781184
std::set<int> out_layers) {
11791185
struct ggml_cgraph* gf = ggml_new_graph(compute_ctx);
11801186

11811187
input_ids = to_backend(input_ids);
1188+
attention_mask = to_backend(attention_mask);
11821189

11831190
for (auto& image_embed : image_embeds) {
11841191
image_embed.second = to_backend(image_embed.second);
@@ -1207,7 +1214,7 @@ namespace LLM {
12071214

12081215
auto runner_ctx = get_context();
12091216

1210-
struct ggml_tensor* hidden_states = forward(&runner_ctx, input_ids, input_pos, image_embeds, out_layers);
1217+
struct ggml_tensor* hidden_states = forward(&runner_ctx, input_ids, input_pos, attention_mask, image_embeds, out_layers);
12111218

12121219
ggml_build_forward_expand(gf, hidden_states);
12131220

@@ -1216,12 +1223,13 @@ namespace LLM {
12161223

12171224
bool compute(const int n_threads,
12181225
struct ggml_tensor* input_ids,
1226+
struct ggml_tensor* attention_mask,
12191227
std::vector<std::pair<int, ggml_tensor*>> image_embeds,
12201228
std::set<int> out_layers,
12211229
ggml_tensor** output,
12221230
ggml_context* output_ctx = nullptr) {
12231231
auto get_graph = [&]() -> struct ggml_cgraph* {
1224-
return build_graph(input_ids, image_embeds, out_layers);
1232+
return build_graph(input_ids, attention_mask, image_embeds, out_layers);
12251233
};
12261234
return GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
12271235
}
@@ -1525,7 +1533,7 @@ namespace LLM {
15251533
struct ggml_tensor* out = nullptr;
15261534

15271535
int64_t t0 = ggml_time_ms();
1528-
model.compute(8, input_ids, image_embeds, {}, &out, work_ctx);
1536+
model.compute(8, input_ids, nullptr, image_embeds, {}, &out, work_ctx);
15291537
int64_t t1 = ggml_time_ms();
15301538

15311539
print_ggml_tensor(out);
@@ -1565,7 +1573,7 @@ namespace LLM {
15651573
struct ggml_tensor* out = nullptr;
15661574

15671575
int64_t t0 = ggml_time_ms();
1568-
model.compute(8, input_ids, {}, {10, 20, 30}, &out, work_ctx);
1576+
model.compute(8, input_ids, nullptr, {}, {10, 20, 30}, &out, work_ctx);
15691577
int64_t t1 = ggml_time_ms();
15701578

15711579
print_ggml_tensor(out);
@@ -1588,7 +1596,7 @@ namespace LLM {
15881596
struct ggml_tensor* out = nullptr;
15891597

15901598
int64_t t0 = ggml_time_ms();
1591-
model.compute(8, input_ids, {}, {35}, &out, work_ctx);
1599+
model.compute(8, input_ids, nullptr, {}, {35}, &out, work_ctx);
15921600
int64_t t1 = ggml_time_ms();
15931601

15941602
print_ggml_tensor(out);
@@ -1611,7 +1619,7 @@ namespace LLM {
16111619
struct ggml_tensor* out = nullptr;
16121620

16131621
int64_t t0 = ggml_time_ms();
1614-
model.compute(8, input_ids, {}, {}, &out, work_ctx);
1622+
model.compute(8, input_ids, nullptr, {}, {}, &out, work_ctx);
16151623
int64_t t1 = ggml_time_ms();
16161624

16171625
print_ggml_tensor(out);

0 commit comments

Comments
 (0)