Skip to content

Commit 34bad7b

Browse files
committed
Cleanup
2 parents 4a91dee + da671cd commit 34bad7b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+1371
-397
lines changed

docs/notebooks/03-elasticity.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@
328328
"\u001b[36mFile \u001b[39m\u001b[32mpetsc4py/PETSc/Log.pyx:188\u001b[39m, in \u001b[36mpetsc4py.PETSc.Log.EventDecorator.decorator.wrapped_func\u001b[39m\u001b[34m()\u001b[39m\n",
329329
"\u001b[36mFile \u001b[39m\u001b[32mpetsc4py/PETSc/Log.pyx:189\u001b[39m, in \u001b[36mpetsc4py.PETSc.Log.EventDecorator.decorator.wrapped_func\u001b[39m\u001b[34m()\u001b[39m\n",
330330
"\u001b[36mFile \u001b[39m\u001b[32m~/src/firedrake-pyadjoint/firedrake/firedrake/adjoint_utils/variational_solver.py:108\u001b[39m, in \u001b[36mNonlinearVariationalSolverMixin._ad_annotate_solve.<locals>.wrapper\u001b[39m\u001b[34m(self, **kwargs)\u001b[39m\n\u001b[32m 105\u001b[39m tape.add_block(block)\n\u001b[32m 107\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m stop_annotating():\n\u001b[32m--> \u001b[39m\u001b[32m108\u001b[39m out = \u001b[43msolve\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 110\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m annotate:\n\u001b[32m 111\u001b[39m block.add_output(\u001b[38;5;28mself\u001b[39m._ad_problem._ad_u.create_block_variable())\n",
331-
"\u001b[36mFile \u001b[39m\u001b[32m~/src/firedrake-pyadjoint/firedrake/firedrake/variational_solver.py:361\u001b[39m, in \u001b[36mNonlinearVariationalSolver.solve\u001b[39m\u001b[34m(self, bounds)\u001b[39m\n\u001b[32m 359\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m problem.restrict:\n\u001b[32m 360\u001b[39m problem.u.interpolate(problem.u_restrict)\n\u001b[32m--> \u001b[39m\u001b[32m361\u001b[39m \u001b[43msolving_utils\u001b[49m\u001b[43m.\u001b[49m\u001b[43mcheck_snes_convergence\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43msnes\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 363\u001b[39m \u001b[38;5;66;03m# Grab the comm associated with the `_problem` and call PETSc's garbage cleanup routine\u001b[39;00m\n\u001b[32m 364\u001b[39m comm = \u001b[38;5;28mself\u001b[39m._problem.u_restrict.function_space().mesh()._comm\n",
331+
"\u001b[36mFile \u001b[39m\u001b[32m~/src/firedrake-pyadjoint/firedrake/firedrake/variational_solver.py:361\u001b[39m, in \u001b[36mNonlinearVariationalSolver.solve\u001b[39m\u001b[34m(self, bounds)\u001b[39m\n\u001b[32m 359\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m problem.restrict:\n\u001b[32m 360\u001b[39m problem.u.interpolate(problem.u_restrict)\n\u001b[32m--> \u001b[39m\u001b[32m361\u001b[39m \u001b[43msolving_utils\u001b[49m\u001b[43m.\u001b[49m\u001b[43mcheck_snes_convergence\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43msnes\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 363\u001b[39m \u001b[38;5;66;03m# Grab the comm associated with the `_problem` and call PETSc's garbage cleanup routine\u001b[39;00m\n\u001b[32m 364\u001b[39m comm = \u001b[38;5;28mself\u001b[39m._problem.u_restrict.function_space().mesh().comm\n",
332332
"\u001b[36mFile \u001b[39m\u001b[32m~/src/firedrake-pyadjoint/firedrake/firedrake/solving_utils.py:128\u001b[39m, in \u001b[36mcheck_snes_convergence\u001b[39m\u001b[34m(snes)\u001b[39m\n\u001b[32m 126\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 127\u001b[39m msg = reason\n\u001b[32m--> \u001b[39m\u001b[32m128\u001b[39m \u001b[38;5;28;01mraise\u001b[39;00m ConvergenceError(\u001b[33mr\u001b[39m\u001b[33m\"\"\"\u001b[39m\u001b[33mNonlinear solve failed to converge after \u001b[39m\u001b[38;5;132;01m%d\u001b[39;00m\u001b[33m nonlinear iterations.\u001b[39m\n\u001b[32m 129\u001b[39m \u001b[33mReason:\u001b[39m\n\u001b[32m 130\u001b[39m \u001b[33m \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[33m\"\"\"\u001b[39m % (snes.getIterationNumber(), msg))\n",
333333
"\u001b[31mConvergenceError\u001b[39m: Nonlinear solve failed to converge after 0 nonlinear iterations.\nReason:\n DIVERGED_LINEAR_SOLVE"
334334
]

