-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.py
More file actions
141 lines (106 loc) · 3.84 KB
/
model.py
File metadata and controls
141 lines (106 loc) · 3.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# ============================================================
# AdminDoc-X — CLEAN NER Pipeline for API Prediction
# LayoutLMv3 (fine-tuned) + OCR + clean token reconstruction
# ============================================================
import os
import json
import cv2
import torch
import numpy as np
from PIL import Image
import pytesseract
import re
from transformers import (
LayoutLMv3ForTokenClassification,
LayoutLMv3ImageProcessor,
LayoutLMv3TokenizerFast,
LayoutLMv3Processor,
)
# ============================================================
# CONFIG
# ============================================================
pytesseract.pytesseract.tesseract_cmd = r"C:\Program Files\Tesseract-OCR\tesseract.exe"
MODEL_PATH = "models/layoutlmv3_trained"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("DEVICE:", device)
# ============================================================
# LABELS (auto-loaded from fine-tuned model)
# ============================================================
tokenizer = LayoutLMv3TokenizerFast.from_pretrained(MODEL_PATH)
image_processor = LayoutLMv3ImageProcessor(apply_ocr=True)
processor = LayoutLMv3Processor(image_processor, tokenizer)
model = LayoutLMv3ForTokenClassification.from_pretrained(MODEL_PATH).to(device)
id2label = model.config.id2label
# ============================================================
# CLEAN TOKEN FUNCTION
# ============================================================
def clean_token(token):
"""
Supprime les préfixes: ##, Ġ, spaces, unknown tokens.
"""
token = token.replace("##", "")
token = token.replace("Ġ", "")
token = token.strip()
if token in ["[PAD]", "[CLS]", "[SEP]", "", " ", None]:
return None
return token
# ============================================================
# IMAGE PREPROCESSING
# ============================================================
def preprocess_image(img_path):
img = cv2.imread(img_path, 0)
# amélioration OCR
img = cv2.adaptiveThreshold(
img, 255,
cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
cv2.THRESH_BINARY,
11, 2
)
img = cv2.resize(img, None, fx=2, fy=2, interpolation=cv2.INTER_CUBIC)
temp_path = "temp_clean.png"
cv2.imwrite(temp_path, img)
return temp_path
# ============================================================
# ENTITY EXTRACTION
# ============================================================
def extract_entities(img_path):
img_clean = preprocess_image(img_path)
image = Image.open(img_clean).convert("RGB")
encoding = processor(
image,
return_tensors="pt",
padding="max_length",
truncation=True
)
encoding = {k: v.to(device) for k, v in encoding.items()}
with torch.no_grad():
outputs = model(**encoding)
logits = outputs.logits
predictions = torch.argmax(logits, dim=-1)[0].cpu().tolist()
tokens = tokenizer.convert_ids_to_tokens(encoding["input_ids"][0])
entities = {}
current_label = None
buffer = []
# =============== Extract entities with BIO logic ===============
for token, pred in zip(tokens, predictions):
token = clean_token(token)
if token is None:
continue
label = id2label[pred]
if label.startswith("B-"):
# save previous entity
if current_label and buffer:
entities[current_label] = " ".join(buffer)
current_label = label[2:] # remove B-
buffer = [token]
elif label.startswith("I-") and current_label == label[2:]:
buffer.append(token)
else:
# O label
continue
# save last entity
if current_label and buffer:
entities[current_label] = " ".join(buffer)
# CLASSIFY DOC TYPE (default)
entities["DOC_TYPE"] = "autre"
return entities