Skip to content

[BUG] Cannot infer frequency when id_column has string[python] dtype #440

@manuel-munoz-aguirre

Description

@manuel-munoz-aguirre

Bug report checklist

  • I provided code that demonstrates a minimal reproducible example.
  • I confirmed bug exists on the latest mainline of Chronos via source install.

Describe the bug
When the input dataframe has string[python] (pandas's StringDtype) as the dtype for the id_column, frequency determination might fail if series are of unequal length:

  File "[...]/.venv/lib/python3.11/site-packages/chronos/df_utils.py", line 153, in validate_freq
    raise ValueError(f"Could not infer frequency for series {series_id}")
ValueError: Could not infer frequency for series 004a615c8d5d2a30724d12b62e3a3e3debab04ea5275df13664e7e0ac53feac4

It seems like this might have to do with series length being extracted through .value_counts() and the index order being different with dtype('O') and string[python], and since the validation relies on indexing and offset, it fails due to order mismatch:

# Get series lengths
series_lengths = df[id_column].value_counts(sort=False).to_list()
def validate_freq(timestamps: pd.DatetimeIndex, series_id: str):
freq = pd.infer_freq(timestamps)
if not freq:
raise ValueError(f"Could not infer frequency for series {series_id}")
return freq
# Validate each series
all_freqs = []
start_idx = 0
timestamp_index = pd.DatetimeIndex(df[timestamp_column])
for length in series_lengths:
if length < 3:
series_id = df[id_column].iloc[start_idx]
raise ValueError(
f"Every time series must have at least 3 data points, found {length=} for series {series_id}"
)
timestamps = timestamp_index[start_idx : start_idx + length]
series_id = df[id_column].iloc[start_idx]
all_freqs.append(validate_freq(timestamps, series_id))
start_idx += length

Expected behavior
Inferring the frequency should work equally when the id_column is dtype('O') or string[python]. Even if pandas' StringDtype is considered experimental, users might accidentally run into this issue.

To reproduce

import torch
import pandas as pd
import numpy as np
import hashlib
from chronos import BaseChronosPipeline

pipeline = BaseChronosPipeline.from_pretrained(
    "amazon/chronos-2",
    device_map="cuda",
    torch_dtype=torch.bfloat16,
)

def generate_test_data(num_series=750, num_periods=180, period_variation=150, seed=1, string_python=False): 
    np.random.seed(seed)
    
    # Calculate the end date based on maximum periods
    end_date = pd.date_range(start='2023-01-02', periods=num_periods + period_variation, freq='W-MON')[-1]

    data = []
    for i in range(num_series):
        unique_id = hashlib.sha256(f"series_{i}".encode()).hexdigest()
        periods = np.random.randint(num_periods - period_variation, num_periods + period_variation + 1)
        dates = pd.date_range(end=end_date, periods=periods, freq='W-MON')
        targets = np.random.uniform(100, 1000, size=periods)
        series_df = pd.DataFrame({
            'unique_id': unique_id,
            'ds': dates,
            'target': targets
        })
        data.append(series_df)
    train_y = pd.concat(data, ignore_index=True)

    if string_python:
        train_y['unique_id'] = train_y['unique_id'].astype('string[python]')

    return train_y

# failure with id column as dtype('string[python]')
train_y = generate_test_data(string_python=True)

# ## id column as dtype('string[python]'), equal length series -> ok
# train_y = generate_test_data(string_python=False, period_variation=0)

# ##  id column as dtype('O') -> ok
# train_y = generate_test_data(string_python=False)

args = {
    'df': train_y,
    'future_df': None,
    'prediction_length': 30,
    'quantile_levels': [0.5],
    'id_column': 'unique_id',
    'timestamp_column': 'ds',
    'target': 'target'
}

yhat = pipeline.predict_df(**args)
print(yhat.shape)

Environment description
Python version: 3.11
CUDA version: 12.2
PyTorch version: 2.3.1

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions