Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
trixi-framework
GitHub Repository: trixi-framework/Trixi.jl
Path: blob/main/src/callbacks_step/amr.jl
5586 views
1
# By default, Julia/LLVM does not use fused multiply-add operations (FMAs).
2
# Since these FMAs can increase the performance of many numerical algorithms,
3
# we need to opt-in explicitly.
4
# See https://ranocha.de/blog/Optimizing_EC_Trixi for further details.
5
@muladd begin
6
#! format: noindent
7
8
"""
9
AMRCallback(semi, controller [,adaptor=AdaptorAMR(semi)];
10
interval,
11
adapt_initial_condition=true,
12
adapt_initial_condition_only_refine=true,
13
dynamic_load_balancing=true)
14
15
Performs adaptive mesh refinement (AMR) every `interval` time steps
16
for a given semidiscretization `semi` using the chosen `controller`.
17
"""
18
struct AMRCallback{Controller, Adaptor, Cache}
19
controller::Controller
20
interval::Int
21
adapt_initial_condition::Bool
22
adapt_initial_condition_only_refine::Bool
23
dynamic_load_balancing::Bool
24
adaptor::Adaptor
25
amr_cache::Cache
26
end
27
28
function AMRCallback(semi, controller, adaptor;
29
interval,
30
adapt_initial_condition = true,
31
adapt_initial_condition_only_refine = true,
32
dynamic_load_balancing = true)
33
# check arguments
34
if !(interval isa Integer && interval >= 0)
35
throw(ArgumentError("`interval` must be a non-negative integer (provided `interval = $interval`)"))
36
end
37
38
# AMR every `interval` time steps, but not after the final step
39
# With error-based step size control, some steps can be rejected. Thus,
40
# `integrator.iter >= integrator.stats.naccept`
41
# (total #steps) (#accepted steps)
42
# We need to check the number of accepted steps since callbacks are not
43
# activated after a rejected step.
44
if interval > 0
45
condition = (u, t, integrator) -> ((integrator.stats.naccept % interval == 0) &&
46
!(integrator.stats.naccept == 0 &&
47
integrator.iter > 0) &&
48
!isfinished(integrator))
49
else # disable the AMR callback except possibly for initial refinement during initialization
50
condition = (u, t, integrator) -> false
51
end
52
53
to_refine = Int[]
54
to_coarsen = Int[]
55
amr_cache = (; to_refine, to_coarsen)
56
57
amr_callback = AMRCallback{typeof(controller), typeof(adaptor), typeof(amr_cache)}(controller,
58
interval,
59
adapt_initial_condition,
60
adapt_initial_condition_only_refine,
61
dynamic_load_balancing,
62
adaptor,
63
amr_cache)
64
65
return DiscreteCallback(condition, amr_callback,
66
save_positions = (false, false),
67
initialize = initialize!)
68
end
69
70
function AMRCallback(semi, controller; kwargs...)
71
adaptor = AdaptorAMR(semi)
72
return AMRCallback(semi, controller, adaptor; kwargs...)
73
end
74
75
function AdaptorAMR(semi; kwargs...)
76
mesh, _, solver, _ = mesh_equations_solver_cache(semi)
77
return AdaptorAMR(mesh, solver; kwargs...)
78
end
79
80
# TODO: Taal bikeshedding, implement a method with less information and the signature
81
# function Base.show(io::IO, cb::DiscreteCallback{<:Any, <:AMRCallback})
82
# @nospecialize cb # reduce precompilation time
83
#
84
# amr_callback = cb.affect!
85
# print(io, "AMRCallback")
86
# end
87
function Base.show(io::IO, mime::MIME"text/plain",
88
cb::DiscreteCallback{<:Any, <:AMRCallback})
89
@nospecialize cb # reduce precompilation time
90
91
if get(io, :compact, false)
92
show(io, cb)
93
else
94
amr_callback = cb.affect!
95
96
summary_header(io, "AMRCallback")
97
summary_line(io, "controller", amr_callback.controller |> typeof |> nameof)
98
show(increment_indent(io), mime, amr_callback.controller)
99
summary_line(io, "interval", amr_callback.interval)
100
summary_line(io, "adapt IC",
101
amr_callback.adapt_initial_condition ? "yes" : "no")
102
if amr_callback.adapt_initial_condition
103
summary_line(io, "│ only refine",
104
amr_callback.adapt_initial_condition_only_refine ? "yes" :
105
"no")
106
end
107
summary_footer(io)
108
end
109
end
110
111
# The function below is used to control the output depending on whether or not AMR is enabled.
112
"""
113
uses_amr(callback)
114
115
Checks whether the provided callback or `CallbackSet` is an [`AMRCallback`](@ref)
116
or contains one.
117
"""
118
uses_amr(cb) = false
119
function uses_amr(cb::DiscreteCallback{Condition, Affect!}) where {Condition,
120
Affect! <:
121
AMRCallback}
122
return true
123
end
124
uses_amr(callbacks::CallbackSet) = mapreduce(uses_amr, |, callbacks.discrete_callbacks)
125
126
function get_element_variables!(element_variables, u, mesh, equations, solver, cache,
127
amr_callback::AMRCallback; kwargs...)
128
return get_element_variables!(element_variables, u, mesh, equations, solver, cache,
129
amr_callback.controller, amr_callback; kwargs...)
130
end
131
132
function initialize!(cb::DiscreteCallback{Condition, Affect!}, u, t,
133
integrator) where {Condition, Affect! <: AMRCallback}
134
amr_callback = cb.affect!
135
semi = integrator.p
136
137
@trixi_timeit timer() "initial condition AMR" if amr_callback.adapt_initial_condition
138
# iterate until mesh does not change anymore
139
has_changed = amr_callback(integrator,
140
only_refine = amr_callback.adapt_initial_condition_only_refine)
141
iterations = 1
142
while has_changed
143
compute_coefficients!(integrator.u, t, semi)
144
u_modified!(integrator, true)
145
has_changed = amr_callback(integrator,
146
only_refine = amr_callback.adapt_initial_condition_only_refine)
147
iterations = iterations + 1
148
allowed_max_iterations = max(10, max_level(amr_callback.controller))
149
if iterations > allowed_max_iterations
150
@warn "AMR for initial condition did not settle within $(allowed_max_iterations) iterations!\n" *
151
"Consider adjusting thresholds or setting `adapt_initial_condition_only_refine`."
152
break
153
end
154
end
155
156
# Update initial state integrals of analysis callback if it exists
157
# See https://github.com/trixi-framework/Trixi.jl/issues/2536 for more information.
158
index = findfirst(cb -> cb.affect! isa AnalysisCallback,
159
integrator.opts.callback.discrete_callbacks)
160
if !isnothing(index)
161
analysis_callback = integrator.opts.callback.discrete_callbacks[index].affect!
162
163
initial_state_integrals = integrate(integrator.u, semi)
164
analysis_callback.initial_state_integrals = initial_state_integrals
165
end
166
end
167
168
return nothing
169
end
170
171
# TODO: Taal remove?
172
# function (cb::DiscreteCallback{Condition,Affect!})(ode::ODEProblem) where {Condition, Affect!<:AMRCallback}
173
# amr_callback = cb.affect!
174
# semi = ode.p
175
176
# @trixi_timeit timer() "initial condition AMR" if amr_callback.adapt_initial_condition
177
# # iterate until mesh does not change anymore
178
# has_changed = true
179
# while has_changed
180
# has_changed = amr_callback(ode.u0, semi,
181
# only_refine=amr_callback.adapt_initial_condition_only_refine)
182
# compute_coefficients!(ode.u0, ode.tspan[1], semi)
183
# end
184
# end
185
186
# return nothing
187
# end
188
189
function (amr_callback::AMRCallback)(integrator; kwargs...)
190
u_ode = integrator.u
191
semi = integrator.p
192
193
@trixi_timeit timer() "AMR" begin
194
has_changed = amr_callback(u_ode, semi,
195
integrator.t, integrator.iter; kwargs...)
196
if has_changed
197
resize!(integrator, length(u_ode))
198
u_modified!(integrator, true)
199
end
200
end
201
202
return has_changed
203
end
204
205
@inline function (amr_callback::AMRCallback)(u_ode::AbstractVector,
206
semi::SemidiscretizationHyperbolic,
207
t, iter;
208
kwargs...)
209
# Note that we don't `wrap_array` the vector `u_ode` to be able to `resize!`
210
# it when doing AMR while still dispatching on the `mesh` etc.
211
return amr_callback(u_ode, mesh_equations_solver_cache(semi)..., semi, t, iter;
212
kwargs...)
213
end
214
215
@inline function (amr_callback::AMRCallback)(u_ode::AbstractVector,
216
semi::Union{SemidiscretizationHyperbolicParabolic,
217
SemidiscretizationParabolic},
218
t, iter;
219
kwargs...)
220
# Note that we don't `wrap_array` the vector `u_ode` to be able to `resize!`
221
# it when doing AMR while still dispatching on the `mesh` etc.
222
return amr_callback(u_ode, mesh_equations_solver_cache(semi)...,
223
semi.cache_parabolic,
224
semi, t, iter; kwargs...)
225
end
226
227
# `passive_args` is currently used for Euler with self-gravity to adapt the gravity solver
228
# passively without querying its indicator, based on the assumption that both solvers use
229
# the same mesh. That's a hack and should be improved in the future once we have more examples
230
# and a better understanding of such a coupling.
231
# `passive_args` is expected to be an iterable of `Tuple`s of the form
232
# `(p_u_ode, p_mesh, p_equations, p_dg, p_cache)`.
233
function (amr_callback::AMRCallback)(u_ode::AbstractVector, mesh::TreeMesh,
234
equations, dg::DG, cache, semi,
235
t, iter;
236
only_refine = false, only_coarsen = false,
237
passive_args = ())
238
@unpack controller, adaptor = amr_callback
239
240
u = wrap_array(u_ode, mesh, equations, dg, cache)
241
lambda = @trixi_timeit timer() "indicator" controller(u, mesh, equations, dg, cache,
242
t = t, iter = iter)
243
244
if mpi_isparallel()
245
# Collect lambda for all elements
246
lambda_global = Vector{eltype(lambda)}(undef, nelementsglobal(mesh, dg, cache))
247
# Use parent because n_elements_by_rank is an OffsetArray
248
recvbuf = MPI.VBuffer(lambda_global, parent(cache.mpi_cache.n_elements_by_rank))
249
MPI.Allgatherv!(lambda, recvbuf, mpi_comm())
250
lambda = lambda_global
251
end
252
253
leaf_cell_ids = leaf_cells(mesh.tree)
254
@boundscheck begin
255
@assert axes(lambda)==axes(leaf_cell_ids) ("Indicator (axes = $(axes(lambda))) and leaf cell (axes = $(axes(leaf_cell_ids))) arrays have different axes")
256
end
257
258
@unpack to_refine, to_coarsen = amr_callback.amr_cache
259
empty!(to_refine)
260
empty!(to_coarsen)
261
# Note: This assumes that the entries of `lambda` are sorted with ascending cell ids
262
for element in eachindex(lambda)
263
controller_value = lambda[element]
264
if controller_value > 0
265
push!(to_refine, leaf_cell_ids[element])
266
elseif controller_value < 0
267
push!(to_coarsen, leaf_cell_ids[element])
268
end
269
end
270
271
@trixi_timeit timer() "refine" if !only_coarsen && !isempty(to_refine)
272
# refine mesh
273
refined_original_cells = @trixi_timeit timer() "mesh" refine!(mesh.tree,
274
to_refine)
275
276
# Find all indices of elements whose cell ids are in refined_original_cells
277
elements_to_refine = findall(in(refined_original_cells),
278
cache.elements.cell_ids)
279
280
# refine solver
281
@trixi_timeit timer() "solver" refine!(u_ode, adaptor, mesh, equations, dg,
282
cache, elements_to_refine)
283
for (p_u_ode, p_mesh, p_equations, p_dg, p_cache) in passive_args
284
@trixi_timeit timer() "passive solver" refine!(p_u_ode, adaptor, p_mesh,
285
p_equations, p_dg, p_cache,
286
elements_to_refine)
287
end
288
else
289
# If there is nothing to refine, create empty array for later use
290
refined_original_cells = Int[]
291
end
292
293
@trixi_timeit timer() "coarsen" if !only_refine && !isempty(to_coarsen)
294
# Since the cells may have been shifted due to refinement, first we need to
295
# translate the old cell ids to the new cell ids
296
if !isempty(to_coarsen)
297
to_coarsen = original2refined(to_coarsen, refined_original_cells, mesh)
298
end
299
300
# Next, determine the parent cells from which the fine cells are to be
301
# removed, since these are needed for the coarsen! function. However, since
302
# we only want to coarsen if *all* child cells are marked for coarsening,
303
# we count the coarsening indicators for each parent cell and only coarsen
304
# if all children are marked as such (i.e., where the count is 2^ndims). At
305
# the same time, check if a cell is marked for coarsening even though it is
306
# *not* a leaf cell -> this can only happen if it was refined due to 2:1
307
# smoothing during the preceding refinement operation.
308
parents_to_coarsen = zeros(Int, length(mesh.tree))
309
for cell_id in to_coarsen
310
# If cell has no parent, it cannot be coarsened
311
if !has_parent(mesh.tree, cell_id)
312
continue
313
end
314
315
# If cell is not leaf (anymore), it cannot be coarsened
316
if !is_leaf(mesh.tree, cell_id)
317
continue
318
end
319
320
# Increase count for parent cell
321
parent_id = mesh.tree.parent_ids[cell_id]
322
parents_to_coarsen[parent_id] += 1
323
end
324
325
# Extract only those parent cells for which all children should be coarsened
326
to_coarsen = collect(eachindex(parents_to_coarsen))[parents_to_coarsen .== 2^ndims(mesh)]
327
328
# Finally, coarsen mesh
329
coarsened_original_cells = @trixi_timeit timer() "mesh" coarsen!(mesh.tree,
330
to_coarsen)
331
332
# Convert coarsened parent cell ids to the list of child cell ids that have
333
# been removed, since this is the information that is expected by the solver
334
removed_child_cells = zeros(Int,
335
n_children_per_cell(mesh.tree) *
336
length(coarsened_original_cells))
337
for (index, coarse_cell_id) in enumerate(coarsened_original_cells)
338
for child in 1:n_children_per_cell(mesh.tree)
339
removed_child_cells[n_children_per_cell(mesh.tree) * (index - 1) + child] = coarse_cell_id +
340
child
341
end
342
end
343
344
# Find all indices of elements whose cell ids are in removed_child_cells
345
elements_to_remove = findall(in(removed_child_cells), cache.elements.cell_ids)
346
347
# coarsen solver
348
@trixi_timeit timer() "solver" coarsen!(u_ode, adaptor, mesh, equations, dg,
349
cache, elements_to_remove)
350
for (p_u_ode, p_mesh, p_equations, p_dg, p_cache) in passive_args
351
@trixi_timeit timer() "passive solver" coarsen!(p_u_ode, adaptor, p_mesh,
352
p_equations, p_dg, p_cache,
353
elements_to_remove)
354
end
355
else
356
# If there is nothing to coarsen, create empty array for later use
357
coarsened_original_cells = Int[]
358
end
359
360
# Store whether there were any cells coarsened or refined
361
has_changed = !isempty(refined_original_cells) || !isempty(coarsened_original_cells)
362
if has_changed # TODO: Taal decide, where shall we set this?
363
# don't set it to has_changed since there can be changes from earlier calls
364
mesh.unsaved_changes = true
365
end
366
367
# Dynamically balance computational load by first repartitioning the mesh and then redistributing the cells/elements
368
if has_changed && mpi_isparallel() && amr_callback.dynamic_load_balancing
369
@trixi_timeit timer() "dynamic load balancing" begin
370
old_mpi_ranks_per_cell = copy(mesh.tree.mpi_ranks)
371
372
partition!(mesh)
373
374
rebalance_solver!(u_ode, mesh, equations, dg, cache, old_mpi_ranks_per_cell)
375
end
376
end
377
378
# Return true if there were any cells coarsened or refined, otherwise false
379
return has_changed
380
end
381
382
function (amr_callback::AMRCallback)(u_ode::AbstractVector, mesh::TreeMesh,
383
equations, dg::DG,
384
cache, cache_parabolic,
385
semi::Union{SemidiscretizationHyperbolicParabolic,
386
SemidiscretizationParabolic},
387
t, iter;
388
only_refine = false, only_coarsen = false)
389
@unpack controller, adaptor = amr_callback
390
391
u = wrap_array(u_ode, mesh, equations, dg, cache)
392
lambda = @trixi_timeit timer() "indicator" controller(u, mesh, equations, dg, cache,
393
t = t, iter = iter)
394
395
if mpi_isparallel()
396
error("MPI has not been verified yet for parabolic AMR")
397
398
# Collect lambda for all elements
399
lambda_global = Vector{eltype(lambda)}(undef, nelementsglobal(mesh, dg, cache))
400
# Use parent because n_elements_by_rank is an OffsetArray
401
recvbuf = MPI.VBuffer(lambda_global, parent(cache.mpi_cache.n_elements_by_rank))
402
MPI.Allgatherv!(lambda, recvbuf, mpi_comm())
403
lambda = lambda_global
404
end
405
406
leaf_cell_ids = leaf_cells(mesh.tree)
407
@boundscheck begin
408
@assert axes(lambda)==axes(leaf_cell_ids) ("Indicator (axes = $(axes(lambda))) and leaf cell (axes = $(axes(leaf_cell_ids))) arrays have different axes")
409
end
410
411
@unpack to_refine, to_coarsen = amr_callback.amr_cache
412
empty!(to_refine)
413
empty!(to_coarsen)
414
# Note: This assumes that the entries of `lambda` are sorted with ascending cell ids
415
for element in eachindex(lambda)
416
controller_value = lambda[element]
417
if controller_value > 0
418
push!(to_refine, leaf_cell_ids[element])
419
elseif controller_value < 0
420
push!(to_coarsen, leaf_cell_ids[element])
421
end
422
end
423
424
@trixi_timeit timer() "refine" if !only_coarsen && !isempty(to_refine)
425
# refine mesh
426
refined_original_cells = @trixi_timeit timer() "mesh" refine!(mesh.tree,
427
to_refine)
428
429
# Find all indices of elements whose cell ids are in refined_original_cells
430
# Note: This assumes same indices for hyperbolic and parabolic part.
431
elements_to_refine = findall(in(refined_original_cells),
432
cache.elements.cell_ids)
433
434
# refine solver
435
@trixi_timeit timer() "solver" refine!(u_ode, adaptor, mesh, equations, dg,
436
cache, cache_parabolic,
437
elements_to_refine)
438
else
439
# If there is nothing to refine, create empty array for later use
440
refined_original_cells = Int[]
441
end
442
443
@trixi_timeit timer() "coarsen" if !only_refine && !isempty(to_coarsen)
444
# Since the cells may have been shifted due to refinement, first we need to
445
# translate the old cell ids to the new cell ids
446
if !isempty(to_coarsen)
447
to_coarsen = original2refined(to_coarsen, refined_original_cells, mesh)
448
end
449
450
# Next, determine the parent cells from which the fine cells are to be
451
# removed, since these are needed for the coarsen! function. However, since
452
# we only want to coarsen if *all* child cells are marked for coarsening,
453
# we count the coarsening indicators for each parent cell and only coarsen
454
# if all children are marked as such (i.e., where the count is 2^ndims). At
455
# the same time, check if a cell is marked for coarsening even though it is
456
# *not* a leaf cell -> this can only happen if it was refined due to 2:1
457
# smoothing during the preceding refinement operation.
458
parents_to_coarsen = zeros(Int, length(mesh.tree))
459
for cell_id in to_coarsen
460
# If cell has no parent, it cannot be coarsened
461
if !has_parent(mesh.tree, cell_id)
462
continue
463
end
464
465
# If cell is not leaf (anymore), it cannot be coarsened
466
if !is_leaf(mesh.tree, cell_id)
467
continue
468
end
469
470
# Increase count for parent cell
471
parent_id = mesh.tree.parent_ids[cell_id]
472
parents_to_coarsen[parent_id] += 1
473
end
474
475
# Extract only those parent cells for which all children should be coarsened
476
to_coarsen = collect(eachindex(parents_to_coarsen))[parents_to_coarsen .== 2^ndims(mesh)]
477
478
# Finally, coarsen mesh
479
coarsened_original_cells = @trixi_timeit timer() "mesh" coarsen!(mesh.tree,
480
to_coarsen)
481
482
# Convert coarsened parent cell ids to the list of child cell ids that have
483
# been removed, since this is the information that is expected by the solver
484
removed_child_cells = zeros(Int,
485
n_children_per_cell(mesh.tree) *
486
length(coarsened_original_cells))
487
for (index, coarse_cell_id) in enumerate(coarsened_original_cells)
488
for child in 1:n_children_per_cell(mesh.tree)
489
removed_child_cells[n_children_per_cell(mesh.tree) * (index - 1) + child] = coarse_cell_id +
490
child
491
end
492
end
493
494
# Find all indices of elements whose cell ids are in removed_child_cells
495
# Note: This assumes same indices for hyperbolic and parabolic part.
496
elements_to_remove = findall(in(removed_child_cells), cache.elements.cell_ids)
497
498
# coarsen solver
499
@trixi_timeit timer() "solver" coarsen!(u_ode, adaptor, mesh, equations, dg,
500
cache, cache_parabolic,
501
elements_to_remove)
502
else
503
# If there is nothing to coarsen, create empty array for later use
504
coarsened_original_cells = Int[]
505
end
506
507
# Store whether there were any cells coarsened or refined
508
has_changed = !isempty(refined_original_cells) || !isempty(coarsened_original_cells)
509
if has_changed # TODO: Taal decide, where shall we set this?
510
# don't set it to has_changed since there can be changes from earlier calls
511
mesh.unsaved_changes = true
512
end
513
514
# Dynamically balance computational load by first repartitioning the mesh and then redistributing the cells/elements
515
if has_changed && mpi_isparallel() && amr_callback.dynamic_load_balancing
516
error("MPI has not been verified yet for parabolic AMR")
517
518
@trixi_timeit timer() "dynamic load balancing" begin
519
old_mpi_ranks_per_cell = copy(mesh.tree.mpi_ranks)
520
521
partition!(mesh)
522
523
rebalance_solver!(u_ode, mesh, equations, dg, cache, old_mpi_ranks_per_cell)
524
end
525
end
526
527
# Return true if there were any cells coarsened or refined, otherwise false
528
return has_changed
529
end
530
531
# Copy controller values to quad user data storage, will be called below
532
function copy_to_quad_iter_volume(info, user_data)
533
info_pw = PointerWrapper(info)
534
535
# Load tree from global trees array, one-based indexing
536
tree_pw = load_pointerwrapper_tree(info_pw.p4est, info_pw.treeid[] + 1)
537
# Quadrant numbering offset of this quadrant
538
offset = tree_pw.quadrants_offset[]
539
# Global quad ID
540
quad_id = offset + info_pw.quadid[]
541
542
# Access user_data = lambda
543
user_data_pw = PointerWrapper(Int, user_data)
544
# Load controller_value = lambda[quad_id + 1]
545
controller_value = user_data_pw[quad_id + 1]
546
547
# Access quadrant's user data ([global quad ID, controller_value])
548
quad_data_pw = PointerWrapper(Int, info_pw.quad.p.user_data[])
549
# Save controller value to quadrant's user data.
550
quad_data_pw[2] = controller_value
551
552
return nothing
553
end
554
555
# specialized callback which includes the `cache_parabolic` argument
556
function (amr_callback::AMRCallback)(u_ode::AbstractVector, mesh::P4estMesh,
557
equations, dg::DG, cache, cache_parabolic,
558
semi,
559
t, iter;
560
only_refine = false, only_coarsen = false,
561
passive_args = ())
562
@unpack controller, adaptor = amr_callback
563
564
u = wrap_array(u_ode, mesh, equations, dg, cache)
565
lambda = @trixi_timeit timer() "indicator" controller(u, mesh, equations, dg, cache,
566
t = t, iter = iter)
567
568
@boundscheck begin
569
@assert axes(lambda)==(Base.OneTo(ncells(mesh)),) ("Indicator array (axes = $(axes(lambda))) and mesh cells (axes = $(Base.OneTo(ncells(mesh)))) have different axes")
570
end
571
572
# Copy controller value of each quad to the quad's user data storage
573
iter_volume_c = cfunction(copy_to_quad_iter_volume, Val(ndims(mesh)))
574
575
# The pointer to lambda will be interpreted as Ptr{Int} below
576
@assert lambda isa Vector{Int}
577
iterate_p4est(mesh.p4est, lambda; iter_volume_c = iter_volume_c)
578
579
@trixi_timeit timer() "refine" if !only_coarsen
580
# Refine mesh
581
refined_original_cells = @trixi_timeit timer() "mesh" refine!(mesh)
582
583
# Refine solver
584
@trixi_timeit timer() "solver" refine!(u_ode, adaptor, mesh, equations, dg,
585
cache, cache_parabolic,
586
refined_original_cells)
587
for (p_u_ode, p_mesh, p_equations, p_dg, p_cache) in passive_args
588
@trixi_timeit timer() "passive solver" refine!(p_u_ode, adaptor, p_mesh,
589
p_equations,
590
p_dg, p_cache,
591
refined_original_cells)
592
end
593
else
594
# If there is nothing to refine, create empty array for later use
595
refined_original_cells = Int[]
596
end
597
598
@trixi_timeit timer() "coarsen" if !only_refine
599
# Coarsen mesh
600
coarsened_original_cells = @trixi_timeit timer() "mesh" coarsen!(mesh)
601
602
# coarsen solver
603
@trixi_timeit timer() "solver" coarsen!(u_ode, adaptor, mesh, equations, dg,
604
cache, cache_parabolic,
605
coarsened_original_cells)
606
for (p_u_ode, p_mesh, p_equations, p_dg, p_cache) in passive_args
607
@trixi_timeit timer() "passive solver" coarsen!(p_u_ode, adaptor, p_mesh,
608
p_equations,
609
p_dg, p_cache,
610
coarsened_original_cells)
611
end
612
else
613
# If there is nothing to coarsen, create empty array for later use
614
coarsened_original_cells = Int[]
615
end
616
617
# Store whether there were any cells coarsened or refined and perform load balancing
618
has_changed = !isempty(refined_original_cells) || !isempty(coarsened_original_cells)
619
# Check if mesh changed on other processes
620
if mpi_isparallel()
621
has_changed = MPI.Allreduce!(Ref(has_changed), |, mpi_comm())[]
622
end
623
624
if has_changed # TODO: Taal decide, where shall we set this?
625
# don't set it to has_changed since there can be changes from earlier calls
626
mesh.unsaved_changes = true
627
628
if mpi_isparallel() && amr_callback.dynamic_load_balancing
629
@trixi_timeit timer() "dynamic load balancing" begin
630
global_first_quadrant = unsafe_wrap(Array,
631
unsafe_load(mesh.p4est).global_first_quadrant,
632
mpi_nranks() + 1)
633
old_global_first_quadrant = copy(global_first_quadrant)
634
partition!(mesh)
635
rebalance_solver!(u_ode, mesh, equations, dg, cache,
636
old_global_first_quadrant)
637
@unpack parabolic_container = cache_parabolic
638
resize!(parabolic_container, equations, dg, cache)
639
end
640
end
641
642
reinitialize_boundaries!(semi.boundary_conditions, cache)
643
# if the semidiscretization also stores parabolic boundary conditions,
644
# reinitialize them after each refinement step as well.
645
if hasproperty(semi, :boundary_conditions_parabolic)
646
reinitialize_boundaries!(semi.boundary_conditions_parabolic, cache)
647
end
648
end
649
650
# Return true if there were any cells coarsened or refined, otherwise false
651
return has_changed
652
end
653
654
# 2D
655
function cfunction(::typeof(copy_to_quad_iter_volume), ::Val{2})
656
@cfunction(copy_to_quad_iter_volume, Cvoid,
657
(Ptr{p4est_iter_volume_info_t}, Ptr{Cvoid}))
658
end
659
# 3D
660
function cfunction(::typeof(copy_to_quad_iter_volume), ::Val{3})
661
@cfunction(copy_to_quad_iter_volume, Cvoid,
662
(Ptr{p8est_iter_volume_info_t}, Ptr{Cvoid}))
663
end
664
665
function (amr_callback::AMRCallback)(u_ode::AbstractVector, mesh::P4estMesh,
666
equations, dg::DG, cache, semi,
667
t, iter;
668
only_refine = false, only_coarsen = false,
669
passive_args = ())
670
@unpack controller, adaptor = amr_callback
671
672
u = wrap_array(u_ode, mesh, equations, dg, cache)
673
lambda = @trixi_timeit timer() "indicator" controller(u, mesh, equations, dg, cache,
674
t = t, iter = iter)
675
676
@boundscheck begin
677
@assert axes(lambda)==(Base.OneTo(ncells(mesh)),) ("Indicator array (axes = $(axes(lambda))) and mesh cells (axes = $(Base.OneTo(ncells(mesh)))) have different axes")
678
end
679
680
# Copy controller value of each quad to the quad's user data storage
681
iter_volume_c = cfunction(copy_to_quad_iter_volume, Val(ndims(mesh)))
682
683
# The pointer to lambda will be interpreted as Ptr{Int} above
684
@assert lambda isa Vector{Int}
685
iterate_p4est(mesh.p4est, lambda; iter_volume_c = iter_volume_c)
686
687
@trixi_timeit timer() "refine" if !only_coarsen
688
# Refine mesh
689
refined_original_cells = @trixi_timeit timer() "mesh" refine!(mesh)
690
691
# Refine solver
692
@trixi_timeit timer() "solver" refine!(u_ode, adaptor, mesh, equations, dg,
693
cache,
694
refined_original_cells)
695
for (p_u_ode, p_mesh, p_equations, p_dg, p_cache) in passive_args
696
@trixi_timeit timer() "passive solver" refine!(p_u_ode, adaptor, p_mesh,
697
p_equations,
698
p_dg, p_cache,
699
refined_original_cells)
700
end
701
else
702
# If there is nothing to refine, create empty array for later use
703
refined_original_cells = Int[]
704
end
705
706
@trixi_timeit timer() "coarsen" if !only_refine
707
# Coarsen mesh
708
coarsened_original_cells = @trixi_timeit timer() "mesh" coarsen!(mesh)
709
710
# coarsen solver
711
@trixi_timeit timer() "solver" coarsen!(u_ode, adaptor, mesh, equations, dg,
712
cache,
713
coarsened_original_cells)
714
for (p_u_ode, p_mesh, p_equations, p_dg, p_cache) in passive_args
715
@trixi_timeit timer() "passive solver" coarsen!(p_u_ode, adaptor, p_mesh,
716
p_equations,
717
p_dg, p_cache,
718
coarsened_original_cells)
719
end
720
else
721
# If there is nothing to coarsen, create empty array for later use
722
coarsened_original_cells = Int[]
723
end
724
725
# Store whether there were any cells coarsened or refined and perform load balancing
726
has_changed = !isempty(refined_original_cells) || !isempty(coarsened_original_cells)
727
# Check if mesh changed on other processes
728
if mpi_isparallel()
729
has_changed = MPI.Allreduce!(Ref(has_changed), |, mpi_comm())[]
730
end
731
732
if has_changed # TODO: Taal decide, where shall we set this?
733
# don't set it to has_changed since there can be changes from earlier calls
734
mesh.unsaved_changes = true
735
736
if mpi_isparallel() && amr_callback.dynamic_load_balancing
737
@trixi_timeit timer() "dynamic load balancing" begin
738
global_first_quadrant = unsafe_wrap(Array,
739
unsafe_load(mesh.p4est).global_first_quadrant,
740
mpi_nranks() + 1)
741
old_global_first_quadrant = copy(global_first_quadrant)
742
partition!(mesh)
743
rebalance_solver!(u_ode, mesh, equations, dg, cache,
744
old_global_first_quadrant)
745
end
746
end
747
748
reinitialize_boundaries!(semi.boundary_conditions, cache)
749
end
750
751
# Return true if there were any cells coarsened or refined, otherwise false
752
return has_changed
753
end
754
755
function (amr_callback::AMRCallback)(u_ode::AbstractVector, mesh::T8codeMesh,
756
equations, dg::DG, cache, semi,
757
t, iter;
758
only_refine = false, only_coarsen = false,
759
passive_args = ())
760
has_changed = false
761
762
@unpack controller, adaptor = amr_callback
763
764
u = wrap_array(u_ode, mesh, equations, dg, cache)
765
indicators = @trixi_timeit timer() "indicator" controller(u, mesh, equations, dg,
766
cache, t = t, iter = iter)
767
768
if only_coarsen
769
indicators[indicators .> 0] .= 0
770
end
771
772
if only_refine
773
indicators[indicators .< 0] .= 0
774
end
775
776
@boundscheck begin
777
@assert axes(indicators)==(Base.OneTo(ncells(mesh)),) ("Indicator array (axes = $(axes(indicators))) and mesh cells (axes = $(Base.OneTo(ncells(mesh)))) have different axes")
778
end
779
780
@trixi_timeit timer() "adapt" begin
781
difference = @trixi_timeit timer() "mesh" trixi_t8_adapt!(mesh, indicators)
782
783
# Store whether there were any cells coarsened or refined and perform load balancing.
784
has_changed = any(difference .!= 0)
785
786
# Check if mesh changed on other processes
787
if mpi_isparallel()
788
has_changed = MPI.Allreduce!(Ref(has_changed), |, mpi_comm())[]
789
end
790
791
if has_changed
792
@trixi_timeit timer() "solver" adapt!(u_ode, adaptor, mesh, equations, dg,
793
cache, difference)
794
end
795
end
796
797
if has_changed
798
if mpi_isparallel() && amr_callback.dynamic_load_balancing
799
@trixi_timeit timer() "dynamic load balancing" begin
800
old_global_first_element_ids = get_global_first_element_ids(mesh)
801
partition!(mesh)
802
rebalance_solver!(u_ode, mesh, equations, dg, cache,
803
old_global_first_element_ids)
804
end
805
end
806
807
reinitialize_boundaries!(semi.boundary_conditions, cache)
808
end
809
810
mesh.unsaved_changes |= has_changed
811
812
# Return true if there were any cells coarsened or refined, otherwise false.
813
return has_changed
814
end
815
816
function reinitialize_boundaries!(boundary_conditions::UnstructuredSortedBoundaryTypes,
817
cache)
818
# Reinitialize boundary types container because boundaries may have changed.
819
return initialize!(boundary_conditions, cache)
820
end
821
822
function reinitialize_boundaries!(boundary_conditions, cache)
823
return boundary_conditions
824
end
825
826
# After refining cells, shift original cell ids to match new locations
827
# Note: Assumes sorted lists of original and refined cell ids!
828
# Note: `mesh` is only required to extract ndims
829
function original2refined(original_cell_ids, refined_original_cells, mesh)
830
# Sanity check
831
@assert issorted(original_cell_ids) "`original_cell_ids` not sorted"
832
@assert issorted(refined_original_cells) "`refined_cell_ids` not sorted"
833
834
# Create array with original cell ids (not yet shifted)
835
shifted_cell_ids = collect(1:original_cell_ids[end])
836
837
# Loop over refined original cells and apply shift for all following cells
838
for cell_id in refined_original_cells
839
# Only calculate shifts for cell ids that are relevant
840
if cell_id > length(shifted_cell_ids)
841
break
842
end
843
844
# Shift all subsequent cells by 2^ndims ids
845
shifted_cell_ids[(cell_id + 1):end] .+= 2^ndims(mesh)
846
end
847
848
# Convert original cell ids to their shifted values
849
return shifted_cell_ids[original_cell_ids]
850
end
851
852
abstract type AbstractController end
853
854
"""
855
ControllerThreeLevel(semi, indicator; base_level=1,
856
med_level=base_level, med_threshold=0.0,
857
max_level=base_level, max_threshold=1.0)
858
859
An AMR controller based on three levels (in descending order of precedence):
860
- set the target level to `max_level` if `indicator > max_threshold`
861
- set the target level to `med_level` if `indicator > med_threshold`;
862
if `med_level < 0`, set the target level to the current level
863
- set the target level to `base_level` otherwise
864
"""
865
struct ControllerThreeLevel{RealT <: Real, Indicator, Cache} <: AbstractController
866
base_level::Int
867
med_level::Int
868
max_level::Int
869
med_threshold::RealT
870
max_threshold::RealT
871
indicator::Indicator
872
cache::Cache
873
end
874
875
function ControllerThreeLevel(semi, indicator; base_level = 1,
876
med_level = base_level, med_threshold = 0.0,
877
max_level = base_level, max_threshold = 1.0)
878
med_threshold, max_threshold = promote(med_threshold, max_threshold)
879
cache = create_cache(ControllerThreeLevel, semi)
880
return ControllerThreeLevel{typeof(max_threshold), typeof(indicator),
881
typeof(cache)}(base_level,
882
med_level,
883
max_level,
884
med_threshold,
885
max_threshold,
886
indicator,
887
cache)
888
end
889
890
max_level(controller::AbstractController) = controller.max_level
891
892
function create_cache(controller_type::Type{<:AbstractController}, semi)
893
return create_cache(controller_type, mesh_equations_solver_cache(semi)...)
894
end
895
896
function Base.show(io::IO, controller::ControllerThreeLevel)
897
@nospecialize controller # reduce precompilation time
898
899
print(io, "ControllerThreeLevel(")
900
print(io, controller.indicator)
901
print(io, ", base_level=", controller.base_level)
902
print(io, ", med_level=", controller.med_level)
903
print(io, ", max_level=", controller.max_level)
904
print(io, ", med_threshold=", controller.med_threshold)
905
print(io, ", max_threshold=", controller.max_threshold)
906
print(io, ")")
907
return nothing
908
end
909
910
function Base.show(io::IO, mime::MIME"text/plain", controller::ControllerThreeLevel)
911
@nospecialize controller # reduce precompilation time
912
913
if get(io, :compact, false)
914
show(io, controller)
915
else
916
summary_header(io, "ControllerThreeLevel")
917
summary_line(io, "indicator", controller.indicator |> typeof |> nameof)
918
show(increment_indent(io), mime, controller.indicator)
919
summary_line(io, "base_level", controller.base_level)
920
summary_line(io, "med_level", controller.med_level)
921
summary_line(io, "max_level", controller.max_level)
922
summary_line(io, "med_threshold", controller.med_threshold)
923
summary_line(io, "max_threshold", controller.max_threshold)
924
summary_footer(io)
925
end
926
end
927
928
function get_element_variables!(element_variables, u, mesh, equations, solver, cache,
929
controller::ControllerThreeLevel,
930
amr_callback::AMRCallback;
931
kwargs...)
932
# call the indicator to get up-to-date values for IO
933
controller.indicator(u, mesh, equations, solver, cache; kwargs...)
934
return get_element_variables!(element_variables, controller.indicator, amr_callback)
935
end
936
937
function get_element_variables!(element_variables, indicator::AbstractIndicator,
938
::AMRCallback)
939
element_variables[:indicator_amr] = indicator.cache.alpha
940
return nothing
941
end
942
943
function current_element_levels(mesh::TreeMesh, solver, cache)
944
cell_ids = cache.elements.cell_ids[eachelement(solver, cache)]
945
946
return mesh.tree.levels[cell_ids]
947
end
948
949
function extract_levels_iter_volume(info, user_data)
950
info_pw = PointerWrapper(info)
951
952
# Load tree from global trees array, one-based indexing
953
tree_pw = load_pointerwrapper_tree(info_pw.p4est, info_pw.treeid[] + 1)
954
# Quadrant numbering offset of this quadrant
955
offset = tree_pw.quadrants_offset[]
956
# Global quad ID
957
quad_id = offset + info_pw.quadid[]
958
# Julia element ID
959
element_id = quad_id + 1
960
961
current_level = info_pw.quad.level[]
962
963
# Unpack user_data = current_levels and save current element level
964
pw = PointerWrapper(Int, user_data)
965
pw[element_id] = current_level
966
967
return nothing
968
end
969
970
# 2D
971
function cfunction(::typeof(extract_levels_iter_volume), ::Val{2})
972
@cfunction(extract_levels_iter_volume, Cvoid,
973
(Ptr{p4est_iter_volume_info_t}, Ptr{Cvoid}))
974
end
975
# 3D
976
function cfunction(::typeof(extract_levels_iter_volume), ::Val{3})
977
@cfunction(extract_levels_iter_volume, Cvoid,
978
(Ptr{p8est_iter_volume_info_t}, Ptr{Cvoid}))
979
end
980
981
function current_element_levels(mesh::P4estMesh, solver, cache)
982
current_levels = Vector{Int}(undef, nelements(solver, cache))
983
984
iter_volume_c = cfunction(extract_levels_iter_volume, Val(ndims(mesh)))
985
iterate_p4est(mesh.p4est, current_levels; iter_volume_c = iter_volume_c)
986
987
return current_levels
988
end
989
990
function current_element_levels(mesh::T8codeMesh, solver, cache)
991
return trixi_t8_get_local_element_levels(mesh.forest)
992
end
993
994
# TODO: Taal refactor, merge the two loops of ControllerThreeLevel and IndicatorLöhner etc.?
995
# But that would remove the simplest possibility to write that stuff to a file...
996
# We could of course implement some additional logic and workarounds, but is it worth the effort?
997
function (controller::ControllerThreeLevel)(u::AbstractArray{<:Any},
998
mesh, equations, dg::DG, cache;
999
kwargs...)
1000
@unpack controller_value = controller.cache
1001
resize!(controller_value, nelements(dg, cache))
1002
1003
alpha = controller.indicator(u, mesh, equations, dg, cache; kwargs...)
1004
current_levels = current_element_levels(mesh, dg, cache)
1005
1006
@threaded for element in eachelement(dg, cache)
1007
current_level = current_levels[element]
1008
1009
# set target level
1010
target_level = current_level
1011
if alpha[element] > controller.max_threshold
1012
target_level = controller.max_level
1013
elseif alpha[element] > controller.med_threshold
1014
if controller.med_level > 0
1015
target_level = controller.med_level
1016
# otherwise, target_level = current_level
1017
# set med_level = -1 to implicitly use med_level = current_level
1018
end
1019
else
1020
target_level = controller.base_level
1021
end
1022
1023
# compare target level with actual level to set controller
1024
if current_level < target_level
1025
controller_value[element] = 1 # refine!
1026
elseif current_level > target_level
1027
controller_value[element] = -1 # coarsen!
1028
else
1029
controller_value[element] = 0 # we're good
1030
end
1031
end
1032
1033
return controller_value
1034
end
1035
1036
"""
1037
ControllerThreeLevelCombined(semi, indicator_primary, indicator_secondary;
1038
base_level=1,
1039
med_level=base_level, med_threshold=0.0,
1040
max_level=base_level, max_threshold=1.0,
1041
max_threshold_secondary=1.0)
1042
1043
An AMR controller based on three levels (in descending order of precedence):
1044
- set the target level to `max_level` if `indicator_primary > max_threshold`
1045
- set the target level to `med_level` if `indicator_primary > med_threshold`;
1046
if `med_level < 0`, set the target level to the current level
1047
- set the target level to `base_level` otherwise
1048
If `indicator_secondary >= max_threshold_secondary`,
1049
set the target level to `max_level`.
1050
"""
1051
struct ControllerThreeLevelCombined{RealT <: Real, IndicatorPrimary, IndicatorSecondary,
1052
Cache} <: AbstractController
1053
base_level::Int
1054
med_level::Int
1055
max_level::Int
1056
med_threshold::RealT
1057
max_threshold::RealT
1058
max_threshold_secondary::RealT
1059
indicator_primary::IndicatorPrimary
1060
indicator_secondary::IndicatorSecondary
1061
cache::Cache
1062
end
1063
1064
function ControllerThreeLevelCombined(semi, indicator_primary, indicator_secondary;
1065
base_level = 1,
1066
med_level = base_level, med_threshold = 0.0,
1067
max_level = base_level, max_threshold = 1.0,
1068
max_threshold_secondary = 1.0)
1069
med_threshold, max_threshold, max_threshold_secondary = promote(med_threshold,
1070
max_threshold,
1071
max_threshold_secondary)
1072
cache = create_cache(ControllerThreeLevelCombined, semi)
1073
return ControllerThreeLevelCombined{typeof(max_threshold),
1074
typeof(indicator_primary),
1075
typeof(indicator_secondary), typeof(cache)}(base_level,
1076
med_level,
1077
max_level,
1078
med_threshold,
1079
max_threshold,
1080
max_threshold_secondary,
1081
indicator_primary,
1082
indicator_secondary,
1083
cache)
1084
end
1085
1086
function Base.show(io::IO, controller::ControllerThreeLevelCombined)
1087
@nospecialize controller # reduce precompilation time
1088
1089
print(io, "ControllerThreeLevelCombined(")
1090
print(io, controller.indicator_primary)
1091
print(io, ", ", controller.indicator_secondary)
1092
print(io, ", base_level=", controller.base_level)
1093
print(io, ", med_level=", controller.med_level)
1094
print(io, ", max_level=", controller.max_level)
1095
print(io, ", med_threshold=", controller.med_threshold)
1096
print(io, ", max_threshold_secondary=", controller.max_threshold_secondary)
1097
print(io, ")")
1098
return nothing
1099
end
1100
1101
function Base.show(io::IO, mime::MIME"text/plain",
1102
controller::ControllerThreeLevelCombined)
1103
@nospecialize controller # reduce precompilation time
1104
1105
if get(io, :compact, false)
1106
show(io, controller)
1107
else
1108
summary_header(io, "ControllerThreeLevelCombined")
1109
summary_line(io, "primary indicator",
1110
controller.indicator_primary |> typeof |> nameof)
1111
show(increment_indent(io), mime, controller.indicator_primary)
1112
summary_line(io, "secondary indicator",
1113
controller.indicator_secondary |> typeof |> nameof)
1114
show(increment_indent(io), mime, controller.indicator_secondary)
1115
summary_line(io, "base_level", controller.base_level)
1116
summary_line(io, "med_level", controller.med_level)
1117
summary_line(io, "max_level", controller.max_level)
1118
summary_line(io, "med_threshold", controller.med_threshold)
1119
summary_line(io, "max_threshold", controller.max_threshold)
1120
summary_line(io, "max_threshold_secondary", controller.max_threshold_secondary)
1121
summary_footer(io)
1122
end
1123
end
1124
1125
function get_element_variables!(element_variables, u, mesh, equations, solver, cache,
1126
controller::ControllerThreeLevelCombined,
1127
amr_callback::AMRCallback;
1128
kwargs...)
1129
# call the indicator to get up-to-date values for IO
1130
controller.indicator_primary(u, mesh, equations, solver, cache; kwargs...)
1131
return get_element_variables!(element_variables, controller.indicator_primary,
1132
amr_callback)
1133
end
1134
1135
function (controller::ControllerThreeLevelCombined)(u::AbstractArray{<:Any},
1136
mesh, equations, dg::DG, cache;
1137
kwargs...)
1138
@unpack controller_value = controller.cache
1139
resize!(controller_value, nelements(dg, cache))
1140
1141
alpha = controller.indicator_primary(u, mesh, equations, dg, cache; kwargs...)
1142
alpha_secondary = controller.indicator_secondary(u, mesh, equations, dg, cache)
1143
1144
current_levels = current_element_levels(mesh, dg, cache)
1145
1146
@threaded for element in eachelement(dg, cache)
1147
current_level = current_levels[element]
1148
1149
# set target level
1150
target_level = current_level
1151
if alpha[element] > controller.max_threshold
1152
target_level = controller.max_level
1153
elseif alpha[element] > controller.med_threshold
1154
if controller.med_level > 0
1155
target_level = controller.med_level
1156
# otherwise, target_level = current_level
1157
# set med_level = -1 to implicitly use med_level = current_level
1158
end
1159
else
1160
target_level = controller.base_level
1161
end
1162
1163
if alpha_secondary[element] >= controller.max_threshold_secondary
1164
target_level = controller.max_level
1165
end
1166
1167
# compare target level with actual level to set controller
1168
if current_level < target_level
1169
controller_value[element] = 1 # refine!
1170
elseif current_level > target_level
1171
controller_value[element] = -1 # coarsen!
1172
else
1173
controller_value[element] = 0 # we're good
1174
end
1175
end
1176
1177
return controller_value
1178
end
1179
1180
include("amr_dg.jl")
1181
include("amr_dg1d.jl")
1182
include("amr_dg2d.jl")
1183
include("amr_dg3d.jl")
1184
end # @muladd
1185
1186