Skip to content

Commit 46d1075

Browse files
committed
tests
1 parent ec1395b commit 46d1075

File tree

2 files changed

+12
-26
lines changed

2 files changed

+12
-26
lines changed

tests/analyze/error_analysis/test_pipeline.py

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def summarize(
1616

1717

1818
class MockEpisodeSummarizer:
19-
def summarize(self, exp_result: ExpResult, step_analysis: list[str]) -> str:
19+
def __call__(self, exp_result: ExpResult) -> str:
2020
return f"Agent did actions {', '.join(step.action for step in exp_result.steps_info if step.action)}"
2121

2222

@@ -33,8 +33,6 @@ def pipeline() -> ErrorAnalysisPipeline:
3333
exp_dir=exp_dir,
3434
filter=None,
3535
episode_summarizer=MockEpisodeSummarizer(),
36-
step_summarizer=MockStepSummarizer(),
37-
analyzer=MockAnalyzer(),
3836
)
3937

4038

@@ -49,30 +47,10 @@ def test_yield_with_filter(pipeline: ErrorAnalysisPipeline):
4947
pipeline.filter = None
5048

5149

52-
def test_analyze_step(pipeline: ErrorAnalysisPipeline):
53-
exp_result = next(pipeline.filter_exp_results())
54-
step_analysis = pipeline.analyze_step(exp_result)
55-
56-
assert len(exp_result.steps_info) == len(step_analysis) + 1
57-
assert step_analysis[0] == f"Agent took action {exp_result.steps_info[0].action} at step 0"
58-
59-
60-
def test_analyze_episode(pipeline: ErrorAnalysisPipeline):
61-
exp_result = next(pipeline.filter_exp_results())
62-
step_analysis = pipeline.analyze_step(exp_result)
63-
episode_analysis = pipeline.analyze_episode(exp_result, step_analysis)
64-
65-
for step_info in exp_result.steps_info:
66-
if step_info.action:
67-
assert step_info.action in episode_analysis
68-
69-
7050
def test_save_analysis(pipeline: ErrorAnalysisPipeline):
7151
exp_result = next(pipeline.filter_exp_results())
72-
step_analysis = pipeline.analyze_step(exp_result)
73-
episode_analysis = pipeline.analyze_episode(exp_result, step_analysis)
74-
error_analysis = pipeline.analyze_errors(exp_result, episode_analysis, step_analysis)
7552

53+
error_analysis = pipeline.episode_summarizer(exp_result)
7654
pipeline.save_analysis(exp_result, error_analysis, exists_ok=False)
7755

7856
assert (exp_result.exp_dir / "error_analysis.json").exists()

tests/analyze/error_analysis/test_summarizer.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,18 @@ def exp_results() -> list[ExpResult]:
1313
return list(yield_all_exp_results(exp_dir))
1414

1515

16+
@pytest.mark.pricy
1617
def test_change_summarizer(exp_results: list[ExpResult]):
17-
summarizer = ChangeSummarizer(llm=lambda x: x)
18+
summarizer = ChangeSummarizer(llm=lambda x: {"content": x})
1819
step = exp_results[0].steps_info[0]
1920
next_step = exp_results[0].steps_info[1]
2021
past_summaries = []
2122
summary = summarizer.summarize(step, next_step, past_summaries)
22-
assert isinstance(summary, str)
23+
assert isinstance(summary, dict)
24+
25+
26+
if __name__ == "__main__":
27+
exp_res = list(
28+
yield_all_exp_results(Path(__file__).parent.parent.parent / "data/error_analysis")
29+
)
30+
test_change_summarizer(exp_res)

0 commit comments

Comments
 (0)