Skip to content

Commit 8fc6bae

Browse files
committed
address review comments
1 parent 5e172fe commit 8fc6bae

File tree

2 files changed

+67
-18
lines changed

2 files changed

+67
-18
lines changed

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ Diffractor = "=0.2.6"
6565
Enzyme = "0.13.39"
6666
EnzymeCore = "0.8.8"
6767
FastDifferentiation = "0.4.3"
68-
HyperHessians = "0.1"
68+
HyperHessians = "0.2"
6969
FiniteDiff = "2.27.0"
7070
FiniteDifferences = "0.12.31"
7171
ForwardDiff = "0.10.36,1"

DifferentiationInterface/ext/DifferentiationInterfaceHyperHessiansExt/DifferentiationInterfaceHyperHessiansExt.jl

Lines changed: 66 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using HyperHessians:
99
Chunk,
1010
chunksize,
1111
pickchunksize,
12+
HyperDual,
1213
hessian,
1314
hessian!,
1415
hessian_gradient_value,
@@ -22,7 +23,6 @@ using HyperHessians:
2223
## Traits
2324
DI.check_available(::DI.AutoHyperHessians) = true
2425
DI.inplace_support(::DI.AutoHyperHessians) = DI.InPlaceSupported()
25-
DI.hvp_mode(::DI.AutoHyperHessians) = DI.ForwardOverForward()
2626
DI.mode(::DI.AutoHyperHessians) = ForwardMode()
2727

2828
chunk_from_backend(backend::DI.AutoHyperHessians, x) =
@@ -35,17 +35,52 @@ function DI.pick_batchsize(backend::DI.AutoHyperHessians, x::AbstractArray)
3535
return DI.BatchSizeSettings{B}(length(x))
3636
end
3737

38+
function DI.pick_batchsize(backend::DI.AutoHyperHessians, N::Integer)
39+
B = chunksize(chunk_from_backend(backend, N, Float64))
40+
return DI.BatchSizeSettings{B}(N)
41+
end
42+
43+
function _translate_toprep(::Type{T}, c::Union{DI.GeneralizedConstant, DI.ConstantOrCache}) where {T}
44+
return nothing
45+
end
46+
function _translate_toprep(::Type{T}, c::DI.Cache) where {T}
47+
return DI.recursive_similar(DI.unwrap(c), T)
48+
end
49+
50+
function translate_toprep(::Type{T}, contexts::NTuple{C, DI.Context}) where {T, C}
51+
new_contexts = map(contexts) do c
52+
_translate_toprep(T, c)
53+
end
54+
return new_contexts
55+
end
56+
57+
function _translate_prepared(c::Union{DI.GeneralizedConstant, DI.ConstantOrCache}, _pc)
58+
return DI.unwrap(c)
59+
end
60+
_translate_prepared(_c::DI.Cache, pc) = pc
61+
62+
function translate_prepared(
63+
contexts::NTuple{C, DI.Context}, prep_contexts::NTuple{C, Any}
64+
) where {C}
65+
new_contexts = map(contexts, prep_contexts) do c, pc
66+
_translate_prepared(c, pc)
67+
end
68+
return new_contexts
69+
end
70+
3871
## Second derivative (scalar input)
3972

40-
struct HyperHessiansSecondDerivativePrep{SIG} <: DI.SecondDerivativePrep{SIG}
73+
struct HyperHessiansSecondDerivativePrep{SIG, C} <: DI.SecondDerivativePrep{SIG}
4174
_sig::Val{SIG}
75+
contexts_prepared::C
4276
end
4377

4478
function DI.prepare_second_derivative_nokwarg(
4579
strict::Val, f, backend::DI.AutoHyperHessians, x::Number, contexts::Vararg{DI.Context, C}
4680
) where {C}
4781
_sig = DI.signature(f, backend, x, contexts...; strict)
48-
return HyperHessiansSecondDerivativePrep(_sig)
82+
contexts_prepared = translate_toprep(HyperDual{1, 1, typeof(x)}, contexts)
83+
return HyperHessiansSecondDerivativePrep(_sig, contexts_prepared)
4984
end
5085