docs/source/parallelism.rst

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,5 @@ different simulations on the two halves we would write.
9696
To access the communicator a mesh was created on, we can use the
9797
``mesh.comm`` property, or the function ``mesh.mpi_comm``.
9898

99-
.. warning::
100-
Do not use the internal ``mesh._comm`` attribute for communication.
101-
This communicator is for internal Firedrake MPI communication only.
102-
10399
.. _MPI: http://mpi-forum.org/
104100
.. _STREAMS: http://www.cs.virginia.edu/stream/

firedrake/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
# TODO RELEASE
99
# PETSC_SUPPORTED_VERSIONS = ">=3.25"
1010

11+
1112
def init_petsc():
1213
import os
1314
import sys

firedrake/adjoint/__init__.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,6 @@
1010
import pyadjoint
1111
__version__ = pyadjoint.__version__
1212

13-
import sys
14-
if 'backend' not in sys.modules:
15-
import firedrake
16-
sys.modules['backend'] = firedrake
17-
else:
18-
raise ImportError("'backend' module already exists?")
19-
2013
from pyadjoint.tape import Tape, set_working_tape, get_working_tape, \
2114
pause_annotation, continue_annotation, \
2215
stop_annotating, annotate_tape # noqa F401

firedrake/adjoint_utils/function.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from functools import wraps
2+
from pyop2.mpi import temp_internal_comm
23
import ufl
34
from ufl.domain import extract_unique_domain
45
from pyadjoint.overloaded_type import create_overloaded_object, FloatingType
@@ -280,8 +281,9 @@ def _ad_assign_numpy(dst, src, offset):
280281
m_a_local = src[offset + range_begin:offset + range_end]
281282
if dst.function_space().ufl_element().family() == "Real":
282283
# Real space keeps a redundant copy of the data on every rank
283-
comm = dst.function_space().mesh()._comm
284-
dst.dat.data_wo[...] = comm.bcast(m_a_local, root=0)
284+
comm = dst.function_space().mesh().comm
285+
with temp_internal_comm(comm) as icomm:
286+
dst.dat.data_wo[...] = icomm.bcast(m_a_local, root=0)
285287
else:
286288
dst.dat.data_wo[...] = m_a_local.reshape(dst.dat.data_wo.shape)
287289
offset += dst.dat.dataset.layout_vec.size

firedrake/assemble.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,7 @@ def base_form_assembly_visitor(self, expr, tensor, bcs, *args):
615615
if rank > 2:
616616
raise ValueError("Cannot assemble an Interpolate with more than two arguments")
617617
interpolator = get_interpolator(expr)
618-
return interpolator.assemble(tensor=tensor, bcs=bcs)
618+
return interpolator.assemble(tensor=tensor, bcs=bcs, mat_type=self._mat_type, sub_mat_type=self._sub_mat_type)
619619
elif tensor and isinstance(expr, (firedrake.Function, firedrake.Cofunction, firedrake.MatrixBase)):
620620
return tensor.assign(expr)
621621
elif tensor and isinstance(expr, ufl.ZeroBaseForm):
@@ -855,6 +855,15 @@ def restructure_base_form(expr, visited=None):
855855
if isinstance(expr, ufl.FormSum) and all(ufl.duals.is_dual(a.function_space()) for a in expr.arguments()):
856856
# Return ufl.Sum if we are assembling a FormSum with Coarguments (a primal expression)
857857
return sum(w*c for w, c in zip(expr.weights(), expr.components()))
858+
859+
# If F: V3 x V2 -> R, then
860+
# Interpolate(TestFunction(V1), F) <=> Action(Interpolate(TestFunction(V1), TrialFunction(V2.dual())), F).
861+
# The result is a two-form V3 x V1 -> R.
862+
if isinstance(expr, ufl.Interpolate) and isinstance(expr.argument_slots()[0], ufl.form.Form):
863+
form, operand = expr.argument_slots()
864+
vstar = firedrake.Argument(form.arguments()[0].function_space().dual(), 1)
865+
expr = expr._ufl_expr_reconstruct_(operand, v=vstar)
866+
return ufl.action(expr, form)
858867
return expr
859868

