Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Tetragramm
GitHub Repository: Tetragramm/opencv
Path: blob/master/modules/dnn/src/opencl/gemm_buffer.cl
16337 views
1
/*M///////////////////////////////////////////////////////////////////////////////////////
2
//
3
// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
4
//
5
// By downloading, copying, installing or using the software you agree to this license.
6
// If you do not agree to this license, do not download, install,
7
// copy or use the software.
8
//
9
//
10
// License Agreement
11
// For Open Source Computer Vision Library
12
//
13
// Copyright (C) 2017, Intel Corporation, all rights reserved.
14
// Third party copyrights are property of their respective owners.
15
//
16
// Redistribution and use in source and binary forms, with or without modification,
17
// are permitted provided that the following conditions are met:
18
//
19
// * Redistribution's of source code must retain the above copyright notice,
20
// this list of conditions and the following disclaimer.
21
//
22
// * Redistribution's in binary form must reproduce the above copyright notice,
23
// this list of conditions and the following disclaimer in the documentation
24
// and/or other materials provided with the distribution.
25
//
26
// * The name of the copyright holders may not be used to endorse or promote products
27
// derived from this software without specific prior written permission.
28
//
29
// This software is provided by the copyright holders and contributors "as is" and
30
// any express or implied warranties, including, but not limited to, the implied
31
// warranties of merchantability and fitness for a particular purpose are disclaimed.
32
// In no event shall the Intel Corporation or contributors be liable for any direct,
33
// indirect, incidental, special, exemplary, or consequential damages
34
// (including, but not limited to, procurement of substitute goods or services;
35
// loss of use, data, or profits; or business interruption) however caused
36
// and on any theory of liability, whether in contract, strict liability,
37
// or tort (including negligence or otherwise) arising in any way out of
38
// the use of this software, even if advised of the possibility of such damage.
39
//
40
//M*/
41
42
#if defined(cl_khr_fp16)
43
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
44
#endif
45
46
#define CONCAT(A,B) A##_##B
47
#define TEMPLATE(name,type) CONCAT(name,type)
48
49
#define KERNEL_ARG_DTYPE float
50
#define TYPE_FLOAT 1
51
#define TYPE_HALF 2
52
53
#if TYPE == TYPE_HALF
54
#define Dtype half
55
#define Dtype2 half2
56
#define Dtype4 half4
57
#define Dtype8 half8
58
#define Dtype16 half16
59
60
#define as_Dtype as_half
61
#define as_Dtype2 as_half2
62
#define as_Dtype4 as_half4
63
#define as_Dtype8 as_half8
64
#define as_Dtype16 as_half16
65
#else
66
#define Dtype float
67
#define Dtype2 float2
68
#define Dtype4 float4
69
#define Dtype8 float8
70
#define Dtype16 float16
71
72
#define as_Dtype as_float
73
#define as_Dtype2 as_float2
74
#define as_Dtype4 as_float4
75
#define as_Dtype8 as_float8
76
#define as_Dtype16 as_float16
77
#endif
78
79
#if TYPE == TYPE_HALF
80
#define SHUFFLE_TYPE2(val) as_ushort2(val)
81
#define SHUFFLE_TYPE8(val) as_ushort8(val)
82
#define SIMD_SIZE_GEMM 16
83
#else
84
#define SHUFFLE_TYPE2(val) val
85
#define SHUFFLE_TYPE8(val) val
86
#define SIMD_SIZE_GEMM 8
87
#endif
88
89
#if defined(cl_intel_subgroups)
90
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
91
#endif
92
93
#define VEC_SIZE 4
94
#define LWG_HEIGHT 4
95
#define TILE_M 8
96
#if TYPE == TYPE_HALF
97
#define TILE_K 32
98
#define TILE_N 64
99
#else
100
#define TILE_K 16
101
#define TILE_N 32
102
#endif
103
104
__attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, LWG_HEIGHT, 1)))
105
__attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM)))
106
__kernel void TEMPLATE(gemm_buffer_NN, Dtype)(
107
const __global Dtype *src0, int off0,
108
const __global Dtype *src1, int off1,
109
__global Dtype *dst, int offd,
110
int M,
111
int N,
112
int K,
113
KERNEL_ARG_DTYPE alpha_in,
114
KERNEL_ARG_DTYPE beta_in,
115
int start_index)
116
{
117
const Dtype alpha = (Dtype)alpha_in;
118
const Dtype beta = (Dtype)beta_in;
119
const int group_x = get_group_id(0);
120
const int group_y = get_group_id(1);
121
const int local_x = get_local_id(0);
122
const int local_y = get_local_id(1);
123
const int global_x = get_global_id(0);
124
const int global_y = get_global_id(1);
125
126
Dtype4 brow;
127
Dtype2 arow0, arow1, arow2, arow3, arow4, arow5, arow6, arow7;
128
129
__global Dtype *dst_write0 = dst + local_x * VEC_SIZE + (group_x * TILE_N) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd;
130
131
const __global Dtype *src0_read = src0 + local_x * (TILE_K / SIMD_SIZE_GEMM) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * K + start_index + off0;
132
133
const __global Dtype *src1_read0 = src1 + local_x * VEC_SIZE + (group_x * TILE_N) + start_index * N + off1;
134
135
int border = -(group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M);
136
137
int row0 = mad24(global_y, TILE_M, 0) < M ? 0 : border;
138
int row1 = mad24(global_y, TILE_M, 1) < M ? 1 : border;
139
int row2 = mad24(global_y, TILE_M, 2) < M ? 2 : border;
140
int row3 = mad24(global_y, TILE_M, 3) < M ? 3 : border;
141
int row4 = mad24(global_y, TILE_M, 4) < M ? 4 : border;
142
int row5 = mad24(global_y, TILE_M, 5) < M ? 5 : border;
143
int row6 = mad24(global_y, TILE_M, 6) < M ? 6 : border;
144
int row7 = mad24(global_y, TILE_M, 7) < M ? 7 : border;
145
146
Dtype4 dot00 = (start_index != 0) ? vload4(0, dst_write0) : beta * vload4(0, dst_write0);
147
Dtype4 dot01 = (start_index != 0) ? vload4(0, dst_write0 + 1 * N) : beta * vload4(0, dst_write0 + 1 * N);
148
Dtype4 dot02 = (start_index != 0) ? vload4(0, dst_write0 + 2 * N) : beta * vload4(0, dst_write0 + 2 * N);
149
Dtype4 dot03 = (start_index != 0) ? vload4(0, dst_write0 + 3 * N) : beta * vload4(0, dst_write0 + 3 * N);
150
Dtype4 dot04 = (start_index != 0) ? vload4(0, dst_write0 + 4 * N) : beta * vload4(0, dst_write0 + 4 * N);
151
Dtype4 dot05 = (start_index != 0) ? vload4(0, dst_write0 + 5 * N) : beta * vload4(0, dst_write0 + 5 * N);
152
Dtype4 dot06 = (start_index != 0) ? vload4(0, dst_write0 + 6 * N) : beta * vload4(0, dst_write0 + 6 * N);
153
Dtype4 dot07 = (start_index != 0) ? vload4(0, dst_write0 + 7 * N) : beta * vload4(0, dst_write0 + 7 * N);
154
155
int end_index = min(start_index + 256, K);
156
int w = start_index;
157
while( w + TILE_K <= end_index ) {
158
arow0 = alpha * vload2(0, src0_read + row0 * K);
159
arow1 = alpha * vload2(0, src0_read + row1 * K);
160
arow2 = alpha * vload2(0, src0_read + row2 * K);
161
arow3 = alpha * vload2(0, src0_read + row3 * K);
162
arow4 = alpha * vload2(0, src0_read + row4 * K);
163
arow5 = alpha * vload2(0, src0_read + row5 * K);
164
arow6 = alpha * vload2(0, src0_read + row6 * K);
165
arow7 = alpha * vload2(0, src0_read + row7 * K);
166
167
#define MM_DOT_PRODUCT( index, suffix ) \
168
brow = vload4(0, src1_read0); src1_read0 += N; \
169
dot00 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow0), index )).s##suffix), brow, dot00 ); \
170
dot01 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow1), index )).s##suffix), brow, dot01 ); \
171
dot02 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow2), index )).s##suffix), brow, dot02 ); \
172
dot03 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow3), index )).s##suffix), brow, dot03 ); \
173
dot04 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow4), index )).s##suffix), brow, dot04 ); \
174
dot05 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow5), index )).s##suffix), brow, dot05 ); \
175
dot06 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow6), index )).s##suffix), brow, dot06 ); \
176
dot07 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow7), index )).s##suffix), brow, dot07 );
177
178
MM_DOT_PRODUCT(0, 0);
179
MM_DOT_PRODUCT(0, 1);
180
MM_DOT_PRODUCT(1, 0);
181
MM_DOT_PRODUCT(1, 1);
182
MM_DOT_PRODUCT(2, 0);
183
MM_DOT_PRODUCT(2, 1);
184
MM_DOT_PRODUCT(3, 0);
185
MM_DOT_PRODUCT(3, 1);
186
MM_DOT_PRODUCT(4, 0);
187
MM_DOT_PRODUCT(4, 1);
188
MM_DOT_PRODUCT(5, 0);
189
MM_DOT_PRODUCT(5, 1);
190
MM_DOT_PRODUCT(6, 0);
191
MM_DOT_PRODUCT(6, 1);
192
MM_DOT_PRODUCT(7, 0);
193
MM_DOT_PRODUCT(7, 1);
194
#if TYPE == TYPE_HALF
195
MM_DOT_PRODUCT(8, 0);
196
MM_DOT_PRODUCT(8, 1);
197
MM_DOT_PRODUCT(9, 0);
198
MM_DOT_PRODUCT(9, 1);
199
MM_DOT_PRODUCT(10, 0);
200
MM_DOT_PRODUCT(10, 1);
201
MM_DOT_PRODUCT(11, 0);
202
MM_DOT_PRODUCT(11, 1);
203
MM_DOT_PRODUCT(12, 0);
204
MM_DOT_PRODUCT(12, 1);
205
MM_DOT_PRODUCT(13, 0);
206
MM_DOT_PRODUCT(13, 1);
207
MM_DOT_PRODUCT(14, 0);
208
MM_DOT_PRODUCT(14, 1);
209
MM_DOT_PRODUCT(15, 0);
210
MM_DOT_PRODUCT(15, 1);
211
#endif
212
#undef MM_DOT_PRODUCT
213
214
src0_read += TILE_K;
215
w += TILE_K;
216
}
217
218
if(w < end_index) {
219
arow0.x = ((w + local_x * 2) < K) ? alpha * (src0_read + row0 * K)[0] : 0.0f;
220
arow0.y = ((w + local_x * 2 + 1) < K) ? alpha * (src0_read + row0 * K)[1] : 0.0f;
221
arow1.x = ((w + local_x * 2) < K) ? alpha * (src0_read + row1 * K)[0] : 0.0f;
222
arow1.y = ((w + local_x * 2 + 1) < K) ? alpha * (src0_read + row1 * K)[1] : 0.0f;
223
arow2.x = ((w + local_x * 2) < K) ? alpha * (src0_read + row2 * K)[0] : 0.0f;
224
arow2.y = ((w + local_x * 2 + 1) < K) ? alpha * (src0_read + row2 * K)[1] : 0.0f;
225
arow3.x = ((w + local_x * 2) < K) ? alpha * (src0_read + row3 * K)[0] : 0.0f;
226
arow3.y = ((w + local_x * 2 + 1) < K) ? alpha * (src0_read + row3 * K)[1] : 0.0f;
227
arow4.x = ((w + local_x * 2) < K) ? alpha * (src0_read + row4 * K)[0] : 0.0f;
228
arow4.y = ((w + local_x * 2 + 1) < K) ? alpha * (src0_read + row4 * K)[1] : 0.0f;
229
arow5.x = ((w + local_x * 2) < K) ? alpha * (src0_read + row5 * K)[0] : 0.0f;
230
arow5.y = ((w + local_x * 2 + 1) < K) ? alpha * (src0_read + row5 * K)[1] : 0.0f;
231
arow6.x = ((w + local_x * 2) < K) ? alpha * (src0_read + row6 * K)[0] : 0.0f;
232
arow6.y = ((w + local_x * 2 + 1) < K) ? alpha * (src0_read + row6 * K)[1] : 0.0f;
233
arow7.x = ((w + local_x * 2) < K) ? alpha * (src0_read + row7 * K)[0] : 0.0f;
234
arow7.y = ((w + local_x * 2 + 1) < K) ? alpha * (src0_read + row7 * K)[1] : 0.0f;
235
236
#define MM_DOT_PRODUCT( index, suffix ) \
237
brow = (w < K) ? vload4(0, src1_read0) : (Dtype4)0.0f; src1_read0 += N; w++; \
238
dot00 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow0), index )).s##suffix), brow, dot00 ); \
239
dot01 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow1), index )).s##suffix), brow, dot01 ); \
240
dot02 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow2), index )).s##suffix), brow, dot02 ); \
241
dot03 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow3), index )).s##suffix), brow, dot03 ); \
242
dot04 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow4), index )).s##suffix), brow, dot04 ); \
243
dot05 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow5), index )).s##suffix), brow, dot05 ); \
244
dot06 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow6), index )).s##suffix), brow, dot06 ); \
245
dot07 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow7), index )).s##suffix), brow, dot07 );
246
247
MM_DOT_PRODUCT(0, 0);
248
MM_DOT_PRODUCT(0, 1);
249
MM_DOT_PRODUCT(1, 0);
250
MM_DOT_PRODUCT(1, 1);
251
MM_DOT_PRODUCT(2, 0);
252
MM_DOT_PRODUCT(2, 1);
253
MM_DOT_PRODUCT(3, 0);
254
MM_DOT_PRODUCT(3, 1);
255
MM_DOT_PRODUCT(4, 0);
256
MM_DOT_PRODUCT(4, 1);
257
MM_DOT_PRODUCT(5, 0);
258
MM_DOT_PRODUCT(5, 1);
259
MM_DOT_PRODUCT(6, 0);
260
MM_DOT_PRODUCT(6, 1);
261
MM_DOT_PRODUCT(7, 0);
262
MM_DOT_PRODUCT(7, 1);
263
#if TYPE == TYPE_HALF
264
MM_DOT_PRODUCT(8, 0);
265
MM_DOT_PRODUCT(8, 1);
266
MM_DOT_PRODUCT(9, 0);
267
MM_DOT_PRODUCT(9, 1);
268
MM_DOT_PRODUCT(10, 0);
269
MM_DOT_PRODUCT(10, 1);
270
MM_DOT_PRODUCT(11, 0);
271
MM_DOT_PRODUCT(11, 1);
272
MM_DOT_PRODUCT(12, 0);
273
MM_DOT_PRODUCT(12, 1);
274
MM_DOT_PRODUCT(13, 0);
275
MM_DOT_PRODUCT(13, 1);
276
MM_DOT_PRODUCT(14, 0);
277
MM_DOT_PRODUCT(14, 1);
278
MM_DOT_PRODUCT(15, 0);
279
MM_DOT_PRODUCT(15, 1);
280
#endif
281
#undef MM_DOT_PRODUCT
282
}
283
284
if(global_x * 4 < N && global_y * 8 < M) {
285
if(mad24(global_x, 4, 3) < N) {
286
vstore4(dot00, 0, dst_write0); dst_write0 += N;
287
if(mad24(global_y, 8, 1) < M) { vstore4(dot01, 0, dst_write0); dst_write0 += N; }
288
else return;
289
if(mad24(global_y, 8, 2) < M) { vstore4(dot02, 0, dst_write0); dst_write0 += N; }
290
else return;
291
if(mad24(global_y, 8, 3) < M) { vstore4(dot03, 0, dst_write0); dst_write0 += N; }
292
else return;
293
if(mad24(global_y, 8, 4) < M) { vstore4(dot04, 0, dst_write0); dst_write0 += N; }
294
else return;
295
if(mad24(global_y, 8, 5) < M) { vstore4(dot05, 0, dst_write0); dst_write0 += N; }
296
else return;
297
if(mad24(global_y, 8, 6) < M) { vstore4(dot06, 0, dst_write0); dst_write0 += N; }
298
else return;
299
if(mad24(global_y, 8, 7) < M) { vstore4(dot07, 0, dst_write0); }
300
} else if(mad24(global_x, 4, 2) < N) {
301
vstore2(dot00.xy, 0, dst_write0);
302
dst_write0[2] = dot00.z;
303
dst_write0 += N;
304
if(mad24(global_y, 8, 1) < M) {
305
vstore2(dot01.xy, 0, dst_write0);
306
dst_write0[2] = dot01.z;
307
dst_write0 += N;
308
} else
309
return;
310
if(mad24(global_y, 8, 2) < M) {
311
vstore2(dot02.xy, 0, dst_write0);
312
dst_write0[2] = dot02.z;
313
dst_write0 += N;
314
} else
315
return;
316
if(mad24(global_y, 8, 3) < M) {
317
vstore2(dot03.xy, 0, dst_write0);
318
dst_write0[2] = dot03.z;
319
dst_write0 += N;
320
} else
321
return;
322
if(mad24(global_y, 8, 4) < M) {
323
vstore2(dot04.xy, 0, dst_write0);
324
dst_write0[2] = dot04.z;
325
dst_write0 += N;
326
} else
327
return;
328
if(mad24(global_y, 8, 5) < M) {
329
vstore2(dot05.xy, 0, dst_write0);
330
dst_write0[2] = dot05.z;
331
dst_write0 += N;
332
} else
333
return;
334
if(mad24(global_y, 8, 6) < M) {
335
vstore2(dot06.xy, 0, dst_write0);
336
dst_write0[2] = dot06.z;
337
dst_write0 += N;
338
} else
339
return;
340
if(mad24(global_y, 8, 7) < M) {
341
vstore2(dot07.xy, 0, dst_write0);
342
dst_write0[2] = dot07.z;
343
}
344
} else if(mad24(global_x, 4, 1) < N) {
345
vstore2(dot00.xy, 0, dst_write0); dst_write0 += N;
346
if(mad24(global_y, 8, 1) < M) { vstore2(dot01.xy, 0, dst_write0); dst_write0 += N; }
347
else return;
348
if(mad24(global_y, 8, 2) < M) { vstore2(dot02.xy, 0, dst_write0); dst_write0 += N; }
349
else return;
350
if(mad24(global_y, 8, 3) < M) { vstore2(dot03.xy, 0, dst_write0); dst_write0 += N; }
351
else return;
352
if(mad24(global_y, 8, 4) < M) { vstore2(dot04.xy, 0, dst_write0); dst_write0 += N; }
353
else return;
354
if(mad24(global_y, 8, 5) < M) { vstore2(dot05.xy, 0, dst_write0); dst_write0 += N; }
355
else return;
356
if(mad24(global_y, 8, 6) < M) { vstore2(dot06.xy, 0, dst_write0); dst_write0 += N; }
357
else return;
358
if(mad24(global_y, 8, 7) < M) { vstore2(dot07.xy, 0, dst_write0); }
359
} else {
360
dst_write0[0] = dot00.x; dst_write0 += N;
361
if(mad24(global_y, 8, 1) < M) { dst_write0[0] = dot01.x; dst_write0 += N; }
362
else return;
363
if(mad24(global_y, 8, 2) < M) { dst_write0[0] = dot02.x; dst_write0 += N; }
364
else return;
365
if(mad24(global_y, 8, 3) < M) { dst_write0[0] = dot03.x; dst_write0 += N; }
366
else return;
367
if(mad24(global_y, 8, 4) < M) { dst_write0[0] = dot04.x; dst_write0 += N; }
368
else return;
369
if(mad24(global_y, 8, 5) < M) { dst_write0[0] = dot05.x; dst_write0 += N; }
370
else return;
371
if(mad24(global_y, 8, 6) < M) { dst_write0[0] = dot06.x; dst_write0 += N; }
372
else return;
373
if(mad24(global_y, 8, 7) < M) { dst_write0[0] = dot07.x; }
374
}
375
}
376
}
377
378
#undef VEC_SIZE
379
#undef LWG_HEIGHT
380
#undef TILE_M
381
#undef TILE_K
382
#undef TILE_N
383
384
#define VEC_SIZE 1
385
#define TILE_M 8
386
#define TILE_N 8
387
#define SLM_BLOCK 128
388
389
#if TYPE == TYPE_HALF
390
#define LWG_HEIGHT 2
391
#define TILE_K 64
392
#else
393
#define LWG_HEIGHT 4
394
#define TILE_K 32
395
#endif
396
397
#if TYPE == TYPE_HALF
398
__attribute__((reqd_work_group_size(8, LWG_HEIGHT, 1)))
399
__attribute__((intel_reqd_sub_group_size(8)))
400
__kernel void TEMPLATE(gemm_buffer_NT, Dtype)(
401
const __global Dtype *src0, int off0,
402
const __global Dtype *src1, int off1,
403
__global Dtype *dst, int offd,
404
int M,
405
int N,
406
int K,
407
KERNEL_ARG_DTYPE alpha_in,
408
KERNEL_ARG_DTYPE beta_in)
409
{
410
const Dtype alpha = (Dtype)alpha_in;
411
const Dtype beta = (Dtype)beta_in;
412
const int group_x = get_group_id(0);
413
const int group_y = get_group_id(1);
414
const int local_x = get_local_id(0);
415
const int local_y = get_local_id(1);
416
const int global_x = get_global_id(0);
417
const int global_y = get_global_id(1);
418
419
Dtype8 dot00 = 0.f;
420
Dtype8 dot01 = 0.f;
421
Dtype8 dot02 = 0.f;
422
Dtype8 dot03 = 0.f;
423
Dtype8 dot04 = 0.f;
424
Dtype8 dot05 = 0.f;
425
Dtype8 dot06 = 0.f;
426
Dtype8 dot07 = 0.f;
427
428
Dtype8 brow0;
429
Dtype8 brow1;
430
Dtype8 brow2;
431
Dtype8 brow3;
432
Dtype8 brow4;
433
Dtype8 brow5;
434
Dtype8 brow6;
435
Dtype8 brow7;
436
437
__global Dtype *dst_write0 = dst + local_x * VEC_SIZE + (group_x * TILE_N) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd;
438
439
const __global Dtype *src0_read = src0 + local_x * (TILE_K / 8) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * K + off0;
440
441
const __global Dtype *src1_read0 = src1 + (group_x * TILE_N) * K + off1;
442
443
__local Dtype slm_brow[8 * SLM_BLOCK];
444
__local Dtype* slm_brow0;
445
446
int local_index = mad24(local_y, 8, local_x) * 8;
447
int w;
448
for(int b_tile = 0; b_tile < K; b_tile += SLM_BLOCK) {
449
barrier(CLK_LOCAL_MEM_FENCE);
450
vstore4(vload4(0, (__global float *)(src1_read0 + mad24(0, K, local_index))), 0, (__local float *)(slm_brow + mad24(0, SLM_BLOCK, local_index)));
451
vstore4(vload4(0, (__global float *)(src1_read0 + mad24(1, K, local_index))), 0, (__local float *)(slm_brow + mad24(1, SLM_BLOCK, local_index)));
452
vstore4(vload4(0, (__global float *)(src1_read0 + mad24(2, K, local_index))), 0, (__local float *)(slm_brow + mad24(2, SLM_BLOCK, local_index)));
453
vstore4(vload4(0, (__global float *)(src1_read0 + mad24(3, K, local_index))), 0, (__local float *)(slm_brow + mad24(3, SLM_BLOCK, local_index)));
454
vstore4(vload4(0, (__global float *)(src1_read0 + mad24(4, K, local_index))), 0, (__local float *)(slm_brow + mad24(4, SLM_BLOCK, local_index)));
455
vstore4(vload4(0, (__global float *)(src1_read0 + mad24(5, K, local_index))), 0, (__local float *)(slm_brow + mad24(5, SLM_BLOCK, local_index)));
456
vstore4(vload4(0, (__global float *)(src1_read0 + mad24(6, K, local_index))), 0, (__local float *)(slm_brow + mad24(6, SLM_BLOCK, local_index)));
457
vstore4(vload4(0, (__global float *)(src1_read0 + mad24(7, K, local_index))), 0, (__local float *)(slm_brow + mad24(7, SLM_BLOCK, local_index)));
458
barrier(CLK_LOCAL_MEM_FENCE);
459
460
slm_brow0 = slm_brow + local_x * (TILE_K / 8);
461
w = b_tile;
462
int end_w = min(b_tile + SLM_BLOCK, K);
463
while( w + TILE_K <= end_w ) {
464
Dtype8 arow;
465
466
brow0 = as_half8(vload4(0, (__local float *)(slm_brow0 + 0 * SLM_BLOCK)));
467
brow1 = as_half8(vload4(0, (__local float *)(slm_brow0 + 1 * SLM_BLOCK)));
468
brow2 = as_half8(vload4(0, (__local float *)(slm_brow0 + 2 * SLM_BLOCK)));
469
brow3 = as_half8(vload4(0, (__local float *)(slm_brow0 + 3 * SLM_BLOCK)));
470
brow4 = as_half8(vload4(0, (__local float *)(slm_brow0 + 4 * SLM_BLOCK)));
471
brow5 = as_half8(vload4(0, (__local float *)(slm_brow0 + 5 * SLM_BLOCK)));
472
brow6 = as_half8(vload4(0, (__local float *)(slm_brow0 + 6 * SLM_BLOCK)));
473
brow7 = as_half8(vload4(0, (__local float *)(slm_brow0 + 7 * SLM_BLOCK)));
474
475
#define MM_DOT_PRODUCT( _row, _dot ) \
476
arow = as_half8(vload4(0, (__global float *)(src0_read + _row * K))); \
477
_dot = mad( (Dtype8)(arow.s0), (Dtype8)(brow0.s0, brow1.s0, brow2.s0, brow3.s0, brow4.s0, brow5.s0, brow6.s0, brow7.s0), _dot ); \
478
_dot = mad( (Dtype8)(arow.s1), (Dtype8)(brow0.s1, brow1.s1, brow2.s1, brow3.s1, brow4.s1, brow5.s1, brow6.s1, brow7.s1), _dot ); \
479
_dot = mad( (Dtype8)(arow.s2), (Dtype8)(brow0.s2, brow1.s2, brow2.s2, brow3.s2, brow4.s2, brow5.s2, brow6.s2, brow7.s2), _dot ); \
480
_dot = mad( (Dtype8)(arow.s3), (Dtype8)(brow0.s3, brow1.s3, brow2.s3, brow3.s3, brow4.s3, brow5.s3, brow6.s3, brow7.s3), _dot ); \
481
_dot = mad( (Dtype8)(arow.s4), (Dtype8)(brow0.s4, brow1.s4, brow2.s4, brow3.s4, brow4.s4, brow5.s4, brow6.s4, brow7.s4), _dot ); \
482
_dot = mad( (Dtype8)(arow.s5), (Dtype8)(brow0.s5, brow1.s5, brow2.s5, brow3.s5, brow4.s5, brow5.s5, brow6.s5, brow7.s5), _dot ); \
483
_dot = mad( (Dtype8)(arow.s6), (Dtype8)(brow0.s6, brow1.s6, brow2.s6, brow3.s6, brow4.s6, brow5.s6, brow6.s6, brow7.s6), _dot ); \
484
_dot = mad( (Dtype8)(arow.s7), (Dtype8)(brow0.s7, brow1.s7, brow2.s7, brow3.s7, brow4.s7, brow5.s7, brow6.s7, brow7.s7), _dot );
485
486
MM_DOT_PRODUCT( 0, dot00 );
487
MM_DOT_PRODUCT( 1, dot01 );
488
MM_DOT_PRODUCT( 2, dot02 );
489
MM_DOT_PRODUCT( 3, dot03 );
490
MM_DOT_PRODUCT( 4, dot04 );
491
MM_DOT_PRODUCT( 5, dot05 );
492
MM_DOT_PRODUCT( 6, dot06 );
493
MM_DOT_PRODUCT( 7, dot07 );
494
#undef MM_DOT_PRODUCT
495
496
src0_read += TILE_K;
497
slm_brow0 += TILE_K;
498
w += TILE_K;
499
}
500
src1_read0 += SLM_BLOCK;
501
}
502
503
if(w < K) {
504
Dtype8 arow;
505
506
#define READ_BROW(_brow, _row) \
507
_brow = as_half8(vload4(0, (__local float *)(slm_brow0 + _row * SLM_BLOCK))); \
508
_brow.s0 = (mad24(local_x, 8, w) < K) ? _brow.s0 : 0.0f; \
509
_brow.s1 = (mad24(local_x, 8, w + 1) < K) ? _brow.s1 : 0.0f; \
510
_brow.s2 = (mad24(local_x, 8, w + 2) < K) ? _brow.s2 : 0.0f; \
511
_brow.s3 = (mad24(local_x, 8, w + 3) < K) ? _brow.s3 : 0.0f; \
512
_brow.s4 = (mad24(local_x, 8, w + 4) < K) ? _brow.s4 : 0.0f; \
513
_brow.s5 = (mad24(local_x, 8, w + 5) < K) ? _brow.s5 : 0.0f; \
514
_brow.s6 = (mad24(local_x, 8, w + 6) < K) ? _brow.s6 : 0.0f; \
515
_brow.s7 = (mad24(local_x, 8, w + 7) < K) ? _brow.s7 : 0.0f;
516
517
READ_BROW(brow0, 0);
518
READ_BROW(brow1, 1);
519
READ_BROW(brow2, 2);
520
READ_BROW(brow3, 3);
521
READ_BROW(brow4, 4);
522
READ_BROW(brow5, 5);
523
READ_BROW(brow6, 6);
524
READ_BROW(brow7, 7);
525
526
#undef READ_BROW
527
528
#define MM_DOT_PRODUCT( _row, _dot ) \
529
arow = as_half8(vload4(0, (__global float *)(src0_read + _row * K))); \
530
arow.s0 = (mad24(local_x, 8, w) < K) ? arow.s0 : 0.0f; \
531
arow.s1 = (mad24(local_x, 8, w + 1) < K) ? arow.s1 : 0.0f; \
532
arow.s2 = (mad24(local_x, 8, w + 2) < K) ? arow.s2 : 0.0f; \
533
arow.s3 = (mad24(local_x, 8, w + 3) < K) ? arow.s3 : 0.0f; \
534
arow.s4 = (mad24(local_x, 8, w + 4) < K) ? arow.s4 : 0.0f; \
535
arow.s5 = (mad24(local_x, 8, w + 5) < K) ? arow.s5 : 0.0f; \
536
arow.s6 = (mad24(local_x, 8, w + 6) < K) ? arow.s6 : 0.0f; \
537
arow.s7 = (mad24(local_x, 8, w + 7) < K) ? arow.s7 : 0.0f; \
538
_dot = mad( (Dtype8)(arow.s0), (Dtype8)(brow0.s0, brow1.s0, brow2.s0, brow3.s0, brow4.s0, brow5.s0, brow6.s0, brow7.s0), _dot ); \
539
_dot = mad( (Dtype8)(arow.s1), (Dtype8)(brow0.s1, brow1.s1, brow2.s1, brow3.s1, brow4.s1, brow5.s1, brow6.s1, brow7.s1), _dot ); \
540
_dot = mad( (Dtype8)(arow.s2), (Dtype8)(brow0.s2, brow1.s2, brow2.s2, brow3.s2, brow4.s2, brow5.s2, brow6.s2, brow7.s2), _dot ); \
541
_dot = mad( (Dtype8)(arow.s3), (Dtype8)(brow0.s3, brow1.s3, brow2.s3, brow3.s3, brow4.s3, brow5.s3, brow6.s3, brow7.s3), _dot ); \
542
_dot = mad( (Dtype8)(arow.s4), (Dtype8)(brow0.s4, brow1.s4, brow2.s4, brow3.s4, brow4.s4, brow5.s4, brow6.s4, brow7.s4), _dot ); \
543
_dot = mad( (Dtype8)(arow.s5), (Dtype8)(brow0.s5, brow1.s5, brow2.s5, brow3.s5, brow4.s5, brow5.s5, brow6.s5, brow7.s5), _dot ); \
544
_dot = mad( (Dtype8)(arow.s6), (Dtype8)(brow0.s6, brow1.s6, brow2.s6, brow3.s6, brow4.s6, brow5.s6, brow6.s6, brow7.s6), _dot ); \
545
_dot = mad( (Dtype8)(arow.s7), (Dtype8)(brow0.s7, brow1.s7, brow2.s7, brow3.s7, brow4.s7, brow5.s7, brow6.s7, brow7.s7), _dot );
546
547
MM_DOT_PRODUCT( 0, dot00 );
548
MM_DOT_PRODUCT( 1, dot01 );
549
MM_DOT_PRODUCT( 2, dot02 );
550
MM_DOT_PRODUCT( 3, dot03 );
551
MM_DOT_PRODUCT( 4, dot04 );
552
MM_DOT_PRODUCT( 5, dot05 );
553
MM_DOT_PRODUCT( 6, dot06 );
554
MM_DOT_PRODUCT( 7, dot07 );
555
#undef MM_DOT_PRODUCT
556
}
557
558
#define REDUCE(_dot) \
559
_dot = as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 0)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 1)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 2)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 3)) + \
560
as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 4)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 5)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 6)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 7));
561
562
REDUCE(dot00);
563
REDUCE(dot01);
564
REDUCE(dot02);
565
REDUCE(dot03);
566
REDUCE(dot04);
567
REDUCE(dot05);
568
REDUCE(dot06);
569
REDUCE(dot07);
570
#undef REDUCE
571
572
Dtype output = 0.0f;
573
#define OUTPUT( _dot) \
574
output = (local_x == 0) ? _dot.s0 : output; \
575
output = (local_x == 1) ? _dot.s1 : output; \
576
output = (local_x == 2) ? _dot.s2 : output; \
577
output = (local_x == 3) ? _dot.s3 : output; \
578
output = (local_x == 4) ? _dot.s4 : output; \
579
output = (local_x == 5) ? _dot.s5 : output; \
580
output = (local_x == 6) ? _dot.s6 : output; \
581
output = (local_x == 7) ? _dot.s7 : output; \
582
dst_write0[0] = mad(output, alpha, beta * dst_write0[0]); \
583
dst_write0 += N;
584
585
if(global_x < N && global_y * 8 < M) {
586
OUTPUT(dot00);
587
if(mad24(global_y, 8, 1) < M) { OUTPUT(dot01); }
588
if(mad24(global_y, 8, 2) < M) { OUTPUT(dot02); }
589
if(mad24(global_y, 8, 3) < M) { OUTPUT(dot03); }
590
if(mad24(global_y, 8, 4) < M) { OUTPUT(dot04); }
591
if(mad24(global_y, 8, 5) < M) { OUTPUT(dot05); }
592
if(mad24(global_y, 8, 6) < M) { OUTPUT(dot06); }
593
if(mad24(global_y, 8, 7) < M) { OUTPUT(dot07); }
594
}
595
#undef OUTPUT
596
}
597
598
#else
599
600
__attribute__((reqd_work_group_size(8, LWG_HEIGHT, 1)))
601
__attribute__((intel_reqd_sub_group_size(8)))
602
__kernel void TEMPLATE(gemm_buffer_NT, Dtype)(
603
const __global Dtype *src0, int off0,
604
const __global Dtype *src1, int off1,
605
__global Dtype *dst, int offd,
606
int M,
607
int N,
608
int K,
609
KERNEL_ARG_DTYPE alpha_in,
610
KERNEL_ARG_DTYPE beta_in)
611
{
612
const Dtype alpha = (Dtype)alpha_in;
613
const Dtype beta = (Dtype)beta_in;
614
const int group_x = get_group_id(0);
615
const int group_y = get_group_id(1);
616
const int local_x = get_local_id(0);
617
const int local_y = get_local_id(1);
618
const int global_x = get_global_id(0);
619
const int global_y = get_global_id(1);
620
621
Dtype8 dot00 = 0.f;
622
Dtype8 dot01 = 0.f;
623
Dtype8 dot02 = 0.f;
624
Dtype8 dot03 = 0.f;
625
Dtype8 dot04 = 0.f;
626
Dtype8 dot05 = 0.f;
627
Dtype8 dot06 = 0.f;
628
Dtype8 dot07 = 0.f;
629
630
Dtype4 brow0;
631
Dtype4 brow1;
632
Dtype4 brow2;
633
Dtype4 brow3;
634
Dtype4 brow4;
635
Dtype4 brow5;
636
Dtype4 brow6;
637
Dtype4 brow7;
638
639
__global Dtype *dst_write0 = dst + local_x * VEC_SIZE + (group_x * TILE_N) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd;
640
641
const __global Dtype *src0_read = src0 + local_x * (TILE_K / 8) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * K + off0;
642
643
const __global Dtype *src1_read0 = src1 + (group_x * TILE_N) * K + off1;
644
645
__local Dtype slm_brow[8 * SLM_BLOCK];
646
__local Dtype* slm_brow0;
647
648
int local_index = mad24(local_y, 8, local_x) * 4;
649
int w;
650
for(int b_tile = 0; b_tile < K; b_tile += SLM_BLOCK) {
651
barrier(CLK_LOCAL_MEM_FENCE);
652
vstore4(vload4(0, src1_read0 + mad24(0, K, local_index)), 0, slm_brow + mad24(0, SLM_BLOCK, local_index));
653
vstore4(vload4(0, src1_read0 + mad24(1, K, local_index)), 0, slm_brow + mad24(1, SLM_BLOCK, local_index));
654
vstore4(vload4(0, src1_read0 + mad24(2, K, local_index)), 0, slm_brow + mad24(2, SLM_BLOCK, local_index));
655
vstore4(vload4(0, src1_read0 + mad24(3, K, local_index)), 0, slm_brow + mad24(3, SLM_BLOCK, local_index));
656
vstore4(vload4(0, src1_read0 + mad24(4, K, local_index)), 0, slm_brow + mad24(4, SLM_BLOCK, local_index));
657
vstore4(vload4(0, src1_read0 + mad24(5, K, local_index)), 0, slm_brow + mad24(5, SLM_BLOCK, local_index));
658
vstore4(vload4(0, src1_read0 + mad24(6, K, local_index)), 0, slm_brow + mad24(6, SLM_BLOCK, local_index));
659
vstore4(vload4(0, src1_read0 + mad24(7, K, local_index)), 0, slm_brow + mad24(7, SLM_BLOCK, local_index));
660
barrier(CLK_LOCAL_MEM_FENCE);
661
662
slm_brow0 = slm_brow + local_x * (TILE_K / 8);
663
w = b_tile;
664
int end_w = min(b_tile + SLM_BLOCK, K);
665
while( w + TILE_K <= end_w ) {
666
Dtype4 arow;
667
668
brow0 = vload4(0, slm_brow0 + 0 * SLM_BLOCK);
669
brow1 = vload4(0, slm_brow0 + 1 * SLM_BLOCK);
670
brow2 = vload4(0, slm_brow0 + 2 * SLM_BLOCK);
671
brow3 = vload4(0, slm_brow0 + 3 * SLM_BLOCK);
672
brow4 = vload4(0, slm_brow0 + 4 * SLM_BLOCK);
673
brow5 = vload4(0, slm_brow0 + 5 * SLM_BLOCK);
674
brow6 = vload4(0, slm_brow0 + 6 * SLM_BLOCK);
675
brow7 = vload4(0, slm_brow0 + 7 * SLM_BLOCK);
676
677
#define MM_DOT_PRODUCT( _row, _dot ) \
678
arow = vload4(0, src0_read + _row * K); \
679
_dot = mad( (Dtype8)(arow.x), (Dtype8)(brow0.x, brow1.x, brow2.x, brow3.x, brow4.x, brow5.x, brow6.x, brow7.x), _dot ); \
680
_dot = mad( (Dtype8)(arow.y), (Dtype8)(brow0.y, brow1.y, brow2.y, brow3.y, brow4.y, brow5.y, brow6.y, brow7.y), _dot ); \
681
_dot = mad( (Dtype8)(arow.z), (Dtype8)(brow0.z, brow1.z, brow2.z, brow3.z, brow4.z, brow5.z, brow6.z, brow7.z), _dot ); \
682
_dot = mad( (Dtype8)(arow.w), (Dtype8)(brow0.w, brow1.w, brow2.w, brow3.w, brow4.w, brow5.w, brow6.w, brow7.w), _dot );
683
684
MM_DOT_PRODUCT( 0, dot00 );
685
MM_DOT_PRODUCT( 1, dot01 );
686
MM_DOT_PRODUCT( 2, dot02 );
687
MM_DOT_PRODUCT( 3, dot03 );
688
MM_DOT_PRODUCT( 4, dot04 );
689
MM_DOT_PRODUCT( 5, dot05 );
690
MM_DOT_PRODUCT( 6, dot06 );
691
MM_DOT_PRODUCT( 7, dot07 );
692
#undef MM_DOT_PRODUCT
693
694
src0_read += TILE_K;
695
slm_brow0 += TILE_K;
696
w += TILE_K;
697
}
698
src1_read0 += SLM_BLOCK;
699
}
700
701
if(w < K) {
702
Dtype4 arow;
703
704
#define READ_BROW(_brow, _row) \
705
_brow = vload4(0, slm_brow0 + _row * SLM_BLOCK); \
706
_brow.x = (mad24(local_x, 4, w) < K) ? _brow.x : 0.0f; \
707
_brow.y = (mad24(local_x, 4, w + 1) < K) ? _brow.y : 0.0f; \
708
_brow.z = (mad24(local_x, 4, w + 2) < K) ? _brow.z : 0.0f; \
709
_brow.w = (mad24(local_x, 4, w + 3) < K) ? _brow.w : 0.0f;
710
711
READ_BROW(brow0, 0);
712
READ_BROW(brow1, 1);
713
READ_BROW(brow2, 2);
714
READ_BROW(brow3, 3);
715
READ_BROW(brow4, 4);
716
READ_BROW(brow5, 5);
717
READ_BROW(brow6, 6);
718
READ_BROW(brow7, 7);
719
720
#undef READ_BROW
721
722
#define MM_DOT_PRODUCT( _row, _dot ) \
723
arow = vload4(0, src0_read + _row * K); \
724
arow.x = (mad24(local_x, 4, w) < K) ? arow.x : 0.0f; \
725
arow.y = (mad24(local_x, 4, w + 1) < K) ? arow.y : 0.0f; \
726
arow.z = (mad24(local_x, 4, w + 2) < K) ? arow.z : 0.0f; \
727
arow.w = (mad24(local_x, 4, w + 3) < K) ? arow.w : 0.0f; \
728
_dot = mad( (Dtype8)(arow.x), (Dtype8)(brow0.x, brow1.x, brow2.x, brow3.x, brow4.x, brow5.x, brow6.x, brow7.x), _dot ); \
729
_dot = mad( (Dtype8)(arow.y), (Dtype8)(brow0.y, brow1.y, brow2.y, brow3.y, brow4.y, brow5.y, brow6.y, brow7.y), _dot ); \
730
_dot = mad( (Dtype8)(arow.z), (Dtype8)(brow0.z, brow1.z, brow2.z, brow3.z, brow4.z, brow5.z, brow6.z, brow7.z), _dot ); \
731
_dot = mad( (Dtype8)(arow.w), (Dtype8)(brow0.w, brow1.w, brow2.w, brow3.w, brow4.w, brow5.w, brow6.w, brow7.w), _dot );
732
733
MM_DOT_PRODUCT( 0, dot00 );
734
MM_DOT_PRODUCT( 1, dot01 );
735
MM_DOT_PRODUCT( 2, dot02 );
736
MM_DOT_PRODUCT( 3, dot03 );
737
MM_DOT_PRODUCT( 4, dot04 );
738
MM_DOT_PRODUCT( 5, dot05 );
739
MM_DOT_PRODUCT( 6, dot06 );
740
MM_DOT_PRODUCT( 7, dot07 );
741
#undef MM_DOT_PRODUCT
742
}
743
744
#define REDUCE(_dot) \
745
_dot = as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 0)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 1)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 2)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 3)) + \
746
as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 4)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 5)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 6)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 7));
747
748
REDUCE(dot00);
749
REDUCE(dot01);
750
REDUCE(dot02);
751
REDUCE(dot03);
752
REDUCE(dot04);
753
REDUCE(dot05);
754
REDUCE(dot06);
755
REDUCE(dot07);
756
#undef REDUCE
757
758
Dtype output = 0.0f;
759
#define OUTPUT( _dot) \
760
output = (local_x == 0) ? _dot.s0 : output; \
761
output = (local_x == 1) ? _dot.s1 : output; \
762
output = (local_x == 2) ? _dot.s2 : output; \
763
output = (local_x == 3) ? _dot.s3 : output; \
764
output = (local_x == 4) ? _dot.s4 : output; \
765
output = (local_x == 5) ? _dot.s5 : output; \
766
output = (local_x == 6) ? _dot.s6 : output; \
767
output = (local_x == 7) ? _dot.s7 : output; \
768
dst_write0[0] = mad(output, alpha, beta * dst_write0[0]); \
769
dst_write0 += N;
770
771
if(global_x < N && global_y * 8 < M) {
772
OUTPUT(dot00);
773
if(mad24(global_y, 8, 1) < M) { OUTPUT(dot01); }
774
if(mad24(global_y, 8, 2) < M) { OUTPUT(dot02); }
775
if(mad24(global_y, 8, 3) < M) { OUTPUT(dot03); }
776
if(mad24(global_y, 8, 4) < M) { OUTPUT(dot04); }
777
if(mad24(global_y, 8, 5) < M) { OUTPUT(dot05); }
778
if(mad24(global_y, 8, 6) < M) { OUTPUT(dot06); }
779
if(mad24(global_y, 8, 7) < M) { OUTPUT(dot07); }
780
}
781
#undef OUTPUT
782
}
783
#endif
784
785
#undef VEC_SIZE
786
#undef LWG_HEIGHT
787
#undef TILE_M
788
#undef TILE_K
789
#undef TILE_N
790
#undef SLM_BLOCK
791
792
#define SLM_SIZE 64
793
void TEMPLATE(gemm_buffer_NT_M_2_edgerows,Dtype)(
794
const __global Dtype* srca_read0,
795
const __global Dtype* srca_read1,
796
const __global Dtype* srcb_read,
797
__local Dtype4* work0,
798
__local Dtype4* work1,
799
int N,
800
int K,
801
int x_gid,
802
int lid,
803
Dtype alpha,
804
Dtype beta,
805
__global Dtype* dstc0,
806
__global Dtype* dstc1)
807
{
808
__local Dtype* work_each0 = (__local Dtype*)work0;
809
__local Dtype* work_each1 = (__local Dtype*)work1;
810
811
int rows = N - x_gid * 4;
812
813
Dtype4 dot0[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};
814
Dtype4 dot1[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};
815
816
int i = lid;
817
while( i < K / 4) {
818
const Dtype4 b0 = {srca_read0[i*4], srca_read0[(i*4+1)], srca_read0[(i*4+2)], srca_read0[(i*4+3)]};
819
const Dtype4 b1 = {srca_read1[i*4], srca_read1[(i*4+1)], srca_read1[(i*4+2)], srca_read1[(i*4+3)]};
820
#pragma unroll
821
for(int j = 0; j < rows; ++j) {
822
dot0[j] += b0 * vload4(i, srcb_read + j * K);
823
dot1[j] += b1 * vload4(i, srcb_read + j * K);
824
}
825
826
i += get_local_size(0);
827
}
828
#pragma unroll
829
for(int j = 0; j < rows; ++j) {
830
work_each0[lid * 4 + j] = dot0[j].x + dot0[j].y + dot0[j].z + dot0[j].w;
831
work_each1[lid * 4 + j] = dot1[j].x + dot1[j].y + dot1[j].z + dot1[j].w;
832
}
833
834
if(i == K / 4) {
835
short tail_items = K % 4;
836
837
if(tail_items != 0) {
838
const __global Dtype *srcb_tail = srcb_read + i * 4;
839
const __global Dtype *srca_tail0 = srca_read0 + i * 4;
840
const __global Dtype *srca_tail1 = srca_read1 + i * 4;
841
#pragma unroll
842
for(short i = 0; i < tail_items; ++i) {
843
const Dtype at0 = srca_tail0[i];
844
const Dtype at1 = srca_tail1[i];
845
#pragma unroll
846
for(int j = 0; j < rows; ++j) {
847
work_each0[lid * 4 + j] += at0 * srcb_tail[i + j * K];
848
work_each1[lid * 4 + j] += at1 * srcb_tail[i + j * K];
849
}
850
}
851
}
852
}
853
854
for(int stride = get_local_size(0) >> 1; stride > 0 ; stride >>= 1) {
855
barrier(CLK_LOCAL_MEM_FENCE);
856
if(lid < stride) {
857
work0[lid] += work0[lid+stride];
858
work1[lid] += work1[lid+stride];
859
}
860
}
861
862
if(lid == 0) {
863
#pragma unroll
864
for(int j = 0; j < rows; ++j) {
865
dstc0[(x_gid * 4 + j)] = alpha * work_each0[j] + beta * dstc0[(x_gid * 4 + j)];
866
dstc1[(x_gid * 4 + j)] = alpha * work_each1[j] + beta * dstc1[(x_gid * 4 + j)];
867
}
868
}
869
}
870
871
__kernel void TEMPLATE(gemm_buffer_NT_M_2,Dtype)(
872
__global const Dtype * A,
873
int offA,
874
__global const Dtype * B,
875
int offB,
876
__global Dtype * C,
877
int offC,
878
int M,
879
int N,
880
int K,
881
KERNEL_ARG_DTYPE alpha_f,
882
KERNEL_ARG_DTYPE beta_f)
883
{
884
Dtype alpha = (Dtype)alpha_f;
885
Dtype beta = (Dtype)beta_f;
886
int x_gid = get_group_id(0);
887
int lid = get_local_id(0);
888
889
const __global Dtype *srca_read0 = A + offA;
890
const __global Dtype *srca_read1 = srca_read0 + K;
891
892
const __global Dtype *srcb_read = B + x_gid * 4 * K + offB;
893
894
__global Dtype4 *dstc0 = (__global Dtype4*)(C + offC);
895
__global Dtype4 *dstc1 = (__global Dtype4*)((__global Dtype*)(dstc0) + N);
896
897
__local Dtype4 work0[SLM_SIZE];
898
__local Dtype4 work1[SLM_SIZE];
899
__local Dtype* work_each0 = (__local Dtype*)work0;
900
__local Dtype* work_each1 = (__local Dtype*)work1;
901
902
if(x_gid == N / 4) {
903
TEMPLATE(gemm_buffer_NT_M_2_edgerows,Dtype) \
904
(srca_read0, srca_read1, srcb_read, work0, work1, N, K, x_gid, lid, alpha, beta, (__global Dtype*)dstc0, (__global Dtype*)dstc1);
905
} else {
906
Dtype4 dot0[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};
907
Dtype4 dot1[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};
908
int i = lid;
909
while( i < K / 4) {
910
const Dtype4 b0 = vload4(i, srca_read0);
911
const Dtype4 b1 = vload4(i, srca_read1);
912
#pragma unroll
913
for(int j = 0; j < 4; ++j) {
914
Dtype4 a = vload4(i, srcb_read + j * K);
915
dot0[j] += b0 * a;
916
dot1[j] += b1 * a;
917
}
918
i += get_local_size(0);
919
}
920
921
#pragma unroll
922
for(int j = 0; j < 4; ++j) {
923
work_each0[lid * 4 + j] = dot0[j].x + dot0[j].y + dot0[j].z + dot0[j].w;
924
work_each1[lid * 4 + j] = dot1[j].x + dot1[j].y + dot1[j].z + dot1[j].w;
925
}
926
927
if(i == K / 4) {
928
short tail_items = K % 4;
929
if(tail_items != 0) {
930
const __global Dtype *srcb_tail = srcb_read + i * 4;
931
932
const __global Dtype *srca_tail0 = srca_read0 + i * 4;
933
const __global Dtype *srca_tail1 = srca_read1 + i * 4;
934
#pragma unroll
935
for(short i = 0; i < tail_items; ++i) {
936
const Dtype at0 = srca_tail0[i];
937
const Dtype at1 = srca_tail1[i];
938
#pragma unroll
939
for(int j = 0; j < 4; ++j) {
940
work_each0[lid * 4 + j] += at0 * srcb_tail[i + j * K];
941
work_each1[lid * 4 + j] += at1 * srcb_tail[i + j * K];
942
}
943
}
944
}
945
}
946
947
for(int stride = get_local_size(0) >> 1; stride > 0 ; stride >>= 1) {
948
barrier(CLK_LOCAL_MEM_FENCE);
949
if(lid < stride) {
950
work0[lid] += work0[lid+stride];
951
work1[lid] += work1[lid+stride];
952
}
953
}
954
955
if(lid == 0) {
956
dstc0[x_gid] = alpha * work0[0] + beta * dstc0[x_gid];
957
dstc1[x_gid] = alpha * work1[0] + beta * dstc1[x_gid];
958
}
959
}
960
}
961
#undef SLM_SIZE
962
963
#define SLM_SIZE 32
964
void TEMPLATE(gemm_buffer_NT_M_4_edgerows,Dtype)(
965
const __global Dtype* srca_read0,
966
const __global Dtype* srca_read1,
967
const __global Dtype* srca_read2,
968
const __global Dtype* srca_read3,
969
const __global Dtype* srcb_read,
970
__local Dtype4* work0,
971
__local Dtype4* work1,
972
__local Dtype4* work2,
973
__local Dtype4* work3,
974
int N,
975
int K,
976
int x_gid,
977
int lid,
978
Dtype alpha,
979
Dtype beta,
980
__global Dtype* dstc0,
981
__global Dtype* dstc1,
982
__global Dtype* dstc2,
983
__global Dtype* dstc3)
984
{
985
__local Dtype* work_each0 = (__local Dtype*)(work0 + lid);
986
__local Dtype* work_each1 = (__local Dtype*)(work1 + lid);
987
__local Dtype* work_each2 = (__local Dtype*)(work2 + lid);
988
__local Dtype* work_each3 = (__local Dtype*)(work3 + lid);
989
990
int rows = N - x_gid * 4;
991
992
Dtype4 dot0[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};
993
Dtype4 dot1[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};
994
Dtype4 dot2[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};
995
Dtype4 dot3[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};
996
997
int i = lid;
998
while( i < K / 4) {
999
const Dtype4 a0 = {srca_read0[i*4], srca_read0[(i*4+1)], srca_read0[(i*4+2)], srca_read0[(i*4+3)]};
1000
const Dtype4 a1 = {srca_read1[i*4], srca_read1[(i*4+1)], srca_read1[(i*4+2)], srca_read1[(i*4+3)]};
1001
const Dtype4 a2 = {srca_read2[i*4], srca_read2[(i*4+1)], srca_read2[(i*4+2)], srca_read2[(i*4+3)]};
1002
const Dtype4 a3 = {srca_read3[i*4], srca_read3[(i*4+1)], srca_read3[(i*4+2)], srca_read3[(i*4+3)]};
1003
#pragma unrol
1004
for(int j = 0; j < rows; ++j) {
1005
dot0[j] += a0 * vload4(i, srcb_read + j * K);
1006
dot1[j] += a1 * vload4(i, srcb_read + j * K);
1007
dot2[j] += a2 * vload4(i, srcb_read + j * K);
1008
dot3[j] += a3 * vload4(i, srcb_read + j * K);
1009
}
1010
1011
i += get_local_size(0);
1012
}
1013
#pragma unroll
1014
for(int j = 0; j < rows; ++j) {
1015
work_each0[j] = dot0[j].x + dot0[j].y + dot0[j].z + dot0[j].w;
1016
work_each1[j] = dot1[j].x + dot1[j].y + dot1[j].z + dot1[j].w;
1017
work_each2[j] = dot2[j].x + dot2[j].y + dot2[j].z + dot2[j].w;
1018
work_each3[j] = dot3[j].x + dot3[j].y + dot3[j].z + dot3[j].w;
1019
}
1020
1021
if(i == K / 4) {
1022
short tail_items = K % 4;
1023
1024
if(tail_items != 0) {
1025
const __global Dtype *srcb_tail = srcb_read + i * 4;
1026
1027
const __global Dtype *srca_tail0 = srca_read0 + i * 4;
1028
const __global Dtype *srca_tail1 = srca_read1 + i * 4;
1029
const __global Dtype *srca_tail2 = srca_read2 + i * 4;
1030
const __global Dtype *srca_tail3 = srca_read3 + i * 4;
1031
#pragma unroll
1032
for(short i = 0; i < tail_items; ++i) {
1033
const Dtype at0 = srca_tail0[i];
1034
const Dtype at1 = srca_tail1[i];
1035
const Dtype at2 = srca_tail2[i];
1036
const Dtype at3 = srca_tail3[i];
1037
#pragma unroll
1038
for(int j = 0; j < rows; ++j) {
1039
work_each0[j] += at0 * srcb_tail[i + j * K];
1040
work_each1[j] += at1 * srcb_tail[i + j * K];
1041
work_each2[j] += at2 * srcb_tail[i + j * K];
1042
work_each3[j] += at3 * srcb_tail[i + j * K];
1043
}
1044
}
1045
}
1046
}
1047
1048
for(int stride = get_local_size(0) >> 1; stride > 0 ; stride >>= 1) {
1049
barrier(CLK_LOCAL_MEM_FENCE);
1050
if(lid < stride) {
1051
work0[lid] += work0[lid+stride];
1052
work1[lid] += work1[lid+stride];
1053
work2[lid] += work2[lid+stride];
1054
work3[lid] += work3[lid+stride];
1055
}
1056
}
1057
1058
if(lid == 0) {
1059
#pragma unroll
1060
for(int j = 0; j < rows; ++j) {
1061
dstc0[(x_gid * 4 + j)] = alpha * work_each0[j] + beta * dstc0[(x_gid * 4 + j)];
1062
dstc1[(x_gid * 4 + j)] = alpha * work_each1[j] + beta * dstc1[(x_gid * 4 + j)];
1063
dstc2[(x_gid * 4 + j)] = alpha * work_each2[j] + beta * dstc2[(x_gid * 4 + j)];
1064
dstc3[(x_gid * 4 + j)] = alpha * work_each3[j] + beta * dstc3[(x_gid * 4 + j)];
1065
}
1066
}
1067
}
1068
1069
__kernel void TEMPLATE(gemm_buffer_NT_M_4,Dtype)(
1070
__global const Dtype * A,
1071
int offA,
1072
__global const Dtype * B,
1073
int offB,
1074
__global Dtype * C,
1075
int offC,
1076
int M,
1077
int N,
1078
int K,
1079
KERNEL_ARG_DTYPE alpha_f,
1080
KERNEL_ARG_DTYPE beta_f)
1081
{
1082
Dtype alpha = (Dtype)alpha_f;
1083
Dtype beta = (Dtype)beta_f;
1084
int x_gid = get_group_id(0);
1085
int lid = get_local_id(0);
1086
int lsize = get_local_size(0);
1087
1088
const __global Dtype *srca_read0 = A + offA;
1089
const __global Dtype *srca_read1 = srca_read0 + K;
1090
const __global Dtype *srca_read2 = srca_read1 + K;
1091
const __global Dtype *srca_read3 = srca_read2 + K;
1092
1093
const __global Dtype *srcb_read = B + x_gid * 4 * K + offB;
1094
1095
__global Dtype4 *dstc0 = (__global Dtype4*)(C + offC);
1096
__global Dtype4 *dstc1 = (__global Dtype4*)((__global Dtype*)(dstc0) + N);
1097
__global Dtype4 *dstc2 = (__global Dtype4*)((__global Dtype*)(dstc1) + N);
1098
__global Dtype4 *dstc3 = (__global Dtype4*)((__global Dtype*)(dstc2) + N);
1099
1100
__local Dtype4 work0[SLM_SIZE];
1101
__local Dtype4 work1[SLM_SIZE];
1102
__local Dtype4 work2[SLM_SIZE];
1103
__local Dtype4 work3[SLM_SIZE];
1104
__local Dtype* work_each0 = (__local Dtype*)(work0 + lid);
1105
__local Dtype* work_each1 = (__local Dtype*)(work1 + lid);
1106
__local Dtype* work_each2 = (__local Dtype*)(work2 + lid);
1107
__local Dtype* work_each3 = (__local Dtype*)(work3 + lid);
1108
1109
if(x_gid == N / 4) {
1110
TEMPLATE(gemm_buffer_NT_M_4_edgerows,Dtype) \
1111
(srca_read0, srca_read1, srca_read2, srca_read3, srcb_read, \
1112
work0, work1, work2, work3, N, K, x_gid, lid, alpha, beta, \
1113
(__global Dtype*)dstc0, (__global Dtype*)dstc1, (__global Dtype*)dstc2, (__global Dtype*)dstc3);
1114
} else {
1115
Dtype4 dot0[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};
1116
Dtype4 dot1[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};
1117
Dtype4 dot2[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};
1118
Dtype4 dot3[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};
1119
1120
int kid = lid;
1121
while( kid < K / 4) {
1122
const Dtype4 b0 = vload4(kid, srca_read0);
1123
const Dtype4 b1 = vload4(kid, srca_read1);
1124
const Dtype4 b2 = vload4(kid, srca_read2);
1125
const Dtype4 b3 = vload4(kid, srca_read3);
1126
#pragma unroll
1127
for(int j = 0; j < 4; ++j) {
1128
Dtype4 a = vload4(kid, srcb_read + j * K);
1129
dot0[j] += b0 * a;
1130
dot1[j] += b1 * a;
1131
dot2[j] += b2 * a;
1132
dot3[j] += b3 * a;
1133
}
1134
kid += lsize;
1135
}
1136
#pragma unroll
1137
for(int j = 0; j < 4; ++j) {
1138
work_each0[j] = dot0[j].x + dot0[j].y + dot0[j].z + dot0[j].w;
1139
work_each1[j] = dot1[j].x + dot1[j].y + dot1[j].z + dot1[j].w;
1140
work_each2[j] = dot2[j].x + dot2[j].y + dot2[j].z + dot2[j].w;
1141
work_each3[j] = dot3[j].x + dot3[j].y + dot3[j].z + dot3[j].w;
1142
}
1143
1144
if(kid == (K >> 2)) {
1145
short tail_items = K % 4;
1146
if(tail_items != 0) {
1147
int offset = kid << 2;
1148
const __global Dtype *srcb_tail = srcb_read + offset;
1149
1150
const __global Dtype *srca_tail0 = srca_read0 + offset;
1151
const __global Dtype *srca_tail1 = srca_read1 + offset;
1152
const __global Dtype *srca_tail2 = srca_read2 + offset;
1153
const __global Dtype *srca_tail3 = srca_read3 + offset;
1154
#pragma unroll
1155
for(short i = 0; i < tail_items; ++i) {
1156
const Dtype at0 = srca_tail0[i];
1157
const Dtype at1 = srca_tail1[i];
1158
const Dtype at2 = srca_tail2[i];
1159
const Dtype at3 = srca_tail3[i];
1160
#pragma unroll
1161
for(int j = 0; j < 4; ++j) {
1162
work_each0[j] += at0 * srcb_tail[i + j * K];
1163
work_each1[j] += at1 * srcb_tail[i + j * K];
1164
work_each2[j] += at2 * srcb_tail[i + j * K];
1165
work_each3[j] += at3 * srcb_tail[i + j * K];
1166
}
1167
}
1168
}
1169
}
1170
1171
for(int stride = get_local_size(0) >> 1; stride > 0 ; stride >>= 1) {
1172
barrier(CLK_LOCAL_MEM_FENCE);
1173
if(lid < stride) {
1174
work0[lid] += work0[lid+stride];
1175
work1[lid] += work1[lid+stride];
1176
work2[lid] += work2[lid+stride];
1177
work3[lid] += work3[lid+stride];
1178
}
1179
}
1180
1181
if(lid == 0) {
1182
dstc0[x_gid] = alpha * work0[0] + beta * dstc0[x_gid];
1183
dstc1[x_gid] = alpha * work1[0] + beta * dstc1[x_gid];
1184
dstc2[x_gid] = alpha * work2[0] + beta * dstc2[x_gid];
1185
dstc3[x_gid] = alpha * work3[0] + beta * dstc3[x_gid];
1186
}
1187
}
1188
}
1189
#undef SLM_SIZE
1190
1191
#define SLM_SIZE 16
1192
__kernel void TEMPLATE(gemm_buffer_NT_M_8,Dtype)(
1193
__global const Dtype * A,
1194
int offA,
1195
__global const Dtype * B,
1196
int offB,
1197
__global Dtype * C,
1198
int offC,
1199
int M,
1200
int N,
1201
int K,
1202
KERNEL_ARG_DTYPE alpha_f,
1203
KERNEL_ARG_DTYPE beta_f)
1204
{
1205
Dtype alpha = (Dtype)alpha_f;
1206
Dtype beta = (Dtype)beta_f;
1207
int x_gid = get_group_id(0);
1208
int lid = get_local_id(0);
1209
int lsize = get_local_size(0);
1210
1211
const __global Dtype *srca_read0 = A + offA;
1212
const __global Dtype *srca_read1 = srca_read0 + K;
1213
const __global Dtype *srca_read2 = srca_read1 + K;
1214
const __global Dtype *srca_read3 = srca_read2 + K;
1215
const __global Dtype *srca_read4 = srca_read3 + K;
1216
const __global Dtype *srca_read5 = srca_read4 + K;
1217
const __global Dtype *srca_read6 = srca_read5 + K;
1218
const __global Dtype *srca_read7 = srca_read6 + K;
1219
1220
const __global Dtype *srcb_read = B + x_gid * K + offB;
1221
1222
__global Dtype *dstc0 = C + offC;
1223
__global Dtype *dstc1 = dstc0 + N;
1224
__global Dtype *dstc2 = dstc1 + N;
1225
__global Dtype *dstc3 = dstc2 + N;
1226
__global Dtype *dstc4 = dstc3 + N;
1227
__global Dtype *dstc5 = dstc4 + N;
1228
__global Dtype *dstc6 = dstc5 + N;
1229
__global Dtype *dstc7 = dstc6 + N;
1230
1231
__local Dtype work0[SLM_SIZE];
1232
__local Dtype work1[SLM_SIZE];
1233
__local Dtype work2[SLM_SIZE];
1234
__local Dtype work3[SLM_SIZE];
1235
__local Dtype work4[SLM_SIZE];
1236
__local Dtype work5[SLM_SIZE];
1237
__local Dtype work6[SLM_SIZE];
1238
__local Dtype work7[SLM_SIZE];
1239
1240
Dtype4 dot0 = (Dtype4)(0.);
1241
Dtype4 dot1 = (Dtype4)(0.);
1242
Dtype4 dot2 = (Dtype4)(0.);
1243
Dtype4 dot3 = (Dtype4)(0.);
1244
Dtype4 dot4 = (Dtype4)(0.);
1245
Dtype4 dot5 = (Dtype4)(0.);
1246
Dtype4 dot6 = (Dtype4)(0.);
1247
Dtype4 dot7 = (Dtype4)(0.);
1248
1249
int kid = lid;
1250
while( kid < K / 4) {
1251
const Dtype4 a0 = vload4(kid, srca_read0);
1252
const Dtype4 a1 = vload4(kid, srca_read1);
1253
const Dtype4 a2 = vload4(kid, srca_read2);
1254
const Dtype4 a3 = vload4(kid, srca_read3);
1255
const Dtype4 a4 = vload4(kid, srca_read4);
1256
const Dtype4 a5 = vload4(kid, srca_read5);
1257
const Dtype4 a6 = vload4(kid, srca_read6);
1258
const Dtype4 a7 = vload4(kid, srca_read7);
1259
Dtype4 b = vload4(kid, srcb_read);
1260
dot0 += a0 * b;
1261
dot1 += a1 * b;
1262
dot2 += a2 * b;
1263
dot3 += a3 * b;
1264
dot4 += a4 * b;
1265
dot5 += a5 * b;
1266
dot6 += a6 * b;
1267
dot7 += a7 * b;
1268
1269
kid += lsize;
1270
}
1271
work0[lid] = dot0.x + dot0.y + dot0.z + dot0.w;
1272
work1[lid] = dot1.x + dot1.y + dot1.z + dot1.w;
1273
work2[lid] = dot2.x + dot2.y + dot2.z + dot2.w;
1274
work3[lid] = dot3.x + dot3.y + dot3.z + dot3.w;
1275
work4[lid] = dot4.x + dot4.y + dot4.z + dot4.w;
1276
work5[lid] = dot5.x + dot5.y + dot5.z + dot5.w;
1277
work6[lid] = dot6.x + dot6.y + dot6.z + dot6.w;
1278
work7[lid] = dot7.x + dot7.y + dot7.z + dot7.w;
1279
1280
if(kid == (K >> 2)) {
1281
short tail_items = K % 4;
1282
if(tail_items != 0) {
1283
int offset = kid << 2;
1284
const __global Dtype *srcb_tail = srcb_read + offset;
1285
1286
const __global Dtype *srca_tail0 = srca_read0 + offset;
1287
const __global Dtype *srca_tail1 = srca_read1 + offset;
1288
const __global Dtype *srca_tail2 = srca_read2 + offset;
1289
const __global Dtype *srca_tail3 = srca_read3 + offset;
1290
const __global Dtype *srca_tail4 = srca_read4 + offset;
1291
const __global Dtype *srca_tail5 = srca_read5 + offset;
1292
const __global Dtype *srca_tail6 = srca_read6 + offset;
1293
const __global Dtype *srca_tail7 = srca_read7 + offset;
1294
#pragma unroll
1295
for(short item = 0; item < tail_items; ++item) {
1296
work0[lid] += srca_tail0[item] * srcb_tail[item];
1297
work1[lid] += srca_tail1[item] * srcb_tail[item];
1298
work2[lid] += srca_tail2[item] * srcb_tail[item];
1299
work3[lid] += srca_tail3[item] * srcb_tail[item];
1300
work4[lid] += srca_tail4[item] * srcb_tail[item];
1301
work5[lid] += srca_tail5[item] * srcb_tail[item];
1302
work6[lid] += srca_tail6[item] * srcb_tail[item];
1303
work7[lid] += srca_tail7[item] * srcb_tail[item];
1304
}
1305
}
1306
}
1307
1308
for(int stride = get_local_size(0) >> 1; stride > 0 ; stride >>= 1) {
1309
barrier(CLK_LOCAL_MEM_FENCE);
1310
if(lid < stride) {
1311
work0[lid] += work0[lid+stride];
1312
work1[lid] += work1[lid+stride];
1313
work2[lid] += work2[lid+stride];
1314
work3[lid] += work3[lid+stride];
1315
work4[lid] += work4[lid+stride];
1316
work5[lid] += work5[lid+stride];
1317
work6[lid] += work6[lid+stride];
1318
work7[lid] += work7[lid+stride];
1319
}
1320
}
1321
1322
if(lid == 0) {
1323
dstc0[x_gid] = alpha * work0[0] + beta * dstc0[x_gid];
1324
dstc1[x_gid] = alpha * work1[0] + beta * dstc1[x_gid];
1325
dstc2[x_gid] = alpha * work2[0] + beta * dstc2[x_gid];
1326
dstc3[x_gid] = alpha * work3[0] + beta * dstc3[x_gid];
1327
dstc4[x_gid] = alpha * work4[0] + beta * dstc4[x_gid];
1328
dstc5[x_gid] = alpha * work5[0] + beta * dstc5[x_gid];
1329
dstc6[x_gid] = alpha * work6[0] + beta * dstc6[x_gid];
1330
dstc7[x_gid] = alpha * work7[0] + beta * dstc7[x_gid];
1331
}
1332
}
1333
#undef SLM_SIZE
1334
1335
#undef VEC_SIZE
1336
#undef LWG_HEIGHT
1337
#undef TILE_M
1338
#undef TILE_K
1339
#undef TILE_N
1340
#undef SIMD_SIZE_GEMM
1341
#undef SHUFFLE_TYPE2
1342
#undef SHUFFLE_TYPE8
1343
1344