Skip to content

Commit 2e1c7c5

Browse files
authored
BUG: Fix symmetry in supermeshing (#4764)
* BUG: Fix symmetry in supermeshing * Fix cross mesh interpolation for symmetry=True
1 parent d382643 commit 2e1c7c5

File tree

4 files changed

+61
-28
lines changed

4 files changed

+61
-28
lines changed

firedrake/interpolation.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,12 @@ def __init__(
472472
raise NotImplementedError("freeze_expr not implemented")
473473
if bcs:
474474
raise NotImplementedError("bcs not implemented")
475-
if V.ufl_element().mapping() != "identity":
475+
476+
# TODO check V.finat_element.is_lagrange() once https://github.com/firedrakeproject/fiat/pull/200 is released
477+
target_element = V.ufl_element()
478+
if not ((isinstance(target_element, finat.ufl.MixedElement)
479+
and all(sub.mapping() == "identity" for sub in target_element.sub_elements))
480+
or target_element.mapping() == "identity"):
476481
# Identity mapping between reference cell and physical coordinates
477482
# implies point evaluation nodes. A more general version would
478483
# require finding the global coordinates of all quadrature points
@@ -551,7 +556,8 @@ def __init__(
551556
elif len(shape) == 1:
552557
fs_type = partial(firedrake.VectorFunctionSpace, dim=shape[0])
553558
else:
554-
fs_type = partial(firedrake.TensorFunctionSpace, shape=shape)
559+
symmetry = V_dest.ufl_element().symmetry()
560+
fs_type = partial(firedrake.TensorFunctionSpace, shape=shape, symmetry=symmetry)
555561
P0DG_vom = fs_type(self.vom_dest_node_coords_in_src_mesh, "DG", 0)
556562
self.point_eval_interpolate = Interpolate(self.expr_renumbered, P0DG_vom)
557563
# The parallel decomposition of the nodes of V_dest in the DESTINATION

firedrake/supermeshing.py

Lines changed: 45 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,18 @@
2020
from pyop2.compilation import load
2121
from pyop2.mpi import COMM_SELF
2222
from pyop2.utils import get_petsc_dir
23+
from collections import defaultdict
2324

2425

2526
__all__ = ["assemble_mixed_mass_matrix", "intersection_finder"]
2627

2728

29+
# TODO replace with KAIJ (we require petsc4py wrappers)
2830
class BlockMatrix(object):
29-
def __init__(self, mat, dimension):
31+
def __init__(self, mat, dimension, block_scale=None):
3032
self.mat = mat
3133
self.dimension = dimension
34+
self.block_scale = block_scale
3235

3336
def mult(self, mat, x, y):
3437
sizes = self.mat.getSizes()
@@ -41,6 +44,8 @@ def mult(self, mat, x, y):
4144
xi = PETSc.Vec().createWithArray(xa, size=sizes[1], comm=x.comm)
4245
yi = PETSc.Vec().createWithArray(ya, size=sizes[0], comm=y.comm)
4346
self.mat.mult(xi, yi)
47+
if self.block_scale is not None:
48+
yi.scale(self.block_scale[i])
4449
y.array[start::stride] = yi.array_r
4550

4651
def multTranspose(self, mat, x, y):
@@ -54,6 +59,8 @@ def multTranspose(self, mat, x, y):
5459
xi = PETSc.Vec().createWithArray(xa, size=sizes[0], comm=x.comm)
5560
yi = PETSc.Vec().createWithArray(ya, size=sizes[1], comm=y.comm)
5661
self.mat.multTranspose(xi, yi)
62+
if self.block_scale is not None:
63+
yi.scale(self.block_scale[i])
5764
y.array[start::stride] = yi.array_r
5865

5966

@@ -68,14 +75,6 @@ def assemble_mixed_mass_matrix(V_A, V_B):
6875
if len(V_A) > 1 or len(V_B) > 1:
6976
raise NotImplementedError("Sorry, only implemented for non-mixed spaces")
7077

71-
if V_A.ufl_element().mapping() != "identity" or V_B.ufl_element().mapping() != "identity":
72-
msg = """
73-
Sorry, only implemented for affine maps for now. To do non-affine, we'd need to
74-
import much more of the assembly engine of UFL/TSFC/etc to do the assembly on
75-
each supermesh cell.
76-
"""
77-
raise NotImplementedError(msg)
78-
7978
mesh_A = V_A.mesh()
8079
mesh_B = V_B.mesh()
8180

@@ -116,15 +115,39 @@ def likely(cell_A):
116115
def likely(cell_A):
117116
return cell_map[cell_A]
118117

119-
assert V_A.value_size == V_B.value_size
120-
orig_value_size = V_A.value_size
121-
if V_A.value_size > 1:
118+
assert V_A.block_size == V_B.block_size
119+
orig_block_size = V_A.block_size
120+
121+
# To deal with symmetry, each block of the mass matrix must be rescaled by the multiplicity
122+
if V_A.ufl_element().mapping() == "symmetries":
123+
symmetry = V_A.ufl_element().symmetry()
124+
assert V_B.ufl_element().mapping() == "symmetries"
125+
assert V_B.ufl_element().symmetry() == symmetry
126+
127+
multiplicity = defaultdict(int)
128+
for idx in numpy.ndindex(V_A.value_shape):
129+
idx = symmetry.get(idx, idx)
130+
multiplicity[idx] += 1
131+
132+
block_scale = tuple(scale for idx, scale in multiplicity.items())
133+
else:
134+
block_scale = None
135+
136+
if V_A.block_size > 1:
122137
V_A = firedrake.FunctionSpace(mesh_A, V_A.ufl_element().sub_elements[0])
123-
if V_B.value_size > 1:
138+
if V_B.block_size > 1:
124139
V_B = firedrake.FunctionSpace(mesh_B, V_B.ufl_element().sub_elements[0])
125140

126-
assert V_A.value_size == 1
127-
assert V_B.value_size == 1
141+
if V_A.ufl_element().mapping() != "identity" or V_B.ufl_element().mapping() != "identity":
142+
msg = """
143+
Sorry, only implemented for affine maps for now. To do non-affine, we'd need to
144+
import much more of the assembly engine of UFL/TSFC/etc to do the assembly on
145+
each supermesh cell.
146+
"""
147+
raise NotImplementedError(msg)
148+
149+
assert V_A.block_size == 1
150+
assert V_B.block_size == 1
128151

129152
preallocator = PETSc.Mat().create(comm=mesh_A._comm)
130153
preallocator.setType(PETSc.Mat.Type.PREALLOCATOR)
@@ -155,7 +178,7 @@ def likely(cell_A):
155178
onnz = numpy.repeat(onnz, cset.cdim)
156179
preallocator.destroy()
157180

158-
assert V_A.value_size == V_B.value_size
181+
assert V_A.block_size == V_B.block_size
159182
rdim = V_B.dof_dset.cdim
160183
cdim = V_A.dof_dset.cdim
161184

@@ -445,16 +468,16 @@ def likely(cell_A):
445468
lib.restype = ctypes.c_int
446469

447470
ammm(V_A, V_B, likely, node_locations_A, node_locations_B, M_SS, ctypes.addressof(lib), mat)
448-
if orig_value_size == 1:
471+
if orig_block_size == 1:
449472
return mat
450473
else:
451474
(lrows, grows), (lcols, gcols) = mat.getSizes()
452-
lrows *= orig_value_size
453-
grows *= orig_value_size
454-
lcols *= orig_value_size
455-
gcols *= orig_value_size
475+
lrows *= orig_block_size
476+
grows *= orig_block_size
477+
lcols *= orig_block_size
478+
gcols *= orig_block_size
456479
size = ((lrows, grows), (lcols, gcols))
457-
context = BlockMatrix(mat, orig_value_size)
480+
context = BlockMatrix(mat, orig_block_size, block_scale=block_scale)
458481
blockmat = PETSc.Mat().createPython(size, context=context, comm=mat.comm)
459482
blockmat.setUp()
460483
return blockmat

tests/firedrake/regression/test_interpolate_cross_mesh.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -396,11 +396,12 @@ def test_exact_refinement():
396396
)
397397

398398

399-
def test_interpolate_unitsquare_tfs_shape():
399+
@pytest.mark.parametrize("shape,symmetry", [((1, 2, 3), None), ((3, 3), True)])
400+
def test_interpolate_unitsquare_tfs_shape(shape, symmetry):
400401
m_src = UnitSquareMesh(2, 3)
401402
m_dest = UnitSquareMesh(3, 5, quadrilateral=True)
402-
V_src = TensorFunctionSpace(m_src, "CG", 3, shape=(1, 2, 3))
403-
V_dest = TensorFunctionSpace(m_dest, "CG", 4, shape=(1, 2, 3))
403+
V_src = TensorFunctionSpace(m_src, "CG", 3, shape=shape, symmetry=symmetry)
404+
V_dest = TensorFunctionSpace(m_dest, "CG", 4, shape=shape, symmetry=symmetry)
404405
f_src = Function(V_src)
405406
assemble(interpolate(f_src, V_dest))
406407

tests/firedrake/supermesh/test_galerkin_projection.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from firedrake.petsc import DEFAULT_DIRECT_SOLVER_PARAMETERS
33
from firedrake.supermeshing import *
44
from itertools import product
5+
from functools import partial
56
import numpy
67
import pytest
78

@@ -14,14 +15,16 @@ def mesh(request):
1415
return UnitCubeMesh(3, 2, 1)
1516

1617

17-
@pytest.fixture(params=["scalar", "vector", pytest.param("tensor", marks=pytest.mark.skip(reason="Prolongation fails for tensors"))])
18+
@pytest.fixture(params=["scalar", "vector", "tensor", "symmetric"])
1819
def shapify(request):
1920
if request.param == "scalar":
2021
return lambda x: x
2122
elif request.param == "vector":
2223
return VectorElement
2324
elif request.param == "tensor":
2425
return TensorElement
26+
elif request.param == "symmetric":
27+
return partial(TensorElement, symmetry=True)
2528
else:
2629
raise RuntimeError
2730

0 commit comments

Comments
 (0)