-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Open
Labels
Description
OpenVINO Version
2025.4.0-20398-7a975177ff4-releases/2025/4 (Python API)
Operating System
Ubuntu 20.04 (LTS)
Device used for inference
CPU
Framework
None
Model used
No response
Issue description
When compiling and running a minimal graph involving uint8 arithmetic operations (Sub, Mul, Add) on CPU, a significant accuracy mismatch is observed compared to PyTorch.
- PyTorch (Oracle): Performs standard wrapping arithmetic for
uint8(e.g., underflow wraps around). - OpenVINO (CPU): The results appear to be saturated or incorrectly clamped. The output contains almost exclusively
0s and255s, whereas the expected output (PyTorch) contains a uniform distribution of values.
This suggests that the CPU plugin might be incorrectly applying saturation logic or handling integer overflow/underflow differently from the ONNX/PyTorch standard for this sequence of operations.
Step-by-step reproduction
- Install OpenVINO 2025.4.0 Python package.
- Save the provided reproduction script as
repro_bug.py. - Run
python repro_bug.py. - Observe the failure:
Mismatched elements: 96 / 96 (100.0%).
import torch
import numpy as np
import openvino.runtime as ov
import onnx
import os
# ================= Configuration =================
# 1. Set Random Seed
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
# 2. Input Parameters
INPUT_DTYPE = np.uint8
OPSET_VERSION = 14
ONNX_PATH = "bug_repro.onnx"
DEVICE = "CPU"
# Tolerance for uint8 (Exact Match Required)
RTOL = 0
ATOL = 0
# ================= Model Definition =================
class BugModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.register_buffer("v6_0", torch.randint(0, 255, (16, 6), dtype=torch.uint8))
def forward(self, v0_0, v3_0):
# Operations susceptible to overflow/underflow
v7_0 = torch.sub(v0_0, self.v6_0)
v5_0 = torch.mul(v3_0, v7_0)
v2_0 = torch.add(v5_0, v5_0)
return v2_0
# ================= Reporting Function =================
def print_mismatch_report(actual, desired, rtol, atol, location_name="output"):
actual = np.array(actual)
desired = np.array(desired)
diff = np.abs(actual.astype(np.float32) - desired.astype(np.float32))
with np.errstate(divide='ignore', invalid='ignore'):
rel_diff = diff / np.abs(desired.astype(np.float32))
rel_diff[np.isnan(rel_diff)] = 0
tol = atol + rtol * np.abs(desired)
mismatch_mask = diff > tol
mismatch_count = np.sum(mismatch_mask)
if mismatch_count == 0:
print(f"✅ Results match within tolerance (rtol={rtol}, atol={atol}).")
return
total_count = desired.size
mismatch_percent = (mismatch_count / total_count) * 100
print(f"Not equal to tolerance rtol={rtol}, atol={atol}")
print(f"openvino (cpu opt: True) != torch[cpu] eager at {location_name}")
print(f"Mismatched elements: {mismatch_count} / {total_count} ({mismatch_percent:.1f}%)")
print(f"Max absolute difference: {np.max(diff)}")
print(f"Max relative difference: {np.max(rel_diff)}")
np.set_printoptions(threshold=1000, edgeitems=3, precision=8, suppress=True)
print(f" x: {actual.__repr__()}")
print(f" y: {desired.__repr__()}")
# ================= Main Process =================
def run_reproduction():
print(f"[Info] Dtype: {INPUT_DTYPE}")
v0_0_np = np.random.randint(0, 255, size=()).astype(INPUT_DTYPE)
v3_0_np = np.random.randint(0, 255, size=(1,)).astype(INPUT_DTYPE)
v0_0_torch = torch.tensor(v0_0_np)
v3_0_torch = torch.tensor(v3_0_np)
model = BugModel()
model.eval()
try:
with torch.no_grad():
y_torch = model(v0_0_torch, v3_0_torch)
except Exception as e:
print(f"❌ PyTorch failed: {e}")
return
# Export ONNX
output_name = "v2_0"
torch.onnx.export(
model,
(v0_0_torch, v3_0_torch),
ONNX_PATH,
input_names=['v0_0', 'v3_0'],
output_names=[output_name],
opset_version=OPSET_VERSION
)
# Run OpenVINO
try:
core = ov.Core()
ov_model = core.read_model(ONNX_PATH)
compiled_model = core.compile_model(ov_model, DEVICE)
infer_request = compiled_model.create_infer_request()
res = infer_request.infer({
"v0_0": v0_0_np,
"v3_0": v3_0_np
})[compiled_model.output(0)]
y_ov = res
except Exception as e:
print(f"❌ OpenVINO failed: {e}")
return
# Compare
print_mismatch_report(
actual=y_ov,
desired=y_torch.numpy(),
rtol=RTOL,
atol=ATOL,
location_name=output_name
)
if os.path.exists(ONNX_PATH):
os.remove(ONNX_PATH)
if __name__ == "__main__":
run_reproduction()Relevant log output
(fuzzer) xwoven@DESKTOP-75U6BRL:~/Rosetta-Fuzzer$ python /home/xwoven/1/test2.py
/home/xwoven/miniconda3/envs/fuzzer/lib/python3.10/site-packages/openvino/runtime/__init__.py:10: DeprecationWarning: The `openvino.runtime` module is deprecated and will be removed in the 2026.0 release. Please replace `openvino.runtime` with `openvino`.
warnings.warn(
[Info] Dtype: <class 'numpy.uint8'>
Not equal to tolerance rtol=0, atol=0
openvino (cpu opt: True) != torch[cpu] eager at v2_0
Mismatched elements: 96 / 96 (100.0%)
Max absolute difference: 246.0
Max relative difference: 17.214284896850586
x: array([[ 0, 0, 0, 0, 0, 0],
[255, 0, 255, 255, 255, 0],
[ 0, 255, 0, 255, 0, 255],
...
[255, 0, 255, 255, 255, 0],
[255, 255, 0, 255, 0, 255],
[ 0, 255, 0, 255, 255, 255]], dtype=uint8)
y: array([[ 12, 26, 134, 90, 112, 214],
[192, 76, 198, 128, 130, 60],
[244, 26, 130, 230, 220, 168],
...
[152, 132, 14, 80, 78, 156],
[236, 196, 84, 150, 220, 118],
[ 34, 216, 206, 38, 98, 196]], dtype=uint8)Issue submission checklist
- I'm reporting an issue. It's not a question.
- I checked the problem with the documentation, FAQ, open issues, Stack Overflow, etc., and have not found a solution.
- There is reproducer code and related data files such as images, videos, models, etc.
Reactions are currently unavailable