22Docstring for calculator
33"""
44
5- from dataclasses import dataclass
65from typing import List , Optional
76import pandas as pd
87from scipy .optimize import root_scalar
1312from curo .exceptions import UnsolvableError , ValidationError
1413from curo .utils import roll_date , gauss_round , to_timestamp
1514
16- @dataclass
1715class Calculator :
1816 """
1917 The Calculator class provides the entry point for solving unknown values and/or
@@ -29,7 +27,7 @@ class Calculator:
2927 Attributes:
3028 precision (int): The rounding precision for cash flow values.
3129 profile (pd.DataFrame or None): The DataFrame containing cash flow data.
32- series (List[Series]): The list of cash flow series (empty if not provided) .
30+ _series (List[Series]): Private list of provided cash flow series.
3331 _is_bespoke_profile (bool): Used internally to identify profile source.
3432
3533 Raises:
@@ -39,15 +37,13 @@ class Calculator:
3937 Calculator: A new Calculator instance configured with the specified precision
4038 and optional profile.
4139 """
42- precision : int = 2
43- profile : Optional [pd .DataFrame ] = None
44- series : List [Series ] = None
45-
46- def __post_init__ (self ):
47- if not 0 <= self .precision <= 4 :
40+ def __init__ (self , precision : int = 2 , profile : Optional [pd .DataFrame ] = None ):
41+ if not 0 <= precision <= 4 :
4842 raise ValidationError ("Precision must be between 0 and 4" )
49- self .series = self .series or []
50- self ._is_bespoke_profile = self .profile is not None
43+ self .precision = precision
44+ self .profile = profile
45+ self ._series = []
46+ self ._is_bespoke_profile = profile is not None
5147
5248 def add (self , series : Series ) -> None :
5349 """
@@ -75,7 +71,7 @@ def add(self, series: Series) -> None:
7571 if series .amount is not None :
7672 # Coerce series monetary value to specified precision
7773 series .amount = gauss_round (series .amount , self .precision )
78- self .series .append (series )
74+ self ._series .append (series )
7975
8076 def solve_value (
8177 self ,
@@ -242,7 +238,7 @@ def _build_profile(self, start_date: Optional[pd.Timestamp] = None) -> pd.DataFr
242238 advance_start_date = start_date
243239 payment_start_date = start_date
244240 charge_start_date = start_date
245- for s in self .series :
241+ for s in self ._series :
246242 if isinstance (s , SeriesAdvance ):
247243 if s .post_date_from is None and s .mode == Mode .ARREAR :
248244 advance_start_date = roll_date (
@@ -314,7 +310,7 @@ def _build_profile(self, start_date: Optional[pd.Timestamp] = None) -> pd.DataFr
314310
315311 if not cash_flows_list :
316312 # Handle empty case
317- return pd .DataFrame (columns = [
313+ profile = pd .DataFrame (columns = [
318314 Column .POST_DATE .value ,
319315 Column .VALUE_DATE .value ,
320316 Column .AMOUNT .value ,
@@ -325,6 +321,17 @@ def _build_profile(self, start_date: Optional[pd.Timestamp] = None) -> pd.DataFr
325321 Column .IS_INTEREST_CAPITALISED .value ,
326322 Column .IS_CHARGE .value
327323 ])
324+ return profile .astype ({
325+ Column .POST_DATE .value : 'datetime64[ns, UTC]' ,
326+ Column .VALUE_DATE .value : 'datetime64[ns, UTC]' ,
327+ Column .AMOUNT .value : 'float64' ,
328+ Column .IS_KNOWN .value : 'bool' ,
329+ Column .WEIGHTING .value : 'float64' ,
330+ Column .LABEL .value : 'object' ,
331+ Column .MODE .value : 'object' ,
332+ Column .IS_INTEREST_CAPITALISED .value : 'object' ,
333+ Column .IS_CHARGE .value : 'bool'
334+ })
328335
329336 profile = pd .concat (cash_flows_list , ignore_index = True )
330337 # Ensure consistent dtypes
@@ -564,11 +571,6 @@ def _validate_profile(
564571 raise ValidationError (
565572 "Unknowns must be either all advances or all payments, not both"
566573 )
567- # Check weighting > 0 for unknowns
568- if unknowns [Column .WEIGHTING .value ].le (0 ).any ():
569- raise ValidationError (
570- "Unknown values must have positive weighting in SOLVE_VALUE mode"
571- )
572574 elif mode == ValidationMode .SOLVE_RATE :
573575 if not unknowns .empty :
574576 raise ValidationError ("All values must be known in SOLVE_RATE mode" )
@@ -585,32 +587,64 @@ def _assign_factors(self, cash_flows: pd.DataFrame, day_count: Convention) -> pd
585587
586588 Args:
587589 cash_flows (pd.DataFrame): DataFrame containing cash flow data with
588- CashFlowColumn.POST_DATE.
590+ CashFlowColumn.POST_DATE and CashFlowColumn.VALUE_DATE .
589591 day_count (Convention): The day count convention for computing time intervals.
590592
591593 Returns:
592594 pd.DataFrame: A copy of the input DataFrame with an added 'factor' column
593- containing computed time factors .
595+ containing computed DayCountFactor objects .
594596
595597 Notes:
596- - Factors are computed from the first `post_date` (if DayCountOrigin.DRAWDOWN)
597- or between consecutive `post_date` values (otherwise).
598- - The 'factor' column contains objects specific to the day count convention.
598+ - Factors are computed from the first advance's post_date or value_date (if
599+ DayCountOrigin.DRAWDOWN) or between consecutive dates (if DayCountOrigin.NEIGHBOUR).
600+ - For charges, if include_non_financing_flows is False, the factor is computed
601+ between the same date (zero period).
602+ - The date used (post_date or value_date) depends on day_count.use_post_dates.
599603 """
600604 cash_flows = cash_flows .copy ()
601- if day_count .day_count_origin == DayCountOrigin .DRAWDOWN :
602- start_date = cash_flows ['post_date' ].iloc [0 ]
603- cash_flows ['factor' ] = cash_flows ['post_date' ].apply (
604- lambda d : day_count .compute_factor (start_date , d )
605- )
606- else :
607- cash_flows ['prev_date' ] = cash_flows ['post_date' ].shift (1 )
608- cash_flows .loc [0 , 'prev_date' ] = cash_flows .loc [0 , 'post_date' ]
609- cash_flows ['factor' ] = cash_flows .apply (
610- lambda row : day_count .compute_factor (row ['prev_date' ], row ['post_date' ]),
611- axis = 1
612- )
613- cash_flows = cash_flows .drop (columns = ['prev_date' ])
605+ date_column = (
606+ Column .POST_DATE .value
607+ if day_count .use_post_dates
608+ else Column .VALUE_DATE .value
609+ )
610+
611+ # Initialize the factor column as object type to store DayCountFactor
612+ cash_flows ['factor' ] = None
613+
614+ # Find the first advance's date for DRAWDOWN origin
615+ advances = cash_flows [
616+ (~ cash_flows [Column .IS_CHARGE .value ]) &
617+ (cash_flows [Column .IS_INTEREST_CAPITALISED .value ].isna ())
618+ ]
619+ if advances .empty :
620+ raise ValidationError ("At least one advance required for factor assignment" )
621+ drawdown_date = advances [date_column ].min ()
622+
623+ # For NEIGHBOUR origin, track the previous date
624+ neighbour_date = drawdown_date
625+
626+ for idx in cash_flows .index :
627+ cash_flow_date = cash_flows .loc [idx , date_column ]
628+ is_charge = cash_flows .loc [idx , Column .IS_CHARGE .value ]
629+
630+ # Handle charges when non-financing flows are excluded
631+ if is_charge and not day_count .include_non_financing_flows :
632+ factor = day_count .compute_factor (cash_flow_date , cash_flow_date )
633+ cash_flows .at [idx , 'factor' ] = factor
634+ continue
635+
636+ # Handle cash flows predating or equal to drawdown_date
637+ if cash_flow_date <= drawdown_date :
638+ factor = day_count .compute_factor (cash_flow_date , cash_flow_date )
639+ else :
640+ if day_count .day_count_origin == DayCountOrigin .DRAWDOWN :
641+ factor = day_count .compute_factor (drawdown_date , cash_flow_date )
642+ else : # NEIGHBOUR
643+ factor = day_count .compute_factor (neighbour_date , cash_flow_date )
644+ neighbour_date = cash_flow_date
645+
646+ cash_flows .at [idx , 'factor' ] = factor
647+
614648 return cash_flows
615649
616650 def _calculate_nfv (self , cash_flows : pd .DataFrame , day_count : Convention , rate : float ) -> float :
0 commit comments