Skip to content

Commit ad9ac8d

Browse files
committed
generate.
1 parent acfa871 commit ad9ac8d

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

utils/generate_model_tests.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
# Other testers
7373
("SingleFileTesterMixin", "single_file"),
7474
("IPAdapterTesterMixin", "ip_adapter"),
75+
("ContextParallelAttentionBackendsTesterMixin", "cp_attn"),
7576
]
7677

7778

@@ -229,7 +230,14 @@ def determine_testers(model_info: dict, include_optional: list[str], imports: se
229230

230231
for tester, flag in OPTIONAL_TESTERS:
231232
if flag in include_optional:
232-
if tester not in testers:
233+
if tester == "ContextParallelAttentionBackendsTesterMixin":
234+
if (
235+
"cp_attn" in include_optional
236+
and "_cp_plan" in model_info["attributes"]
237+
and model_info["attributes"]["_cp_plan"] is not None
238+
):
239+
testers.append(tester)
240+
elif tester not in testers:
233241
testers.append(tester)
234242

235243
return testers
@@ -530,6 +538,7 @@ def main():
530538
"faster_cache",
531539
"single_file",
532540
"ip_adapter",
541+
"cp_attn",
533542
"all",
534543
],
535544
help="Optional testers to include",

0 commit comments

Comments
 (0)