1919
2020import h5py
2121import numpy as np
22+ import pyscf
2223from chemcoord import Cartesian
2324from numba import prange # type: ignore[attr-defined]
2425from numba .typed import List
26+ from packaging .version import parse as parse_version
2527from pyscf import dft , gto , scf
2628from pyscf .ao2mo .addons import restore
2729from pyscf .df .addons import make_auxmol
30+ from pyscf .df .incore import aux_e2
2831from pyscf .gto import Mole
2932from pyscf .gto .moleintor import getints
3033from pyscf .pbc import dft as pbc_dft
34+ from pyscf .pbc .df .incore import aux_e2 as pbc_aux_e2
3135from pyscf .pbc .df .incore import make_auxcell
3236from pyscf .pbc .gto import Cell
33- from pyscf .pbc .scf import KRHF
37+ from pyscf .pbc .scf . khf import KRHF
3438from pyscf .pbc .tools import super_cell
3539from scipy .linalg import cholesky
3640from scipy .special import roots_hermite
8286
8387def _aux_e2 ( # type: ignore[no-untyped-def]
8488 mol : Mole ,
85- auxmol_or_auxbasis : Mole | str ,
89+ auxmol_or_auxbasis : gto . MoleBase | str ,
8690 intor : str = "int3c2e" ,
8791 aosym : str = "s1" ,
8892 comp : int | None = None ,
@@ -94,8 +98,8 @@ def _aux_e2( # type: ignore[no-untyped-def]
9498
9599 Fixes a bug in the original implementation :func:`pyscf.df.incore.aux_e2`
96100 that does not accept all valid slices.
97- Replace with the original, as soon as https://github.com/pyscf/pyscf/pull/2734
98- is merged in the stable release .
101+ This function has been fixed in pyscf >= 2.9.0 and is kept for backwards
102+ compatibility with older pyscf versions .
99103 """
100104 if isinstance (auxmol_or_auxbasis , gto .MoleBase ):
101105 auxmol = auxmol_or_auxbasis
@@ -123,6 +127,37 @@ def _aux_e2( # type: ignore[no-untyped-def]
123127 )
124128
125129
130+ def aux_e2_wrapper (
131+ mol : _T_chemsystem ,
132+ auxmol_or_auxbasis : _T_chemsystem | gto .MoleBase | str ,
133+ intor : str = "int3c2e" ,
134+ shls_slice : tuple [int , int , int , int , int , int ] | list [int ] | None = None ,
135+ ) -> Tensor3D [np .float64 ]:
136+ if isinstance (mol , Cell ):
137+ return pbc_aux_e2 (
138+ mol ,
139+ auxmol_or_auxbasis ,
140+ intor = intor ,
141+ shls_slice = shls_slice ,
142+ )
143+ elif parse_version (pyscf .__version__ ) < parse_version ("2.9.0" ):
144+ # use fixed version of aux_e2 for older pyscf versions
145+ return _aux_e2 (
146+ mol ,
147+ auxmol_or_auxbasis ,
148+ intor = intor ,
149+ shls_slice = shls_slice ,
150+ )
151+ else :
152+ # from pyscf.df.incore
153+ return aux_e2 (
154+ mol ,
155+ auxmol_or_auxbasis ,
156+ intor = intor ,
157+ shls_slice = shls_slice ,
158+ )
159+
160+
126161_T_old_key = TypeVar ("_T_old_key" , bound = Hashable )
127162_T_new_key = TypeVar ("_T_new_key" , bound = Hashable )
128163
@@ -244,7 +279,7 @@ def _get_AO_per_AO(
244279
245280
246281def conversions_AO_shell (
247- mol : Mole ,
282+ mol : _T_chemsystem ,
248283) -> tuple [dict [ShellIdx , list [AOIdx ]], dict [AOIdx , ShellIdx ]]:
249284 """Return dictionaries that for a shell index return the corresponding AO indices
250285 and for an AO index return the corresponding shell index.
@@ -411,7 +446,7 @@ def _traverse_reachable(
411446
412447
413448def get_sparse_P_mu_nu (
414- mol : Mole ,
449+ mol : _T_chemsystem ,
415450 auxmol : Mole ,
416451 exch_reachable : Mapping [AOIdx , Sequence [AOIdx ]],
417452) -> SemiSparseSym3DTensor :
@@ -462,7 +497,7 @@ def to_shell_reachable_by_shell(
462497 for i_shell , reachable in shell_reachable_by_shell .items ():
463498 for start_block , stop_block in get_blocks (reachable ):
464499 integrals = np .asarray ( # type: ignore[call-overload]
465- _aux_e2 (
500+ aux_e2_wrapper (
466501 mol ,
467502 auxmol ,
468503 intor = "int3c2e" ,
@@ -608,12 +643,12 @@ def _run_sparse_df_driver(
608643 set_log_level (logging .getLogger ().getEffectiveLevel ())
609644
610645 is_periodic : Final [bool ] = isinstance (mf , KRHF )
611- mol : Final [ _T_chemsystem ] = mf . mol # KRHF also has mol as a Cell
612- auxmol : Final [ Mole ] = (
613- make_auxcell (mol , auxbasis = auxbasis )
614- if is_periodic
615- else make_auxmol ( mol , auxbasis = auxbasis )
616- )
646+ if is_periodic :
647+ mol = mf . cell
648+ auxmol = make_auxcell (mol , auxbasis = auxbasis )
649+ else :
650+ mol = mf . mol
651+ auxmol = make_auxmol ( mol , auxbasis = auxbasis )
617652
618653 S_abs : Final [Matrix [np .floating ]] = approx_S_abs (mol )
619654
0 commit comments