Skip to content

Commit a025df6

Browse files
committed
Address aux_e2 for pbc and newer pyscf
1 parent a51cd86 commit a025df6

File tree

2 files changed

+49
-13
lines changed

2 files changed

+49
-13
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ dependencies = [
2727
"ordered-set",
2828
"libdmet @ git+https://github.com/gkclab/libdmet_preview.git",
2929
"chemcoord @ git+https://github.com/mcocdawc/chemcoord.git",
30+
"packaging",
3031
"pathos",
3132
]
3233

src/quemb/molbe/eri_sparse_DF.py

Lines changed: 48 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,22 @@
1919

2020
import h5py
2121
import numpy as np
22+
import pyscf
2223
from chemcoord import Cartesian
2324
from numba import prange # type: ignore[attr-defined]
2425
from numba.typed import List
26+
from packaging.version import parse as parse_version
2527
from pyscf import dft, gto, scf
2628
from pyscf.ao2mo.addons import restore
2729
from pyscf.df.addons import make_auxmol
30+
from pyscf.df.incore import aux_e2
2831
from pyscf.gto import Mole
2932
from pyscf.gto.moleintor import getints
3033
from pyscf.pbc import dft as pbc_dft
34+
from pyscf.pbc.df.incore import aux_e2 as pbc_aux_e2
3135
from pyscf.pbc.df.incore import make_auxcell
3236
from pyscf.pbc.gto import Cell
33-
from pyscf.pbc.scf import KRHF
37+
from pyscf.pbc.scf.khf import KRHF
3438
from pyscf.pbc.tools import super_cell
3539
from scipy.linalg import cholesky
3640
from scipy.special import roots_hermite
@@ -82,7 +86,7 @@
8286

8387
def _aux_e2( # type: ignore[no-untyped-def]
8488
mol: Mole,
85-
auxmol_or_auxbasis: Mole | str,
89+
auxmol_or_auxbasis: gto.MoleBase | str,
8690
intor: str = "int3c2e",
8791
aosym: str = "s1",
8892
comp: int | None = None,
@@ -94,8 +98,8 @@ def _aux_e2( # type: ignore[no-untyped-def]
9498
9599
Fixes a bug in the original implementation :func:`pyscf.df.incore.aux_e2`
96100
that does not accept all valid slices.
97-
Replace with the original, as soon as https://github.com/pyscf/pyscf/pull/2734
98-
is merged in the stable release.
101+
This function has been fixed in pyscf >= 2.9.0 and is kept for backwards
102+
compatibility with older pyscf versions.
99103
"""
100104
if isinstance(auxmol_or_auxbasis, gto.MoleBase):
101105
auxmol = auxmol_or_auxbasis
@@ -123,6 +127,37 @@ def _aux_e2( # type: ignore[no-untyped-def]
123127
)
124128

125129

130+
def aux_e2_wrapper(
131+
mol: _T_chemsystem,
132+
auxmol_or_auxbasis: _T_chemsystem | gto.MoleBase | str,
133+
intor: str = "int3c2e",
134+
shls_slice: tuple[int, int, int, int, int, int] | list[int] | None = None,
135+
) -> Tensor3D[np.float64]:
136+
if isinstance(mol, Cell):
137+
return pbc_aux_e2(
138+
mol,
139+
auxmol_or_auxbasis,
140+
intor=intor,
141+
shls_slice=shls_slice,
142+
)
143+
elif parse_version(pyscf.__version__) < parse_version("2.9.0"):
144+
# use fixed version of aux_e2 for older pyscf versions
145+
return _aux_e2(
146+
mol,
147+
auxmol_or_auxbasis,
148+
intor=intor,
149+
shls_slice=shls_slice,
150+
)
151+
else:
152+
# from pyscf.df.incore
153+
return aux_e2(
154+
mol,
155+
auxmol_or_auxbasis,
156+
intor=intor,
157+
shls_slice=shls_slice,
158+
)
159+
160+
126161
_T_old_key = TypeVar("_T_old_key", bound=Hashable)
127162
_T_new_key = TypeVar("_T_new_key", bound=Hashable)
128163

@@ -244,7 +279,7 @@ def _get_AO_per_AO(
244279

245280

246281
def conversions_AO_shell(
247-
mol: Mole,
282+
mol: _T_chemsystem,
248283
) -> tuple[dict[ShellIdx, list[AOIdx]], dict[AOIdx, ShellIdx]]:
249284
"""Return dictionaries that for a shell index return the corresponding AO indices
250285
and for an AO index return the corresponding shell index.
@@ -411,7 +446,7 @@ def _traverse_reachable(
411446

412447

413448
def get_sparse_P_mu_nu(
414-
mol: Mole,
449+
mol: _T_chemsystem,
415450
auxmol: Mole,
416451
exch_reachable: Mapping[AOIdx, Sequence[AOIdx]],
417452
) -> SemiSparseSym3DTensor:
@@ -462,7 +497,7 @@ def to_shell_reachable_by_shell(
462497
for i_shell, reachable in shell_reachable_by_shell.items():
463498
for start_block, stop_block in get_blocks(reachable):
464499
integrals = np.asarray( # type: ignore[call-overload]
465-
_aux_e2(
500+
aux_e2_wrapper(
466501
mol,
467502
auxmol,
468503
intor="int3c2e",
@@ -608,12 +643,12 @@ def _run_sparse_df_driver(
608643
set_log_level(logging.getLogger().getEffectiveLevel())
609644

610645
is_periodic: Final[bool] = isinstance(mf, KRHF)
611-
mol: Final[_T_chemsystem] = mf.mol # KRHF also has mol as a Cell
612-
auxmol: Final[Mole] = (
613-
make_auxcell(mol, auxbasis=auxbasis)
614-
if is_periodic
615-
else make_auxmol(mol, auxbasis=auxbasis)
616-
)
646+
if is_periodic:
647+
mol = mf.cell
648+
auxmol = make_auxcell(mol, auxbasis=auxbasis)
649+
else:
650+
mol = mf.mol
651+
auxmol = make_auxmol(mol, auxbasis=auxbasis)
617652

618653
S_abs: Final[Matrix[np.floating]] = approx_S_abs(mol)
619654

0 commit comments

Comments
 (0)