Skip to content

Commit c10281a

Browse files
authored
Merge pull request #1 from m-bone/AmbiguousAtomControl
Ambiguous atom control
2 parents 58ec7d2 + eae56ae commit c10281a

File tree

5 files changed

+192
-75
lines changed

5 files changed

+192
-75
lines changed

AtomMapping.py

Lines changed: 93 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,88 @@
11
import os
2+
import math
23
import numpy as np
4+
from natsort import natsorted
35
from sklearn.metrics import mean_absolute_error
46

57
from BondDistanceMatrix import bond_distance_matrix
68
from MappingFunctions import element_atomID_dict, element_validation, get_atomIDs
79

8-
# File search constants and user inputs
9-
DATA_DIR = os.getcwd() + '/Test_Cases'
10-
PRE_FILE_NAME = 'new_start_molecule.data'
11-
POST_FILE_NAME = 'new_post_rx1_molecule.data'
12-
ELEMENT_BY_TYPE = ['H', 'H', 'C', 'C', 'N', 'O', 'O', 'O']
13-
PRE_BONDING_ATOMS = ['28', '62']
14-
POST_BONDING_ATOMS = ['32', '15']
15-
POST_MAJOR_MOVED_ATOMS = ['33']
16-
POST_MINOR_MOVED_ATOMS = ['16']
17-
_WEIGHT_COEFF = 0.0
18-
19-
# Get elements of atom IDs for pre and post molecules
20-
preElements = element_atomID_dict(DATA_DIR, PRE_FILE_NAME, ELEMENT_BY_TYPE)
21-
postElements = element_atomID_dict(DATA_DIR, POST_FILE_NAME, ELEMENT_BY_TYPE)
22-
23-
# Get atomIDs - using existing function for now
24-
preAtomIDs = get_atomIDs(DATA_DIR, PRE_FILE_NAME)
25-
postAtomIDs = get_atomIDs(DATA_DIR, POST_FILE_NAME)
26-
27-
# Calculate bond distance matrices for pre and post molecule
28-
preBondDistMat = bond_distance_matrix(DATA_DIR, PRE_FILE_NAME, PRE_BONDING_ATOMS)
29-
postBondDistMat = bond_distance_matrix(DATA_DIR, POST_FILE_NAME, POST_BONDING_ATOMS)
30-
31-
# Set value for hydrogen that moves to epoxide ring to zero - this will be automated / user supplied info in the future
32-
for atom in POST_MAJOR_MOVED_ATOMS:
33-
atomIndex = postAtomIDs.index(atom)
34-
for index, atomRow in enumerate(postBondDistMat):
35-
postBondDistMat[index][atomIndex] = 0.0
36-
37-
# Sample weight matrix - lower weight for atoms with significant movement
38-
sampleWeights = np.ones(len(postAtomIDs))
39-
for atom in POST_MINOR_MOVED_ATOMS:
40-
atomIndex = postAtomIDs.index(atom)
41-
sampleWeights[atomIndex] = _WEIGHT_COEFF
42-
43-
mappedIDList = []
44-
for searchIndex, searchRow in enumerate(preBondDistMat):
45-
# Shortcircuit this search if search atom is a bonding atom
46-
if preAtomIDs[searchIndex] in PRE_BONDING_ATOMS:
47-
bondingIndex = PRE_BONDING_ATOMS.index(preAtomIDs[searchIndex])
48-
bondingPostAtomID = POST_BONDING_ATOMS[bondingIndex]
49-
mappedIDList.append([preAtomIDs[searchIndex], bondingPostAtomID])
50-
51-
else:
52-
# Sort search row arrays from smallest to largest
53-
searchRowIndex = np.argsort(searchRow)
54-
searchRowSorted = np.take_along_axis(searchRow, searchRowIndex, axis=0)
10+
11+
def atom_mapping(DATA_DIR, PRE_FILE_NAME, POST_FILE_NAME, ELEMENT_BY_TYPE, PRE_BONDING_ATOMS, POST_BONDING_ATOMS, POST_MAJOR_MOVED_ATOMS, POST_MINOR_MOVED_ATOMS):
12+
_WEIGHT_COEFF = 0.0
13+
# Get elements of atom IDs for pre and post molecules
14+
preElements = element_atomID_dict(DATA_DIR, PRE_FILE_NAME, ELEMENT_BY_TYPE)
15+
postElements = element_atomID_dict(DATA_DIR, POST_FILE_NAME, ELEMENT_BY_TYPE)
16+
17+
# Get atomIDs
18+
preAtomIDs = get_atomIDs(DATA_DIR, PRE_FILE_NAME)
19+
postAtomIDs = get_atomIDs(DATA_DIR, POST_FILE_NAME)
20+
21+
# Calculate bond distance matrices for pre and post molecule
22+
preBondDistMat = bond_distance_matrix(DATA_DIR, PRE_FILE_NAME, PRE_BONDING_ATOMS, powerBonds=False)
23+
postBondDistMat = bond_distance_matrix(DATA_DIR, POST_FILE_NAME, POST_BONDING_ATOMS, powerBonds=False)
24+
25+
# Set value for hydrogen that moves to epoxide ring to zero - this will be automated / user supplied info in the future
26+
for atom in POST_MAJOR_MOVED_ATOMS:
27+
atomIndex = postAtomIDs.index(atom)
28+
for index, _ in enumerate(postBondDistMat):
29+
postBondDistMat[index][atomIndex] = 0.0
30+
31+
# Sample weight matrix - lower weight for atoms with significant movement
32+
sampleWeights = np.ones(len(postAtomIDs))
33+
for atom in POST_MINOR_MOVED_ATOMS:
34+
atomIndex = postAtomIDs.index(atom)
35+
sampleWeights[atomIndex] = _WEIGHT_COEFF
36+
37+
mappedIDList = []
38+
mappedPostAtomsIndex = []
39+
for searchIndex, searchRow in enumerate(preBondDistMat):
40+
# Shortcircuit this search if search atom is a bonding atom
41+
if preAtomIDs[searchIndex] in PRE_BONDING_ATOMS:
42+
bondingIndex = PRE_BONDING_ATOMS.index(preAtomIDs[searchIndex])
43+
bondingPostAtomID = POST_BONDING_ATOMS[bondingIndex]
44+
mappedIDList.append([preAtomIDs[searchIndex], bondingPostAtomID])
45+
46+
else:
47+
# Sort search row arrays from smallest to largest
48+
searchRowIndex = np.argsort(searchRow)
49+
searchRowSorted = np.take_along_axis(searchRow, searchRowIndex, axis=0)
50+
51+
distDifference = []
52+
for row in postBondDistMat:
53+
# Sort row arrays from smallest to largest
54+
rowIndex = np.argsort(row)
55+
rowSorted = np.take_along_axis(row, rowIndex, axis=0)
56+
57+
# Sort sample weight matrix the same as row
58+
sampleWeightsSorted = np.take_along_axis(sampleWeights, rowIndex, axis=0)
59+
60+
# MAE
61+
finalVal = mean_absolute_error(searchRowSorted, rowSorted, sample_weight=sampleWeightsSorted)
62+
63+
# Append - abs to get smallest value closest to zero
64+
distDifference.append(abs(finalVal))
65+
66+
mappedPreAtomID, mappedPostAtomID, postAtomIDIndex = element_validation(preAtomIDs[searchIndex], postAtomIDs, distDifference, preElements, postElements, POST_BONDING_ATOMS)
67+
68+
mappedIDList.append([mappedPreAtomID, mappedPostAtomID])
69+
mappedPostAtomsIndex.append(postAtomIDIndex)
70+
71+
# Ambiguous Atom Group Processing
72+
# Gather all the pairs
73+
mappedPostAtomIDs = [val[1] for val in mappedIDList]
74+
repeatedPostIDSet = natsorted(set([val for val in mappedPostAtomIDs if mappedPostAtomIDs.count(val) > 1]))
75+
repeatedIndexes = [postAtomIDs.index(ID) for ID in repeatedPostIDSet]
76+
77+
ambiguousGroupPairs = []
78+
# Loop through all post atoms to find similar
79+
for index in repeatedIndexes:
80+
matchArray = postBondDistMat[index]
5581

