Skip to content
Open
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ Enzyme = "0.13.118"
EnzymeTestUtils = "0.2.5"
JET = "0.9, 0.10"
LinearAlgebra = "1"
Mooncake = "0.4.183"
Mooncake = "0.4.195"
ParallelTestRunner = "2"
Random = "1"
SafeTestsets = "0.1"
Expand Down
9 changes: 8 additions & 1 deletion ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using MatrixAlgebraKit: diagview, sign_safe
using MatrixAlgebraKit: LQViaTransposedQR, TruncationStrategy, NoTruncation, TruncationByValue, AbstractAlgorithm
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eigh_algorithm
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_gesvdj!
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _gpu_heev!, _gpu_heevx!
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _gpu_heev!, _gpu_heevx!, _sylvester, svd_rank
using AMDGPU
using LinearAlgebra
using LinearAlgebra: BlasFloat
Expand Down Expand Up @@ -171,4 +171,11 @@ end
MatrixAlgebraKit._ind_intersect(A::ROCVector{Int}, B::ROCVector{Int}) =
MatrixAlgebraKit._ind_intersect(collect(A), collect(B))

function _sylvester(A::AnyROCMatrix, B::AnyROCMatrix, C::AnyROCMatrix)
hX = sylvester(collect(A), collect(B), collect(C))
return ROCArray(hX)
end

svd_rank(S::AnyROCVector, rank_atol) = findlast(s -> s ≥ rank_atol, S)

end
11 changes: 10 additions & 1 deletion ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using MatrixAlgebraKit: diagview, sign_safe
using MatrixAlgebraKit: LQViaTransposedQR, TruncationByValue, AbstractAlgorithm
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!, _gpu_geev!
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _sylvester, svd_rank
using CUDA, CUDA.CUBLAS
using CUDA: i32
using LinearAlgebra
Expand Down Expand Up @@ -195,4 +195,13 @@ end
MatrixAlgebraKit._ind_intersect(A::CuVector{Int}, B::CuVector{Int}) =
MatrixAlgebraKit._ind_intersect(collect(A), collect(B))

function _sylvester(A::AnyCuMatrix, B::AnyCuMatrix, C::AnyCuMatrix)
# https://github.com/JuliaGPU/CUDA.jl/issues/3021
# to add native sylvester to CUDA
hX = sylvester(collect(A), collect(B), collect(C))
return CuArray(hX)
end

svd_rank(S::AnyCuVector, rank_atol) = findlast(s -> s ≥ rank_atol, S)

