Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
trixi-framework
GitHub Repository: trixi-framework/Trixi.jl
Path: blob/main/ext/TrixiCUDAExt.jl
5582 views
1
# Package extension for adding CUDA-based features to Trixi.jl
2
module TrixiCUDAExt
3
4
using CUDA: CUDA, CuArray, CuDeviceArray, KernelAdaptor, @device_override
5
import Trixi
6
7
function Trixi.storage_type(::Type{<:CuArray})
8
return CuArray
9
end
10
11
function Trixi.unsafe_wrap_or_alloc(::KernelAdaptor, vec, size)
12
return Trixi.unsafe_wrap_or_alloc(CuDeviceArray, vec, size)
13
end
14
15
function Trixi.unsafe_wrap_or_alloc(::Type{<:CuDeviceArray}, vec::CuDeviceArray, size)
16
return reshape(vec, size)
17
end
18
19
@static if Trixi._PREFERENCE_LOG == "log_Trixi_NaN"
20
@device_override Trixi.log(x::Float64) = ccall("extern __nv_log", llvmcall, Cdouble,
21
(Cdouble,), x)
22
@device_override Trixi.log(x::Float32) = ccall("extern __nv_logf", llvmcall, Cfloat,
23
(Cfloat,), x)
24
# TODO: Trixi.log(x::Float16)
25
end
26
27
end
28
29