860869
@staticmethod
@@ -1159,13 +1168,13 @@ def __init__(self, form, form_compiler_parameters=None):
11591168

11601169
def allocate(self):
11611170
# Getting the comm attribute of a form isn't straightforward
1162-
# form.ufl_domains()[0]._comm seems the most robust method
1171+
# form.ufl_domains()[0].comm seems the most robust method
11631172
# revisit in a refactor
11641173
return op2.Global(
11651174
1,
11661175
[0.0],
11671176
dtype=utils.ScalarType,
1168-
comm=self._form.ufl_domains()[0]._comm
1177+
comm=self._form.ufl_domains()[0].comm
11691178
)
11701179

11711180
def _apply_bc(self, tensor, bc, u=None):

firedrake/checkpointing.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from petsc4py.PETSc import ViewerHDF5
44
import finat.ufl
55
from pyop2 import op2
6-
from pyop2.mpi import COMM_WORLD, internal_comm, MPI
6+
from pyop2.mpi import COMM_WORLD, MPI
77
from petsctools import OptionsManager
88
from firedrake.cython import hdf5interface as h5i
99
from firedrake.cython import dmcommon
@@ -104,7 +104,6 @@ def __init__(self, basename, single_file=True,
104104
warnings.warn("DumbCheckpoint class will soon be deprecated; use CheckpointFile class instead.",
105105
DeprecationWarning)
106106
self.comm = comm or COMM_WORLD
107-
self._comm = internal_comm(self.comm, self)
108107
self.mode = mode
109108

110109
self._single = single_file
@@ -195,7 +194,7 @@ def new_file(self, name=None):
195194
if mode == FILE_UPDATE and not exists:
196195
mode = FILE_CREATE
197196
self._vwr = PETSc.ViewerHDF5().create(name, mode=mode,
198-
comm=self._comm)
197+
comm=self.comm)
199198
if self.mode == FILE_READ:
200199
nprocs = self.read_attribute("/", "nprocs")
201200
if nprocs != self.comm.size:
@@ -379,7 +378,6 @@ def __init__(self, filename, file_mode, comm=None):
379378
warnings.warn("HDF5File class will soon be deprecated; use CheckpointFile class instead.",
380379
DeprecationWarning)
381380
self.comm = comm or COMM_WORLD
382-
self._comm = internal_comm(self.comm, self)
383381

384382
self._filename = filename
385383
self._mode = file_mode
@@ -397,7 +395,7 @@ def __init__(self, filename, file_mode, comm=None):
397395

398396
# Try to use MPI
399397
try:
400-
self._h5file = h5py.File(filename, file_mode, driver="mpio", comm=self._comm)
398+
self._h5file = h5py.File(filename, file_mode, driver="mpio", comm=self.comm)
401399
except NameError: # the error you get if h5py isn't compiled against parallel HDF5
402400
raise RuntimeError("h5py *must* be installed with MPI support")
403401

