Path: blob/main/src/solvers/dgsem_p4est/dg_parallel.jl
5616 views
# By default, Julia/LLVM does not use fused multiply-add operations (FMAs).1# Since these FMAs can increase the performance of many numerical algorithms,2# we need to opt-in explicitly.3# See https://ranocha.de/blog/Optimizing_EC_Trixi for further details.4@muladd begin5#! format: noindent67mutable struct P4estMPICache{BufferType <: DenseVector,8VecInt <: DenseVector{<:Integer}}9mpi_neighbor_ranks::Vector{Int}10mpi_neighbor_interfaces::VecOfArrays{VecInt}11mpi_neighbor_mortars::VecOfArrays{VecInt}12mpi_send_buffers::VecOfArrays{BufferType}13mpi_recv_buffers::VecOfArrays{BufferType}14mpi_send_requests::Vector{MPI.Request}15mpi_recv_requests::Vector{MPI.Request}16n_elements_by_rank::OffsetArray{Int, 1, Array{Int, 1}}17n_elements_global::Int18first_element_global_id::Int19end2021function P4estMPICache(uEltype)22# MPI communication "just works" for bitstypes only23if !isbitstype(uEltype)24throw(ArgumentError("P4estMPICache only supports bitstypes, $uEltype is not a bitstype."))25end2627mpi_neighbor_ranks = Vector{Int}(undef, 0)28mpi_neighbor_interfaces = Vector{Vector{Int}}(undef, 0) |> VecOfArrays29mpi_neighbor_mortars = Vector{Vector{Int}}(undef, 0) |> VecOfArrays30mpi_send_buffers = Vector{Vector{uEltype}}(undef, 0) |> VecOfArrays31mpi_recv_buffers = Vector{Vector{uEltype}}(undef, 0) |> VecOfArrays32mpi_send_requests = Vector{MPI.Request}(undef, 0)33mpi_recv_requests = Vector{MPI.Request}(undef, 0)34n_elements_by_rank = OffsetArray(Vector{Int}(undef, 0), 0:-1)35n_elements_global = 036first_element_global_id = 03738return P4estMPICache{Vector{uEltype}, Vector{Int}}(mpi_neighbor_ranks,39mpi_neighbor_interfaces,40mpi_neighbor_mortars,41mpi_send_buffers,42mpi_recv_buffers,43mpi_send_requests,44mpi_recv_requests,45n_elements_by_rank,46n_elements_global,47first_element_global_id)48end4950@inline Base.eltype(::P4estMPICache{BufferType}) where {BufferType} = eltype(BufferType)5152# @eval due to @muladd53@eval Adapt.@adapt_structure(P4estMPICache)5455##56# Note that the code in `start_mpi_send`/`finish_mpi_receive!` is sensitive to inference on (at least) Julia 1.10.57# Julia's inference is bi-stable, it can sometimes depend on what code has been looked at already, and58# the presence of an inference result in the cache can have an impact on the inference of code.59# In this case the `send_buffer[first:last] .= vec(cache.mpi_mortars.u[2, :, :, ..,mortar])`,60# can fail to be inferred due to heuristics if this function is not in the cache...61precompile(Base.reindex,62(Tuple{Base.Slice{Base.OneTo{Int64}}, Int64, Base.Slice{Base.OneTo{Int64}},63Base.Slice{Base.OneTo{Int64}}, Int64}, Tuple{Int64, Int64, Int64}))6465function start_mpi_send!(mpi_cache::P4estMPICache, mesh, equations, dg, cache)66data_size = nvariables(equations) * nnodes(dg)^(ndims(mesh) - 1)67n_small_elements = 2^(ndims(mesh) - 1)6869for rank in 1:length(mpi_cache.mpi_neighbor_ranks)70send_buffer = mpi_cache.mpi_send_buffers[rank]7172for (index, interface) in enumerate(mpi_cache.mpi_neighbor_interfaces[rank])73first = (index - 1) * data_size + 174last = (index - 1) * data_size + data_size75local_side = cache.mpi_interfaces.local_sides[interface]76@views send_buffer[first:last] .= vec(cache.mpi_interfaces.u[local_side, ..,77interface])78end7980# Set send_buffer corresponding to mortar data to NaN and overwrite the parts where local81# data exists82interfaces_data_size = length(mpi_cache.mpi_neighbor_interfaces[rank]) *83data_size84mortars_data_size = length(mpi_cache.mpi_neighbor_mortars[rank]) *85n_small_elements * 2 * data_size86# `NaN |> eltype(...)` ensures that the NaN's are of the appropriate floating point type87send_buffer[(interfaces_data_size + 1):(interfaces_data_size + mortars_data_size)] .= NaN |>88eltype(mpi_cache)8990for (index, mortar) in enumerate(mpi_cache.mpi_neighbor_mortars[rank])91index_base = interfaces_data_size +92(index - 1) * n_small_elements * 2 * data_size93indices = buffer_mortar_indices(mesh, index_base, data_size)9495for position in cache.mpi_mortars.local_neighbor_positions[mortar]96first, last = indices[position]97if position > n_small_elements # large element98@views send_buffer[first:last] .= vec(cache.mpi_mortars.u[2, :, :,99..,100mortar])101else # small element102@views send_buffer[first:last] .= vec(cache.mpi_mortars.u[1, :,103position,104..,105mortar])106end107end108end109end110111# Start sending112for (index, rank) in enumerate(mpi_cache.mpi_neighbor_ranks)113mpi_cache.mpi_send_requests[index] = MPI.Isend(mpi_cache.mpi_send_buffers[index],114rank, mpi_rank(), mpi_comm())115end116117return nothing118end119120function start_mpi_receive!(mpi_cache::P4estMPICache)121for (index, rank) in enumerate(mpi_cache.mpi_neighbor_ranks)122mpi_cache.mpi_recv_requests[index] = MPI.Irecv!(mpi_cache.mpi_recv_buffers[index],123rank, rank, mpi_comm())124end125126return nothing127end128129function finish_mpi_send!(mpi_cache::P4estMPICache)130return MPI.Waitall(mpi_cache.mpi_send_requests, MPI.Status)131end132133function finish_mpi_receive!(mpi_cache::P4estMPICache, mesh, equations, dg, cache)134data_size = nvariables(equations) * nnodes(dg)^(ndims(mesh) - 1)135n_small_elements = 2^(ndims(mesh) - 1)136n_positions = n_small_elements + 1137138# Start receiving and unpack received data until all communication is finished139data = MPI.Waitany(mpi_cache.mpi_recv_requests)140while data !== nothing141recv_buffer = mpi_cache.mpi_recv_buffers[data]142143for (index, interface) in enumerate(mpi_cache.mpi_neighbor_interfaces[data])144first = (index - 1) * data_size + 1145last = (index - 1) * data_size + data_size146147if cache.mpi_interfaces.local_sides[interface] == 1 # local element on primary side148@views vec(cache.mpi_interfaces.u[2, .., interface]) .= recv_buffer[first:last]149else # local element at secondary side150@views vec(cache.mpi_interfaces.u[1, .., interface]) .= recv_buffer[first:last]151end152end153154interfaces_data_size = length(mpi_cache.mpi_neighbor_interfaces[data]) *155data_size156for (index, mortar) in enumerate(mpi_cache.mpi_neighbor_mortars[data])157index_base = interfaces_data_size +158(index - 1) * n_small_elements * 2 * data_size159indices = buffer_mortar_indices(mesh, index_base, data_size)160161for position in 1:n_positions162# Skip if received data for `position` is NaN as no real data has been sent for the163# corresponding element164if isnan(recv_buffer[Base.first(indices[position])])165continue166end167168first, last = indices[position]169if position == n_positions # large element170@views vec(cache.mpi_mortars.u[2, :, :, .., mortar]) .= recv_buffer[first:last]171else # small element172@views vec(cache.mpi_mortars.u[1, :, position, .., mortar]) .= recv_buffer[first:last]173end174end175end176177data = MPI.Waitany(mpi_cache.mpi_recv_requests)178end179180return nothing181end182183# Return a tuple `indices` where indices[position] is a `(first, last)` tuple for accessing the184# data corresponding to the `position` part of a mortar in an MPI buffer. The mortar data must begin185# at `index_base`+1 in the MPI buffer. `data_size` is the data size associated with each small186# position (i.e. position 1 or 2). The data corresponding to the large side (i.e. position 3) has187# size `2 * data_size`.188@inline function buffer_mortar_indices(mesh::Union{P4estMeshParallel{2},189T8codeMeshParallel{2}}, index_base,190data_size)191return (192# first, last for local element in position 1 (small element)193(index_base + 1,194index_base + 1 * data_size),195# first, last for local element in position 2 (small element)196(index_base + 1 * data_size + 1,197index_base + 2 * data_size),198# first, last for local element in position 3 (large element)199(index_base + 2 * data_size + 1,200index_base + 4 * data_size))201end202203# Return a tuple `indices` where indices[position] is a `(first, last)` tuple for accessing the204# data corresponding to the `position` part of a mortar in an MPI buffer. The mortar data must begin205# at `index_base`+1 in the MPI buffer. `data_size` is the data size associated with each small206# position (i.e. position 1 to 4). The data corresponding to the large side (i.e. position 5) has207# size `4 * data_size`.208@inline function buffer_mortar_indices(mesh::Union{P4estMeshParallel{3},209T8codeMeshParallel{3}}, index_base,210data_size)211return (212# first, last for local element in position 1 (small element)213(index_base + 1,214index_base + 1 * data_size),215# first, last for local element in position 2 (small element)216(index_base + 1 * data_size + 1,217index_base + 2 * data_size),218# first, last for local element in position 3 (small element)219(index_base + 2 * data_size + 1,220index_base + 3 * data_size),221# first, last for local element in position 4 (small element)222(index_base + 3 * data_size + 1,223index_base + 4 * data_size),224# first, last for local element in position 5 (large element)225(index_base + 4 * data_size + 1,226index_base + 8 * data_size))227end228229# This method is called when a SemidiscretizationHyperbolic is constructed.230# It constructs the basic `cache` used throughout the simulation to compute231# the RHS etc.232function create_cache(mesh::P4estMeshParallel, equations::AbstractEquations, dg::DG,233::Any, ::Type{uEltype}) where {uEltype <: Real}234# Make sure to balance and partition the p4est and create a new ghost layer before creating any235# containers in case someone has tampered with the p4est after creating the mesh236balance!(mesh)237partition!(mesh)238update_ghost_layer!(mesh)239240elements = init_elements(mesh, equations, dg.basis, uEltype)241242mpi_interfaces = init_mpi_interfaces(mesh, equations, dg.basis, elements)243mpi_mortars = init_mpi_mortars(mesh, equations, dg.basis, elements)244mpi_cache = init_mpi_cache(mesh, mpi_interfaces, mpi_mortars,245nvariables(equations), nnodes(dg), uEltype)246247exchange_normal_directions!(mpi_mortars, mpi_cache, mesh, nnodes(dg))248249interfaces = init_interfaces(mesh, equations, dg.basis, elements)250boundaries = init_boundaries(mesh, equations, dg.basis, elements)251mortars = init_mortars(mesh, equations, dg.basis, elements)252253# Container cache254cache = (; elements, interfaces, mpi_interfaces, boundaries, mortars,255mpi_mortars, mpi_cache)256257# Add Volume-Integral cache258cache = (; cache...,259create_cache(mesh, equations, dg.volume_integral, dg, cache, uEltype)...)260# Add Mortar cache261cache = (; cache..., create_cache(mesh, equations, dg.mortar, uEltype)...)262263return cache264end265266function init_mpi_cache(mesh::P4estMeshParallel, mpi_interfaces, mpi_mortars, nvars,267nnodes, uEltype)268mpi_cache = P4estMPICache(uEltype)269init_mpi_cache!(mpi_cache, mesh, mpi_interfaces, mpi_mortars, nvars, nnodes,270uEltype)271272return mpi_cache273end274275function init_mpi_cache!(mpi_cache::P4estMPICache, mesh::P4estMeshParallel,276mpi_interfaces, mpi_mortars, nvars, n_nodes, uEltype)277mpi_neighbor_ranks, _mpi_neighbor_interfaces, _mpi_neighbor_mortars = init_mpi_neighbor_connectivity(mpi_interfaces,278mpi_mortars,279mesh)280281_mpi_send_buffers, _mpi_recv_buffers, mpi_send_requests, mpi_recv_requests = init_mpi_data_structures(_mpi_neighbor_interfaces,282_mpi_neighbor_mortars,283ndims(mesh),284nvars,285n_nodes,286uEltype)287288# Determine local and total number of elements289n_elements_global = Int(mesh.p4est.global_num_quadrants[])290n_elements_by_rank = vcat(Int.(unsafe_wrap(Array, mesh.p4est.global_first_quadrant,291mpi_nranks())),292n_elements_global) |> diff # diff sufficient due to 0-based quad indices293n_elements_by_rank = OffsetArray(n_elements_by_rank, 0:(mpi_nranks() - 1))294# Account for 1-based indexing in Julia295first_element_global_id = Int(mesh.p4est.global_first_quadrant[mpi_rank() + 1]) + 1296@assert n_elements_global==sum(n_elements_by_rank) "error in total number of elements"297298mpi_neighbor_interfaces = VecOfArrays(_mpi_neighbor_interfaces)299mpi_neighbor_mortars = VecOfArrays(_mpi_neighbor_mortars)300mpi_send_buffers = VecOfArrays(_mpi_send_buffers)301mpi_recv_buffers = VecOfArrays(_mpi_recv_buffers)302303# TODO reuse existing structures304@pack! mpi_cache = mpi_neighbor_ranks, mpi_neighbor_interfaces,305mpi_neighbor_mortars,306mpi_send_buffers, mpi_recv_buffers,307mpi_send_requests, mpi_recv_requests,308n_elements_by_rank, n_elements_global,309first_element_global_id310end311312function init_mpi_neighbor_connectivity(mpi_interfaces, mpi_mortars,313mesh::P4estMeshParallel)314# Let p4est iterate over all interfaces and call init_neighbor_rank_connectivity_iter_face315# to collect connectivity information316iter_face_c = cfunction(init_neighbor_rank_connectivity_iter_face, Val(ndims(mesh)))317user_data = InitNeighborRankConnectivityIterFaceUserData(mpi_interfaces,318mpi_mortars, mesh)319320iterate_p4est(mesh.p4est, user_data; ghost_layer = mesh.ghost,321iter_face_c = iter_face_c)322323# Build proper connectivity data structures from information gathered by iterating over p4est324@unpack global_interface_ids, neighbor_ranks_interface, global_mortar_ids, neighbor_ranks_mortar = user_data325326mpi_neighbor_ranks = vcat(neighbor_ranks_interface, neighbor_ranks_mortar...) |>327sort |> unique328329p = sortperm(global_interface_ids)330neighbor_ranks_interface .= neighbor_ranks_interface[p]331interface_ids = collect(1:nmpiinterfaces(mpi_interfaces))[p]332333p = sortperm(global_mortar_ids)334neighbor_ranks_mortar .= neighbor_ranks_mortar[p]335mortar_ids = collect(1:nmpimortars(mpi_mortars))[p]336337# For each neighbor rank, init connectivity data structures338mpi_neighbor_interfaces = Vector{Vector{Int}}(undef, length(mpi_neighbor_ranks))339mpi_neighbor_mortars = Vector{Vector{Int}}(undef, length(mpi_neighbor_ranks))340for (index, rank) in enumerate(mpi_neighbor_ranks)341mpi_neighbor_interfaces[index] = interface_ids[findall(==(rank),342neighbor_ranks_interface)]343mpi_neighbor_mortars[index] = mortar_ids[findall(x -> (rank in x),344neighbor_ranks_mortar)]345end346347# Check that all interfaces were counted exactly once348@assert mapreduce(length, +, mpi_neighbor_interfaces; init = 0) ==349nmpiinterfaces(mpi_interfaces)350351return mpi_neighbor_ranks, mpi_neighbor_interfaces, mpi_neighbor_mortars352end353354mutable struct InitNeighborRankConnectivityIterFaceUserData{MPIInterfaces, MPIMortars,355Mesh}356interfaces::MPIInterfaces357interface_id::Int358global_interface_ids::Vector{Int}359neighbor_ranks_interface::Vector{Int}360mortars::MPIMortars361mortar_id::Int362global_mortar_ids::Vector{Int}363neighbor_ranks_mortar::Vector{Vector{Int}}364mesh::Mesh365end366367function InitNeighborRankConnectivityIterFaceUserData(mpi_interfaces, mpi_mortars, mesh)368global_interface_ids = fill(-1, nmpiinterfaces(mpi_interfaces))369neighbor_ranks_interface = fill(-1, nmpiinterfaces(mpi_interfaces))370global_mortar_ids = fill(-1, nmpimortars(mpi_mortars))371neighbor_ranks_mortar = Vector{Vector{Int}}(undef, nmpimortars(mpi_mortars))372373return InitNeighborRankConnectivityIterFaceUserData{typeof(mpi_interfaces),374typeof(mpi_mortars),375typeof(mesh)}(mpi_interfaces, 1,376global_interface_ids,377neighbor_ranks_interface,378mpi_mortars, 1,379global_mortar_ids,380neighbor_ranks_mortar,381mesh)382end383384function init_neighbor_rank_connectivity_iter_face(info, user_data)385data = unsafe_pointer_to_objref(Ptr{InitNeighborRankConnectivityIterFaceUserData}(user_data))386387# Function barrier because the unpacked user_data above is not type-stable388return init_neighbor_rank_connectivity_iter_face_inner(info, data)389end390391# 2D392function cfunction(::typeof(init_neighbor_rank_connectivity_iter_face), ::Val{2})393@cfunction(init_neighbor_rank_connectivity_iter_face, Cvoid,394(Ptr{p4est_iter_face_info_t}, Ptr{Cvoid}))395end396# 3D397function cfunction(::typeof(init_neighbor_rank_connectivity_iter_face), ::Val{3})398@cfunction(init_neighbor_rank_connectivity_iter_face, Cvoid,399(Ptr{p8est_iter_face_info_t}, Ptr{Cvoid}))400end401402# Function barrier for type stability403function init_neighbor_rank_connectivity_iter_face_inner(info, user_data)404@unpack interfaces, interface_id, global_interface_ids, neighbor_ranks_interface,405mortars, mortar_id, global_mortar_ids, neighbor_ranks_mortar, mesh = user_data406407info_pw = PointerWrapper(info)408# Get the global interface/mortar ids and neighbor rank if current face belongs to an MPI409# interface/mortar410if info_pw.sides.elem_count[] == 2 # MPI interfaces/mortars have two neighboring elements411# Extract surface data412sides_pw = (load_pointerwrapper_side(info_pw, 1),413load_pointerwrapper_side(info_pw, 2))414415if sides_pw[1].is_hanging[] == false && sides_pw[2].is_hanging[] == false # No hanging nodes for MPI interfaces416if sides_pw[1].is.full.is_ghost[] == true417remote_side = 1418local_side = 2419elseif sides_pw[2].is.full.is_ghost[] == true420remote_side = 2421local_side = 1422else # both sides are on this rank -> skip since it's a regular interface423return nothing424end425426# Sanity check, current face should belong to current MPI interface427local_tree_pw = load_pointerwrapper_tree(mesh.p4est,428sides_pw[local_side].treeid[] + 1) # one-based indexing429local_quad_id = local_tree_pw.quadrants_offset[] +430sides_pw[local_side].is.full.quadid[]431@assert interfaces.local_neighbor_ids[interface_id] == local_quad_id + 1 # one-based indexing432433# Get neighbor ID from ghost layer434proc_offsets = unsafe_wrap(Array,435info_pw.ghost_layer.proc_offsets,436mpi_nranks() + 1)437ghost_id = sides_pw[remote_side].is.full.quadid[] # indexes the ghost layer, 0-based438neighbor_rank = findfirst(r -> proc_offsets[r] <= ghost_id <439proc_offsets[r + 1],4401:mpi_nranks()) - 1 # MPI ranks are 0-based441neighbor_ranks_interface[interface_id] = neighbor_rank442443# Global interface id is the globally unique quadrant id of the quadrant on the primary444# side (1) multiplied by the number of faces per quadrant plus face445if local_side == 1446offset = mesh.p4est.global_first_quadrant[mpi_rank() + 1] # one-based indexing447primary_quad_id = offset + local_quad_id448else449offset = mesh.p4est.global_first_quadrant[neighbor_rank + 1] # one-based indexing450primary_quad_id = offset + sides_pw[1].is.full.quad.p.piggy3.local_num[]451end452global_interface_id = 2 * ndims(mesh) * primary_quad_id + sides_pw[1].face[]453global_interface_ids[interface_id] = global_interface_id454455user_data.interface_id += 1456else # hanging node457if sides_pw[1].is_hanging[] == true458hanging_side = 1459full_side = 2460else461hanging_side = 2462full_side = 1463end464# Verify before accessing is.full / is.hanging465@assert sides_pw[hanging_side].is_hanging[] == true &&466sides_pw[full_side].is_hanging[] == false467468# If all quadrants are locally available, this is a regular mortar -> skip469if sides_pw[full_side].is.full.is_ghost[] == false &&470all(sides_pw[hanging_side].is.hanging.is_ghost[] .== false)471return nothing472end473474trees_pw = (load_pointerwrapper_tree(mesh.p4est, sides_pw[1].treeid[] + 1),475load_pointerwrapper_tree(mesh.p4est, sides_pw[2].treeid[] + 1))476477# Find small quads that are remote and determine which rank owns them478remote_small_quad_positions = findall(sides_pw[hanging_side].is.hanging.is_ghost[] .==479true)480proc_offsets = unsafe_wrap(Array,481info_pw.ghost_layer.proc_offsets,482mpi_nranks() + 1)483# indices of small remote quads inside the ghost layer, 0-based484ghost_ids = map(pos -> sides_pw[hanging_side].is.hanging.quadid[][pos],485remote_small_quad_positions)486neighbor_ranks = map(ghost_ids) do ghost_id487return findfirst(r -> proc_offsets[r] <= ghost_id < proc_offsets[r + 1],4881:mpi_nranks()) - 1 # MPI ranks are 0-based489end490# Determine global quad id of large element to determine global MPI mortar id491# Furthermore, if large element is ghost, add its owner rank to neighbor_ranks492if sides_pw[full_side].is.full.is_ghost[] == true493ghost_id = sides_pw[full_side].is.full.quadid[]494large_quad_owner_rank = findfirst(r -> proc_offsets[r] <= ghost_id <495proc_offsets[r + 1],4961:mpi_nranks()) - 1 # MPI ranks are 0-based497push!(neighbor_ranks, large_quad_owner_rank)498499offset = mesh.p4est.global_first_quadrant[large_quad_owner_rank + 1] # one-based indexing500large_quad_id = offset +501sides_pw[full_side].is.full.quad.p.piggy3.local_num[]502else503offset = mesh.p4est.global_first_quadrant[mpi_rank() + 1] # one-based indexing504large_quad_id = offset + trees_pw[full_side].quadrants_offset[] +505sides_pw[full_side].is.full.quadid[]506end507neighbor_ranks_mortar[mortar_id] = neighbor_ranks508# Global mortar id is the globally unique quadrant id of the large quadrant multiplied by the509# number of faces per quadrant plus face510global_mortar_ids[mortar_id] = 2 * ndims(mesh) * large_quad_id +511sides_pw[full_side].face[]512513user_data.mortar_id += 1514end515end516517return nothing518end519520# Exchange normal directions of small elements of the MPI mortars. They are needed on all involved521# MPI ranks to calculate the mortar fluxes.522function exchange_normal_directions!(mpi_mortars, mpi_cache,523mesh::Union{P4estMeshParallel, T8codeMeshParallel},524n_nodes)525RealT = real(mesh)526n_dims = ndims(mesh)527@unpack mpi_neighbor_mortars, mpi_neighbor_ranks = mpi_cache528n_small_elements = 2^(n_dims - 1)529data_size = n_nodes^(n_dims - 1) * n_dims530531# Create buffers and requests532send_buffers = Vector{Vector{RealT}}(undef, length(mpi_neighbor_mortars))533recv_buffers = Vector{Vector{RealT}}(undef, length(mpi_neighbor_mortars))534for index in 1:length(mpi_neighbor_mortars)535send_buffers[index] = Vector{RealT}(undef,536length(mpi_neighbor_mortars[index]) *537n_small_elements * data_size)538send_buffers[index] .= NaN |> RealT539recv_buffers[index] = Vector{RealT}(undef,540length(mpi_neighbor_mortars[index]) *541n_small_elements * data_size)542recv_buffers[index] .= NaN |> RealT543end544send_requests = Vector{MPI.Request}(undef, length(mpi_neighbor_mortars))545recv_requests = Vector{MPI.Request}(undef, length(mpi_neighbor_mortars))546547# Fill send buffers548for rank in 1:length(mpi_neighbor_ranks)549send_buffer = send_buffers[rank]550551for (index, mortar) in enumerate(mpi_neighbor_mortars[rank])552index_base = (index - 1) * n_small_elements * data_size553indices = buffer_mortar_indices(mesh, index_base, data_size)554for position in mpi_mortars.local_neighbor_positions[mortar]555if position <= n_small_elements # element is small556first, last = indices[position]557@views send_buffer[first:last] .= vec(mpi_mortars.normal_directions[:,558..,559position,560mortar])561end562end563end564end565566# Start data exchange567for (index, rank) in enumerate(mpi_neighbor_ranks)568send_requests[index] = MPI.Isend(send_buffers[index], rank, mpi_rank(),569mpi_comm())570recv_requests[index] = MPI.Irecv!(recv_buffers[index], rank, rank, mpi_comm())571end572573# Unpack data from receive buffers574data = MPI.Waitany(recv_requests)575while data !== nothing576recv_buffer = recv_buffers[data]577578for (index, mortar) in enumerate(mpi_neighbor_mortars[data])579index_base = (index - 1) * n_small_elements * data_size580indices = buffer_mortar_indices(mesh, index_base, data_size)581for position in 1:n_small_elements582# Skip if received data for `position` is NaN as no real data has been sent for the583# corresponding element584if isnan(recv_buffer[Base.first(indices[position])])585continue586end587588first, last = indices[position]589@views vec(mpi_mortars.normal_directions[:, .., position, mortar]) .= recv_buffer[first:last]590end591end592593data = MPI.Waitany(recv_requests)594end595596# Wait for communication to finish597MPI.Waitall(send_requests, MPI.Status)598599return nothing600end601602# Get normal direction of MPI mortar603@inline function get_normal_direction(mpi_mortars::P4estMPIMortarContainer, indices...)604return SVector(ntuple(@inline(dim->mpi_mortars.normal_directions[dim, indices...]),605Val(ndims(mpi_mortars))))606end607608include("dg_2d_parallel.jl")609include("dg_3d_parallel.jl")610include("dg_2d_parabolic_parallel.jl")611end # muladd612613614