Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
trixi-framework
GitHub Repository: trixi-framework/Trixi.jl
Path: blob/main/src/solvers/dgsem_t8code/dg_parallel.jl
5590 views
1
@muladd begin
2
#! format: noindent
3
4
# This method is called when a `SemidiscretizationHyperbolic` is constructed.
5
# It constructs the basic `cache` used throughout the simulation to compute
6
# the RHS etc.
7
function create_cache(mesh::T8codeMeshParallel, equations::AbstractEquations, dg::DG,
8
::Any,
9
::Type{uEltype}) where {uEltype <: Real}
10
# Make sure to balance and partition the forest before creating any
11
# containers in case someone has tampered with forest after creating the
12
# mesh.
13
balance!(mesh)
14
partition!(mesh)
15
16
count_required_surfaces!(mesh)
17
18
elements = init_elements(mesh, equations, dg.basis, uEltype)
19
mortars = init_mortars(mesh, equations, dg.basis, elements)
20
interfaces = init_interfaces(mesh, equations, dg.basis, elements)
21
boundaries = init_boundaries(mesh, equations, dg.basis, elements)
22
23
mpi_mortars = init_mpi_mortars(mesh, equations, dg.basis, elements)
24
mpi_interfaces = init_mpi_interfaces(mesh, equations, dg.basis, elements)
25
26
mpi_mesh_info = (mpi_mortars = mpi_mortars,
27
mpi_interfaces = mpi_interfaces,
28
global_mortar_ids = fill(UInt128(0), nmpimortars(mpi_mortars)),
29
global_interface_ids = fill(UInt128(0),
30
nmpiinterfaces(mpi_interfaces)),
31
neighbor_ranks_mortar = Vector{Vector{Int}}(undef,
32
nmpimortars(mpi_mortars)),
33
neighbor_ranks_interface = fill(-1,
34
nmpiinterfaces(mpi_interfaces)))
35
36
fill_mesh_info!(mesh, interfaces, mortars, boundaries,
37
mesh.boundary_names; mpi_mesh_info = mpi_mesh_info)
38
39
mpi_cache = init_mpi_cache(mesh, mpi_mesh_info, nvariables(equations), nnodes(dg),
40
uEltype)
41
42
empty!(mpi_mesh_info.global_mortar_ids)
43
empty!(mpi_mesh_info.global_interface_ids)
44
empty!(mpi_mesh_info.neighbor_ranks_mortar)
45
empty!(mpi_mesh_info.neighbor_ranks_interface)
46
47
init_normal_directions!(mpi_mortars, dg.basis, elements)
48
exchange_normal_directions!(mpi_mortars, mpi_cache, mesh, nnodes(dg))
49
50
# Container cache
51
cache = (; elements, interfaces, mpi_interfaces, boundaries, mortars,
52
mpi_mortars, mpi_cache)
53
54
# Add Volume-Integral cache
55
cache = (; cache...,
56
create_cache(mesh, equations, dg.volume_integral, dg, cache, uEltype)...)
57
# Add Mortar cache
58
cache = (; cache..., create_cache(mesh, equations, dg.mortar, uEltype)...)
59
60
return cache
61
end
62
63
function init_mpi_cache(mesh::T8codeMeshParallel, mpi_mesh_info, nvars, nnodes, uEltype)
64
mpi_cache = P4estMPICache(uEltype)
65
init_mpi_cache!(mpi_cache, mesh, mpi_mesh_info, nvars, nnodes, uEltype)
66
return mpi_cache
67
end
68
69
function init_mpi_cache!(mpi_cache::P4estMPICache, mesh::T8codeMeshParallel,
70
mpi_mesh_info, nvars, nnodes, uEltype)
71
mpi_neighbor_ranks, mpi_neighbor_interfaces, mpi_neighbor_mortars = init_mpi_neighbor_connectivity(mpi_mesh_info,
72
mesh)
73
74
mpi_send_buffers, mpi_recv_buffers, mpi_send_requests, mpi_recv_requests = init_mpi_data_structures(mpi_neighbor_interfaces,
75
mpi_neighbor_mortars,
76
ndims(mesh),
77
nvars,
78
nnodes,
79
uEltype)
80
n_elements_global = Int(t8_forest_get_global_num_elements(mesh.forest))
81
n_elements_local = Int(t8_forest_get_local_num_elements(mesh.forest))
82
83
n_elements_by_rank = Vector{Int}(undef, mpi_nranks())
84
n_elements_by_rank[mpi_rank() + 1] = n_elements_local
85
86
MPI.Allgather!(MPI.UBuffer(n_elements_by_rank, 1), mpi_comm())
87
88
n_elements_by_rank = OffsetArray(n_elements_by_rank, 0:(mpi_nranks() - 1))
89
90
# Account for 1-based indexing in Julia.
91
first_element_global_id = sum(n_elements_by_rank[0:(mpi_rank() - 1)]) + 1
92
93
@assert n_elements_global==sum(n_elements_by_rank) "error in total number of elements"
94
95
@pack! mpi_cache = mpi_neighbor_ranks, mpi_neighbor_interfaces,
96
mpi_neighbor_mortars,
97
mpi_send_buffers, mpi_recv_buffers,
98
mpi_send_requests, mpi_recv_requests,
99
n_elements_by_rank, n_elements_global,
100
first_element_global_id
101
102
return mpi_cache
103
end
104
105
function init_mpi_neighbor_connectivity(mpi_mesh_info, mesh::T8codeMeshParallel)
106
@unpack mpi_interfaces, mpi_mortars, global_interface_ids, neighbor_ranks_interface, global_mortar_ids, neighbor_ranks_mortar = mpi_mesh_info
107
108
mpi_neighbor_ranks = vcat(neighbor_ranks_interface, neighbor_ranks_mortar...) |>
109
sort |> unique
110
111
p = sortperm(global_interface_ids)
112
113
neighbor_ranks_interface .= neighbor_ranks_interface[p]
114
interface_ids = collect(1:nmpiinterfaces(mpi_interfaces))[p]
115
116
p = sortperm(global_mortar_ids)
117
neighbor_ranks_mortar .= neighbor_ranks_mortar[p]
118
mortar_ids = collect(1:nmpimortars(mpi_mortars))[p]
119
120
# For each neighbor rank, init connectivity data structures
121
mpi_neighbor_interfaces = Vector{Vector{Int}}(undef, length(mpi_neighbor_ranks))
122
mpi_neighbor_mortars = Vector{Vector{Int}}(undef, length(mpi_neighbor_ranks))
123
for (index, rank) in enumerate(mpi_neighbor_ranks)
124
mpi_neighbor_interfaces[index] = interface_ids[findall(==(rank),
125
neighbor_ranks_interface)]
126
mpi_neighbor_mortars[index] = mortar_ids[findall(x -> (rank in x),
127
neighbor_ranks_mortar)]
128
end
129
130
# Check that all interfaces were counted exactly once
131
@assert mapreduce(length, +, mpi_neighbor_interfaces; init = 0) ==
132
nmpiinterfaces(mpi_interfaces)
133
134
return mpi_neighbor_ranks, mpi_neighbor_interfaces, mpi_neighbor_mortars
135
end
136
end # @muladd
137
138