Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
trixi-framework
GitHub Repository: trixi-framework/Trixi.jl
Path: blob/main/src/callbacks_step/save_solution.jl
5586 views
1
# By default, Julia/LLVM does not use fused multiply-add operations (FMAs).
2
# Since these FMAs can increase the performance of many numerical algorithms,
3
# we need to opt-in explicitly.
4
# See https://ranocha.de/blog/Optimizing_EC_Trixi for further details.
5
@muladd begin
6
#! format: noindent
7
8
"""
9
SaveSolutionCallback(; interval::Integer=0,
10
dt=nothing,
11
save_initial_solution=true,
12
save_final_solution=true,
13
output_directory="out",
14
solution_variables=cons2prim,
15
extra_node_variables=())
16
17
Save the current numerical solution in regular intervals. Either pass `interval` to save
18
every `interval` time steps or pass `dt` to save in intervals of `dt` in terms
19
of integration time by adding additional (shortened) time steps where necessary (note that this may change the solution).
20
`solution_variables` can be any callable that converts the conservative variables
21
at a single point to a set of solution variables. The first parameter passed
22
to `solution_variables` will be the set of conservative variables
23
and the second parameter is the equation struct.
24
25
Additional nodal variables such as vorticity or the Mach number can be saved by passing a tuple of symbols
26
to `extra_node_variables`, e.g., `extra_node_variables = (:vorticity, :mach)`.
27
In that case the function `get_node_variable` must be defined for each symbol in the tuple.
28
The expected signature of the function for (purely) hyperbolic equations is:
29
```julia
30
function get_node_variable(::Val{symbol}, u, mesh, equations, dg, cache)
31
# Implementation goes here
32
end
33
```
34
and must return an array of dimension
35
`(ntuple(_ -> n_nodes, ndims(mesh))..., n_elements)`.
36
37
For purely parabolic equations, `cache_parabolic` must be added:
38
```julia
39
function get_node_variable(::Val{symbol}, u, mesh, equations, dg, cache,
40
cache_parabolic)
41
# Implementation goes here
42
end
43
```
44
45
For hyperbolic-parabolic equations, `equations_parabolic` and `cache_parabolic` must be
46
added:
47
```julia
48
function get_node_variable(::Val{symbol}, u, mesh, equations, dg, cache,
49
equations_parabolic, cache_parabolic)
50
# Implementation goes here
51
end
52
```
53
"""
54
struct SaveSolutionCallback{IntervalType, SolutionVariablesType}
55
interval_or_dt::IntervalType
56
save_initial_solution::Bool
57
save_final_solution::Bool
58
output_directory::String
59
solution_variables::SolutionVariablesType
60
node_variables::Dict{Symbol, Any}
61
end
62
63
function Base.show(io::IO, cb::DiscreteCallback{<:Any, <:SaveSolutionCallback})
64
@nospecialize cb # reduce precompilation time
65
66
save_solution_callback = cb.affect!
67
print(io, "SaveSolutionCallback(interval=", save_solution_callback.interval_or_dt,
68
")")
69
return nothing
70
end
71
72
function Base.show(io::IO,
73
cb::DiscreteCallback{<:Any,
74
<:PeriodicCallbackAffect{<:SaveSolutionCallback}})
75
@nospecialize cb # reduce precompilation time
76
77
save_solution_callback = cb.affect!.affect!
78
print(io, "SaveSolutionCallback(dt=", save_solution_callback.interval_or_dt, ")")
79
return nothing
80
end
81
82
function Base.show(io::IO, ::MIME"text/plain",
83
cb::DiscreteCallback{<:Any, <:SaveSolutionCallback})
84
@nospecialize cb # reduce precompilation time
85
86
if get(io, :compact, false)
87
show(io, cb)
88
else
89
save_solution_callback = cb.affect!
90
91
setup = [
92
"interval" => save_solution_callback.interval_or_dt,
93
"solution variables" => save_solution_callback.solution_variables,
94
"save initial solution" => save_solution_callback.save_initial_solution ?
95
"yes" : "no",
96
"save final solution" => save_solution_callback.save_final_solution ?
97
"yes" : "no",
98
"output directory" => abspath(normpath(save_solution_callback.output_directory))
99
]
100
summary_box(io, "SaveSolutionCallback", setup)
101
end
102
end
103
104
function Base.show(io::IO, ::MIME"text/plain",
105
cb::DiscreteCallback{<:Any,
106
<:PeriodicCallbackAffect{<:SaveSolutionCallback}})
107
@nospecialize cb # reduce precompilation time
108
109
if get(io, :compact, false)
110
show(io, cb)
111
else
112
save_solution_callback = cb.affect!.affect!
113
114
setup = [
115
"dt" => save_solution_callback.interval_or_dt,
116
"solution variables" => save_solution_callback.solution_variables,
117
"save initial solution" => save_solution_callback.save_initial_solution ?
118
"yes" : "no",
119
"save final solution" => save_solution_callback.save_final_solution ?
120
"yes" : "no",
121
"output directory" => abspath(normpath(save_solution_callback.output_directory))
122
]
123
summary_box(io, "SaveSolutionCallback", setup)
124
end
125
end
126
127
function SaveSolutionCallback(; interval::Integer = 0,
128
dt = nothing,
129
save_initial_solution = true,
130
save_final_solution = true,
131
output_directory = "out",
132
solution_variables = cons2prim,
133
extra_node_variables = ())
134
if !isnothing(dt) && interval > 0
135
throw(ArgumentError("You can either set the number of steps between output (using `interval`) or the time between outputs (using `dt`) but not both simultaneously"))
136
end
137
138
# Expected most frequent behavior comes first
139
if isnothing(dt)
140
interval_or_dt = interval
141
else # !isnothing(dt)
142
interval_or_dt = dt
143
end
144
145
node_variables = Dict{Symbol, Any}(var => nothing for var in extra_node_variables)
146
solution_callback = SaveSolutionCallback(interval_or_dt,
147
save_initial_solution, save_final_solution,
148
output_directory, solution_variables,
149
node_variables)
150
151
# Expected most frequent behavior comes first
152
if isnothing(dt)
153
# Save every `interval` (accepted) time steps
154
# The first one is the condition, the second the affect!
155
return DiscreteCallback(solution_callback, solution_callback,
156
save_positions = (false, false),
157
initialize = initialize_save_cb!)
158
else
159
# Add a `tstop` every `dt`, and save the final solution.
160
return PeriodicCallback(solution_callback, dt,
161
save_positions = (false, false),
162
initialize = initialize_save_cb!,
163
final_affect = save_final_solution)
164
end
165
end
166
167
function initialize_save_cb!(cb, u, t, integrator)
168
# The SaveSolutionCallback is either cb.affect! (with DiscreteCallback)
169
# or cb.affect!.affect! (with PeriodicCallback).
170
# Let recursive dispatch handle this.
171
return initialize_save_cb!(cb.affect!, u, t, integrator)
172
end
173
174
function initialize_save_cb!(solution_callback::SaveSolutionCallback, u, t, integrator)
175
mpi_isroot() && mkpath(solution_callback.output_directory)
176
177
semi = integrator.p
178
@trixi_timeit timer() "I/O" save_mesh(semi, solution_callback.output_directory)
179
180
if solution_callback.save_initial_solution
181
solution_callback(integrator)
182
end
183
184
return nothing
185
end
186
187
# Save mesh for a general semidiscretization (default)
188
function save_mesh(semi::AbstractSemidiscretization, output_directory, timestep = 0)
189
mesh, _, _, _ = mesh_equations_solver_cache(semi)
190
191
if mesh.unsaved_changes
192
# We only append the time step number to the mesh file name if it has
193
# changed during the simulation due to AMR. We do not append it for
194
# the first time step.
195
if timestep == 0
196
mesh.current_filename = save_mesh_file(mesh, output_directory)
197
else
198
mesh.current_filename = save_mesh_file(mesh, output_directory, timestep)
199
end
200
mesh.unsaved_changes = false
201
end
202
return mesh.current_filename
203
end
204
205
# Save mesh for a DGMultiMesh, which requires passing the `basis` as an argument to
206
# save_mesh_file
207
function save_mesh(semi::Union{SemidiscretizationHyperbolic{<:DGMultiMesh},
208
SemidiscretizationHyperbolicParabolic{<:DGMultiMesh}},
209
output_directory, timestep = 0)
210
mesh, _, solver, _ = mesh_equations_solver_cache(semi)
211
212
if mesh.unsaved_changes
213
# We only append the time step number to the mesh file name if it has
214
# changed during the simulation due to AMR. We do not append it for
215
# the first time step.
216
if timestep == 0
217
mesh.current_filename = save_mesh_file(semi.mesh, solver.basis,
218
output_directory)
219
else
220
mesh.current_filename = save_mesh_file(semi.mesh, solver.basis,
221
output_directory, timestep)
222
end
223
mesh.unsaved_changes = false
224
end
225
return mesh.current_filename
226
end
227
228
# this method is called to determine whether the callback should be activated
229
function (solution_callback::SaveSolutionCallback)(u, t, integrator)
230
@unpack interval_or_dt, save_final_solution = solution_callback
231
232
# With error-based step size control, some steps can be rejected. Thus,
233
# `integrator.iter >= integrator.stats.naccept`
234
# (total #steps) (#accepted steps)
235
# We need to check the number of accepted steps since callbacks are not
236
# activated after a rejected step.
237
return interval_or_dt > 0 && (integrator.stats.naccept % interval_or_dt == 0 ||
238
(save_final_solution && isfinished(integrator)))
239
end
240
241
# this method is called when the callback is activated
242
function (solution_callback::SaveSolutionCallback)(integrator)
243
u_ode = integrator.u
244
semi = integrator.p
245
iter = integrator.stats.naccept
246
247
@trixi_timeit timer() "I/O" begin
248
# Call high-level functions that dispatch on semidiscretization type
249
@trixi_timeit timer() "save mesh" save_mesh(semi,
250
solution_callback.output_directory,
251
iter)
252
save_solution_file(semi, u_ode, solution_callback, integrator)
253
end
254
255
# avoid re-evaluating possible FSAL stages
256
u_modified!(integrator, false)
257
return nothing
258
end
259
260
@inline function save_solution_file(semi::AbstractSemidiscretization, u_ode,
261
solution_callback,
262
integrator; system = "")
263
@unpack t, dt = integrator
264
iter = integrator.stats.naccept
265
266
element_variables = Dict{Symbol, Any}()
267
@trixi_timeit timer() "get element variables" begin
268
get_element_variables!(element_variables, u_ode, semi)
269
callbacks = integrator.opts.callback
270
if callbacks isa CallbackSet
271
foreach(callbacks.continuous_callbacks) do cb
272
return get_element_variables!(element_variables, u_ode, semi, cb;
273
t = integrator.t, iter = iter)
274
end
275
foreach(callbacks.discrete_callbacks) do cb
276
return get_element_variables!(element_variables, u_ode, semi, cb;
277
t = integrator.t, iter = iter)
278
end
279
end
280
end
281
282
@trixi_timeit timer() "get node variables" get_node_variables!(solution_callback.node_variables,
283
u_ode, semi)
284
285
@trixi_timeit timer() "save solution" save_solution_file(u_ode, t, dt, iter, semi,
286
solution_callback,
287
element_variables,
288
solution_callback.node_variables,
289
system = system)
290
291
return nothing
292
end
293
294
@inline function save_solution_file(u_ode, t, dt, iter,
295
semi::AbstractSemidiscretization, solution_callback,
296
element_variables = Dict{Symbol, Any}(),
297
node_variables = Dict{Symbol, Any}();
298
system = "")
299
# TODO GPU currently on CPU
300
backend = trixi_backend(u_ode)
301
if backend !== nothing
302
u_ode = Array(u_ode)
303
end
304
mesh, equations, solver, cache = mesh_equations_solver_cache(semi)
305
u = wrap_array_native(u_ode, mesh, equations, solver, cache)
306
save_solution_file(u, t, dt, iter, mesh, equations, solver, cache,
307
solution_callback,
308
element_variables,
309
node_variables; system = system)
310
311
return nothing
312
end
313
314
# TODO: Taal refactor, move save_mesh_file?
315
# function save_mesh_file(mesh::TreeMesh, output_directory, timestep=-1) in io/io.jl
316
317
include("save_solution_dg.jl")
318
end # @muladd
319
320