@@ -9,6 +9,7 @@ using HyperHessians:
99 Chunk,
1010 chunksize,
1111 pickchunksize,
12+ HyperDual,
1213 hessian,
1314 hessian!,
1415 hessian_gradient_value,
@@ -22,7 +23,6 @@ using HyperHessians:
2223# # Traits
2324DI. check_available (:: DI.AutoHyperHessians ) = true
2425DI. inplace_support (:: DI.AutoHyperHessians ) = DI. InPlaceSupported ()
25- DI. hvp_mode (:: DI.AutoHyperHessians ) = DI. ForwardOverForward ()
2626DI. mode (:: DI.AutoHyperHessians ) = ForwardMode ()
2727
2828chunk_from_backend (backend:: DI.AutoHyperHessians , x) =
@@ -35,17 +35,52 @@ function DI.pick_batchsize(backend::DI.AutoHyperHessians, x::AbstractArray)
3535 return DI. BatchSizeSettings {B} (length (x))
3636end
3737
38+ function DI. pick_batchsize (backend:: DI.AutoHyperHessians , N:: Integer )
39+ B = chunksize (chunk_from_backend (backend, N, Float64))
40+ return DI. BatchSizeSettings {B} (N)
41+ end
42+
43+ function _translate_toprep (:: Type{T} , c:: Union{DI.GeneralizedConstant, DI.ConstantOrCache} ) where {T}
44+ return nothing
45+ end
46+ function _translate_toprep (:: Type{T} , c:: DI.Cache ) where {T}
47+ return DI. recursive_similar (DI. unwrap (c), T)
48+ end
49+
50+ function translate_toprep (:: Type{T} , contexts:: NTuple{C, DI.Context} ) where {T, C}
51+ new_contexts = map (contexts) do c
52+ _translate_toprep (T, c)
53+ end
54+ return new_contexts
55+ end
56+
57+ function _translate_prepared (c:: Union{DI.GeneralizedConstant, DI.ConstantOrCache} , _pc)
58+ return DI. unwrap (c)
59+ end
60+ _translate_prepared (_c:: DI.Cache , pc) = pc
61+
62+ function translate_prepared (
63+ contexts:: NTuple{C, DI.Context} , prep_contexts:: NTuple{C, Any}
64+ ) where {C}
65+ new_contexts = map (contexts, prep_contexts) do c, pc
66+ _translate_prepared (c, pc)
67+ end
68+ return new_contexts
69+ end
70+
3871# # Second derivative (scalar input)
3972
40- struct HyperHessiansSecondDerivativePrep{SIG} <: DI.SecondDerivativePrep{SIG}
73+ struct HyperHessiansSecondDerivativePrep{SIG, C } <: DI.SecondDerivativePrep{SIG}
4174 _sig:: Val{SIG}
75+ contexts_prepared:: C
4276end
4377
4478function DI. prepare_second_derivative_nokwarg (
4579 strict:: Val , f, backend:: DI.AutoHyperHessians , x:: Number , contexts:: Vararg{DI.Context, C}
4680 ) where {C}
4781 _sig = DI. signature (f, backend, x, contexts... ; strict)
48- return HyperHessiansSecondDerivativePrep (_sig)
82+ contexts_prepared = translate_toprep (HyperDual{1 , 1 , typeof (x)}, contexts)
83+ return HyperHessiansSecondDerivativePrep (_sig, contexts_prepared)
4984end
5085
5186function DI. second_derivative (
@@ -56,7 +91,8 @@ function DI.second_derivative(
5691 contexts:: Vararg{DI.Context, C} ,
5792 ) where {C}
5893 DI. check_prep (f, prep, backend, x, contexts... )
59- fc = DI. fix_tail (f, map (DI. unwrap, contexts)... )
94+ contexts_prepared = translate_prepared (contexts, prep. contexts_prepared)
95+ fc = DI. fix_tail (f, contexts_prepared... )
6096 return hessian (fc, x)
6197end
6298
@@ -81,7 +117,8 @@ function DI.value_derivative_and_second_derivative(
81117 contexts:: Vararg{DI.Context, C} ,
82118 ) where {C}
83119 DI. check_prep (f, prep, backend, x, contexts... )
84- fc = DI. fix_tail (f, map (DI. unwrap, contexts)... )
120+ contexts_prepared = translate_prepared (contexts, prep. contexts_prepared)
121+ fc = DI. fix_tail (f, contexts_prepared... )
85122 res = hessian_gradient_value (fc, x)
86123 return res. value, res. gradient, res. hessian
87124end
@@ -104,14 +141,16 @@ end
104141
105142# # Preparation structs
106143
107- struct HyperHessiansHessianPrep{SIG, C} <: DI.HessianPrep{SIG}
144+ struct HyperHessiansHessianPrep{SIG, C, CP } <: DI.HessianPrep{SIG}
108145 _sig:: Val{SIG}
109146 cfg:: C
147+ contexts_prepared:: CP
110148end
111149
112- struct HyperHessiansHVPPrep{SIG, C} <: DI.HVPPrep{SIG}
150+ struct HyperHessiansHVPPrep{SIG, C, CP } <: DI.HVPPrep{SIG}
113151 _sig:: Val{SIG}
114152 cfg:: C
153+ contexts_prepared:: CP
115154end
116155
117156# # Hessian
@@ -121,7 +160,8 @@ function DI.prepare_hessian_nokwarg(
121160 ) where {C}
122161 _sig = DI. signature (f, backend, x, contexts... ; strict)
123162 cfg = HessianConfig (x, chunk_from_backend (backend, x))
124- return HyperHessiansHessianPrep (_sig, cfg)
163+ contexts_prepared = translate_toprep (eltype (cfg. duals), contexts)
164+ return HyperHessiansHessianPrep (_sig, cfg, contexts_prepared)
125165end
126166
127167function DI. hessian (
@@ -132,7 +172,8 @@ function DI.hessian(
132172 contexts:: Vararg{DI.Context, C} ,
133173 ) where {C}
134174 DI. check_prep (f, prep, backend, x, contexts... )
135- fc = DI. fix_tail (f, map (DI. unwrap, contexts)... )
175+ contexts_prepared = translate_prepared (contexts, prep. contexts_prepared)
176+ fc = DI. fix_tail (f, contexts_prepared... )
136177 return hessian (fc, x, prep. cfg)
137178end
138179
@@ -145,7 +186,8 @@ function DI.hessian!(
145186 contexts:: Vararg{DI.Context, C} ,
146187 ) where {C}
147188 DI. check_prep (f, prep, backend, x, contexts... )
148- fc = DI. fix_tail (f, map (DI. unwrap, contexts)... )
189+ contexts_prepared = translate_prepared (contexts, prep. contexts_prepared)
190+ fc = DI. fix_tail (f, contexts_prepared... )
149191 return hessian! (hess, fc, x, prep. cfg)
150192end
151193
@@ -157,7 +199,8 @@ function DI.value_gradient_and_hessian(
157199 contexts:: Vararg{DI.Context, C} ,
158200 ) where {C}
159201 DI. check_prep (f, prep, backend, x, contexts... )
160- fc = DI. fix_tail (f, map (DI. unwrap, contexts)... )
202+ contexts_prepared = translate_prepared (contexts, prep. contexts_prepared)
203+ fc = DI. fix_tail (f, contexts_prepared... )
161204 res = hessian_gradient_value (fc, x, prep. cfg)
162205 return res. value, res. gradient, res. hessian
163206end
@@ -172,7 +215,8 @@ function DI.value_gradient_and_hessian!(
172215 contexts:: Vararg{DI.Context, C} ,
173216 ) where {C}
174217 DI. check_prep (f, prep, backend, x, contexts... )
175- fc = DI. fix_tail (f, map (DI. unwrap, contexts)... )
218+ contexts_prepared = translate_prepared (contexts, prep. contexts_prepared)
219+ fc = DI. fix_tail (f, contexts_prepared... )
176220 val = hessian_gradient_value! (hess, grad, fc, x, prep. cfg)
177221 return val, grad, hess
178222end
@@ -184,7 +228,8 @@ function DI.prepare_hvp_nokwarg(
184228 ) where {C}
185229 _sig = DI. signature (f, backend, x, tx, contexts... ; strict)
186230 cfg = HVPConfig (x, tx, chunk_from_backend (backend, x))
187- return HyperHessiansHVPPrep (_sig, cfg)
231+ contexts_prepared = translate_toprep (eltype (cfg. duals), contexts)
232+ return HyperHessiansHVPPrep (_sig, cfg, contexts_prepared)
188233end
189234
190235function DI. prepare_hvp_same_point (
@@ -208,7 +253,8 @@ function DI.hvp(
208253 contexts:: Vararg{DI.Context, C} ,
209254 ) where {C}
210255 DI. check_prep (f, prep, backend, x, tx, contexts... )
211- fc = DI. fix_tail (f, map (DI. unwrap, contexts)... )
256+ contexts_prepared = translate_prepared (contexts, prep. contexts_prepared)
257+ fc = DI. fix_tail (f, contexts_prepared... )
212258 return hvp (fc, x, tx, prep. cfg)
213259end
214260
@@ -222,7 +268,8 @@ function DI.hvp!(
222268 contexts:: Vararg{DI.Context, C} ,
223269 ) where {C}
224270 DI. check_prep (f, prep, backend, x, tx, contexts... )
225- fc = DI. fix_tail (f, map (DI. unwrap, contexts)... )
271+ contexts_prepared = translate_prepared (contexts, prep. contexts_prepared)
272+ fc = DI. fix_tail (f, contexts_prepared... )
226273 return hvp! (tg, fc, x, tx, prep. cfg)
227274end
228275
@@ -235,7 +282,8 @@ function DI.gradient_and_hvp(
235282 contexts:: Vararg{DI.Context, C} ,
236283 ) where {C}
237284 DI. check_prep (f, prep, backend, x, tx, contexts... )
238- fc = DI. fix_tail (f, map (DI. unwrap, contexts)... )
285+ contexts_prepared = translate_prepared (contexts, prep. contexts_prepared)
286+ fc = DI. fix_tail (f, contexts_prepared... )
239287 res = hvp_gradient_value (fc, x, tx, prep. cfg)
240288 return res. gradient, res. hvp
241289end
@@ -251,7 +299,8 @@ function DI.gradient_and_hvp!(
251299 contexts:: Vararg{DI.Context, C} ,
252300 ) where {C}
253301 DI. check_prep (f, prep, backend, x, tx, contexts... )
254- fc = DI. fix_tail (f, map (DI. unwrap, contexts)... )
302+ contexts_prepared = translate_prepared (contexts, prep. contexts_prepared)
303+ fc = DI. fix_tail (f, contexts_prepared... )
255304 hvp_gradient_value! (tg, grad, fc, x, tx, prep. cfg)
256305 return grad, tg
257306end
0 commit comments