Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
trixi-framework
GitHub Repository: trixi-framework/Trixi.jl
Path: blob/main/src/time_integration/methods_SSP.jl
5586 views
1
# By default, Julia/LLVM does not use fused multiply-add operations (FMAs).
2
# Since these FMAs can increase the performance of many numerical algorithms,
3
# we need to opt-in explicitly.
4
# See https://ranocha.de/blog/Optimizing_EC_Trixi for further details.
5
@muladd begin
6
#! format: noindent
7
8
# Abstract base type for time integration schemes of explicit strong stability-preserving (SSP)
9
# Runge-Kutta (RK) methods. They are high-order time discretizations that guarantee the TVD property.
10
abstract type SimpleAlgorithmSSP <: AbstractTimeIntegrationAlgorithm end
11
12
"""
13
SimpleSSPRK33(; stage_callbacks=())
14
15
The third-order SSP Runge-Kutta method of Shu and Osher.
16
17
## References
18
19
- Shu, Osher (1988)
20
"Efficient Implementation of Essentially Non-oscillatory Shock-Capturing Schemes" (Eq. 2.18)
21
[DOI: 10.1016/0021-9991(88)90177-5](https://doi.org/10.1016/0021-9991(88)90177-5)
22
"""
23
struct SimpleSSPRK33{StageCallbacks} <: SimpleAlgorithmSSP
24
numerator_a::SVector{3, Float64}
25
numerator_b::SVector{3, Float64}
26
denominator::SVector{3, Float64}
27
c::SVector{3, Float64}
28
stage_callbacks::StageCallbacks
29
30
function SimpleSSPRK33(; stage_callbacks = ())
31
# Mathematically speaking, it is not necessary for the algorithm to split the factors
32
# into numerator and denominator. Otherwise, however, rounding errors of the order of
33
# the machine accuracy will occur, which will add up over time and thus endanger the
34
# conservation of the simulation.
35
# See also https://github.com/trixi-framework/Trixi.jl/pull/1640.
36
numerator_a = SVector(0.0, 3.0, 1.0) # a = numerator_a / denominator
37
numerator_b = SVector(1.0, 1.0, 2.0) # b = numerator_b / denominator
38
denominator = SVector(1.0, 4.0, 3.0)
39
c = SVector(0.0, 1.0, 1 / 2)
40
41
# Butcher tableau
42
# c | A
43
# 0 |
44
# 1 | 1
45
# 1/2 | 1/4 1/4
46
# --------------------
47
# b | 1/6 1/6 2/3
48
49
return new{typeof(stage_callbacks)}(numerator_a, numerator_b, denominator, c,
50
stage_callbacks)
51
end
52
end
53
54
# This struct is needed to fake https://github.com/SciML/OrdinaryDiffEq.jl/blob/0c2048a502101647ac35faabd80da8a5645beac7/src/integrators/type.jl#L1
55
mutable struct SimpleIntegratorSSPOptions{Callback, TStops}
56
callback::Callback # callbacks; used in Trixi
57
const adaptive::Bool # whether the algorithm is adaptive; ignored
58
dtmax::Float64 # ignored
59
const maxiters::Int # maximal number of time steps
60
tstops::TStops # tstops from https://diffeq.sciml.ai/v6.8/basics/common_solver_opts/#Output-Control-1; ignored
61
end
62
63
function SimpleIntegratorSSPOptions(callback, tspan; maxiters = typemax(Int), kwargs...)
64
tstops_internal = BinaryHeap{eltype(tspan)}(FasterForward())
65
# We add last(tspan) to make sure that the time integration stops at the end time
66
push!(tstops_internal, last(tspan))
67
# We add 2 * last(tspan) because add_tstop!(integrator, t) is only called by DiffEqCallbacks.jl if tstops contains a time that is larger than t
68
# (https://github.com/SciML/DiffEqCallbacks.jl/blob/025dfe99029bd0f30a2e027582744528eb92cd24/src/iterative_and_periodic.jl#L92)
69
push!(tstops_internal, 2 * last(tspan))
70
return SimpleIntegratorSSPOptions{typeof(callback), typeof(tstops_internal)}(callback,
71
false,
72
Inf,
73
maxiters,
74
tstops_internal)
75
end
76
77
# This struct is needed to fake https://github.com/SciML/OrdinaryDiffEq.jl/blob/0c2048a502101647ac35faabd80da8a5645beac7/src/integrators/type.jl#L77
78
# This implements the interface components described at
79
# https://diffeq.sciml.ai/v6.8/basics/integrator/#Handing-Integrators-1
80
# which are used in Trixi.
81
mutable struct SimpleIntegratorSSP{RealT <: Real, uType,
82
Params, Sol, F, Alg,
83
SimpleIntegratorSSPOptions} <: AbstractTimeIntegrator
84
u::uType
85
du::uType
86
u_tmp::uType
87
t::RealT
88
tdir::RealT # DIRection of time integration, i.e., if one marches forward or backward in time
89
dt::RealT # current time step
90
dtcache::RealT # manually set time step
91
iter::Int # current number of time steps (iteration)
92
p::Params # will be the semidiscretization from Trixi
93
sol::Sol # faked
94
const f::F # `rhs!` of the semidiscretization
95
const alg::Alg # SimpleSSPRK33
96
opts::SimpleIntegratorSSPOptions
97
finalstep::Bool # added for convenience
98
const dtchangeable::Bool
99
const force_stepfail::Bool
100
end
101
102
"""
103
add_tstop!(integrator::SimpleIntegratorSSP, t)
104
Add a time stop during the time integration process.
105
This function is called after the periodic SaveSolutionCallback to specify the next stop to save the solution.
106
"""
107
function add_tstop!(integrator::SimpleIntegratorSSP, t)
108
integrator.tdir * (t - integrator.t) < zero(integrator.t) &&
109
error("Tried to add a tstop that is behind the current time. This is strictly forbidden")
110
# We need to remove the first entry of tstops when a new entry is added.
111
# Otherwise, the simulation gets stuck at the previous tstop and dt is adjusted to zero.
112
if length(integrator.opts.tstops) > 1
113
pop!(integrator.opts.tstops)
114
end
115
return push!(integrator.opts.tstops, integrator.tdir * t)
116
end
117
118
has_tstop(integrator::SimpleIntegratorSSP) = !isempty(integrator.opts.tstops)
119
first_tstop(integrator::SimpleIntegratorSSP) = first(integrator.opts.tstops)
120
121
function init(ode::ODEProblem, alg::SimpleAlgorithmSSP;
122
dt, callback::Union{CallbackSet, Nothing} = nothing, kwargs...)
123
u = copy(ode.u0)
124
du = similar(u)
125
u_tmp = similar(u)
126
t = first(ode.tspan)
127
tdir = sign(ode.tspan[end] - ode.tspan[1])
128
iter = 0
129
integrator = SimpleIntegratorSSP(u, du, u_tmp, t, tdir, dt, dt, iter, ode.p,
130
(prob = ode,), ode.f, alg,
131
SimpleIntegratorSSPOptions(callback, ode.tspan;
132
kwargs...),
133
false, true, false)
134
135
# Standard callbacks
136
initialize_callbacks!(callback, integrator)
137
138
# Addition for `SimpleAlgorithmSSP` which may have stage callbacks
139
for stage_callback in alg.stage_callbacks
140
init_callback(stage_callback, integrator.p)
141
end
142
143
return integrator
144
end
145
146
function solve!(integrator::SimpleIntegratorSSP)
147
@unpack prob = integrator.sol
148
149
integrator.finalstep = false
150
151
@trixi_timeit timer() "main loop" while !integrator.finalstep
152
step!(integrator)
153
end
154
155
# Empty the tstops array.
156
# This cannot be done in terminate!(integrator::SimpleIntegratorSSP) because DiffEqCallbacks.PeriodicCallbackAffect would return at error.
157
extract_all!(integrator.opts.tstops)
158
159
for stage_callback in integrator.alg.stage_callbacks
160
finalize_callback(stage_callback, integrator.p)
161
end
162
163
finalize_callbacks(integrator)
164
165
return TimeIntegratorSolution((first(prob.tspan), integrator.t),
166
(prob.u0, integrator.u), prob)
167
end
168
169
function step!(integrator::SimpleIntegratorSSP)
170
@unpack prob = integrator.sol
171
@unpack alg = integrator
172
t_end = last(prob.tspan)
173
callbacks = integrator.opts.callback
174
175
@assert !integrator.finalstep
176
if isnan(integrator.dt)
177
error("time step size `dt` is NaN")
178
end
179
180
modify_dt_for_tstops!(integrator)
181
182
limit_dt!(integrator, t_end)
183
184
@. integrator.u_tmp = integrator.u
185
for stage in eachindex(alg.c)
186
t_stage = integrator.t + integrator.dt * alg.c[stage]
187
# compute du
188
integrator.f(integrator.du, integrator.u, integrator.p, t_stage)
189
190
# perform forward Euler step
191
@. integrator.u = integrator.u + integrator.dt * integrator.du
192
193
for stage_callback in alg.stage_callbacks
194
stage_callback(integrator.u, integrator, stage)
195
end
196
197
# perform convex combination
198
@. integrator.u = (alg.numerator_a[stage] * integrator.u_tmp +
199
alg.numerator_b[stage] * integrator.u) /
200
alg.denominator[stage]
201
end
202
integrator.iter += 1
203
integrator.t += integrator.dt
204
205
@trixi_timeit timer() "Step-Callbacks" handle_callbacks!(callbacks, integrator)
206
207
check_max_iter!(integrator)
208
209
return nothing
210
end
211
212
# get a cache where the RHS can be stored
213
get_tmp_cache(integrator::SimpleIntegratorSSP) = (integrator.u_tmp,)
214
215
# some algorithms from DiffEq like FSAL-ones need to be informed when a callback has modified u
216
u_modified!(integrator::SimpleIntegratorSSP, ::Bool) = false
217
218
# stop the time integration
219
function terminate!(integrator::SimpleIntegratorSSP)
220
integrator.finalstep = true
221
222
return nothing
223
end
224
225
"""
226
modify_dt_for_tstops!(integrator::SimpleIntegratorSSP)
227
Modify the time-step size to match the time stops specified in integrator.opts.tstops.
228
To avoid adding OrdinaryDiffEq to Trixi's dependencies, this routine is a copy of
229
https://github.com/SciML/OrdinaryDiffEq.jl/blob/d76335281c540ee5a6d1bd8bb634713e004f62ee/src/integrators/integrator_utils.jl#L38-L54
230
"""
231
function modify_dt_for_tstops!(integrator::SimpleIntegratorSSP)
232
if has_tstop(integrator)
233
tdir_t = integrator.tdir * integrator.t
234
tdir_tstop = first_tstop(integrator)
235
if integrator.opts.adaptive
236
integrator.dt = integrator.tdir *
237
min(abs(integrator.dt), abs(tdir_tstop - tdir_t)) # step! to the end
238
elseif iszero(integrator.dtcache) && integrator.dtchangeable
239
integrator.dt = integrator.tdir * abs(tdir_tstop - tdir_t)
240
elseif integrator.dtchangeable && !integrator.force_stepfail
241
# always try to step! with dtcache, but lower if a tstop
242
# however, if force_stepfail then don't set to dtcache, and no tstop worry
243
integrator.dt = integrator.tdir *
244
min(abs(integrator.dtcache), abs(tdir_tstop - tdir_t)) # step! to the end
245
end
246
end
247
248
return nothing
249
end
250
251
# used for AMR
252
function Base.resize!(integrator::SimpleIntegratorSSP, new_size)
253
resize!(integrator.u, new_size)
254
resize!(integrator.du, new_size)
255
resize!(integrator.u_tmp, new_size)
256
257
return nothing
258
end
259
end # @muladd
260
261