Skip to content

Commit a36c37d

Browse files
authored
Feat: Add new unifined postprocess class to handle DeePTB model. (#301)
* new_post_v0 * new post v0.2 * feat: add support for built-in polynomial checkpoints Enhance build_model to handle 'poly2' and 'poly4' as special checkpoint names, automatically resolving them to their corresponding base model files in the dftb directory. This simplifies model initialization for common polynomial baselines. * feat: add GUI detection and refactor calculator interface Add is_gui_available() function to detect matplotlib GUI display availability in various environments including Jupyter notebooks and different operating systems. Refactor HamiltonianCalculator from Protocol to ABC with proper abstract methods for better type safety and inheritance. Update DeePTBAdapter to inherit from the new abstract base class. * feat: Implement unified post-processing for band structure and DOS, including new examples and tests. * feat: add get_hk method to HamiltonianCalculator interface Add new abstract method get_hk() to calculate H(k) and S(k) matrices at specified k-points. Rename get_hamiltonian_blocks() to get_hr() for clarity. Implement get_hk() in DeePTBAdapter with support for custom k-point injection and proper tensor handling. Remove redundant get_hamiltonian_at_k() method from BandAccessor. Add data property to TBSystem for atomic data access. * feat: add eigenstates calculation and DOS analysis capabilities - Add abstract get_eigenstates method to HamiltonianCalculator interface - Implement eigenstates calculation in DeePTBAdapter using Eigh solver - Add DosData class for structured DOS results storage - Add DOS plotting functionality with matplotlib integration - Refactor solver initialization to support both Eigenvalues and Eigh solvers - Remove unused logging imports and update error handling * refactor: improve DOS calculation interface and configuration Rename set_dos to set_dos_config for clarity in DOS configuration setup. Add reuse parameter to get_dos and get_bands methods to control recalculation. Improve error messages and add dos_data property for easier access to computed DOS results. Refactor DOS initialization to separate kpoints setup from configuration. * refactor(test): replace mock tests with integration tests using real data Replace mock-based unit tests with integration tests that use actual silicon example data. Add pytest fixtures for efficient system initialization, test band structure and DOS calculations with real model, and include plotting functionality tests. This provides more realistic validation of the unified postprocessing module. * feat: add unified post-processing tutorial notebook Add a comprehensive Jupyter notebook demonstrating the new TBSystem class for unified post-processing in DeePTB. The tutorial covers initialization, band structure calculation, DOS computation, and visualization using the centralized interface that manages both atomic structure and Hamiltonian model. * feat: ignore band structure and DOS plot files in .gitignore * fix: ensure dtype consistency in tensor product operations Fix dtype mismatches in tensor product calculations by explicitly setting dtype to match input tensors. Updates J_full_small initialization and _Jd tensor conversion in batch_wigner_D function, and enforces double precision in test cases to prevent numerical precision issues. * feat: add band_data initialization and simplify tensor type check - Initialize _band_data attribute in BandAccessor.__init__ for future use - Remove NestedTensor type check in test_postprocess_unified.py to simplify validation - Ensure consistent tensor type handling in Hamiltonian calculations * feat: add PythTB export tutorial notebook Add comprehensive tutorial demonstrating how to export DeePTB models to PythTB format. The notebook covers loading pre-trained models, calculating band structures with both DeePTB and PythTB, and comparing results to ensure compatibility. This enables users to leverage external tools for post-processing trained tight-binding models. * feat: implement Fermi level calculation and integration Add proper Fermi level calculation functionality to replace hardcoded values. The band structure and DOS properties now use the system's Fermi level when available, falling back to 0.0 when not set. Includes a new utility function for calculating Fermi energy from eigenvalues using various smearing methods and temperature parameters. Also improves DOS plotting with inward tick direction. * feat: add export functionality for TBSystem to third-party formats Add ExportAccessor class to * feat: support dict input for export interfaces Add support for dictionary input in ToWannier90 and ToPythTB export classes. Skip data loading when input is already a dictionary to avoid redundant processing. Update ExportAccessor to use explicit parameter names for better clarity. * feat: add PythTB-Wannier postprocessing example and rename existing notebook - Add new_postprocess_pythtb_wannier.ipynb demonstrating PythTB and Wannier90 integration - Rename new_postprocess.ipynb to new_postprocess_band_dos.ipynb for clarity - New example includes band structure calculation with k-path configuration and Fermi level determination * feat: add Wannier90 integration and export functionality Add new cells to demonstrate Wannier90 integration with PythTB, including export functionality and model loading. Update execution counts and fix reference from tbsys.band_data to tbsys.band for consistency. * test: update export tests to match implementation changes Updated test assertions in test_export_unified.py to align with the actual implementation of export methods. Fixed parameter names and added missing arguments like overlap, e_fermi, and filename parameters to ensure tests accurately validate the export functionality. * refactor: simplify HR2HK onsite block construction and SOC handling Simplify the HR2HK class by removing redundant overlap-specific code paths and consolidating onsite block construction. The SOC handling is now restricted to Hamiltonian-only cases with clearer documentation about current limitations. This refactoring improves code readability and maintainability while preserving the same functionality. * feat(nn): add gauge convention support to HR2HK module Add support for two gauge conventions in HR2HK: - Wannier90 Gauge (gauge=False) using cell shift vectors - Physical/Periodic Gauge (gauge=True) using edge vectors The gauge parameter controls phase factor calculation in k-space transformation. When derivative=True, automatically enables gauge=True mode. Added comprehensive test suite for both conventions. * add new key for AtomicDataDict: HAMILTONIAN_DERIV_KEY = "hamiltonian_derivative" # dH(k)/dk OVERLAP_DERIV_KEY = "overlap_derivative" # dS(k)/dk * feat(nn): add Hamiltonian k-derivative computation to HR2HK Add support for computing dH/dk derivatives in the HR2HK module. The implementation calculates the k-space derivatives of the Hamiltonian matrix using the analytical formula dH/dk = -i * R * H_R * exp(-i k·R), where R is the real-space hopping vector. The derivatives are computed for all three Cartesian directions and stored in the output data dictionary. This feature is controlled by the derivative parameter and uses the new out_derivative_field parameter to specify the output key. The gauge parameter default is changed to False to maintain consistency with the Wannier90 convention. * feat: add derivative support to get_hk method in DeePTBAdapter Enhance the get_hk method to support Hamiltonian and overlap matrix derivatives by adding a new with_derivative parameter. This allows computing derivatives of H(k) and S(k) with respect to k-points, which is useful for band structure analysis and property calculations. The implementation uses the HR2HK module for both regular and derivative computations, maintaining consistency with the existing codebase structure. * refactor: restructure derivative handling in DeePTBAdapter Restructured the derivative handling logic in the DeePTBAdapter class to improve code clarity and maintainability. The changes include moving derivative-related operations to a dedicated section and using .get() for safer access to optional derivative keys. This refactoring maintains the same functionality while making the code flow more logical and reducing potential errors when overlap derivatives are not present. * correct v0 * correct v0 * jit * v1 * feat: remove commented derivative code in HR2HK Clean up HR2HK module by removing commented-out derivative computation code that was no longer needed. The active derivative computation logic remains intact and functional. * refactor: rename optical to optical_conductivity and update class name Renamed optical.py to optical_conductivity.py for clarity and updated OpticalAccessor to ACAccessor. Modified TBSystem to use accond property for accessing optical conductivity functionality, improving code organization and naming consistency. * feat: improve optical conductivity calculation with complex Lorentzian Enhance optical conductivity computation by implementing complex Lorentzian broadening to capture both real and imaginary parts. Update accumulation methods (loop and JIT) to use complex form 1/(E - ω + iη) for physical accuracy. Add comprehensive documentation for parameters and return values. Remove redundant vectorized method for cleaner code structure. * feat: add spin degeneracy parameter to optical conductivity Add g_s parameter to control spin degeneracy factor in optical conductivity calculations. Replace hardcoded spin factor of 2.0 with configurable g_s parameter (default 2.0). Update calculation to use 2πg_s/volume factor. Change default method from 'vectorized' to 'loop' for consistency. * test: add optical conductivity test for silicon model Add comprehensive unit test for optical conductivity calculations using a silicon tight-binding model. The test validates Fermi level calculation, optical conductivity computation with both JIT and loop methods, and consistency between different broadening functions (Lorentzian and Gaussian). Includes reference value comparisons and physical property validation. * fix: improve broadening validation and example notebook - Fix inconsistent indentation in optical_conductivity.py ACAccessor - Add explicit validation for broadening type with clear error message - Update example notebook with non-orthogonal model demonstration - Reset execution counts and add missing import warnings * update(example): refresh plot output in optical_cond.ipynb The notebook's execution counts and matplotlib plot outputs have been updated. This reflects re-running the code cells to generate the latest visualization, ensuring the saved notebook state is consistent with its code. * test: relax tolerance for optical conductivity tests Adjusted assertion tolerances from 1e-5 to 1e-4 in optical conductivity unit tests to accommodate numerical precision differences while maintaining test validity. * feat: add overlap override support for DeePTB calculator Enable override of overlap matrices in DeePTB models by adding override_overlap parameter to DeePTBAdapter and TBSystem. This allows users to provide custom overlap files even when the original model includes overlap support, improving flexibility for tight-binding calculations.</think> Based on the provided git diff, generate a concise and descriptive commit message. The commit message should: 1. Has a short title (50-72 characters) 2. The commit message should adhere to the conventional commit format 3. Describe what was changed and why 4. De clear and informative 'git --no-pager diff --staged --diff-filter=d' Output: diff --git a/dptb/sktb/neighbors.py b/dptb/sktb/neighbors.py index 5128b57..689ec19 1006 Slater-Koster tight binding model module.""" + +import numpy as np import torch from torch import nn from torch.autograd import grad from prefetch_generator import prefetch from dptb.nn.ops.e3gnn import ( @@ -263,6 +264,11 @@ class NearestNeighborsFinder(nn.Module): self.max_num_neighbors = max_num_neighbors self.reference_distance = reference_distance self.neighbor_indexing = neighbor_indexing + if use_efficient_format: + if self.max_num_neighbors is None: + self.max_num_neighbors = 200 + log.info(f"use_efficient_format=True but max_num_neighbors is not specified, setting to default value 200.") self.use_efficient_format = use_efficient_format def forward( @@ -525,6 +55, neighbor_list, atom_mask, positions, cell, offset_vecs = build_neighbor_list( positions=positions, cell=missing_masked_cell, - neighbor_cutoff=neighbor_cutoff, + neighbor_cutoff=neighbor_cutoff,</think> <think></think> feat: add efficient neighbor format with default max neighbors Introduce use_efficient_format option to NearestNeighborsFinder for optimized neighbor list handling. Automatically sets max_num_neighbors to 200 when not specified to improve performance with efficient format processing. * refactor: rename TBSystems example notebooks for clarity - Rename new_postprocess_band_dos.ipynb to tbsys_band_dos.ipynb - Rename new_postprocess_pythtb_wannier.ipynb to tbsys_to_pythtb_wannier.ipynb - Update kpath configuration and output filename in band DOS example
1 parent 7e55619 commit a36c37d

27 files changed

+5529
-79
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
debug_**
22
test*.ipynb
3+
*_bands.png
4+
*_dos.png
35
examples/**/*centres.xyz
46
examples/**/*.win
57
**/processed*/*

dptb/data/_keys.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,11 @@
3232
KPOINT_KEY = "kpoint"
3333

3434
HAMILTONIAN_KEY = "hamiltonian"
35+
HAMILTONIAN_DERIV_KEY = "hamiltonian_derivative" # dH(k)/dk
3536

3637
OVERLAP_KEY = "overlap"
38+
OVERLAP_DERIV_KEY = "overlap_derivative" # dS(k)/dk
39+
3740
# [n_batch, 3] bool tensor
3841
PBC_KEY: Final[str] = "pbc"
3942
# [n_atom, 1] long tensor

dptb/nn/build.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66
from dptb.utils.tools import j_must_have, j_loader
77
import copy
8-
8+
import os
99
log = logging.getLogger(__name__)
1010

1111
def build_model(
@@ -43,6 +43,10 @@ def build_model(
4343

4444
# load the model_options and common_options from checkpoint if not provided
4545
if not from_scratch:
46+
if checkpoint in ['poly2', 'poly4']:
47+
modelname = f'base_{checkpoint}.pth'
48+
checkpoint = os.path.join(os.path.dirname(__file__), 'dftb', modelname)
49+
4650
if checkpoint.split(".")[-1] == "json":
4751
ckptconfig = j_loader(checkpoint)
4852
else:

dptb/nn/hr2hk.py

Lines changed: 83 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99

1010
class HR2HK(torch.nn.Module):
11+
# this is actually a general FFT from real space hamiltonian/overlap to kspace hamiltonian/overlap
12+
# the more correct name should be HSR2HSK. But to keep consistent with previous naming convention, we still use HR2HK here.
1113
def __init__(
1214
self,
1315
basis: Dict[str, Union[str, list]]=None,
@@ -18,9 +20,18 @@ def __init__(
1820
overlap: bool = False,
1921
dtype: Union[str, torch.dtype] = torch.float32,
2022
device: Union[str, torch.device] = torch.device("cpu"),
23+
derivative:bool = False,
24+
out_derivative_field: str = AtomicDataDict.HAMILTONIAN_DERIV_KEY,
25+
gauge: bool = False
2126
):
27+
# gauge: False -> Tight-binding Convention I: Wannier90 Gauge
28+
# gauge: True -> Tight-binding Convention II: "Physical Gauge"/"Periodic Gauge"
2229
super(HR2HK, self).__init__()
23-
30+
31+
if derivative:
32+
gauge = True
33+
self.gauge = gauge
34+
self.derivative = derivative
2435
if isinstance(dtype, str):
2536
dtype = getattr(torch, dtype)
2637
self.dtype = dtype
@@ -44,15 +55,17 @@ def __init__(
4455
self.edge_field = edge_field
4556
self.node_field = node_field
4657
self.out_field = out_field
47-
48-
49-
58+
self.out_derivative_field = out_derivative_field
5059

5160
def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
5261

5362
# construct bond wise hamiltonian block from obital pair wise node/edge features
5463
# we assume the edge feature have the similar format as the node feature, which is reduced from orbitals index oj-oi with j>i
5564

65+
# Ensure edge_vectors are computed if using gauge mode
66+
if self.gauge:
67+
data = AtomicDataDict.with_edge_vectors(data, with_lengths=True)
68+
5669
orbpair_hopping = data[self.edge_field]
5770
orbpair_onsite = data.get(self.node_field)
5871
bondwise_hopping = torch.zeros((len(orbpair_hopping), self.idp.full_basis_norb, self.idp.full_basis_norb), dtype=self.dtype, device=self.device)
@@ -67,15 +80,12 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
6780
soc = data.get(AtomicDataDict.NODE_SOC_SWITCH_KEY, False)
6881
if isinstance(soc, torch.Tensor):
6982
soc = soc.all()
70-
if soc:
71-
# if self.overlap:
72-
# print("Overlap for SOC is realized by kronecker product.")
73-
83+
if soc:
84+
# this soc only support sktb.
7485
orbpair_soc = data[AtomicDataDict.NODE_SOC_KEY]
7586
soc_upup_block = torch.zeros((len(data[AtomicDataDict.ATOM_TYPE_KEY]), self.idp.full_basis_norb, self.idp.full_basis_norb), dtype=self.ctype, device=self.device)
7687
soc_updn_block = torch.zeros((len(data[AtomicDataDict.ATOM_TYPE_KEY]), self.idp.full_basis_norb, self.idp.full_basis_norb), dtype=self.ctype, device=self.device)
7788

78-
7989
ist = 0
8090
for i,iorb in enumerate(self.idp.full_basis):
8191
jst = 0
@@ -92,45 +102,53 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
92102

93103
if i <= j:
94104
bondwise_hopping[:,ist:ist+2*li+1,jst:jst+2*lj+1] = factor * orbpair_hopping[:,self.idp.orbpair_maps[orbpair]].reshape(-1, 2*li+1, 2*lj+1)
105+
onsite_block[:,ist:ist+2*li+1,jst:jst+2*lj+1] = factor * orbpair_onsite[:,self.idp.orbpair_maps[orbpair]].reshape(-1, 2*li+1, 2*lj+1)
95106

96-
97-
# constructing onsite blocks
98-
if self.overlap:
99-
# if iorb == jorb:
100-
# onsite_block[:, ist:ist+2*li+1, jst:jst+2*lj+1] = factor * torch.eye(2*li+1, dtype=self.dtype, device=self.device).reshape(1, 2*li+1, 2*lj+1).repeat(onsite_block.shape[0], 1, 1)
101-
if i <= j:
102-
onsite_block[:,ist:ist+2*li+1,jst:jst+2*lj+1] = factor * orbpair_onsite[:,self.idp.orbpair_maps[orbpair]].reshape(-1, 2*li+1, 2*lj+1)
103-
104-
if soc and i == j:
105-
soc_updn_tmp = orbpair_soc[:, self.idp.orbpair_soc_maps[orbpair]].reshape(-1, 2*li+1, 2*(2*lj+1))
106-
# j==i -> 2*lj+1 == 2*li+1
107-
soc_upup_block[:, ist:ist+2*li+1, jst:jst+2*lj+1] = soc_updn_tmp[:, :2*li+1, :2*lj+1]
108-
soc_updn_block[:, ist:ist+2*li+1, jst:jst+2*lj+1] = soc_updn_tmp[:, :2*li+1, 2*lj+1:]
109-
else:
110-
if i <= j:
111-
onsite_block[:,ist:ist+2*li+1,jst:jst+2*lj+1] = factor * orbpair_onsite[:,self.idp.orbpair_maps[orbpair]].reshape(-1, 2*li+1, 2*lj+1)
112-
113-
if soc and i==j:
107+
if soc and i==j and not self.overlap:
108+
# For now, The SOC part is only added to Hamiltonian, not overlap matrix.
109+
# For now, The SOC only has onsite contribution.
114110
soc_updn_tmp = orbpair_soc[:,self.idp.orbpair_soc_maps[orbpair]].reshape(-1, 2*li+1, 2*(2*lj+1))
115111
soc_upup_block[:,ist:ist+2*li+1,jst:jst+2*lj+1] = soc_updn_tmp[:, :2*li+1,:2*lj+1]
116112
soc_updn_block[:,ist:ist+2*li+1,jst:jst+2*lj+1] = soc_updn_tmp[:, :2*li+1,2*lj+1:]
113+
114+
# constructing onsite blocks
115+
#if self.overlap:
116+
# # if iorb == jorb:
117+
# # onsite_block[:, ist:ist+2*li+1, jst:jst+2*lj+1] = factor * torch.eye(2*li+1, dtype=self.dtype, device=self.device).reshape(1, 2*li+1, 2*lj+1).repeat(onsite_block.shape[0], 1, 1)
118+
# if i <= j:
119+
# onsite_block[:,ist:ist+2*li+1,jst:jst+2*lj+1] = factor * orbpair_onsite[:,self.idp.orbpair_maps[orbpair]].reshape(-1, 2*li+1, 2*lj+1)
120+
# if soc and i == j:
121+
# soc_updn_tmp = orbpair_soc[:, self.idp.orbpair_soc_maps[orbpair]].reshape(-1, 2*li+1, 2*(2*lj+1))
122+
# # j==i -> 2*lj+1 == 2*li+1
123+
# soc_upup_block[:, ist:ist+2*li+1, jst:jst+2*lj+1] = soc_updn_tmp[:, :2*li+1, :2*lj+1]
124+
# soc_updn_block[:, ist:ist+2*li+1, jst:jst+2*lj+1] = soc_updn_tmp[:, :2*li+1, 2*lj+1:]
125+
#else:
126+
# if i <= j:
127+
# onsite_block[:,ist:ist+2*li+1,jst:jst+2*lj+1] = factor * orbpair_onsite[:,self.idp.orbpair_maps[orbpair]].reshape(-1, 2*li+1, 2*lj+1)
128+
#
129+
# if soc and i==j:
130+
# soc_updn_tmp = orbpair_soc[:,self.idp.orbpair_soc_maps[orbpair]].reshape(-1, 2*li+1, 2*(2*lj+1))
131+
# soc_upup_block[:,ist:ist+2*li+1,jst:jst+2*lj+1] = soc_updn_tmp[:, :2*li+1,:2*lj+1]
132+
# soc_updn_block[:,ist:ist+2*li+1,jst:jst+2*lj+1] = soc_updn_tmp[:, :2*li+1,2*lj+1:]
117133

118134
jst += 2*lj+1
119135
ist += 2*li+1
120136
self.onsite_block = onsite_block
121137
self.bondwise_hopping = bondwise_hopping
122-
if soc:
123-
# 先保存已有的
138+
if soc and not self.overlap:
139+
# store for later use
140+
# for now, soc only contribute to Hamiltonain, thus for overlap not store soc parts.
124141
self.soc_upup_block = soc_upup_block
125142
self.soc_updn_block = soc_updn_block
126143

127144
# R2K procedure can be done for all kpoint at once.
128145
all_norb = self.idp.atom_norb[data[AtomicDataDict.ATOM_TYPE_KEY]].sum()
129146
block = torch.zeros(kpoints.shape[0], all_norb, all_norb, dtype=self.ctype, device=self.device)
130-
# block = torch.complex(block, torch.zeros_like(block))
131-
# if data[AtomicDataDict.NODE_SOC_SWITCH_KEY].all():
132-
# block_uu = torch.zeros(data[AtomicDataDict.KPOINT_KEY].shape[0], all_norb, all_norb, dtype=self.ctype, device=self.device)
133-
# block_ud = torch.zeros(data[AtomicDataDict.KPOINT_KEY].shape[0], all_norb, all_norb, dtype=self.ctype, device=self.device)
147+
148+
# Initialize derivative blocks if needed: dH/dk = [dH/dkx, dH/dky, dH/dkz]
149+
if self.derivative:
150+
dblock = torch.zeros(kpoints.shape[0], all_norb, all_norb, 3, dtype=self.ctype, device=self.device)
151+
134152
atom_id_to_indices = {}
135153
ist = 0
136154
for i, oblock in enumerate(onsite_block):
@@ -139,21 +157,7 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
139157
block[:,ist:ist+masked_oblock.shape[0],ist:ist+masked_oblock.shape[1]] = masked_oblock.squeeze(0)
140158
atom_id_to_indices[i] = slice(ist, ist+masked_oblock.shape[0])
141159
ist += masked_oblock.shape[0]
142-
143-
# if data[AtomicDataDict.NODE_SOC_SWITCH_KEY].all():
144-
# ist = 0
145-
# for i, soc_block in enumerate(soc_upup_block):
146-
# mask = self.idp.mask_to_basis[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()[i]]
147-
# masked_soc_block = soc_block[mask][:,mask]
148-
# block_uu[:,ist:ist+masked_soc_block.shape[0],ist:ist+masked_soc_block.shape[1]] = masked_soc_block.squeeze(0)
149-
# ist += masked_soc_block.shape[0]
150-
# ist = 0
151-
# for i, soc_block in enumerate(soc_updn_block):
152-
# mask = self.idp.mask_to_basis[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()[i]]
153-
# masked_soc_block = soc_block[mask][:,mask]
154-
# block_ud[:,ist:ist+masked_soc_block.shape[0],ist:ist+masked_soc_block.shape[1]] = masked_soc_block.squeeze(0)
155-
# ist += masked_soc_block.shape[0]
156-
160+
157161
for i, hblock in enumerate(bondwise_hopping):
158162
iatom = data[AtomicDataDict.EDGE_INDEX_KEY][0][i]
159163
jatom = data[AtomicDataDict.EDGE_INDEX_KEY][1][i]
@@ -163,12 +167,37 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
163167
jmask = self.idp.mask_to_basis[data[AtomicDataDict.ATOM_TYPE_KEY].flatten()[jatom]]
164168
masked_hblock = hblock[imask][:,jmask]
165169

166-
block[:,iatom_indices,jatom_indices] += masked_hblock.squeeze(0).type_as(block) * \
167-
torch.exp(-1j * 2 * torch.pi * (kpoints @ data[AtomicDataDict.EDGE_CELL_SHIFT_KEY][i])).reshape(-1,1,1)
170+
if self.gauge:
171+
# phase factor according to convention II
172+
# k and R are in fractional coordinates, need to convert to cartesian
173+
edge_vec = data[AtomicDataDict.EDGE_VECTORS_KEY][i] # Cartesian coordinates
174+
phase_factor = torch.exp(-1j * 2 * torch.pi * (
175+
kpoints @ data[AtomicDataDict.CELL_KEY].inverse().T @ edge_vec)).reshape(-1,1,1)
176+
# Compute derivative: dH/dk_alpha = -i * R_alpha * H_R * exp(-i k·R)
177+
# where R is edge_vec in Cartesian coordinates
178+
if self.derivative:
179+
# derivative_factor shape: [n_kpoints, 1, 1, 3]
180+
# - i * R * exp(-i k·R) = -i * R * phase_factor
181+
derivative_factor = (-1.0j * edge_vec).reshape(1, 1, 1, 3) * phase_factor.unsqueeze(-1)
182+
else:
183+
phase_factor = torch.exp(-1j * 2 * torch.pi * (
184+
kpoints @ data[AtomicDataDict.EDGE_CELL_SHIFT_KEY][i])).reshape(-1,1,1)
185+
186+
block[:,iatom_indices,jatom_indices] += masked_hblock.squeeze(0).type_as(block) * phase_factor
187+
188+
if self.derivative and self.gauge:
189+
# Add derivative contribution
190+
dblock[:,iatom_indices,jatom_indices,:] += masked_hblock.squeeze(0).type_as(dblock).unsqueeze(-1) * derivative_factor
168191

169192
block = block + block.transpose(1,2).conj()
170193
block = block.contiguous()
171194

195+
# Hermitianize derivative blocks: dH/dk should also be Hermitian
196+
if self.derivative:
197+
for alpha in range(3):
198+
dblock[:,:,:,alpha] = dblock[:,:,:,alpha] + dblock[:,:,:,alpha].transpose(1,2).conj()
199+
dblock = dblock.contiguous()
200+
172201
if soc:
173202
if self.overlap:
174203
# ========== S_soc = S ⊗ I₂ : N×N S(k) to 2N×2N kronecker product ==========
@@ -182,10 +211,6 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
182211
data[self.out_field] = S_soc
183212
else:
184213
HK_SOC = torch.zeros(kpoints.shape[0], 2*all_norb, 2*all_norb, dtype=self.ctype, device=self.device)
185-
#HK_SOC[:,:all_norb,:all_norb] = block + block_uu
186-
#HK_SOC[:,:all_norb,all_norb:] = block_ud
187-
#HK_SOC[:,all_norb:,:all_norb] = block_ud.conj()
188-
#HK_SOC[:,all_norb:,all_norb:] = block + block_uu.conj()
189214
ist = 0
190215
assert len(soc_upup_block) == len(soc_updn_block)
191216
for i in range(len(soc_upup_block)):
@@ -207,6 +232,10 @@ def forward(self, data: AtomicDataDict.Type) -> AtomicDataDict.Type:
207232
data[self.out_field] = HK_SOC
208233
else:
209234
data[self.out_field] = block
235+
236+
# Store derivative if computed
237+
if self.derivative:
238+
data[self.out_derivative_field] = dblock
210239

211240
return data
212241

dptb/nn/tensor_product.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,10 @@ def batch_wigner_D(l_max, alpha, beta, gamma, _Jd):
7474
D_total = sum(dims)
7575

7676
# Construct block-diagonal J matrix
77-
J_full_small = torch.zeros(D_total, D_total, device=device)
77+
J_full_small = torch.zeros(D_total, D_total, device=device, dtype=alpha.dtype)
7878
for l in range(l_max + 1):
7979
start = offsets[l]
80-
J_full_small[start:start+2*l+1, start:start+2*l+1] = _Jd[l]
80+
J_full_small[start:start+2*l+1, start:start+2*l+1] = _Jd[l].to(dtype=alpha.dtype)
8181

8282
J_full = J_full_small.unsqueeze(0).expand(N, -1, -1)
8383
angle_stack = torch.cat([alpha, beta, gamma], dim=0)

dptb/postprocess/common.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
from typing import Union, Optional
77
from copy import deepcopy
88
from ase.io import read
9-
9+
import sys
1010
from dptb.data import AtomicData, AtomicDataDict, block_to_feature
1111
from dptb.utils.argcheck import get_cutoffs_from_model_options
12+
import matplotlib.pyplot as plt
1213

1314
log = logging.getLogger(__name__)
1415

@@ -110,3 +111,61 @@ def load_data_for_model(
110111
# Actually, ElecStruCal.get_data does NOT run self.model(data). It runs self.model.idp(data).
111112
# self.get_eigs runs self.model(data).
112113
return data_obj
114+
115+
def is_gui_available():
116+
"""
117+
Detect if GUI display is available for matplotlib.
118+
119+
Returns:
120+
bool: True if GUI is available, False otherwise
121+
"""
122+
try:
123+
# Check if we're in a Jupyter notebook environment
124+
if 'ipykernel' in sys.modules or 'IPython' in sys.modules:
125+
# In Jupyter, we can typically show plots
126+
return True
127+
128+
# Check DISPLAY environment variable (Unix-like systems)
129+
if sys.platform.startswith('linux') or sys.platform.startswith('darwin'):
130+
display = os.environ.get('DISPLAY')
131+
if display is None:
132+
return False
133+
134+
# Try to get the current matplotlib backend
135+
backend = plt.get_backend().lower()
136+
137+
# Non-interactive backends
138+
non_gui_backends = ['agg', 'pdf', 'ps', 'svg', 'cairo', 'gdk', 'template']
139+
if any(non_gui in backend for non_gui in non_gui_backends):
140+
return False
141+
142+
# Try to create a test figure to see if it works
143+
# This is a more robust check
144+
try:
145+
import matplotlib
146+
# Save current backend
147+
current_backend = matplotlib.get_backend()
148+
149+
# Try to use a GUI backend if not already
150+
if 'agg' in backend.lower():
151+
# Try common GUI backends
152+
for test_backend in ['TkAgg', 'Qt5Agg', 'Qt4Agg', 'WXAgg']:
153+
try:
154+
matplotlib.use(test_backend, force=True)
155+
test_fig = plt.figure()
156+
plt.close(test_fig)
157+
matplotlib.use(current_backend, force=True)
158+
return True
159+
except:
160+
continue
161+
return False
162+
else:
163+
# Current backend seems to be GUI-based
164+
return True
165+
166+
except Exception:
167+
return False
168+
169+
except Exception:
170+
# If any error occurs, assume no GUI is available
171+
return False

0 commit comments

Comments
 (0)