Skip to content

Commit 1327e79

Browse files
committed
WIP: saddle point pc impl
1 parent 2f625bf commit 1327e79

File tree

1 file changed

+24
-5
lines changed

1 file changed

+24
-5
lines changed

firedrake/adjoint/fourdvar_reduced_functional.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
stop_annotating, get_working_tape, set_working_tape
33
from pyadjoint.enlisting import Enlist
44
from firedrake.function import Function
5-
from firedrake.ensemble import EnsembleFunction
5+
from firedrake.ensemble import EnsembleFunction, EnsembleFunctionSpace
66
from firedrake import assemble, inner, dx, Constant
77
from firedrake.adjoint.composite_reduced_functional import (
88
CompositeReducedFunctional, intermediate_options)
@@ -153,8 +153,8 @@ def __init__(self, control: Control,
153153
self.background = control.control.subfunctions[0]._ad_copy()
154154
_rename(self.background, "Background")
155155

156-
self.control_space = control.function_space()
157-
ensemble = self.control_space.ensemble
156+
self.solution_space = control.function_space()
157+
ensemble = self.solution_space.ensemble
158158
self.ensemble = ensemble
159159
self.trank = ensemble.ensemble_comm.rank if ensemble else 0
160160
self.nchunks = ensemble.ensemble_comm.size if ensemble else 1
@@ -175,6 +175,12 @@ def __init__(self, control: Control,
175175

176176
self.stages = [] # The record of each observation stage
177177

178+
self.observation_rfs = []
179+
self.observation_norms = []
180+
181+
self.model_rfs = []
182+
self.model_norms = []
183+
178184
# first rank sets up functionals for background initial observations
179185
if self.trank == 0:
180186

@@ -210,6 +216,9 @@ def __init__(self, control: Control,
210216
observation_covariance,
211217
control_name="obs_err_vec_0_copy")
212218

219+
self.observation_rfs.append(self.initial_observation_error)
220+
self.observation_norms.append(self.initial_observation_norm)
221+
213222
# compose initial observation reduced functionals to evaluate both together
214223
self.initial_observation_rf = CompositeReducedFunctional(
215224
self.initial_observation_error, self.initial_observation_norm)
@@ -392,12 +401,12 @@ def derivative(self, adj_input: float = 1.0, options: dict = {}):
392401

393402
# create the derivative in the right primal or dual space
394403
if is_primal(sderiv0[0]):
395-
derivative_space = self.control_space
404+
derivative_space = self.solution_space
396405
else:
397406
if not is_dual(sderiv0[0]):
398407
raise ValueError(
399408
"Do not know how to handle stage derivative which is not primal or dual")
400-
derivative_space = self.control_space.dual()
409+
derivative_space = self.solution_space.dual()
401410
derivatives = EnsembleFunction(derivative_space)
402411

403412
derivatives.zero()
@@ -631,6 +640,12 @@ def recording_stages(self, sequential=True, nstages=None, **stage_kwargs):
631640
# let the user record the local stages
632641
yield stage_sequence
633642

643+
for stage in self.stages:
644+
self.observation_rfs.append(stage.observation_error)
645+
self.observation_norms.append(stage.observation_norm)
646+
self.model_rfs.append(stage.forward_model)
647+
self.model_norms.append(stage.model_norm)
648+
634649
# send the state forward
635650
with stop_annotating():
636651
state = self.stages[-1].controls[1].control
@@ -646,6 +661,10 @@ def recording_stages(self, sequential=True, nstages=None, **stage_kwargs):
646661
# values of the initial timeseris
647662
self.control.assign(self._cbuf)
648663

664+
self.observation_space = EnsembleFunctionSpace(
665+
[Jo.functional.function_space() for Jo in self.observation_rfs],
666+
self.ensemble)
667+
649668
else: # strong constraint
650669

651670
yield ObservationStageSequence(

0 commit comments

Comments
 (0)