File tree Expand file tree Collapse file tree 1 file changed +33
-12
lines changed
Expand file tree Collapse file tree 1 file changed +33
-12
lines changed Original file line number Diff line number Diff line change @@ -58,22 +58,43 @@ def __call__(self, params):
5858
5959def _find_bfloat16_if_available ():
6060 # Try to get bfloat16 if available.
61- try :
62- from bfloat16 import bfloat16
63- return bfloat16
64- except ImportError :
65- pass
61+ dtype = None
6662
63+ # get it via numpy if available
6764 try :
68- from tensorflow import bfloat16
69- return bfloat16 .as_numpy_dtype
70- except ImportError :
65+ dtype = np .dtype ("bfloat16" )
66+ except TypeError :
7167 pass
7268
73- logging .warning (
74- "could not find `bfloat16` data type for numpy, "
75- + "please install either the package `bfloat16` or `tensorflow`"
76- )
69+ # otherwise, try ml_dtypes
70+ if dtype is None :
71+ try :
72+ from ml_dtypes import bfloat16
73+ dtype = bfloat16
74+ except ImportError :
75+ pass
76+
77+ # otherwise, try jax
78+ if dtype is None :
79+ try :
80+ from jax .numpy import bfloat16
81+ dtype = bfloat16
82+ except ImportError :
83+ pass
84+
85+ # otherwise, try tensorflow
86+ if dtype is None :
87+ try :
88+ from tensorflow import bfloat16
89+ dtype = bfloat16 .as_numpy_dtype
90+ except ImportError :
91+ pass
92+
93+ if dtype is None :
94+ logging .warning (
95+ "could not find `bfloat16` data type for numpy, "
96+ + "please install either the package `ml_dtypes`, `jax`, or `tensorflow`"
97+ )
7798
7899 return None
79100
You can’t perform that action at this time.
0 commit comments