-
Notifications
You must be signed in to change notification settings - Fork 13
Expand file tree
/
Copy pathevaluate_ere.py
More file actions
24 lines (18 loc) · 773 Bytes
/
evaluate_ere.py
File metadata and controls
24 lines (18 loc) · 773 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from run_ere import ERERunner
import sys
import torch
def evaluate(config_name, gpu_id, saved_suffix):
runner = ERERunner(
config_file="configs/ere.conf",
config_name=config_name,
gpu_id=gpu_id
)
model, _ = runner.initialize_model(saved_suffix, continue_training=False)
examples_train, examples_dev, examples_test = runner.data.get_tensor_examples()
stored_info = runner.data.get_stored_info()
runner.evaluate(model, examples_test, stored_info, 0, predict=True) # Eval test
# E.g.
# CUDA_VISIBLE_DEVICES=0 python evaluate_ere.py t5_base Aug14_19-53-06_85000 0
if __name__ == '__main__':
config_name, saved_suffix, gpu_id = sys.argv[1], sys.argv[2], int(sys.argv[3])
evaluate(config_name, gpu_id, saved_suffix)