Skip to content

Commit 3711ea0

Browse files
Curo code refactoring
1 parent 634c1b1 commit 3711ea0

File tree

6 files changed

+766
-38
lines changed

6 files changed

+766
-38
lines changed

curo/calculator.py

Lines changed: 71 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
Docstring for calculator
33
"""
44

5-
from dataclasses import dataclass
65
from typing import List, Optional
76
import pandas as pd
87
from scipy.optimize import root_scalar
@@ -13,7 +12,6 @@
1312
from curo.exceptions import UnsolvableError, ValidationError
1413
from curo.utils import roll_date, gauss_round, to_timestamp
1514

16-
@dataclass
1715
class Calculator:
1816
"""
1917
The Calculator class provides the entry point for solving unknown values and/or
@@ -29,7 +27,7 @@ class Calculator:
2927
Attributes:
3028
precision (int): The rounding precision for cash flow values.
3129
profile (pd.DataFrame or None): The DataFrame containing cash flow data.
32-
series (List[Series]): The list of cash flow series (empty if not provided).
30+
_series (List[Series]): Private list of provided cash flow series.
3331
_is_bespoke_profile (bool): Used internally to identify profile source.
3432
3533
Raises:
@@ -39,15 +37,13 @@ class Calculator:
3937
Calculator: A new Calculator instance configured with the specified precision
4038
and optional profile.
4139
"""
42-
precision: int = 2
43-
profile: Optional[pd.DataFrame] = None
44-
series: List[Series] = None
45-
46-
def __post_init__(self):
47-
if not 0 <= self.precision <= 4:
40+
def __init__(self, precision: int = 2, profile: Optional[pd.DataFrame] = None):
41+
if not 0 <= precision <= 4:
4842
raise ValidationError("Precision must be between 0 and 4")
49-
self.series = self.series or []
50-
self._is_bespoke_profile = self.profile is not None
43+
self.precision = precision
44+
self.profile = profile
45+
self._series = []
46+
self._is_bespoke_profile = profile is not None
5147

5248
def add(self, series: Series) -> None:
5349
"""
@@ -75,7 +71,7 @@ def add(self, series: Series) -> None:
7571
if series.amount is not None:
7672
# Coerce series monetary value to specified precision
7773
series.amount = gauss_round(series.amount, self.precision)
78-
self.series.append(series)
74+
self._series.append(series)
7975

