Skip to content

Commit 64d5864

Browse files
committed
fix report_to
1 parent bd60890 commit 64d5864

File tree

6 files changed

+30
-18
lines changed

6 files changed

+30
-18
lines changed

.github/workflows/ruff.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@ jobs:
77
- uses: actions/checkout@v4
88
- uses: astral-sh/ruff-action@v2
99
with:
10-
version: "0.8.4"
10+
version: "0.15.0"

src/autointent/_callbacks/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
REPORTERS = {cb.name: cb for cb in [WandbCallback, TensorBoardCallback, EmissionsTrackerCallback]}
1010

11-
REPORTERS_NAMES = Literal[tuple(REPORTERS.keys())] # type: ignore[valid-type]
11+
REPORTERS_NAMES = Literal[tuple(REPORTERS.keys()) + ("none",)] # type: ignore[valid-type]
1212

1313

1414
def get_callbacks(reporters: list[str] | None) -> CallbackHandler:
@@ -25,6 +25,8 @@ def get_callbacks(reporters: list[str] | None) -> CallbackHandler:
2525

2626
reporters_cb = []
2727
for reporter in reporters:
28+
if reporter == "none":
29+
continue
2830
if reporter not in REPORTERS:
2931
msg = f"Reporter {reporter} not supported. Supported reporters {','.join(REPORTERS)}"
3032
raise ValueError(msg)

