@@ -16,14 +16,65 @@ Base.propertynames(d::AffineTransform{N}) where {N} = N
1616@inline Base. inv (f:: AffineTransform{(:ω,)} ) = AffineTransform ((σ = f. ω,))
1717@inline Base. inv (f:: AffineTransform{(:μ,)} ) = AffineTransform ((μ = - f. μ,))
1818
19+ # `size(f) == (m,n)` means `f : ℝⁿ → ℝᵐ`
20+ Base. size (f:: AffineTransform{(:μ,:σ)} ) = size (f. σ)
21+ Base. size (f:: AffineTransform{(:μ,:ω)} ) = size (f. ω)
22+ Base. size (f:: AffineTransform{(:σ,)} ) = size (f. σ)
23+ Base. size (f:: AffineTransform{(:ω,)} ) = size (f. ω)
24+
25+ function Base. size (f:: AffineTransform{(:μ,)} )
26+ (n,) = size (f. μ)
27+ return (n,n)
28+ end
29+
30+ Base. size (f:: AffineTransform , n:: Int ) = @inbounds size (f)[n]
31+
1932(f:: AffineTransform{(:μ,)} )(x) = x + f. μ
2033(f:: AffineTransform{(:σ,)} )(x) = f. σ * x
2134(f:: AffineTransform{(:ω,)} )(x) = f. ω \ x
2235(f:: AffineTransform{(:μ,:σ)} )(x) = f. σ * x + f. μ
2336(f:: AffineTransform{(:μ,:ω)} )(x) = f. ω \ x + f. μ
2437
38+ @inline function apply! (x, f:: AffineTransform{(:μ,)} , z)
39+ x .= z .+ f. μ
40+ return x
41+ end
42+
43+ @inline function apply! (x, f:: AffineTransform{(:σ,)} , z)
44+ mul! (x, f. σ, z)
45+ return x
46+ end
47+
48+ @inline function apply! (x, f:: AffineTransform{(:ω,), Tuple{F}} , z) where {F<: Factorization }
49+ ldiv! (x, f. ω, z)
50+ return x
51+ end
52+
53+ @inline function apply! (x, f:: AffineTransform{(:ω,)} , z)
54+ ldiv! (x, factorize (f. ω), z)
55+ return x
56+ end
57+
58+ @inline function apply! (x, f:: AffineTransform{(:μ,:σ)} , z)
59+ apply! (x, AffineTransform ((σ = f. σ,)))
60+ apply! (x, AffineTransform ((μ = f. μ,)))
61+ return x
62+ end
63+
64+ @inline function apply! (x, f:: AffineTransform{(:μ,:ω)} , z)
65+ apply! (x, AffineTransform ((ω = f. ω,)))
66+ apply! (x, AffineTransform ((μ = f. μ,)))
67+ return x
68+ end
69+
70+ function logjac (x:: AbstractMatrix )
71+ (m,n) = size (x)
72+ m == n && return first (logabsdet (x))
2573
26- logjac (x:: AbstractMatrix ) = first (logabsdet (x))
74+ # Equivalent to sum(log, svdvals(x)), but much faster
75+ m > n && return first (logabsdet (x' * x)) / 2
76+ return first (logabsdet (x * x' )) / 2
77+ end
2778
2879logjac (x:: Number ) = log (abs (x))
2980
@@ -41,6 +92,12 @@ logjac(f::AffineTransform{(:μ,)}) = 0.0
4192 parent:: M
4293end
4394
95+ function testvalue (d:: Affine )
96+ f = getfield (d, :f )
97+ z = testvalue (parent (d))
98+ return f (z)
99+ end
100+
44101Affine (nt:: NamedTuple , μ:: AbstractMeasure ) = affine (nt, μ)
45102
46103Affine (nt:: NamedTuple ) = affine (nt)
@@ -57,17 +114,19 @@ function paramnames(::Type{A}) where {N,M, A<:Affine{N,M}}
57114 tuple (union (N, paramnames (M))... )
58115end
59116
60- Base. propertynames (d:: Affine{N} ) where {N} = N ∪ (:parent ,)
117+ Base. propertynames (d:: Affine{N} ) where {N} = N ∪ (:parent ,:f )
61118
62119@inline function Base. getproperty (d:: Affine , s:: Symbol )
63120 if s === :parent
64121 return getfield (d, :parent )
122+ elseif s === :f
123+ return getfield (d, :f )
65124 else
66125 return getproperty (getfield (d, :f ), s)
67126 end
68127end
69128
70- Base. size (d) = size (d. μ)
129+ Base. size (d:: Affine ) = size (d. μ)
71130Base. size (d:: Affine{(:σ,)} ) = (size (d. σ, 1 ),)
72131Base. size (d:: Affine{(:ω,)} ) = (size (d. ω, 2 ),)
73132
@@ -78,14 +137,19 @@ logdensity(d::Affine{(:μ,:σ)}, x) = logdensity(d.parent, d.σ \ (x - d.μ))
78137logdensity (d:: Affine{(:μ,:ω)} , x) = logdensity (d. parent, d. ω * (x - d. μ))
79138
80139# logdensity(d::Affine{(:μ,:ω)}, x) = logdensity(d.parent, d.σ \ (x - d.μ))
81- function logdensity (d:: Affine{(:μ,:σ), Tuple{AbstractVector, AbstractMatrix}} , x)
140+ @inline function logdensity (d:: Affine{(:μ,:σ), Tuple{AbstractVector, AbstractMatrix}} , x)
82141 z = x - d. μ
83- ldiv! (d. σ, z)
142+ σ = d. σ
143+ if σ isa Factorization
144+ ldiv! (σ, z)
145+ else
146+ ldiv! (factorize (σ), z)
147+ end
84148 logdensity (d. parent, z)
85149end
86150
87151# logdensity(d::Affine{(:μ,:ω)}, x) = logdensity(d.parent, d.ω * (x - d.μ))
88- function logdensity (d:: Affine{(:μ,:ω), Tuple{AbstractVector, AbstractMatrix}} , x)
152+ @inline function logdensity (d:: Affine{(:μ,:ω), Tuple{AbstractVector, AbstractMatrix}} , x)
89153 z = x - d. μ
90154 lmul! (d. ω, z)
91155 logdensity (d. parent, z)
103167
104168logjac (d:: Affine ) = logjac (getfield (d, :f ))
105169
106-
107- function Base. rand (rng:: Random.AbstractRNG , :: Type{T} , d:: Affine ) where {T}
108- z = rand (rng, T, parent (d))
170+ function Random. rand! (rng:: Random.AbstractRNG , d:: Affine , x:: AbstractVector{T} , z= Vector {T} (undef, size (getfield (d,:f ),2 ))) where {T}
171+ rand! (rng, parent (d), z)
109172 f = getfield (d, :f )
110- return f (z)
173+ apply! (x, f, z)
174+ return x
111175end
176+
177+
178+ # function Base.rand(rng::Random.AbstractRNG, ::Type{T}, d::Affine) where {T}
179+ # f = getfield(d, :f)
180+ # z = rand(rng, T, parent(d))
181+ # apply!(x, f, z)
182+ # return z
183+ # end
0 commit comments