8076
def solve_value(
8177
self,
@@ -242,7 +238,7 @@ def _build_profile(self, start_date: Optional[pd.Timestamp] = None) -> pd.DataFr
242238
advance_start_date = start_date
243239
payment_start_date = start_date
244240
charge_start_date = start_date
245-
for s in self.series:
241+
for s in self._series:
246242
if isinstance(s, SeriesAdvance):
247243
if s.post_date_from is None and s.mode == Mode.ARREAR:
248244
advance_start_date = roll_date(
@@ -314,7 +310,7 @@ def _build_profile(self, start_date: Optional[pd.Timestamp] = None) -> pd.DataFr
314310

315311
if not cash_flows_list:
316312
# Handle empty case
317-
return pd.DataFrame(columns=[
313+
profile = pd.DataFrame(columns=[
318314
Column.POST_DATE.value,
319315
Column.VALUE_DATE.value,
320316
Column.AMOUNT.value,
@@ -325,6 +321,17 @@ def _build_profile(self, start_date: Optional[pd.Timestamp] = None) -> pd.DataFr
325321
Column.IS_INTEREST_CAPITALISED.value,
326322
Column.IS_CHARGE.value
327323
])
324+
return profile.astype({
325+
Column.POST_DATE.value: 'datetime64[ns, UTC]',
326+
Column.VALUE_DATE.value: 'datetime64[ns, UTC]',
327+
Column.AMOUNT.value: 'float64',
328+
Column.IS_KNOWN.value: 'bool',
329+
Column.WEIGHTING.value: 'float64',
330+
Column.LABEL.value: 'object',
331+
Column.MODE.value: 'object',
332+
Column.IS_INTEREST_CAPITALISED.value: 'object',
333+
Column.IS_CHARGE.value: 'bool'
334+
})
328335

329336
profile = pd.concat(cash_flows_list, ignore_index=True)
330337
# Ensure consistent dtypes
@@ -564,11 +571,6 @@ def _validate_profile(
564571
raise ValidationError(
565572
"Unknowns must be either all advances or all payments, not both"
566573
)
567-
# Check weighting > 0 for unknowns
568-
if unknowns[Column.WEIGHTING.value].le(0).any():
569-
raise ValidationError(
570-
"Unknown values must have positive weighting in SOLVE_VALUE mode"
571-
)
572574
elif mode == ValidationMode.SOLVE_RATE:
573575
if not unknowns.empty:
574576
raise ValidationError("All values must be known in SOLVE_RATE mode")
@@ -585,32 +587,64 @@ def _assign_factors(self, cash_flows: pd.DataFrame, day_count: Convention) -> pd
585587
586588
Args:
587589
cash_flows (pd.DataFrame): DataFrame containing cash flow data with
588-
CashFlowColumn.POST_DATE.
590+
CashFlowColumn.POST_DATE and CashFlowColumn.VALUE_DATE.
589591
day_count (Convention): The day count convention for computing time intervals.
590592
591593
Returns:
592594
pd.DataFrame: A copy of the input DataFrame with an added 'factor' column
593-
containing computed time factors.
595+
containing computed DayCountFactor objects.
594596
595597
Notes:
596-
- Factors are computed from the first `post_date` (if DayCountOrigin.DRAWDOWN)
597-
or between consecutive `post_date` values (otherwise).
598-
- The 'factor' column contains objects specific to the day count convention.
598+
- Factors are computed from the first advance's post_date or value_date (if
599+
DayCountOrigin.DRAWDOWN) or between consecutive dates (if DayCountOrigin.NEIGHBOUR).
600+
- For charges, if include_non_financing_flows is False, the factor is computed
601+
between the same date (zero period).
602+
- The date used (post_date or value_date) depends on day_count.use_post_dates.
599603
"""
600604
cash_flows = cash_flows.copy()
601-
if day_count.day_count_origin == DayCountOrigin.DRAWDOWN:
602-
start_date = cash_flows['post_date'].iloc[0]
603-
cash_flows['factor'] = cash_flows['post_date'].apply(
604-
lambda d: day_count.compute_factor(start_date, d)
605-
)
606-
else:
607-
cash_flows['prev_date'] = cash_flows['post_date'].shift(1)
608-
cash_flows.loc[0, 'prev_date'] = cash_flows.loc[0, 'post_date']
609-
cash_flows['factor'] = cash_flows.apply(
610-
lambda row: day_count.compute_factor(row['prev_date'], row['post_date']),
611-
axis=1
612-
)
613-
cash_flows = cash_flows.drop(columns=['prev_date'])
605+
date_column = (
606+
Column.POST_DATE.value
607+
if day_count.use_post_dates
608+
else Column.VALUE_DATE.value
609+
)
610+
611+
# Initialize the factor column as object type to store DayCountFactor
612+
cash_flows['factor'] = None
613+
614+
# Find the first advance's date for DRAWDOWN origin
615+
advances = cash_flows[
616+
(~cash_flows[Column.IS_CHARGE.value]) &
617+
(cash_flows[Column.IS_INTEREST_CAPITALISED.value].isna())
618+
]
619+
if advances.empty:
620+
raise ValidationError("At least one advance required for factor assignment")
621+
drawdown_date = advances[date_column].min()
622+
623+
# For NEIGHBOUR origin, track the previous date
624+
neighbour_date = drawdown_date
625+
626+
for idx in cash_flows.index:
627+
cash_flow_date = cash_flows.loc[idx, date_column]
628+
is_charge = cash_flows.loc[idx, Column.IS_CHARGE.value]
629+
630+
# Handle charges when non-financing flows are excluded
631+
if is_charge and not day_count.include_non_financing_flows:
632+
factor = day_count.compute_factor(cash_flow_date, cash_flow_date)
633+
cash_flows.at[idx, 'factor'] = factor
634+
continue
635+
636+
# Handle cash flows predating or equal to drawdown_date
637+
if cash_flow_date <= drawdown_date:
638+
factor = day_count.compute_factor(cash_flow_date, cash_flow_date)
639+
else:
640+
if day_count.day_count_origin == DayCountOrigin.DRAWDOWN:
641+
factor = day_count.compute_factor(drawdown_date, cash_flow_date)
642+
else: # NEIGHBOUR
643+
factor = day_count.compute_factor(neighbour_date, cash_flow_date)
644+
neighbour_date = cash_flow_date
645+
646+
cash_flows.at[idx, 'factor'] = factor
647+
614648
return cash_flows
615649

616650
def _calculate_nfv(self, cash_flows: pd.DataFrame, day_count: Convention, rate: float) -> float:

tests/test_calculator.py

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,86 @@
1-
# general tests - constructor, add method
1+
# test_calculator.py
2+
# pylint: disable=C0114,C0116,W0212,W0621
3+
# - C0114: missing-module-docstring
4+
# - C0116: missing-function-docstring
5+
# - W0212: protected-access
6+
# - W0621: redefined-outer-name
7+
8+
import pytest
9+
import pandas as pd
10+
from curo.calculator import Calculator
11+
from curo.series import SeriesAdvance, SeriesPayment, SeriesCharge
12+
from curo.enums import CashFlowColumn as Column
13+
from curo.exceptions import ValidationError
14+
15+
@pytest.fixture
16+
def valid_series():
17+
return SeriesAdvance(
18+
number_of=1,
19+
amount=1000.0,
20+
post_date_from=pd.Timestamp("2025-01-01", tz="UTC"),
21+
label="Loan advance"
22+
)
23+
24+
@pytest.fixture
25+
def valid_profile():
26+
return pd.DataFrame({
27+
Column.POST_DATE.value: [pd.Timestamp("2025-01-01", tz="UTC")],
28+
Column.VALUE_DATE.value: [pd.Timestamp("2025-01-01", tz="UTC")],
29+
Column.AMOUNT.value: [1000.0],
30+
Column.IS_KNOWN.value: [True],
31+
Column.WEIGHTING.value: [1.0],
32+
Column.LABEL.value: ["Loan advance"],
33+
Column.MODE.value: ["advance"],
34+
Column.IS_INTEREST_CAPITALISED.value: [None],
35+
Column.IS_CHARGE.value: [False]
36+
})
37+
38+
def test_constructor_valid_precision():
39+
calc = Calculator(precision=2)
40+
assert calc.precision == 2
41+
assert calc.profile is None
42+
assert not calc._series
43+
assert calc._is_bespoke_profile is False
44+
45+
def test_constructor_invalid_precision():
46+
with pytest.raises(ValidationError, match="Precision must be between 0 and 4"):
47+
Calculator(precision=5)
48+
with pytest.raises(ValidationError, match="Precision must be between 0 and 4"):
49+
Calculator(precision=-1)
50+
51+
def test_constructor_with_profile(valid_profile):
52+
calc = Calculator(precision=3, profile=valid_profile)
53+
assert calc.precision == 3
54+
assert calc.profile.equals(valid_profile)
55+
assert not calc._series
56+
assert calc._is_bespoke_profile is True
57+
58+
def test_add_valid_series(valid_series):
59+
calc = Calculator(precision=2)
60+
calc.add(valid_series)
61+
assert len(calc._series) == 1
62+
assert calc._series[0] == valid_series
63+
assert calc._series[0].amount == 1000.0 # Amount unchanged
64+
65+
def test_add_series_with_rounding():
66+
calc = Calculator(precision=1)
67+
series = SeriesPayment(
68+
number_of=1,
69+
amount=123.456,
70+
post_date_from=pd.Timestamp("2025-01-01", tz="UTC"),
71+
label="Loan repayment"
72+
)
73+
calc.add(series)
74+
assert len(calc._series) == 1
75+
assert calc._series[0].amount == 123.5 # Rounded to 1 decimal place
76+
77+
def test_add_with_bespoke_profile(valid_profile):
78+
calc = Calculator(precision=2, profile=valid_profile)
79+
series = SeriesCharge(
80+
number_of=1,
81+
amount=50.0,
82+
post_date_from=pd.Timestamp("2025-01-01", tz="UTC"),
83+
label="Arrangement fee"
84+
)
85+
with pytest.raises(ValidationError, match="Cannot add series with a bespoke profile"):
86+
calc.add(series)

0 commit comments

Comments
 (0)