diff --git a/cpp/demo/custom_kernel/main.cpp b/cpp/demo/custom_kernel/main.cpp index 6ddcfacabfb..7f701fd4c21 100644 --- a/cpp/demo/custom_kernel/main.cpp +++ b/cpp/demo/custom_kernel/main.cpp @@ -141,9 +141,13 @@ double assemble_matrix1(const mesh::Geometry& g, const fem::DofMap& dofmap, common::Timer timer("Assembler1 lambda (matrix)"); md::mdspan> x( g.x().data(), g.x().size() / 3, 3); - fem::impl::assemble_cells_matrix( - A.mat_add_values(), g.dofmap(), x, cells, {dofmap.map(), 1, cells}, ident, - {dofmap.map(), 1, cells}, ident, {}, {}, kernel, {}, {}, {}, {}); + + std::vector cdofs_b(3 * g.dofmap().extent(1)); + std::vector Ab(dofmap.map().extent(1) * dofmap.map().extent(1)); + fem::impl::assemble_cells_matrix(A.mat_add_values(), g.dofmap(), x, cells, + {dofmap.map(), 1, cells}, ident, + {dofmap.map(), 1, cells}, ident, {}, {}, + kernel, {}, {}, {}, {}, Ab, cdofs_b); A.scatter_rev(); return A.squared_norm(); } diff --git a/cpp/dolfinx/fem/assemble_matrix_impl.h b/cpp/dolfinx/fem/assemble_matrix_impl.h index 3b6148f9e97..42f8c43150f 100644 --- a/cpp/dolfinx/fem/assemble_matrix_impl.h +++ b/cpp/dolfinx/fem/assemble_matrix_impl.h @@ -60,6 +60,12 @@ using mdspan2_t = md::mdspan>; /// function mesh. /// @param cell_info1 Cell permutation information for the trial /// function mesh. +/// @param Ab Buffer for local element matrix. Size must be at least +/// `(bs0 * num_dofs0) * (bs1 * num_dofs1)`, where `bs0 * num_dofs0` is +/// the number of rows and `bs1 * num_dofs1` is the number of columns in +/// local element matrix. +/// @param cdofs_b Buffer for local element geometry. Size must be at +/// least `3 * x_dofmap.extent(1))`. template void assemble_cells_matrix( la::MatSet auto mat_set, mdspan2_t x_dofmap, @@ -74,7 +80,8 @@ void assemble_cells_matrix( std::span bc1, FEkernel auto kernel, md::mdspan> coeffs, std::span constants, std::span cell_info0, - std::span cell_info1) + std::span cell_info1, std::span Ab, + std::span> cdofs_b) { if (cells.empty()) return; @@ -83,12 +90,14 @@ void assemble_cells_matrix( const auto [dmap1, bs1, cells1] = dofmap1; // Iterate over active cells - const int num_dofs0 = dmap0.extent(1); - const int num_dofs1 = dmap1.extent(1); - const int ndim0 = bs0 * num_dofs0; - const int ndim1 = bs1 * num_dofs1; - std::vector Ae(ndim0 * ndim1); - std::vector> cdofs(3 * x_dofmap.extent(1)); + std::size_t num_dofs0 = dmap0.extent(1); + std::size_t num_dofs1 = dmap1.extent(1); + std::size_t ndim0 = bs0 * num_dofs0; + std::size_t ndim1 = bs1 * num_dofs1; + + assert(Ab.size() >= ndim0 * ndim1); + assert(cdofs_b.size() >= 3 * x_dofmap.extent(1)); + auto Ae = Ab.first(ndim0 * ndim1); // Iterate over active cells assert(cells0.size() == cells.size()); @@ -104,11 +113,11 @@ void assemble_cells_matrix( // Get cell coordinates/geometry auto x_dofs = md::submdspan(x_dofmap, cell, md::full_extent); for (std::size_t i = 0; i < x_dofs.size(); ++i) - std::copy_n(&x(x_dofs[i], 0), 3, std::next(cdofs.begin(), 3 * i)); + std::copy_n(&x(x_dofs[i], 0), 3, std::next(cdofs_b.begin(), 3 * i)); // Tabulate tensor std::ranges::fill(Ae, 0); - kernel(Ae.data(), &coeffs(c, 0), constants.data(), cdofs.data(), nullptr, + kernel(Ae.data(), &coeffs(c, 0), constants.data(), cdofs_b.data(), nullptr, nullptr, nullptr); // Compute A = P_0 \tilde{A} P_1^T (dof transformation) @@ -121,7 +130,7 @@ void assemble_cells_matrix( if (!bc0.empty()) { - for (int i = 0; i < num_dofs0; ++i) + for (std::size_t i = 0; i < num_dofs0; ++i) { for (int k = 0; k < bs0; ++k) { @@ -137,15 +146,15 @@ void assemble_cells_matrix( if (!bc1.empty()) { - for (int j = 0; j < num_dofs1; ++j) + for (std::size_t j = 0; j < num_dofs1; ++j) { for (int k = 0; k < bs1; ++k) { if (bc1[bs1 * dofs1[j] + k]) { // Zero column bs1 * j + k - const int col = bs1 * j + k; - for (int row = 0; row < ndim0; ++row) + int col = bs1 * j + k; + for (std::size_t row = 0; row < ndim0; ++row) Ae[row * ndim1 + col] = 0; } } @@ -198,6 +207,12 @@ void assemble_cells_matrix( /// function mesh. /// @param[in] perms Entity permutation integer. Empty if entity /// permutations are not required. +/// @param Ab Buffer for local element matrix. Size must be at least +/// `(bs0 * num_dofs0) * (bs1 * num_dofs1)`, where `bs0 * num_dofs0` is +/// the number of rows and `bs1 * num_dofs1` is the number of columns in +/// local element matrix. +/// @param cdofs_b Buffer for local element geometry. Size must be at +/// least `3 * x_dofmap.extent(1))`. template void assemble_entities( la::MatSet auto mat_set, mdspan2_t x_dofmap, @@ -221,7 +236,8 @@ void assemble_entities( md::mdspan> coeffs, std::span constants, std::span cell_info0, std::span cell_info1, - md::mdspan> perms) + md::mdspan> perms, + std::span Ab, std::span> cdofs_b) { if (entities.empty()) return; @@ -229,16 +245,16 @@ void assemble_entities( const auto [dmap0, bs0, entities0] = dofmap0; const auto [dmap1, bs1, entities1] = dofmap1; - // Data structures used in assembly - std::vector> cdofs(3 * x_dofmap.extent(1)); - const int num_dofs0 = dmap0.extent(1); - const int num_dofs1 = dmap1.extent(1); - const int ndim0 = bs0 * num_dofs0; - const int ndim1 = bs1 * num_dofs1; - std::vector Ae(ndim0 * ndim1); - assert(entities0.size() == entities.size()); - assert(entities1.size() == entities.size()); - for (std::size_t f = 0; f < entities.extent(0); ++f) + std::size_t num_dofs0 = dmap0.extent(1); + std::size_t num_dofs1 = dmap1.extent(1); + std::size_t ndim0 = bs0 * num_dofs0; + std::size_t ndim1 = bs1 * num_dofs1; + assert(facets0.size() == facets.size()); + assert(facets1.size() == facets.size()); + assert(Ab.size() >= ndim0 * ndim1); + assert(cdofs_b.size() >= 3 * x_dofmap.extent(1)); + auto Ae = Ab.first(ndim0 * ndim1); + for (std::size_t f = 0; f < facets.extent(0); ++f) { // Cell in the integration domain, local entity index relative to the // integration domain cell, and cells in the test and trial function @@ -251,15 +267,15 @@ void assemble_entities( // Get cell coordinates/geometry auto x_dofs = md::submdspan(x_dofmap, cell, md::full_extent); for (std::size_t i = 0; i < x_dofs.size(); ++i) - std::copy_n(&x(x_dofs[i], 0), 3, std::next(cdofs.begin(), 3 * i)); + std::copy_n(&x(x_dofs[i], 0), 3, std::next(cdofs_b.begin(), 3 * i)); // Permutations std::uint8_t perm = perms.empty() ? 0 : perms(cell, local_entity); // Tabulate tensor std::ranges::fill(Ae, 0); - kernel(Ae.data(), &coeffs(f, 0), constants.data(), cdofs.data(), - &local_entity, &perm, nullptr); + kernel(Ae.data(), &coeffs(f, 0), constants.data(), cdofs_b.data(), + &local_facet, &perm, nullptr); P0(Ae, cell_info0, cell0, ndim1); P1T(Ae, cell_info1, cell1, ndim0); @@ -268,7 +284,7 @@ void assemble_entities( std::span dofs1(dmap1.data_handle() + cell1 * num_dofs1, num_dofs1); if (!bc0.empty()) { - for (int i = 0; i < num_dofs0; ++i) + for (std::size_t i = 0; i < num_dofs0; ++i) { for (int k = 0; k < bs0; ++k) { @@ -283,15 +299,15 @@ void assemble_entities( } if (!bc1.empty()) { - for (int j = 0; j < num_dofs1; ++j) + for (std::size_t j = 0; j < num_dofs1; ++j) { for (int k = 0; k < bs1; ++k) { if (bc1[bs1 * dofs1[j] + k]) { // Zero column bs1 * j + k - const int col = bs1 * j + k; - for (int row = 0; row < ndim0; ++row) + int col = bs1 * j + k; + for (std::size_t row = 0; row < ndim0; ++row) Ae[row * ndim1 + col] = 0; } } @@ -338,6 +354,14 @@ void assemble_entities( /// function mesh. /// @param[in] perms Facet permutation integer. Empty if facet /// permutations are not required. +/// @param Ab Buffer for local element matrix. Size must be at least `4 +/// * (bs0 * num_dofs0) * (bs1 * num_dofs1)`, where `bs0 * num_dofs0` is +/// the number of rows and `bs1 * num_dofs1` is the number of columns in +/// local element matrix. +/// @param cdofs_b Buffer for local element geometry. Size must be at +/// least `2 * 3 * x_dofmap.extent(1))`. +/// @param dofs_b Buffer for degrees-of-freedom. Size must be at least +/// `2 * dmap0.map().extent(1) + 2 * dmap1.map().extent(1)`. template void assemble_interior_facets( la::MatSet auto mat_set, mdspan2_t x_dofmap, @@ -363,7 +387,9 @@ void assemble_interior_facets( coeffs, std::span constants, std::span cell_info0, std::span cell_info1, - md::mdspan> perms) + md::mdspan> perms, + std::span Ab, std::span> cdofs_b, + std::span dofs_b) { if (facets.empty()) return; @@ -372,23 +398,24 @@ void assemble_interior_facets( const auto [dmap1, bs1, facets1] = dofmap1; // Data structures used in assembly - using X = scalar_value_t; - std::vector cdofs(2 * x_dofmap.extent(1) * 3); - std::span cdofs0(cdofs.data(), x_dofmap.extent(1) * 3); - std::span cdofs1(cdofs.data() + x_dofmap.extent(1) * 3, - x_dofmap.extent(1) * 3); - - const std::size_t dmap0_size = dmap0.map().extent(1); - const std::size_t dmap1_size = dmap1.map().extent(1); - const int num_rows = bs0 * 2 * dmap0_size; - const int num_cols = bs1 * 2 * dmap1_size; - - // Temporaries for joint dofmaps - std::vector Ae(num_rows * num_cols), be(num_rows); - std::vector dmapjoint0(2 * dmap0_size); - std::vector dmapjoint1(2 * dmap1_size); + assert(cdofs_b.size() >= 2 * 3 * x_dofmap.extent(1)); + auto cdofs0 = cdofs_b.first(3 * x_dofmap.extent(1)); + auto cdofs1 = cdofs_b.last(3 * x_dofmap.extent(1)); + + std::size_t dmap0_size = dmap0.map().extent(1); + std::size_t dmap1_size = dmap1.map().extent(1); + std::size_t num_rows = bs0 * 2 * dmap0_size; + std::size_t num_cols = bs1 * 2 * dmap1_size; + + // Dofmap data structures + assert(dofs_b.size() >= (2 * dmap0_size) + (2 * dmap1_size)); + auto dmapjoint0 = dofs_b.first(2 * dmap0_size); + auto dmapjoint1 = dofs_b.last(2 * dmap1_size); + assert(facets0.size() == facets.size()); assert(facets1.size() == facets.size()); + assert(Ab.size() >= num_rows * num_cols); + auto Ae = Ab.first(num_rows * num_cols); for (std::size_t f = 0; f < facets.extent(0); ++f) { // Cells in integration domain, test function domain and trial @@ -439,7 +466,7 @@ void assemble_interior_facets( ? std::array{0, 0} : std::array{perms(cells[0], local_facet[0]), perms(cells[1], local_facet[1])}; - kernel(Ae.data(), &coeffs(f, 0, 0), constants.data(), cdofs.data(), + kernel(Ae.data(), &coeffs(f, 0, 0), constants.data(), cdofs_b.data(), local_facet.data(), perm.data(), nullptr); // Local element layout is a 2x2 block matrix with structure @@ -464,7 +491,7 @@ void assemble_interior_facets( if (cells1[1] >= 0) { - for (int row = 0; row < num_rows; ++row) + for (std::size_t row = 0; row < num_rows; ++row) { // DOFs for dmap1 and cell1 are not stored contiguously in the // block matrix, so each row needs a separate span access @@ -499,7 +526,7 @@ void assemble_interior_facets( if (bc1[bs1 * dmapjoint1[j] + k]) { // Zero column bs1 * j + k - for (int m = 0; m < num_rows; ++m) + for (std::size_t m = 0; m < num_rows; ++m) Ae[m * num_cols + bs1 * j + k] = 0; } } @@ -570,11 +597,20 @@ void assemble_matrix( = a.function_spaces().at(1)->dofmaps(cell_type_idx); assert(dofmap0); assert(dofmap1); - auto dofs0 = dofmap0->map(); + md::mdspan> dofs0 + = dofmap0->map(); const int bs0 = dofmap0->bs(); - auto dofs1 = dofmap1->map(); + md::mdspan> dofs1 + = dofmap1->map(); const int bs1 = dofmap1->bs(); + std::vector Ab((2 * bs0 * dofs0.extent(1)) + * (2 * bs1 * dofs1.extent(1))); + std::vector> cdofs_b(2 * 3 * x_dofmap.extent(1)); + std::size_t dmap0_size = dofmap0->map().extent(1); + std::size_t dmap1_size = dofmap1->map().extent(1); + std::vector dmap_b((2 * dmap0_size) + (2 * dmap1_size)); + auto element0 = a.function_spaces().at(0)->elements(cell_type_idx); assert(element0); auto element1 = a.function_spaces().at(1)->elements(cell_type_idx); @@ -610,7 +646,7 @@ void assemble_matrix( mat_set, x_dofmap, x, cells, {dofs0, bs0, cells0}, P0, {dofs1, bs1, cells1}, P1T, bc0, bc1, fn, md::mdspan(coeffs.data(), cells.size(), cstride), constants, - cell_info0, cell_info1); + cell_info0, cell_info1, std::span(Ab), std::span(cdofs_b)); } md::mdspan> facet_perms; @@ -622,8 +658,40 @@ void assemble_matrix( mesh->topology_mutable()->create_entity_permutations(); const std::vector& p = mesh->topology()->get_facet_permutations(); - facet_perms = md::mdspan(p.data(), p.size() / num_facets_per_cell, - num_facets_per_cell); + perms = md::mdspan(p.data(), p.size() / num_facets_per_cell, + num_facets_per_cell); + } + + for (int i = 0; + i < a.num_integrals(IntegralType::exterior_facet, cell_type_idx); ++i) + { + if (num_cell_types > 1) + { + throw std::runtime_error("Exterior facet integrals with mixed " + "topology aren't supported yet"); + } + + using mdspanx2_t + = md::mdspan>; + + auto fn = a.kernel(IntegralType::exterior_facet, i, 0); + assert(fn); + auto& [coeffs, cstride] + = coefficients.at({IntegralType::exterior_facet, i}); + + std::span f = a.domain(IntegralType::exterior_facet, i, 0); + mdspanx2_t facets(f.data(), f.size() / 2, 2); + std::span f0 = a.domain_arg(IntegralType::exterior_facet, 0, i, 0); + mdspanx2_t facets0(f0.data(), f0.size() / 2, 2); + std::span f1 = a.domain_arg(IntegralType::exterior_facet, 1, i, 0); + mdspanx2_t facets1(f1.data(), f1.size() / 2, 2); + assert((facets.size() / 2) * cstride == coeffs.size()); + impl::assemble_exterior_facets( + mat_set, x_dofmap, x, facets, {dofs0, bs0, facets0}, P0, + {dofs1, bs1, facets1}, P1T, bc0, bc1, fn, + md::mdspan(coeffs.data(), facets.extent(0), cstride), constants, + cell_info0, cell_info1, perms, std::span(Ab), std::span(cdofs_b)); } for (int i = 0; @@ -661,47 +729,8 @@ void assemble_matrix( mdspanx22_t(facets1.data(), facets1.size() / 4, 2, 2)}, P1T, bc0, bc1, fn, mdspanx2x_t(coeffs.data(), facets.size() / 4, 2, cstride), constants, - cell_info0, cell_info1, facet_perms); - } - - for (auto itg_type : {fem::IntegralType::exterior_facet, - fem::IntegralType::vertex, fem::IntegralType::ridge}) - { - md::mdspan> perms - = itg_type == fem::IntegralType::exterior_facet - ? facet_perms - : md::mdspan>{}; - - for (int i = 0; i < a.num_integrals(itg_type, cell_type_idx); ++i) - { - if (num_cell_types > 1) - { - throw std::runtime_error("Exterior facet integrals with mixed " - "topology aren't supported yet"); - } - - using mdspanx2_t - = md::mdspan>; - - auto fn = a.kernel(itg_type, i, 0); - assert(fn); - auto& [coeffs, cstride] = coefficients.at({itg_type, i}); - - std::span e = a.domain(itg_type, i, 0); - mdspanx2_t entities(e.data(), e.size() / 2, 2); - std::span e0 = a.domain_arg(itg_type, 0, i, 0); - mdspanx2_t entities0(e0.data(), e0.size() / 2, 2); - std::span e1 = a.domain_arg(itg_type, 1, i, 0); - mdspanx2_t entities1(e1.data(), e1.size() / 2, 2); - assert((entities.size() / 2) * cstride == coeffs.size()); - impl::assemble_entities( - mat_set, x_dofmap, x, entities, {dofs0, bs0, entities0}, P0, - {dofs1, bs1, entities1}, P1T, bc0, bc1, fn, - md::mdspan(coeffs.data(), entities.extent(0), cstride), constants, - cell_info0, cell_info1, perms); - } + cell_info0, cell_info1, perms, std::span(Ab), std::span(cdofs_b), + dmap_b); } } }