Skip to content

Commit d94e42e

Browse files
bug: try-except wrappers with TabPFN imports (#1403)
* Added try-except wrappers to TabPFN imports * Automated autopep8 fixes --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 6cff9ed commit d94e42e

File tree

2 files changed

+37
-11
lines changed

2 files changed

+37
-11
lines changed

fedot/core/operations/evaluation/tabpfn.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,27 @@
1+
import logging
12
from typing import Optional
23

34
from fedot.core.data.data import InputData, OutputData
45
from fedot.core.operations.evaluation.evaluation_interfaces import EvaluationStrategy
5-
from fedot.core.operations.evaluation.operation_implementations.models.tabpfn import \
6-
FedotTabPFNClassificationImplementation, FedotTabPFNRegressionImplementation
76
from fedot.core.operations.operation_parameters import OperationParameters
87
from fedot.core.repository.tasks import TaskTypesEnum
98
from fedot.utilities.random import ImplementationRandomStateHandler
109

10+
try:
11+
from fedot.core.operations.evaluation.operation_implementations.models.tabpfn import \
12+
FedotTabPFNClassificationImplementation, FedotTabPFNRegressionImplementation
13+
_TABPFN_AVAILABLE = True
14+
except ModuleNotFoundError:
15+
FedotTabPFNClassificationImplementation = None
16+
FedotTabPFNRegressionImplementation = None
17+
_ERROR_MESSAGE = (""
18+
"TabPFN is required but not installed. "
19+
"Install with `pip install fedot[extra]` or `pip install tabpfn`."
20+
""
21+
)
22+
_TABPFN_AVAILABLE = False
23+
logging.log(100, _ERROR_MESSAGE)
24+
1125

1226
class TabPFNStrategy(EvaluationStrategy):
1327
_operations_by_types = {
@@ -23,6 +37,9 @@ def __init__(self, operation_type: str, params: Optional[OperationParameters] =
2337
self.max_features = params.get('max_features', 500) if params else 500
2438

2539
def fit(self, train_data: InputData):
40+
if not _TABPFN_AVAILABLE:
41+
raise ModuleNotFoundError(_ERROR_MESSAGE)
42+
2643
check_data_size(
2744
data=train_data,
2845
device=self.device,
@@ -48,6 +65,9 @@ def __init__(self, operation_type: str, params: Optional[OperationParameters] =
4865
super().__init__(operation_type, params)
4966

5067
def predict(self, trained_operation, predict_data: InputData) -> OutputData:
68+
if not _TABPFN_AVAILABLE:
69+
raise ModuleNotFoundError(_ERROR_MESSAGE)
70+
5171
if self.output_mode == 'labels':
5272
output = trained_operation.predict(predict_data)
5373
elif self.output_mode in ['probs', 'full_probs', 'default']:
@@ -69,6 +89,9 @@ def __init__(self, operation_type: str, params: Optional[OperationParameters] =
6989
super().__init__(operation_type, params)
7090

7191
def predict(self, trained_operation, predict_data: InputData) -> OutputData:
92+
if not _TABPFN_AVAILABLE:
93+
raise ModuleNotFoundError(_ERROR_MESSAGE)
94+
7295
return trained_operation.predict(predict_data)
7396

7497

test/integration/models/test_strategy.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -85,16 +85,19 @@ def run_tabpfn(
8585
test_data: pd.DataFrame,
8686
task: str,
8787
):
88-
pipeline = PipelineBuilder().add_node(model_name).build()
89-
pipeline.fit(train_data)
90-
predicted_output = pipeline.predict(test_data, output_mode='labels')
91-
if task == 'classification':
92-
metric = roc_auc(test_data.target, predicted_output.predict)
93-
else:
94-
metric = r2_score(test_data.target, predicted_output.predict)
88+
try:
89+
pipeline = PipelineBuilder().add_node(model_name).build()
90+
pipeline.fit(train_data)
91+
predicted_output = pipeline.predict(test_data, output_mode='labels')
92+
if task == 'classification':
93+
metric = roc_auc(test_data.target, predicted_output.predict)
94+
else:
95+
metric = r2_score(test_data.target, predicted_output.predict)
9596

96-
assert isinstance(pipeline, Pipeline)
97-
assert metric > 0.5
97+
assert isinstance(pipeline, Pipeline)
98+
assert metric > 0.5
99+
except ModuleNotFoundError:
100+
pass
98101

99102

100103
def test_tabpfn_classification_operation():

0 commit comments

Comments
 (0)