Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions pytorch_pfn_extras/onnx/_globals.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import pytorch_pfn_extras
import torch
from typing import Optional
from packaging import version


import torch.onnx._globals
GLOBALS = torch.onnx._globals.GLOBALS

torch_version = version.parse(torch.__version__.split("+")[0])
if version.parse("2.9.0") <= torch_version:
from torch.onnx._internal.torchscript_exporter import _globals
else:
from torch.onnx import _globals
GLOBALS = _globals.GLOBALS
10 changes: 8 additions & 2 deletions pytorch_pfn_extras/onnx/symbolic_registry.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import pytorch_pfn_extras
import torch
from typing import cast, Any, Callable, Tuple, Union

import torch.onnx._internal.registration as reg
import torch.onnx.utils
from packaging import version

torch_version = version.parse(torch.__version__.split("+")[0])
if version.parse("2.9.0") <= torch_version:
import torch.onnx._internal.torchscript_exporter.registration as reg
else:
import torch.onnx._internal.registration as reg

def is_registered_op(opname: str, domain: str, version: int) -> Any:
return reg.registry.is_registered_op(f"{domain}::{opname}", version)
Expand Down