src/autointent/_callbacks/tensorboard.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,11 @@ def start_module(self, module_name: str, num: int, module_kwargs: dict[str, Any]
5353
"""
5454
module_run_name = f"{self.run_name}_{module_name}_{num}"
5555
log_dir = Path(self.dirpath) / module_run_name
56-
self.module_writer = self.writer(log_dir=log_dir) # type: ignore[no-untyped-call]
56+
self.module_writer = self.writer(log_dir=log_dir)
5757

58-
self.module_writer.add_text("module_info", f"Starting module {module_name}_{num}") # type: ignore[no-untyped-call]
58+
self.module_writer.add_text("module_info", f"Starting module {module_name}_{num}")
5959
for key, value in module_kwargs.items():
60-
self.module_writer.add_text(f"module_params/{key}", str(value)) # type: ignore[no-untyped-call]
60+
self.module_writer.add_text(f"module_params/{key}", str(value))
6161

6262
def log_value(self, **kwargs: dict[str, int | float | Any]) -> None:
6363
"""Logs scalar or text values.
@@ -69,7 +69,7 @@ def log_value(self, **kwargs: dict[str, int | float | Any]) -> None:
6969
if isinstance(value, int | float):
7070
self.module_writer.add_scalar(key, value)
7171
else:
72-
self.module_writer.add_text(key, str(value)) # type: ignore[no-untyped-call]
72+
self.module_writer.add_text(key, str(value))
7373

7474
def log_metrics(self, metrics: dict[str, Any]) -> None:
7575
"""Logs training metrics.
@@ -79,9 +79,9 @@ def log_metrics(self, metrics: dict[str, Any]) -> None:
7979
"""
8080
for key, value in metrics.items():
8181
if isinstance(value, int | float):
82-
self.module_writer.add_scalar(key, value) # type: ignore[no-untyped-call]
82+
self.module_writer.add_scalar(key, value)
8383
else:
84-
self.module_writer.add_text(key, str(value)) # type: ignore[no-untyped-call]
84+
self.module_writer.add_text(key, str(value))
8585

8686
def log_final_metrics(self, metrics: dict[str, Any]) -> None:
8787
"""Logs final metrics at the end of training.
@@ -97,13 +97,13 @@ def log_final_metrics(self, metrics: dict[str, Any]) -> None:
9797
raise RuntimeError(msg)
9898

9999
log_dir = Path(self.dirpath) / "final_metrics"
100-
self.module_writer = self.writer(log_dir=log_dir) # type: ignore[no-untyped-call]
100+
self.module_writer = self.writer(log_dir=log_dir)
101101

102102
for key, value in metrics.items():
103103
if isinstance(value, int | float):
104-
self.module_writer.add_scalar(key, value) # type: ignore[no-untyped-call]
104+
self.module_writer.add_scalar(key, value)
105105
else:
106-
self.module_writer.add_text(key, str(value)) # type: ignore[no-untyped-call]
106+
self.module_writer.add_text(key, str(value))
107107

108108
def end_module(self) -> None:
109109
"""Ends the current module and closes the TensorBoard writer.
@@ -115,8 +115,8 @@ def end_module(self) -> None:
115115
msg = "start_run must be called before end_module."
116116
raise RuntimeError(msg)
117117

118-
self.module_writer.add_text("module_info", "Ending module") # type: ignore[no-untyped-call]
119-
self.module_writer.close() # type: ignore[no-untyped-call]
118+
self.module_writer.add_text("module_info", "Ending module")
119+
self.module_writer.close()
120120

121121
def end_run(self) -> None:
122122
"""Ends the current run. This method is currently a placeholder."""

src/autointent/configs/_optimization.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
"""Configuration for the optimization process."""
22

33
from pathlib import Path
4+
from typing import Literal
45

5-
from pydantic import BaseModel, ConfigDict, Field, PositiveInt
6+
from pydantic import BaseModel, ConfigDict, Field, PositiveInt, field_validator
67

78
from autointent._callbacks import REPORTERS_NAMES
89
from autointent.custom_types import FloatFromZeroToOne, SamplerType, ValidationScheme
@@ -57,7 +58,7 @@ class LoggingConfig(BaseModel):
5758
clear_ram: bool = Field(False, description="Whether to clear the RAM after dumping the modules")
5859
"""Whether to clear the RAM after dumping the modules"""
5960
report_to: list[REPORTERS_NAMES] | None = Field( # type: ignore[valid-type]
60-
None, description="List of callbacks to report to. If None, no callbacks will be used"
61+
['none'], description="List of callbacks to report to. If None, no callbacks will be used"
6162
)
6263
log_interval_time: float = Field(
6364
0.1, description="Sampling interval for the system monitor in seconds for Wandb logger."
@@ -88,6 +89,13 @@ def get_run_name(self) -> str:
8889
self.run_name = get_run_name()
8990
return self.run_name
9091

92+
@field_validator("report_to")
93+
def validate_report_to(cls, value: list[REPORTERS_NAMES] | None) -> list[REPORTERS_NAMES]:
94+
"""Validate the `report_to` field to ensure it is either 'none' or a list of valid reporter names."""
95+
if value is None:
96+
return ['none'] # since transformers v5 doesn't allow None for report_to
97+
return value
98+
9199

92100
class HPOConfig(BaseModel):
93101
"""Configuration for hyperparameter optimization using Optuna.

src/autointent/modules/scoring/_catboost/catboost_scorer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,9 @@ def fit(
218218
**self.get_extra_params(),
219219
)
220220
self._model.fit(
221-
dataset, labels, early_stopping_rounds=self.early_stopping_rounds if self.val_fraction is not None else None
221+
dataset,
222+
list(labels), # datasets >4 would pass `Column` instead of list, which causes error in CatBoostClassifier
223+
early_stopping_rounds=self.early_stopping_rounds if self.val_fraction is not None else None,
222224
)
223225

224226
def predict(self, utterances: list[str]) -> npt.NDArray[np.float64]:

tests/pipeline/test_optimization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def test_cv(dataset, task_type):
9898
context = pipeline_optimizer.fit(dataset, refit_after=True)
9999
context.dump()
100100

101-
assert len(pipeline_optimizer.logging_config.dump_dir.iterdir()) > 0
101+
assert len(list(pipeline_optimizer.logging_config.dump_dir.iterdir())) > 0
102102

103103

104104
@pytest.mark.parametrize(
@@ -161,7 +161,7 @@ def test_dump_modules(dataset, task_type):
161161
context = pipeline_optimizer.fit(dataset)
162162
context.dump()
163163

164-
assert pipeline_optimizer.logging_config.dump_dir.iterdir() > 0
164+
assert len(list(pipeline_optimizer.logging_config.dump_dir.iterdir())) > 0
165165

166166

167167
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)