Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Tetragramm
GitHub Repository: Tetragramm/opencv
Path: blob/master/modules/dnn/src/opencl/gemm_image.cl
16337 views
1
/*M///////////////////////////////////////////////////////////////////////////////////////
2
//
3
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4
//
5
// By downloading, copying, installing or using the software you agree to this license.
6
// If you do not agree to this license, do not download, install,
7
// copy or use the software.
8
//
9
//
10
// License Agreement
11
// For Open Source Computer Vision Library
12
//
13
// Copyright (C) 2017, Intel Corporation, all rights reserved.
14
// Third party copyrights are property of their respective owners.
15
//
16
// Redistribution and use in source and binary forms, with or without modification,
17
// are permitted provided that the following conditions are met:
18
//
19
// * Redistribution's of source code must retain the above copyright notice,
20
// this list of conditions and the following disclaimer.
21
//
22
// * Redistribution's in binary form must reproduce the above copyright notice,
23
// this list of conditions and the following disclaimer in the documentation
24
// and/or other materials provided with the distribution.
25
//
26
// * The name of the copyright holders may not be used to endorse or promote products
27
// derived from this software without specific prior written permission.
28
//
29
// This software is provided by the copyright holders and contributors "as is" and
30
// any express or implied warranties, including, but not limited to, the implied
31
// warranties of merchantability and fitness for a particular purpose are disclaimed.
32
// In no event shall the Intel Corporation or contributors be liable for any direct,
33
// indirect, incidental, special, exemplary, or consequential damages
34
// (including, but not limited to, procurement of substitute goods or services;
35
// loss of use, data, or profits; or business interruption) however caused
36
// and on any theory of liability, whether in contract, strict liability,
37
// or tort (including negligence or otherwise) arising in any way out of
38
// the use of this software, even if advised of the possibility of such damage.
39
//
40
//M*/
41
42
#if defined(cl_khr_fp16)
43
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
44
#endif
45
46
#define CONCAT(A,B) A##_##B
47
#define TEMPLATE(name,type) CONCAT(name,type)
48
49
#define KERNEL_ARG_DTYPE float
50
#define TYPE_FLOAT 1
51
#define TYPE_HALF 2
52
53
#if TYPE == TYPE_HALF
54
#define Dtype half
55
#define Dtype2 half2
56
#define Dtype4 half4
57
#define Dtype8 half8
58
#define Dtype16 half16
59
60
#define as_Dtype as_half
61
#define as_Dtype2 as_half2
62
#define as_Dtype4 as_half4
63
#define as_Dtype8 as_half8
64
#define as_Dtype16 as_half16
65
#else
66
#define Dtype float
67
#define Dtype2 float2
68
#define Dtype4 float4
69
#define Dtype8 float8
70
#define Dtype16 float16
71
72
#define as_Dtype as_float
73
#define as_Dtype2 as_float2
74
#define as_Dtype4 as_float4
75
#define as_Dtype8 as_float8
76
#define as_Dtype16 as_float16
77
#endif
78
79
#if defined(cl_intel_subgroups)
80
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
81
#endif
82
83
#define TILE_M 32
84
#define TILE_K 8
85
86
// common block to calculate (alpha * AxB + beta * C) and output to destination image.
87
88
#if TYPE == TYPE_HALF
89
#define SUBGROUP_BLOCK_READ8( __image, __coord ) intel_sub_group_block_read_us8( __image, __coord )
90
#define SHUFFLE_TYPE2(val) as_ushort2(val)
91
#define SHUFFLE_TYPE8(val) as_ushort8(val)
92
#define READ_IMAGE(__image, __coord) read_imageh(__image, sampler, __coord)
93
#define SIZE_OF_ELEMENT sizeof(ushort)
94
#define SIMD_SIZE_GEMM 16
95
#define TILE_N 16
96
#else
97
#define SUBGROUP_BLOCK_READ8( __image, __coord ) intel_sub_group_block_read8( __image, __coord )
98
#define SHUFFLE_TYPE2(val) val
99
#define SHUFFLE_TYPE8(val) val
100
#define READ_IMAGE(__image, __coord) read_imagef(__image, sampler, __coord)
101
#define SIZE_OF_ELEMENT sizeof(uint)
102
#define SIMD_SIZE_GEMM 8
103
#define TILE_N 8
104
#endif
105
106
//#define USE_IMAGE_C
107
#ifdef USE_IMAGE_C
108
#if TYPE == TYPE_HALF
109
#define BLOCKC_READ8( _C, _coordC ) as_Dtype8( intel_sub_group_block_read_us8( _C, _coordC ) )
110
#define BLOCKC_WRITE8( _C, _coordC, _val ) intel_sub_group_block_write_us8( _C, _coordC, as_ushort8( _val ) )
111
#else
112
#define BLOCKC_READ8( _C, _coordC ) as_Dtype8( intel_sub_group_block_read8( _C, _coordC ) )
113
#define BLOCKC_WRITE8( _C, _coordC, _val ) intel_sub_group_block_write8( _C, _coordC, as_uint8( _val ) )
114
#endif
115
#define MATC_PARAMETER __read_only image2d_t C, __write_only image2d_t dst
116
#define GEMM_OUTPUT(ALPHA1, BETA_NOT0) GEMM_OUTPUT_EXT(ALPHA1, BETA_NOT0, C, dst, sizeof(uint))
117
#else
118
#define BLOCKC_READ8( _C, _coordC ) \
119
(Dtype8) ( (_coordC.x + get_local_id(0) < N && _coordC.y < M) ? _C[ _coordC.y * ldc + _coordC.x + get_local_id(0) ] : 0, \
120
(_coordC.x + get_local_id(0) < N && _coordC.y + 1 < M) ? _C[ ( _coordC.y + 1 ) * ldc + _coordC.x + get_local_id(0) ] : 0, \
121
(_coordC.x + get_local_id(0) < N && _coordC.y + 2 < M) ? _C[ ( _coordC.y + 2 ) * ldc + _coordC.x + get_local_id(0) ] : 0, \
122
(_coordC.x + get_local_id(0) < N && _coordC.y + 3 < M) ? _C[ ( _coordC.y + 3 ) * ldc + _coordC.x + get_local_id(0) ] : 0, \
123
(_coordC.x + get_local_id(0) < N && _coordC.y + 4 < M) ? _C[ ( _coordC.y + 4 ) * ldc + _coordC.x + get_local_id(0) ] : 0, \
124
(_coordC.x + get_local_id(0) < N && _coordC.y + 5 < M) ? _C[ ( _coordC.y + 5 ) * ldc + _coordC.x + get_local_id(0) ] : 0, \
125
(_coordC.x + get_local_id(0) < N && _coordC.y + 6 < M) ? _C[ ( _coordC.y + 6 ) * ldc + _coordC.x + get_local_id(0) ] : 0, \
126
(_coordC.x + get_local_id(0) < N && _coordC.y + 7 < M) ? _C[ ( _coordC.y + 7 ) * ldc + _coordC.x + get_local_id(0) ] : 0)
127
128
#define BLOCKC_WRITE8( _C, _coordC, _val) do {\
129
if (_coordC.x + get_local_id(0) < N) { \
130
if (_coordC.y < M) \
131
_C[ _coordC.y * ldc + _coordC.x + get_local_id(0) ] = _val.s0; \
132
if (_coordC.y + 1 < M) \
133
_C[ ( _coordC.y + 1 )* ldc + _coordC.x + get_local_id(0) ] = _val.s1; \
134
if (_coordC.y + 2 < M) \
135
_C[ ( _coordC.y + 2 )* ldc + _coordC.x + get_local_id(0) ] = _val.s2; \
136
if (_coordC.y + 3 < M) \
137
_C[ ( _coordC.y + 3 )* ldc + _coordC.x + get_local_id(0) ] = _val.s3; \
138
if (_coordC.y + 4 < M) \
139
_C[ ( _coordC.y + 4 )* ldc + _coordC.x + get_local_id(0) ] = _val.s4; \
140
if (_coordC.y + 5 < M) \
141
_C[ ( _coordC.y + 5 )* ldc + _coordC.x + get_local_id(0) ] = _val.s5; \
142
if (_coordC.y + 6 < M) \
143
_C[ ( _coordC.y + 6 )* ldc + _coordC.x + get_local_id(0) ] = _val.s6; \
144
if (_coordC.y + 7 < M) \
145
_C[ ( _coordC.y + 7 )* ldc + _coordC.x + get_local_id(0) ] = _val.s7; \
146
}} while(0)
147
#define MATC_PARAMETER __global Dtype * C, const int offC, const int M, const int N, const int ldc
148
#define GEMM_OUTPUT(ALPHA1, BETA_NOT0) GEMM_OUTPUT_EXT(ALPHA1, BETA_NOT0, (C + offC), (C + offC), 1)
149
#endif
150
151
#define GEMM_OUTPUT_EXT(ALPHA1, BETA_NOT0, _C, _dst, _C_step) \
152
int2 coordDst = (int2)( ( group_x * TILE_N ) * _C_step, ( group_y * TILE_M ) ); \
153
int2 coordC = coordDst; \
154
Dtype8 blockC00; \
155
Dtype8 blockC01; \
156
Dtype8 blockC02; \
157
Dtype8 blockC03; \
158
if (BETA_NOT0) { \
159
blockC00 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); coordC.y += 8; \
160
blockC01 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); coordC.y += 8; \
161
blockC02 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); coordC.y += 8; \
162
blockC03 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); \
163
if (!ALPHA1) { \
164
blockC00 = mad(blockAxB00, (Dtype8)alpha, blockC00); \
165
blockC01 = mad(blockAxB01, (Dtype8)alpha, blockC01); \
166
blockC02 = mad(blockAxB02, (Dtype8)alpha, blockC02); \
167
blockC03 = mad(blockAxB03, (Dtype8)alpha, blockC03); \
168
} else { \
169
blockC00 += blockAxB00; \
170
blockC01 += blockAxB01; \
171
blockC02 += blockAxB02; \
172
blockC03 += blockAxB03; \
173
} \
174
} else { \
175
blockC00 = isFirstColBlock ? (Dtype)0. : BLOCKC_READ8( _C, coordC ); coordC.y += 8; \
176
blockC01 = isFirstColBlock ? (Dtype)0. : BLOCKC_READ8( _C, coordC ); coordC.y += 8; \
177
blockC02 = isFirstColBlock ? (Dtype)0. : BLOCKC_READ8( _C, coordC ); coordC.y += 8; \
178
blockC03 = isFirstColBlock ? (Dtype)0. : BLOCKC_READ8( _C, coordC ); \
179
if (!ALPHA1) { \
180
blockC00 = mad(blockAxB00, (Dtype8)alpha, blockC00); \
181
blockC01 = mad(blockAxB01, (Dtype8)alpha, blockC01); \
182
blockC02 = mad(blockAxB02, (Dtype8)alpha, blockC02); \
183
blockC03 = mad(blockAxB03, (Dtype8)alpha, blockC03); \
184
} else { \
185
blockC00 += blockAxB00; \
186
blockC01 += blockAxB01; \
187
blockC02 += blockAxB02; \
188
blockC03 += blockAxB03; \
189
} \
190
} \
191
BLOCKC_WRITE8( _dst, coordDst, blockC00 ); coordDst.y += 8; \
192
BLOCKC_WRITE8( _dst, coordDst, blockC01 ); coordDst.y += 8; \
193
BLOCKC_WRITE8( _dst, coordDst, blockC02 ); coordDst.y += 8; \
194
BLOCKC_WRITE8( _dst, coordDst, blockC03 );
195
196
// Get the specified column of the block of the block
197
#define TRANSPOSE_BLOCK_8( _block, _col ) \
198
(Dtype8)( intel_sub_group_shuffle( _block.s0, _col ), \
199
intel_sub_group_shuffle( _block.s1, _col ), \
200
intel_sub_group_shuffle( _block.s2, _col ), \
201
intel_sub_group_shuffle( _block.s3, _col ), \
202
intel_sub_group_shuffle( _block.s4, _col ), \
203
intel_sub_group_shuffle( _block.s5, _col ), \
204
intel_sub_group_shuffle( _block.s6, _col ), \
205
intel_sub_group_shuffle( _block.s7, _col ) );
206
207
// A's column block multiply B 's row block.
208
#if TYPE == TYPE_HALF
209
#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB00, _blockB01 ) \
210
{ \
211
const Dtype8 acol0 = TRANSPOSE_BLOCK_8( _blockA, 0 ); \
212
const Dtype8 acol1 = TRANSPOSE_BLOCK_8( _blockA, 1 ); \
213
const Dtype8 acol2 = TRANSPOSE_BLOCK_8( _blockA, 2 ); \
214
const Dtype8 acol3 = TRANSPOSE_BLOCK_8( _blockA, 3 ); \
215
const Dtype8 acol4 = TRANSPOSE_BLOCK_8( _blockA, 4 ); \
216
const Dtype8 acol5 = TRANSPOSE_BLOCK_8( _blockA, 5 ); \
217
const Dtype8 acol6 = TRANSPOSE_BLOCK_8( _blockA, 6 ); \
218
const Dtype8 acol7 = TRANSPOSE_BLOCK_8( _blockA, 7 ); \
219
const Dtype8 acol8 = TRANSPOSE_BLOCK_8( _blockA, 8 ); \
220
const Dtype8 acol9 = TRANSPOSE_BLOCK_8( _blockA, 9 ); \
221
const Dtype8 acola = TRANSPOSE_BLOCK_8( _blockA, 10 ); \
222
const Dtype8 acolb = TRANSPOSE_BLOCK_8( _blockA, 11 ); \
223
const Dtype8 acolc = TRANSPOSE_BLOCK_8( _blockA, 12 ); \
224
const Dtype8 acold = TRANSPOSE_BLOCK_8( _blockA, 13 ); \
225
const Dtype8 acole = TRANSPOSE_BLOCK_8( _blockA, 14 ); \
226
const Dtype8 acolf = TRANSPOSE_BLOCK_8( _blockA, 15 ); \
227
_result = mad( (Dtype8)(_blockB00.s0), acol0, _result ); \
228
_result = mad( (Dtype8)(_blockB00.s1), acol1, _result ); \
229
_result = mad( (Dtype8)(_blockB00.s2), acol2, _result ); \
230
_result = mad( (Dtype8)(_blockB00.s3), acol3, _result ); \
231
_result = mad( (Dtype8)(_blockB00.s4), acol4, _result ); \
232
_result = mad( (Dtype8)(_blockB00.s5), acol5, _result ); \
233
_result = mad( (Dtype8)(_blockB00.s6), acol6, _result ); \
234
_result = mad( (Dtype8)(_blockB00.s7), acol7, _result ); \
235
_result = mad( (Dtype8)(_blockB01.s0), acol8, _result ); \
236
_result = mad( (Dtype8)(_blockB01.s1), acol9, _result ); \
237
_result = mad( (Dtype8)(_blockB01.s2), acola, _result ); \
238
_result = mad( (Dtype8)(_blockB01.s3), acolb, _result ); \
239
_result = mad( (Dtype8)(_blockB01.s4), acolc, _result ); \
240
_result = mad( (Dtype8)(_blockB01.s5), acold, _result ); \
241
_result = mad( (Dtype8)(_blockB01.s6), acole, _result ); \
242
_result = mad( (Dtype8)(_blockB01.s7), acolf, _result ); \
243
}
244
#else
245
#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB ) \
246
{ \
247
const Dtype8 acol0 = TRANSPOSE_BLOCK_8( _blockA, 0 ); \
248
const Dtype8 acol1 = TRANSPOSE_BLOCK_8( _blockA, 1 ); \
249
const Dtype8 acol2 = TRANSPOSE_BLOCK_8( _blockA, 2 ); \
250
const Dtype8 acol3 = TRANSPOSE_BLOCK_8( _blockA, 3 ); \
251
const Dtype8 acol4 = TRANSPOSE_BLOCK_8( _blockA, 4 ); \
252
const Dtype8 acol5 = TRANSPOSE_BLOCK_8( _blockA, 5 ); \
253
const Dtype8 acol6 = TRANSPOSE_BLOCK_8( _blockA, 6 ); \
254
const Dtype8 acol7 = TRANSPOSE_BLOCK_8( _blockA, 7 ); \
255
_result = mad( (Dtype8)(_blockB.s0), acol0, _result ); \
256
_result = mad( (Dtype8)(_blockB.s1), acol1, _result ); \
257
_result = mad( (Dtype8)(_blockB.s2), acol2, _result ); \
258
_result = mad( (Dtype8)(_blockB.s3), acol3, _result ); \
259
_result = mad( (Dtype8)(_blockB.s4), acol4, _result ); \
260
_result = mad( (Dtype8)(_blockB.s5), acol5, _result ); \
261
_result = mad( (Dtype8)(_blockB.s6), acol6, _result ); \
262
_result = mad( (Dtype8)(_blockB.s7), acol7, _result ); \
263
}
264
#endif
265
266
#if TYPE == TYPE_HALF
267
#define GEMM_NN(ALPHA1, BETA_NOT0) \
268
__attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) \
269
__attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) \
270
__kernel void TEMPLATE(gemm_32_1_NN_ ##ALPHA1 ##_ ##BETA_NOT0, Dtype)( \
271
__read_only image2d_t A, \
272
__read_only image2d_t B, \
273
MATC_PARAMETER, \
274
KERNEL_ARG_DTYPE alpha_in, \
275
KERNEL_ARG_DTYPE beta_in, \
276
int width0, \
277
int isFirstColBlock) \
278
{ \
279
const Dtype alpha = (Dtype)alpha_in; \
280
const Dtype beta = (Dtype)beta_in; \
281
const int group_x = get_group_id(0); \
282
const int group_y = get_group_id(1); \
283
Dtype8 blockAxB00 = 0; \
284
Dtype8 blockAxB01 = 0; \
285
Dtype8 blockAxB02 = 0; \
286
Dtype8 blockAxB03 = 0; \
287
int2 coordA = (int2)( 0, group_y * TILE_M ); \
288
int2 coordB = (int2)( ( group_x * TILE_N ) * SIZE_OF_ELEMENT, 0 ); \
289
do \
290
{ \
291
int2 coordBTemp = coordB; \
292
Dtype8 blockB00 = as_Dtype8( SUBGROUP_BLOCK_READ8( B, coordBTemp ) ); coordB.y += TILE_K; \
293
Dtype8 blockB01 = as_Dtype8( SUBGROUP_BLOCK_READ8( B, coordBTemp ) ); coordB.y += TILE_K; \
294
int2 coordATemp = coordA; \
295
Dtype8 blockA00 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \
296
Dtype8 blockA01 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \
297
Dtype8 blockA02 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \
298
Dtype8 blockA03 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.x += TILE_K * SIZE_OF_ELEMENT * 2; \
299
MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00, blockB01 ); \
300
MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00, blockB01 ); \
301
MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00, blockB01 ); \
302
MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03, blockB00, blockB01 ); \
303
} \
304
while( coordB.y < width0 ); \
305
GEMM_OUTPUT(ALPHA1, BETA_NOT0); \
306
}
307
#else
308
#define GEMM_NN(ALPHA1, BETA_NOT0) \
309
__attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) \
310
__attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) \
311
__kernel void TEMPLATE(gemm_32_1_NN_ ##ALPHA1 ##_ ##BETA_NOT0, Dtype)( \
312
__read_only image2d_t A, \
313
__read_only image2d_t B, \
314
MATC_PARAMETER, \
315
KERNEL_ARG_DTYPE alpha_in, \
316
KERNEL_ARG_DTYPE beta_in, \
317
int width0, \
318
int isFirstColBlock) \
319
{ \
320
const Dtype alpha = (Dtype)alpha_in; \
321
const Dtype beta = (Dtype)beta_in; \
322
const int group_x = get_group_id(0); \
323
const int group_y = get_group_id(1); \
324
Dtype8 blockAxB00 = 0.0f; \
325
Dtype8 blockAxB01 = 0.0f; \
326
Dtype8 blockAxB02 = 0.0f; \
327
Dtype8 blockAxB03 = 0.0f; \
328
int2 coordA = (int2)( 0, group_y * TILE_M ); \
329
int2 coordB = (int2)( ( group_x * TILE_N ) * SIZE_OF_ELEMENT, 0 ); \
330
do \
331
{ \
332
int2 coordBTemp = coordB; \
333
Dtype8 blockB00 = as_Dtype8( SUBGROUP_BLOCK_READ8( B, coordBTemp ) ); coordB.y += TILE_K; \
334
int2 coordATemp = coordA; \
335
Dtype8 blockA00 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \
336
Dtype8 blockA01 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \
337
Dtype8 blockA02 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \
338
Dtype8 blockA03 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.x += TILE_K * SIZE_OF_ELEMENT; \
339
MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00 ); \
340
MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00 ); \
341
MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00 ); \
342
MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03, blockB00 ); \
343
} \
344
while( coordB.y < width0 ); \
345
GEMM_OUTPUT(ALPHA1, BETA_NOT0); \
346
}
347
#endif
348
349
GEMM_NN(1, 0) // ALPHA == 1, BETA == 0
350
GEMM_NN(1, 1) // ALPHA == 1, BETA != 0
351
GEMM_NN(0, 0) // ALPHA != 1, BETA == 0
352
GEMM_NN(0, 1) // ALPHA != 1, BETA != 0
353
354
#undef TRANSPOSE_BLOCK_8
355
#undef MULTIPLY_BLOCKS_8x8
356
#undef GEMM_NN
357
358
// replicate the first row to column block.
359
#define TRANSPOSE_BLOCK_8(_vec, _col) \
360
(Dtype8)( intel_sub_group_shuffle(_vec, _col + 0), \
361
intel_sub_group_shuffle(_vec, _col + 1), \
362
intel_sub_group_shuffle(_vec, _col + 2), \
363
intel_sub_group_shuffle(_vec, _col + 3), \
364
intel_sub_group_shuffle(_vec, _col + 4), \
365
intel_sub_group_shuffle(_vec, _col + 5), \
366
intel_sub_group_shuffle(_vec, _col + 6), \
367
intel_sub_group_shuffle(_vec, _col + 7) )
368
369
#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB, _col ) \
370
{ \
371
_result = mad( (Dtype8)(_blockB.s0), TRANSPOSE_BLOCK_8(_blockA.s0, _col), _result ); \
372
_result = mad( (Dtype8)(_blockB.s1), TRANSPOSE_BLOCK_8(_blockA.s1, _col), _result ); \
373
_result = mad( (Dtype8)(_blockB.s2), TRANSPOSE_BLOCK_8(_blockA.s2, _col), _result ); \
374
_result = mad( (Dtype8)(_blockB.s3), TRANSPOSE_BLOCK_8(_blockA.s3, _col), _result ); \
375
_result = mad( (Dtype8)(_blockB.s4), TRANSPOSE_BLOCK_8(_blockA.s4, _col), _result ); \
376
_result = mad( (Dtype8)(_blockB.s5), TRANSPOSE_BLOCK_8(_blockA.s5, _col), _result ); \
377
_result = mad( (Dtype8)(_blockB.s6), TRANSPOSE_BLOCK_8(_blockA.s6, _col), _result ); \
378
_result = mad( (Dtype8)(_blockB.s7), TRANSPOSE_BLOCK_8(_blockA.s7, _col), _result ); \
379
}
380
381
#if TYPE == TYPE_HALF
382
#define GEMM_TN(ALPHA1, BETA_NOT0) \
383
__attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) \
384
__attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) \
385
__kernel void TEMPLATE(gemm_32_1_TN_ ##ALPHA1 ##_ ##BETA_NOT0,Dtype)( \
386
__read_only image2d_t A, \
387
__read_only image2d_t B, \
388
MATC_PARAMETER, \
389
KERNEL_ARG_DTYPE alpha_in, \
390
KERNEL_ARG_DTYPE beta_in, \
391
int width0, \
392
int isFirstColBlock) \
393
{ \
394
const Dtype alpha = (Dtype)alpha_in; \
395
const Dtype beta = (Dtype)beta_in; \
396
const int group_x = get_group_id(0);\
397
const int group_y = get_group_id(1);\
398
Dtype8 blockAxB00 = 0;\
399
Dtype8 blockAxB01 = 0;\
400
Dtype8 blockAxB02 = 0;\
401
Dtype8 blockAxB03 = 0;\
402
int2 coordA = (int2)( group_y * TILE_M * SIZE_OF_ELEMENT, 0 );\
403
int2 coordB = (int2)( ( group_x * TILE_N ) * SIZE_OF_ELEMENT, 0 );\
404
do\
405
{\
406
int2 coordBTemp = coordB;\
407
Dtype8 blockB00 = as_Dtype8( SUBGROUP_BLOCK_READ8( B, coordBTemp ) ); coordB.y += TILE_K;\
408
int2 coordATemp = coordA;\
409
Dtype8 blockA00 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 16 * SIZE_OF_ELEMENT;\
410
Dtype8 blockA01 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.y += TILE_K;\
411
MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00, 0); \
412
MULTIPLY_BLOCKS_8x8( blockAxB01, blockA00, blockB00, 8); \
413
MULTIPLY_BLOCKS_8x8( blockAxB02, blockA01, blockB00, 0); \
414
MULTIPLY_BLOCKS_8x8( blockAxB03, blockA01, blockB00, 8); \
415
} \
416
while( coordB.y < width0 ); \
417
GEMM_OUTPUT(ALPHA1, BETA_NOT0); \
418
}
419
#else
420
#define GEMM_TN(ALPHA1, BETA_NOT0) \
421
__attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) \
422
__attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) \
423
__kernel void TEMPLATE(gemm_32_1_TN_ ##ALPHA1 ##_ ##BETA_NOT0,Dtype)( \
424
__read_only image2d_t A, \
425
__read_only image2d_t B, \
426
MATC_PARAMETER, \
427
KERNEL_ARG_DTYPE alpha_in, \
428
KERNEL_ARG_DTYPE beta_in, \
429
int width0, \
430
int isFirstColBlock) \
431
{ \
432
const Dtype alpha = (Dtype)alpha_in; \
433
const Dtype beta = (Dtype)beta_in; \
434
const int group_x = get_group_id(0);\
435
const int group_y = get_group_id(1);\
436
Dtype8 blockAxB00 = 0.0f;\
437
Dtype8 blockAxB01 = 0.0f;\
438
Dtype8 blockAxB02 = 0.0f;\
439
Dtype8 blockAxB03 = 0.0f;\
440
int2 coordA = (int2)( group_y * TILE_M * SIZE_OF_ELEMENT, 0 );\
441
int2 coordB = (int2)( ( group_x * TILE_N ) * SIZE_OF_ELEMENT, 0 );\
442
do\
443
{\
444
int2 coordBTemp = coordB;\
445
Dtype8 blockB00 = as_Dtype8( SUBGROUP_BLOCK_READ8( B, coordBTemp ) ); coordB.y += TILE_K;\
446
int2 coordATemp = coordA;\
447
Dtype8 blockA00 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT;\
448
Dtype8 blockA01 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT;\
449
Dtype8 blockA02 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT;\
450
Dtype8 blockA03 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.y += TILE_K;\
451
MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00, 0 ); \
452
MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00, 0 ); \
453
MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00, 0 ); \
454
MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03, blockB00, 0 ); \
455
} \
456
while( coordB.y < width0 ); \
457
GEMM_OUTPUT(ALPHA1, BETA_NOT0); \
458
}
459
#endif
460
461
GEMM_TN(1, 0) // ALPHA == 1, BETA == 0
462
GEMM_TN(1, 1) // ALPHA == 1, BETA != 0
463
GEMM_TN(0, 0) // ALPHA != 1, BETA == 0
464
GEMM_TN(0, 1) // ALPHA != 1, BETA != 0
465
466
#undef MULTIPLY_BLOCKS_8x8
467
#undef TRANSPOSE_BLOCK_8
468
#undef GEMM_TN
469
470
// The same as GEMM_NN
471
#define TRANSPOSE_BLOCK_8( _block, _col ) \
472
(Dtype8)( intel_sub_group_shuffle( _block.s0, _col), \
473
intel_sub_group_shuffle( _block.s1, _col), \
474
intel_sub_group_shuffle( _block.s2, _col), \
475
intel_sub_group_shuffle( _block.s3, _col), \
476
intel_sub_group_shuffle( _block.s4, _col), \
477
intel_sub_group_shuffle( _block.s5, _col), \
478
intel_sub_group_shuffle( _block.s6, _col), \
479
intel_sub_group_shuffle( _block.s7, _col) )
480
481
#if TYPE == TYPE_HALF
482
#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB ) \
483
{ \
484
const Dtype8 acol0 = TRANSPOSE_BLOCK_8( _blockA, 0 ); \
485
const Dtype8 acol1 = TRANSPOSE_BLOCK_8( _blockA, 1 ); \
486
const Dtype8 acol2 = TRANSPOSE_BLOCK_8( _blockA, 2 ); \
487
const Dtype8 acol3 = TRANSPOSE_BLOCK_8( _blockA, 3 ); \
488
const Dtype8 acol4 = TRANSPOSE_BLOCK_8( _blockA, 4 ); \
489
const Dtype8 acol5 = TRANSPOSE_BLOCK_8( _blockA, 5 ); \
490
const Dtype8 acol6 = TRANSPOSE_BLOCK_8( _blockA, 6 ); \
491
const Dtype8 acol7 = TRANSPOSE_BLOCK_8( _blockA, 7 ); \
492
const Dtype8 acol8 = TRANSPOSE_BLOCK_8( _blockA, 8 ); \
493
const Dtype8 acol9 = TRANSPOSE_BLOCK_8( _blockA, 9 ); \
494
const Dtype8 acola = TRANSPOSE_BLOCK_8( _blockA, 10 ); \
495
const Dtype8 acolb = TRANSPOSE_BLOCK_8( _blockA, 11 ); \
496
const Dtype8 acolc = TRANSPOSE_BLOCK_8( _blockA, 12 ); \
497
const Dtype8 acold = TRANSPOSE_BLOCK_8( _blockA, 13 ); \
498
const Dtype8 acole = TRANSPOSE_BLOCK_8( _blockA, 14 ); \
499
const Dtype8 acolf = TRANSPOSE_BLOCK_8( _blockA, 15 ); \
500
_result = mad( (Dtype8)_blockB.s0, acol0, _result ); \
501
_result = mad( (Dtype8)_blockB.s1, acol1, _result ); \
502
_result = mad( (Dtype8)_blockB.s2, acol2, _result ); \
503
_result = mad( (Dtype8)_blockB.s3, acol3, _result ); \
504
_result = mad( (Dtype8)_blockB.s4, acol4, _result ); \
505
_result = mad( (Dtype8)_blockB.s5, acol5, _result ); \
506
_result = mad( (Dtype8)_blockB.s6, acol6, _result ); \
507
_result = mad( (Dtype8)_blockB.s7, acol7, _result ); \
508
_result = mad( (Dtype8)_blockB.s8, acol8, _result ); \
509
_result = mad( (Dtype8)_blockB.s9, acol9, _result ); \
510
_result = mad( (Dtype8)_blockB.sa, acola, _result ); \
511
_result = mad( (Dtype8)_blockB.sb, acolb, _result ); \
512
_result = mad( (Dtype8)_blockB.sc, acolc, _result ); \
513
_result = mad( (Dtype8)_blockB.sd, acold, _result ); \
514
_result = mad( (Dtype8)_blockB.se, acole, _result ); \
515
_result = mad( (Dtype8)_blockB.sf, acolf, _result ); \
516
}
517
#else
518
#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB ) \
519
{ \
520
const Dtype8 acol0 = TRANSPOSE_BLOCK_8( _blockA, 0 ); \
521
const Dtype8 acol1 = TRANSPOSE_BLOCK_8( _blockA, 1 ); \
522
const Dtype8 acol2 = TRANSPOSE_BLOCK_8( _blockA, 2 ); \
523
const Dtype8 acol3 = TRANSPOSE_BLOCK_8( _blockA, 3 ); \
524
const Dtype8 acol4 = TRANSPOSE_BLOCK_8( _blockA, 4 ); \
525
const Dtype8 acol5 = TRANSPOSE_BLOCK_8( _blockA, 5 ); \
526
const Dtype8 acol6 = TRANSPOSE_BLOCK_8( _blockA, 6 ); \
527
const Dtype8 acol7 = TRANSPOSE_BLOCK_8( _blockA, 7 ); \
528
_result = mad( (Dtype8)_blockB.s0, acol0, _result ); \
529
_result = mad( (Dtype8)_blockB.s1, acol1, _result ); \
530
_result = mad( (Dtype8)_blockB.s2, acol2, _result ); \
531
_result = mad( (Dtype8)_blockB.s3, acol3, _result ); \
532
_result = mad( (Dtype8)_blockB.s4, acol4, _result ); \
533
_result = mad( (Dtype8)_blockB.s5, acol5, _result ); \
534
_result = mad( (Dtype8)_blockB.s6, acol6, _result ); \
535
_result = mad( (Dtype8)_blockB.s7, acol7, _result ); \
536
}
537
#endif
538
539
#if TYPE == TYPE_HALF
540
#define GEMM_NT(ALPHA1, BETA_NOT0, VECSCALAR, VECSIZE) \
541
__attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) \
542
__attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) \
543
__kernel void TEMPLATE(gemm_32_1_NT_ ##VECSCALAR ##_ ##ALPHA1 ##_ ##BETA_NOT0,Dtype)( \
544
__read_only image2d_t A, \
545
MATB_PARAMETER, \
546
MATC_PARAMETER, \
547
KERNEL_ARG_DTYPE alpha_in, \
548
KERNEL_ARG_DTYPE beta_in, \
549
int padded_k, \
550
int k, \
551
int isFirstColBlock) \
552
{ \
553
const Dtype alpha = (Dtype)alpha_in; \
554
const Dtype beta = (Dtype)beta_in; \
555
const int group_x = get_group_id(0); \
556
const int group_y = get_group_id(1); \
557
Dtype8 blockAxB00 = 0; \
558
Dtype8 blockAxB01 = 0; \
559
Dtype8 blockAxB02 = 0; \
560
Dtype8 blockAxB03 = 0; \
561
int2 coordA = (int2)( 0, group_y * TILE_M ); \
562
int2 coordB = (int2)( 0, ( group_x * TILE_N )); \
563
const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; \
564
do \
565
{ \
566
Dtype16 blockB00; \
567
BLOCKB_READ8(blockB00, B, coordB); \
568
int2 coordATemp = coordA; \
569
Dtype8 blockA00 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \
570
Dtype8 blockA01 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \
571
Dtype8 blockA02 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \
572
Dtype8 blockA03 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.x += TILE_K * SIZE_OF_ELEMENT * 2; \
573
MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00 ); \
574
MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00 ); \
575
MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00 ); \
576
MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03, blockB00 ); \
577
} \
578
while( coordB.x < padded_k / VECSIZE ); \
579
GEMM_OUTPUT(ALPHA1, BETA_NOT0); \
580
}
581
#else
582
#define GEMM_NT(ALPHA1, BETA_NOT0, VECSCALAR, VECSIZE) \
583
__attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) \
584
__attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) \
585
__kernel void TEMPLATE(gemm_32_1_NT_ ##VECSCALAR ##_ ##ALPHA1 ##_ ##BETA_NOT0,Dtype)( \
586
__read_only image2d_t A, \
587
MATB_PARAMETER, \
588
MATC_PARAMETER, \
589
KERNEL_ARG_DTYPE alpha_in, \
590
KERNEL_ARG_DTYPE beta_in, \
591
int padded_k, \
592
int k, \
593
int isFirstColBlock) \
594
{ \
595
const Dtype alpha = (Dtype)alpha_in; \
596
const Dtype beta = (Dtype)beta_in; \
597
const int group_x = get_group_id(0); \
598
const int group_y = get_group_id(1); \
599
Dtype8 blockAxB00 = 0.0f; \
600
Dtype8 blockAxB01 = 0.0f; \
601
Dtype8 blockAxB02 = 0.0f; \
602
Dtype8 blockAxB03 = 0.0f; \
603
int2 coordA = (int2)( 0, group_y * TILE_M ); \
604
int2 coordB = (int2)( 0, ( group_x * TILE_N )); \
605
const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; \
606
do \
607
{ \
608
Dtype8 blockB00; \
609
BLOCKB_READ8(blockB00, B, coordB); \
610
int2 coordATemp = coordA; \
611
Dtype8 blockA00 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \
612
Dtype8 blockA01 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \
613
Dtype8 blockA02 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \
614
Dtype8 blockA03 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.x += TILE_K * SIZE_OF_ELEMENT; \
615
MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00 ); \
616
MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00 ); \
617
MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00 ); \
618
MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03, blockB00 ); \
619
} \
620
while( coordB.x < padded_k / VECSIZE ); \
621
GEMM_OUTPUT(ALPHA1, BETA_NOT0); \
622
}
623
#endif
624
625
#if TYPE == TYPE_HALF
626
#define BLOCKB_READ8(_blockb, _B, _coordB) \
627
int2 _coordBTemp = _coordB; \
628
_coordBTemp.y += get_local_id(0); \
629
_blockb.s0123 = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
630
_blockb.s4567 = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
631
_blockb.s89ab = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
632
_blockb.scdef = READ_IMAGE(_B, _coordBTemp); _coordB.x += 4;
633
#else
634
#define BLOCKB_READ8(_blockb, _B, _coordB) \
635
int2 _coordBTemp = _coordB; \
636
_coordBTemp.y += get_local_id(0); \
637
_blockb.s0123 = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
638
_blockb.s4567 = READ_IMAGE(_B, _coordBTemp); _coordB.x += 2;
639
#endif
640
641
#define MATB_PARAMETER __read_only image2d_t B
642
643
GEMM_NT(1, 0, VEC4, 4) // ALPHA == 1, BETA == 0
644
GEMM_NT(1, 1, VEC4, 4) // ALPHA == 1, BETA != 0
645
GEMM_NT(0, 0, VEC4, 4) // ALPHA != 1, BETA == 0
646
GEMM_NT(0, 1, VEC4, 4) // ALPHA != 1, BETA != 0
647
#undef BLOCKB_READ8
648
#undef MATB_PARAMETER
649
650
#if TYPE == TYPE_HALF
651
#define BLOCKB_READ8(_blockb, _B, _coordB) \
652
int2 _coordBTemp = _coordB; \
653
_coordBTemp.y += get_local_id(0); \
654
const __global float *B_read = (__global float *)(_B + (_coordBTemp.y * ldb) + _coordBTemp.x + offB); \
655
_blockb = as_Dtype16(as_ushort16(vload8(0, B_read))); \
656
_coordB.x += TILE_K * 2;
657
#else
658
#define BLOCKB_READ8(_blockb, _B, _coordB) \
659
int2 _coordBTemp = _coordB; \
660
_coordBTemp.y += get_local_id(0); \
661
const __global Dtype *B_read = (__global Dtype *)(_B + (_coordBTemp.y * ldb) + _coordBTemp.x + offB); \
662
_blockb = vload8(0, B_read); \
663
_coordB.x += TILE_K;
664
#endif
665
666
#define MATB_PARAMETER __global Dtype *B, int offB, int ldb
667
668
GEMM_NT(1, 0, BUFFER, 1) // ALPHA == 1, BETA == 0
669
GEMM_NT(1, 1, BUFFER, 1) // ALPHA == 1, BETA != 0
670
GEMM_NT(0, 0, BUFFER, 1) // ALPHA != 1, BETA == 0
671
GEMM_NT(0, 1, BUFFER, 1) // ALPHA != 1, BETA != 0
672
#undef BLOCKB_READ8
673
#undef MATB_PARAMETER
674
675
#if TYPE == TYPE_HALF
676
#define BLOCKB_READ8(_blockb, _B, _coordB) \
677
int2 _coordBTemp = _coordB; \
678
_coordBTemp.y += get_local_id(0); \
679
Dtype4 temp; \
680
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
681
_blockb.s0 = temp.s0; \
682
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
683
_blockb.s1 = temp.s0; \
684
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
685
_blockb.s2 = temp.s0; \
686
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
687
_blockb.s3 = temp.s0; \
688
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
689
_blockb.s4 = temp.s0; \
690
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
691
_blockb.s5 = temp.s0; \
692
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
693
_blockb.s6 = temp.s0; \
694
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
695
_blockb.s7 = temp.s0; \
696
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
697
_blockb.s8 = temp.s0; \
698
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
699
_blockb.s9 = temp.s0; \
700
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
701
_blockb.sa = temp.s0; \
702
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
703
_blockb.sb = temp.s0; \
704
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
705
_blockb.sc = temp.s0; \
706
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
707
_blockb.sd = temp.s0; \
708
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
709
_blockb.se = temp.s0; \
710
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
711
_blockb.sf = temp.s0; \
712
_coordB.x += 16;
713
#else
714
#define BLOCKB_READ8(_blockb, _B, _coordB) \
715
int2 _coordBTemp = _coordB; \
716
_coordBTemp.y += get_local_id(0); \
717
Dtype4 temp; \
718
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
719
_blockb.s0 = temp.s0; \
720
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
721
_blockb.s1 = temp.s0; \
722
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
723
_blockb.s2 = temp.s0; \
724
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
725
_blockb.s3 = temp.s0; \
726
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
727
_blockb.s4 = temp.s0; \
728
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
729
_blockb.s5 = temp.s0; \
730
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
731
_blockb.s6 = temp.s0; \
732
temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
733
_blockb.s7 = temp.s0; \
734
_coordB.x += 8;
735
#endif
736
737
#define MATB_PARAMETER __read_only image2d_t B
738
739
GEMM_NT(1, 0, SCALAR, 1) // ALPHA == 1, BETA == 0
740
GEMM_NT(1, 1, SCALAR, 1) // ALPHA == 1, BETA != 0
741
GEMM_NT(0, 0, SCALAR, 1) // ALPHA != 1, BETA == 0
742
GEMM_NT(0, 1, SCALAR, 1) // ALPHA != 1, BETA != 0
743
#undef BLOCKB_READ8
744
#undef MATB_PARAMETER
745
746
#undef MULTIPLY_BLOCKS_8x8
747
#undef TRANSPOSE_BLOCK_8
748
#undef GEMM_NT
749
750
//The same as GEMM_TN.
751
#define TRANSPOSE_BLOCK_8(_vec, _col) \
752
(Dtype8)( intel_sub_group_shuffle(_vec, _col + 0), \
753
intel_sub_group_shuffle(_vec, _col + 1), \
754
intel_sub_group_shuffle(_vec, _col + 2), \
755
intel_sub_group_shuffle(_vec, _col + 3), \
756
intel_sub_group_shuffle(_vec, _col + 4), \
757
intel_sub_group_shuffle(_vec, _col + 5), \
758
intel_sub_group_shuffle(_vec, _col + 6), \
759
intel_sub_group_shuffle(_vec, _col + 7) );
760
761
#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB, _col ) \
762
{ \
763
const Dtype8 acol0 = TRANSPOSE_BLOCK_8( _blockA.s0, _col ); \
764
const Dtype8 acol1 = TRANSPOSE_BLOCK_8( _blockA.s1, _col ); \
765
const Dtype8 acol2 = TRANSPOSE_BLOCK_8( _blockA.s2, _col ); \
766
const Dtype8 acol3 = TRANSPOSE_BLOCK_8( _blockA.s3, _col ); \
767
const Dtype8 acol4 = TRANSPOSE_BLOCK_8( _blockA.s4, _col ); \
768
const Dtype8 acol5 = TRANSPOSE_BLOCK_8( _blockA.s5, _col ); \
769
const Dtype8 acol6 = TRANSPOSE_BLOCK_8( _blockA.s6, _col ); \
770
const Dtype8 acol7 = TRANSPOSE_BLOCK_8( _blockA.s7, _col ); \
771
_result = mad( (Dtype8)_blockB.s0, acol0, _result ); \
772
_result = mad( (Dtype8)_blockB.s1, acol1, _result ); \
773
_result = mad( (Dtype8)_blockB.s2, acol2, _result ); \
774
_result = mad( (Dtype8)_blockB.s3, acol3, _result ); \
775
_result = mad( (Dtype8)_blockB.s4, acol4, _result ); \
776
_result = mad( (Dtype8)_blockB.s5, acol5, _result ); \
777
_result = mad( (Dtype8)_blockB.s6, acol6, _result ); \
778
_result = mad( (Dtype8)_blockB.s7, acol7, _result ); \
779
}
780
781
#if TYPE == TYPE_HALF
782
#define GEMM_TT(ALPHA1, BETA_NOT0, VECSCALAR, VECSIZE) \
783
__attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) \
784
__attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) \
785
__kernel void TEMPLATE(gemm_32_1_TT_ ##VECSCALAR ##_ ##ALPHA1 ##_ ##BETA_NOT0, Dtype)( \
786
__read_only image2d_t A, \
787
MATB_PARAMETER, \
788
MATC_PARAMETER, \
789
KERNEL_ARG_DTYPE alpha_in, \
790
KERNEL_ARG_DTYPE beta_in, \
791
int padded_k, \
792
int k, \
793
int isFirstColBlock) \
794
{ \
795
const Dtype alpha = (Dtype)alpha_in; \
796
const Dtype beta = (Dtype)beta_in; \
797
const int group_x = get_group_id(0); \
798
const int group_y = get_group_id(1); \
799
Dtype8 blockAxB00 = 0; \
800
Dtype8 blockAxB01 = 0; \
801
Dtype8 blockAxB02 = 0; \
802
Dtype8 blockAxB03 = 0; \
803
int2 coordA = (int2)( group_y * TILE_M * SIZE_OF_ELEMENT, 0 ); \
804
int2 coordB = (int2)( 0, ( group_x * TILE_N )); \
805
const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; \
806
do \
807
{ \
808
Dtype8 blockB00; \
809
BLOCKB_READ8(blockB00, B, coordB); \
810
int2 coordATemp = coordA; \
811
Dtype8 blockA00 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 16 * SIZE_OF_ELEMENT;\
812
Dtype8 blockA01 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.y += TILE_K;\
813
MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00, 0); \
814
MULTIPLY_BLOCKS_8x8( blockAxB01, blockA00, blockB00, 8); \
815
MULTIPLY_BLOCKS_8x8( blockAxB02, blockA01, blockB00, 0); \
816
MULTIPLY_BLOCKS_8x8( blockAxB03, blockA01, blockB00, 8); \
817
} \
818
while( coordB.x < padded_k / VECSIZE ); \
819
GEMM_OUTPUT(ALPHA1, BETA_NOT0);\
820
}
821
#else
822
#define GEMM_TT(ALPHA1, BETA_NOT0, VECSCALAR, VECSIZE) \
823
__attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) \
824
__attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) \
825
__kernel void TEMPLATE(gemm_32_1_TT_ ##VECSCALAR ##_ ##ALPHA1 ##_ ##BETA_NOT0, Dtype)( \
826
__read_only image2d_t A, \
827
MATB_PARAMETER, \
828
MATC_PARAMETER, \
829
KERNEL_ARG_DTYPE alpha_in, \
830
KERNEL_ARG_DTYPE beta_in, \
831
int padded_k, \
832
int k, \
833
int isFirstColBlock) \
834
{ \
835
const Dtype alpha = (Dtype)alpha_in; \
836
const Dtype beta = (Dtype)beta_in; \
837
const int group_x = get_group_id(0); \
838
const int group_y = get_group_id(1); \
839
Dtype8 blockAxB00 = 0.0f; \
840
Dtype8 blockAxB01 = 0.0f; \
841
Dtype8 blockAxB02 = 0.0f; \
842
Dtype8 blockAxB03 = 0.0f; \
843
int2 coordA = (int2)( group_y * TILE_M * SIZE_OF_ELEMENT, 0 ); \
844
int2 coordB = (int2)( 0, ( group_x * TILE_N )); \
845
const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; \
846
do \
847
{ \
848
Dtype8 blockB00; \
849
BLOCKB_READ8(blockB00, B, coordB); \
850
int2 coordATemp = coordA; \
851
Dtype8 blockA00 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT; \
852
Dtype8 blockA01 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT; \
853
Dtype8 blockA02 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT; \
854
Dtype8 blockA03 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.y += TILE_K; \
855
MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00 , blockB00, 0 ); \
856
MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01 , blockB00, 0 ); \
857
MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02 , blockB00, 0 ); \
858
MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03 , blockB00, 0 ); \
859
} \
860
while( coordB.x < padded_k / VECSIZE ); \
861
GEMM_OUTPUT(ALPHA1, BETA_NOT0);\
862
}
863
#endif
864
865
#define BLOCKB_READ8(_blockb, _B, _coordB) \
866
int2 _coordBTemp = _coordB; \
867
_coordBTemp.y += get_local_id(0); \
868
_blockb.s0123 = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \
869
_blockb.s4567 = READ_IMAGE(_B, _coordBTemp); _coordB.x += 2;
870
871
#define MATB_PARAMETER __read_only image2d_t B
872
873
GEMM_TT(1, 0, VEC4, 4) // ALPHA == 1, BETA == 0
874
GEMM_TT(1, 1, VEC4, 4) // ALPHA == 1, BETA != 0
875
GEMM_TT(0, 0, VEC4, 4) // ALPHA != 1, BETA == 0
876
GEMM_TT(0, 1, VEC4, 4) // ALPHA != 1, BETA != 0
877
#undef BLOCKB_READ8
878
#undef MATB_PARAMETER
879
880
#if TYPE == TYPE_HALF
881
#define BLOCKB_READ8(_blockb, _B, _coordB) \
882
int2 _coordBTemp = _coordB; \
883
_coordBTemp.y += get_local_id(0); \
884
const __global float *B_read = (__global float *)(_B + (_coordBTemp.y * k) + _coordBTemp.x + offB); \
885
_blockb = as_Dtype8(as_ushort8(vload4(0, B_read))); \
886
_coordB.x += TILE_K;
887
#else
888
#define BLOCKB_READ8(_blockb, _B, _coordB) \
889
int2 _coordBTemp = _coordB; \
890
_coordBTemp.y += get_local_id(0); \
891
const __global Dtype *B_read = (__global Dtype *)(_B + (_coordBTemp.y * k) + _coordBTemp.x + offB); \
892
_blockb = vload8(0, B_read); \
893
_coordB.x += TILE_K;
894
#endif
895
896
#define MATB_PARAMETER __global Dtype *B, int offB, int ldb
897
898
GEMM_TT(1, 0, BUFFER, 1) // ALPHA == 1, BETA == 0
899
GEMM_TT(1, 1, BUFFER, 1) // ALPHA == 1, BETA != 0
900
GEMM_TT(0, 0, BUFFER, 1) // ALPHA != 1, BETA == 0
901
GEMM_TT(0, 1, BUFFER, 1) // ALPHA != 1, BETA != 0
902
#undef BLOCKB_READ8
903
#undef MATB_PARAMETER
904
905
#define BLOCKB_READ8(_blockb, _B, _coordB) \
906
int2 _coordBTemp = _coordB; \
907
_coordBTemp.y += get_local_id(0); \
908
Dtype4 temp; \
909
temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; \
910
_blockb.s0 = temp.s0; \
911
temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; \
912
_blockb.s1 = temp.s0; \
913
temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; \
914
_blockb.s2 = temp.s0; \
915
temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; \
916
_blockb.s3 = temp.s0; \
917
temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; \
918
_blockb.s4 = temp.s0; \
919
temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; \
920
_blockb.s5 = temp.s0; \
921
temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; \
922
_blockb.s6 = temp.s0; \
923
temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; \
924
_blockb.s7 = temp.s0; \
925
_coordB.x += 8;
926
927
#define MATB_PARAMETER __read_only image2d_t B
928
929
GEMM_TT(1, 0, SCALAR, 1) // ALPHA == 1, BETA == 0
930
GEMM_TT(1, 1, SCALAR, 1) // ALPHA == 1, BETA != 0
931
GEMM_TT(0, 0, SCALAR, 1) // ALPHA != 1, BETA == 0
932
GEMM_TT(0, 1, SCALAR, 1) // ALPHA != 1, BETA != 0
933
#undef BLOCKB_READ8
934
#undef MATB_PARAMETER
935
936
#undef MULTIPLY_BLOCKS_8x8
937
#undef TRANSPOSE_BLOCK_8
938
#undef GEMM_TT
939
940
#undef TILE_M
941
#undef TILE_K
942
#undef TILE_N
943
#undef SUBGROUP_BLOCK_READ8
944
#undef READ_IMAGE
945
#undef SIZE_OF_ELEMENT
946
947
__kernel void TEMPLATE(gemm_buffer_copy_image_transpose, Dtype)(
948
__global Dtype* A,
949
__write_only image2d_t ImA,
950
int offA,
951
int width,
952
int height,
953
int ldA)
954
{
955
const int gidx = get_global_id(0);
956
const int gidy = get_global_id(1);
957
int2 coord_dst = (int2)(gidx, gidy);
958
__global Dtype* A_off = A + offA;
959
Dtype srcA = A_off[gidy * ldA + gidx];
960
#if TYPE == TYPE_HALF
961
write_imageh(ImA, coord_dst, (Dtype4)srcA);
962
#else
963
write_imagef(ImA, coord_dst, (Dtype4)srcA);
964
#endif
965
}
966
967
__kernel void TEMPLATE(gemm_buffer_copy_image_no_transpose, Dtype)(
968
__global Dtype* A,
969
__write_only image2d_t ImA,
970
int offA,
971
int width,
972
int height,
973
int ldA)
974
{
975
const int gidx = get_global_id(0);
976
const int gidy = get_global_id(1);
977
int2 coord_dst = (int2)(gidx, gidy);
978
#if TYPE == TYPE_HALF
979
if (gidx >= width || gidy >= height) {
980
write_imageh(ImA, coord_dst, 0);
981
return;
982
}
983
__global Dtype* A_off = A + offA;
984
write_imageh(ImA, coord_dst, A_off[gidy * ldA + gidx]);
985
#else
986
if (gidx >= width || gidy >= height) {
987
write_imageui(ImA, coord_dst, (uint4)0);
988
return;
989
}
990
__global Dtype* A_off = A + offA;
991
uint4 srcA = convert_uint4(as_uchar4(A_off[gidy * ldA + gidx]));
992
write_imageui(ImA, coord_dst, srcA);
993
#endif
994
}
995
996