82+
# Sort search row arrays from smallest to largest
83+
searchRowIndex = np.argsort(matchArray)
84+
searchRowSorted = np.take_along_axis(matchArray, searchRowIndex, axis=0)
85+
5686
distDifference = []
5787
for row in postBondDistMat:
5888
# Sort row arrays from smallest to largest
@@ -68,26 +98,23 @@
6898
# Append - abs to get smallest value closest to zero
6999
distDifference.append(abs(finalVal))
70100

71-
mappedPreAtomID, mappedPostAtomID = element_validation(preAtomIDs[searchIndex], postAtomIDs, distDifference, preElements, postElements)
72-
73-
mappedIDList.append([mappedPreAtomID, mappedPostAtomID])
74-
101+
# Set repeatedIndex value to nan as it will always be zero
102+
distDifference[index] = math.nan
103+
_, smallestIndex = min((val, idx) for (idx, val) in enumerate(distDifference))
75104

105+
ambiguousGroupPairs.append([postAtomIDs[index], postAtomIDs[smallestIndex]])
106+
# print(f'Atom {postAtomIDs[index]} is paired to atom {postAtomIDs[smallestIndex]}')
76107

77-
# Print test report
78-
for mappedPair in mappedIDList:
79-
print(f'Atom {mappedPair[0]} is mapped to atom {mappedPair[1]}')
108+
# Update mappedIDList based on the ambiguousGroupPairs values
109+
# Interestingly, mappedIDList can be updated with the iterator, but ambiguousGroupPairs needs to be deleted with the index value
110+
for mappedID in mappedIDList:
111+
if mappedID[1] in repeatedPostIDSet: # If mappedPostAtomID is one that is repeated
112+
for index, groupPair in enumerate(ambiguousGroupPairs):
113+
if groupPair[0] == mappedID[1]: # If groupPair is a matching PostAtomID
114+
mappedID[1] = groupPair[1]
115+
del ambiguousGroupPairs[index]
116+
break
80117

