|
11 | 11 |
|
12 | 12 | from fedotllm import prompts |
13 | 13 | from fedotllm.agents.automl.state import AutoMLAgentState |
14 | | -from fedotllm.agents.automl.structured import FedotConfig |
| 14 | +from fedotllm.agents.automl.structured import FedotConfig, RDKitConfig |
15 | 15 | from fedotllm.agents.automl.templates.load_template import ( |
16 | 16 | load_template, |
17 | 17 | render_template, |
|
32 | 32 | "predict_proba": "predict_proba(features=input_data)", |
33 | 33 | } |
34 | 34 |
|
| 35 | +RDKIT_DESCRIPTORS_MAP = { |
| 36 | + "MolWt": "Descriptors.MolWt(mol)", |
| 37 | + "HeavyAtomMolWt": "Descriptors.HeavyAtomMolWt(mol)", |
| 38 | + "HeavyAtomCount": "Descriptors.HeavyAtomCount(mol)", |
| 39 | + "NumAtoms": "mol.GetNumAtoms()", |
| 40 | + "NumValenceElectrons": "Descriptors.NumValenceElectrons(mol)", |
| 41 | + |
| 42 | + # Lipophilicity/Hydrophobicity |
| 43 | + "MolLogP": "Descriptors.MolLogP(mol)", |
| 44 | + "MolMR": "Descriptors.MolMR(mol)", |
| 45 | + |
| 46 | + # Hydrogen Bonding |
| 47 | + "NumHDonors": "Descriptors.NumHDonors(mol)", |
| 48 | + "NumHAcceptors": "Descriptors.NumHAcceptors(mol)", |
| 49 | + |
| 50 | + # Topology and Connectivity |
| 51 | + "TPSA": "Descriptors.TPSA(mol)", |
| 52 | + "NumRotatableBonds": "Descriptors.NumRotatableBonds(mol)", |
| 53 | + "RingCount": "Descriptors.RingCount(mol)", |
| 54 | + "NumAromaticRings": "Descriptors.NumAromaticRings(mol)", |
| 55 | + "NumAliphaticRings": "Descriptors.NumAliphaticRings(mol)", |
| 56 | + "NumSaturatedRings": "Descriptors.NumSaturatedRings(mol)", |
| 57 | + "NumHeteroatoms": "Descriptors.NumHeteroatoms(mol)", |
| 58 | + "NumAmideBonds": "Descriptors.NumAmideBonds(mol)" |
| 59 | +} |
| 60 | + |
35 | 61 |
|
36 | 62 | def init_state(state: AutoMLAgentState): |
37 | 63 | return Command( |
38 | 64 | update={ |
39 | 65 | "reflection": None, |
40 | 66 | "fedot_config": None, |
| 67 | + "rdkit_config": None, |
41 | 68 | "skeleton": None, |
42 | 69 | "raw_code": None, |
43 | 70 | "code": None, |
@@ -83,6 +110,20 @@ def generate_automl_config( |
83 | 110 |
|
84 | 111 | return Command(update={"fedot_config": fedot_config}) |
85 | 112 |
|
| 113 | +def generate_rdkit_config( |
| 114 | + state: AutoMLAgentState, inference: AIInference, dataset: Dataset |
| 115 | +): |
| 116 | + logger.info("Running generate RDKit config") |
| 117 | + |
| 118 | + rdkit_config = inference.create( |
| 119 | + prompts.automl.generate_rdkit_configuration_prompt( |
| 120 | + reflection=state["reflection"], |
| 121 | + ), |
| 122 | + response_model=RDKitConfig, |
| 123 | + ) |
| 124 | + |
| 125 | + return Command(update={"rdkit_config": rdkit_config}) |
| 126 | + |
86 | 127 |
|
87 | 128 | def select_skeleton( |
88 | 129 | state: AutoMLAgentState, app_config: AppConfig, dataset: Dataset, workspace: Path |
@@ -126,8 +167,13 @@ def insert_templates( |
126 | 167 | logger.info("Running insert templates") |
127 | 168 | code = state["raw_code"] |
128 | 169 | fedot_config = state["fedot_config"] |
| 170 | + rdkit_config = state["rdkit_config"] |
129 | 171 | predict_method = PREDICT_METHOD_MAP.get(fedot_config.predict_method) |
130 | 172 |
|
| 173 | + if rdkit_config is not None: |
| 174 | + rdkit_decriptor_lines = [f'"{item.value}": {RDKIT_DESCRIPTORS_MAP.get(item.value)}' for item in rdkit_config.descriptors] |
| 175 | + rdkit_decriptors_code = "\n,".join(rdkit_decriptor_lines) |
| 176 | + |
131 | 177 | predictor_init_kwargs = ( |
132 | 178 | { |
133 | 179 | "problem": str(fedot_config.problem), |
@@ -164,14 +210,26 @@ def insert_templates( |
164 | 210 | }, |
165 | 211 | } |
166 | 212 |
|
| 213 | + if rdkit_config is not None: |
| 214 | + smiles_to_features_params = {"descriptors": rdkit_decriptors_code} |
| 215 | + smiles_to_features_template = { |
| 216 | + app_config.automl.templates.smiles_to_features: {"params": smiles_to_features_params} |
| 217 | + } |
| 218 | + templates.update(smiles_to_features_template) |
| 219 | + |
167 | 220 | rendered_templates = [] |
168 | 221 | for template_name, fconfig in templates.items(): |
169 | 222 | template = load_template(template_name) |
170 | 223 | rendered = render_template(template=template, **fconfig["params"]) |
171 | 224 | rendered_templates.append(rendered) |
172 | 225 |
|
| 226 | + line_to_replace = "from automl import train_model, evaluate_model, automl_predict" |
| 227 | + |
| 228 | + if rdkit_config is not None: |
| 229 | + line_to_replace = "from automl import train_model, evaluate_model, automl_predict, smiles_to_features" |
| 230 | + |
173 | 231 | code = code.replace( |
174 | | - "from automl import train_model, evaluate_model, automl_predict", |
| 232 | + line_to_replace, |
175 | 233 | "\n".join(rendered_templates), |
176 | 234 | ) |
177 | 235 |
|
@@ -322,6 +380,7 @@ def test_submission_format(args: tuple) -> Observation: |
322 | 380 | msg=f"Submission file has wrong number of columns. Expected: {sample_df.shape[1]}, Got: {submission_df.shape[1]}", |
323 | 381 | ) |
324 | 382 |
|
| 383 | + |
325 | 384 | # LLM validation for deeper format checking |
326 | 385 | try: |
327 | 386 | submission_sample = submission_df.head(3).to_string( |
|
0 commit comments