Path: blob/master/modules/dnn/src/opencl/gemm_image.cl
16337 views
/*M///////////////////////////////////////////////////////////////////////////////////////1//2// IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.3//4// By downloading, copying, installing or using the software you agree to this license.5// If you do not agree to this license, do not download, install,6// copy or use the software.7//8//9// License Agreement10// For Open Source Computer Vision Library11//12// Copyright (C) 2017, Intel Corporation, all rights reserved.13// Third party copyrights are property of their respective owners.14//15// Redistribution and use in source and binary forms, with or without modification,16// are permitted provided that the following conditions are met:17//18// * Redistribution's of source code must retain the above copyright notice,19// this list of conditions and the following disclaimer.20//21// * Redistribution's in binary form must reproduce the above copyright notice,22// this list of conditions and the following disclaimer in the documentation23// and/or other materials provided with the distribution.24//25// * The name of the copyright holders may not be used to endorse or promote products26// derived from this software without specific prior written permission.27//28// This software is provided by the copyright holders and contributors "as is" and29// any express or implied warranties, including, but not limited to, the implied30// warranties of merchantability and fitness for a particular purpose are disclaimed.31// In no event shall the Intel Corporation or contributors be liable for any direct,32// indirect, incidental, special, exemplary, or consequential damages33// (including, but not limited to, procurement of substitute goods or services;34// loss of use, data, or profits; or business interruption) however caused35// and on any theory of liability, whether in contract, strict liability,36// or tort (including negligence or otherwise) arising in any way out of37// the use of this software, even if advised of the possibility of such damage.38//39//M*/4041#if defined(cl_khr_fp16)42#pragma OPENCL EXTENSION cl_khr_fp16 : enable43#endif4445#define CONCAT(A,B) A##_##B46#define TEMPLATE(name,type) CONCAT(name,type)4748#define KERNEL_ARG_DTYPE float49#define TYPE_FLOAT 150#define TYPE_HALF 25152#if TYPE == TYPE_HALF53#define Dtype half54#define Dtype2 half255#define Dtype4 half456#define Dtype8 half857#define Dtype16 half165859#define as_Dtype as_half60#define as_Dtype2 as_half261#define as_Dtype4 as_half462#define as_Dtype8 as_half863#define as_Dtype16 as_half1664#else65#define Dtype float66#define Dtype2 float267#define Dtype4 float468#define Dtype8 float869#define Dtype16 float167071#define as_Dtype as_float72#define as_Dtype2 as_float273#define as_Dtype4 as_float474#define as_Dtype8 as_float875#define as_Dtype16 as_float1676#endif7778#if defined(cl_intel_subgroups)79#pragma OPENCL EXTENSION cl_intel_subgroups : enable80#endif8182#define TILE_M 3283#define TILE_K 88485// common block to calculate (alpha * AxB + beta * C) and output to destination image.8687#if TYPE == TYPE_HALF88#define SUBGROUP_BLOCK_READ8( __image, __coord ) intel_sub_group_block_read_us8( __image, __coord )89#define SHUFFLE_TYPE2(val) as_ushort2(val)90#define SHUFFLE_TYPE8(val) as_ushort8(val)91#define READ_IMAGE(__image, __coord) read_imageh(__image, sampler, __coord)92#define SIZE_OF_ELEMENT sizeof(ushort)93#define SIMD_SIZE_GEMM 1694#define TILE_N 1695#else96#define SUBGROUP_BLOCK_READ8( __image, __coord ) intel_sub_group_block_read8( __image, __coord )97#define SHUFFLE_TYPE2(val) val98#define SHUFFLE_TYPE8(val) val99#define READ_IMAGE(__image, __coord) read_imagef(__image, sampler, __coord)100#define SIZE_OF_ELEMENT sizeof(uint)101#define SIMD_SIZE_GEMM 8102#define TILE_N 8103#endif104105//#define USE_IMAGE_C106#ifdef USE_IMAGE_C107#if TYPE == TYPE_HALF108#define BLOCKC_READ8( _C, _coordC ) as_Dtype8( intel_sub_group_block_read_us8( _C, _coordC ) )109#define BLOCKC_WRITE8( _C, _coordC, _val ) intel_sub_group_block_write_us8( _C, _coordC, as_ushort8( _val ) )110#else111#define BLOCKC_READ8( _C, _coordC ) as_Dtype8( intel_sub_group_block_read8( _C, _coordC ) )112#define BLOCKC_WRITE8( _C, _coordC, _val ) intel_sub_group_block_write8( _C, _coordC, as_uint8( _val ) )113#endif114#define MATC_PARAMETER __read_only image2d_t C, __write_only image2d_t dst115#define GEMM_OUTPUT(ALPHA1, BETA_NOT0) GEMM_OUTPUT_EXT(ALPHA1, BETA_NOT0, C, dst, sizeof(uint))116#else117#define BLOCKC_READ8( _C, _coordC ) \118(Dtype8) ( (_coordC.x + get_local_id(0) < N && _coordC.y < M) ? _C[ _coordC.y * ldc + _coordC.x + get_local_id(0) ] : 0, \119(_coordC.x + get_local_id(0) < N && _coordC.y + 1 < M) ? _C[ ( _coordC.y + 1 ) * ldc + _coordC.x + get_local_id(0) ] : 0, \120(_coordC.x + get_local_id(0) < N && _coordC.y + 2 < M) ? _C[ ( _coordC.y + 2 ) * ldc + _coordC.x + get_local_id(0) ] : 0, \121(_coordC.x + get_local_id(0) < N && _coordC.y + 3 < M) ? _C[ ( _coordC.y + 3 ) * ldc + _coordC.x + get_local_id(0) ] : 0, \122(_coordC.x + get_local_id(0) < N && _coordC.y + 4 < M) ? _C[ ( _coordC.y + 4 ) * ldc + _coordC.x + get_local_id(0) ] : 0, \123(_coordC.x + get_local_id(0) < N && _coordC.y + 5 < M) ? _C[ ( _coordC.y + 5 ) * ldc + _coordC.x + get_local_id(0) ] : 0, \124(_coordC.x + get_local_id(0) < N && _coordC.y + 6 < M) ? _C[ ( _coordC.y + 6 ) * ldc + _coordC.x + get_local_id(0) ] : 0, \125(_coordC.x + get_local_id(0) < N && _coordC.y + 7 < M) ? _C[ ( _coordC.y + 7 ) * ldc + _coordC.x + get_local_id(0) ] : 0)126127#define BLOCKC_WRITE8( _C, _coordC, _val) do {\128if (_coordC.x + get_local_id(0) < N) { \129if (_coordC.y < M) \130_C[ _coordC.y * ldc + _coordC.x + get_local_id(0) ] = _val.s0; \131if (_coordC.y + 1 < M) \132_C[ ( _coordC.y + 1 )* ldc + _coordC.x + get_local_id(0) ] = _val.s1; \133if (_coordC.y + 2 < M) \134_C[ ( _coordC.y + 2 )* ldc + _coordC.x + get_local_id(0) ] = _val.s2; \135if (_coordC.y + 3 < M) \136_C[ ( _coordC.y + 3 )* ldc + _coordC.x + get_local_id(0) ] = _val.s3; \137if (_coordC.y + 4 < M) \138_C[ ( _coordC.y + 4 )* ldc + _coordC.x + get_local_id(0) ] = _val.s4; \139if (_coordC.y + 5 < M) \140_C[ ( _coordC.y + 5 )* ldc + _coordC.x + get_local_id(0) ] = _val.s5; \141if (_coordC.y + 6 < M) \142_C[ ( _coordC.y + 6 )* ldc + _coordC.x + get_local_id(0) ] = _val.s6; \143if (_coordC.y + 7 < M) \144_C[ ( _coordC.y + 7 )* ldc + _coordC.x + get_local_id(0) ] = _val.s7; \145}} while(0)146#define MATC_PARAMETER __global Dtype * C, const int offC, const int M, const int N, const int ldc147#define GEMM_OUTPUT(ALPHA1, BETA_NOT0) GEMM_OUTPUT_EXT(ALPHA1, BETA_NOT0, (C + offC), (C + offC), 1)148#endif149150#define GEMM_OUTPUT_EXT(ALPHA1, BETA_NOT0, _C, _dst, _C_step) \151int2 coordDst = (int2)( ( group_x * TILE_N ) * _C_step, ( group_y * TILE_M ) ); \152int2 coordC = coordDst; \153Dtype8 blockC00; \154Dtype8 blockC01; \155Dtype8 blockC02; \156Dtype8 blockC03; \157if (BETA_NOT0) { \158blockC00 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); coordC.y += 8; \159blockC01 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); coordC.y += 8; \160blockC02 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); coordC.y += 8; \161blockC03 = isFirstColBlock ? BLOCKC_READ8( _C, coordC ) * beta : BLOCKC_READ8( _C, coordC ); \162if (!ALPHA1) { \163blockC00 = mad(blockAxB00, (Dtype8)alpha, blockC00); \164blockC01 = mad(blockAxB01, (Dtype8)alpha, blockC01); \165blockC02 = mad(blockAxB02, (Dtype8)alpha, blockC02); \166blockC03 = mad(blockAxB03, (Dtype8)alpha, blockC03); \167} else { \168blockC00 += blockAxB00; \169blockC01 += blockAxB01; \170blockC02 += blockAxB02; \171blockC03 += blockAxB03; \172} \173} else { \174blockC00 = isFirstColBlock ? (Dtype)0. : BLOCKC_READ8( _C, coordC ); coordC.y += 8; \175blockC01 = isFirstColBlock ? (Dtype)0. : BLOCKC_READ8( _C, coordC ); coordC.y += 8; \176blockC02 = isFirstColBlock ? (Dtype)0. : BLOCKC_READ8( _C, coordC ); coordC.y += 8; \177blockC03 = isFirstColBlock ? (Dtype)0. : BLOCKC_READ8( _C, coordC ); \178if (!ALPHA1) { \179blockC00 = mad(blockAxB00, (Dtype8)alpha, blockC00); \180blockC01 = mad(blockAxB01, (Dtype8)alpha, blockC01); \181blockC02 = mad(blockAxB02, (Dtype8)alpha, blockC02); \182blockC03 = mad(blockAxB03, (Dtype8)alpha, blockC03); \183} else { \184blockC00 += blockAxB00; \185blockC01 += blockAxB01; \186blockC02 += blockAxB02; \187blockC03 += blockAxB03; \188} \189} \190BLOCKC_WRITE8( _dst, coordDst, blockC00 ); coordDst.y += 8; \191BLOCKC_WRITE8( _dst, coordDst, blockC01 ); coordDst.y += 8; \192BLOCKC_WRITE8( _dst, coordDst, blockC02 ); coordDst.y += 8; \193BLOCKC_WRITE8( _dst, coordDst, blockC03 );194195// Get the specified column of the block of the block196#define TRANSPOSE_BLOCK_8( _block, _col ) \197(Dtype8)( intel_sub_group_shuffle( _block.s0, _col ), \198intel_sub_group_shuffle( _block.s1, _col ), \199intel_sub_group_shuffle( _block.s2, _col ), \200intel_sub_group_shuffle( _block.s3, _col ), \201intel_sub_group_shuffle( _block.s4, _col ), \202intel_sub_group_shuffle( _block.s5, _col ), \203intel_sub_group_shuffle( _block.s6, _col ), \204intel_sub_group_shuffle( _block.s7, _col ) );205206// A's column block multiply B 's row block.207#if TYPE == TYPE_HALF208#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB00, _blockB01 ) \209{ \210const Dtype8 acol0 = TRANSPOSE_BLOCK_8( _blockA, 0 ); \211const Dtype8 acol1 = TRANSPOSE_BLOCK_8( _blockA, 1 ); \212const Dtype8 acol2 = TRANSPOSE_BLOCK_8( _blockA, 2 ); \213const Dtype8 acol3 = TRANSPOSE_BLOCK_8( _blockA, 3 ); \214const Dtype8 acol4 = TRANSPOSE_BLOCK_8( _blockA, 4 ); \215const Dtype8 acol5 = TRANSPOSE_BLOCK_8( _blockA, 5 ); \216const Dtype8 acol6 = TRANSPOSE_BLOCK_8( _blockA, 6 ); \217const Dtype8 acol7 = TRANSPOSE_BLOCK_8( _blockA, 7 ); \218const Dtype8 acol8 = TRANSPOSE_BLOCK_8( _blockA, 8 ); \219const Dtype8 acol9 = TRANSPOSE_BLOCK_8( _blockA, 9 ); \220const Dtype8 acola = TRANSPOSE_BLOCK_8( _blockA, 10 ); \221const Dtype8 acolb = TRANSPOSE_BLOCK_8( _blockA, 11 ); \222const Dtype8 acolc = TRANSPOSE_BLOCK_8( _blockA, 12 ); \223const Dtype8 acold = TRANSPOSE_BLOCK_8( _blockA, 13 ); \224const Dtype8 acole = TRANSPOSE_BLOCK_8( _blockA, 14 ); \225const Dtype8 acolf = TRANSPOSE_BLOCK_8( _blockA, 15 ); \226_result = mad( (Dtype8)(_blockB00.s0), acol0, _result ); \227_result = mad( (Dtype8)(_blockB00.s1), acol1, _result ); \228_result = mad( (Dtype8)(_blockB00.s2), acol2, _result ); \229_result = mad( (Dtype8)(_blockB00.s3), acol3, _result ); \230_result = mad( (Dtype8)(_blockB00.s4), acol4, _result ); \231_result = mad( (Dtype8)(_blockB00.s5), acol5, _result ); \232_result = mad( (Dtype8)(_blockB00.s6), acol6, _result ); \233_result = mad( (Dtype8)(_blockB00.s7), acol7, _result ); \234_result = mad( (Dtype8)(_blockB01.s0), acol8, _result ); \235_result = mad( (Dtype8)(_blockB01.s1), acol9, _result ); \236_result = mad( (Dtype8)(_blockB01.s2), acola, _result ); \237_result = mad( (Dtype8)(_blockB01.s3), acolb, _result ); \238_result = mad( (Dtype8)(_blockB01.s4), acolc, _result ); \239_result = mad( (Dtype8)(_blockB01.s5), acold, _result ); \240_result = mad( (Dtype8)(_blockB01.s6), acole, _result ); \241_result = mad( (Dtype8)(_blockB01.s7), acolf, _result ); \242}243#else244#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB ) \245{ \246const Dtype8 acol0 = TRANSPOSE_BLOCK_8( _blockA, 0 ); \247const Dtype8 acol1 = TRANSPOSE_BLOCK_8( _blockA, 1 ); \248const Dtype8 acol2 = TRANSPOSE_BLOCK_8( _blockA, 2 ); \249const Dtype8 acol3 = TRANSPOSE_BLOCK_8( _blockA, 3 ); \250const Dtype8 acol4 = TRANSPOSE_BLOCK_8( _blockA, 4 ); \251const Dtype8 acol5 = TRANSPOSE_BLOCK_8( _blockA, 5 ); \252const Dtype8 acol6 = TRANSPOSE_BLOCK_8( _blockA, 6 ); \253const Dtype8 acol7 = TRANSPOSE_BLOCK_8( _blockA, 7 ); \254_result = mad( (Dtype8)(_blockB.s0), acol0, _result ); \255_result = mad( (Dtype8)(_blockB.s1), acol1, _result ); \256_result = mad( (Dtype8)(_blockB.s2), acol2, _result ); \257_result = mad( (Dtype8)(_blockB.s3), acol3, _result ); \258_result = mad( (Dtype8)(_blockB.s4), acol4, _result ); \259_result = mad( (Dtype8)(_blockB.s5), acol5, _result ); \260_result = mad( (Dtype8)(_blockB.s6), acol6, _result ); \261_result = mad( (Dtype8)(_blockB.s7), acol7, _result ); \262}263#endif264265#if TYPE == TYPE_HALF266#define GEMM_NN(ALPHA1, BETA_NOT0) \267__attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) \268__attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) \269__kernel void TEMPLATE(gemm_32_1_NN_ ##ALPHA1 ##_ ##BETA_NOT0, Dtype)( \270__read_only image2d_t A, \271__read_only image2d_t B, \272MATC_PARAMETER, \273KERNEL_ARG_DTYPE alpha_in, \274KERNEL_ARG_DTYPE beta_in, \275int width0, \276int isFirstColBlock) \277{ \278const Dtype alpha = (Dtype)alpha_in; \279const Dtype beta = (Dtype)beta_in; \280const int group_x = get_group_id(0); \281const int group_y = get_group_id(1); \282Dtype8 blockAxB00 = 0; \283Dtype8 blockAxB01 = 0; \284Dtype8 blockAxB02 = 0; \285Dtype8 blockAxB03 = 0; \286int2 coordA = (int2)( 0, group_y * TILE_M ); \287int2 coordB = (int2)( ( group_x * TILE_N ) * SIZE_OF_ELEMENT, 0 ); \288do \289{ \290int2 coordBTemp = coordB; \291Dtype8 blockB00 = as_Dtype8( SUBGROUP_BLOCK_READ8( B, coordBTemp ) ); coordB.y += TILE_K; \292Dtype8 blockB01 = as_Dtype8( SUBGROUP_BLOCK_READ8( B, coordBTemp ) ); coordB.y += TILE_K; \293int2 coordATemp = coordA; \294Dtype8 blockA00 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \295Dtype8 blockA01 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \296Dtype8 blockA02 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \297Dtype8 blockA03 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.x += TILE_K * SIZE_OF_ELEMENT * 2; \298MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00, blockB01 ); \299MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00, blockB01 ); \300MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00, blockB01 ); \301MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03, blockB00, blockB01 ); \302} \303while( coordB.y < width0 ); \304GEMM_OUTPUT(ALPHA1, BETA_NOT0); \305}306#else307#define GEMM_NN(ALPHA1, BETA_NOT0) \308__attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) \309__attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) \310__kernel void TEMPLATE(gemm_32_1_NN_ ##ALPHA1 ##_ ##BETA_NOT0, Dtype)( \311__read_only image2d_t A, \312__read_only image2d_t B, \313MATC_PARAMETER, \314KERNEL_ARG_DTYPE alpha_in, \315KERNEL_ARG_DTYPE beta_in, \316int width0, \317int isFirstColBlock) \318{ \319const Dtype alpha = (Dtype)alpha_in; \320const Dtype beta = (Dtype)beta_in; \321const int group_x = get_group_id(0); \322const int group_y = get_group_id(1); \323Dtype8 blockAxB00 = 0.0f; \324Dtype8 blockAxB01 = 0.0f; \325Dtype8 blockAxB02 = 0.0f; \326Dtype8 blockAxB03 = 0.0f; \327int2 coordA = (int2)( 0, group_y * TILE_M ); \328int2 coordB = (int2)( ( group_x * TILE_N ) * SIZE_OF_ELEMENT, 0 ); \329do \330{ \331int2 coordBTemp = coordB; \332Dtype8 blockB00 = as_Dtype8( SUBGROUP_BLOCK_READ8( B, coordBTemp ) ); coordB.y += TILE_K; \333int2 coordATemp = coordA; \334Dtype8 blockA00 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \335Dtype8 blockA01 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \336Dtype8 blockA02 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \337Dtype8 blockA03 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.x += TILE_K * SIZE_OF_ELEMENT; \338MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00 ); \339MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00 ); \340MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00 ); \341MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03, blockB00 ); \342} \343while( coordB.y < width0 ); \344GEMM_OUTPUT(ALPHA1, BETA_NOT0); \345}346#endif347348GEMM_NN(1, 0) // ALPHA == 1, BETA == 0349GEMM_NN(1, 1) // ALPHA == 1, BETA != 0350GEMM_NN(0, 0) // ALPHA != 1, BETA == 0351GEMM_NN(0, 1) // ALPHA != 1, BETA != 0352353#undef TRANSPOSE_BLOCK_8354#undef MULTIPLY_BLOCKS_8x8355#undef GEMM_NN356357// replicate the first row to column block.358#define TRANSPOSE_BLOCK_8(_vec, _col) \359(Dtype8)( intel_sub_group_shuffle(_vec, _col + 0), \360intel_sub_group_shuffle(_vec, _col + 1), \361intel_sub_group_shuffle(_vec, _col + 2), \362intel_sub_group_shuffle(_vec, _col + 3), \363intel_sub_group_shuffle(_vec, _col + 4), \364intel_sub_group_shuffle(_vec, _col + 5), \365intel_sub_group_shuffle(_vec, _col + 6), \366intel_sub_group_shuffle(_vec, _col + 7) )367368#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB, _col ) \369{ \370_result = mad( (Dtype8)(_blockB.s0), TRANSPOSE_BLOCK_8(_blockA.s0, _col), _result ); \371_result = mad( (Dtype8)(_blockB.s1), TRANSPOSE_BLOCK_8(_blockA.s1, _col), _result ); \372_result = mad( (Dtype8)(_blockB.s2), TRANSPOSE_BLOCK_8(_blockA.s2, _col), _result ); \373_result = mad( (Dtype8)(_blockB.s3), TRANSPOSE_BLOCK_8(_blockA.s3, _col), _result ); \374_result = mad( (Dtype8)(_blockB.s4), TRANSPOSE_BLOCK_8(_blockA.s4, _col), _result ); \375_result = mad( (Dtype8)(_blockB.s5), TRANSPOSE_BLOCK_8(_blockA.s5, _col), _result ); \376_result = mad( (Dtype8)(_blockB.s6), TRANSPOSE_BLOCK_8(_blockA.s6, _col), _result ); \377_result = mad( (Dtype8)(_blockB.s7), TRANSPOSE_BLOCK_8(_blockA.s7, _col), _result ); \378}379380#if TYPE == TYPE_HALF381#define GEMM_TN(ALPHA1, BETA_NOT0) \382__attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) \383__attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) \384__kernel void TEMPLATE(gemm_32_1_TN_ ##ALPHA1 ##_ ##BETA_NOT0,Dtype)( \385__read_only image2d_t A, \386__read_only image2d_t B, \387MATC_PARAMETER, \388KERNEL_ARG_DTYPE alpha_in, \389KERNEL_ARG_DTYPE beta_in, \390int width0, \391int isFirstColBlock) \392{ \393const Dtype alpha = (Dtype)alpha_in; \394const Dtype beta = (Dtype)beta_in; \395const int group_x = get_group_id(0);\396const int group_y = get_group_id(1);\397Dtype8 blockAxB00 = 0;\398Dtype8 blockAxB01 = 0;\399Dtype8 blockAxB02 = 0;\400Dtype8 blockAxB03 = 0;\401int2 coordA = (int2)( group_y * TILE_M * SIZE_OF_ELEMENT, 0 );\402int2 coordB = (int2)( ( group_x * TILE_N ) * SIZE_OF_ELEMENT, 0 );\403do\404{\405int2 coordBTemp = coordB;\406Dtype8 blockB00 = as_Dtype8( SUBGROUP_BLOCK_READ8( B, coordBTemp ) ); coordB.y += TILE_K;\407int2 coordATemp = coordA;\408Dtype8 blockA00 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 16 * SIZE_OF_ELEMENT;\409Dtype8 blockA01 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.y += TILE_K;\410MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00, 0); \411MULTIPLY_BLOCKS_8x8( blockAxB01, blockA00, blockB00, 8); \412MULTIPLY_BLOCKS_8x8( blockAxB02, blockA01, blockB00, 0); \413MULTIPLY_BLOCKS_8x8( blockAxB03, blockA01, blockB00, 8); \414} \415while( coordB.y < width0 ); \416GEMM_OUTPUT(ALPHA1, BETA_NOT0); \417}418#else419#define GEMM_TN(ALPHA1, BETA_NOT0) \420__attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) \421__attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) \422__kernel void TEMPLATE(gemm_32_1_TN_ ##ALPHA1 ##_ ##BETA_NOT0,Dtype)( \423__read_only image2d_t A, \424__read_only image2d_t B, \425MATC_PARAMETER, \426KERNEL_ARG_DTYPE alpha_in, \427KERNEL_ARG_DTYPE beta_in, \428int width0, \429int isFirstColBlock) \430{ \431const Dtype alpha = (Dtype)alpha_in; \432const Dtype beta = (Dtype)beta_in; \433const int group_x = get_group_id(0);\434const int group_y = get_group_id(1);\435Dtype8 blockAxB00 = 0.0f;\436Dtype8 blockAxB01 = 0.0f;\437Dtype8 blockAxB02 = 0.0f;\438Dtype8 blockAxB03 = 0.0f;\439int2 coordA = (int2)( group_y * TILE_M * SIZE_OF_ELEMENT, 0 );\440int2 coordB = (int2)( ( group_x * TILE_N ) * SIZE_OF_ELEMENT, 0 );\441do\442{\443int2 coordBTemp = coordB;\444Dtype8 blockB00 = as_Dtype8( SUBGROUP_BLOCK_READ8( B, coordBTemp ) ); coordB.y += TILE_K;\445int2 coordATemp = coordA;\446Dtype8 blockA00 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT;\447Dtype8 blockA01 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT;\448Dtype8 blockA02 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT;\449Dtype8 blockA03 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.y += TILE_K;\450MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00, 0 ); \451MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00, 0 ); \452MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00, 0 ); \453MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03, blockB00, 0 ); \454} \455while( coordB.y < width0 ); \456GEMM_OUTPUT(ALPHA1, BETA_NOT0); \457}458#endif459460GEMM_TN(1, 0) // ALPHA == 1, BETA == 0461GEMM_TN(1, 1) // ALPHA == 1, BETA != 0462GEMM_TN(0, 0) // ALPHA != 1, BETA == 0463GEMM_TN(0, 1) // ALPHA != 1, BETA != 0464465#undef MULTIPLY_BLOCKS_8x8466#undef TRANSPOSE_BLOCK_8467#undef GEMM_TN468469// The same as GEMM_NN470#define TRANSPOSE_BLOCK_8( _block, _col ) \471(Dtype8)( intel_sub_group_shuffle( _block.s0, _col), \472intel_sub_group_shuffle( _block.s1, _col), \473intel_sub_group_shuffle( _block.s2, _col), \474intel_sub_group_shuffle( _block.s3, _col), \475intel_sub_group_shuffle( _block.s4, _col), \476intel_sub_group_shuffle( _block.s5, _col), \477intel_sub_group_shuffle( _block.s6, _col), \478intel_sub_group_shuffle( _block.s7, _col) )479480#if TYPE == TYPE_HALF481#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB ) \482{ \483const Dtype8 acol0 = TRANSPOSE_BLOCK_8( _blockA, 0 ); \484const Dtype8 acol1 = TRANSPOSE_BLOCK_8( _blockA, 1 ); \485const Dtype8 acol2 = TRANSPOSE_BLOCK_8( _blockA, 2 ); \486const Dtype8 acol3 = TRANSPOSE_BLOCK_8( _blockA, 3 ); \487const Dtype8 acol4 = TRANSPOSE_BLOCK_8( _blockA, 4 ); \488const Dtype8 acol5 = TRANSPOSE_BLOCK_8( _blockA, 5 ); \489const Dtype8 acol6 = TRANSPOSE_BLOCK_8( _blockA, 6 ); \490const Dtype8 acol7 = TRANSPOSE_BLOCK_8( _blockA, 7 ); \491const Dtype8 acol8 = TRANSPOSE_BLOCK_8( _blockA, 8 ); \492const Dtype8 acol9 = TRANSPOSE_BLOCK_8( _blockA, 9 ); \493const Dtype8 acola = TRANSPOSE_BLOCK_8( _blockA, 10 ); \494const Dtype8 acolb = TRANSPOSE_BLOCK_8( _blockA, 11 ); \495const Dtype8 acolc = TRANSPOSE_BLOCK_8( _blockA, 12 ); \496const Dtype8 acold = TRANSPOSE_BLOCK_8( _blockA, 13 ); \497const Dtype8 acole = TRANSPOSE_BLOCK_8( _blockA, 14 ); \498const Dtype8 acolf = TRANSPOSE_BLOCK_8( _blockA, 15 ); \499_result = mad( (Dtype8)_blockB.s0, acol0, _result ); \500_result = mad( (Dtype8)_blockB.s1, acol1, _result ); \501_result = mad( (Dtype8)_blockB.s2, acol2, _result ); \502_result = mad( (Dtype8)_blockB.s3, acol3, _result ); \503_result = mad( (Dtype8)_blockB.s4, acol4, _result ); \504_result = mad( (Dtype8)_blockB.s5, acol5, _result ); \505_result = mad( (Dtype8)_blockB.s6, acol6, _result ); \506_result = mad( (Dtype8)_blockB.s7, acol7, _result ); \507_result = mad( (Dtype8)_blockB.s8, acol8, _result ); \508_result = mad( (Dtype8)_blockB.s9, acol9, _result ); \509_result = mad( (Dtype8)_blockB.sa, acola, _result ); \510_result = mad( (Dtype8)_blockB.sb, acolb, _result ); \511_result = mad( (Dtype8)_blockB.sc, acolc, _result ); \512_result = mad( (Dtype8)_blockB.sd, acold, _result ); \513_result = mad( (Dtype8)_blockB.se, acole, _result ); \514_result = mad( (Dtype8)_blockB.sf, acolf, _result ); \515}516#else517#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB ) \518{ \519const Dtype8 acol0 = TRANSPOSE_BLOCK_8( _blockA, 0 ); \520const Dtype8 acol1 = TRANSPOSE_BLOCK_8( _blockA, 1 ); \521const Dtype8 acol2 = TRANSPOSE_BLOCK_8( _blockA, 2 ); \522const Dtype8 acol3 = TRANSPOSE_BLOCK_8( _blockA, 3 ); \523const Dtype8 acol4 = TRANSPOSE_BLOCK_8( _blockA, 4 ); \524const Dtype8 acol5 = TRANSPOSE_BLOCK_8( _blockA, 5 ); \525const Dtype8 acol6 = TRANSPOSE_BLOCK_8( _blockA, 6 ); \526const Dtype8 acol7 = TRANSPOSE_BLOCK_8( _blockA, 7 ); \527_result = mad( (Dtype8)_blockB.s0, acol0, _result ); \528_result = mad( (Dtype8)_blockB.s1, acol1, _result ); \529_result = mad( (Dtype8)_blockB.s2, acol2, _result ); \530_result = mad( (Dtype8)_blockB.s3, acol3, _result ); \531_result = mad( (Dtype8)_blockB.s4, acol4, _result ); \532_result = mad( (Dtype8)_blockB.s5, acol5, _result ); \533_result = mad( (Dtype8)_blockB.s6, acol6, _result ); \534_result = mad( (Dtype8)_blockB.s7, acol7, _result ); \535}536#endif537538#if TYPE == TYPE_HALF539#define GEMM_NT(ALPHA1, BETA_NOT0, VECSCALAR, VECSIZE) \540__attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) \541__attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) \542__kernel void TEMPLATE(gemm_32_1_NT_ ##VECSCALAR ##_ ##ALPHA1 ##_ ##BETA_NOT0,Dtype)( \543__read_only image2d_t A, \544MATB_PARAMETER, \545MATC_PARAMETER, \546KERNEL_ARG_DTYPE alpha_in, \547KERNEL_ARG_DTYPE beta_in, \548int padded_k, \549int k, \550int isFirstColBlock) \551{ \552const Dtype alpha = (Dtype)alpha_in; \553const Dtype beta = (Dtype)beta_in; \554const int group_x = get_group_id(0); \555const int group_y = get_group_id(1); \556Dtype8 blockAxB00 = 0; \557Dtype8 blockAxB01 = 0; \558Dtype8 blockAxB02 = 0; \559Dtype8 blockAxB03 = 0; \560int2 coordA = (int2)( 0, group_y * TILE_M ); \561int2 coordB = (int2)( 0, ( group_x * TILE_N )); \562const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; \563do \564{ \565Dtype16 blockB00; \566BLOCKB_READ8(blockB00, B, coordB); \567int2 coordATemp = coordA; \568Dtype8 blockA00 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \569Dtype8 blockA01 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \570Dtype8 blockA02 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \571Dtype8 blockA03 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.x += TILE_K * SIZE_OF_ELEMENT * 2; \572MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00 ); \573MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00 ); \574MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00 ); \575MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03, blockB00 ); \576} \577while( coordB.x < padded_k / VECSIZE ); \578GEMM_OUTPUT(ALPHA1, BETA_NOT0); \579}580#else581#define GEMM_NT(ALPHA1, BETA_NOT0, VECSCALAR, VECSIZE) \582__attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) \583__attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) \584__kernel void TEMPLATE(gemm_32_1_NT_ ##VECSCALAR ##_ ##ALPHA1 ##_ ##BETA_NOT0,Dtype)( \585__read_only image2d_t A, \586MATB_PARAMETER, \587MATC_PARAMETER, \588KERNEL_ARG_DTYPE alpha_in, \589KERNEL_ARG_DTYPE beta_in, \590int padded_k, \591int k, \592int isFirstColBlock) \593{ \594const Dtype alpha = (Dtype)alpha_in; \595const Dtype beta = (Dtype)beta_in; \596const int group_x = get_group_id(0); \597const int group_y = get_group_id(1); \598Dtype8 blockAxB00 = 0.0f; \599Dtype8 blockAxB01 = 0.0f; \600Dtype8 blockAxB02 = 0.0f; \601Dtype8 blockAxB03 = 0.0f; \602int2 coordA = (int2)( 0, group_y * TILE_M ); \603int2 coordB = (int2)( 0, ( group_x * TILE_N )); \604const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; \605do \606{ \607Dtype8 blockB00; \608BLOCKB_READ8(blockB00, B, coordB); \609int2 coordATemp = coordA; \610Dtype8 blockA00 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \611Dtype8 blockA01 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \612Dtype8 blockA02 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.y += 8; \613Dtype8 blockA03 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.x += TILE_K * SIZE_OF_ELEMENT; \614MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00 ); \615MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01, blockB00 ); \616MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02, blockB00 ); \617MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03, blockB00 ); \618} \619while( coordB.x < padded_k / VECSIZE ); \620GEMM_OUTPUT(ALPHA1, BETA_NOT0); \621}622#endif623624#if TYPE == TYPE_HALF625#define BLOCKB_READ8(_blockb, _B, _coordB) \626int2 _coordBTemp = _coordB; \627_coordBTemp.y += get_local_id(0); \628_blockb.s0123 = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \629_blockb.s4567 = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \630_blockb.s89ab = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \631_blockb.scdef = READ_IMAGE(_B, _coordBTemp); _coordB.x += 4;632#else633#define BLOCKB_READ8(_blockb, _B, _coordB) \634int2 _coordBTemp = _coordB; \635_coordBTemp.y += get_local_id(0); \636_blockb.s0123 = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \637_blockb.s4567 = READ_IMAGE(_B, _coordBTemp); _coordB.x += 2;638#endif639640#define MATB_PARAMETER __read_only image2d_t B641642GEMM_NT(1, 0, VEC4, 4) // ALPHA == 1, BETA == 0643GEMM_NT(1, 1, VEC4, 4) // ALPHA == 1, BETA != 0644GEMM_NT(0, 0, VEC4, 4) // ALPHA != 1, BETA == 0645GEMM_NT(0, 1, VEC4, 4) // ALPHA != 1, BETA != 0646#undef BLOCKB_READ8647#undef MATB_PARAMETER648649#if TYPE == TYPE_HALF650#define BLOCKB_READ8(_blockb, _B, _coordB) \651int2 _coordBTemp = _coordB; \652_coordBTemp.y += get_local_id(0); \653const __global float *B_read = (__global float *)(_B + (_coordBTemp.y * ldb) + _coordBTemp.x + offB); \654_blockb = as_Dtype16(as_ushort16(vload8(0, B_read))); \655_coordB.x += TILE_K * 2;656#else657#define BLOCKB_READ8(_blockb, _B, _coordB) \658int2 _coordBTemp = _coordB; \659_coordBTemp.y += get_local_id(0); \660const __global Dtype *B_read = (__global Dtype *)(_B + (_coordBTemp.y * ldb) + _coordBTemp.x + offB); \661_blockb = vload8(0, B_read); \662_coordB.x += TILE_K;663#endif664665#define MATB_PARAMETER __global Dtype *B, int offB, int ldb666667GEMM_NT(1, 0, BUFFER, 1) // ALPHA == 1, BETA == 0668GEMM_NT(1, 1, BUFFER, 1) // ALPHA == 1, BETA != 0669GEMM_NT(0, 0, BUFFER, 1) // ALPHA != 1, BETA == 0670GEMM_NT(0, 1, BUFFER, 1) // ALPHA != 1, BETA != 0671#undef BLOCKB_READ8672#undef MATB_PARAMETER673674#if TYPE == TYPE_HALF675#define BLOCKB_READ8(_blockb, _B, _coordB) \676int2 _coordBTemp = _coordB; \677_coordBTemp.y += get_local_id(0); \678Dtype4 temp; \679temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \680_blockb.s0 = temp.s0; \681temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \682_blockb.s1 = temp.s0; \683temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \684_blockb.s2 = temp.s0; \685temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \686_blockb.s3 = temp.s0; \687temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \688_blockb.s4 = temp.s0; \689temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \690_blockb.s5 = temp.s0; \691temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \692_blockb.s6 = temp.s0; \693temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \694_blockb.s7 = temp.s0; \695temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \696_blockb.s8 = temp.s0; \697temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \698_blockb.s9 = temp.s0; \699temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \700_blockb.sa = temp.s0; \701temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \702_blockb.sb = temp.s0; \703temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \704_blockb.sc = temp.s0; \705temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \706_blockb.sd = temp.s0; \707temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \708_blockb.se = temp.s0; \709temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \710_blockb.sf = temp.s0; \711_coordB.x += 16;712#else713#define BLOCKB_READ8(_blockb, _B, _coordB) \714int2 _coordBTemp = _coordB; \715_coordBTemp.y += get_local_id(0); \716Dtype4 temp; \717temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \718_blockb.s0 = temp.s0; \719temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \720_blockb.s1 = temp.s0; \721temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \722_blockb.s2 = temp.s0; \723temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \724_blockb.s3 = temp.s0; \725temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \726_blockb.s4 = temp.s0; \727temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \728_blockb.s5 = temp.s0; \729temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \730_blockb.s6 = temp.s0; \731temp = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \732_blockb.s7 = temp.s0; \733_coordB.x += 8;734#endif735736#define MATB_PARAMETER __read_only image2d_t B737738GEMM_NT(1, 0, SCALAR, 1) // ALPHA == 1, BETA == 0739GEMM_NT(1, 1, SCALAR, 1) // ALPHA == 1, BETA != 0740GEMM_NT(0, 0, SCALAR, 1) // ALPHA != 1, BETA == 0741GEMM_NT(0, 1, SCALAR, 1) // ALPHA != 1, BETA != 0742#undef BLOCKB_READ8743#undef MATB_PARAMETER744745#undef MULTIPLY_BLOCKS_8x8746#undef TRANSPOSE_BLOCK_8747#undef GEMM_NT748749//The same as GEMM_TN.750#define TRANSPOSE_BLOCK_8(_vec, _col) \751(Dtype8)( intel_sub_group_shuffle(_vec, _col + 0), \752intel_sub_group_shuffle(_vec, _col + 1), \753intel_sub_group_shuffle(_vec, _col + 2), \754intel_sub_group_shuffle(_vec, _col + 3), \755intel_sub_group_shuffle(_vec, _col + 4), \756intel_sub_group_shuffle(_vec, _col + 5), \757intel_sub_group_shuffle(_vec, _col + 6), \758intel_sub_group_shuffle(_vec, _col + 7) );759760#define MULTIPLY_BLOCKS_8x8( _result, _blockA, _blockB, _col ) \761{ \762const Dtype8 acol0 = TRANSPOSE_BLOCK_8( _blockA.s0, _col ); \763const Dtype8 acol1 = TRANSPOSE_BLOCK_8( _blockA.s1, _col ); \764const Dtype8 acol2 = TRANSPOSE_BLOCK_8( _blockA.s2, _col ); \765const Dtype8 acol3 = TRANSPOSE_BLOCK_8( _blockA.s3, _col ); \766const Dtype8 acol4 = TRANSPOSE_BLOCK_8( _blockA.s4, _col ); \767const Dtype8 acol5 = TRANSPOSE_BLOCK_8( _blockA.s5, _col ); \768const Dtype8 acol6 = TRANSPOSE_BLOCK_8( _blockA.s6, _col ); \769const Dtype8 acol7 = TRANSPOSE_BLOCK_8( _blockA.s7, _col ); \770_result = mad( (Dtype8)_blockB.s0, acol0, _result ); \771_result = mad( (Dtype8)_blockB.s1, acol1, _result ); \772_result = mad( (Dtype8)_blockB.s2, acol2, _result ); \773_result = mad( (Dtype8)_blockB.s3, acol3, _result ); \774_result = mad( (Dtype8)_blockB.s4, acol4, _result ); \775_result = mad( (Dtype8)_blockB.s5, acol5, _result ); \776_result = mad( (Dtype8)_blockB.s6, acol6, _result ); \777_result = mad( (Dtype8)_blockB.s7, acol7, _result ); \778}779780#if TYPE == TYPE_HALF781#define GEMM_TT(ALPHA1, BETA_NOT0, VECSCALAR, VECSIZE) \782__attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) \783__attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) \784__kernel void TEMPLATE(gemm_32_1_TT_ ##VECSCALAR ##_ ##ALPHA1 ##_ ##BETA_NOT0, Dtype)( \785__read_only image2d_t A, \786MATB_PARAMETER, \787MATC_PARAMETER, \788KERNEL_ARG_DTYPE alpha_in, \789KERNEL_ARG_DTYPE beta_in, \790int padded_k, \791int k, \792int isFirstColBlock) \793{ \794const Dtype alpha = (Dtype)alpha_in; \795const Dtype beta = (Dtype)beta_in; \796const int group_x = get_group_id(0); \797const int group_y = get_group_id(1); \798Dtype8 blockAxB00 = 0; \799Dtype8 blockAxB01 = 0; \800Dtype8 blockAxB02 = 0; \801Dtype8 blockAxB03 = 0; \802int2 coordA = (int2)( group_y * TILE_M * SIZE_OF_ELEMENT, 0 ); \803int2 coordB = (int2)( 0, ( group_x * TILE_N )); \804const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; \805do \806{ \807Dtype8 blockB00; \808BLOCKB_READ8(blockB00, B, coordB); \809int2 coordATemp = coordA; \810Dtype8 blockA00 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 16 * SIZE_OF_ELEMENT;\811Dtype8 blockA01 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.y += TILE_K;\812MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00, blockB00, 0); \813MULTIPLY_BLOCKS_8x8( blockAxB01, blockA00, blockB00, 8); \814MULTIPLY_BLOCKS_8x8( blockAxB02, blockA01, blockB00, 0); \815MULTIPLY_BLOCKS_8x8( blockAxB03, blockA01, blockB00, 8); \816} \817while( coordB.x < padded_k / VECSIZE ); \818GEMM_OUTPUT(ALPHA1, BETA_NOT0);\819}820#else821#define GEMM_TT(ALPHA1, BETA_NOT0, VECSCALAR, VECSIZE) \822__attribute__((intel_reqd_sub_group_size(SIMD_SIZE_GEMM))) \823__attribute__((reqd_work_group_size(SIMD_SIZE_GEMM, 1, 1))) \824__kernel void TEMPLATE(gemm_32_1_TT_ ##VECSCALAR ##_ ##ALPHA1 ##_ ##BETA_NOT0, Dtype)( \825__read_only image2d_t A, \826MATB_PARAMETER, \827MATC_PARAMETER, \828KERNEL_ARG_DTYPE alpha_in, \829KERNEL_ARG_DTYPE beta_in, \830int padded_k, \831int k, \832int isFirstColBlock) \833{ \834const Dtype alpha = (Dtype)alpha_in; \835const Dtype beta = (Dtype)beta_in; \836const int group_x = get_group_id(0); \837const int group_y = get_group_id(1); \838Dtype8 blockAxB00 = 0.0f; \839Dtype8 blockAxB01 = 0.0f; \840Dtype8 blockAxB02 = 0.0f; \841Dtype8 blockAxB03 = 0.0f; \842int2 coordA = (int2)( group_y * TILE_M * SIZE_OF_ELEMENT, 0 ); \843int2 coordB = (int2)( 0, ( group_x * TILE_N )); \844const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; \845do \846{ \847Dtype8 blockB00; \848BLOCKB_READ8(blockB00, B, coordB); \849int2 coordATemp = coordA; \850Dtype8 blockA00 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT; \851Dtype8 blockA01 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT; \852Dtype8 blockA02 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordATemp.x += 8 * SIZE_OF_ELEMENT; \853Dtype8 blockA03 = as_Dtype8( SUBGROUP_BLOCK_READ8( A, coordATemp ) ); coordA.y += TILE_K; \854MULTIPLY_BLOCKS_8x8( blockAxB00, blockA00 , blockB00, 0 ); \855MULTIPLY_BLOCKS_8x8( blockAxB01, blockA01 , blockB00, 0 ); \856MULTIPLY_BLOCKS_8x8( blockAxB02, blockA02 , blockB00, 0 ); \857MULTIPLY_BLOCKS_8x8( blockAxB03, blockA03 , blockB00, 0 ); \858} \859while( coordB.x < padded_k / VECSIZE ); \860GEMM_OUTPUT(ALPHA1, BETA_NOT0);\861}862#endif863864#define BLOCKB_READ8(_blockb, _B, _coordB) \865int2 _coordBTemp = _coordB; \866_coordBTemp.y += get_local_id(0); \867_blockb.s0123 = READ_IMAGE(_B, _coordBTemp); _coordBTemp.x += 1; \868_blockb.s4567 = READ_IMAGE(_B, _coordBTemp); _coordB.x += 2;869870#define MATB_PARAMETER __read_only image2d_t B871872GEMM_TT(1, 0, VEC4, 4) // ALPHA == 1, BETA == 0873GEMM_TT(1, 1, VEC4, 4) // ALPHA == 1, BETA != 0874GEMM_TT(0, 0, VEC4, 4) // ALPHA != 1, BETA == 0875GEMM_TT(0, 1, VEC4, 4) // ALPHA != 1, BETA != 0876#undef BLOCKB_READ8877#undef MATB_PARAMETER878879#if TYPE == TYPE_HALF880#define BLOCKB_READ8(_blockb, _B, _coordB) \881int2 _coordBTemp = _coordB; \882_coordBTemp.y += get_local_id(0); \883const __global float *B_read = (__global float *)(_B + (_coordBTemp.y * k) + _coordBTemp.x + offB); \884_blockb = as_Dtype8(as_ushort8(vload4(0, B_read))); \885_coordB.x += TILE_K;886#else887#define BLOCKB_READ8(_blockb, _B, _coordB) \888int2 _coordBTemp = _coordB; \889_coordBTemp.y += get_local_id(0); \890const __global Dtype *B_read = (__global Dtype *)(_B + (_coordBTemp.y * k) + _coordBTemp.x + offB); \891_blockb = vload8(0, B_read); \892_coordB.x += TILE_K;893#endif894895#define MATB_PARAMETER __global Dtype *B, int offB, int ldb896897GEMM_TT(1, 0, BUFFER, 1) // ALPHA == 1, BETA == 0898GEMM_TT(1, 1, BUFFER, 1) // ALPHA == 1, BETA != 0899GEMM_TT(0, 0, BUFFER, 1) // ALPHA != 1, BETA == 0900GEMM_TT(0, 1, BUFFER, 1) // ALPHA != 1, BETA != 0901#undef BLOCKB_READ8902#undef MATB_PARAMETER903904#define BLOCKB_READ8(_blockb, _B, _coordB) \905int2 _coordBTemp = _coordB; \906_coordBTemp.y += get_local_id(0); \907Dtype4 temp; \908temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; \909_blockb.s0 = temp.s0; \910temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; \911_blockb.s1 = temp.s0; \912temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; \913_blockb.s2 = temp.s0; \914temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; \915_blockb.s3 = temp.s0; \916temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; \917_blockb.s4 = temp.s0; \918temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; \919_blockb.s5 = temp.s0; \920temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; \921_blockb.s6 = temp.s0; \922temp = READ_IMAGE(B, _coordBTemp); _coordBTemp.x += 1; \923_blockb.s7 = temp.s0; \924_coordB.x += 8;925926#define MATB_PARAMETER __read_only image2d_t B927928GEMM_TT(1, 0, SCALAR, 1) // ALPHA == 1, BETA == 0929GEMM_TT(1, 1, SCALAR, 1) // ALPHA == 1, BETA != 0930GEMM_TT(0, 0, SCALAR, 1) // ALPHA != 1, BETA == 0931GEMM_TT(0, 1, SCALAR, 1) // ALPHA != 1, BETA != 0932#undef BLOCKB_READ8933#undef MATB_PARAMETER934935#undef MULTIPLY_BLOCKS_8x8936#undef TRANSPOSE_BLOCK_8937#undef GEMM_TT938939#undef TILE_M940#undef TILE_K941#undef TILE_N942#undef SUBGROUP_BLOCK_READ8943#undef READ_IMAGE944#undef SIZE_OF_ELEMENT945946__kernel void TEMPLATE(gemm_buffer_copy_image_transpose, Dtype)(947__global Dtype* A,948__write_only image2d_t ImA,949int offA,950int width,951int height,952int ldA)953{954const int gidx = get_global_id(0);955const int gidy = get_global_id(1);956int2 coord_dst = (int2)(gidx, gidy);957__global Dtype* A_off = A + offA;958Dtype srcA = A_off[gidy * ldA + gidx];959#if TYPE == TYPE_HALF960write_imageh(ImA, coord_dst, (Dtype4)srcA);961#else962write_imagef(ImA, coord_dst, (Dtype4)srcA);963#endif964}965966__kernel void TEMPLATE(gemm_buffer_copy_image_no_transpose, Dtype)(967__global Dtype* A,968__write_only image2d_t ImA,969int offA,970int width,971int height,972int ldA)973{974const int gidx = get_global_id(0);975const int gidy = get_global_id(1);976int2 coord_dst = (int2)(gidx, gidy);977#if TYPE == TYPE_HALF978if (gidx >= width || gidy >= height) {979write_imageh(ImA, coord_dst, 0);980return;981}982__global Dtype* A_off = A + offA;983write_imageh(ImA, coord_dst, A_off[gidy * ldA + gidx]);984#else985if (gidx >= width || gidy >= height) {986write_imageui(ImA, coord_dst, (uint4)0);987return;988}989__global Dtype* A_off = A + offA;990uint4 srcA = convert_uint4(as_uchar4(A_off[gidy * ldA + gidx]));991write_imageui(ImA, coord_dst, srcA);992#endif993}994995996