end
21 changes: 21 additions & 0 deletions ext/MatrixAlgebraKitChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ for eig in (:eig, :eigh)
eig_t! = Symbol(eig, "_trunc!")
eig_t_pb = Symbol(eig, "_trunc_pullback")
_make_eig_t_pb = Symbol("_make_", eig_t_pb)
eig_t_ne! = Symbol(eig, "_trunc_no_error!")
eig_t_ne_pb = Symbol(eig, "_trunc_no_error_pullback")
_make_eig_t_ne_pb = Symbol("_make_", eig_t_ne_pb)
eig_v = Symbol(eig, "_vals")
eig_v! = Symbol(eig_v, "!")
eig_v_pb = Symbol(eig_v, "_pullback")
Expand Down Expand Up @@ -136,6 +139,24 @@ for eig in (:eig, :eigh)
end
return $eig_t_pb
end
function ChainRulesCore.rrule(::typeof($eig_t_ne!), A, DV, alg::TruncatedAlgorithm)
Ac = copy_input($eig_f, A)
DV = $(eig_f!)(Ac, DV, alg.alg)
DV′, ind = MatrixAlgebraKit.truncate($eig_t!, DV, alg.trunc)
return DV′, $(_make_eig_t_ne_pb)(A, DV, ind)
end
function $(_make_eig_t_ne_pb)(A, DV, ind)
function $eig_t_ne_pb(ΔDV)
ΔA = zero(A)
ΔD, ΔV = ΔDV
MatrixAlgebraKit.$eig_pb!(ΔA, A, DV, unthunk.((ΔD, ΔV)), ind)
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
end
function $eig_t_ne_pb(::Tuple{ZeroTangent, ZeroTangent}) # is this extra definition useful?
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
end
return $eig_t_ne_pb
end
function ChainRulesCore.rrule(::typeof($eig_v!), A, D, alg)
DV = $eig_f(A, alg)
function $eig_v_pb(ΔD)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module MatrixAlgebraKitMooncakeExt
using Mooncake
using Mooncake: DefaultCtx, CoDual, Dual, NoRData, rrule!!, frule!!, arrayify, @is_primitive
using MatrixAlgebraKit
using MatrixAlgebraKit: inv_safe, diagview, copy_input
using MatrixAlgebraKit: inv_safe, diagview, copy_input, initialize_output
using MatrixAlgebraKit: qr_pullback!, lq_pullback!
using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!
using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback!
Expand All @@ -18,14 +18,16 @@ Mooncake.tangent_type(::Type{<:MatrixAlgebraKit.AbstractAlgorithm}) = Mooncake.N
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(copy_input), Any, Any}
function Mooncake.rrule!!(::CoDual{typeof(copy_input)}, f_df::CoDual, A_dA::CoDual)
Ac = copy_input(Mooncake.primal(f_df), Mooncake.primal(A_dA))
dAc = Mooncake.zero_tangent(Ac)
Ac_dAc = Mooncake.zero_fcodual(Ac)
dAc = Mooncake.tangent(Ac_dAc)
function copy_input_pb(::NoRData)
Mooncake.increment!!(Mooncake.tangent(A_dA), dAc)
return NoRData(), NoRData(), NoRData()
end
return CoDual(Ac, dAc), copy_input_pb
return Ac_dAc, copy_input_pb
end

Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(initialize_output), Any, Any, Any}
# two-argument in-place factorizations like LQ, QR, EIG
for (f!, f, pb, adj) in (
(:qr_full!, :qr_full, :qr_pullback!, :qr_adjoint),
Expand Down
1 change: 1 addition & 0 deletions src/common/defaults.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ default_pullback_degeneracy_atol(A) = eps(norm(A, Inf))^(3 / 4)
Default tolerance for deciding what values should be considered equal to 0.
"""
default_pullback_rank_atol(A) = eps(norm(A, Inf))^(3 / 4)
default_pullback_rank_atol(A::Diagonal) = default_pullback_rank_atol(diagview(A))

"""
default_hermitian_tol(A)
Expand Down
3 changes: 3 additions & 0 deletions src/common/pullbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@ function iszerotangent end

iszerotangent(::Any) = false
iszerotangent(::Nothing) = true

# fallback
_sylvester(A, B, C) = LinearAlgebra.sylvester(A, B, C)
24 changes: 15 additions & 9 deletions src/pullbacks/eig.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
function check_eig_cotangents(
D, VᴴΔV;
degeneracy_atol::Real = default_pullback_rank_atol(D),
gauge_atol::Real = default_pullback_gauge_atol(VᴴΔV)
)
mask = abs.(transpose(D) .- D) .< degeneracy_atol
Δgauge = norm(view(VᴴΔV, mask))
Δgauge ≤ gauge_atol ||
@warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
return
end

"""
eig_pullback!(
ΔA::AbstractMatrix, A, DV, ΔDV, [ind];
Expand Down Expand Up @@ -41,10 +53,7 @@ function eig_pullback!(
length(indV) == pV || throw(DimensionMismatch())
mul!(view(VᴴΔV, :, indV), V', ΔV)

mask = abs.(transpose(D) .- D) .< degeneracy_atol
Δgauge = norm(view(VᴴΔV, mask), Inf)
Δgauge ≤ gauge_atol ||
@warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
check_eig_cotangents(D, VᴴΔV; degeneracy_atol, gauge_atol)

VᴴΔV .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol))

Expand Down Expand Up @@ -129,10 +138,7 @@ function eig_trunc_pullback!(
if !iszerotangent(ΔV)
(n, p) == size(ΔV) || throw(DimensionMismatch())
VᴴΔV = V' * ΔV
mask = abs.(transpose(D) .- D) .< degeneracy_atol
Δgauge = norm(view(VᴴΔV, mask), Inf)
Δgauge ≤ gauge_atol ||
@warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
check_eig_cotangents(D, VᴴΔV; degeneracy_atol, gauge_atol)

ΔVperp = ΔV - V * inv(G) * VᴴΔV
VᴴΔV .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol))
Expand All @@ -150,7 +156,7 @@ function eig_trunc_pullback!(
# add contribution from orthogonal complement
PA = A - (A * V) / V
Y = mul!(ΔVperp, PA', Z, 1, 1)
X = sylvester(PA', -Dmat', Y)
X = _sylvester(PA', -Dmat', Y)
Z .+= X

if eltype(ΔA) <: Real
Expand Down
24 changes: 15 additions & 9 deletions src/pullbacks/eigh.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
function check_eigh_cotangents(
D, aVᴴΔV;
degeneracy_atol::Real = default_pullback_rank_atol(D),
gauge_atol::Real = default_pullback_gauge_atol(aVᴴΔV)
)
mask = abs.(D' .- D) .< degeneracy_atol
Δgauge = norm(view(aVᴴΔV, mask))
Δgauge ≤ gauge_atol ||
@warn "`eigh` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
return
end

"""
eigh_pullback!(
ΔA::AbstractMatrix, A, DV, ΔDV, [ind];
Expand Down Expand Up @@ -42,10 +54,7 @@ function eigh_pullback!(
mul!(view(VᴴΔV, :, indV), V', ΔV)
aVᴴΔV = project_antihermitian(VᴴΔV) # can't use in-place or recycling doesn't work

mask = abs.(D' .- D) .< degeneracy_atol
Δgauge = norm(view(aVᴴΔV, mask))
Δgauge ≤ gauge_atol ||
@warn "`eigh` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
check_eigh_cotangents(D, aVᴴΔV; degeneracy_atol, gauge_atol)

aVᴴΔV .*= inv_safe.(D' .- D, degeneracy_atol)

Expand Down Expand Up @@ -120,10 +129,7 @@ function eigh_trunc_pullback!(
VᴴΔV = V' * ΔV
aVᴴΔV = project_antihermitian!(VᴴΔV)

mask = abs.(D' .- D) .< degeneracy_atol
Δgauge = norm(view(aVᴴΔV, mask))
Δgauge ≤ gauge_atol ||
@warn "`eigh` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
check_eigh_cotangents(D, aVᴴΔV; degeneracy_atol, gauge_atol)

aVᴴΔV .*= inv_safe.(D' .- D, degeneracy_atol)

Expand All @@ -138,7 +144,7 @@ function eigh_trunc_pullback!(
# add contribution from orthogonal complement
W = qr_null(V)
WᴴΔV = W' * ΔV
X = sylvester(W' * A * W, -Dmat, WᴴΔV)
X = _sylvester(W' * A * W, -Dmat, WᴴΔV)
Z = mul!(Z, W, X, 1, 1)

# put everything together: symmetrize for hermitian case
Expand Down
80 changes: 49 additions & 31 deletions src/pullbacks/lq.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,41 @@
function check_lq_cotangents(
L, Q, ΔL, ΔQ, minmn::Int, p::Int;
gauge_atol::Real = default_pullback_gauge_atol(ΔQ)
)
if minmn > p # case where A is rank-deficient
Δgauge = abs(zero(eltype(Q)))
if !iszerotangent(ΔQ)
# in this case the number Householder reflections will
# change upon small variations, and all of the remaining
# columns of ΔQ should be zero for a gauge-invariant
# cost function
ΔQ2 = view(ΔQ, (p + 1):size(Q, 1), :)
Δgauge = max(Δgauge, norm(ΔQ2))
end
if !iszerotangent(ΔL)
ΔL22 = view(ΔL, (p + 1):size(L, 1), (p + 1):minmn)
Δgauge = max(Δgauge, norm(ΔL22))
end
Δgauge ≤ gauge_atol ||
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
end
return
end

function check_lq_full_cotangents(Q1, ΔQ2, ΔQ2Q1ᴴ; gauge_atol::Real = default_pullback_gauge_atol(Q1))
# in the case where A is full rank, but there are more columns in Q than in A
# (the case of `lq_full`), there is gauge-invariant information in the
# projection of ΔQ2 onto the column space of Q1, by virtue of Q being a unitary
# matrix. As the number of Householder reflections is in fixed in the full rank
# case, Q is expected to rotate smoothly (we might even be able to predict) also
# how the full Q2 will change, but this we omit for now, and we consider
# Q2' * ΔQ2 as a gauge dependent quantity.
Δgauge = norm(mul!(copy(ΔQ2), ΔQ2Q1ᴴ, Q1, -1, 1), Inf)
Δgauge ≤ gauge_atol ||
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
return
end

"""
lq_pullback!(
ΔA, A, LQ, ΔLQ;
Expand Down Expand Up @@ -36,23 +74,7 @@ function lq_pullback!(
ΔA1 = view(ΔA, 1:p, :)
ΔA2 = view(ΔA, (p + 1):m, :)

if minmn > p # case where A is rank-deficient
Δgauge = abs(zero(eltype(Q)))
if !iszerotangent(ΔQ)
# in this case the number Householder reflections will
# change upon small variations, and all of the remaining
# columns of ΔQ should be zero for a gauge-invariant
# cost function
ΔQ2 = view(ΔQ, (p + 1):size(Q, 1), :)
Δgauge = max(Δgauge, norm(ΔQ2, Inf))
end
if !iszerotangent(ΔL)
ΔL22 = view(ΔL, (p + 1):m, (p + 1):minmn)
Δgauge = max(Δgauge, norm(ΔL22, Inf))
end
Δgauge ≤ gauge_atol ||
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
end
check_lq_cotangents(L, Q, ΔL, ΔQ, minmn, p; gauge_atol)

ΔQ̃ = zero!(similar(Q, (p, n)))
if !iszerotangent(ΔQ)
Expand All @@ -61,17 +83,8 @@ function lq_pullback!(
if p < size(Q, 1)
Q2 = view(Q, (p + 1):size(Q, 1), :)
ΔQ2 = view(ΔQ, (p + 1):size(Q, 1), :)
# in the case where A is full rank, but there are more columns in Q than in A
# (the case of `qr_full`), there is gauge-invariant information in the
# projection of ΔQ2 onto the column space of Q1, by virtue of Q being a unitary
# matrix. As the number of Householder reflections is in fixed in the full rank
# case, Q is expected to rotate smoothly (we might even be able to predict) also
# how the full Q2 will change, but this we omit for now, and we consider
# Q2' * ΔQ2 as a gauge dependent quantity.
ΔQ2Q1ᴴ = ΔQ2 * Q1'
Δgauge = norm(mul!(copy(ΔQ2), ΔQ2Q1ᴴ, Q1, -1, 1), Inf)
Δgauge ≤ gauge_atol ||
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
check_lq_full_cotangents(Q1, ΔQ2, ΔQ2Q1ᴴ; gauge_atol)
ΔQ̃ = mul!(ΔQ̃, ΔQ2Q1ᴴ', Q2, -1, 1)
end
end
Expand Down Expand Up @@ -102,6 +115,14 @@ function lq_pullback!(
return ΔA
end

function check_lq_null_cotangents(Nᴴ, ΔNᴴ; gauge_atol::Real = default_pullback_gauge_atol(ΔNᴴ))
aNᴴΔN = project_antihermitian!(Nᴴ * ΔNᴴ')
Δgauge = norm(aNᴴΔN)
Δgauge ≤ gauge_atol ||
@warn "`lq_null` cotangent sensitive to gauge choice: (|Δgauge| = $Δgauge)"
return
end

"""
lq_null_pullback!(
ΔA::AbstractMatrix, A, Nᴴ, ΔNᴴ;
Expand All @@ -118,10 +139,7 @@ function lq_null_pullback!(
gauge_atol::Real = default_pullback_gauge_atol(ΔNᴴ)
)
if !iszerotangent(ΔNᴴ) && size(Nᴴ, 1) > 0
aNᴴΔN = project_antihermitian!(Nᴴ * ΔNᴴ')
Δgauge = norm(aNᴴΔN)
Δgauge ≤ gauge_atol ||
@warn "`lq_null` cotangent sensitive to gauge choice: (|Δgauge| = $Δgauge)"
check_lq_null_cotangents(Nᴴ, ΔNᴴ; gauge_atol)
L, Q = lq_compact(A; positive = true) # should we be able to provide algorithm here?
X = ldiv!(LowerTriangular(L)', Q * ΔNᴴ')
ΔA = mul!(ΔA, X, Nᴴ, -1, 1)
Expand Down
4 changes: 2 additions & 2 deletions src/pullbacks/polar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ function left_polar_pullback!(ΔA::AbstractMatrix, A, WP, ΔWP; kwargs...)
M = zero(P)
!iszerotangent(ΔW) && mul!(M, W', ΔW, 1, 1)
!iszerotangent(ΔP) && mul!(M, ΔP, P, -1, 1)
C = sylvester(P, P, M' - M)
C = _sylvester(P, P, M' - M)
C .+= ΔP
ΔA = mul!(ΔA, W, C, 1, 1)
if !iszerotangent(ΔW)
Expand Down Expand Up @@ -46,7 +46,7 @@ function right_polar_pullback!(ΔA::AbstractMatrix, A, PWᴴ, ΔPWᴴ; kwargs...
M = zero(P)
!iszerotangent(ΔWᴴ) && mul!(M, ΔWᴴ, Wᴴ', 1, 1)
!iszerotangent(ΔP) && mul!(M, P, ΔP, -1, 1)
C = sylvester(P, P, M' - M)
C = _sylvester(P, P, M' - M)
C .+= ΔP
ΔA = mul!(ΔA, C, Wᴴ, 1, 1)
if !iszerotangent(ΔWᴴ)
Expand Down
Loading
Loading