22 stop_annotating , get_working_tape , set_working_tape
33from pyadjoint .enlisting import Enlist
44from firedrake .function import Function
5- from firedrake .ensemble import EnsembleFunction
5+ from firedrake .ensemble import EnsembleFunction , EnsembleFunctionSpace
66from firedrake import assemble , inner , dx , Constant
77from firedrake .adjoint .composite_reduced_functional import (
88 CompositeReducedFunctional , intermediate_options )
@@ -153,8 +153,8 @@ def __init__(self, control: Control,
153153 self .background = control .control .subfunctions [0 ]._ad_copy ()
154154 _rename (self .background , "Background" )
155155
156- self .control_space = control .function_space ()
157- ensemble = self .control_space .ensemble
156+ self .solution_space = control .function_space ()
157+ ensemble = self .solution_space .ensemble
158158 self .ensemble = ensemble
159159 self .trank = ensemble .ensemble_comm .rank if ensemble else 0
160160 self .nchunks = ensemble .ensemble_comm .size if ensemble else 1
@@ -175,6 +175,12 @@ def __init__(self, control: Control,
175175
176176 self .stages = [] # The record of each observation stage
177177
178+ self .observation_rfs = []
179+ self .observation_norms = []
180+
181+ self .model_rfs = []
182+ self .model_norms = []
183+
178184 # first rank sets up functionals for background initial observations
179185 if self .trank == 0 :
180186
@@ -210,6 +216,9 @@ def __init__(self, control: Control,
210216 observation_covariance ,
211217 control_name = "obs_err_vec_0_copy" )
212218
219+ self .observation_rfs .append (self .initial_observation_error )
220+ self .observation_norms .append (self .initial_observation_norm )
221+
213222 # compose initial observation reduced functionals to evaluate both together
214223 self .initial_observation_rf = CompositeReducedFunctional (
215224 self .initial_observation_error , self .initial_observation_norm )
@@ -392,12 +401,12 @@ def derivative(self, adj_input: float = 1.0, options: dict = {}):
392401
393402 # create the derivative in the right primal or dual space
394403 if is_primal (sderiv0 [0 ]):
395- derivative_space = self .control_space
404+ derivative_space = self .solution_space
396405 else :
397406 if not is_dual (sderiv0 [0 ]):
398407 raise ValueError (
399408 "Do not know how to handle stage derivative which is not primal or dual" )
400- derivative_space = self .control_space .dual ()
409+ derivative_space = self .solution_space .dual ()
401410 derivatives = EnsembleFunction (derivative_space )
402411
403412 derivatives .zero ()
@@ -631,6 +640,12 @@ def recording_stages(self, sequential=True, nstages=None, **stage_kwargs):
631640 # let the user record the local stages
632641 yield stage_sequence
633642
643+ for stage in self .stages :
644+ self .observation_rfs .append (stage .observation_error )
645+ self .observation_norms .append (stage .observation_norm )
646+ self .model_rfs .append (stage .forward_model )
647+ self .model_norms .append (stage .model_norm )
648+
634649 # send the state forward
635650 with stop_annotating ():
636651 state = self .stages [- 1 ].controls [1 ].control
@@ -646,6 +661,10 @@ def recording_stages(self, sequential=True, nstages=None, **stage_kwargs):
646661 # values of the initial timeseris
647662 self .control .assign (self ._cbuf )
648663
664+ self .observation_space = EnsembleFunctionSpace (
665+ [Jo .functional .function_space () for Jo in self .observation_rfs ],
666+ self .ensemble )
667+
649668 else : # strong constraint
650669
651670 yield ObservationStageSequence (
0 commit comments