5186
function DI.second_derivative(
@@ -56,7 +91,8 @@ function DI.second_derivative(
5691
contexts::Vararg{DI.Context, C},
5792
) where {C}
5893
DI.check_prep(f, prep, backend, x, contexts...)
59-
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
94+
contexts_prepared = translate_prepared(contexts, prep.contexts_prepared)
95+
fc = DI.fix_tail(f, contexts_prepared...)
6096
return hessian(fc, x)
6197
end
6298

@@ -81,7 +117,8 @@ function DI.value_derivative_and_second_derivative(
81117
contexts::Vararg{DI.Context, C},
82118
) where {C}
83119
DI.check_prep(f, prep, backend, x, contexts...)
84-
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
120+
contexts_prepared = translate_prepared(contexts, prep.contexts_prepared)
121+
fc = DI.fix_tail(f, contexts_prepared...)
85122
res = hessian_gradient_value(fc, x)
86123
return res.value, res.gradient, res.hessian
87124
end
@@ -104,14 +141,16 @@ end
104141

105142
## Preparation structs
106143

107-
struct HyperHessiansHessianPrep{SIG, C} <: DI.HessianPrep{SIG}
144+
struct HyperHessiansHessianPrep{SIG, C, CP} <: DI.HessianPrep{SIG}
108145
_sig::Val{SIG}
109146
cfg::C
147+
contexts_prepared::CP
110148
end
111149

112-
struct HyperHessiansHVPPrep{SIG, C} <: DI.HVPPrep{SIG}
150+
struct HyperHessiansHVPPrep{SIG, C, CP} <: DI.HVPPrep{SIG}
113151
_sig::Val{SIG}
114152
cfg::C
153+
contexts_prepared::CP
115154
end
116155

117156
## Hessian
@@ -121,7 +160,8 @@ function DI.prepare_hessian_nokwarg(
121160
) where {C}
122161
_sig = DI.signature(f, backend, x, contexts...; strict)
123162
cfg = HessianConfig(x, chunk_from_backend(backend, x))
124-
return HyperHessiansHessianPrep(_sig, cfg)
163+
contexts_prepared = translate_toprep(eltype(cfg.duals), contexts)
164+
return HyperHessiansHessianPrep(_sig, cfg, contexts_prepared)
125165
end
126166

127167
function DI.hessian(
@@ -132,7 +172,8 @@ function DI.hessian(
132172
contexts::Vararg{DI.Context, C},
133173
) where {C}
134174
DI.check_prep(f, prep, backend, x, contexts...)
135-
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
175+
contexts_prepared = translate_prepared(contexts, prep.contexts_prepared)
176+
fc = DI.fix_tail(f, contexts_prepared...)
136177
return hessian(fc, x, prep.cfg)
137178
end
138179

@@ -145,7 +186,8 @@ function DI.hessian!(
145186
contexts::Vararg{DI.Context, C},
146187
) where {C}
147188
DI.check_prep(f, prep, backend, x, contexts...)
148-
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
189+
contexts_prepared = translate_prepared(contexts, prep.contexts_prepared)
190+
fc = DI.fix_tail(f, contexts_prepared...)
149191
return hessian!(hess, fc, x, prep.cfg)
150192
end
151193

@@ -157,7 +199,8 @@ function DI.value_gradient_and_hessian(
157199
contexts::Vararg{DI.Context, C},
158200
) where {C}
159201
DI.check_prep(f, prep, backend, x, contexts...)
160-
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
202+
contexts_prepared = translate_prepared(contexts, prep.contexts_prepared)
203+
fc = DI.fix_tail(f, contexts_prepared...)
161204
res = hessian_gradient_value(fc, x, prep.cfg)
162205
return res.value, res.gradient, res.hessian
163206
end
@@ -172,7 +215,8 @@ function DI.value_gradient_and_hessian!(
172215
contexts::Vararg{DI.Context, C},
173216
) where {C}
174217
DI.check_prep(f, prep, backend, x, contexts...)
175-
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
218+
contexts_prepared = translate_prepared(contexts, prep.contexts_prepared)
219+
fc = DI.fix_tail(f, contexts_prepared...)
176220
val = hessian_gradient_value!(hess, grad, fc, x, prep.cfg)
177221
return val, grad, hess
178222
end
@@ -184,7 +228,8 @@ function DI.prepare_hvp_nokwarg(
184228
) where {C}
185229
_sig = DI.signature(f, backend, x, tx, contexts...; strict)
186230
cfg = HVPConfig(x, tx, chunk_from_backend(backend, x))
187-
return HyperHessiansHVPPrep(_sig, cfg)
231+
contexts_prepared = translate_toprep(eltype(cfg.duals), contexts)
232+
return HyperHessiansHVPPrep(_sig, cfg, contexts_prepared)
188233
end
189234

190235
function DI.prepare_hvp_same_point(
@@ -208,7 +253,8 @@ function DI.hvp(
208253
contexts::Vararg{DI.Context, C},
209254
) where {C}
210255
DI.check_prep(f, prep, backend, x, tx, contexts...)
211-
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
256+
contexts_prepared = translate_prepared(contexts, prep.contexts_prepared)
257+
fc = DI.fix_tail(f, contexts_prepared...)
212258
return hvp(fc, x, tx, prep.cfg)
213259
end
214260

@@ -222,7 +268,8 @@ function DI.hvp!(
222268
contexts::Vararg{DI.Context, C},
223269
) where {C}
224270
DI.check_prep(f, prep, backend, x, tx, contexts...)
225-
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
271+
contexts_prepared = translate_prepared(contexts, prep.contexts_prepared)
272+
fc = DI.fix_tail(f, contexts_prepared...)
226273
return hvp!(tg, fc, x, tx, prep.cfg)
227274
end
228275

@@ -235,7 +282,8 @@ function DI.gradient_and_hvp(
235282
contexts::Vararg{DI.Context, C},
236283
) where {C}
237284
DI.check_prep(f, prep, backend, x, tx, contexts...)
238-
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
285+
contexts_prepared = translate_prepared(contexts, prep.contexts_prepared)
286+
fc = DI.fix_tail(f, contexts_prepared...)
239287
res = hvp_gradient_value(fc, x, tx, prep.cfg)
240288
return res.gradient, res.hvp
241289
end
@@ -251,7 +299,8 @@ function DI.gradient_and_hvp!(
251299
contexts::Vararg{DI.Context, C},
252300
) where {C}
253301
DI.check_prep(f, prep, backend, x, tx, contexts...)
254-
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
302+
contexts_prepared = translate_prepared(contexts, prep.contexts_prepared)
303+
fc = DI.fix_tail(f, contexts_prepared...)
255304
hvp_gradient_value!(tg, grad, fc, x, tx, prep.cfg)
256305
return grad, tg
257306
end

0 commit comments

Comments
 (0)