module TrixiCUDAExt
using CUDA: CUDA, CuArray, CuDeviceArray, KernelAdaptor, @device_override
import Trixi
function Trixi.storage_type(::Type{<:CuArray})
return CuArray
end
function Trixi.unsafe_wrap_or_alloc(::KernelAdaptor, vec, size)
return Trixi.unsafe_wrap_or_alloc(CuDeviceArray, vec, size)
end
function Trixi.unsafe_wrap_or_alloc(::Type{<:CuDeviceArray}, vec::CuDeviceArray, size)
return reshape(vec, size)
end
@static if Trixi._PREFERENCE_LOG == "log_Trixi_NaN"
@device_override Trixi.log(x::Float64) = ccall("extern __nv_log", llvmcall, Cdouble,
(Cdouble,), x)
@device_override Trixi.log(x::Float32) = ccall("extern __nv_logf", llvmcall, Cfloat,
(Cfloat,), x)
end
end