11import os
2+ import math
23import numpy as np
4+ from natsort import natsorted
35from sklearn .metrics import mean_absolute_error
46
57from BondDistanceMatrix import bond_distance_matrix
68from MappingFunctions import element_atomID_dict , element_validation , get_atomIDs
79
8- # File search constants and user inputs
9- DATA_DIR = os .getcwd () + '/Test_Cases'
10- PRE_FILE_NAME = 'new_start_molecule.data'
11- POST_FILE_NAME = 'new_post_rx1_molecule.data'
12- ELEMENT_BY_TYPE = ['H' , 'H' , 'C' , 'C' , 'N' , 'O' , 'O' , 'O' ]
13- PRE_BONDING_ATOMS = ['28' , '62' ]
14- POST_BONDING_ATOMS = ['32' , '15' ]
15- POST_MAJOR_MOVED_ATOMS = ['33' ]
16- POST_MINOR_MOVED_ATOMS = ['16' ]
17- _WEIGHT_COEFF = 0.0
18-
19- # Get elements of atom IDs for pre and post molecules
20- preElements = element_atomID_dict (DATA_DIR , PRE_FILE_NAME , ELEMENT_BY_TYPE )
21- postElements = element_atomID_dict (DATA_DIR , POST_FILE_NAME , ELEMENT_BY_TYPE )
22-
23- # Get atomIDs - using existing function for now
24- preAtomIDs = get_atomIDs (DATA_DIR , PRE_FILE_NAME )
25- postAtomIDs = get_atomIDs (DATA_DIR , POST_FILE_NAME )
26-
27- # Calculate bond distance matrices for pre and post molecule
28- preBondDistMat = bond_distance_matrix (DATA_DIR , PRE_FILE_NAME , PRE_BONDING_ATOMS )
29- postBondDistMat = bond_distance_matrix (DATA_DIR , POST_FILE_NAME , POST_BONDING_ATOMS )
30-
31- # Set value for hydrogen that moves to epoxide ring to zero - this will be automated / user supplied info in the future
32- for atom in POST_MAJOR_MOVED_ATOMS :
33- atomIndex = postAtomIDs .index (atom )
34- for index , atomRow in enumerate (postBondDistMat ):
35- postBondDistMat [index ][atomIndex ] = 0.0
36-
37- # Sample weight matrix - lower weight for atoms with significant movement
38- sampleWeights = np .ones (len (postAtomIDs ))
39- for atom in POST_MINOR_MOVED_ATOMS :
40- atomIndex = postAtomIDs .index (atom )
41- sampleWeights [atomIndex ] = _WEIGHT_COEFF
42-
43- mappedIDList = []
44- for searchIndex , searchRow in enumerate (preBondDistMat ):
45- # Shortcircuit this search if search atom is a bonding atom
46- if preAtomIDs [searchIndex ] in PRE_BONDING_ATOMS :
47- bondingIndex = PRE_BONDING_ATOMS .index (preAtomIDs [searchIndex ])
48- bondingPostAtomID = POST_BONDING_ATOMS [bondingIndex ]
49- mappedIDList .append ([preAtomIDs [searchIndex ], bondingPostAtomID ])
50-
51- else :
52- # Sort search row arrays from smallest to largest
53- searchRowIndex = np .argsort (searchRow )
54- searchRowSorted = np .take_along_axis (searchRow , searchRowIndex , axis = 0 )
10+
11+ def atom_mapping (DATA_DIR , PRE_FILE_NAME , POST_FILE_NAME , ELEMENT_BY_TYPE , PRE_BONDING_ATOMS , POST_BONDING_ATOMS , POST_MAJOR_MOVED_ATOMS , POST_MINOR_MOVED_ATOMS ):
12+ _WEIGHT_COEFF = 0.0
13+ # Get elements of atom IDs for pre and post molecules
14+ preElements = element_atomID_dict (DATA_DIR , PRE_FILE_NAME , ELEMENT_BY_TYPE )
15+ postElements = element_atomID_dict (DATA_DIR , POST_FILE_NAME , ELEMENT_BY_TYPE )
16+
17+ # Get atomIDs
18+ preAtomIDs = get_atomIDs (DATA_DIR , PRE_FILE_NAME )
19+ postAtomIDs = get_atomIDs (DATA_DIR , POST_FILE_NAME )
20+
21+ # Calculate bond distance matrices for pre and post molecule
22+ preBondDistMat = bond_distance_matrix (DATA_DIR , PRE_FILE_NAME , PRE_BONDING_ATOMS , powerBonds = False )
23+ postBondDistMat = bond_distance_matrix (DATA_DIR , POST_FILE_NAME , POST_BONDING_ATOMS , powerBonds = False )
24+
25+ # Set value for hydrogen that moves to epoxide ring to zero - this will be automated / user supplied info in the future
26+ for atom in POST_MAJOR_MOVED_ATOMS :
27+ atomIndex = postAtomIDs .index (atom )
28+ for index , _ in enumerate (postBondDistMat ):
29+ postBondDistMat [index ][atomIndex ] = 0.0
30+
31+ # Sample weight matrix - lower weight for atoms with significant movement
32+ sampleWeights = np .ones (len (postAtomIDs ))
33+ for atom in POST_MINOR_MOVED_ATOMS :
34+ atomIndex = postAtomIDs .index (atom )
35+ sampleWeights [atomIndex ] = _WEIGHT_COEFF
36+
37+ mappedIDList = []
38+ mappedPostAtomsIndex = []
39+ for searchIndex , searchRow in enumerate (preBondDistMat ):
40+ # Shortcircuit this search if search atom is a bonding atom
41+ if preAtomIDs [searchIndex ] in PRE_BONDING_ATOMS :
42+ bondingIndex = PRE_BONDING_ATOMS .index (preAtomIDs [searchIndex ])
43+ bondingPostAtomID = POST_BONDING_ATOMS [bondingIndex ]
44+ mappedIDList .append ([preAtomIDs [searchIndex ], bondingPostAtomID ])
45+
46+ else :
47+ # Sort search row arrays from smallest to largest
48+ searchRowIndex = np .argsort (searchRow )
49+ searchRowSorted = np .take_along_axis (searchRow , searchRowIndex , axis = 0 )
50+
51+ distDifference = []
52+ for row in postBondDistMat :
53+ # Sort row arrays from smallest to largest
54+ rowIndex = np .argsort (row )
55+ rowSorted = np .take_along_axis (row , rowIndex , axis = 0 )
56+
57+ # Sort sample weight matrix the same as row
58+ sampleWeightsSorted = np .take_along_axis (sampleWeights , rowIndex , axis = 0 )
59+
60+ # MAE
61+ finalVal = mean_absolute_error (searchRowSorted , rowSorted , sample_weight = sampleWeightsSorted )
62+
63+ # Append - abs to get smallest value closest to zero
64+ distDifference .append (abs (finalVal ))
65+
66+ mappedPreAtomID , mappedPostAtomID , postAtomIDIndex = element_validation (preAtomIDs [searchIndex ], postAtomIDs , distDifference , preElements , postElements , POST_BONDING_ATOMS )
67+
68+ mappedIDList .append ([mappedPreAtomID , mappedPostAtomID ])
69+ mappedPostAtomsIndex .append (postAtomIDIndex )
70+
71+ # Ambiguous Atom Group Processing
72+ # Gather all the pairs
73+ mappedPostAtomIDs = [val [1 ] for val in mappedIDList ]
74+ repeatedPostIDSet = natsorted (set ([val for val in mappedPostAtomIDs if mappedPostAtomIDs .count (val ) > 1 ]))
75+ repeatedIndexes = [postAtomIDs .index (ID ) for ID in repeatedPostIDSet ]
76+
77+ ambiguousGroupPairs = []
78+ # Loop through all post atoms to find similar
79+ for index in repeatedIndexes :
80+ matchArray = postBondDistMat [index ]
5581
82+ # Sort search row arrays from smallest to largest
83+ searchRowIndex = np .argsort (matchArray )
84+ searchRowSorted = np .take_along_axis (matchArray , searchRowIndex , axis = 0 )
85+
5686 distDifference = []
5787 for row in postBondDistMat :
5888 # Sort row arrays from smallest to largest
6898 # Append - abs to get smallest value closest to zero
6999 distDifference .append (abs (finalVal ))
70100
71- mappedPreAtomID , mappedPostAtomID = element_validation (preAtomIDs [searchIndex ], postAtomIDs , distDifference , preElements , postElements )
72-
73- mappedIDList .append ([mappedPreAtomID , mappedPostAtomID ])
74-
101+ # Set repeatedIndex value to nan as it will always be zero
102+ distDifference [index ] = math .nan
103+ _ , smallestIndex = min ((val , idx ) for (idx , val ) in enumerate (distDifference ))
75104
105+ ambiguousGroupPairs .append ([postAtomIDs [index ], postAtomIDs [smallestIndex ]])
106+ # print(f'Atom {postAtomIDs[index]} is paired to atom {postAtomIDs[smallestIndex]}')
76107
77- # Print test report
78- for mappedPair in mappedIDList :
79- print (f'Atom { mappedPair [0 ]} is mapped to atom { mappedPair [1 ]} ' )
108+ # Update mappedIDList based on the ambiguousGroupPairs values
109+ # Interestingly, mappedIDList can be updated with the iterator, but ambiguousGroupPairs needs to be deleted with the index value
110+ for mappedID in mappedIDList :
111+ if mappedID [1 ] in repeatedPostIDSet : # If mappedPostAtomID is one that is repeated
112+ for index , groupPair in enumerate (ambiguousGroupPairs ):
113+ if groupPair [0 ] == mappedID [1 ]: # If groupPair is a matching PostAtomID
114+ mappedID [1 ] = groupPair [1 ]
115+ del ambiguousGroupPairs [index ]
116+ break
80117
81- correctPostAtomIDs = [['38' ], ['39' ], ['35' ], ['41' , '42' ], ['42' , '41' ], ['32' ], ['16' ], ['5' , '36' ], ['36' , '5' ], ['37' ], ['6' , '9' ], ['4' ], ['1' , '3' ], ['3' , '1' ], ['9' , '6' ], ['17' , '23' ], ['23' , '17' ], ['15' ], ['33' , '34' ], ['34' , '33' ]]
82- totalAtoms = len (correctPostAtomIDs )
83- correctAtoms = 0
84- incorrectPreAtomsList = []
85- for index , atom in enumerate (mappedIDList ):
86- if atom [1 ] in correctPostAtomIDs [index ]:
87- correctAtoms += 1
88- else :
89- incorrectPreAtomsList .append (atom [0 ])
118+ return mappedIDList
119+ # This needs to include a check to make sure it's not updating the atom to an ID already assigned - might be best to have an unassigned-postAtomIDList
90120
91- print (f'Test Results: Weight coeff is { _WEIGHT_COEFF } ' )
92- print (f'Correct atoms: { correctAtoms } . Accuracy: { round (correctAtoms / totalAtoms * 100 , 1 )} %' )
93- print (f'Incorrect premolecule atomIDs: { incorrectPreAtomsList } ' )
0 commit comments