Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
trixi-framework
GitHub Repository: trixi-framework/Trixi.jl
Path: blob/main/ext/TrixiAMDGPUExt.jl
5582 views
1
# Package extension for adding AMDGPU-based features to Trixi.jl
2
module TrixiAMDGPUExt
3
4
using AMDGPU: AMDGPU, ROCArray, ROCDeviceArray
5
import AMDGPU.Device: @device_override
6
import AMDGPU.Runtime: Adaptor
7
import Trixi
8
9
function Trixi.storage_type(::Type{<:ROCArray})
10
return ROCArray
11
end
12
13
function Trixi.unsafe_wrap_or_alloc(::Adaptor, vec, size)
14
return Trixi.unsafe_wrap_or_alloc(ROCDeviceArray, vec, size)
15
end
16
17
function Trixi.unsafe_wrap_or_alloc(::Type{<:ROCDeviceArray}, vec::ROCDeviceArray, size)
18
return reshape(vec, size)
19
end
20
21
@static if Trixi._PREFERENCE_LOG == "log_Trixi_NaN"
22
@device_override Trixi.log(x::Float64) = ccall("extern __ocml_log_f64", llvmcall,
23
Cdouble,
24
(Cdouble,), x)
25
@device_override Trixi.log(x::Float32) = ccall("extern __ocml_log_f32", llvmcall,
26
Cfloat,
27
(Cfloat,), x)
28
# TODO: Trixi.log(x::Float16)
29
end
30
31
end
32
33