Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
official-stockfish
GitHub Repository: official-stockfish/stockfish
Path: blob/master/src/nnue/simd.h
513 views
1
/*
2
Stockfish, a UCI chess playing engine derived from Glaurung 2.1
3
Copyright (C) 2004-2026 The Stockfish developers (see AUTHORS file)
4
5
Stockfish is free software: you can redistribute it and/or modify
6
it under the terms of the GNU General Public License as published by
7
the Free Software Foundation, either version 3 of the License, or
8
(at your option) any later version.
9
10
Stockfish is distributed in the hope that it will be useful,
11
but WITHOUT ANY WARRANTY; without even the implied warranty of
12
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
GNU General Public License for more details.
14
15
You should have received a copy of the GNU General Public License
16
along with this program. If not, see <http://www.gnu.org/licenses/>.
17
*/
18
19
#ifndef NNUE_SIMD_H_INCLUDED
20
#define NNUE_SIMD_H_INCLUDED
21
22
#if defined(USE_AVX2)
23
#include <immintrin.h>
24
25
#elif defined(USE_SSE41)
26
#include <smmintrin.h>
27
28
#elif defined(USE_SSSE3)
29
#include <tmmintrin.h>
30
31
#elif defined(USE_SSE2)
32
#include <emmintrin.h>
33
34
#elif defined(USE_NEON)
35
#include <arm_neon.h>
36
#endif
37
38
#include "../types.h"
39
#include "nnue_common.h"
40
41
namespace Stockfish::Eval::NNUE::SIMD {
42
43
// If vector instructions are enabled, we update and refresh the
44
// accumulator tile by tile such that each tile fits in the CPU's
45
// vector registers.
46
#define VECTOR
47
48
#ifdef USE_AVX512
49
using vec_t = __m512i;
50
using vec_i8_t = __m256i;
51
using vec128_t = __m128i;
52
using psqt_vec_t = __m256i;
53
using vec_uint_t = __m512i;
54
#define vec_load(a) _mm512_load_si512(a)
55
#define vec_store(a, b) _mm512_store_si512(a, b)
56
#define vec_convert_8_16(a) _mm512_cvtepi8_epi16(a)
57
#define vec_add_16(a, b) _mm512_add_epi16(a, b)
58
#define vec_sub_16(a, b) _mm512_sub_epi16(a, b)
59
#define vec_mulhi_16(a, b) _mm512_mulhi_epi16(a, b)
60
#define vec_zero() _mm512_setzero_epi32()
61
#define vec_set_16(a) _mm512_set1_epi16(a)
62
#define vec_max_16(a, b) _mm512_max_epi16(a, b)
63
#define vec_min_16(a, b) _mm512_min_epi16(a, b)
64
#define vec_slli_16(a, b) _mm512_slli_epi16(a, b)
65
// Inverse permuted at load time
66
#define vec_packus_16(a, b) _mm512_packus_epi16(a, b)
67
#define vec_load_psqt(a) _mm256_load_si256(a)
68
#define vec_store_psqt(a, b) _mm256_store_si256(a, b)
69
#define vec_add_psqt_32(a, b) _mm256_add_epi32(a, b)
70
#define vec_sub_psqt_32(a, b) _mm256_sub_epi32(a, b)
71
#define vec_zero_psqt() _mm256_setzero_si256()
72
73
#ifdef USE_SSSE3
74
#define vec_nnz(a) _mm512_cmpgt_epi32_mask(a, _mm512_setzero_si512())
75
#endif
76
77
#define vec128_zero _mm_setzero_si128()
78
#define vec128_set_16(a) _mm_set1_epi16(a)
79
#define vec128_load(a) _mm_load_si128(a)
80
#define vec128_storeu(a, b) _mm_storeu_si128(a, b)
81
#define vec128_add(a, b) _mm_add_epi16(a, b)
82
#define NumRegistersSIMD 16
83
#define MaxChunkSize 64
84
85
#elif USE_AVX2
86
using vec_t = __m256i;
87
using vec_i8_t = __m128i;
88
using vec128_t = __m128i;
89
using psqt_vec_t = __m256i;
90
using vec_uint_t = __m256i;
91
#define vec_load(a) _mm256_load_si256(a)
92
#define vec_store(a, b) _mm256_store_si256(a, b)
93
#define vec_convert_8_16(a) _mm256_cvtepi8_epi16(a)
94
#define vec_add_16(a, b) _mm256_add_epi16(a, b)
95
#define vec_sub_16(a, b) _mm256_sub_epi16(a, b)
96
#define vec_mulhi_16(a, b) _mm256_mulhi_epi16(a, b)
97
#define vec_zero() _mm256_setzero_si256()
98
#define vec_set_16(a) _mm256_set1_epi16(a)
99
#define vec_max_16(a, b) _mm256_max_epi16(a, b)
100
#define vec_min_16(a, b) _mm256_min_epi16(a, b)
101
#define vec_slli_16(a, b) _mm256_slli_epi16(a, b)
102
// Inverse permuted at load time
103
#define vec_packus_16(a, b) _mm256_packus_epi16(a, b)
104
#define vec_load_psqt(a) _mm256_load_si256(a)
105
#define vec_store_psqt(a, b) _mm256_store_si256(a, b)
106
#define vec_add_psqt_32(a, b) _mm256_add_epi32(a, b)
107
#define vec_sub_psqt_32(a, b) _mm256_sub_epi32(a, b)
108
#define vec_zero_psqt() _mm256_setzero_si256()
109
110
#ifdef USE_SSSE3
111
#if defined(USE_VNNI) && !defined(USE_AVXVNNI)
112
#define vec_nnz(a) _mm256_cmpgt_epi32_mask(a, _mm256_setzero_si256())
113
#else
114
#define vec_nnz(a) \
115
_mm256_movemask_ps( \
116
_mm256_castsi256_ps(_mm256_cmpgt_epi32(a, _mm256_setzero_si256())))
117
#endif
118
#endif
119
120
#define vec128_zero _mm_setzero_si128()
121
#define vec128_set_16(a) _mm_set1_epi16(a)
122
#define vec128_load(a) _mm_load_si128(a)
123
#define vec128_storeu(a, b) _mm_storeu_si128(a, b)
124
#define vec128_add(a, b) _mm_add_epi16(a, b)
125
126
#define NumRegistersSIMD 12
127
#define MaxChunkSize 32
128
129
#elif USE_SSE2
130
using vec_t = __m128i;
131
using vec_i8_t = std::uint64_t; // for the correct size -- will be loaded into an xmm reg
132
using vec128_t = __m128i;
133
using psqt_vec_t = __m128i;
134
using vec_uint_t = __m128i;
135
#define vec_load(a) (*(a))
136
#define vec_store(a, b) *(a) = (b)
137
#define vec_add_16(a, b) _mm_add_epi16(a, b)
138
#define vec_sub_16(a, b) _mm_sub_epi16(a, b)
139
#define vec_mulhi_16(a, b) _mm_mulhi_epi16(a, b)
140
#define vec_zero() _mm_setzero_si128()
141
#define vec_set_16(a) _mm_set1_epi16(a)
142
#define vec_max_16(a, b) _mm_max_epi16(a, b)
143
#define vec_min_16(a, b) _mm_min_epi16(a, b)
144
#define vec_slli_16(a, b) _mm_slli_epi16(a, b)
145
#define vec_packus_16(a, b) _mm_packus_epi16(a, b)
146
#define vec_load_psqt(a) (*(a))
147
#define vec_store_psqt(a, b) *(a) = (b)
148
#define vec_add_psqt_32(a, b) _mm_add_epi32(a, b)
149
#define vec_sub_psqt_32(a, b) _mm_sub_epi32(a, b)
150
#define vec_zero_psqt() _mm_setzero_si128()
151
152
#ifdef USE_SSSE3
153
#define vec_nnz(a) \
154
_mm_movemask_ps(_mm_castsi128_ps(_mm_cmpgt_epi32(a, _mm_setzero_si128())))
155
#endif
156
157
#ifdef __i386__
158
inline __m128i _mm_cvtsi64_si128(int64_t val) {
159
return _mm_loadl_epi64(reinterpret_cast<const __m128i*>(&val));
160
}
161
#endif
162
163
#ifdef USE_SSE41
164
#define vec_convert_8_16(a) _mm_cvtepi8_epi16(_mm_cvtsi64_si128(static_cast<int64_t>(a)))
165
#else
166
// Credit: Yoshie2000
167
inline __m128i vec_convert_8_16(uint64_t x) {
168
__m128i v8 = _mm_cvtsi64_si128(static_cast<int64_t>(x));
169
__m128i sign = _mm_cmpgt_epi8(_mm_setzero_si128(), v8);
170
return _mm_unpacklo_epi8(v8, sign);
171
}
172
#endif
173
174
#define vec128_zero _mm_setzero_si128()
175
#define vec128_set_16(a) _mm_set1_epi16(a)
176
#define vec128_load(a) _mm_load_si128(a)
177
#define vec128_storeu(a, b) _mm_storeu_si128(a, b)
178
#define vec128_add(a, b) _mm_add_epi16(a, b)
179
180
#define NumRegistersSIMD (Is64Bit ? 12 : 6)
181
#define MaxChunkSize 16
182
183
#elif USE_NEON
184
using vec_i8x8_t __attribute__((may_alias)) = int8x8_t;
185
using vec_i16x8_t __attribute__((may_alias)) = int16x8_t;
186
using vec_i8x16_t __attribute__((may_alias)) = int8x16_t;
187
using vec_u16x8_t __attribute__((may_alias)) = uint16x8_t;
188
using vec_i32x4_t __attribute__((may_alias)) = int32x4_t;
189
190
using vec_t __attribute__((may_alias)) = int16x8_t;
191
using vec_i8_t __attribute__((may_alias)) = int8x16_t;
192
using psqt_vec_t __attribute__((may_alias)) = int32x4_t;
193
using vec128_t __attribute__((may_alias)) = uint16x8_t;
194
using vec_uint_t __attribute__((may_alias)) = uint32x4_t;
195
#define vec_load(a) (*(a))
196
#define vec_store(a, b) *(a) = (b)
197
#define vec_add_16(a, b) vaddq_s16(a, b)
198
#define vec_sub_16(a, b) vsubq_s16(a, b)
199
#define vec_mulhi_16(a, b) vqdmulhq_s16(a, b)
200
#define vec_zero() vec_t{0}
201
#define vec_set_16(a) vdupq_n_s16(a)
202
#define vec_max_16(a, b) vmaxq_s16(a, b)
203
#define vec_min_16(a, b) vminq_s16(a, b)
204
#define vec_slli_16(a, b) vshlq_s16(a, vec_set_16(b))
205
#define vec_packus_16(a, b) reinterpret_cast<vec_t>(vcombine_u8(vqmovun_s16(a), vqmovun_s16(b)))
206
#define vec_load_psqt(a) (*(a))
207
#define vec_store_psqt(a, b) *(a) = (b)
208
#define vec_add_psqt_32(a, b) vaddq_s32(a, b)
209
#define vec_sub_psqt_32(a, b) vsubq_s32(a, b)
210
#define vec_zero_psqt() psqt_vec_t{0}
211
212
static constexpr std::uint32_t Mask[4] = {1, 2, 4, 8};
213
#define vec_nnz(a) vaddvq_u32(vandq_u32(vtstq_u32(a, a), vld1q_u32(Mask)))
214
#define vec128_zero vdupq_n_u16(0)
215
#define vec128_set_16(a) vdupq_n_u16(a)
216
#define vec128_load(a) vld1q_u16(reinterpret_cast<const std::uint16_t*>(a))
217
#define vec128_storeu(a, b) vst1q_u16(reinterpret_cast<std::uint16_t*>(a), b)
218
#define vec128_add(a, b) vaddq_u16(a, b)
219
220
#define NumRegistersSIMD 16
221
#define MaxChunkSize 16
222
223
#ifndef __aarch64__
224
// Single instruction doesn't exist on 32-bit ARM
225
inline int16x8_t vmovl_high_s8(int8x16_t val) { return vmovl_s8(vget_high_s8(val)); }
226
#endif
227
228
#else
229
#undef VECTOR
230
231
#endif
232
233
struct Vec16Wrapper {
234
#ifdef VECTOR
235
using type = vec_t;
236
static type add(const type& lhs, const type& rhs) { return vec_add_16(lhs, rhs); }
237
static type sub(const type& lhs, const type& rhs) { return vec_sub_16(lhs, rhs); }
238
#else
239
using type = BiasType;
240
static type add(const type& lhs, const type& rhs) { return lhs + rhs; }
241
static type sub(const type& lhs, const type& rhs) { return lhs - rhs; }
242
#endif
243
};
244
245
struct Vec32Wrapper {
246
#ifdef VECTOR
247
using type = psqt_vec_t;
248
static type add(const type& lhs, const type& rhs) { return vec_add_psqt_32(lhs, rhs); }
249
static type sub(const type& lhs, const type& rhs) { return vec_sub_psqt_32(lhs, rhs); }
250
#else
251
using type = PSQTWeightType;
252
static type add(const type& lhs, const type& rhs) { return lhs + rhs; }
253
static type sub(const type& lhs, const type& rhs) { return lhs - rhs; }
254
#endif
255
};
256
257
enum UpdateOperation {
258
Add,
259
Sub
260
};
261
262
template<typename VecWrapper,
263
UpdateOperation... ops,
264
std::enable_if_t<sizeof...(ops) == 0, bool> = true>
265
typename VecWrapper::type fused(const typename VecWrapper::type& in) {
266
return in;
267
}
268
269
template<typename VecWrapper,
270
UpdateOperation update_op,
271
UpdateOperation... ops,
272
typename T,
273
typename... Ts,
274
std::enable_if_t<is_all_same_v<typename VecWrapper::type, T, Ts...>, bool> = true,
275
std::enable_if_t<sizeof...(ops) == sizeof...(Ts), bool> = true>
276
typename VecWrapper::type
277
fused(const typename VecWrapper::type& in, const T& operand, const Ts&... operands) {
278
switch (update_op)
279
{
280
case Add :
281
return fused<VecWrapper, ops...>(VecWrapper::add(in, operand), operands...);
282
case Sub :
283
return fused<VecWrapper, ops...>(VecWrapper::sub(in, operand), operands...);
284
default :
285
static_assert(update_op == Add || update_op == Sub,
286
"Only Add and Sub are currently supported.");
287
return typename VecWrapper::type();
288
}
289
}
290
291
#if defined(USE_AVX512)
292
293
[[maybe_unused]] static int m512_hadd(__m512i sum, int bias) {
294
return _mm512_reduce_add_epi32(sum) + bias;
295
}
296
297
[[maybe_unused]] static void m512_add_dpbusd_epi32(__m512i& acc, __m512i a, __m512i b) {
298
299
#if defined(USE_VNNI)
300
acc = _mm512_dpbusd_epi32(acc, a, b);
301
#else
302
__m512i product0 = _mm512_maddubs_epi16(a, b);
303
product0 = _mm512_madd_epi16(product0, _mm512_set1_epi16(1));
304
acc = _mm512_add_epi32(acc, product0);
305
#endif
306
}
307
308
#endif
309
310
#if defined(USE_AVX2)
311
312
[[maybe_unused]] static int m256_hadd(__m256i sum, int bias) {
313
__m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(sum), _mm256_extracti128_si256(sum, 1));
314
sum128 = _mm_add_epi32(sum128, _mm_shuffle_epi32(sum128, _MM_PERM_BADC));
315
sum128 = _mm_add_epi32(sum128, _mm_shuffle_epi32(sum128, _MM_PERM_CDAB));
316
return _mm_cvtsi128_si32(sum128) + bias;
317
}
318
319
[[maybe_unused]] static void m256_add_dpbusd_epi32(__m256i& acc, __m256i a, __m256i b) {
320
321
#if defined(USE_VNNI)
322
acc = _mm256_dpbusd_epi32(acc, a, b);
323
#else
324
__m256i product0 = _mm256_maddubs_epi16(a, b);
325
product0 = _mm256_madd_epi16(product0, _mm256_set1_epi16(1));
326
acc = _mm256_add_epi32(acc, product0);
327
#endif
328
}
329
330
#endif
331
332
#if defined(USE_SSSE3)
333
334
[[maybe_unused]] static int m128_hadd(__m128i sum, int bias) {
335
sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, 0x4E)); //_MM_PERM_BADC
336
sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, 0xB1)); //_MM_PERM_CDAB
337
return _mm_cvtsi128_si32(sum) + bias;
338
}
339
340
[[maybe_unused]] static void m128_add_dpbusd_epi32(__m128i& acc, __m128i a, __m128i b) {
341
342
__m128i product0 = _mm_maddubs_epi16(a, b);
343
product0 = _mm_madd_epi16(product0, _mm_set1_epi16(1));
344
acc = _mm_add_epi32(acc, product0);
345
}
346
347
#endif
348
349
#if defined(USE_NEON_DOTPROD)
350
351
[[maybe_unused]] static void
352
dotprod_m128_add_dpbusd_epi32(int32x4_t& acc, int8x16_t a, int8x16_t b) {
353
354
acc = vdotq_s32(acc, a, b);
355
}
356
#endif
357
358
#if defined(USE_NEON)
359
360
[[maybe_unused]] static int neon_m128_reduce_add_epi32(int32x4_t s) {
361
#if USE_NEON >= 8
362
return vaddvq_s32(s);
363
#else
364
return s[0] + s[1] + s[2] + s[3];
365
#endif
366
}
367
368
[[maybe_unused]] static int neon_m128_hadd(int32x4_t sum, int bias) {
369
return neon_m128_reduce_add_epi32(sum) + bias;
370
}
371
372
#endif
373
374
#if USE_NEON >= 8
375
[[maybe_unused]] static void neon_m128_add_dpbusd_epi32(int32x4_t& acc, int8x16_t a, int8x16_t b) {
376
377
int16x8_t product0 = vmull_s8(vget_low_s8(a), vget_low_s8(b));
378
int16x8_t product1 = vmull_high_s8(a, b);
379
int16x8_t sum = vpaddq_s16(product0, product1);
380
acc = vpadalq_s16(acc, sum);
381
}
382
#endif
383
384
385
// Compute optimal SIMD register count for feature transformer accumulation.
386
template<IndexType TransformedFeatureWidth, IndexType HalfDimensions, IndexType PSQTBuckets>
387
class SIMDTiling {
388
#ifdef VECTOR
389
// We use __m* types as template arguments, which causes GCC to emit warnings
390
// about losing some attribute information. This is irrelevant to us as we
391
// only take their size, so the following pragma are harmless.
392
#if defined(__GNUC__)
393
#pragma GCC diagnostic push
394
#pragma GCC diagnostic ignored "-Wignored-attributes"
395
#endif
396
397
template<typename SIMDRegisterType, typename LaneType, int NumLanes, int MaxRegisters>
398
static constexpr int BestRegisterCount() {
399
constexpr std::size_t RegisterSize = sizeof(SIMDRegisterType);
400
constexpr std::size_t LaneSize = sizeof(LaneType);
401
402
static_assert(RegisterSize >= LaneSize);
403
static_assert(MaxRegisters <= NumRegistersSIMD);
404
static_assert(MaxRegisters > 0);
405
static_assert(NumRegistersSIMD > 0);
406
static_assert(RegisterSize % LaneSize == 0);
407
static_assert((NumLanes * LaneSize) % RegisterSize == 0);
408
409
const int ideal = (NumLanes * LaneSize) / RegisterSize;
410
if (ideal <= MaxRegisters)
411
return ideal;
412
413
// Look for the largest divisor of the ideal register count that is smaller than MaxRegisters
414
for (int divisor = MaxRegisters; divisor > 1; --divisor)
415
if (ideal % divisor == 0)
416
return divisor;
417
418
return 1;
419
}
420
421
#if defined(__GNUC__)
422
#pragma GCC diagnostic pop
423
#endif
424
425
public:
426
static constexpr int NumRegs =
427
BestRegisterCount<vec_t, WeightType, TransformedFeatureWidth, NumRegistersSIMD>();
428
static constexpr int NumPsqtRegs =
429
BestRegisterCount<psqt_vec_t, PSQTWeightType, PSQTBuckets, NumRegistersSIMD>();
430
431
static constexpr IndexType TileHeight = NumRegs * sizeof(vec_t) / 2;
432
static constexpr IndexType PsqtTileHeight = NumPsqtRegs * sizeof(psqt_vec_t) / 4;
433
434
static_assert(HalfDimensions % TileHeight == 0, "TileHeight must divide HalfDimensions");
435
static_assert(PSQTBuckets % PsqtTileHeight == 0, "PsqtTileHeight must divide PSQTBuckets");
436
#endif
437
};
438
}
439
440
#endif
441
442