@@ -527,10 +525,9 @@ def __init__(self, filename, mode, comm=COMM_WORLD):
527525
self.viewer = ViewerHDF5()
528526
self.filename = filename
529527
self.comm = comm
530-
self._comm = internal_comm(comm, self)
531528
r"""The neme of the checkpoint file."""
532-
self.viewer.create(filename, mode=mode, comm=self._comm)
533-
self.commkey = self._comm.py2f()
529+
self.viewer.create(filename, mode=mode, comm=self.comm)
530+
self.commkey = self.comm.py2f()
534531
assert self.commkey != MPI.COMM_NULL.py2f()
535532
self._function_spaces = {}
536533
self._function_load_utils = {}
@@ -597,7 +594,7 @@ def save_mesh(self, mesh, distribution_name=None, permutation_name=None):
597594
layers_tV = impl.FunctionSpace(base_tmesh, element)
598595
self._save_function_space_topology(layers_tV)
599596
# Note that _cell_numbering coincides with DG0 section, so we can use tmesh.layers directly.
600-
layers_iset = PETSc.IS().createGeneral(tmesh.layers[:tmesh.cell_set.size, :], comm=tmesh._comm)
597+
layers_iset = PETSc.IS().createGeneral(tmesh.layers[:tmesh.cell_set.size, :], comm=tmesh.comm)
601598
layers_iset.setName("_".join([PREFIX_EXTRUDED, "layers_iset"]))
602599
self.viewer.pushGroup(path)
603600
layers_iset.view(self.viewer)
@@ -661,7 +658,7 @@ def save_mesh(self, mesh, distribution_name=None, permutation_name=None):
661658
reflected = o_r_map[tmesh.entity_orientations[:tmesh.cell_set.size, -1]]
662659
reflected_indices = (reflected == 1)
663660
canonical_cell_orientations[reflected_indices] = 1 - canonical_cell_orientations[reflected_indices]
664-
cell_orientations_iset = PETSc.IS().createGeneral(canonical_cell_orientations, comm=tmesh._comm)
661+
cell_orientations_iset = PETSc.IS().createGeneral(canonical_cell_orientations, comm=tmesh.comm)
665662
cell_orientations_iset.setName("_".join([PREFIX_IMMERSED, "cell_orientations_iset"]))
666663
self.viewer.pushGroup(path)
667664
cell_orientations_iset.view(self.viewer)
@@ -1065,7 +1062,7 @@ def load_mesh(self, name=DEFAULT_MESH_NAME, reorder=None, distribution_parameter
10651062
_, _, lsf = self._function_load_utils[base_tmesh_key + sd_key]
10661063
nroots, _, _ = lsf.getGraph()
10671064
layers_a = np.empty(nroots, dtype=utils.IntType)
1068-
layers_a_iset = PETSc.IS().createGeneral(layers_a, comm=self._comm)
1065+
layers_a_iset = PETSc.IS().createGeneral(layers_a, comm=self.comm)
10691066
layers_a_iset.setName("_".join([PREFIX_EXTRUDED, "layers_iset"]))
10701067
self.viewer.pushGroup(path)
10711068
layers_a_iset.load(self.viewer)
@@ -1128,7 +1125,7 @@ def load_mesh(self, name=DEFAULT_MESH_NAME, reorder=None, distribution_parameter
11281125
_, _, lsf = self._function_load_utils[tmesh_key + sd_key]
11291126
nroots, _, _ = lsf.getGraph()
11301127
cell_orientations_a = np.empty(nroots, dtype=utils.IntType)
1131-
cell_orientations_a_iset = PETSc.IS().createGeneral(cell_orientations_a, comm=self._comm)
1128+
cell_orientations_a_iset = PETSc.IS().createGeneral(cell_orientations_a, comm=self.comm)
11321129
cell_orientations_a_iset.setName("_".join([PREFIX_IMMERSED, "cell_orientations_iset"]))
11331130
self.viewer.pushGroup(path)
11341131
cell_orientations_a_iset.load(self.viewer)
@@ -1165,7 +1162,7 @@ def _load_mesh_topology(self, tmesh_name, reorder, distribution_parameters):
11651162
_distribution_name, = self.h5pyfile[path].keys()
11661163
path = self._path_to_distribution(tmesh_name, _distribution_name)
11671164
_comm_size = self.get_attr(path, "comm_size")
1168-
if _comm_size == self._comm.size and \
1165+
if _comm_size == self.comm.size and \
11691166
distribution_parameters is None and reorder is None:
11701167
load_distribution_permutation = True
11711168
if load_distribution_permutation:
@@ -1185,7 +1182,7 @@ def _load_mesh_topology(self, tmesh_name, reorder, distribution_parameters):
11851182
permutation_name = None
11861183
perm_is = None
11871184
plex = PETSc.DMPlex()
1188-
plex.create(comm=self._comm)
1185+
plex.create(comm=self.comm)
11891186
plex.setName(tmesh_name)
11901187
# Check format
11911188
path = os.path.join(self._path_to_topology(tmesh_name), "topology")
@@ -1203,15 +1200,15 @@ def _load_mesh_topology(self, tmesh_name, reorder, distribution_parameters):
12031200
plex.removeLabel("pyop2_ghost")
12041201
if load_distribution_permutation:
12051202
chart_size = np.empty(1, dtype=utils.IntType)
1206-
chart_sizes_iset = PETSc.IS().createGeneral(chart_size, comm=self._comm)
1203+
chart_sizes_iset = PETSc.IS().createGeneral(chart_size, comm=self.comm)
12071204
chart_sizes_iset.setName("chart_sizes")
12081205
path = self._path_to_distribution(tmesh_name, distribution_name)
12091206
self.viewer.pushGroup(path)
12101207
chart_sizes_iset.load(self.viewer)
12111208
self.viewer.popGroup()
12121209
chart_size = chart_sizes_iset.getIndices().item()
12131210
perm = np.empty(chart_size, dtype=utils.IntType)
1214-
perm_is = PETSc.IS().createGeneral(perm, comm=self._comm)
1211+
perm_is = PETSc.IS().createGeneral(perm, comm=self.comm)
12151212
path = self._path_to_permutation(tmesh_name, distribution_name, permutation_name)
12161213
self.viewer.pushGroup(path)
12171214
perm_is.setName("permutation")
@@ -1276,10 +1273,10 @@ def _load_function_space_topology(self, tmesh, element):
12761273
sd_key = self._get_shared_data_key_for_checkpointing(tmesh, element)
12771274
if tmesh_key + sd_key not in self._function_load_utils:
12781275
topology_dm = tmesh.topology_dm
1279-
dm = PETSc.DMShell().create(comm=tmesh._comm)
1276+
dm = PETSc.DMShell().create(comm=tmesh.comm)
12801277
dm.setName(self._get_dm_name_for_checkpointing(tmesh, element))
12811278
dm.setPointSF(topology_dm.getPointSF())
1282-
section = PETSc.Section().create(comm=tmesh._comm)
1279+
section = PETSc.Section().create(comm=tmesh.comm)
12831280
section.setPermutation(tmesh._dm_renumbering)
12841281
dm.setSection(section)
12851282
base_tmesh = tmesh._base_mesh if isinstance(tmesh, ExtrudedMeshTopology) else tmesh
@@ -1446,7 +1443,7 @@ def _get_dm_for_checkpointing(self, tV):
14461443
nodes_per_entity, real_tensorproduct, block_size = sd_key
14471444
global_numbering, _ = tV.mesh().create_section(nodes_per_entity, real_tensorproduct, block_size=block_size)
14481445
topology_dm = tV.mesh().topology_dm
1449-
dm = PETSc.DMShell().create(tV.mesh()._comm)
1446+
dm = PETSc.DMShell().create(tV.mesh().comm)
14501447
dm.setPointSF(topology_dm.getPointSF())
14511448
dm.setSection(global_numbering)
14521449
else:

firedrake/cofunction.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import ufl
44

55
from ufl.form import BaseForm
6-
from pyop2 import op2, mpi
6+
from pyop2 import op2
77
from pyadjoint.tape import stop_annotating, annotate_tape, get_working_tape
88
from finat.ufl import MixedElement
99
import firedrake.assemble
@@ -65,18 +65,16 @@ def __init__(self, function_space, val=None, name=None, dtype=ScalarType,
6565

6666
# User comm
6767
self.comm = V.comm
68-
# Internal comm
69-
self._comm = mpi.internal_comm(V.comm, self)
7068
self._function_space = V
71-
self.uid = utils._new_uid(self._comm)
69+
self.uid = utils._new_uid(self.comm)
7270
self._name = name or 'cofunction_%d' % self.uid
7371
self._label = "a cofunction"
7472

7573
if isinstance(val, Cofunction):
7674
val = val.dat
7775

7876
if isinstance(val, (op2.Dat, op2.DatView, op2.MixedDat, op2.Global)):
79-
assert val.comm == self._comm
77+
assert val.comm == self.comm
8078
self.dat = val
8179
else:
8280
self.dat = function_space.make_dat(val, dtype, self.name())

firedrake/cython/dmcommon.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1263,7 +1263,7 @@ def create_section(mesh, nodes_per_entity, on_base=False, block_size=1, boundary
12631263
nodes = sum(nodes_per_entity[:, i]*(mesh.layers - i) for i in range(2)).reshape(dimension + 1, -1)
12641264
else:
12651265
nodes = nodes_per_entity.reshape(dimension + 1, -1)
1266-
section = PETSc.Section().create(comm=mesh._comm)
1266+
section = PETSc.Section().create(comm=mesh.comm)
12671267
get_chart(dm.dm, &pStart, &pEnd)
12681268
section.setChart(pStart, pEnd)
12691269

firedrake/dmhooks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ def create_subdm(dm, fields, *args, **kwargs):
393393
# Index set mapping from W into subspace.
394394
iset = PETSc.IS().createGeneral(numpy.concatenate([W._ises[f].indices
395395
for f in fields]),
396-
comm=W._comm)
396+
comm=W.comm)
397397
if ctx is not None:
398398
ctx, = ctx.split([fields])
399399
add_hook(parent, setup=partial(push_appctx, subspace.dm, ctx),

0 commit comments

Comments
 (0)