Skip to content

Commit 57773f2

Browse files
authored
Merge pull request #78 from JuBiotech/type-fixes
Fix various type hints
2 parents a35a2d0 + f1274fb commit 57773f2

File tree

3 files changed

+33
-24
lines changed

3 files changed

+33
-24
lines changed

calibr8/core.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,17 @@
99
import json
1010
import logging
1111
import os
12-
import typing
1312
import warnings
1413
from pathlib import Path
15-
from typing import Callable, Optional, Sequence, Tuple, Union
14+
from typing import Callable, DefaultDict, List, Optional, Sequence, Tuple, Union
1615

1716
import numpy
1817
import scipy
1918

2019
from . import utils
2120
from .utils import DistributionType, pm
2221

23-
__version__ = "7.1.1"
22+
__version__ = "7.1.2"
2423
_log = logging.getLogger("calibr8")
2524

2625

@@ -170,7 +169,7 @@ def _interval_prob(x_cdf: numpy.ndarray, cdf: numpy.ndarray, a: float, b: float)
170169
return cdf[ib] - cdf[ia]
171170

172171

173-
def _get_eti(x_cdf: numpy.ndarray, cdf: numpy.ndarray, ci_prob: float) -> typing.Tuple[float, float]:
172+
def _get_eti(x_cdf: numpy.ndarray, cdf: numpy.ndarray, ci_prob: float) -> Tuple[float, float]:
174173
"""Find the equal tailed interval (ETI) corresponding to a certain credible interval probability level.
175174
176175
Parameters
@@ -203,8 +202,8 @@ def _get_hdi(
203202
guess_lower: float,
204203
guess_upper: float,
205204
*,
206-
history: typing.Optional[typing.DefaultDict[str, typing.List]] = None,
207-
) -> typing.Tuple[float, float]:
205+
history: Optional[DefaultDict[str, List]] = None,
206+
) -> Tuple[float, float]:
208207
"""Find the highest density interval (HDI) corresponding to a certain credible interval probability level.
209208
210209
Parameters
@@ -600,7 +599,7 @@ def likelihood(self, *, y, x, theta=None, scan_x: bool = False):
600599
return numpy.exp([self.loglikelihood(y=y, x=xi, theta=theta) for xi in x])
601600
return numpy.exp(self.loglikelihood(y=y, x=x, theta=theta))
602601

603-
def objective(self, independent, dependent, minimize=True) -> typing.Callable:
602+
def objective(self, independent, dependent, minimize=True) -> Callable:
604603
"""Creates an objective function for fitting to data.
605604
606605
Parameters
@@ -628,7 +627,7 @@ def objective(x):
628627

629628
return objective
630629

631-
def save(self, filepath: Union[Path, os.PathLike]):
630+
def save(self, filepath: Union[str, Path, os.PathLike]):
632631
"""Save key properties of the calibration model to a JSON file.
633632
634633
Parameters
@@ -654,7 +653,7 @@ def save(self, filepath: Union[Path, os.PathLike]):
654653
return
655654

656655
@classmethod
657-
def load(cls, filepath: Union[Path, os.PathLike]):
656+
def load(cls, filepath: Union[str, Path, os.PathLike]):
658657
"""Instantiates a model from a JSON file of key properties.
659658
660659
Parameters

calibr8/optimization.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
likelihood estimation of calibration model parameters.
44
"""
55
import logging
6-
import typing
7-
from typing import Any, Mapping, Optional, Sequence, Tuple, Union
6+
from typing import Any, Literal, Mapping, Optional, Sequence, Tuple, Union
87

98
import numpy
109
import scipy.optimize
@@ -14,7 +13,11 @@
1413
_log = logging.getLogger("calibr8.optimization")
1514

1615

