Skip to content

Commit e9a5592

Browse files
authored
feat: use timecopilot forecaster class in agent (#65)
2 parents 1d3949c + 61bec46 commit e9a5592

File tree

3 files changed

+65
-17
lines changed

3 files changed

+65
-17
lines changed

tests/test_forecaster.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def test_forecaster_forecast(models, freq, h):
2222
df = generate_series(n_series=n_uids, freq=freq, min_length=30)
2323
forecaster = TimeCopilotForecaster(models=models)
2424
fcst_df = forecaster.forecast(df=df, h=h, freq=freq)
25+
assert len(fcst_df.columns) == 2 + len(models)
2526
assert len(fcst_df) == h * n_uids
2627
for model in models:
2728
assert model.alias in fcst_df.columns
@@ -45,10 +46,40 @@ def test_forecaster_cross_validation(models, freq, h, n_windows, step_size):
4546
n_windows=n_windows,
4647
step_size=step_size,
4748
)
49+
assert len(fcst_df.columns) == 4 + len(models)
4850
uids = df["unique_id"].unique()
4951
for uid in uids: # noqa: B007
5052
fcst_df_uid = fcst_df.query("unique_id == @uid")
5153
assert fcst_df_uid["cutoff"].nunique() == n_windows
5254
assert len(fcst_df_uid) == n_windows * h
5355
for model in models:
5456
assert model.alias in fcst_df.columns
57+
58+
59+
def test_forecaster_forecast_with_level(models):
60+
n_uids = 3
61+
level = [80, 90]
62+
df = generate_series(n_series=n_uids, freq="D", min_length=30)
63+
forecaster = TimeCopilotForecaster(models=models)
64+
fcst_df = forecaster.forecast(df=df, h=2, freq="D", level=level)
65+
assert len(fcst_df) == 2 * n_uids
66+
assert len(fcst_df.columns) == 2 + len(models) * (1 + 2 * len(level))
67+
for model in models:
68+
assert model.alias in fcst_df.columns
69+
for lv in level:
70+
assert f"{model.alias}-lo-{lv}" in fcst_df.columns
71+
assert f"{model.alias}-hi-{lv}" in fcst_df.columns
72+
73+
74+
def test_forecaster_forecast_with_quantiles(models):
75+
n_uids = 3
76+
quantiles = [0.1, 0.9]
77+
df = generate_series(n_series=n_uids, freq="D", min_length=30)
78+
forecaster = TimeCopilotForecaster(models=models)
79+
fcst_df = forecaster.forecast(df=df, h=2, freq="D", quantiles=quantiles)
80+
assert len(fcst_df) == 2 * n_uids
81+
assert len(fcst_df.columns) == 2 + len(models) * (1 + len(quantiles))
82+
for model in models:
83+
assert model.alias in fcst_df.columns
84+
for q in quantiles:
85+
assert f"{model.alias}-q-{int(100 * q)}" in fcst_df.columns

timecopilot/agent.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
)
3131
from tsfeatures.tsfeatures import _get_feats
3232

33+
from .forecaster import TimeCopilotForecaster
3334
from .models.benchmarks import (
3435
ADIDA,
3536
IMAPA,
@@ -400,7 +401,6 @@ async def cross_validation_tool(
400401
ctx: RunContext[ExperimentDataset],
401402
models: list[str],
402403
) -> str:
403-
models_fcst_cv = None
404404
callable_models = []
405405
for str_model in models:
406406
if str_model not in MODELS:
@@ -409,21 +409,14 @@ async def cross_validation_tool(
409409
f"{', '.join(MODELS.keys())}"
410410
)
411411
callable_models.append(MODELS[str_model])
412-
for model in callable_models:
413-
fcst_cv = model.cross_validation(
414-
df=ctx.deps.df,
415-
h=ctx.deps.h,
416-
freq=ctx.deps.freq,
417-
)
418-
if models_fcst_cv is None:
419-
models_fcst_cv = fcst_cv
420-
else:
421-
models_fcst_cv = models_fcst_cv.merge(
422-
fcst_cv.drop(columns=["y"]),
423-
on=["unique_id", "cutoff", "ds"],
424-
)
412+
forecaster = TimeCopilotForecaster(models=callable_models)
413+
fcst_cv = forecaster.cross_validation(
414+
df=ctx.deps.df,
415+
h=ctx.deps.h,
416+
freq=ctx.deps.freq,
417+
)
425418
eval_df = ctx.deps.evaluate_forecast_df(
426-
forecast_df=models_fcst_cv,
419+
forecast_df=fcst_cv,
427420
models=[model.alias for model in callable_models],
428421
)
429422
eval_df = eval_df.groupby(
@@ -444,7 +437,8 @@ async def forecast_tool(
444437
model: str,
445438
) -> str:
446439
callable_model = MODELS[model]
447-
fcst_df = callable_model.forecast(
440+
forecaster = TimeCopilotForecaster(models=[callable_model])
441+
fcst_df = forecaster.forecast(
448442
df=ctx.deps.df,
449443
h=ctx.deps.h,
450444
freq=ctx.deps.freq,

timecopilot/forecaster.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,29 @@ def _call_models(
1414
df: pd.DataFrame,
1515
h: int,
1616
freq: str,
17+
level: list[int] | None = None,
18+
quantiles: list[float] | None = None,
1719
**kwargs,
1820
) -> pd.DataFrame:
1921
res_df: pd.DataFrame | None = None
2022
for model in self.models:
21-
res_df_model = getattr(model, attr)(df=df, h=h, freq=freq, **kwargs)
23+
res_df_model = getattr(model, attr)(
24+
df=df,
25+
h=h,
26+
freq=freq,
27+
level=level,
28+
quantiles=quantiles,
29+
**kwargs,
30+
)
2231
if res_df is None:
2332
res_df = res_df_model
2433
else:
34+
if "y" in res_df_model:
35+
# drop y to avoid duplicate columns
36+
# y was added by the previous condition
37+
# to cross validation
38+
# (the initial model)
39+
res_df_model = res_df_model.drop(columns=["y"])
2540
res_df = res_df.merge(
2641
res_df_model,
2742
on=merge_on,
@@ -33,13 +48,17 @@ def forecast(
3348
df: pd.DataFrame,
3449
h: int,
3550
freq: str,
51+
level: list[int] | None = None,
52+
quantiles: list[float] | None = None,
3653
) -> pd.DataFrame:
3754
return self._call_models(
3855
"forecast",
3956
merge_on=["unique_id", "ds"],
4057
df=df,
4158
h=h,
4259
freq=freq,
60+
level=level,
61+
quantiles=quantiles,
4362
)
4463

4564
def cross_validation(
@@ -49,6 +68,8 @@ def cross_validation(
4968
freq: str,
5069
n_windows: int = 1,
5170
step_size: int | None = None,
71+
level: list[int] | None = None,
72+
quantiles: list[float] | None = None,
5273
) -> pd.DataFrame:
5374
return self._call_models(
5475
"cross_validation",
@@ -58,4 +79,6 @@ def cross_validation(
5879
freq=freq,
5980
n_windows=n_windows,
6081
step_size=step_size,
82+
level=level,
83+
quantiles=quantiles,
6184
)

0 commit comments

Comments
 (0)