Skip to content

Commit 5855b1a

Browse files
Merge pull request #330 from KernelTuner/fix-bfloat16-dtype
Replace bfloat16 dtype from `bfloat16` package by one from `ml_dtypes` package
2 parents c2f990e + 4cb3f05 commit 5855b1a

File tree

1 file changed

+33
-12
lines changed

1 file changed

+33
-12
lines changed

kernel_tuner/accuracy.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -58,22 +58,43 @@ def __call__(self, params):
5858

5959
def _find_bfloat16_if_available():
6060
# Try to get bfloat16 if available.
61-
try:
62-
from bfloat16 import bfloat16
63-
return bfloat16
64-
except ImportError:
65-
pass
61+
dtype = None
6662

63+
# get it via numpy if available
6764
try:
68-
from tensorflow import bfloat16
69-
return bfloat16.as_numpy_dtype
70-
except ImportError:
65+
dtype = np.dtype("bfloat16")
66+
except TypeError:
7167
pass
7268

73-
logging.warning(
74-
"could not find `bfloat16` data type for numpy, "
75-
+ "please install either the package `bfloat16` or `tensorflow`"
76-
)
69+
# otherwise, try ml_dtypes
70+
if dtype is None:
71+
try:
72+
from ml_dtypes import bfloat16
73+
dtype = bfloat16
74+
except ImportError:
75+
pass
76+
77+
# otherwise, try jax
78+
if dtype is None:
79+
try:
80+
from jax.numpy import bfloat16
81+
dtype = bfloat16
82+
except ImportError:
83+
pass
84+
85+
# otherwise, try tensorflow
86+
if dtype is None:
87+
try:
88+
from tensorflow import bfloat16
89+
dtype = bfloat16.as_numpy_dtype
90+
except ImportError:
91+
pass
92+
93+
if dtype is None:
94+
logging.warning(
95+
"could not find `bfloat16` data type for numpy, "
96+
+ "please install either the package `ml_dtypes`, `jax`, or `tensorflow`"
97+
)
7798

7899
return None
79100

0 commit comments

Comments
 (0)