Path: blob/master/modules/dnn/src/opencl/gemm_buffer.cl
16337 views
/*M///////////////////////////////////////////////////////////////////////////////////////1//2// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.3//4// By downloading, copying, installing or using the software you agree to this license.5// If you do not agree to this license, do not download, install,6// copy or use the software.7//8//9// License Agreement10// For Open Source Computer Vision Library11//12// Copyright (C) 2017, Intel Corporation, all rights reserved.13// Third party copyrights are property of their respective owners.14//15// Redistribution and use in source and binary forms, with or without modification,16// are permitted provided that the following conditions are met:17//18// * Redistribution's of source code must retain the above copyright notice,19// this list of conditions and the following disclaimer.20//21// * Redistribution's in binary form must reproduce the above copyright notice,22// this list of conditions and the following disclaimer in the documentation23// and/or other materials provided with the distribution.24//25// * The name of the copyright holders may not be used to endorse or promote products26// derived from this software without specific prior written permission.27//28// This software is provided by the copyright holders and contributors "as is" and29// any express or implied warranties, including, but not limited to, the implied30// warranties of merchantability and fitness for a particular purpose are disclaimed.31// In no event shall the Intel Corporation or contributors be liable for any direct,32// indirect, incidental, special, exemplary, or consequential damages33// (including, but not limited to, procurement of substitute goods or services;34// loss of use, data, or profits; or business interruption) however caused35// and on any theory of liability, whether in contract, strict liability,36// or tort (including negligence or otherwise) arising in any way out of37// the use of this software, even if advised of the possibility of such damage.38//39//M*/4041#if defined(cl_khr_fp16)42#pragma OPENCL EXTENSION cl_khr_fp16 : enable43#endif4445#define CONCAT(A,B) A##_##B46#define TEMPLATE(name,type) CONCAT(name,type)4748#define KERNEL_ARG_DTYPE float49#define TYPE_FLOAT 150#define TYPE_HALF 25152#if TYPE == TYPE_HALF53#define Dtype half54#define Dtype2 half255#define Dtype4 half456#define Dtype8 half857#define Dtype16 half165859#define as_Dtype as_half60#define as_Dtype2 as_half261#define as_Dtype4 as_half462#define as_Dtype8 as_half863#define as_Dtype16 as_half1664#else65#define Dtype float66#define Dtype2 float267#define Dtype4 float468#define Dtype8 float869#define Dtype16 float167071#define as_Dtype as_float72#define as_Dtype2 as_float273#define as_Dtype4 as_float474#define as_Dtype8 as_float875#define as_Dtype16 as_float1676#endif7778#if TYPE == TYPE_HALF79#define SHUFFLE_TYPE2(val) as_ushort2(val)80#define SHUFFLE_TYPE8(val) as_ushort8(val)81#define SIMD_SIZE_GEMM 1682#else83#define SHUFFLE_TYPE2(val) val84#define SHUFFLE_TYPE8(val) val85#define SIMD_SIZE_GEMM 886#endif8788#if defined(cl_intel_subgroups)89#pragma OPENCL EXTENSION cl_intel_subgroups : enable90#endif9192#define VEC_SIZE 493#define LWG_HEIGHT 494#define TILE_M 895#if TYPE == TYPE_HALF96#define TILE_K 3297#define TILE_N 6498#else99#define TILE_K 16100#define TILE_N 32101#endif102103__attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, LWG_HEIGHT, 1)))104__attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM)))105__kernel void TEMPLATE(gemm_buffer_NN, Dtype)(106const __global Dtype *src0, int off0,107const __global Dtype *src1, int off1,108__global Dtype *dst, int offd,109int M,110int N,111int K,112KERNEL_ARG_DTYPE alpha_in,113KERNEL_ARG_DTYPE beta_in,114int start_index)115{116const Dtype alpha = (Dtype)alpha_in;117const Dtype beta = (Dtype)beta_in;118const int group_x = get_group_id(0);119const int group_y = get_group_id(1);120const int local_x = get_local_id(0);121const int local_y = get_local_id(1);122const int global_x = get_global_id(0);123const int global_y = get_global_id(1);124125Dtype4 brow;126Dtype2 arow0, arow1, arow2, arow3, arow4, arow5, arow6, arow7;127128__global Dtype *dst_write0 = dst + local_x * VEC_SIZE + (group_x * TILE_N) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd;129130const __global Dtype *src0_read = src0 + local_x * (TILE_K / SIMD_SIZE_GEMM) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * K + start_index + off0;131132const __global Dtype *src1_read0 = src1 + local_x * VEC_SIZE + (group_x * TILE_N) + start_index * N + off1;133134int border = -(group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M);135136int row0 = mad24(global_y, TILE_M, 0) < M ? 0 : border;137int row1 = mad24(global_y, TILE_M, 1) < M ? 1 : border;138int row2 = mad24(global_y, TILE_M, 2) < M ? 2 : border;139int row3 = mad24(global_y, TILE_M, 3) < M ? 3 : border;140int row4 = mad24(global_y, TILE_M, 4) < M ? 4 : border;141int row5 = mad24(global_y, TILE_M, 5) < M ? 5 : border;142int row6 = mad24(global_y, TILE_M, 6) < M ? 6 : border;143int row7 = mad24(global_y, TILE_M, 7) < M ? 7 : border;144145Dtype4 dot00 = (start_index != 0) ? vload4(0, dst_write0) : beta * vload4(0, dst_write0);146Dtype4 dot01 = (start_index != 0) ? vload4(0, dst_write0 + 1 * N) : beta * vload4(0, dst_write0 + 1 * N);147Dtype4 dot02 = (start_index != 0) ? vload4(0, dst_write0 + 2 * N) : beta * vload4(0, dst_write0 + 2 * N);148Dtype4 dot03 = (start_index != 0) ? vload4(0, dst_write0 + 3 * N) : beta * vload4(0, dst_write0 + 3 * N);149Dtype4 dot04 = (start_index != 0) ? vload4(0, dst_write0 + 4 * N) : beta * vload4(0, dst_write0 + 4 * N);150Dtype4 dot05 = (start_index != 0) ? vload4(0, dst_write0 + 5 * N) : beta * vload4(0, dst_write0 + 5 * N);151Dtype4 dot06 = (start_index != 0) ? vload4(0, dst_write0 + 6 * N) : beta * vload4(0, dst_write0 + 6 * N);152Dtype4 dot07 = (start_index != 0) ? vload4(0, dst_write0 + 7 * N) : beta * vload4(0, dst_write0 + 7 * N);153154int end_index = min(start_index + 256, K);155int w = start_index;156while( w + TILE_K <= end_index ) {157arow0 = alpha * vload2(0, src0_read + row0 * K);158arow1 = alpha * vload2(0, src0_read + row1 * K);159arow2 = alpha * vload2(0, src0_read + row2 * K);160arow3 = alpha * vload2(0, src0_read + row3 * K);161arow4 = alpha * vload2(0, src0_read + row4 * K);162arow5 = alpha * vload2(0, src0_read + row5 * K);163arow6 = alpha * vload2(0, src0_read + row6 * K);164arow7 = alpha * vload2(0, src0_read + row7 * K);165166#define MM_DOT_PRODUCT( index, suffix ) \167brow = vload4(0, src1_read0); src1_read0 += N; \168dot00 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow0), index )).s##suffix), brow, dot00 ); \169dot01 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow1), index )).s##suffix), brow, dot01 ); \170dot02 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow2), index )).s##suffix), brow, dot02 ); \171dot03 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow3), index )).s##suffix), brow, dot03 ); \172dot04 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow4), index )).s##suffix), brow, dot04 ); \173dot05 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow5), index )).s##suffix), brow, dot05 ); \174dot06 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow6), index )).s##suffix), brow, dot06 ); \175dot07 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow7), index )).s##suffix), brow, dot07 );176177MM_DOT_PRODUCT(0, 0);178MM_DOT_PRODUCT(0, 1);179MM_DOT_PRODUCT(1, 0);180MM_DOT_PRODUCT(1, 1);181MM_DOT_PRODUCT(2, 0);182MM_DOT_PRODUCT(2, 1);183MM_DOT_PRODUCT(3, 0);184MM_DOT_PRODUCT(3, 1);185MM_DOT_PRODUCT(4, 0);186MM_DOT_PRODUCT(4, 1);187MM_DOT_PRODUCT(5, 0);188MM_DOT_PRODUCT(5, 1);189MM_DOT_PRODUCT(6, 0);190MM_DOT_PRODUCT(6, 1);191MM_DOT_PRODUCT(7, 0);192MM_DOT_PRODUCT(7, 1);193#if TYPE == TYPE_HALF194MM_DOT_PRODUCT(8, 0);195MM_DOT_PRODUCT(8, 1);196MM_DOT_PRODUCT(9, 0);197MM_DOT_PRODUCT(9, 1);198MM_DOT_PRODUCT(10, 0);199MM_DOT_PRODUCT(10, 1);200MM_DOT_PRODUCT(11, 0);201MM_DOT_PRODUCT(11, 1);202MM_DOT_PRODUCT(12, 0);203MM_DOT_PRODUCT(12, 1);204MM_DOT_PRODUCT(13, 0);205MM_DOT_PRODUCT(13, 1);206MM_DOT_PRODUCT(14, 0);207MM_DOT_PRODUCT(14, 1);208MM_DOT_PRODUCT(15, 0);209MM_DOT_PRODUCT(15, 1);210#endif211#undef MM_DOT_PRODUCT212213src0_read += TILE_K;214w += TILE_K;215}216217if(w < end_index) {218arow0.x = ((w + local_x * 2) < K) ? alpha * (src0_read + row0 * K)[0] : 0.0f;219arow0.y = ((w + local_x * 2 + 1) < K) ? alpha * (src0_read + row0 * K)[1] : 0.0f;220arow1.x = ((w + local_x * 2) < K) ? alpha * (src0_read + row1 * K)[0] : 0.0f;221arow1.y = ((w + local_x * 2 + 1) < K) ? alpha * (src0_read + row1 * K)[1] : 0.0f;222arow2.x = ((w + local_x * 2) < K) ? alpha * (src0_read + row2 * K)[0] : 0.0f;223arow2.y = ((w + local_x * 2 + 1) < K) ? alpha * (src0_read + row2 * K)[1] : 0.0f;224arow3.x = ((w + local_x * 2) < K) ? alpha * (src0_read + row3 * K)[0] : 0.0f;225arow3.y = ((w + local_x * 2 + 1) < K) ? alpha * (src0_read + row3 * K)[1] : 0.0f;226arow4.x = ((w + local_x * 2) < K) ? alpha * (src0_read + row4 * K)[0] : 0.0f;227arow4.y = ((w + local_x * 2 + 1) < K) ? alpha * (src0_read + row4 * K)[1] : 0.0f;228arow5.x = ((w + local_x * 2) < K) ? alpha * (src0_read + row5 * K)[0] : 0.0f;229arow5.y = ((w + local_x * 2 + 1) < K) ? alpha * (src0_read + row5 * K)[1] : 0.0f;230arow6.x = ((w + local_x * 2) < K) ? alpha * (src0_read + row6 * K)[0] : 0.0f;231arow6.y = ((w + local_x * 2 + 1) < K) ? alpha * (src0_read + row6 * K)[1] : 0.0f;232arow7.x = ((w + local_x * 2) < K) ? alpha * (src0_read + row7 * K)[0] : 0.0f;233arow7.y = ((w + local_x * 2 + 1) < K) ? alpha * (src0_read + row7 * K)[1] : 0.0f;234235#define MM_DOT_PRODUCT( index, suffix ) \236brow = (w < K) ? vload4(0, src1_read0) : (Dtype4)0.0f; src1_read0 += N; w++; \237dot00 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow0), index )).s##suffix), brow, dot00 ); \238dot01 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow1), index )).s##suffix), brow, dot01 ); \239dot02 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow2), index )).s##suffix), brow, dot02 ); \240dot03 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow3), index )).s##suffix), brow, dot03 ); \241dot04 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow4), index )).s##suffix), brow, dot04 ); \242dot05 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow5), index )).s##suffix), brow, dot05 ); \243dot06 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow6), index )).s##suffix), brow, dot06 ); \244dot07 = mad( (Dtype4)(as_Dtype2(intel_sub_group_shuffle( SHUFFLE_TYPE2(arow7), index )).s##suffix), brow, dot07 );245246MM_DOT_PRODUCT(0, 0);247MM_DOT_PRODUCT(0, 1);248MM_DOT_PRODUCT(1, 0);249MM_DOT_PRODUCT(1, 1);250MM_DOT_PRODUCT(2, 0);251MM_DOT_PRODUCT(2, 1);252MM_DOT_PRODUCT(3, 0);253MM_DOT_PRODUCT(3, 1);254MM_DOT_PRODUCT(4, 0);255MM_DOT_PRODUCT(4, 1);256MM_DOT_PRODUCT(5, 0);257MM_DOT_PRODUCT(5, 1);258MM_DOT_PRODUCT(6, 0);259MM_DOT_PRODUCT(6, 1);260MM_DOT_PRODUCT(7, 0);261MM_DOT_PRODUCT(7, 1);262#if TYPE == TYPE_HALF263MM_DOT_PRODUCT(8, 0);264MM_DOT_PRODUCT(8, 1);265MM_DOT_PRODUCT(9, 0);266MM_DOT_PRODUCT(9, 1);267MM_DOT_PRODUCT(10, 0);268MM_DOT_PRODUCT(10, 1);269MM_DOT_PRODUCT(11, 0);270MM_DOT_PRODUCT(11, 1);271MM_DOT_PRODUCT(12, 0);272MM_DOT_PRODUCT(12, 1);273MM_DOT_PRODUCT(13, 0);274MM_DOT_PRODUCT(13, 1);275MM_DOT_PRODUCT(14, 0);276MM_DOT_PRODUCT(14, 1);277MM_DOT_PRODUCT(15, 0);278MM_DOT_PRODUCT(15, 1);279#endif280#undef MM_DOT_PRODUCT281}282283if(global_x * 4 < N && global_y * 8 < M) {284if(mad24(global_x, 4, 3) < N) {285vstore4(dot00, 0, dst_write0); dst_write0 += N;286if(mad24(global_y, 8, 1) < M) { vstore4(dot01, 0, dst_write0); dst_write0 += N; }287else return;288if(mad24(global_y, 8, 2) < M) { vstore4(dot02, 0, dst_write0); dst_write0 += N; }289else return;290if(mad24(global_y, 8, 3) < M) { vstore4(dot03, 0, dst_write0); dst_write0 += N; }291else return;292if(mad24(global_y, 8, 4) < M) { vstore4(dot04, 0, dst_write0); dst_write0 += N; }293else return;294if(mad24(global_y, 8, 5) < M) { vstore4(dot05, 0, dst_write0); dst_write0 += N; }295else return;296if(mad24(global_y, 8, 6) < M) { vstore4(dot06, 0, dst_write0); dst_write0 += N; }297else return;298if(mad24(global_y, 8, 7) < M) { vstore4(dot07, 0, dst_write0); }299} else if(mad24(global_x, 4, 2) < N) {300vstore2(dot00.xy, 0, dst_write0);301dst_write0[2] = dot00.z;302dst_write0 += N;303if(mad24(global_y, 8, 1) < M) {304vstore2(dot01.xy, 0, dst_write0);305dst_write0[2] = dot01.z;306dst_write0 += N;307} else308return;309if(mad24(global_y, 8, 2) < M) {310vstore2(dot02.xy, 0, dst_write0);311dst_write0[2] = dot02.z;312dst_write0 += N;313} else314return;315if(mad24(global_y, 8, 3) < M) {316vstore2(dot03.xy, 0, dst_write0);317dst_write0[2] = dot03.z;318dst_write0 += N;319} else320return;321if(mad24(global_y, 8, 4) < M) {322vstore2(dot04.xy, 0, dst_write0);323dst_write0[2] = dot04.z;324dst_write0 += N;325} else326return;327if(mad24(global_y, 8, 5) < M) {328vstore2(dot05.xy, 0, dst_write0);329dst_write0[2] = dot05.z;330dst_write0 += N;331} else332return;333if(mad24(global_y, 8, 6) < M) {334vstore2(dot06.xy, 0, dst_write0);335dst_write0[2] = dot06.z;336dst_write0 += N;337} else338return;339if(mad24(global_y, 8, 7) < M) {340vstore2(dot07.xy, 0, dst_write0);341dst_write0[2] = dot07.z;342}343} else if(mad24(global_x, 4, 1) < N) {344vstore2(dot00.xy, 0, dst_write0); dst_write0 += N;345if(mad24(global_y, 8, 1) < M) { vstore2(dot01.xy, 0, dst_write0); dst_write0 += N; }346else return;347if(mad24(global_y, 8, 2) < M) { vstore2(dot02.xy, 0, dst_write0); dst_write0 += N; }348else return;349if(mad24(global_y, 8, 3) < M) { vstore2(dot03.xy, 0, dst_write0); dst_write0 += N; }350else return;351if(mad24(global_y, 8, 4) < M) { vstore2(dot04.xy, 0, dst_write0); dst_write0 += N; }352else return;353if(mad24(global_y, 8, 5) < M) { vstore2(dot05.xy, 0, dst_write0); dst_write0 += N; }354else return;355if(mad24(global_y, 8, 6) < M) { vstore2(dot06.xy, 0, dst_write0); dst_write0 += N; }356else return;357if(mad24(global_y, 8, 7) < M) { vstore2(dot07.xy, 0, dst_write0); }358} else {359dst_write0[0] = dot00.x; dst_write0 += N;360if(mad24(global_y, 8, 1) < M) { dst_write0[0] = dot01.x; dst_write0 += N; }361else return;362if(mad24(global_y, 8, 2) < M) { dst_write0[0] = dot02.x; dst_write0 += N; }363else return;364if(mad24(global_y, 8, 3) < M) { dst_write0[0] = dot03.x; dst_write0 += N; }365else return;366if(mad24(global_y, 8, 4) < M) { dst_write0[0] = dot04.x; dst_write0 += N; }367else return;368if(mad24(global_y, 8, 5) < M) { dst_write0[0] = dot05.x; dst_write0 += N; }369else return;370if(mad24(global_y, 8, 6) < M) { dst_write0[0] = dot06.x; dst_write0 += N; }371else return;372if(mad24(global_y, 8, 7) < M) { dst_write0[0] = dot07.x; }373}374}375}376377#undef VEC_SIZE378#undef LWG_HEIGHT379#undef TILE_M380#undef TILE_K381#undef TILE_N382383#define VEC_SIZE 1384#define TILE_M 8385#define TILE_N 8386#define SLM_BLOCK 128387388#if TYPE == TYPE_HALF389#define LWG_HEIGHT 2390#define TILE_K 64391#else392#define LWG_HEIGHT 4393#define TILE_K 32394#endif395396#if TYPE == TYPE_HALF397__attribute__((reqd_work_group_size(8, LWG_HEIGHT, 1)))398__attribute__((intel_reqd_sub_group_size(8)))399__kernel void TEMPLATE(gemm_buffer_NT, Dtype)(400const __global Dtype *src0, int off0,401const __global Dtype *src1, int off1,402__global Dtype *dst, int offd,403int M,404int N,405int K,406KERNEL_ARG_DTYPE alpha_in,407KERNEL_ARG_DTYPE beta_in)408{409const Dtype alpha = (Dtype)alpha_in;410const Dtype beta = (Dtype)beta_in;411const int group_x = get_group_id(0);412const int group_y = get_group_id(1);413const int local_x = get_local_id(0);414const int local_y = get_local_id(1);415const int global_x = get_global_id(0);416const int global_y = get_global_id(1);417418Dtype8 dot00 = 0.f;419Dtype8 dot01 = 0.f;420Dtype8 dot02 = 0.f;421Dtype8 dot03 = 0.f;422Dtype8 dot04 = 0.f;423Dtype8 dot05 = 0.f;424Dtype8 dot06 = 0.f;425Dtype8 dot07 = 0.f;426427Dtype8 brow0;428Dtype8 brow1;429Dtype8 brow2;430Dtype8 brow3;431Dtype8 brow4;432Dtype8 brow5;433Dtype8 brow6;434Dtype8 brow7;435436__global Dtype *dst_write0 = dst + local_x * VEC_SIZE + (group_x * TILE_N) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd;437438const __global Dtype *src0_read = src0 + local_x * (TILE_K / 8) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * K + off0;439440const __global Dtype *src1_read0 = src1 + (group_x * TILE_N) * K + off1;441442__local Dtype slm_brow[8 * SLM_BLOCK];443__local Dtype* slm_brow0;444445int local_index = mad24(local_y, 8, local_x) * 8;446int w;447for(int b_tile = 0; b_tile < K; b_tile += SLM_BLOCK) {448barrier(CLK_LOCAL_MEM_FENCE);449vstore4(vload4(0, (__global float *)(src1_read0 + mad24(0, K, local_index))), 0, (__local float *)(slm_brow + mad24(0, SLM_BLOCK, local_index)));450vstore4(vload4(0, (__global float *)(src1_read0 + mad24(1, K, local_index))), 0, (__local float *)(slm_brow + mad24(1, SLM_BLOCK, local_index)));451vstore4(vload4(0, (__global float *)(src1_read0 + mad24(2, K, local_index))), 0, (__local float *)(slm_brow + mad24(2, SLM_BLOCK, local_index)));452vstore4(vload4(0, (__global float *)(src1_read0 + mad24(3, K, local_index))), 0, (__local float *)(slm_brow + mad24(3, SLM_BLOCK, local_index)));453vstore4(vload4(0, (__global float *)(src1_read0 + mad24(4, K, local_index))), 0, (__local float *)(slm_brow + mad24(4, SLM_BLOCK, local_index)));454vstore4(vload4(0, (__global float *)(src1_read0 + mad24(5, K, local_index))), 0, (__local float *)(slm_brow + mad24(5, SLM_BLOCK, local_index)));455vstore4(vload4(0, (__global float *)(src1_read0 + mad24(6, K, local_index))), 0, (__local float *)(slm_brow + mad24(6, SLM_BLOCK, local_index)));456vstore4(vload4(0, (__global float *)(src1_read0 + mad24(7, K, local_index))), 0, (__local float *)(slm_brow + mad24(7, SLM_BLOCK, local_index)));457barrier(CLK_LOCAL_MEM_FENCE);458459slm_brow0 = slm_brow + local_x * (TILE_K / 8);460w = b_tile;461int end_w = min(b_tile + SLM_BLOCK, K);462while( w + TILE_K <= end_w ) {463Dtype8 arow;464465brow0 = as_half8(vload4(0, (__local float *)(slm_brow0 + 0 * SLM_BLOCK)));466brow1 = as_half8(vload4(0, (__local float *)(slm_brow0 + 1 * SLM_BLOCK)));467brow2 = as_half8(vload4(0, (__local float *)(slm_brow0 + 2 * SLM_BLOCK)));468brow3 = as_half8(vload4(0, (__local float *)(slm_brow0 + 3 * SLM_BLOCK)));469brow4 = as_half8(vload4(0, (__local float *)(slm_brow0 + 4 * SLM_BLOCK)));470brow5 = as_half8(vload4(0, (__local float *)(slm_brow0 + 5 * SLM_BLOCK)));471brow6 = as_half8(vload4(0, (__local float *)(slm_brow0 + 6 * SLM_BLOCK)));472brow7 = as_half8(vload4(0, (__local float *)(slm_brow0 + 7 * SLM_BLOCK)));473474#define MM_DOT_PRODUCT( _row, _dot ) \475arow = as_half8(vload4(0, (__global float *)(src0_read + _row * K))); \476_dot = mad( (Dtype8)(arow.s0), (Dtype8)(brow0.s0, brow1.s0, brow2.s0, brow3.s0, brow4.s0, brow5.s0, brow6.s0, brow7.s0), _dot ); \477_dot = mad( (Dtype8)(arow.s1), (Dtype8)(brow0.s1, brow1.s1, brow2.s1, brow3.s1, brow4.s1, brow5.s1, brow6.s1, brow7.s1), _dot ); \478_dot = mad( (Dtype8)(arow.s2), (Dtype8)(brow0.s2, brow1.s2, brow2.s2, brow3.s2, brow4.s2, brow5.s2, brow6.s2, brow7.s2), _dot ); \479_dot = mad( (Dtype8)(arow.s3), (Dtype8)(brow0.s3, brow1.s3, brow2.s3, brow3.s3, brow4.s3, brow5.s3, brow6.s3, brow7.s3), _dot ); \480_dot = mad( (Dtype8)(arow.s4), (Dtype8)(brow0.s4, brow1.s4, brow2.s4, brow3.s4, brow4.s4, brow5.s4, brow6.s4, brow7.s4), _dot ); \481_dot = mad( (Dtype8)(arow.s5), (Dtype8)(brow0.s5, brow1.s5, brow2.s5, brow3.s5, brow4.s5, brow5.s5, brow6.s5, brow7.s5), _dot ); \482_dot = mad( (Dtype8)(arow.s6), (Dtype8)(brow0.s6, brow1.s6, brow2.s6, brow3.s6, brow4.s6, brow5.s6, brow6.s6, brow7.s6), _dot ); \483_dot = mad( (Dtype8)(arow.s7), (Dtype8)(brow0.s7, brow1.s7, brow2.s7, brow3.s7, brow4.s7, brow5.s7, brow6.s7, brow7.s7), _dot );484485MM_DOT_PRODUCT( 0, dot00 );486MM_DOT_PRODUCT( 1, dot01 );487MM_DOT_PRODUCT( 2, dot02 );488MM_DOT_PRODUCT( 3, dot03 );489MM_DOT_PRODUCT( 4, dot04 );490MM_DOT_PRODUCT( 5, dot05 );491MM_DOT_PRODUCT( 6, dot06 );492MM_DOT_PRODUCT( 7, dot07 );493#undef MM_DOT_PRODUCT494495src0_read += TILE_K;496slm_brow0 += TILE_K;497w += TILE_K;498}499src1_read0 += SLM_BLOCK;500}501502if(w < K) {503Dtype8 arow;504505#define READ_BROW(_brow, _row) \506_brow = as_half8(vload4(0, (__local float *)(slm_brow0 + _row * SLM_BLOCK))); \507_brow.s0 = (mad24(local_x, 8, w) < K) ? _brow.s0 : 0.0f; \508_brow.s1 = (mad24(local_x, 8, w + 1) < K) ? _brow.s1 : 0.0f; \509_brow.s2 = (mad24(local_x, 8, w + 2) < K) ? _brow.s2 : 0.0f; \510_brow.s3 = (mad24(local_x, 8, w + 3) < K) ? _brow.s3 : 0.0f; \511_brow.s4 = (mad24(local_x, 8, w + 4) < K) ? _brow.s4 : 0.0f; \512_brow.s5 = (mad24(local_x, 8, w + 5) < K) ? _brow.s5 : 0.0f; \513_brow.s6 = (mad24(local_x, 8, w + 6) < K) ? _brow.s6 : 0.0f; \514_brow.s7 = (mad24(local_x, 8, w + 7) < K) ? _brow.s7 : 0.0f;515516READ_BROW(brow0, 0);517READ_BROW(brow1, 1);518READ_BROW(brow2, 2);519READ_BROW(brow3, 3);520READ_BROW(brow4, 4);521READ_BROW(brow5, 5);522READ_BROW(brow6, 6);523READ_BROW(brow7, 7);524525#undef READ_BROW526527#define MM_DOT_PRODUCT( _row, _dot ) \528arow = as_half8(vload4(0, (__global float *)(src0_read + _row * K))); \529arow.s0 = (mad24(local_x, 8, w) < K) ? arow.s0 : 0.0f; \530arow.s1 = (mad24(local_x, 8, w + 1) < K) ? arow.s1 : 0.0f; \531arow.s2 = (mad24(local_x, 8, w + 2) < K) ? arow.s2 : 0.0f; \532arow.s3 = (mad24(local_x, 8, w + 3) < K) ? arow.s3 : 0.0f; \533arow.s4 = (mad24(local_x, 8, w + 4) < K) ? arow.s4 : 0.0f; \534arow.s5 = (mad24(local_x, 8, w + 5) < K) ? arow.s5 : 0.0f; \535arow.s6 = (mad24(local_x, 8, w + 6) < K) ? arow.s6 : 0.0f; \536arow.s7 = (mad24(local_x, 8, w + 7) < K) ? arow.s7 : 0.0f; \537_dot = mad( (Dtype8)(arow.s0), (Dtype8)(brow0.s0, brow1.s0, brow2.s0, brow3.s0, brow4.s0, brow5.s0, brow6.s0, brow7.s0), _dot ); \538_dot = mad( (Dtype8)(arow.s1), (Dtype8)(brow0.s1, brow1.s1, brow2.s1, brow3.s1, brow4.s1, brow5.s1, brow6.s1, brow7.s1), _dot ); \539_dot = mad( (Dtype8)(arow.s2), (Dtype8)(brow0.s2, brow1.s2, brow2.s2, brow3.s2, brow4.s2, brow5.s2, brow6.s2, brow7.s2), _dot ); \540_dot = mad( (Dtype8)(arow.s3), (Dtype8)(brow0.s3, brow1.s3, brow2.s3, brow3.s3, brow4.s3, brow5.s3, brow6.s3, brow7.s3), _dot ); \541_dot = mad( (Dtype8)(arow.s4), (Dtype8)(brow0.s4, brow1.s4, brow2.s4, brow3.s4, brow4.s4, brow5.s4, brow6.s4, brow7.s4), _dot ); \542_dot = mad( (Dtype8)(arow.s5), (Dtype8)(brow0.s5, brow1.s5, brow2.s5, brow3.s5, brow4.s5, brow5.s5, brow6.s5, brow7.s5), _dot ); \543_dot = mad( (Dtype8)(arow.s6), (Dtype8)(brow0.s6, brow1.s6, brow2.s6, brow3.s6, brow4.s6, brow5.s6, brow6.s6, brow7.s6), _dot ); \544_dot = mad( (Dtype8)(arow.s7), (Dtype8)(brow0.s7, brow1.s7, brow2.s7, brow3.s7, brow4.s7, brow5.s7, brow6.s7, brow7.s7), _dot );545546MM_DOT_PRODUCT( 0, dot00 );547MM_DOT_PRODUCT( 1, dot01 );548MM_DOT_PRODUCT( 2, dot02 );549MM_DOT_PRODUCT( 3, dot03 );550MM_DOT_PRODUCT( 4, dot04 );551MM_DOT_PRODUCT( 5, dot05 );552MM_DOT_PRODUCT( 6, dot06 );553MM_DOT_PRODUCT( 7, dot07 );554#undef MM_DOT_PRODUCT555}556557#define REDUCE(_dot) \558_dot = as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 0)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 1)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 2)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 3)) + \559as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 4)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 5)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 6)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 7));560561REDUCE(dot00);562REDUCE(dot01);563REDUCE(dot02);564REDUCE(dot03);565REDUCE(dot04);566REDUCE(dot05);567REDUCE(dot06);568REDUCE(dot07);569#undef REDUCE570571Dtype output = 0.0f;572#define OUTPUT( _dot) \573output = (local_x == 0) ? _dot.s0 : output; \574output = (local_x == 1) ? _dot.s1 : output; \575output = (local_x == 2) ? _dot.s2 : output; \576output = (local_x == 3) ? _dot.s3 : output; \577output = (local_x == 4) ? _dot.s4 : output; \578output = (local_x == 5) ? _dot.s5 : output; \579output = (local_x == 6) ? _dot.s6 : output; \580output = (local_x == 7) ? _dot.s7 : output; \581dst_write0[0] = mad(output, alpha, beta * dst_write0[0]); \582dst_write0 += N;583584if(global_x < N && global_y * 8 < M) {585OUTPUT(dot00);586if(mad24(global_y, 8, 1) < M) { OUTPUT(dot01); }587if(mad24(global_y, 8, 2) < M) { OUTPUT(dot02); }588if(mad24(global_y, 8, 3) < M) { OUTPUT(dot03); }589if(mad24(global_y, 8, 4) < M) { OUTPUT(dot04); }590if(mad24(global_y, 8, 5) < M) { OUTPUT(dot05); }591if(mad24(global_y, 8, 6) < M) { OUTPUT(dot06); }592if(mad24(global_y, 8, 7) < M) { OUTPUT(dot07); }593}594#undef OUTPUT595}596597#else598599__attribute__((reqd_work_group_size(8, LWG_HEIGHT, 1)))600__attribute__((intel_reqd_sub_group_size(8)))601__kernel void TEMPLATE(gemm_buffer_NT, Dtype)(602const __global Dtype *src0, int off0,603const __global Dtype *src1, int off1,604__global Dtype *dst, int offd,605int M,606int N,607int K,608KERNEL_ARG_DTYPE alpha_in,609KERNEL_ARG_DTYPE beta_in)610{611const Dtype alpha = (Dtype)alpha_in;612const Dtype beta = (Dtype)beta_in;613const int group_x = get_group_id(0);614const int group_y = get_group_id(1);615const int local_x = get_local_id(0);616const int local_y = get_local_id(1);617const int global_x = get_global_id(0);618const int global_y = get_global_id(1);619620Dtype8 dot00 = 0.f;621Dtype8 dot01 = 0.f;622Dtype8 dot02 = 0.f;623Dtype8 dot03 = 0.f;624Dtype8 dot04 = 0.f;625Dtype8 dot05 = 0.f;626Dtype8 dot06 = 0.f;627Dtype8 dot07 = 0.f;628629Dtype4 brow0;630Dtype4 brow1;631Dtype4 brow2;632Dtype4 brow3;633Dtype4 brow4;634Dtype4 brow5;635Dtype4 brow6;636Dtype4 brow7;637638__global Dtype *dst_write0 = dst + local_x * VEC_SIZE + (group_x * TILE_N) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * N + offd;639640const __global Dtype *src0_read = src0 + local_x * (TILE_K / 8) + (group_y * LWG_HEIGHT * TILE_M + local_y * TILE_M) * K + off0;641642const __global Dtype *src1_read0 = src1 + (group_x * TILE_N) * K + off1;643644__local Dtype slm_brow[8 * SLM_BLOCK];645__local Dtype* slm_brow0;646647int local_index = mad24(local_y, 8, local_x) * 4;648int w;649for(int b_tile = 0; b_tile < K; b_tile += SLM_BLOCK) {650barrier(CLK_LOCAL_MEM_FENCE);651vstore4(vload4(0, src1_read0 + mad24(0, K, local_index)), 0, slm_brow + mad24(0, SLM_BLOCK, local_index));652vstore4(vload4(0, src1_read0 + mad24(1, K, local_index)), 0, slm_brow + mad24(1, SLM_BLOCK, local_index));653vstore4(vload4(0, src1_read0 + mad24(2, K, local_index)), 0, slm_brow + mad24(2, SLM_BLOCK, local_index));654vstore4(vload4(0, src1_read0 + mad24(3, K, local_index)), 0, slm_brow + mad24(3, SLM_BLOCK, local_index));655vstore4(vload4(0, src1_read0 + mad24(4, K, local_index)), 0, slm_brow + mad24(4, SLM_BLOCK, local_index));656vstore4(vload4(0, src1_read0 + mad24(5, K, local_index)), 0, slm_brow + mad24(5, SLM_BLOCK, local_index));657vstore4(vload4(0, src1_read0 + mad24(6, K, local_index)), 0, slm_brow + mad24(6, SLM_BLOCK, local_index));658vstore4(vload4(0, src1_read0 + mad24(7, K, local_index)), 0, slm_brow + mad24(7, SLM_BLOCK, local_index));659barrier(CLK_LOCAL_MEM_FENCE);660661slm_brow0 = slm_brow + local_x * (TILE_K / 8);662w = b_tile;663int end_w = min(b_tile + SLM_BLOCK, K);664while( w + TILE_K <= end_w ) {665Dtype4 arow;666667brow0 = vload4(0, slm_brow0 + 0 * SLM_BLOCK);668brow1 = vload4(0, slm_brow0 + 1 * SLM_BLOCK);669brow2 = vload4(0, slm_brow0 + 2 * SLM_BLOCK);670brow3 = vload4(0, slm_brow0 + 3 * SLM_BLOCK);671brow4 = vload4(0, slm_brow0 + 4 * SLM_BLOCK);672brow5 = vload4(0, slm_brow0 + 5 * SLM_BLOCK);673brow6 = vload4(0, slm_brow0 + 6 * SLM_BLOCK);674brow7 = vload4(0, slm_brow0 + 7 * SLM_BLOCK);675676#define MM_DOT_PRODUCT( _row, _dot ) \677arow = vload4(0, src0_read + _row * K); \678_dot = mad( (Dtype8)(arow.x), (Dtype8)(brow0.x, brow1.x, brow2.x, brow3.x, brow4.x, brow5.x, brow6.x, brow7.x), _dot ); \679_dot = mad( (Dtype8)(arow.y), (Dtype8)(brow0.y, brow1.y, brow2.y, brow3.y, brow4.y, brow5.y, brow6.y, brow7.y), _dot ); \680_dot = mad( (Dtype8)(arow.z), (Dtype8)(brow0.z, brow1.z, brow2.z, brow3.z, brow4.z, brow5.z, brow6.z, brow7.z), _dot ); \681_dot = mad( (Dtype8)(arow.w), (Dtype8)(brow0.w, brow1.w, brow2.w, brow3.w, brow4.w, brow5.w, brow6.w, brow7.w), _dot );682683MM_DOT_PRODUCT( 0, dot00 );684MM_DOT_PRODUCT( 1, dot01 );685MM_DOT_PRODUCT( 2, dot02 );686MM_DOT_PRODUCT( 3, dot03 );687MM_DOT_PRODUCT( 4, dot04 );688MM_DOT_PRODUCT( 5, dot05 );689MM_DOT_PRODUCT( 6, dot06 );690MM_DOT_PRODUCT( 7, dot07 );691#undef MM_DOT_PRODUCT692693src0_read += TILE_K;694slm_brow0 += TILE_K;695w += TILE_K;696}697src1_read0 += SLM_BLOCK;698}699700if(w < K) {701Dtype4 arow;702703#define READ_BROW(_brow, _row) \704_brow = vload4(0, slm_brow0 + _row * SLM_BLOCK); \705_brow.x = (mad24(local_x, 4, w) < K) ? _brow.x : 0.0f; \706_brow.y = (mad24(local_x, 4, w + 1) < K) ? _brow.y : 0.0f; \707_brow.z = (mad24(local_x, 4, w + 2) < K) ? _brow.z : 0.0f; \708_brow.w = (mad24(local_x, 4, w + 3) < K) ? _brow.w : 0.0f;709710READ_BROW(brow0, 0);711READ_BROW(brow1, 1);712READ_BROW(brow2, 2);713READ_BROW(brow3, 3);714READ_BROW(brow4, 4);715READ_BROW(brow5, 5);716READ_BROW(brow6, 6);717READ_BROW(brow7, 7);718719#undef READ_BROW720721#define MM_DOT_PRODUCT( _row, _dot ) \722arow = vload4(0, src0_read + _row * K); \723arow.x = (mad24(local_x, 4, w) < K) ? arow.x : 0.0f; \724arow.y = (mad24(local_x, 4, w + 1) < K) ? arow.y : 0.0f; \725arow.z = (mad24(local_x, 4, w + 2) < K) ? arow.z : 0.0f; \726arow.w = (mad24(local_x, 4, w + 3) < K) ? arow.w : 0.0f; \727_dot = mad( (Dtype8)(arow.x), (Dtype8)(brow0.x, brow1.x, brow2.x, brow3.x, brow4.x, brow5.x, brow6.x, brow7.x), _dot ); \728_dot = mad( (Dtype8)(arow.y), (Dtype8)(brow0.y, brow1.y, brow2.y, brow3.y, brow4.y, brow5.y, brow6.y, brow7.y), _dot ); \729_dot = mad( (Dtype8)(arow.z), (Dtype8)(brow0.z, brow1.z, brow2.z, brow3.z, brow4.z, brow5.z, brow6.z, brow7.z), _dot ); \730_dot = mad( (Dtype8)(arow.w), (Dtype8)(brow0.w, brow1.w, brow2.w, brow3.w, brow4.w, brow5.w, brow6.w, brow7.w), _dot );731732MM_DOT_PRODUCT( 0, dot00 );733MM_DOT_PRODUCT( 1, dot01 );734MM_DOT_PRODUCT( 2, dot02 );735MM_DOT_PRODUCT( 3, dot03 );736MM_DOT_PRODUCT( 4, dot04 );737MM_DOT_PRODUCT( 5, dot05 );738MM_DOT_PRODUCT( 6, dot06 );739MM_DOT_PRODUCT( 7, dot07 );740#undef MM_DOT_PRODUCT741}742743#define REDUCE(_dot) \744_dot = as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 0)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 1)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 2)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 3)) + \745as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 4)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 5)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 6)) + as_Dtype8(intel_sub_group_shuffle(SHUFFLE_TYPE8(_dot), 7));746747REDUCE(dot00);748REDUCE(dot01);749REDUCE(dot02);750REDUCE(dot03);751REDUCE(dot04);752REDUCE(dot05);753REDUCE(dot06);754REDUCE(dot07);755#undef REDUCE756757Dtype output = 0.0f;758#define OUTPUT( _dot) \759output = (local_x == 0) ? _dot.s0 : output; \760output = (local_x == 1) ? _dot.s1 : output; \761output = (local_x == 2) ? _dot.s2 : output; \762output = (local_x == 3) ? _dot.s3 : output; \763output = (local_x == 4) ? _dot.s4 : output; \764output = (local_x == 5) ? _dot.s5 : output; \765output = (local_x == 6) ? _dot.s6 : output; \766output = (local_x == 7) ? _dot.s7 : output; \767dst_write0[0] = mad(output, alpha, beta * dst_write0[0]); \768dst_write0 += N;769770if(global_x < N && global_y * 8 < M) {771OUTPUT(dot00);772if(mad24(global_y, 8, 1) < M) { OUTPUT(dot01); }773if(mad24(global_y, 8, 2) < M) { OUTPUT(dot02); }774if(mad24(global_y, 8, 3) < M) { OUTPUT(dot03); }775if(mad24(global_y, 8, 4) < M) { OUTPUT(dot04); }776if(mad24(global_y, 8, 5) < M) { OUTPUT(dot05); }777if(mad24(global_y, 8, 6) < M) { OUTPUT(dot06); }778if(mad24(global_y, 8, 7) < M) { OUTPUT(dot07); }779}780#undef OUTPUT781}782#endif783784#undef VEC_SIZE785#undef LWG_HEIGHT786#undef TILE_M787#undef TILE_K788#undef TILE_N789#undef SLM_BLOCK790791#define SLM_SIZE 64792void TEMPLATE(gemm_buffer_NT_M_2_edgerows,Dtype)(793const __global Dtype* srca_read0,794const __global Dtype* srca_read1,795const __global Dtype* srcb_read,796__local Dtype4* work0,797__local Dtype4* work1,798int N,799int K,800int x_gid,801int lid,802Dtype alpha,803Dtype beta,804__global Dtype* dstc0,805__global Dtype* dstc1)806{807__local Dtype* work_each0 = (__local Dtype*)work0;808__local Dtype* work_each1 = (__local Dtype*)work1;809810int rows = N - x_gid * 4;811812Dtype4 dot0[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};813Dtype4 dot1[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};814815int i = lid;816while( i < K / 4) {817const Dtype4 b0 = {srca_read0[i*4], srca_read0[(i*4+1)], srca_read0[(i*4+2)], srca_read0[(i*4+3)]};818const Dtype4 b1 = {srca_read1[i*4], srca_read1[(i*4+1)], srca_read1[(i*4+2)], srca_read1[(i*4+3)]};819#pragma unroll820for(int j = 0; j < rows; ++j) {821dot0[j] += b0 * vload4(i, srcb_read + j * K);822dot1[j] += b1 * vload4(i, srcb_read + j * K);823}824825i += get_local_size(0);826}827#pragma unroll828for(int j = 0; j < rows; ++j) {829work_each0[lid * 4 + j] = dot0[j].x + dot0[j].y + dot0[j].z + dot0[j].w;830work_each1[lid * 4 + j] = dot1[j].x + dot1[j].y + dot1[j].z + dot1[j].w;831}832833if(i == K / 4) {834short tail_items = K % 4;835836if(tail_items != 0) {837const __global Dtype *srcb_tail = srcb_read + i * 4;838const __global Dtype *srca_tail0 = srca_read0 + i * 4;839const __global Dtype *srca_tail1 = srca_read1 + i * 4;840#pragma unroll841for(short i = 0; i < tail_items; ++i) {842const Dtype at0 = srca_tail0[i];843const Dtype at1 = srca_tail1[i];844#pragma unroll845for(int j = 0; j < rows; ++j) {846work_each0[lid * 4 + j] += at0 * srcb_tail[i + j * K];847work_each1[lid * 4 + j] += at1 * srcb_tail[i + j * K];848}849}850}851}852853for(int stride = get_local_size(0) >> 1; stride > 0 ; stride >>= 1) {854barrier(CLK_LOCAL_MEM_FENCE);855if(lid < stride) {856work0[lid] += work0[lid+stride];857work1[lid] += work1[lid+stride];858}859}860861if(lid == 0) {862#pragma unroll863for(int j = 0; j < rows; ++j) {864dstc0[(x_gid * 4 + j)] = alpha * work_each0[j] + beta * dstc0[(x_gid * 4 + j)];865dstc1[(x_gid * 4 + j)] = alpha * work_each1[j] + beta * dstc1[(x_gid * 4 + j)];866}867}868}869870__kernel void TEMPLATE(gemm_buffer_NT_M_2,Dtype)(871__global const Dtype * A,872int offA,873__global const Dtype * B,874int offB,875__global Dtype * C,876int offC,877int M,878int N,879int K,880KERNEL_ARG_DTYPE alpha_f,881KERNEL_ARG_DTYPE beta_f)882{883Dtype alpha = (Dtype)alpha_f;884Dtype beta = (Dtype)beta_f;885int x_gid = get_group_id(0);886int lid = get_local_id(0);887888const __global Dtype *srca_read0 = A + offA;889const __global Dtype *srca_read1 = srca_read0 + K;890891const __global Dtype *srcb_read = B + x_gid * 4 * K + offB;892893__global Dtype4 *dstc0 = (__global Dtype4*)(C + offC);894__global Dtype4 *dstc1 = (__global Dtype4*)((__global Dtype*)(dstc0) + N);895896__local Dtype4 work0[SLM_SIZE];897__local Dtype4 work1[SLM_SIZE];898__local Dtype* work_each0 = (__local Dtype*)work0;899__local Dtype* work_each1 = (__local Dtype*)work1;900901if(x_gid == N / 4) {902TEMPLATE(gemm_buffer_NT_M_2_edgerows,Dtype) \903(srca_read0, srca_read1, srcb_read, work0, work1, N, K, x_gid, lid, alpha, beta, (__global Dtype*)dstc0, (__global Dtype*)dstc1);904} else {905Dtype4 dot0[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};906Dtype4 dot1[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};907int i = lid;908while( i < K / 4) {909const Dtype4 b0 = vload4(i, srca_read0);910const Dtype4 b1 = vload4(i, srca_read1);911#pragma unroll912for(int j = 0; j < 4; ++j) {913Dtype4 a = vload4(i, srcb_read + j * K);914dot0[j] += b0 * a;915dot1[j] += b1 * a;916}917i += get_local_size(0);918}919920#pragma unroll921for(int j = 0; j < 4; ++j) {922work_each0[lid * 4 + j] = dot0[j].x + dot0[j].y + dot0[j].z + dot0[j].w;923work_each1[lid * 4 + j] = dot1[j].x + dot1[j].y + dot1[j].z + dot1[j].w;924}925926if(i == K / 4) {927short tail_items = K % 4;928if(tail_items != 0) {929const __global Dtype *srcb_tail = srcb_read + i * 4;930931const __global Dtype *srca_tail0 = srca_read0 + i * 4;932const __global Dtype *srca_tail1 = srca_read1 + i * 4;933#pragma unroll934for(short i = 0; i < tail_items; ++i) {935const Dtype at0 = srca_tail0[i];936const Dtype at1 = srca_tail1[i];937#pragma unroll938for(int j = 0; j < 4; ++j) {939work_each0[lid * 4 + j] += at0 * srcb_tail[i + j * K];940work_each1[lid * 4 + j] += at1 * srcb_tail[i + j * K];941}942}943}944}945946for(int stride = get_local_size(0) >> 1; stride > 0 ; stride >>= 1) {947barrier(CLK_LOCAL_MEM_FENCE);948if(lid < stride) {949work0[lid] += work0[lid+stride];950work1[lid] += work1[lid+stride];951}952}953954if(lid == 0) {955dstc0[x_gid] = alpha * work0[0] + beta * dstc0[x_gid];956dstc1[x_gid] = alpha * work1[0] + beta * dstc1[x_gid];957}958}959}960#undef SLM_SIZE961962#define SLM_SIZE 32963void TEMPLATE(gemm_buffer_NT_M_4_edgerows,Dtype)(964const __global Dtype* srca_read0,965const __global Dtype* srca_read1,966const __global Dtype* srca_read2,967const __global Dtype* srca_read3,968const __global Dtype* srcb_read,969__local Dtype4* work0,970__local Dtype4* work1,971__local Dtype4* work2,972__local Dtype4* work3,973int N,974int K,975int x_gid,976int lid,977Dtype alpha,978Dtype beta,979__global Dtype* dstc0,980__global Dtype* dstc1,981__global Dtype* dstc2,982__global Dtype* dstc3)983{984__local Dtype* work_each0 = (__local Dtype*)(work0 + lid);985__local Dtype* work_each1 = (__local Dtype*)(work1 + lid);986__local Dtype* work_each2 = (__local Dtype*)(work2 + lid);987__local Dtype* work_each3 = (__local Dtype*)(work3 + lid);988989int rows = N - x_gid * 4;990991Dtype4 dot0[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};992Dtype4 dot1[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};993Dtype4 dot2[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};994Dtype4 dot3[3] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};995996int i = lid;997while( i < K / 4) {998const Dtype4 a0 = {srca_read0[i*4], srca_read0[(i*4+1)], srca_read0[(i*4+2)], srca_read0[(i*4+3)]};999const Dtype4 a1 = {srca_read1[i*4], srca_read1[(i*4+1)], srca_read1[(i*4+2)], srca_read1[(i*4+3)]};1000const Dtype4 a2 = {srca_read2[i*4], srca_read2[(i*4+1)], srca_read2[(i*4+2)], srca_read2[(i*4+3)]};1001const Dtype4 a3 = {srca_read3[i*4], srca_read3[(i*4+1)], srca_read3[(i*4+2)], srca_read3[(i*4+3)]};1002#pragma unrol1003for(int j = 0; j < rows; ++j) {1004dot0[j] += a0 * vload4(i, srcb_read + j * K);1005dot1[j] += a1 * vload4(i, srcb_read + j * K);1006dot2[j] += a2 * vload4(i, srcb_read + j * K);1007dot3[j] += a3 * vload4(i, srcb_read + j * K);1008}10091010i += get_local_size(0);1011}1012#pragma unroll1013for(int j = 0; j < rows; ++j) {1014work_each0[j] = dot0[j].x + dot0[j].y + dot0[j].z + dot0[j].w;1015work_each1[j] = dot1[j].x + dot1[j].y + dot1[j].z + dot1[j].w;1016work_each2[j] = dot2[j].x + dot2[j].y + dot2[j].z + dot2[j].w;1017work_each3[j] = dot3[j].x + dot3[j].y + dot3[j].z + dot3[j].w;1018}10191020if(i == K / 4) {1021short tail_items = K % 4;10221023if(tail_items != 0) {1024const __global Dtype *srcb_tail = srcb_read + i * 4;10251026const __global Dtype *srca_tail0 = srca_read0 + i * 4;1027const __global Dtype *srca_tail1 = srca_read1 + i * 4;1028const __global Dtype *srca_tail2 = srca_read2 + i * 4;1029const __global Dtype *srca_tail3 = srca_read3 + i * 4;1030#pragma unroll1031for(short i = 0; i < tail_items; ++i) {1032const Dtype at0 = srca_tail0[i];1033const Dtype at1 = srca_tail1[i];1034const Dtype at2 = srca_tail2[i];1035const Dtype at3 = srca_tail3[i];1036#pragma unroll1037for(int j = 0; j < rows; ++j) {1038work_each0[j] += at0 * srcb_tail[i + j * K];1039work_each1[j] += at1 * srcb_tail[i + j * K];1040work_each2[j] += at2 * srcb_tail[i + j * K];1041work_each3[j] += at3 * srcb_tail[i + j * K];1042}1043}1044}1045}10461047for(int stride = get_local_size(0) >> 1; stride > 0 ; stride >>= 1) {1048barrier(CLK_LOCAL_MEM_FENCE);1049if(lid < stride) {1050work0[lid] += work0[lid+stride];1051work1[lid] += work1[lid+stride];1052work2[lid] += work2[lid+stride];1053work3[lid] += work3[lid+stride];1054}1055}10561057if(lid == 0) {1058#pragma unroll1059for(int j = 0; j < rows; ++j) {1060dstc0[(x_gid * 4 + j)] = alpha * work_each0[j] + beta * dstc0[(x_gid * 4 + j)];1061dstc1[(x_gid * 4 + j)] = alpha * work_each1[j] + beta * dstc1[(x_gid * 4 + j)];1062dstc2[(x_gid * 4 + j)] = alpha * work_each2[j] + beta * dstc2[(x_gid * 4 + j)];1063dstc3[(x_gid * 4 + j)] = alpha * work_each3[j] + beta * dstc3[(x_gid * 4 + j)];1064}1065}1066}10671068__kernel void TEMPLATE(gemm_buffer_NT_M_4,Dtype)(1069__global const Dtype * A,1070int offA,1071__global const Dtype * B,1072int offB,1073__global Dtype * C,1074int offC,1075int M,1076int N,1077int K,1078KERNEL_ARG_DTYPE alpha_f,1079KERNEL_ARG_DTYPE beta_f)1080{1081Dtype alpha = (Dtype)alpha_f;1082Dtype beta = (Dtype)beta_f;1083int x_gid = get_group_id(0);1084int lid = get_local_id(0);1085int lsize = get_local_size(0);10861087const __global Dtype *srca_read0 = A + offA;1088const __global Dtype *srca_read1 = srca_read0 + K;1089const __global Dtype *srca_read2 = srca_read1 + K;1090const __global Dtype *srca_read3 = srca_read2 + K;10911092const __global Dtype *srcb_read = B + x_gid * 4 * K + offB;10931094__global Dtype4 *dstc0 = (__global Dtype4*)(C + offC);1095__global Dtype4 *dstc1 = (__global Dtype4*)((__global Dtype*)(dstc0) + N);1096__global Dtype4 *dstc2 = (__global Dtype4*)((__global Dtype*)(dstc1) + N);1097__global Dtype4 *dstc3 = (__global Dtype4*)((__global Dtype*)(dstc2) + N);10981099__local Dtype4 work0[SLM_SIZE];1100__local Dtype4 work1[SLM_SIZE];1101__local Dtype4 work2[SLM_SIZE];1102__local Dtype4 work3[SLM_SIZE];1103__local Dtype* work_each0 = (__local Dtype*)(work0 + lid);1104__local Dtype* work_each1 = (__local Dtype*)(work1 + lid);1105__local Dtype* work_each2 = (__local Dtype*)(work2 + lid);1106__local Dtype* work_each3 = (__local Dtype*)(work3 + lid);11071108if(x_gid == N / 4) {1109TEMPLATE(gemm_buffer_NT_M_4_edgerows,Dtype) \1110(srca_read0, srca_read1, srca_read2, srca_read3, srcb_read, \1111work0, work1, work2, work3, N, K, x_gid, lid, alpha, beta, \1112(__global Dtype*)dstc0, (__global Dtype*)dstc1, (__global Dtype*)dstc2, (__global Dtype*)dstc3);1113} else {1114Dtype4 dot0[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};1115Dtype4 dot1[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};1116Dtype4 dot2[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};1117Dtype4 dot3[4] = {(Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.), (Dtype4)(0.)};11181119int kid = lid;1120while( kid < K / 4) {1121const Dtype4 b0 = vload4(kid, srca_read0);1122const Dtype4 b1 = vload4(kid, srca_read1);1123const Dtype4 b2 = vload4(kid, srca_read2);1124const Dtype4 b3 = vload4(kid, srca_read3);1125#pragma unroll1126for(int j = 0; j < 4; ++j) {1127Dtype4 a = vload4(kid, srcb_read + j * K);1128dot0[j] += b0 * a;1129dot1[j] += b1 * a;1130dot2[j] += b2 * a;1131dot3[j] += b3 * a;1132}1133kid += lsize;1134}1135#pragma unroll1136for(int j = 0; j < 4; ++j) {1137work_each0[j] = dot0[j].x + dot0[j].y + dot0[j].z + dot0[j].w;1138work_each1[j] = dot1[j].x + dot1[j].y + dot1[j].z + dot1[j].w;1139work_each2[j] = dot2[j].x + dot2[j].y + dot2[j].z + dot2[j].w;1140work_each3[j] = dot3[j].x + dot3[j].y + dot3[j].z + dot3[j].w;1141}11421143if(kid == (K >> 2)) {1144short tail_items = K % 4;1145if(tail_items != 0) {1146int offset = kid << 2;1147const __global Dtype *srcb_tail = srcb_read + offset;11481149const __global Dtype *srca_tail0 = srca_read0 + offset;1150const __global Dtype *srca_tail1 = srca_read1 + offset;1151const __global Dtype *srca_tail2 = srca_read2 + offset;1152const __global Dtype *srca_tail3 = srca_read3 + offset;1153#pragma unroll1154for(short i = 0; i < tail_items; ++i) {1155const Dtype at0 = srca_tail0[i];1156const Dtype at1 = srca_tail1[i];1157const Dtype at2 = srca_tail2[i];1158const Dtype at3 = srca_tail3[i];1159#pragma unroll1160for(int j = 0; j < 4; ++j) {1161work_each0[j] += at0 * srcb_tail[i + j * K];1162work_each1[j] += at1 * srcb_tail[i + j * K];1163work_each2[j] += at2 * srcb_tail[i + j * K];1164work_each3[j] += at3 * srcb_tail[i + j * K];1165}1166}1167}1168}11691170for(int stride = get_local_size(0) >> 1; stride > 0 ; stride >>= 1) {1171barrier(CLK_LOCAL_MEM_FENCE);1172if(lid < stride) {1173work0[lid] += work0[lid+stride];1174work1[lid] += work1[lid+stride];1175work2[lid] += work2[lid+stride];1176work3[lid] += work3[lid+stride];1177}1178}11791180if(lid == 0) {1181dstc0[x_gid] = alpha * work0[0] + beta * dstc0[x_gid];1182dstc1[x_gid] = alpha * work1[0] + beta * dstc1[x_gid];1183dstc2[x_gid] = alpha * work2[0] + beta * dstc2[x_gid];1184dstc3[x_gid] = alpha * work3[0] + beta * dstc3[x_gid];1185}1186}1187}1188#undef SLM_SIZE11891190#define SLM_SIZE 161191__kernel void TEMPLATE(gemm_buffer_NT_M_8,Dtype)(1192__global const Dtype * A,1193int offA,1194__global const Dtype * B,1195int offB,1196__global Dtype * C,1197int offC,1198int M,1199int N,1200int K,1201KERNEL_ARG_DTYPE alpha_f,1202KERNEL_ARG_DTYPE beta_f)1203{1204Dtype alpha = (Dtype)alpha_f;1205Dtype beta = (Dtype)beta_f;1206int x_gid = get_group_id(0);1207int lid = get_local_id(0);1208int lsize = get_local_size(0);12091210const __global Dtype *srca_read0 = A + offA;1211const __global Dtype *srca_read1 = srca_read0 + K;1212const __global Dtype *srca_read2 = srca_read1 + K;1213const __global Dtype *srca_read3 = srca_read2 + K;1214const __global Dtype *srca_read4 = srca_read3 + K;1215const __global Dtype *srca_read5 = srca_read4 + K;1216const __global Dtype *srca_read6 = srca_read5 + K;1217const __global Dtype *srca_read7 = srca_read6 + K;12181219const __global Dtype *srcb_read = B + x_gid * K + offB;12201221__global Dtype *dstc0 = C + offC;1222__global Dtype *dstc1 = dstc0 + N;1223__global Dtype *dstc2 = dstc1 + N;1224__global Dtype *dstc3 = dstc2 + N;1225__global Dtype *dstc4 = dstc3 + N;1226__global Dtype *dstc5 = dstc4 + N;1227__global Dtype *dstc6 = dstc5 + N;1228__global Dtype *dstc7 = dstc6 + N;12291230__local Dtype work0[SLM_SIZE];1231__local Dtype work1[SLM_SIZE];1232__local Dtype work2[SLM_SIZE];1233__local Dtype work3[SLM_SIZE];1234__local Dtype work4[SLM_SIZE];1235__local Dtype work5[SLM_SIZE];1236__local Dtype work6[SLM_SIZE];1237__local Dtype work7[SLM_SIZE];12381239Dtype4 dot0 = (Dtype4)(0.);1240Dtype4 dot1 = (Dtype4)(0.);1241Dtype4 dot2 = (Dtype4)(0.);1242Dtype4 dot3 = (Dtype4)(0.);1243Dtype4 dot4 = (Dtype4)(0.);1244Dtype4 dot5 = (Dtype4)(0.);1245Dtype4 dot6 = (Dtype4)(0.);1246Dtype4 dot7 = (Dtype4)(0.);12471248int kid = lid;1249while( kid < K / 4) {1250const Dtype4 a0 = vload4(kid, srca_read0);1251const Dtype4 a1 = vload4(kid, srca_read1);1252const Dtype4 a2 = vload4(kid, srca_read2);1253const Dtype4 a3 = vload4(kid, srca_read3);1254const Dtype4 a4 = vload4(kid, srca_read4);1255const Dtype4 a5 = vload4(kid, srca_read5);1256const Dtype4 a6 = vload4(kid, srca_read6);1257const Dtype4 a7 = vload4(kid, srca_read7);1258Dtype4 b = vload4(kid, srcb_read);1259dot0 += a0 * b;1260dot1 += a1 * b;1261dot2 += a2 * b;1262dot3 += a3 * b;1263dot4 += a4 * b;1264dot5 += a5 * b;1265dot6 += a6 * b;1266dot7 += a7 * b;12671268kid += lsize;1269}1270work0[lid] = dot0.x + dot0.y + dot0.z + dot0.w;1271work1[lid] = dot1.x + dot1.y + dot1.z + dot1.w;1272work2[lid] = dot2.x + dot2.y + dot2.z + dot2.w;1273work3[lid] = dot3.x + dot3.y + dot3.z + dot3.w;1274work4[lid] = dot4.x + dot4.y + dot4.z + dot4.w;1275work5[lid] = dot5.x + dot5.y + dot5.z + dot5.w;1276work6[lid] = dot6.x + dot6.y + dot6.z + dot6.w;1277work7[lid] = dot7.x + dot7.y + dot7.z + dot7.w;12781279if(kid == (K >> 2)) {1280short tail_items = K % 4;1281if(tail_items != 0) {1282int offset = kid << 2;1283const __global Dtype *srcb_tail = srcb_read + offset;12841285const __global Dtype *srca_tail0 = srca_read0 + offset;1286const __global Dtype *srca_tail1 = srca_read1 + offset;1287const __global Dtype *srca_tail2 = srca_read2 + offset;1288const __global Dtype *srca_tail3 = srca_read3 + offset;1289const __global Dtype *srca_tail4 = srca_read4 + offset;1290const __global Dtype *srca_tail5 = srca_read5 + offset;1291const __global Dtype *srca_tail6 = srca_read6 + offset;1292const __global Dtype *srca_tail7 = srca_read7 + offset;1293#pragma unroll1294for(short item = 0; item < tail_items; ++item) {1295work0[lid] += srca_tail0[item] * srcb_tail[item];1296work1[lid] += srca_tail1[item] * srcb_tail[item];1297work2[lid] += srca_tail2[item] * srcb_tail[item];1298work3[lid] += srca_tail3[item] * srcb_tail[item];1299work4[lid] += srca_tail4[item] * srcb_tail[item];1300work5[lid] += srca_tail5[item] * srcb_tail[item];1301work6[lid] += srca_tail6[item] * srcb_tail[item];1302work7[lid] += srca_tail7[item] * srcb_tail[item];1303}1304}1305}13061307for(int stride = get_local_size(0) >> 1; stride > 0 ; stride >>= 1) {1308barrier(CLK_LOCAL_MEM_FENCE);1309if(lid < stride) {1310work0[lid] += work0[lid+stride];1311work1[lid] += work1[lid+stride];1312work2[lid] += work2[lid+stride];1313work3[lid] += work3[lid+stride];1314work4[lid] += work4[lid+stride];1315work5[lid] += work5[lid+stride];1316work6[lid] += work6[lid+stride];1317work7[lid] += work7[lid+stride];1318}1319}13201321if(lid == 0) {1322dstc0[x_gid] = alpha * work0[0] + beta * dstc0[x_gid];1323dstc1[x_gid] = alpha * work1[0] + beta * dstc1[x_gid];1324dstc2[x_gid] = alpha * work2[0] + beta * dstc2[x_gid];1325dstc3[x_gid] = alpha * work3[0] + beta * dstc3[x_gid];1326dstc4[x_gid] = alpha * work4[0] + beta * dstc4[x_gid];1327dstc5[x_gid] = alpha * work5[0] + beta * dstc5[x_gid];1328dstc6[x_gid] = alpha * work6[0] + beta * dstc6[x_gid];1329dstc7[x_gid] = alpha * work7[0] + beta * dstc7[x_gid];1330}1331}1332#undef SLM_SIZE13331334#undef VEC_SIZE1335#undef LWG_HEIGHT1336#undef TILE_M1337#undef TILE_K1338#undef TILE_N1339#undef SIMD_SIZE_GEMM1340#undef SHUFFLE_TYPE21341#undef SHUFFLE_TYPE8134213431344