Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
trixi-framework
GitHub Repository: trixi-framework/Trixi.jl
Path: blob/main/src/semidiscretization/semidiscretization_coupled_p4est.jl
5586 views
1
# By default, Julia/LLVM does not use fused multiply-add operations (FMAs).
2
# Since these FMAs can increase the performance of many numerical algorithms,
3
# we need to opt-in explicitly.
4
# See https://ranocha.de/blog/Optimizing_EC_Trixi for further details.
5
@muladd begin
6
#! format: noindent
7
8
"""
9
SemidiscretizationCoupledP4est
10
11
Specialized semidiscretization routines for coupled problems using P4est mesh views.
12
This is analogous to the implementation for structured meshes.
13
[`semidiscretize`](@ref) will return an `ODEProblem` that synchronizes time steps between the semidiscretizations.
14
Each call of `rhs!` will call `rhs!` for each semidiscretization individually.
15
The semidiscretizations can be coupled by glueing meshes together using [`BoundaryConditionCoupled`](@ref).
16
17
See also: [`SemidiscretizationCoupled`](@ref)
18
19
!!! warning "Experimental code"
20
This is an experimental feature and can change any time.
21
"""
22
mutable struct SemidiscretizationCoupledP4est{Semis, Indices, EquationList} <:
23
AbstractSemidiscretization
24
semis::Semis
25
u_indices::Indices # u_ode[u_indices[i]] is the part of u_ode corresponding to semis[i]
26
performance_counter::PerformanceCounter
27
parent_cell_ids::Vector{Int}
28
view_cell_ids::Vector{Int}
29
mesh_ids::Vector{Int}
30
end
31
32
"""
33
SemidiscretizationCoupledP4est(semis...)
34
35
Create a coupled semidiscretization that consists of the semidiscretizations passed as arguments.
36
"""
37
function SemidiscretizationCoupledP4est(semis...)
38
@assert all(semi -> ndims(semi) == ndims(semis[1]), semis) "All semidiscretizations must have the same dimension!"
39
40
# Number of coefficients for each semidiscretization
41
n_coefficients = zeros(Int, length(semis))
42
for i in 1:length(semis)
43
_, equations, _, _ = mesh_equations_solver_cache(semis[i])
44
n_coefficients[i] = ndofs(semis[i]) * nvariables(equations)
45
end
46
47
# Compute range of coefficients associated with each semidiscretization
48
u_indices = Vector{UnitRange{Int}}(undef, length(semis))
49
for i in 1:length(semis)
50
offset = sum(n_coefficients[1:(i - 1)]) + 1
51
u_indices[i] = range(offset, length = n_coefficients[i])
52
end
53
54
# Create correspondence between parent mesh cell IDs and view cell IDs.
55
parent_cell_ids = 1:size(semis[1].mesh.parent.tree_node_coordinates)[end]
56
view_cell_ids = zeros(Int, length(parent_cell_ids))
57
mesh_ids = zeros(Int, length(parent_cell_ids))
58
for i in eachindex(semis)
59
view_cell_ids[semis[i].mesh.cell_ids] = parent_cell_id_to_view(parent_cell_ids[semis[i].mesh.cell_ids],
60
semis[i].mesh)
61
mesh_ids[semis[i].mesh.cell_ids] .= i
62
end
63
64
performance_counter = PerformanceCounter()
65
66
SemidiscretizationCoupledP4est{typeof(semis), typeof(u_indices),
67
typeof(performance_counter)}(semis, u_indices,
68
performance_counter,
69
parent_cell_ids,
70
view_cell_ids,
71
mesh_ids)
72
end
73
74
function Base.show(io::IO, ::MIME"text/plain", semi::SemidiscretizationCoupledP4est)
75
@nospecialize semi # reduce precompilation time
76
77
if get(io, :compact, false)
78
show(io, semi)
79
else
80
summary_header(io, "SemidiscretizationCoupledP4est")
81
summary_line(io, "#spatial dimensions", ndims(semi.semis[1]))
82
summary_line(io, "#systems", nsystems(semi))
83
for i in eachsystem(semi)
84
summary_line(io, "system", i)
85
mesh, equations, solver, _ = mesh_equations_solver_cache(semi.semis[i])
86
summary_line(increment_indent(io), "mesh", mesh |> typeof |> nameof)
87
summary_line(increment_indent(io), "equations",
88
equations |> typeof |> nameof)
89
summary_line(increment_indent(io), "initial condition",
90
semi.semis[i].initial_condition)
91
# no boundary conditions since that could be too much
92
summary_line(increment_indent(io), "source terms",
93
semi.semis[i].source_terms)
94
summary_line(increment_indent(io), "solver", solver |> typeof |> nameof)
95
end
96
summary_line(io, "total #DOFs per field", ndofsglobal(semi))
97
summary_footer(io)
98
end
99
end
100
101
function print_summary_semidiscretization(io::IO, semi::SemidiscretizationCoupledP4est)
102
show(io, MIME"text/plain"(), semi)
103
println(io, "\n")
104
for i in eachsystem(semi)
105
mesh, equations, solver, _ = mesh_equations_solver_cache(semi.semis[i])
106
summary_header(io, "System #$i")
107
108
summary_line(io, "mesh", mesh |> typeof |> nameof)
109
show(increment_indent(io), MIME"text/plain"(), mesh)
110
111
summary_line(io, "equations", equations |> typeof |> nameof)
112
show(increment_indent(io), MIME"text/plain"(), equations)
113
114
summary_line(io, "solver", solver |> typeof |> nameof)
115
show(increment_indent(io), MIME"text/plain"(), solver)
116
117
summary_footer(io)
118
println(io, "\n")
119
end
120
end
121
122
@inline nsystems(semi::SemidiscretizationCoupledP4est) = length(semi.semis)
123
124
@inline eachsystem(semi::SemidiscretizationCoupledP4est) = Base.OneTo(nsystems(semi))
125
126
@inline Base.real(semi::SemidiscretizationCoupledP4est) = promote_type(real.(semi.semis)...)
127
128
@inline function ndofs(semi::SemidiscretizationCoupledP4est)
129
return sum(ndofs, semi.semis)
130
end
131
132
"""
133
ndofsglobal(semi::SemidiscretizationCoupledP4est)
134
135
Return the global number of degrees of freedom associated with each scalar variable across all MPI ranks, and summed up over all coupled systems.
136
This is the same as [`ndofs`](@ref) for simulations running in serial or
137
parallelized via threads. It will in general be different for simulations
138
running in parallel with MPI.
139
"""
140
@inline function ndofsglobal(semi::SemidiscretizationCoupledP4est)
141
return sum(ndofsglobal, semi.semis)
142
end
143
144
function compute_coefficients(t, semi::SemidiscretizationCoupledP4est)
145
@unpack u_indices = semi
146
147
u_ode = Vector{real(semi)}(undef, u_indices[end][end])
148
149
# Distribute the partial solution vectors onto the global one.
150
@threaded for i in eachsystem(semi)
151
# Call `compute_coefficients` in `src/semidiscretization/semidiscretization.jl`
152
u_ode[u_indices[i]] .= compute_coefficients(t, semi.semis[i])
153
end
154
155
return u_ode
156
end
157
158
@inline function get_system_u_ode(u_ode, index, semi::SemidiscretizationCoupledP4est)
159
return @view u_ode[semi.u_indices[index]]
160
end
161
162
# RHS call for the coupled system.
163
function rhs!(du_ode, u_ode, semi::SemidiscretizationCoupledP4est, t)
164
time_start = time_ns()
165
166
n_nodes = length(semi.semis[1].mesh.parent.nodes)
167
# Reformat the parent solutions vector.
168
u_ode_reformatted = Vector{real(semi)}(undef, ndofs(semi))
169
u_ode_reformatted_reshape = reshape(u_ode_reformatted,
170
(n_nodes,
171
n_nodes,
172
length(semi.mesh_ids)))
173
# Extract the parent solution vector from the local solutions.
174
foreach_enumerate(semi.semis) do (i, semi_)
175
system_ode = get_system_u_ode(u_ode, i, semi)
176
system_ode_reshape = reshape(system_ode,
177
(n_nodes, n_nodes,
178
Int(length(system_ode) /
179
n_nodes^ndims(semi_.mesh))))
180
u_ode_reformatted_reshape[:, :, semi.mesh_ids .== i] .= system_ode_reshape
181
end
182
183
# Call rhs! for each semidiscretization
184
foreach_enumerate(semi.semis) do (i, semi_)
185
u_loc = get_system_u_ode(u_ode, i, semi)
186
du_loc = get_system_u_ode(du_ode, i, semi)
187
rhs!(du_loc, u_loc, u_ode_reformatted, semi, semi_, t)
188
end
189
190
runtime = time_ns() - time_start
191
put!(semi.performance_counter, runtime)
192
193
return nothing
194
end
195
196
# RHS call for the local system.
197
# Here we require the data from u_parent for each semidiscretization in order
198
# to exchange the correct boundary values.
199
function rhs!(du_ode, u_ode, u_parent, semis,
200
semi::SemidiscretizationHyperbolic, t)
201
@unpack mesh, equations, boundary_conditions, source_terms, solver, cache = semi
202
203
u = wrap_array(u_ode, mesh, equations, solver, cache)
204
du = wrap_array(du_ode, mesh, equations, solver, cache)
205
206
time_start = time_ns()
207
@trixi_timeit timer() "rhs!" rhs!(du, u, t, u_parent, semis, mesh, equations,
208
boundary_conditions, source_terms, solver, cache)
209
runtime = time_ns() - time_start
210
put!(semi.performance_counter, runtime)
211
212
return nothing
213
end
214
215
################################################################################
216
### AnalysisCallback
217
################################################################################
218
219
"""
220
AnalysisCallbackCoupledP4est(semi, callbacks...)
221
222
Combine multiple analysis callbacks for coupled simulations with a
223
[`SemidiscretizationCoupled`](@ref). For each coupled system, an indididual
224
[`AnalysisCallback`](@ref) **must** be created and passed to the `AnalysisCallbackCoupledP4est` **in
225
order**, i.e., in the same sequence as the indidvidual semidiscretizations are stored in the
226
`SemidiscretizationCoupled`.
227
228
!!! warning "Experimental code"
229
This is an experimental feature and can change any time.
230
"""
231
struct AnalysisCallbackCoupledP4est{CB}
232
callbacks::CB
233
end
234
235
# Convenience constructor for the coupled callback that gets called directly from the elixirs
236
function AnalysisCallbackCoupledP4est(semi_coupled, callbacks...)
237
if length(callbacks) != nsystems(semi_coupled)
238
error("an AnalysisCallbackCoupledP4est requires one AnalysisCallback for each semidiscretization")
239
end
240
241
analysis_callback_coupled = AnalysisCallbackCoupledP4est{typeof(callbacks)}(callbacks)
242
243
# This callback is triggered if any of its subsidiary callbacks' condition is triggered
244
condition = (u, t, integrator) -> any(callbacks) do callback
245
callback.condition(u, t, integrator)
246
end
247
248
DiscreteCallback(condition, analysis_callback_coupled,
249
save_positions = (false, false),
250
initialize = initialize!)
251
end
252
253
# used for error checks and EOC analysis
254
function (cb::DiscreteCallback{Condition, Affect!})(sol) where {Condition,
255
Affect! <:
256
AnalysisCallbackCoupledP4est
257
}
258
semi_coupled = sol.prob.p
259
u_ode_coupled = sol.u[end]
260
@unpack callbacks = cb.affect!
261
262
uEltype = real(semi_coupled)
263
n_vars_upto_semi = cumsum(nvariables(semi_coupled.semis[i].equations)
264
for i in eachindex(semi_coupled.semis))[begin:end]
265
error_indices = Array([1, 1 .+ n_vars_upto_semi...])
266
length_error_array = sum(nvariables(semi_coupled.semis[i].equations)
267
for i in eachindex(semi_coupled.semis))
268
l2_error_collection = uEltype[]
269
linf_error_collection = uEltype[]
270
for i in eachsystem(semi_coupled)
271
analysis_callback = callbacks[i].affect!
272
@unpack analyzer = analysis_callback
273
cache_analysis = analysis_callback.cache
274
275
semi = semi_coupled.semis[i]
276
u_ode = get_system_u_ode(u_ode_coupled, i, semi_coupled)
277
278
l2_error,
279
linf_error = calc_error_norms(u_ode, sol.t[end], analyzer, semi,
280
cache_analysis)
281
append!(l2_error_collection, l2_error)
282
append!(linf_error_collection, linf_error)
283
end
284
285
return (; l2 = l2_error_collection, linf = linf_error_collection)
286
end
287
288
################################################################################
289
### SaveSolutionCallback
290
################################################################################
291
292
# Save mesh for a coupled semidiscretization, which contains multiple meshes internally
293
function save_mesh(semi::SemidiscretizationCoupledP4est, output_directory, timestep = 0)
294
for i in eachsystem(semi)
295
mesh, _, _, _ = mesh_equations_solver_cache(semi.semis[i])
296
297
if mesh.unsaved_changes
298
mesh.current_filename = save_mesh_file(mesh, output_directory;
299
system = string(i),
300
timestep = timestep)
301
mesh.unsaved_changes = false
302
end
303
end
304
return nothing
305
end
306
307
@inline function save_solution_file(semi::SemidiscretizationCoupledP4est, u_ode,
308
solution_callback,
309
integrator)
310
@unpack semis = semi
311
312
for i in eachsystem(semi)
313
u_ode_slice = get_system_u_ode(u_ode, i, semi)
314
save_solution_file(semis[i], u_ode_slice, solution_callback, integrator,
315
system = i)
316
end
317
return nothing
318
end
319
320
################################################################################
321
### StepsizeCallback
322
################################################################################
323
324
# In case of coupled system, use minimum timestep over all systems
325
# Case for constant `cfl_number`.
326
function calculate_dt(u_ode, t, cfl_hyperbolic, cfl_parabolic,
327
semi::SemidiscretizationCoupledP4est)
328
dt = minimum(eachsystem(semi)) do i
329
u_ode_slice = get_system_u_ode(u_ode, i, semi)
330
calculate_dt(u_ode_slice, t, cfl_hyperbolic, cfl_parabolic, semi.semis[i])
331
end
332
333
return dt
334
end
335
336
################################################################################
337
### Boundary conditions
338
################################################################################
339
340
"""
341
BoundaryConditionCoupledP4est(coupling_converter)
342
343
Boundary condition struct where the user can specify the coupling converter function.
344
345
# Arguments
346
- `coupling_converter::CouplingConverter`: function to call for converting the solution
347
state of one system to the other system
348
"""
349
mutable struct BoundaryConditionCoupledP4est{CouplingConverter}
350
coupling_converter::CouplingConverter
351
352
function BoundaryConditionCoupledP4est(coupling_converter)
353
new{typeof(coupling_converter)}(coupling_converter)
354
end
355
end
356
357
"""
358
Extract the boundary values from the neighboring element.
359
This requires values from other mesh views.
360
This currently only works for Cartesian meshes.
361
"""
362
function (boundary_condition::BoundaryConditionCoupledP4est)(u_inner, mesh, equations,
363
cache,
364
i_index, j_index,
365
element_index,
366
normal_direction,
367
surface_flux_function,
368
direction,
369
u_ode_coupled)
370
n_nodes = length(mesh.parent.nodes)
371
# Using a projection onto e_x, -e_x, e_y, -e_y to determine which way our boundary interfaces points to.
372
# Knowing this, we then find the cell index in the global (parent) space of the neighboring cell.
373
if abs(sum(normal_direction .* (1.0, 0.0))) >
374
abs(sum(normal_direction .* (0.0, 1.0)))
375
if sum(normal_direction .* (1.0, 0.0)) >
376
sum(normal_direction .* (-1.0, 0.0))
377
cell_index_parent = cache.neighbor_ids_parent[findfirst((cache.boundaries.name .==
378
:x_pos) .*
379
(cache.boundaries.neighbor_ids .==
380
element_index))]
381
else
382
cell_index_parent = cache.neighbor_ids_parent[findfirst((cache.boundaries.name .==
383
:x_neg) .*
384
(cache.boundaries.neighbor_ids .==
385
element_index))]
386
end
387
i_index_g = i_index
388
# Make sure we do not leave the domain.
389
if i_index == n_nodes
390
i_index_g = 1
391
elseif i_index == 1
392
i_index_g = n_nodes
393
end
394
j_index_g = j_index
395
else
396
if sum(normal_direction .* (0.0, 1.0)) > sum(normal_direction .* (0.0, -1.0))
397
cell_index_parent = cache.neighbor_ids_parent[findfirst((cache.boundaries.name .==
398
:y_pos) .*
399
(cache.boundaries.neighbor_ids .==
400
element_index))]
401
else
402
cell_index_parent = cache.neighbor_ids_parent[findfirst((cache.boundaries.name .==
403
:y_neg) .*
404
(cache.boundaries.neighbor_ids .==
405
element_index))]
406
end
407
j_index_g = j_index
408
# Make sure we do not leave the domain.
409
if j_index == n_nodes
410
j_index_g = 1
411
elseif j_index == 1
412
j_index_g = n_nodes
413
end
414
i_index_g = i_index
415
end
416
# Perform integer division to get the right shape of the array.
417
u_parent_reshape = reshape(u_ode_coupled,
418
(n_nodes, n_nodes,
419
length(u_ode_coupled) ÷ n_nodes^ndims(mesh.parent)))
420
u_boundary = SVector(u_parent_reshape[i_index_g, j_index_g, cell_index_parent])
421
422
# u_boundary = u_inner
423
orientation = normal_direction
424
425
# Calculate boundary flux
426
flux = surface_flux_function(u_inner, u_boundary, orientation, equations)
427
428
return flux
429
end
430
431
function calc_boundary_flux!(cache, t, boundary_condition::BC, boundary_indexing,
432
mesh::P4estMeshView{2},
433
equations, surface_integral, dg::DG, u_parent) where {BC}
434
@unpack boundaries = cache
435
@unpack surface_flux_values = cache.elements
436
index_range = eachnode(dg)
437
438
@threaded for local_index in eachindex(boundary_indexing)
439
# Use the local index to get the global boundary index from the pre-sorted list
440
boundary = boundary_indexing[local_index]
441
442
# Get information on the adjacent element, compute the surface fluxes,
443
# and store them
444
element = boundaries.neighbor_ids[boundary]
445
node_indices = boundaries.node_indices[boundary]
446
direction = indices2direction(node_indices)
447
448
i_node_start, i_node_step = index_to_start_step_2d(node_indices[1], index_range)
449
j_node_start, j_node_step = index_to_start_step_2d(node_indices[2], index_range)
450
451
i_node = i_node_start
452
j_node = j_node_start
453
for node in eachnode(dg)
454
calc_boundary_flux!(surface_flux_values, t, boundary_condition,
455
mesh, have_nonconservative_terms(equations),
456
equations, surface_integral, dg, cache,
457
i_node, j_node,
458
node, direction, element, boundary,
459
u_parent)
460
461
i_node += i_node_step
462
j_node += j_node_step
463
end
464
end
465
return nothing
466
end
467
468
# Iterate over tuples of boundary condition types and associated indices
469
# in a type-stable way using "lispy tuple programming".
470
function calc_boundary_flux_by_type!(cache, t, BCs::NTuple{N, Any},
471
BC_indices::NTuple{N, Vector{Int}},
472
mesh::P4estMeshView,
473
equations, surface_integral, dg::DG,
474
u_parent) where {N}
475
# Extract the boundary condition type and index vector
476
boundary_condition = first(BCs)
477
boundary_condition_indices = first(BC_indices)
478
# Extract the remaining types and indices to be processed later
479
remaining_boundary_conditions = Base.tail(BCs)
480
remaining_boundary_condition_indices = Base.tail(BC_indices)
481
482
# process the first boundary condition type
483
calc_boundary_flux!(cache, t, boundary_condition, boundary_condition_indices,
484
mesh, equations, surface_integral, dg, u_parent)
485
486
# recursively call this method with the unprocessed boundary types
487
calc_boundary_flux_by_type!(cache, t, remaining_boundary_conditions,
488
remaining_boundary_condition_indices,
489
mesh, equations, surface_integral, dg, u_parent)
490
491
return nothing
492
end
493
494
# terminate the type-stable iteration over tuples
495
function calc_boundary_flux_by_type!(cache, t, BCs::Tuple{}, BC_indices::Tuple{},
496
mesh::P4estMeshView,
497
equations, surface_integral, dg::DG, u_parent)
498
return nothing
499
end
500
end # @muladd
501
502