Skip to content

Commit 1d3949c

Browse files
authored
feat: add multiseries support (#64)
2 parents 97a3ae1 + a810b2f commit 1d3949c

File tree

3 files changed

+64
-46
lines changed

3 files changed

+64
-46
lines changed

tests/test_agent.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,19 @@ def _response_fn(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse
2222
def test_forecast_returns_expected_output(query):
2323
df = generate_series(n_series=1, freq="D", min_length=30)
2424
expected_output = {
25-
"tsfeatures_results": ["mean: 0.5"],
2625
"tsfeatures_analysis": "ok",
2726
"selected_model": "ZeroModel",
2827
"model_details": "details",
29-
"cross_validation_results": ["ZeroModel: 0.1"],
3028
"model_comparison": "cmp",
3129
"is_better_than_seasonal_naive": True,
3230
"reason_for_selection": "reason",
33-
"forecast": ["2025-01-01: 1.0"],
3431
"forecast_analysis": "analysis",
3532
"user_query_response": query,
3633
}
3734
tc = TimeCopilot(llm=build_stub_llm(expected_output))
35+
tc.fcst_df = None
36+
tc.eval_df = None
37+
tc.features_df = None
3838
result = tc.forecast(df=df, h=2, freq="D", seasonality=7, query=query)
3939

4040
assert result.output == ForecastAgentOutput(**expected_output)

tests/test_live.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,29 +3,37 @@
33
Keeping it separate from the other tests because costs and requires a live LLM.
44
"""
55

6+
import logfire
67
import pytest
8+
from dotenv import load_dotenv
79
from utilsforecast.data import generate_series
810

911
from timecopilot import TimeCopilot
1012

13+
load_dotenv()
14+
logfire.configure(send_to_logfire="if-token-present")
15+
logfire.instrument_pydantic_ai()
16+
1117

1218
@pytest.mark.live
13-
def test_forecast_returns_expected_output():
19+
@pytest.mark.parametrize("n_series", [1, 2])
20+
def test_forecast_returns_expected_output(n_series):
21+
h = 2
1422
df = generate_series(
15-
n_series=1,
23+
n_series=n_series,
1624
freq="D",
1725
min_length=30,
1826
static_as_categorical=False,
1927
)
20-
forecasting_agent = TimeCopilot(
28+
tc = TimeCopilot(
2129
llm="openai:gpt-4o-mini",
2230
retries=3,
2331
)
24-
result = forecasting_agent.forecast(
32+
result = tc.forecast(
2533
df=df,
26-
query="Please forecast the series with a horizon of 2 and frequency D.",
34+
query=f"Please forecast the series with a horizon of {h} and frequency D.",
2735
)
28-
assert len(result.output.forecast) == 2
36+
assert len(result.fcst_df) == n_series * h
2937
assert result.output.is_better_than_seasonal_naive
3038
assert result.output.forecast_analysis is not None
3139
assert result.output.reason_for_selection is not None

timecopilot/agent.py

Lines changed: 47 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,6 @@
8787
class ForecastAgentOutput(BaseModel):
8888
"""The output of the forecasting agent."""
8989

90-
tsfeatures_results: list[str] = Field(
91-
description=(
92-
"The time series features that were considered as a list of strings of "
93-
"feature names and their values separated by colons."
94-
)
95-
)
9690
tsfeatures_analysis: str = Field(
9791
description=(
9892
"Analysis of what the time series features reveal about the data "
@@ -108,12 +102,6 @@ class ForecastAgentOutput(BaseModel):
108102
"strengths, and typical use cases."
109103
)
110104
)
111-
cross_validation_results: list[str] = Field(
112-
description=(
113-
"The cross-validation results as a string of model names "
114-
"and their scores separated by colons."
115-
)
116-
)
117105
model_comparison: str = Field(
118106
description=(
119107
"Detailed comparison of model performances, explaining why certain "
@@ -126,12 +114,6 @@ class ForecastAgentOutput(BaseModel):
126114
reason_for_selection: str = Field(
127115
description="Explanation for why the selected model was chosen"
128116
)
129-
forecast: list[str] = Field(
130-
description=(
131-
"The forecasted values for the time series as a list of strings of "
132-
"periods and their values separated by colons."
133-
)
134-
)
135117
forecast_analysis: str = Field(
136118
description=(
137119
"Detailed interpretation of the forecast, including trends, patterns, "
@@ -324,7 +306,7 @@ def __init__(
324306
325307
3. Final Model Selection and Forecasting:
326308
- Choose the best performing model with clear justification
327-
- Generate and analyze the forecast
309+
- Generate the forecast using just the selected model
328310
- Interpret trends and patterns in the forecast
329311
- Discuss reliability and potential uncertainties
330312
- Address any specific aspects from the user's prompt
@@ -364,10 +346,17 @@ def __init__(
364346
)
365347

366348
@self.forecasting_agent.system_prompt
367-
async def add_time_series(ctx: RunContext[ExperimentDataset]) -> str:
349+
async def add_time_series(
350+
ctx: RunContext[ExperimentDataset],
351+
) -> str:
352+
df_agg = ctx.deps.df.groupby("unique_id").agg(list)
368353
output = (
369-
f"The time series is: {ctx.deps.df['y'].tolist()}, "
370-
f"the date column is: {ctx.deps.df['ds'].tolist()}"
354+
"these are the time series in json format where the key is the "
355+
"identifier of the time series and the values is also a json "
356+
"of two elements: "
357+
"the first element is the date column and the second element is the "
358+
"value column."
359+
f"{df_agg.to_json(orient='index')}"
371360
)
372361
return output
373362

@@ -384,15 +373,27 @@ async def tsfeatures_tool(
384373
f"{', '.join(TSFEATURES.keys())}"
385374
)
386375
callable_features.append(TSFEATURES[feature])
387-
features_df = _get_feats(
388-
index=ctx.deps.df["unique_id"].iloc[0],
389-
ts=ctx.deps.df,
390-
features=callable_features,
391-
freq=ctx.deps.seasonality,
392-
)
393-
return ",".join(
394-
[f"{col}: {features_df[col].iloc[0]}" for col in features_df.columns]
376+
features_df: pd.DataFrame | None = None
377+
for uid in ctx.deps.df["unique_id"].unique():
378+
features_df_uid = _get_feats(
379+
index=uid,
380+
ts=ctx.deps.df,
381+
features=callable_features,
382+
freq=ctx.deps.seasonality,
383+
)
384+
if features_df is None:
385+
features_df = features_df_uid
386+
else:
387+
features_df = pd.concat([features_df, features_df_uid])
388+
features_df = features_df.rename_axis("unique_id") # type: ignore
389+
self.features_df = features_df
390+
output = (
391+
"these are the time series features in json format where the key is "
392+
"the identifier of the time series and the values is also a json of "
393+
"feature names and their values."
394+
f"{features_df.to_json(orient='index')}"
395395
)
396+
return output
396397

397398
@self.forecasting_agent.tool
398399
async def cross_validation_tool(
@@ -429,6 +430,7 @@ async def cross_validation_tool(
429430
["metric"],
430431
as_index=False,
431432
).mean(numeric_only=True)
433+
self.eval_df = eval_df
432434
return ", ".join(
433435
[
434436
f"{model.alias}: {eval_df[model.alias].iloc[0]}"
@@ -437,19 +439,25 @@ async def cross_validation_tool(
437439
)
438440

439441
@self.forecasting_agent.tool
440-
async def forecast_tool(ctx: RunContext[ExperimentDataset], model: str) -> str:
442+
async def forecast_tool(
443+
ctx: RunContext[ExperimentDataset],
444+
model: str,
445+
) -> str:
441446
callable_model = MODELS[model]
442447
fcst_df = callable_model.forecast(
443448
df=ctx.deps.df,
444449
h=ctx.deps.h,
445450
freq=ctx.deps.freq,
446451
)
447-
output = ",".join(
448-
[
449-
f"{row['ds'].strftime('%Y-%m-%d')}: {row[model]}"
450-
for _, row in fcst_df.iterrows()
451-
]
452+
df_agg = fcst_df.groupby("unique_id").agg(list)
453+
output = (
454+
"these are the forecasted values in json format where the key is the "
455+
"identifier of the time series and the values is also a json of two "
456+
"elements: the first element is the date column and the second "
457+
"element is the value column."
458+
f"{df_agg.to_json(orient='index')}"
452459
)
460+
self.fcst_df = fcst_df
453461
return output
454462

455463
@self.forecasting_agent.output_validator
@@ -519,5 +527,7 @@ def forecast(
519527
user_prompt=query,
520528
deps=dataset,
521529
)
522-
530+
result.fcst_df = self.fcst_df
531+
result.eval_df = self.eval_df
532+
result.features_df = self.features_df
523533
return result

0 commit comments

Comments
 (0)