81-
correctPostAtomIDs = [['38'], ['39'], ['35'], ['41', '42'], ['42', '41'], ['32'], ['16'], ['5', '36'], ['36', '5'], ['37'], ['6', '9'], ['4'], ['1', '3'], ['3', '1'], ['9', '6'], ['17', '23'], ['23', '17'], ['15'], ['33', '34'], ['34', '33']]
82-
totalAtoms = len(correctPostAtomIDs)
83-
correctAtoms = 0
84-
incorrectPreAtomsList = []
85-
for index, atom in enumerate(mappedIDList):
86-
if atom[1] in correctPostAtomIDs[index]:
87-
correctAtoms += 1
88-
else:
89-
incorrectPreAtomsList.append(atom[0])
118+
return mappedIDList
119+
# This needs to include a check to make sure it's not updating the atom to an ID already assigned - might be best to have an unassigned-postAtomIDList
90120

91-
print(f'Test Results: Weight coeff is {_WEIGHT_COEFF}')
92-
print(f'Correct atoms: {correctAtoms}. Accuracy: {round(correctAtoms / totalAtoms * 100, 1)}%')
93-
print(f'Incorrect premolecule atomIDs: {incorrectPreAtomsList}')

BondDistanceMatrix.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def get_bond_path(atomList, bonds):
5151

5252
return bondIDList
5353

54-
def calc_path_distance(bondList, bondDict):
54+
def calc_path_distance(bondList, bondDict, powers):
5555
# If bondList is empty return zero
5656
if len(bondList) == 0:
5757
return 0.0
@@ -60,6 +60,9 @@ def calc_path_distance(bondList, bondDict):
6060
for bondID in bondList:
6161
bondDistList.append(bondDict[bondID])
6262

63+
if powers:
64+
bondDistList = [bond ** (index + 1) for index, bond in enumerate(bondDistList)]
65+
6366
bondDistMultiple = reduce((lambda x, y: x * y), bondDistList)
6467
return bondDistMultiple
6568

@@ -139,7 +142,7 @@ def breadth_first_search(graph, start, target):
139142

140143
return path
141144

142-
def bond_distance_matrix(directory, fileName, bondingAtoms):
145+
def bond_distance_matrix(directory, fileName, bondingAtoms, powerBonds=False):
143146
os.chdir(directory)
144147

