Path: blob/main/src/callbacks_step/save_solution.jl
5586 views
# By default, Julia/LLVM does not use fused multiply-add operations (FMAs).1# Since these FMAs can increase the performance of many numerical algorithms,2# we need to opt-in explicitly.3# See https://ranocha.de/blog/Optimizing_EC_Trixi for further details.4@muladd begin5#! format: noindent67"""8SaveSolutionCallback(; interval::Integer=0,9dt=nothing,10save_initial_solution=true,11save_final_solution=true,12output_directory="out",13solution_variables=cons2prim,14extra_node_variables=())1516Save the current numerical solution in regular intervals. Either pass `interval` to save17every `interval` time steps or pass `dt` to save in intervals of `dt` in terms18of integration time by adding additional (shortened) time steps where necessary (note that this may change the solution).19`solution_variables` can be any callable that converts the conservative variables20at a single point to a set of solution variables. The first parameter passed21to `solution_variables` will be the set of conservative variables22and the second parameter is the equation struct.2324Additional nodal variables such as vorticity or the Mach number can be saved by passing a tuple of symbols25to `extra_node_variables`, e.g., `extra_node_variables = (:vorticity, :mach)`.26In that case the function `get_node_variable` must be defined for each symbol in the tuple.27The expected signature of the function for (purely) hyperbolic equations is:28```julia29function get_node_variable(::Val{symbol}, u, mesh, equations, dg, cache)30# Implementation goes here31end32```33and must return an array of dimension34`(ntuple(_ -> n_nodes, ndims(mesh))..., n_elements)`.3536For purely parabolic equations, `cache_parabolic` must be added:37```julia38function get_node_variable(::Val{symbol}, u, mesh, equations, dg, cache,39cache_parabolic)40# Implementation goes here41end42```4344For hyperbolic-parabolic equations, `equations_parabolic` and `cache_parabolic` must be45added:46```julia47function get_node_variable(::Val{symbol}, u, mesh, equations, dg, cache,48equations_parabolic, cache_parabolic)49# Implementation goes here50end51```52"""53struct SaveSolutionCallback{IntervalType, SolutionVariablesType}54interval_or_dt::IntervalType55save_initial_solution::Bool56save_final_solution::Bool57output_directory::String58solution_variables::SolutionVariablesType59node_variables::Dict{Symbol, Any}60end6162function Base.show(io::IO, cb::DiscreteCallback{<:Any, <:SaveSolutionCallback})63@nospecialize cb # reduce precompilation time6465save_solution_callback = cb.affect!66print(io, "SaveSolutionCallback(interval=", save_solution_callback.interval_or_dt,67")")68return nothing69end7071function Base.show(io::IO,72cb::DiscreteCallback{<:Any,73<:PeriodicCallbackAffect{<:SaveSolutionCallback}})74@nospecialize cb # reduce precompilation time7576save_solution_callback = cb.affect!.affect!77print(io, "SaveSolutionCallback(dt=", save_solution_callback.interval_or_dt, ")")78return nothing79end8081function Base.show(io::IO, ::MIME"text/plain",82cb::DiscreteCallback{<:Any, <:SaveSolutionCallback})83@nospecialize cb # reduce precompilation time8485if get(io, :compact, false)86show(io, cb)87else88save_solution_callback = cb.affect!8990setup = [91"interval" => save_solution_callback.interval_or_dt,92"solution variables" => save_solution_callback.solution_variables,93"save initial solution" => save_solution_callback.save_initial_solution ?94"yes" : "no",95"save final solution" => save_solution_callback.save_final_solution ?96"yes" : "no",97"output directory" => abspath(normpath(save_solution_callback.output_directory))98]99summary_box(io, "SaveSolutionCallback", setup)100end101end102103function Base.show(io::IO, ::MIME"text/plain",104cb::DiscreteCallback{<:Any,105<:PeriodicCallbackAffect{<:SaveSolutionCallback}})106@nospecialize cb # reduce precompilation time107108if get(io, :compact, false)109show(io, cb)110else111save_solution_callback = cb.affect!.affect!112113setup = [114"dt" => save_solution_callback.interval_or_dt,115"solution variables" => save_solution_callback.solution_variables,116"save initial solution" => save_solution_callback.save_initial_solution ?117"yes" : "no",118"save final solution" => save_solution_callback.save_final_solution ?119"yes" : "no",120"output directory" => abspath(normpath(save_solution_callback.output_directory))121]122summary_box(io, "SaveSolutionCallback", setup)123end124end125126function SaveSolutionCallback(; interval::Integer = 0,127dt = nothing,128save_initial_solution = true,129save_final_solution = true,130output_directory = "out",131solution_variables = cons2prim,132extra_node_variables = ())133if !isnothing(dt) && interval > 0134throw(ArgumentError("You can either set the number of steps between output (using `interval`) or the time between outputs (using `dt`) but not both simultaneously"))135end136137# Expected most frequent behavior comes first138if isnothing(dt)139interval_or_dt = interval140else # !isnothing(dt)141interval_or_dt = dt142end143144node_variables = Dict{Symbol, Any}(var => nothing for var in extra_node_variables)145solution_callback = SaveSolutionCallback(interval_or_dt,146save_initial_solution, save_final_solution,147output_directory, solution_variables,148node_variables)149150# Expected most frequent behavior comes first151if isnothing(dt)152# Save every `interval` (accepted) time steps153# The first one is the condition, the second the affect!154return DiscreteCallback(solution_callback, solution_callback,155save_positions = (false, false),156initialize = initialize_save_cb!)157else158# Add a `tstop` every `dt`, and save the final solution.159return PeriodicCallback(solution_callback, dt,160save_positions = (false, false),161initialize = initialize_save_cb!,162final_affect = save_final_solution)163end164end165166function initialize_save_cb!(cb, u, t, integrator)167# The SaveSolutionCallback is either cb.affect! (with DiscreteCallback)168# or cb.affect!.affect! (with PeriodicCallback).169# Let recursive dispatch handle this.170return initialize_save_cb!(cb.affect!, u, t, integrator)171end172173function initialize_save_cb!(solution_callback::SaveSolutionCallback, u, t, integrator)174mpi_isroot() && mkpath(solution_callback.output_directory)175176semi = integrator.p177@trixi_timeit timer() "I/O" save_mesh(semi, solution_callback.output_directory)178179if solution_callback.save_initial_solution180solution_callback(integrator)181end182183return nothing184end185186# Save mesh for a general semidiscretization (default)187function save_mesh(semi::AbstractSemidiscretization, output_directory, timestep = 0)188mesh, _, _, _ = mesh_equations_solver_cache(semi)189190if mesh.unsaved_changes191# We only append the time step number to the mesh file name if it has192# changed during the simulation due to AMR. We do not append it for193# the first time step.194if timestep == 0195mesh.current_filename = save_mesh_file(mesh, output_directory)196else197mesh.current_filename = save_mesh_file(mesh, output_directory, timestep)198end199mesh.unsaved_changes = false200end201return mesh.current_filename202end203204# Save mesh for a DGMultiMesh, which requires passing the `basis` as an argument to205# save_mesh_file206function save_mesh(semi::Union{SemidiscretizationHyperbolic{<:DGMultiMesh},207SemidiscretizationHyperbolicParabolic{<:DGMultiMesh}},208output_directory, timestep = 0)209mesh, _, solver, _ = mesh_equations_solver_cache(semi)210211if mesh.unsaved_changes212# We only append the time step number to the mesh file name if it has213# changed during the simulation due to AMR. We do not append it for214# the first time step.215if timestep == 0216mesh.current_filename = save_mesh_file(semi.mesh, solver.basis,217output_directory)218else219mesh.current_filename = save_mesh_file(semi.mesh, solver.basis,220output_directory, timestep)221end222mesh.unsaved_changes = false223end224return mesh.current_filename225end226227# this method is called to determine whether the callback should be activated228function (solution_callback::SaveSolutionCallback)(u, t, integrator)229@unpack interval_or_dt, save_final_solution = solution_callback230231# With error-based step size control, some steps can be rejected. Thus,232# `integrator.iter >= integrator.stats.naccept`233# (total #steps) (#accepted steps)234# We need to check the number of accepted steps since callbacks are not235# activated after a rejected step.236return interval_or_dt > 0 && (integrator.stats.naccept % interval_or_dt == 0 ||237(save_final_solution && isfinished(integrator)))238end239240# this method is called when the callback is activated241function (solution_callback::SaveSolutionCallback)(integrator)242u_ode = integrator.u243semi = integrator.p244iter = integrator.stats.naccept245246@trixi_timeit timer() "I/O" begin247# Call high-level functions that dispatch on semidiscretization type248@trixi_timeit timer() "save mesh" save_mesh(semi,249solution_callback.output_directory,250iter)251save_solution_file(semi, u_ode, solution_callback, integrator)252end253254# avoid re-evaluating possible FSAL stages255u_modified!(integrator, false)256return nothing257end258259@inline function save_solution_file(semi::AbstractSemidiscretization, u_ode,260solution_callback,261integrator; system = "")262@unpack t, dt = integrator263iter = integrator.stats.naccept264265element_variables = Dict{Symbol, Any}()266@trixi_timeit timer() "get element variables" begin267get_element_variables!(element_variables, u_ode, semi)268callbacks = integrator.opts.callback269if callbacks isa CallbackSet270foreach(callbacks.continuous_callbacks) do cb271return get_element_variables!(element_variables, u_ode, semi, cb;272t = integrator.t, iter = iter)273end274foreach(callbacks.discrete_callbacks) do cb275return get_element_variables!(element_variables, u_ode, semi, cb;276t = integrator.t, iter = iter)277end278end279end280281@trixi_timeit timer() "get node variables" get_node_variables!(solution_callback.node_variables,282u_ode, semi)283284@trixi_timeit timer() "save solution" save_solution_file(u_ode, t, dt, iter, semi,285solution_callback,286element_variables,287solution_callback.node_variables,288system = system)289290return nothing291end292293@inline function save_solution_file(u_ode, t, dt, iter,294semi::AbstractSemidiscretization, solution_callback,295element_variables = Dict{Symbol, Any}(),296node_variables = Dict{Symbol, Any}();297system = "")298# TODO GPU currently on CPU299backend = trixi_backend(u_ode)300if backend !== nothing301u_ode = Array(u_ode)302end303mesh, equations, solver, cache = mesh_equations_solver_cache(semi)304u = wrap_array_native(u_ode, mesh, equations, solver, cache)305save_solution_file(u, t, dt, iter, mesh, equations, solver, cache,306solution_callback,307element_variables,308node_variables; system = system)309310return nothing311end312313# TODO: Taal refactor, move save_mesh_file?314# function save_mesh_file(mesh::TreeMesh, output_directory, timestep=-1) in io/io.jl315316include("save_solution_dg.jl")317end # @muladd318319320