@@ -837,7 +837,8 @@ namespace LLM {
837837
838838 struct ggml_tensor * forward (GGMLRunnerContext* ctx,
839839 struct ggml_tensor * x,
840- struct ggml_tensor * input_pos) {
840+ struct ggml_tensor * input_pos,
841+ struct ggml_tensor * attention_mask = nullptr ) {
841842 // x: [N, n_token, hidden_size]
842843 int64_t n_token = x->ne [1 ];
843844 int64_t N = x->ne [2 ];
@@ -880,7 +881,7 @@ namespace LLM {
880881 k = ggml_cont (ctx->ggml_ctx , ggml_ext_torch_permute (ctx->ggml_ctx , k, 0 , 2 , 1 , 3 )); // [N, num_kv_heads, n_token, head_dim]
881882 k = ggml_reshape_3d (ctx->ggml_ctx , k, k->ne [0 ], k->ne [1 ], k->ne [2 ] * k->ne [3 ]); // [N*num_kv_heads, n_token, head_dim]
882883
883- x = ggml_ext_attention_ext (ctx->ggml_ctx , ctx->backend , q, k, v, num_heads, nullptr , true , true , false ); // [N, n_token, hidden_size]
884+ x = ggml_ext_attention_ext (ctx->ggml_ctx , ctx->backend , q, k, v, num_heads, attention_mask , true , true , false ); // [N, n_token, hidden_size]
884885
885886 x = out_proj->forward (ctx, x); // [N, n_token, hidden_size]
886887 return x;
@@ -898,7 +899,8 @@ namespace LLM {
898899
899900 struct ggml_tensor * forward (GGMLRunnerContext* ctx,
900901 struct ggml_tensor * x,
901- struct ggml_tensor * input_pos) {
902+ struct ggml_tensor * input_pos,
903+ struct ggml_tensor * attention_mask = nullptr ) {
902904 // x: [N, n_token, hidden_size]
903905 auto self_attn = std::dynamic_pointer_cast<Attention>(blocks[" self_attn" ]);
904906 auto mlp = std::dynamic_pointer_cast<MLP>(blocks[" mlp" ]);
@@ -907,7 +909,7 @@ namespace LLM {
907909
908910 auto residual = x;
909911 x = input_layernorm->forward (ctx, x);
910- x = self_attn->forward (ctx, x, input_pos);
912+ x = self_attn->forward (ctx, x, input_pos, attention_mask );
911913 x = ggml_add_inplace (ctx->ggml_ctx , x, residual);
912914
913915 residual = x;
@@ -936,6 +938,7 @@ namespace LLM {
936938 struct ggml_tensor * forward (GGMLRunnerContext* ctx,
937939 struct ggml_tensor * input_ids,
938940 struct ggml_tensor * input_pos,
941+ struct ggml_tensor * attention_mask,
939942 std::vector<std::pair<int , ggml_tensor*>> image_embeds,
940943 std::set<int > out_layers) {
941944 // input_ids: [N, n_token]
@@ -990,7 +993,7 @@ namespace LLM {
990993 for (int i = 0 ; i < num_layers; i++) {
991994 auto block = std::dynamic_pointer_cast<TransformerBlock>(blocks[" layers." + std::to_string (i)]);
992995
993- x = block->forward (ctx, x, input_pos);
996+ x = block->forward (ctx, x, input_pos, attention_mask );
994997 if (out_layers.find (i + 1 ) != out_layers.end ()) {
995998 intermediate_outputs.push_back (x);
996999 }
@@ -1036,12 +1039,13 @@ namespace LLM {
10361039 struct ggml_tensor * forward (GGMLRunnerContext* ctx,
10371040 struct ggml_tensor * input_ids,
10381041 struct ggml_tensor * input_pos,
1042+ struct ggml_tensor * attention_mask,
10391043 std::vector<std::pair<int , ggml_tensor*>> image_embeds,
10401044 std::set<int > out_layers) {
10411045 // input_ids: [N, n_token]
10421046 auto model = std::dynamic_pointer_cast<TextModel>(blocks[" model" ]);
10431047
1044- auto x = model->forward (ctx, input_ids, input_pos, image_embeds, out_layers);
1048+ auto x = model->forward (ctx, input_ids, input_pos, attention_mask, image_embeds, out_layers);
10451049 return x;
10461050 }
10471051
@@ -1157,9 +1161,10 @@ namespace LLM {
11571161 struct ggml_tensor * forward (GGMLRunnerContext* ctx,
11581162 struct ggml_tensor * input_ids,
11591163 struct ggml_tensor * input_pos,
1164+ struct ggml_tensor * attention_mask,
11601165 std::vector<std::pair<int , ggml_tensor*>> image_embeds,
11611166 std::set<int > out_layers) {
1162- auto hidden_states = model.forward (ctx, input_ids, input_pos, image_embeds, out_layers); // [N, n_token, hidden_size]
1167+ auto hidden_states = model.forward (ctx, input_ids, input_pos, attention_mask, image_embeds, out_layers); // [N, n_token, hidden_size]
11631168 return hidden_states;
11641169 }
11651170
@@ -1174,11 +1179,13 @@ namespace LLM {
11741179 }
11751180
11761181 struct ggml_cgraph * build_graph (struct ggml_tensor * input_ids,
1182+ struct ggml_tensor * attention_mask,
11771183 std::vector<std::pair<int , ggml_tensor*>> image_embeds,
11781184 std::set<int > out_layers) {
11791185 struct ggml_cgraph * gf = ggml_new_graph (compute_ctx);
11801186
11811187 input_ids = to_backend (input_ids);
1188+ attention_mask = to_backend (attention_mask);
11821189
11831190 for (auto & image_embed : image_embeds) {
11841191 image_embed.second = to_backend (image_embed.second );
@@ -1207,7 +1214,7 @@ namespace LLM {
12071214
12081215 auto runner_ctx = get_context ();
12091216
1210- struct ggml_tensor * hidden_states = forward (&runner_ctx, input_ids, input_pos, image_embeds, out_layers);
1217+ struct ggml_tensor * hidden_states = forward (&runner_ctx, input_ids, input_pos, attention_mask, image_embeds, out_layers);
12111218
12121219 ggml_build_forward_expand (gf, hidden_states);
12131220
@@ -1216,12 +1223,13 @@ namespace LLM {
12161223
12171224 bool compute (const int n_threads,
12181225 struct ggml_tensor * input_ids,
1226+ struct ggml_tensor * attention_mask,
12191227 std::vector<std::pair<int , ggml_tensor*>> image_embeds,
12201228 std::set<int > out_layers,
12211229 ggml_tensor** output,
12221230 ggml_context* output_ctx = nullptr ) {
12231231 auto get_graph = [&]() -> struct ggml_cgraph * {
1224- return build_graph(input_ids, image_embeds, out_layers);
1232+ return build_graph(input_ids, attention_mask, image_embeds, out_layers);
12251233 };
12261234 return GGMLRunner::compute (get_graph, n_threads, true , output, output_ctx);
12271235 }
@@ -1525,7 +1533,7 @@ namespace LLM {
15251533 struct ggml_tensor * out = nullptr ;
15261534
15271535 int64_t t0 = ggml_time_ms ();
1528- model.compute (8 , input_ids, image_embeds, {}, &out, work_ctx);
1536+ model.compute (8 , input_ids, nullptr , image_embeds, {}, &out, work_ctx);
15291537 int64_t t1 = ggml_time_ms ();
15301538
15311539 print_ggml_tensor (out);
@@ -1565,7 +1573,7 @@ namespace LLM {
15651573 struct ggml_tensor * out = nullptr ;
15661574
15671575 int64_t t0 = ggml_time_ms ();
1568- model.compute (8 , input_ids, {}, {10 , 20 , 30 }, &out, work_ctx);
1576+ model.compute (8 , input_ids, nullptr , {}, {10 , 20 , 30 }, &out, work_ctx);
15691577 int64_t t1 = ggml_time_ms ();
15701578
15711579 print_ggml_tensor (out);
@@ -1588,7 +1596,7 @@ namespace LLM {
15881596 struct ggml_tensor * out = nullptr ;
15891597
15901598 int64_t t0 = ggml_time_ms ();
1591- model.compute (8 , input_ids, {}, {35 }, &out, work_ctx);
1599+ model.compute (8 , input_ids, nullptr , {}, {35 }, &out, work_ctx);
15921600 int64_t t1 = ggml_time_ms ();
15931601
15941602 print_ggml_tensor (out);
@@ -1611,7 +1619,7 @@ namespace LLM {
16111619 struct ggml_tensor * out = nullptr ;
16121620
16131621 int64_t t0 = ggml_time_ms ();
1614- model.compute (8 , input_ids, {}, {}, &out, work_ctx);
1622+ model.compute (8 , input_ids, nullptr , {}, {}, &out, work_ctx);
16151623 int64_t t1 = ggml_time_ms ();
16161624
16171625 print_ggml_tensor (out);
0 commit comments