1+ import logging
12from typing import Optional
23
34from fedot .core .data .data import InputData , OutputData
45from fedot .core .operations .evaluation .evaluation_interfaces import EvaluationStrategy
5- from fedot .core .operations .evaluation .operation_implementations .models .tabpfn import \
6- FedotTabPFNClassificationImplementation , FedotTabPFNRegressionImplementation
76from fedot .core .operations .operation_parameters import OperationParameters
87from fedot .core .repository .tasks import TaskTypesEnum
98from fedot .utilities .random import ImplementationRandomStateHandler
109
10+ try :
11+ from fedot .core .operations .evaluation .operation_implementations .models .tabpfn import \
12+ FedotTabPFNClassificationImplementation , FedotTabPFNRegressionImplementation
13+ _TABPFN_AVAILABLE = True
14+ except ModuleNotFoundError :
15+ FedotTabPFNClassificationImplementation = None
16+ FedotTabPFNRegressionImplementation = None
17+ _ERROR_MESSAGE = (""
18+ "TabPFN is required but not installed. "
19+ "Install with `pip install fedot[extra]` or `pip install tabpfn`."
20+ ""
21+ )
22+ _TABPFN_AVAILABLE = False
23+ logging .log (100 , _ERROR_MESSAGE )
24+
1125
1226class TabPFNStrategy (EvaluationStrategy ):
1327 _operations_by_types = {
@@ -23,6 +37,9 @@ def __init__(self, operation_type: str, params: Optional[OperationParameters] =
2337 self .max_features = params .get ('max_features' , 500 ) if params else 500
2438
2539 def fit (self , train_data : InputData ):
40+ if not _TABPFN_AVAILABLE :
41+ raise ModuleNotFoundError (_ERROR_MESSAGE )
42+
2643 check_data_size (
2744 data = train_data ,
2845 device = self .device ,
@@ -48,6 +65,9 @@ def __init__(self, operation_type: str, params: Optional[OperationParameters] =
4865 super ().__init__ (operation_type , params )
4966
5067 def predict (self , trained_operation , predict_data : InputData ) -> OutputData :
68+ if not _TABPFN_AVAILABLE :
69+ raise ModuleNotFoundError (_ERROR_MESSAGE )
70+
5171 if self .output_mode == 'labels' :
5272 output = trained_operation .predict (predict_data )
5373 elif self .output_mode in ['probs' , 'full_probs' , 'default' ]:
@@ -69,6 +89,9 @@ def __init__(self, operation_type: str, params: Optional[OperationParameters] =
6989 super ().__init__ (operation_type , params )
7090
7191 def predict (self , trained_operation , predict_data : InputData ) -> OutputData :
92+ if not _TABPFN_AVAILABLE :
93+ raise ModuleNotFoundError (_ERROR_MESSAGE )
94+
7295 return trained_operation .predict (predict_data )
7396
7497
0 commit comments