17-
def _mask_and_warn_inf_or_nan(x: numpy.ndarray, y: numpy.ndarray, on: typing.Optional[str] = None):
16+
def _mask_and_warn_inf_or_nan(
17+
x: Union[Sequence[float], numpy.ndarray],
18+
y: Union[Sequence[float], numpy.ndarray],
19+
on: Optional[Literal["x", "y"]] = None,
20+
) -> Tuple[numpy.ndarray, numpy.ndarray]:
1821
"""Filters `x` and `y` such that only finite elements remain.
1922
2023
Parameters
@@ -31,6 +34,8 @@ def _mask_and_warn_inf_or_nan(x: numpy.ndarray, y: numpy.ndarray, on: typing.Opt
3134
x : array
3235
y : array
3336
"""
37+
x = numpy.asarray(x)
38+
y = numpy.asarray(y)
3439
xdims = numpy.ndim(x)
3540
if xdims == 1:
3641
mask_x = numpy.isfinite(x)
@@ -82,8 +87,8 @@ def _warn_hit_bounds(theta, bounds, theta_names) -> bool:
8287
def fit_scipy(
8388
model: core.CalibrationModel,
8489
*,
85-
independent: numpy.ndarray,
86-
dependent: numpy.ndarray,
90+
independent: Union[Sequence[float], numpy.ndarray],
91+
dependent: Union[Sequence[float], numpy.ndarray],
8792
theta_guess: Union[Sequence[float], numpy.ndarray],
8893
theta_bounds: Sequence[Tuple[float, float]],
8994
minimize_kwargs: Optional[Mapping[str, Any]] = None,
@@ -154,8 +159,8 @@ def fit_scipy(
154159
def fit_scipy_global(
155160
model: core.CalibrationModel,
156161
*,
157-
independent: numpy.ndarray,
158-
dependent: numpy.ndarray,
162+
independent: Union[Sequence[float], numpy.ndarray],
163+
dependent: Union[Sequence[float], numpy.ndarray],
159164
theta_bounds: list,
160165
method: Optional[str] = None,
161166
maxiter: int = 5000,

calibr8/utils.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,9 @@
33
imports, timestamp parsing and plotting.
44
"""
55
import datetime
6-
import typing
76
import warnings
87
from collections.abc import Iterable
9-
from typing import Literal, Optional, Sequence, Tuple
8+
from typing import List, Literal, Optional, Sequence, Tuple
109

1110
import matplotlib
1211
import numpy
@@ -52,7 +51,7 @@ def __getattr__(self, attr):
5251
pm = ImportWarner("pymc")
5352

5453

55-
def parse_datetime(s: typing.Optional[str]) -> typing.Optional[datetime.datetime]:
54+
def parse_datetime(s: Optional[str]) -> Optional[datetime.datetime]:
5655
"""Parses a timezone-aware datetime formatted like 2020-08-05T13:37:00Z.
5756
5857
Returns
@@ -65,7 +64,7 @@ def parse_datetime(s: typing.Optional[str]) -> typing.Optional[datetime.datetime
6564
return datetime.datetime.strptime(s.replace("Z", "+0000"), "%Y-%m-%dT%H:%M:%S%z")
6665

6766

68-
def format_datetime(dt: typing.Optional[datetime.datetime]) -> typing.Optional[str]:
67+
def format_datetime(dt: Optional[datetime.datetime]) -> Optional[str]:
6968
"""Formats a datetime like 2020-08-05T13:37:00Z.
7069
7170
Returns
@@ -176,7 +175,9 @@ def plot_norm_band(ax, independent, mu, scale):
176175
return artists
177176

178177

179-
def plot_t_band(ax, independent, mu, scale, df, *, residual_type: typing.Optional[str] = None):
178+
def plot_t_band(
179+
ax, independent, mu, scale, df, *, residual_type: Optional[Literal["absolute", "relative"]] = None
180+
):
180181
"""Helper function for plotting the 68, 90 and 95 % likelihood-bands of a t-distribution.
181182
182183
Parameters
@@ -241,7 +242,9 @@ def plot_t_band(ax, independent, mu, scale, df, *, residual_type: typing.Optiona
241242
return artists
242243

243244

244-
def plot_continuous_band(ax, independent, model, residual_type: typing.Optional[str] = None):
245+
def plot_continuous_band(
246+
ax, independent, model, residual_type: Optional[Literal["absolute", "relative"]] = None
247+
):
245248
"""Helper function for plotting the 68, 90 and 95 % likelihood-bands of a univariate distribution.
246249
247250
Parameters
@@ -364,9 +367,9 @@ def plot_model(
364367
*,
365368
fig: Optional[matplotlib.figure.Figure] = None,
366369
axs: Optional[Sequence[matplotlib.axes.Axes]] = None,
367-
residual_type="absolute",
370+
residual_type: Literal["absolute", "relative"] = "absolute",
368371
band_xlim: Tuple[Optional[float], Optional[float]] = (None, None),
369-
):
372+
) -> Tuple[matplotlib.figure.Figure, List[matplotlib.axes.Axes]]:
370373
"""Makes a plot of the model with its data.
371374
372375
Parameters
@@ -416,6 +419,8 @@ def plot_model(
416419
axs.append(fig.add_subplot(gs1[0, 1], sharey=axs[0]))
417420
pyplot.setp(axs[1].get_yticklabels(), visible=False)
418421
axs.append(fig.add_subplot(gs2[0, 2]))
422+
else:
423+
axs = list(axs)
419424

420425
# ======= Left =======
421426
# Untransformed, outer range

0 commit comments

Comments
 (0)