Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 65 additions & 29 deletions src/qutip_jax/properties.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import jax
import jax.numpy as jnp
from .jaxarray import JaxArray
from .jaxdia import JaxDia, clean_dia
Expand Down Expand Up @@ -39,49 +40,84 @@ def _is_zero(vec, tol):
def _is_conj(vec1, vec2, tol):
return jnp.allclose(vec1, vec2.conj(), atol=tol, rtol=0)

@partial(jit, static_argnames=["shape", "tol"])
def _isherm_dia_jit(offsets, data, shape, tol):
num_rows, num_cols = shape
num_diags = offsets.shape[0]

# Return index of partner diagonal:
def find_partner_idx(target_offset):
mask = (offsets == -target_offset)
idx = jnp.argmax(mask)
return jnp.where(mask[idx], idx, -1)

def check_one_diagonal(i):
offset = offsets[i]
diag_data = data[i]

col_indices = jnp.arange(num_cols)
row_indices = col_indices - offset

mask = (row_indices >= 0) & (row_indices < num_rows)

partner_idx = find_partner_idx(offset)

def handle_no_partner():
return jnp.all(jnp.where(mask, jnp.abs(diag_data) <= tol, True))

def handle_with_partner():
partner_diag_data = data[partner_idx]
partner_aligned = jnp.take(partner_diag_data, row_indices, mode='clip')
diff = jnp.abs(diag_data - partner_aligned.conj())
return jnp.all(jnp.where(mask, diff <= tol, True))

return jnp.where(partner_idx == -1, handle_no_partner(), handle_with_partner())

results = jax.vmap(check_one_diagonal)(jnp.arange(num_diags))
return jnp.all(results)

def isherm_jaxdia(matrix, tol=None):
if matrix.shape[0] != matrix.shape[1]:
return False
tol = tol or qutip.settings.core["atol"]
done = []
for offset, data in zip(matrix.offsets, matrix.data):
if offset in done:
continue
start = max(0, offset)
end = min(matrix.shape[1], matrix.shape[0] + offset)
if -offset not in matrix.offsets:
if not _is_zero(data[start:end], tol):
return False
else:
idx = matrix.offsets.index(-offset)
done.append(-offset)
st = max(0, -offset)
et = min(matrix.shape[1], matrix.shape[0] - offset)
if not _is_conj(data[start:end], matrix.data[idx, st:et], tol):
return False
return True
if tol is None:
tol = qutip.settings.core["atol"]

offset_array = jnp.array(matrix.offsets)
return _isherm_dia_jit(offset_array, matrix.data, matrix.shape, tol)


@jit
def isdiag_jaxarray(matrix):
mat_abs = jnp.abs(matrix._jxa)
return jnp.trace(mat_abs) == jnp.sum(mat_abs)

@partial(jit, static_argnames=["shape"])
def _isdiag_dia_jit(offsets, data, shape):
num_rows, num_cols = shape
num_diags = offsets.shape[0]

def check_one_diagonal(i):
offset = offsets[i]
diag_data = data[i]

col_indices = jnp.arange(num_cols)
row_indices = col_indices - offset
mask = (row_indices >= 0) & (row_indices < num_rows)

all_zero = jnp.all(jnp.where(mask, diag_data == 0, True))
skip = (offset == 0)

return jnp.where(skip, True, all_zero)

results = jax.vmap(check_one_diagonal)(jnp.arange(num_diags))
return jnp.all(results)

def isdiag_jaxdia(matrix):
if matrix.num_diags == 0 or (
matrix.num_diags == 1 and matrix.offsets[0] == 0
):
if matrix.num_diags == 0 or (matrix.num_diags == 1 and matrix.offsets[0] == 0):
return True
for offset, data in zip(matrix.offsets, matrix.data):
if offset == 0:
continue
start = max(0, offset)
end = min(matrix.shape[1], matrix.shape[0] + offset)
if not jnp.all(data[start:end] == 0):
return False
return True

offset_array = jnp.array(matrix.offsets)
return _isdiag_dia_jit(offset_array, matrix.data, matrix.shape)


def iszero_jaxarray(matrix, tol=None):
Expand Down
Loading