-
Notifications
You must be signed in to change notification settings - Fork 557
Description
Bug report checklist
- I provided code that demonstrates a minimal reproducible example.
- I confirmed bug exists on the latest mainline of Chronos via source install.
Describe the bug
When the input dataframe has string[python] (pandas's StringDtype) as the dtype for the id_column, frequency determination might fail if series are of unequal length:
File "[...]/.venv/lib/python3.11/site-packages/chronos/df_utils.py", line 153, in validate_freq
raise ValueError(f"Could not infer frequency for series {series_id}")
ValueError: Could not infer frequency for series 004a615c8d5d2a30724d12b62e3a3e3debab04ea5275df13664e7e0ac53feac4
It seems like this might have to do with series length being extracted through .value_counts() and the index order being different with dtype('O') and string[python], and since the validation relies on indexing and offset, it fails due to order mismatch:
chronos-forecasting/src/chronos/df_utils.py
Lines 147 to 169 in fd53338
| # Get series lengths | |
| series_lengths = df[id_column].value_counts(sort=False).to_list() | |
| def validate_freq(timestamps: pd.DatetimeIndex, series_id: str): | |
| freq = pd.infer_freq(timestamps) | |
| if not freq: | |
| raise ValueError(f"Could not infer frequency for series {series_id}") | |
| return freq | |
| # Validate each series | |
| all_freqs = [] | |
| start_idx = 0 | |
| timestamp_index = pd.DatetimeIndex(df[timestamp_column]) | |
| for length in series_lengths: | |
| if length < 3: | |
| series_id = df[id_column].iloc[start_idx] | |
| raise ValueError( | |
| f"Every time series must have at least 3 data points, found {length=} for series {series_id}" | |
| ) | |
| timestamps = timestamp_index[start_idx : start_idx + length] | |
| series_id = df[id_column].iloc[start_idx] | |
| all_freqs.append(validate_freq(timestamps, series_id)) | |
| start_idx += length |
Expected behavior
Inferring the frequency should work equally when the id_column is dtype('O') or string[python]. Even if pandas' StringDtype is considered experimental, users might accidentally run into this issue.
To reproduce
import torch
import pandas as pd
import numpy as np
import hashlib
from chronos import BaseChronosPipeline
pipeline = BaseChronosPipeline.from_pretrained(
"amazon/chronos-2",
device_map="cuda",
torch_dtype=torch.bfloat16,
)
def generate_test_data(num_series=750, num_periods=180, period_variation=150, seed=1, string_python=False):
np.random.seed(seed)
# Calculate the end date based on maximum periods
end_date = pd.date_range(start='2023-01-02', periods=num_periods + period_variation, freq='W-MON')[-1]
data = []
for i in range(num_series):
unique_id = hashlib.sha256(f"series_{i}".encode()).hexdigest()
periods = np.random.randint(num_periods - period_variation, num_periods + period_variation + 1)
dates = pd.date_range(end=end_date, periods=periods, freq='W-MON')
targets = np.random.uniform(100, 1000, size=periods)
series_df = pd.DataFrame({
'unique_id': unique_id,
'ds': dates,
'target': targets
})
data.append(series_df)
train_y = pd.concat(data, ignore_index=True)
if string_python:
train_y['unique_id'] = train_y['unique_id'].astype('string[python]')
return train_y
# failure with id column as dtype('string[python]')
train_y = generate_test_data(string_python=True)
# ## id column as dtype('string[python]'), equal length series -> ok
# train_y = generate_test_data(string_python=False, period_variation=0)
# ## id column as dtype('O') -> ok
# train_y = generate_test_data(string_python=False)
args = {
'df': train_y,
'future_df': None,
'prediction_length': 30,
'quantile_levels': [0.5],
'id_column': 'unique_id',
'timestamp_column': 'ds',
'target': 'target'
}
yhat = pipeline.predict_df(**args)
print(yhat.shape)Environment description
Python version: 3.11
CUDA version: 12.2
PyTorch version: 2.3.1