145148
# Load molecule file
@@ -180,7 +183,7 @@ def bond_distance_matrix(directory, fileName, bondingAtoms):
180183
for otherAtom in atomIDs:
181184
atomPath = breadth_first_search(moleculeGraph, startAtom, otherAtom)
182185
bondPath = get_bond_path(atomPath, bonds)
183-
pathDistance = calc_path_distance(bondPath, bondLengthDict)
186+
pathDistance = calc_path_distance(bondPath, bondLengthDict, powerBonds)
184187
atomBondDistanceList.append(pathDistance)
185188

186189
totalBondDistanceList.append(atomBondDistanceList)

DetailedTesting.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from AtomMapping import atom_mapping
2+
3+
class Reaction:
4+
def __init__(self, directory, preFileName, postFileName, elementByType, preBondingAtoms, postBondingAtoms, postMajorMovedAtoms, postMinorMovedAtoms):
5+
self.mappedIDList = atom_mapping(directory, preFileName, postFileName, elementByType, preBondingAtoms, postBondingAtoms, postMajorMovedAtoms, postMinorMovedAtoms)
6+
7+
def test_report(self, correctPostAtomIDs, reactionName):
8+
print(f'\n\nReaction: {reactionName}')
9+
# Print test report
10+
for mappedPair in self.mappedIDList:
11+
print(f'Atom {mappedPair[0]} is mapped to atom {mappedPair[1]}')
12+
13+
14+
totalAtoms = len(correctPostAtomIDs)
15+
correctAtoms = 0
16+
incorrectPreAtomsList = []
17+
for index, atom in enumerate(self.mappedIDList):
18+
if atom[1] in correctPostAtomIDs[index]:
19+
correctAtoms += 1
20+
else:
21+
incorrectPreAtomsList.append(atom[0])
22+
23+
mappedPostAtomsList = [val[1] for val in self.mappedIDList]
24+
repeatedPostIDs = [val for val in mappedPostAtomsList if mappedPostAtomsList.count(val) > 1]
25+
26+
print(f'Total atoms: {totalAtoms}. Correct atoms: {correctAtoms}. Accuracy: {round(correctAtoms / totalAtoms * 100, 1)}%')
27+
print(f'Incorrect premolecule atomIDs: {incorrectPreAtomsList}')
28+
print(f'Repeated Atoms: {repeatedPostIDs}, Count: {len(repeatedPostIDs)}')
29+
30+
# DGEBA-DETDA
31+
dgebaDetda = Reaction('/home/matt/Documents/Oct20-Dec20/Bonding_Test/DGEBA_DETDA/Reaction', 'new_start_molecule.data', 'new_post_rx1_molecule.data', ['H', 'H', 'C', 'C', 'N', 'O', 'O', 'O'],
32+
['28', '62'], ['32', '15'], ['33'], ['16'])
33+
correctDgebaDetda = [['38'], ['39'], ['35'], ['41', '42'], ['42', '41'], ['32'], ['16'], ['5', '36'], ['36', '5'], ['37'], ['6', '9'], ['4'], ['1', '3'], ['3', '1'], ['9', '6'], ['17', '23'], ['23', '17'], ['15'], ['33', '34'], ['34', '33']]
34+
dgebaDetda.test_report(correctDgebaDetda, 'DGEBA-DETDA')
35+
36+
# Ethyl Ethanoate
37+
ethylEthanoate = Reaction('/home/matt/Documents/Oct20-Dec20/Bonding_Test/Ethyl_Ethanoate/Reaction', 'pre-molecule.data', 'post-molecule.data', ['H', 'H', 'C', 'C', 'O', 'O', 'O', 'O'], ['6', '11'], ['7', '2'], [], [])
38+
correctEthylEthanoate = [['9'], ['8'], ['12', '13', '14'], ['13', '12', '14'], ['14', '12', '13'], ['7'], ['10', '11'], ['11', '10'], ['17', '16'], ['1'], ['2'], ['3', '4', '5'], ['4', '3', '5'], ['5', '3', '4'], ['15'], ['16', '17'], ['6']]
39+
ethylEthanoate.test_report(correctEthylEthanoate, 'Ethyl Ethanoate')
40+
41+
42+
# Nothing reasonable got given 13, too many 14 including some across the molecule boundary
43+
# 15 given 2 should be impossible for multiple reasons - 15 is O, 2 is C and 2 is a bonding atom
44+
45+
# Validation idea
46+
# Search for ambiguous groups in the post molecule by comparing post atom to all post atoms
47+
# I can find this easily and it can confirm if something should be an ambiguous group
48+
# Could cause issues if the BPDM manages to split two things that should be pairs - this check may find things that the BPDM doesn't
49+
# Tool would help explin why BPDM works in some cases but less in others
50+
# I can also predict how many ambiguous groups in my pre and post molecule with this method
51+
# This could check if I have as many as I expect and it may help identify atoms that have moved - can use ambiguous pairs as a useful tool
52+
# Can I use ambiguous pairs that don't exist before but do after and visa versa to identify moved atoms

