-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
109 lines (88 loc) · 3.1 KB
/
utils.py
File metadata and controls
109 lines (88 loc) · 3.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""
Utility functions for model evaluation and metrics
"""
import torch
import torch.nn as nn
from typing import List, Dict
import numpy as np
import sys
from tqdm import tqdm
def compute_accuracy(model: nn.Module, data_loader, device: str = 'cuda') -> float:
"""
Compute token-level accuracy.
Args:
model: Model to evaluate
data_loader: Data loader
device: Device to use
Returns:
Accuracy score
"""
model.eval()
correct = 0
total = 0
with torch.no_grad():
for batch in tqdm(
data_loader,
desc="Computing accuracy",
mininterval=0.1,
maxinterval=1.0,
file=sys.stderr, # Write to stderr to avoid buffering issues
dynamic_ncols=True, # Auto-adjust to terminal width
disable=False, # Explicitly enable
):
input_ids = batch['input_ids'].to(device)
labels = batch['labels'].to(device)
logits, _ = model(input_ids)
predictions = torch.argmax(logits, dim=-1)
# Mask out padding tokens
mask = (labels != -100)
correct += ((predictions == labels) * mask).sum().item()
total += mask.sum().item()
accuracy = correct / total if total > 0 else 0.0
return accuracy
def compute_metrics(model: nn.Module, data_loader, device: str = 'cuda') -> Dict[str, float]:
"""
Compute various evaluation metrics.
Args:
model: Model to evaluate
data_loader: Data loader
device: Device to use
Returns:
Dictionary of metrics
"""
model.eval()
total_loss = 0.0
correct = 0
total_tokens = 0
criterion = nn.CrossEntropyLoss(ignore_index=-100, reduction='sum')
with torch.no_grad():
for batch in tqdm(
data_loader,
desc="Computing metrics",
mininterval=0.1,
maxinterval=1.0,
file=sys.stderr, # Write to stderr to avoid buffering issues
dynamic_ncols=True, # Auto-adjust to terminal width
disable=False, # Explicitly enable
):
input_ids = batch['input_ids'].to(device)
labels = batch['labels'].to(device)
logits, _ = model(input_ids)
logits = logits.view(-1, logits.size(-1))
labels_flat = labels.view(-1)
# Loss
loss = criterion(logits, labels_flat)
total_loss += loss.item()
# Accuracy
predictions = torch.argmax(logits, dim=-1)
mask = (labels_flat != -100)
correct += ((predictions == labels_flat) * mask).sum().item()
total_tokens += mask.sum().item()
avg_loss = total_loss / total_tokens if total_tokens > 0 else 0.0
accuracy = correct / total_tokens if total_tokens > 0 else 0.0
perplexity = np.exp(avg_loss) if avg_loss > 0 else float('inf')
return {
'loss': avg_loss,
'accuracy': accuracy,
'perplexity': perplexity,
}