Path: blob/main/src/semidiscretization/semidiscretization_coupled_p4est.jl
5586 views
# By default, Julia/LLVM does not use fused multiply-add operations (FMAs).1# Since these FMAs can increase the performance of many numerical algorithms,2# we need to opt-in explicitly.3# See https://ranocha.de/blog/Optimizing_EC_Trixi for further details.4@muladd begin5#! format: noindent67"""8SemidiscretizationCoupledP4est910Specialized semidiscretization routines for coupled problems using P4est mesh views.11This is analogous to the implementation for structured meshes.12[`semidiscretize`](@ref) will return an `ODEProblem` that synchronizes time steps between the semidiscretizations.13Each call of `rhs!` will call `rhs!` for each semidiscretization individually.14The semidiscretizations can be coupled by glueing meshes together using [`BoundaryConditionCoupled`](@ref).1516See also: [`SemidiscretizationCoupled`](@ref)1718!!! warning "Experimental code"19This is an experimental feature and can change any time.20"""21mutable struct SemidiscretizationCoupledP4est{Semis, Indices, EquationList} <:22AbstractSemidiscretization23semis::Semis24u_indices::Indices # u_ode[u_indices[i]] is the part of u_ode corresponding to semis[i]25performance_counter::PerformanceCounter26parent_cell_ids::Vector{Int}27view_cell_ids::Vector{Int}28mesh_ids::Vector{Int}29end3031"""32SemidiscretizationCoupledP4est(semis...)3334Create a coupled semidiscretization that consists of the semidiscretizations passed as arguments.35"""36function SemidiscretizationCoupledP4est(semis...)37@assert all(semi -> ndims(semi) == ndims(semis[1]), semis) "All semidiscretizations must have the same dimension!"3839# Number of coefficients for each semidiscretization40n_coefficients = zeros(Int, length(semis))41for i in 1:length(semis)42_, equations, _, _ = mesh_equations_solver_cache(semis[i])43n_coefficients[i] = ndofs(semis[i]) * nvariables(equations)44end4546# Compute range of coefficients associated with each semidiscretization47u_indices = Vector{UnitRange{Int}}(undef, length(semis))48for i in 1:length(semis)49offset = sum(n_coefficients[1:(i - 1)]) + 150u_indices[i] = range(offset, length = n_coefficients[i])51end5253# Create correspondence between parent mesh cell IDs and view cell IDs.54parent_cell_ids = 1:size(semis[1].mesh.parent.tree_node_coordinates)[end]55view_cell_ids = zeros(Int, length(parent_cell_ids))56mesh_ids = zeros(Int, length(parent_cell_ids))57for i in eachindex(semis)58view_cell_ids[semis[i].mesh.cell_ids] = parent_cell_id_to_view(parent_cell_ids[semis[i].mesh.cell_ids],59semis[i].mesh)60mesh_ids[semis[i].mesh.cell_ids] .= i61end6263performance_counter = PerformanceCounter()6465SemidiscretizationCoupledP4est{typeof(semis), typeof(u_indices),66typeof(performance_counter)}(semis, u_indices,67performance_counter,68parent_cell_ids,69view_cell_ids,70mesh_ids)71end7273function Base.show(io::IO, ::MIME"text/plain", semi::SemidiscretizationCoupledP4est)74@nospecialize semi # reduce precompilation time7576if get(io, :compact, false)77show(io, semi)78else79summary_header(io, "SemidiscretizationCoupledP4est")80summary_line(io, "#spatial dimensions", ndims(semi.semis[1]))81summary_line(io, "#systems", nsystems(semi))82for i in eachsystem(semi)83summary_line(io, "system", i)84mesh, equations, solver, _ = mesh_equations_solver_cache(semi.semis[i])85summary_line(increment_indent(io), "mesh", mesh |> typeof |> nameof)86summary_line(increment_indent(io), "equations",87equations |> typeof |> nameof)88summary_line(increment_indent(io), "initial condition",89semi.semis[i].initial_condition)90# no boundary conditions since that could be too much91summary_line(increment_indent(io), "source terms",92semi.semis[i].source_terms)93summary_line(increment_indent(io), "solver", solver |> typeof |> nameof)94end95summary_line(io, "total #DOFs per field", ndofsglobal(semi))96summary_footer(io)97end98end99100function print_summary_semidiscretization(io::IO, semi::SemidiscretizationCoupledP4est)101show(io, MIME"text/plain"(), semi)102println(io, "\n")103for i in eachsystem(semi)104mesh, equations, solver, _ = mesh_equations_solver_cache(semi.semis[i])105summary_header(io, "System #$i")106107summary_line(io, "mesh", mesh |> typeof |> nameof)108show(increment_indent(io), MIME"text/plain"(), mesh)109110summary_line(io, "equations", equations |> typeof |> nameof)111show(increment_indent(io), MIME"text/plain"(), equations)112113summary_line(io, "solver", solver |> typeof |> nameof)114show(increment_indent(io), MIME"text/plain"(), solver)115116summary_footer(io)117println(io, "\n")118end119end120121@inline nsystems(semi::SemidiscretizationCoupledP4est) = length(semi.semis)122123@inline eachsystem(semi::SemidiscretizationCoupledP4est) = Base.OneTo(nsystems(semi))124125@inline Base.real(semi::SemidiscretizationCoupledP4est) = promote_type(real.(semi.semis)...)126127@inline function ndofs(semi::SemidiscretizationCoupledP4est)128return sum(ndofs, semi.semis)129end130131"""132ndofsglobal(semi::SemidiscretizationCoupledP4est)133134Return the global number of degrees of freedom associated with each scalar variable across all MPI ranks, and summed up over all coupled systems.135This is the same as [`ndofs`](@ref) for simulations running in serial or136parallelized via threads. It will in general be different for simulations137running in parallel with MPI.138"""139@inline function ndofsglobal(semi::SemidiscretizationCoupledP4est)140return sum(ndofsglobal, semi.semis)141end142143function compute_coefficients(t, semi::SemidiscretizationCoupledP4est)144@unpack u_indices = semi145146u_ode = Vector{real(semi)}(undef, u_indices[end][end])147148# Distribute the partial solution vectors onto the global one.149@threaded for i in eachsystem(semi)150# Call `compute_coefficients` in `src/semidiscretization/semidiscretization.jl`151u_ode[u_indices[i]] .= compute_coefficients(t, semi.semis[i])152end153154return u_ode155end156157@inline function get_system_u_ode(u_ode, index, semi::SemidiscretizationCoupledP4est)158return @view u_ode[semi.u_indices[index]]159end160161# RHS call for the coupled system.162function rhs!(du_ode, u_ode, semi::SemidiscretizationCoupledP4est, t)163time_start = time_ns()164165n_nodes = length(semi.semis[1].mesh.parent.nodes)166# Reformat the parent solutions vector.167u_ode_reformatted = Vector{real(semi)}(undef, ndofs(semi))168u_ode_reformatted_reshape = reshape(u_ode_reformatted,169(n_nodes,170n_nodes,171length(semi.mesh_ids)))172# Extract the parent solution vector from the local solutions.173foreach_enumerate(semi.semis) do (i, semi_)174system_ode = get_system_u_ode(u_ode, i, semi)175system_ode_reshape = reshape(system_ode,176(n_nodes, n_nodes,177Int(length(system_ode) /178n_nodes^ndims(semi_.mesh))))179u_ode_reformatted_reshape[:, :, semi.mesh_ids .== i] .= system_ode_reshape180end181182# Call rhs! for each semidiscretization183foreach_enumerate(semi.semis) do (i, semi_)184u_loc = get_system_u_ode(u_ode, i, semi)185du_loc = get_system_u_ode(du_ode, i, semi)186rhs!(du_loc, u_loc, u_ode_reformatted, semi, semi_, t)187end188189runtime = time_ns() - time_start190put!(semi.performance_counter, runtime)191192return nothing193end194195# RHS call for the local system.196# Here we require the data from u_parent for each semidiscretization in order197# to exchange the correct boundary values.198function rhs!(du_ode, u_ode, u_parent, semis,199semi::SemidiscretizationHyperbolic, t)200@unpack mesh, equations, boundary_conditions, source_terms, solver, cache = semi201202u = wrap_array(u_ode, mesh, equations, solver, cache)203du = wrap_array(du_ode, mesh, equations, solver, cache)204205time_start = time_ns()206@trixi_timeit timer() "rhs!" rhs!(du, u, t, u_parent, semis, mesh, equations,207boundary_conditions, source_terms, solver, cache)208runtime = time_ns() - time_start209put!(semi.performance_counter, runtime)210211return nothing212end213214################################################################################215### AnalysisCallback216################################################################################217218"""219AnalysisCallbackCoupledP4est(semi, callbacks...)220221Combine multiple analysis callbacks for coupled simulations with a222[`SemidiscretizationCoupled`](@ref). For each coupled system, an indididual223[`AnalysisCallback`](@ref) **must** be created and passed to the `AnalysisCallbackCoupledP4est` **in224order**, i.e., in the same sequence as the indidvidual semidiscretizations are stored in the225`SemidiscretizationCoupled`.226227!!! warning "Experimental code"228This is an experimental feature and can change any time.229"""230struct AnalysisCallbackCoupledP4est{CB}231callbacks::CB232end233234# Convenience constructor for the coupled callback that gets called directly from the elixirs235function AnalysisCallbackCoupledP4est(semi_coupled, callbacks...)236if length(callbacks) != nsystems(semi_coupled)237error("an AnalysisCallbackCoupledP4est requires one AnalysisCallback for each semidiscretization")238end239240analysis_callback_coupled = AnalysisCallbackCoupledP4est{typeof(callbacks)}(callbacks)241242# This callback is triggered if any of its subsidiary callbacks' condition is triggered243condition = (u, t, integrator) -> any(callbacks) do callback244callback.condition(u, t, integrator)245end246247DiscreteCallback(condition, analysis_callback_coupled,248save_positions = (false, false),249initialize = initialize!)250end251252# used for error checks and EOC analysis253function (cb::DiscreteCallback{Condition, Affect!})(sol) where {Condition,254Affect! <:255AnalysisCallbackCoupledP4est256}257semi_coupled = sol.prob.p258u_ode_coupled = sol.u[end]259@unpack callbacks = cb.affect!260261uEltype = real(semi_coupled)262n_vars_upto_semi = cumsum(nvariables(semi_coupled.semis[i].equations)263for i in eachindex(semi_coupled.semis))[begin:end]264error_indices = Array([1, 1 .+ n_vars_upto_semi...])265length_error_array = sum(nvariables(semi_coupled.semis[i].equations)266for i in eachindex(semi_coupled.semis))267l2_error_collection = uEltype[]268linf_error_collection = uEltype[]269for i in eachsystem(semi_coupled)270analysis_callback = callbacks[i].affect!271@unpack analyzer = analysis_callback272cache_analysis = analysis_callback.cache273274semi = semi_coupled.semis[i]275u_ode = get_system_u_ode(u_ode_coupled, i, semi_coupled)276277l2_error,278linf_error = calc_error_norms(u_ode, sol.t[end], analyzer, semi,279cache_analysis)280append!(l2_error_collection, l2_error)281append!(linf_error_collection, linf_error)282end283284return (; l2 = l2_error_collection, linf = linf_error_collection)285end286287################################################################################288### SaveSolutionCallback289################################################################################290291# Save mesh for a coupled semidiscretization, which contains multiple meshes internally292function save_mesh(semi::SemidiscretizationCoupledP4est, output_directory, timestep = 0)293for i in eachsystem(semi)294mesh, _, _, _ = mesh_equations_solver_cache(semi.semis[i])295296if mesh.unsaved_changes297mesh.current_filename = save_mesh_file(mesh, output_directory;298system = string(i),299timestep = timestep)300mesh.unsaved_changes = false301end302end303return nothing304end305306@inline function save_solution_file(semi::SemidiscretizationCoupledP4est, u_ode,307solution_callback,308integrator)309@unpack semis = semi310311for i in eachsystem(semi)312u_ode_slice = get_system_u_ode(u_ode, i, semi)313save_solution_file(semis[i], u_ode_slice, solution_callback, integrator,314system = i)315end316return nothing317end318319################################################################################320### StepsizeCallback321################################################################################322323# In case of coupled system, use minimum timestep over all systems324# Case for constant `cfl_number`.325function calculate_dt(u_ode, t, cfl_hyperbolic, cfl_parabolic,326semi::SemidiscretizationCoupledP4est)327dt = minimum(eachsystem(semi)) do i328u_ode_slice = get_system_u_ode(u_ode, i, semi)329calculate_dt(u_ode_slice, t, cfl_hyperbolic, cfl_parabolic, semi.semis[i])330end331332return dt333end334335################################################################################336### Boundary conditions337################################################################################338339"""340BoundaryConditionCoupledP4est(coupling_converter)341342Boundary condition struct where the user can specify the coupling converter function.343344# Arguments345- `coupling_converter::CouplingConverter`: function to call for converting the solution346state of one system to the other system347"""348mutable struct BoundaryConditionCoupledP4est{CouplingConverter}349coupling_converter::CouplingConverter350351function BoundaryConditionCoupledP4est(coupling_converter)352new{typeof(coupling_converter)}(coupling_converter)353end354end355356"""357Extract the boundary values from the neighboring element.358This requires values from other mesh views.359This currently only works for Cartesian meshes.360"""361function (boundary_condition::BoundaryConditionCoupledP4est)(u_inner, mesh, equations,362cache,363i_index, j_index,364element_index,365normal_direction,366surface_flux_function,367direction,368u_ode_coupled)369n_nodes = length(mesh.parent.nodes)370# Using a projection onto e_x, -e_x, e_y, -e_y to determine which way our boundary interfaces points to.371# Knowing this, we then find the cell index in the global (parent) space of the neighboring cell.372if abs(sum(normal_direction .* (1.0, 0.0))) >373abs(sum(normal_direction .* (0.0, 1.0)))374if sum(normal_direction .* (1.0, 0.0)) >375sum(normal_direction .* (-1.0, 0.0))376cell_index_parent = cache.neighbor_ids_parent[findfirst((cache.boundaries.name .==377:x_pos) .*378(cache.boundaries.neighbor_ids .==379element_index))]380else381cell_index_parent = cache.neighbor_ids_parent[findfirst((cache.boundaries.name .==382:x_neg) .*383(cache.boundaries.neighbor_ids .==384element_index))]385end386i_index_g = i_index387# Make sure we do not leave the domain.388if i_index == n_nodes389i_index_g = 1390elseif i_index == 1391i_index_g = n_nodes392end393j_index_g = j_index394else395if sum(normal_direction .* (0.0, 1.0)) > sum(normal_direction .* (0.0, -1.0))396cell_index_parent = cache.neighbor_ids_parent[findfirst((cache.boundaries.name .==397:y_pos) .*398(cache.boundaries.neighbor_ids .==399element_index))]400else401cell_index_parent = cache.neighbor_ids_parent[findfirst((cache.boundaries.name .==402:y_neg) .*403(cache.boundaries.neighbor_ids .==404element_index))]405end406j_index_g = j_index407# Make sure we do not leave the domain.408if j_index == n_nodes409j_index_g = 1410elseif j_index == 1411j_index_g = n_nodes412end413i_index_g = i_index414end415# Perform integer division to get the right shape of the array.416u_parent_reshape = reshape(u_ode_coupled,417(n_nodes, n_nodes,418length(u_ode_coupled) ÷ n_nodes^ndims(mesh.parent)))419u_boundary = SVector(u_parent_reshape[i_index_g, j_index_g, cell_index_parent])420421# u_boundary = u_inner422orientation = normal_direction423424# Calculate boundary flux425flux = surface_flux_function(u_inner, u_boundary, orientation, equations)426427return flux428end429430function calc_boundary_flux!(cache, t, boundary_condition::BC, boundary_indexing,431mesh::P4estMeshView{2},432equations, surface_integral, dg::DG, u_parent) where {BC}433@unpack boundaries = cache434@unpack surface_flux_values = cache.elements435index_range = eachnode(dg)436437@threaded for local_index in eachindex(boundary_indexing)438# Use the local index to get the global boundary index from the pre-sorted list439boundary = boundary_indexing[local_index]440441# Get information on the adjacent element, compute the surface fluxes,442# and store them443element = boundaries.neighbor_ids[boundary]444node_indices = boundaries.node_indices[boundary]445direction = indices2direction(node_indices)446447i_node_start, i_node_step = index_to_start_step_2d(node_indices[1], index_range)448j_node_start, j_node_step = index_to_start_step_2d(node_indices[2], index_range)449450i_node = i_node_start451j_node = j_node_start452for node in eachnode(dg)453calc_boundary_flux!(surface_flux_values, t, boundary_condition,454mesh, have_nonconservative_terms(equations),455equations, surface_integral, dg, cache,456i_node, j_node,457node, direction, element, boundary,458u_parent)459460i_node += i_node_step461j_node += j_node_step462end463end464return nothing465end466467# Iterate over tuples of boundary condition types and associated indices468# in a type-stable way using "lispy tuple programming".469function calc_boundary_flux_by_type!(cache, t, BCs::NTuple{N, Any},470BC_indices::NTuple{N, Vector{Int}},471mesh::P4estMeshView,472equations, surface_integral, dg::DG,473u_parent) where {N}474# Extract the boundary condition type and index vector475boundary_condition = first(BCs)476boundary_condition_indices = first(BC_indices)477# Extract the remaining types and indices to be processed later478remaining_boundary_conditions = Base.tail(BCs)479remaining_boundary_condition_indices = Base.tail(BC_indices)480481# process the first boundary condition type482calc_boundary_flux!(cache, t, boundary_condition, boundary_condition_indices,483mesh, equations, surface_integral, dg, u_parent)484485# recursively call this method with the unprocessed boundary types486calc_boundary_flux_by_type!(cache, t, remaining_boundary_conditions,487remaining_boundary_condition_indices,488mesh, equations, surface_integral, dg, u_parent)489490return nothing491end492493# terminate the type-stable iteration over tuples494function calc_boundary_flux_by_type!(cache, t, BCs::Tuple{}, BC_indices::Tuple{},495mesh::P4estMeshView,496equations, surface_integral, dg::DG, u_parent)497return nothing498end499end # @muladd500501502