MappingFunctions.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,22 +40,24 @@ def element_atomID_dict(directory, fileName, elementsByType):
4040

4141
return elementIDDict
4242

43-
def element_validation(preAtomID, postAtomIDList, differenceList, preElementDict, postElementDict):
43+
def element_validation(preAtomID, postAtomIDList, differenceList, preElementDict, postElementDict, postBondingAtoms):
4444
# Make a copy of unchanged differenceList
4545
originalDifferenceList = differenceList.copy()
4646

47-
checkElement = 1
48-
4947
# Find lowest difference post atom ID that is the same element as the pre atom ID
48+
checkElement = 1
5049
while checkElement:
5150
# Find smallest value and corresponding index
5251
val, idx = min((val, idx) for (idx, val) in enumerate(differenceList))
5352
# Find the smallest value's index in the original list
5453
originalIndex = originalDifferenceList.index(val)
5554

56-
# If elements are the same return the pre and post atom IDs
57-
if preElementDict[preAtomID] == postElementDict[postAtomIDList[originalIndex]]:
58-
return preAtomID, postAtomIDList[originalIndex]
55+
if postAtomIDList[originalIndex] in postBondingAtoms:
56+
# If chosen ID is one of the bondingAtoms, it's wrong so can be removed
57+
del differenceList[idx]
58+
elif preElementDict[preAtomID] == postElementDict[postAtomIDList[originalIndex]]:
59+
# If elements are the same return the pre and post atom IDs
60+
return preAtomID, postAtomIDList[originalIndex], originalIndex
5961
else:
6062
# If the elements are different delete the smallest value by index and try again
6163
del differenceList[idx]

test_AtomMapping.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from AtomMapping import atom_mapping
2+
3+
def validation_function(mappedIDList, correctPostAtomIDs):
4+
# Calculate accuracy
5+
totalAtoms = len(correctPostAtomIDs)
6+
correctAtoms = 0
7+
incorrectPreAtomsList = []
8+
for index, atom in enumerate(mappedIDList):
9+
if atom[1] in correctPostAtomIDs[index]:
10+
correctAtoms += 1
11+
else:
12+
incorrectPreAtomsList.append(atom[0])
13+
14+
accuracy = round(correctAtoms / totalAtoms * 100, 1)
15+
16+
# Calculate multiple assignment atoms
17+
mappedPostAtomsList = [val[1] for val in mappedIDList]
18+
repeatedPostIDs = [val for val in mappedPostAtomsList if mappedPostAtomsList.count(val) > 1]
19+
countRepeatedPostIDs = len(repeatedPostIDs)
20+
21+
return accuracy, countRepeatedPostIDs
22+
23+
def test_dgeba_detda():
24+
mappedIDList = atom_mapping('/home/matt/Documents/Oct20-Dec20/Bonding_Test/DGEBA_DETDA/Reaction', 'new_start_molecule.data', 'new_post_rx1_molecule.data', ['H', 'H', 'C', 'C', 'N', 'O', 'O', 'O'],
25+
['28', '62'], ['32', '15'], ['33'], ['16'])
26+
correctPostAtomIDs = [['38'], ['39'], ['35'], ['41', '42'], ['42', '41'], ['32'], ['16'], ['5', '36'], ['36', '5'], ['37'], ['6', '9'], ['4'], ['1', '3'], ['3', '1'], ['9', '6'], ['17', '23'], ['23', '17'], ['15'], ['33', '34'], ['34', '33']]
27+
acc, repeatCount = validation_function(mappedIDList, correctPostAtomIDs)
28+
29+
# Check accuracy and number of repeated IDs are as expect
30+
checkValues = [acc, repeatCount]
31+
expected = [95, 2]
32+
33+
assert checkValues == expected

0 commit comments

Comments
 (0)