Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Tetragramm
GitHub Repository: Tetragramm/opencv
Path: blob/master/modules/core/src/intel_gpu_gemm.inl.hpp
16337 views
1
/*
2
* Copyright 2015-2017 Philippe Tillet
3
* Copyright (c) 2017, Intel Corporation
4
*
5
* Permission is hereby granted, free of charge, to any person obtaining
6
* a copy of this software and associated documentation files
7
* (the "Software"), to deal in the Software without restriction,
8
* including without limitation the rights to use, copy, modify, merge,
9
* publish, distribute, sublicense, and/or sell copies of the Software,
10
* and to permit persons to whom the Software is furnished to do so,
11
* subject to the following conditions:
12
*
13
* The above copyright notice and this permission notice shall be
14
* included in all copies or substantial portions of the Software.
15
*
16
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
17
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
18
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
19
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
20
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
21
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
22
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
23
*/
24
25
#ifdef HAVE_OPENCL
26
27
#include <sstream>
28
#include "precomp.hpp"
29
#include "opencl_kernels_core.hpp"
30
#include "opencv2/core/opencl/runtime/opencl_clamdblas.hpp"
31
#include "opencv2/core/opencl/runtime/opencl_core.hpp"
32
33
namespace cv
34
{
35
36
static bool intel_gpu_gemm(
37
UMat A, Size sizeA,
38
UMat B, Size sizeB,
39
UMat D, Size sizeD,
40
double alpha, double beta,
41
bool atrans, bool btrans)
42
{
43
CV_UNUSED(sizeB);
44
45
int M = sizeD.height, N = sizeD.width, K = ((atrans)? sizeA.height : sizeA.width);
46
47
std::string kernelName;
48
bool ret = true;
49
50
size_t lx = 8, ly = 4;
51
size_t dx = 4, dy = 8;
52
53
if(!atrans && !btrans)
54
{
55
56
if (M % 32 == 0 && N % 32 == 0 && K % 16 == 0)
57
{
58
kernelName = "intelblas_gemm_buffer_NN_sp";
59
}
60
else
61
{
62
kernelName = "intelblas_gemm_buffer_NN";
63
}
64
}
65
else if(atrans && !btrans)
66
{
67
kernelName = "intelblas_gemm_buffer_TN";
68
}
69
else if(!atrans && btrans)
70
{
71
kernelName = "intelblas_gemm_buffer_NT";
72
ly = 16;
73
dx = 1;
74
}
75
else
76
{
77
kernelName = "intelblas_gemm_buffer_TT";
78
}
79
80
const size_t gx = (size_t)(N + dx - 1) / dx;
81
const size_t gy = (size_t)(M + dy - 1) / dy;
82
83
size_t local[] = {lx, ly, 1};
84
size_t global[] = {(gx + lx - 1) / lx * lx, (gy + ly - 1) / ly * ly, 1};
85
86
int stride = (M * N < 1024 * 1024) ? 10000000 : 256;
87
88
ocl::Queue q;
89
String errmsg;
90
const ocl::Program program = ocl::Context::getDefault().getProg(ocl::core::intel_gemm_oclsrc, "", errmsg);
91
92
if(!atrans && btrans)
93
{
94
ocl::Kernel k(kernelName.c_str(), program);
95
if (k.empty())
96
{
97
return false;
98
}
99
100
k.args(ocl::KernelArg::PtrReadOnly(A),
101
(int) (A.offset / sizeof(float)),
102
ocl::KernelArg::PtrReadOnly(B),
103
(int) (B.offset / sizeof(float)),
104
ocl::KernelArg::PtrWriteOnly(D),
105
(int) (D.offset / sizeof(float)),
106
M, N, K,
107
(float)alpha,
108
(float)beta,
109
(int)(A.step / sizeof(float)),
110
(int)(B.step / sizeof(float)),
111
(int)(D.step / sizeof(float))
112
);
113
114
ret = k.run(2, global, local, false, q);
115
}
116
else
117
{
118
for(int start_index = 0; start_index < K; start_index += stride)
119
{
120
ocl::Kernel k(kernelName.c_str(), program);
121
k.args(ocl::KernelArg::PtrReadOnly(A),
122
(int) (A.offset / sizeof(float)),
123
ocl::KernelArg::PtrReadOnly(B),
124
(int) (B.offset / sizeof(float)),
125
ocl::KernelArg::PtrWriteOnly(D),
126
(int) (D.offset / sizeof(float)),
127
M, N, K,
128
(float)alpha,
129
(float)beta,
130
(int)(A.step / sizeof(float)),
131
(int)(B.step / sizeof(float)),
132
(int)(D.step / sizeof(float)),
133
(int) start_index, // 14 start_index
134
stride);
135
136
ret = k.run(2, global, local, false, q);
137
if (!ret) return ret;
138
}
139
}
140
141
return ret;
142
}
143
144
} // namespace cv
145
146
#endif
147
148