-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmain.py
More file actions
209 lines (170 loc) · 8.28 KB
/
main.py
File metadata and controls
209 lines (170 loc) · 8.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import argparse
import logging
import os
import sys
from dotenv import load_dotenv
from src.nli_checker import NLIChecker
from src.pmid_extractor import PMIDExtractor
from src.node_normalization import NodeNormalizationClient
# Load environment variables
load_dotenv()
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
def main():
parser = argparse.ArgumentParser(description="Check if PMIDs support a research triple using NLI.")
# Triple arguments (CURIE-based)
parser.add_argument('--triple_curie', nargs=3,
help="Triple CURIEs: Subject_CURIE Predicate Object_CURIE (e.g., NCBIGene:6495 affects UMLS:C0596290)",
required=True)
# PMID arguments
parser.add_argument('--pmids', nargs='+', help="List of PMIDs to check", required=True)
# Model Config - use short names from available models
available_models = os.environ.get("AVAILABLE_NLI_MODELS", "bart,deberta,roberta,sapbert").split(',')
parser.add_argument('--nli_model',
help=f"NLI model: {', '.join(available_models)}",
default=available_models[0],
choices=available_models)
# GPU Configuration
parser.add_argument('--use_gpu', action='store_true', default=False,
help="Use GPU for inference (default: False, use CPU)")
parser.add_argument('--gpu', type=str, default="0",
help="Which GPU to use, e.g., '0', '1' (default: '0')")
# Threshold Configuration
parser.add_argument('--threshold', type=float, default=0.5,
help="Confidence threshold for entailment (0.0-1.0, default: 0.5). Higher = stricter.")
# Performance Configuration
parser.add_argument('--max_names', type=int, default=5,
help="Maximum number of equivalent names to test per entity (default: 5). More = slower but more robust.")
parser.add_argument('--batch_size', type=int, default=32,
help="Batch size for parallel processing (default: 32). Larger = faster but more GPU memory.")
# API Configuration (enabled by default, use flags to disable)
parser.add_argument('--no_umls', action='store_true', default=False,
help="Disable UMLS API integration (enabled by default). UMLS provides additional name synonyms.")
parser.add_argument('--no_hgnc', action='store_true', default=False,
help="Disable HGNC API integration (enabled by default). HGNC provides gene name resolution.")
args = parser.parse_args()
subject_curie, predicate, obj_curie = args.triple_curie
pmids = args.pmids
# Get NCBI credentials from environment
ncbi_email = os.environ.get("NCBI_EMAIL")
ncbi_api_key = os.environ.get("NCBI_API_KEY", "")
if not ncbi_email:
logger.error("NCBI_EMAIL not found in environment variables. Please set it in .env file.")
sys.exit(1)
# Initialize node normalization client
logger.info("Initializing node normalization...")
# Get API settings from command line args (enabled by default, disabled with --no_* flags)
umls_api_key = os.environ.get("UMLS_API_KEY", "")
use_umls = not args.no_umls # Enabled by default, disabled if --no_umls is specified
use_hgnc = not args.no_hgnc # Enabled by default, disabled if --no_hgnc is specified
# Log configuration
if use_umls:
if umls_api_key:
logger.info("UMLS integration: ENABLED (with API key)")
else:
logger.warning("UMLS integration: ENABLED but no API key found in .env (will be disabled)")
use_umls = False
else:
logger.info("UMLS integration: DISABLED (via --no_umls)")
if use_hgnc:
logger.info("HGNC integration: ENABLED")
else:
logger.info("HGNC integration: DISABLED (via --no_hgnc)")
# Initialize with API settings
node_normalizer = NodeNormalizationClient(
use_umls=use_umls and bool(umls_api_key),
use_hgnc=use_hgnc
)
# Get equivalent names for subject and object CURIEs
logger.info(f"Resolving subject CURIE: {subject_curie}")
subject_names = node_normalizer.get_equivalent_names(curie=subject_curie)
logger.info(f"Resolving object CURIE: {obj_curie}")
object_names = node_normalizer.get_equivalent_names(curie=obj_curie)
if not subject_names:
logger.error(f"Could not resolve subject CURIE: {subject_curie}")
sys.exit(1)
if not object_names:
logger.error(f"Could not resolve object CURIE: {obj_curie}")
sys.exit(1)
# Limit to top N names to avoid excessive combinations
subject_names_limited = subject_names[:args.max_names]
object_names_limited = object_names[:args.max_names]
logger.info(f"Resolved {len(subject_names)} subject names (using top {len(subject_names_limited)})")
logger.info(f"Resolved {len(object_names)} object names (using top {len(object_names_limited)})")
logger.info(f"Will test {len(subject_names_limited) * len(object_names_limited)} name combinations")
logger.info(f"Batch size: {args.batch_size}")
logger.info(f"Checking triple: ({subject_curie}, {predicate}, {obj_curie})")
logger.info(f"Against PMIDs: {pmids}")
# Validate parameters
if not 0.0 <= args.threshold <= 1.0:
logger.error(f"Threshold must be between 0.0 and 1.0, got: {args.threshold}")
sys.exit(1)
if args.max_names < 1:
logger.error(f"max_names must be at least 1, got: {args.max_names}")
sys.exit(1)
if args.batch_size < 1:
logger.error(f"batch_size must be at least 1, got: {args.batch_size}")
sys.exit(1)
# Initialize components
extractor = PMIDExtractor(email=ncbi_email, api_key=ncbi_api_key)
checker = NLIChecker(
model_name=args.nli_model,
use_gpu=args.use_gpu,
gpu_id=args.gpu,
threshold=args.threshold,
batch_size=args.batch_size
)
# Fetch abstracts
logger.info("Fetching abstracts...")
abstracts_data = extractor.extract_abstracts(pmids)
results = []
for pmid, data in abstracts_data.items():
if data.error:
logger.warning(f"Could not retrieve PMID {pmid}: {data.error}")
results.append({
"pmid": pmid,
"supported": False,
"reason": f"Error: {data.error}"
})
continue
if not data.abstract:
logger.warning(f"No abstract text for PMID {pmid}")
results.append({
"pmid": pmid,
"supported": False,
"reason": "No abstract text"
})
continue
# Check support with all equivalent names
logger.info(f"Checking PMID {pmid}...")
check_result = checker.check_support((subject_names_limited, predicate, object_names_limited), data.abstract)
# Extract the matched names from the best hypothesis
matched_hypothesis = check_result["hypothesis"]
result_entry = {
"pmid": pmid,
"subject_curie": subject_curie,
"predicate": predicate,
"object_curie": obj_curie,
"matched_hypothesis": matched_hypothesis,
"supported": check_result["supported"],
"confidence": check_result["confidence"],
"evidence": check_result["evidence"],
"threshold": check_result["threshold"]
}
results.append(result_entry)
status = "SUPPORTED" if check_result["supported"] else "NOT SUPPORTED"
print(f"\n=== Result for PMID {pmid} ===")
print(f"Triple (CURIEs): ({subject_curie}, {predicate}, {obj_curie})")
print(f"Best Match: {matched_hypothesis}")
print(f"Status: {status} (Confidence: {check_result['confidence']:.4f}, Threshold: {check_result['threshold']:.2f})")
if check_result["evidence"]:
print(f"Evidence: {check_result['evidence']}")
else:
print("No supporting evidence found.")
print("============================\n")
if __name__ == "__main__":
main()