Path: blob/main/src/solvers/dgsem_t8code/dg_parallel.jl
5590 views
@muladd begin1#! format: noindent23# This method is called when a `SemidiscretizationHyperbolic` is constructed.4# It constructs the basic `cache` used throughout the simulation to compute5# the RHS etc.6function create_cache(mesh::T8codeMeshParallel, equations::AbstractEquations, dg::DG,7::Any,8::Type{uEltype}) where {uEltype <: Real}9# Make sure to balance and partition the forest before creating any10# containers in case someone has tampered with forest after creating the11# mesh.12balance!(mesh)13partition!(mesh)1415count_required_surfaces!(mesh)1617elements = init_elements(mesh, equations, dg.basis, uEltype)18mortars = init_mortars(mesh, equations, dg.basis, elements)19interfaces = init_interfaces(mesh, equations, dg.basis, elements)20boundaries = init_boundaries(mesh, equations, dg.basis, elements)2122mpi_mortars = init_mpi_mortars(mesh, equations, dg.basis, elements)23mpi_interfaces = init_mpi_interfaces(mesh, equations, dg.basis, elements)2425mpi_mesh_info = (mpi_mortars = mpi_mortars,26mpi_interfaces = mpi_interfaces,27global_mortar_ids = fill(UInt128(0), nmpimortars(mpi_mortars)),28global_interface_ids = fill(UInt128(0),29nmpiinterfaces(mpi_interfaces)),30neighbor_ranks_mortar = Vector{Vector{Int}}(undef,31nmpimortars(mpi_mortars)),32neighbor_ranks_interface = fill(-1,33nmpiinterfaces(mpi_interfaces)))3435fill_mesh_info!(mesh, interfaces, mortars, boundaries,36mesh.boundary_names; mpi_mesh_info = mpi_mesh_info)3738mpi_cache = init_mpi_cache(mesh, mpi_mesh_info, nvariables(equations), nnodes(dg),39uEltype)4041empty!(mpi_mesh_info.global_mortar_ids)42empty!(mpi_mesh_info.global_interface_ids)43empty!(mpi_mesh_info.neighbor_ranks_mortar)44empty!(mpi_mesh_info.neighbor_ranks_interface)4546init_normal_directions!(mpi_mortars, dg.basis, elements)47exchange_normal_directions!(mpi_mortars, mpi_cache, mesh, nnodes(dg))4849# Container cache50cache = (; elements, interfaces, mpi_interfaces, boundaries, mortars,51mpi_mortars, mpi_cache)5253# Add Volume-Integral cache54cache = (; cache...,55create_cache(mesh, equations, dg.volume_integral, dg, cache, uEltype)...)56# Add Mortar cache57cache = (; cache..., create_cache(mesh, equations, dg.mortar, uEltype)...)5859return cache60end6162function init_mpi_cache(mesh::T8codeMeshParallel, mpi_mesh_info, nvars, nnodes, uEltype)63mpi_cache = P4estMPICache(uEltype)64init_mpi_cache!(mpi_cache, mesh, mpi_mesh_info, nvars, nnodes, uEltype)65return mpi_cache66end6768function init_mpi_cache!(mpi_cache::P4estMPICache, mesh::T8codeMeshParallel,69mpi_mesh_info, nvars, nnodes, uEltype)70mpi_neighbor_ranks, mpi_neighbor_interfaces, mpi_neighbor_mortars = init_mpi_neighbor_connectivity(mpi_mesh_info,71mesh)7273mpi_send_buffers, mpi_recv_buffers, mpi_send_requests, mpi_recv_requests = init_mpi_data_structures(mpi_neighbor_interfaces,74mpi_neighbor_mortars,75ndims(mesh),76nvars,77nnodes,78uEltype)79n_elements_global = Int(t8_forest_get_global_num_elements(mesh.forest))80n_elements_local = Int(t8_forest_get_local_num_elements(mesh.forest))8182n_elements_by_rank = Vector{Int}(undef, mpi_nranks())83n_elements_by_rank[mpi_rank() + 1] = n_elements_local8485MPI.Allgather!(MPI.UBuffer(n_elements_by_rank, 1), mpi_comm())8687n_elements_by_rank = OffsetArray(n_elements_by_rank, 0:(mpi_nranks() - 1))8889# Account for 1-based indexing in Julia.90first_element_global_id = sum(n_elements_by_rank[0:(mpi_rank() - 1)]) + 19192@assert n_elements_global==sum(n_elements_by_rank) "error in total number of elements"9394@pack! mpi_cache = mpi_neighbor_ranks, mpi_neighbor_interfaces,95mpi_neighbor_mortars,96mpi_send_buffers, mpi_recv_buffers,97mpi_send_requests, mpi_recv_requests,98n_elements_by_rank, n_elements_global,99first_element_global_id100101return mpi_cache102end103104function init_mpi_neighbor_connectivity(mpi_mesh_info, mesh::T8codeMeshParallel)105@unpack mpi_interfaces, mpi_mortars, global_interface_ids, neighbor_ranks_interface, global_mortar_ids, neighbor_ranks_mortar = mpi_mesh_info106107mpi_neighbor_ranks = vcat(neighbor_ranks_interface, neighbor_ranks_mortar...) |>108sort |> unique109110p = sortperm(global_interface_ids)111112neighbor_ranks_interface .= neighbor_ranks_interface[p]113interface_ids = collect(1:nmpiinterfaces(mpi_interfaces))[p]114115p = sortperm(global_mortar_ids)116neighbor_ranks_mortar .= neighbor_ranks_mortar[p]117mortar_ids = collect(1:nmpimortars(mpi_mortars))[p]118119# For each neighbor rank, init connectivity data structures120mpi_neighbor_interfaces = Vector{Vector{Int}}(undef, length(mpi_neighbor_ranks))121mpi_neighbor_mortars = Vector{Vector{Int}}(undef, length(mpi_neighbor_ranks))122for (index, rank) in enumerate(mpi_neighbor_ranks)123mpi_neighbor_interfaces[index] = interface_ids[findall(==(rank),124neighbor_ranks_interface)]125mpi_neighbor_mortars[index] = mortar_ids[findall(x -> (rank in x),126neighbor_ranks_mortar)]127end128129# Check that all interfaces were counted exactly once130@assert mapreduce(length, +, mpi_neighbor_interfaces; init = 0) ==131nmpiinterfaces(mpi_interfaces)132133return mpi_neighbor_ranks, mpi_neighbor_interfaces, mpi_neighbor_mortars134end135end # @muladd136137138