diff --git a/src/qutip_jax/properties.py b/src/qutip_jax/properties.py index 33dac84..d369e0c 100644 --- a/src/qutip_jax/properties.py +++ b/src/qutip_jax/properties.py @@ -1,3 +1,4 @@ +import jax import jax.numpy as jnp from .jaxarray import JaxArray from .jaxdia import JaxDia, clean_dia @@ -39,28 +40,50 @@ def _is_zero(vec, tol): def _is_conj(vec1, vec2, tol): return jnp.allclose(vec1, vec2.conj(), atol=tol, rtol=0) +@partial(jit, static_argnames=["shape", "tol"]) +def _isherm_dia_jit(offsets, data, shape, tol): + num_rows, num_cols = shape + num_diags = offsets.shape[0] + + # Return index of partner diagonal: + def find_partner_idx(target_offset): + mask = (offsets == -target_offset) + idx = jnp.argmax(mask) + return jnp.where(mask[idx], idx, -1) + + def check_one_diagonal(i): + offset = offsets[i] + diag_data = data[i] + + col_indices = jnp.arange(num_cols) + row_indices = col_indices - offset + + mask = (row_indices >= 0) & (row_indices < num_rows) + + partner_idx = find_partner_idx(offset) + + def handle_no_partner(): + return jnp.all(jnp.where(mask, jnp.abs(diag_data) <= tol, True)) + + def handle_with_partner(): + partner_diag_data = data[partner_idx] + partner_aligned = jnp.take(partner_diag_data, row_indices, mode='clip') + diff = jnp.abs(diag_data - partner_aligned.conj()) + return jnp.all(jnp.where(mask, diff <= tol, True)) + + return jnp.where(partner_idx == -1, handle_no_partner(), handle_with_partner()) + + results = jax.vmap(check_one_diagonal)(jnp.arange(num_diags)) + return jnp.all(results) def isherm_jaxdia(matrix, tol=None): if matrix.shape[0] != matrix.shape[1]: return False - tol = tol or qutip.settings.core["atol"] - done = [] - for offset, data in zip(matrix.offsets, matrix.data): - if offset in done: - continue - start = max(0, offset) - end = min(matrix.shape[1], matrix.shape[0] + offset) - if -offset not in matrix.offsets: - if not _is_zero(data[start:end], tol): - return False - else: - idx = matrix.offsets.index(-offset) - done.append(-offset) - st = max(0, -offset) - et = min(matrix.shape[1], matrix.shape[0] - offset) - if not _is_conj(data[start:end], matrix.data[idx, st:et], tol): - return False - return True + if tol is None: + tol = qutip.settings.core["atol"] + + offset_array = jnp.array(matrix.offsets) + return _isherm_dia_jit(offset_array, matrix.data, matrix.shape, tol) @jit @@ -68,20 +91,33 @@ def isdiag_jaxarray(matrix): mat_abs = jnp.abs(matrix._jxa) return jnp.trace(mat_abs) == jnp.sum(mat_abs) +@partial(jit, static_argnames=["shape"]) +def _isdiag_dia_jit(offsets, data, shape): + num_rows, num_cols = shape + num_diags = offsets.shape[0] + + def check_one_diagonal(i): + offset = offsets[i] + diag_data = data[i] + + col_indices = jnp.arange(num_cols) + row_indices = col_indices - offset + mask = (row_indices >= 0) & (row_indices < num_rows) + + all_zero = jnp.all(jnp.where(mask, diag_data == 0, True)) + skip = (offset == 0) + + return jnp.where(skip, True, all_zero) + + results = jax.vmap(check_one_diagonal)(jnp.arange(num_diags)) + return jnp.all(results) def isdiag_jaxdia(matrix): - if matrix.num_diags == 0 or ( - matrix.num_diags == 1 and matrix.offsets[0] == 0 - ): + if matrix.num_diags == 0 or (matrix.num_diags == 1 and matrix.offsets[0] == 0): return True - for offset, data in zip(matrix.offsets, matrix.data): - if offset == 0: - continue - start = max(0, offset) - end = min(matrix.shape[1], matrix.shape[0] + offset) - if not jnp.all(data[start:end] == 0): - return False - return True + + offset_array = jnp.array(matrix.offsets) + return _isdiag_dia_jit(offset_array, matrix.data, matrix.shape) def iszero_jaxarray(matrix, tol=None):