Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
official-stockfish
GitHub Repository: official-stockfish/Stockfish
Path: blob/master/src/nnue/simd.h
375 views
1
/*
2
Stockfish, a UCI chess playing engine derived from Glaurung 2.1
3
Copyright (C) 2004-2025 The Stockfish developers (see AUTHORS file)
4
5
Stockfish is free software: you can redistribute it and/or modify
6
it under the terms of the GNU General Public License as published by
7
the Free Software Foundation, either version 3 of the License, or
8
(at your option) any later version.
9
10
Stockfish is distributed in the hope that it will be useful,
11
but WITHOUT ANY WARRANTY; without even the implied warranty of
12
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
GNU General Public License for more details.
14
15
You should have received a copy of the GNU General Public License
16
along with this program. If not, see <http://www.gnu.org/licenses/>.
17
*/
18
19
#ifndef NNUE_SIMD_H_INCLUDED
20
#define NNUE_SIMD_H_INCLUDED
21
22
#if defined(USE_AVX2)
23
#include <immintrin.h>
24
25
#elif defined(USE_SSE41)
26
#include <smmintrin.h>
27
28
#elif defined(USE_SSSE3)
29
#include <tmmintrin.h>
30
31
#elif defined(USE_SSE2)
32
#include <emmintrin.h>
33
34
#elif defined(USE_NEON)
35
#include <arm_neon.h>
36
#endif
37
38
#include "../types.h"
39
#include "nnue_common.h"
40
41
namespace Stockfish::Eval::NNUE::SIMD {
42
43
// If vector instructions are enabled, we update and refresh the
44
// accumulator tile by tile such that each tile fits in the CPU's
45
// vector registers.
46
#define VECTOR
47
48
#ifdef USE_AVX512
49
using vec_t = __m512i;
50
using vec128_t = __m128i;
51
using psqt_vec_t = __m256i;
52
using vec_uint_t = __m512i;
53
#define vec_load(a) _mm512_load_si512(a)
54
#define vec_store(a, b) _mm512_store_si512(a, b)
55
#define vec_add_16(a, b) _mm512_add_epi16(a, b)
56
#define vec_sub_16(a, b) _mm512_sub_epi16(a, b)
57
#define vec_mulhi_16(a, b) _mm512_mulhi_epi16(a, b)
58
#define vec_zero() _mm512_setzero_epi32()
59
#define vec_set_16(a) _mm512_set1_epi16(a)
60
#define vec_max_16(a, b) _mm512_max_epi16(a, b)
61
#define vec_min_16(a, b) _mm512_min_epi16(a, b)
62
#define vec_slli_16(a, b) _mm512_slli_epi16(a, b)
63
// Inverse permuted at load time
64
#define vec_packus_16(a, b) _mm512_packus_epi16(a, b)
65
#define vec_load_psqt(a) _mm256_load_si256(a)
66
#define vec_store_psqt(a, b) _mm256_store_si256(a, b)
67
#define vec_add_psqt_32(a, b) _mm256_add_epi32(a, b)
68
#define vec_sub_psqt_32(a, b) _mm256_sub_epi32(a, b)
69
#define vec_zero_psqt() _mm256_setzero_si256()
70
71
#ifdef USE_SSSE3
72
#define vec_nnz(a) _mm512_cmpgt_epi32_mask(a, _mm512_setzero_si512())
73
#endif
74
75
#define vec128_zero _mm_setzero_si128()
76
#define vec128_set_16(a) _mm_set1_epi16(a)
77
#define vec128_load(a) _mm_load_si128(a)
78
#define vec128_storeu(a, b) _mm_storeu_si128(a, b)
79
#define vec128_add(a, b) _mm_add_epi16(a, b)
80
#define NumRegistersSIMD 16
81
#define MaxChunkSize 64
82
83
#elif USE_AVX2
84
using vec_t = __m256i;
85
using vec128_t = __m128i;
86
using psqt_vec_t = __m256i;
87
using vec_uint_t = __m256i;
88
#define vec_load(a) _mm256_load_si256(a)
89
#define vec_store(a, b) _mm256_store_si256(a, b)
90
#define vec_add_16(a, b) _mm256_add_epi16(a, b)
91
#define vec_sub_16(a, b) _mm256_sub_epi16(a, b)
92
#define vec_mulhi_16(a, b) _mm256_mulhi_epi16(a, b)
93
#define vec_zero() _mm256_setzero_si256()
94
#define vec_set_16(a) _mm256_set1_epi16(a)
95
#define vec_max_16(a, b) _mm256_max_epi16(a, b)
96
#define vec_min_16(a, b) _mm256_min_epi16(a, b)
97
#define vec_slli_16(a, b) _mm256_slli_epi16(a, b)
98
// Inverse permuted at load time
99
#define vec_packus_16(a, b) _mm256_packus_epi16(a, b)
100
#define vec_load_psqt(a) _mm256_load_si256(a)
101
#define vec_store_psqt(a, b) _mm256_store_si256(a, b)
102
#define vec_add_psqt_32(a, b) _mm256_add_epi32(a, b)
103
#define vec_sub_psqt_32(a, b) _mm256_sub_epi32(a, b)
104
#define vec_zero_psqt() _mm256_setzero_si256()
105
106
#ifdef USE_SSSE3
107
#if defined(USE_VNNI) && !defined(USE_AVXVNNI)
108
#define vec_nnz(a) _mm256_cmpgt_epi32_mask(a, _mm256_setzero_si256())
109
#else
110
#define vec_nnz(a) \
111
_mm256_movemask_ps( \
112
_mm256_castsi256_ps(_mm256_cmpgt_epi32(a, _mm256_setzero_si256())))
113
#endif
114
#endif
115
116
#define vec128_zero _mm_setzero_si128()
117
#define vec128_set_16(a) _mm_set1_epi16(a)
118
#define vec128_load(a) _mm_load_si128(a)
119
#define vec128_storeu(a, b) _mm_storeu_si128(a, b)
120
#define vec128_add(a, b) _mm_add_epi16(a, b)
121
122
#define NumRegistersSIMD 16
123
#define MaxChunkSize 32
124
125
#elif USE_SSE2
126
using vec_t = __m128i;
127
using vec128_t = __m128i;
128
using psqt_vec_t = __m128i;
129
using vec_uint_t = __m128i;
130
#define vec_load(a) (*(a))
131
#define vec_store(a, b) *(a) = (b)
132
#define vec_add_16(a, b) _mm_add_epi16(a, b)
133
#define vec_sub_16(a, b) _mm_sub_epi16(a, b)
134
#define vec_mulhi_16(a, b) _mm_mulhi_epi16(a, b)
135
#define vec_zero() _mm_setzero_si128()
136
#define vec_set_16(a) _mm_set1_epi16(a)
137
#define vec_max_16(a, b) _mm_max_epi16(a, b)
138
#define vec_min_16(a, b) _mm_min_epi16(a, b)
139
#define vec_slli_16(a, b) _mm_slli_epi16(a, b)
140
#define vec_packus_16(a, b) _mm_packus_epi16(a, b)
141
#define vec_load_psqt(a) (*(a))
142
#define vec_store_psqt(a, b) *(a) = (b)
143
#define vec_add_psqt_32(a, b) _mm_add_epi32(a, b)
144
#define vec_sub_psqt_32(a, b) _mm_sub_epi32(a, b)
145
#define vec_zero_psqt() _mm_setzero_si128()
146
147
#ifdef USE_SSSE3
148
#define vec_nnz(a) \
149
_mm_movemask_ps(_mm_castsi128_ps(_mm_cmpgt_epi32(a, _mm_setzero_si128())))
150
#endif
151
152
#define vec128_zero _mm_setzero_si128()
153
#define vec128_set_16(a) _mm_set1_epi16(a)
154
#define vec128_load(a) _mm_load_si128(a)
155
#define vec128_storeu(a, b) _mm_storeu_si128(a, b)
156
#define vec128_add(a, b) _mm_add_epi16(a, b)
157
158
#define NumRegistersSIMD (Is64Bit ? 16 : 8)
159
#define MaxChunkSize 16
160
161
#elif USE_NEON
162
using vec_t = int16x8_t;
163
using psqt_vec_t = int32x4_t;
164
using vec128_t = uint16x8_t;
165
using vec_uint_t = uint32x4_t;
166
#define vec_load(a) (*(a))
167
#define vec_store(a, b) *(a) = (b)
168
#define vec_add_16(a, b) vaddq_s16(a, b)
169
#define vec_sub_16(a, b) vsubq_s16(a, b)
170
#define vec_mulhi_16(a, b) vqdmulhq_s16(a, b)
171
#define vec_zero() vec_t{0}
172
#define vec_set_16(a) vdupq_n_s16(a)
173
#define vec_max_16(a, b) vmaxq_s16(a, b)
174
#define vec_min_16(a, b) vminq_s16(a, b)
175
#define vec_slli_16(a, b) vshlq_s16(a, vec_set_16(b))
176
#define vec_packus_16(a, b) reinterpret_cast<vec_t>(vcombine_u8(vqmovun_s16(a), vqmovun_s16(b)))
177
#define vec_load_psqt(a) (*(a))
178
#define vec_store_psqt(a, b) *(a) = (b)
179
#define vec_add_psqt_32(a, b) vaddq_s32(a, b)
180
#define vec_sub_psqt_32(a, b) vsubq_s32(a, b)
181
#define vec_zero_psqt() psqt_vec_t{0}
182
183
static constexpr std::uint32_t Mask[4] = {1, 2, 4, 8};
184
#define vec_nnz(a) vaddvq_u32(vandq_u32(vtstq_u32(a, a), vld1q_u32(Mask)))
185
#define vec128_zero vdupq_n_u16(0)
186
#define vec128_set_16(a) vdupq_n_u16(a)
187
#define vec128_load(a) vld1q_u16(reinterpret_cast<const std::uint16_t*>(a))
188
#define vec128_storeu(a, b) vst1q_u16(reinterpret_cast<std::uint16_t*>(a), b)
189
#define vec128_add(a, b) vaddq_u16(a, b)
190
191
#define NumRegistersSIMD 16
192
#define MaxChunkSize 16
193
194
#else
195
#undef VECTOR
196
197
#endif
198
199
struct Vec16Wrapper {
200
#ifdef VECTOR
201
using type = vec_t;
202
static type add(const type& lhs, const type& rhs) { return vec_add_16(lhs, rhs); }
203
static type sub(const type& lhs, const type& rhs) { return vec_sub_16(lhs, rhs); }
204
#else
205
using type = BiasType;
206
static type add(const type& lhs, const type& rhs) { return lhs + rhs; }
207
static type sub(const type& lhs, const type& rhs) { return lhs - rhs; }
208
#endif
209
};
210
211
struct Vec32Wrapper {
212
#ifdef VECTOR
213
using type = psqt_vec_t;
214
static type add(const type& lhs, const type& rhs) { return vec_add_psqt_32(lhs, rhs); }
215
static type sub(const type& lhs, const type& rhs) { return vec_sub_psqt_32(lhs, rhs); }
216
#else
217
using type = PSQTWeightType;
218
static type add(const type& lhs, const type& rhs) { return lhs + rhs; }
219
static type sub(const type& lhs, const type& rhs) { return lhs - rhs; }
220
#endif
221
};
222
223
enum UpdateOperation {
224
Add,
225
Sub
226
};
227
228
template<typename VecWrapper,
229
UpdateOperation... ops,
230
std::enable_if_t<sizeof...(ops) == 0, bool> = true>
231
typename VecWrapper::type fused(const typename VecWrapper::type& in) {
232
return in;
233
}
234
235
template<typename VecWrapper,
236
UpdateOperation update_op,
237
UpdateOperation... ops,
238
typename T,
239
typename... Ts,
240
std::enable_if_t<is_all_same_v<typename VecWrapper::type, T, Ts...>, bool> = true,
241
std::enable_if_t<sizeof...(ops) == sizeof...(Ts), bool> = true>
242
typename VecWrapper::type
243
fused(const typename VecWrapper::type& in, const T& operand, const Ts&... operands) {
244
switch (update_op)
245
{
246
case Add :
247
return fused<VecWrapper, ops...>(VecWrapper::add(in, operand), operands...);
248
case Sub :
249
return fused<VecWrapper, ops...>(VecWrapper::sub(in, operand), operands...);
250
default :
251
static_assert(update_op == Add || update_op == Sub,
252
"Only Add and Sub are currently supported.");
253
return typename VecWrapper::type();
254
}
255
}
256
257
#if defined(USE_AVX512)
258
259
[[maybe_unused]] static int m512_hadd(__m512i sum, int bias) {
260
return _mm512_reduce_add_epi32(sum) + bias;
261
}
262
263
[[maybe_unused]] static void m512_add_dpbusd_epi32(__m512i& acc, __m512i a, __m512i b) {
264
265
#if defined(USE_VNNI)
266
acc = _mm512_dpbusd_epi32(acc, a, b);
267
#else
268
__m512i product0 = _mm512_maddubs_epi16(a, b);
269
product0 = _mm512_madd_epi16(product0, _mm512_set1_epi16(1));
270
acc = _mm512_add_epi32(acc, product0);
271
#endif
272
}
273
274
#endif
275
276
#if defined(USE_AVX2)
277
278
[[maybe_unused]] static int m256_hadd(__m256i sum, int bias) {
279
__m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(sum), _mm256_extracti128_si256(sum, 1));
280
sum128 = _mm_add_epi32(sum128, _mm_shuffle_epi32(sum128, _MM_PERM_BADC));
281
sum128 = _mm_add_epi32(sum128, _mm_shuffle_epi32(sum128, _MM_PERM_CDAB));
282
return _mm_cvtsi128_si32(sum128) + bias;
283
}
284
285
[[maybe_unused]] static void m256_add_dpbusd_epi32(__m256i& acc, __m256i a, __m256i b) {
286
287
#if defined(USE_VNNI)
288
acc = _mm256_dpbusd_epi32(acc, a, b);
289
#else
290
__m256i product0 = _mm256_maddubs_epi16(a, b);
291
product0 = _mm256_madd_epi16(product0, _mm256_set1_epi16(1));
292
acc = _mm256_add_epi32(acc, product0);
293
#endif
294
}
295
296
#endif
297
298
#if defined(USE_SSSE3)
299
300
[[maybe_unused]] static int m128_hadd(__m128i sum, int bias) {
301
sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, 0x4E)); //_MM_PERM_BADC
302
sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, 0xB1)); //_MM_PERM_CDAB
303
return _mm_cvtsi128_si32(sum) + bias;
304
}
305
306
[[maybe_unused]] static void m128_add_dpbusd_epi32(__m128i& acc, __m128i a, __m128i b) {
307
308
__m128i product0 = _mm_maddubs_epi16(a, b);
309
product0 = _mm_madd_epi16(product0, _mm_set1_epi16(1));
310
acc = _mm_add_epi32(acc, product0);
311
}
312
313
#endif
314
315
#if defined(USE_NEON_DOTPROD)
316
317
[[maybe_unused]] static void
318
dotprod_m128_add_dpbusd_epi32(int32x4_t& acc, int8x16_t a, int8x16_t b) {
319
320
acc = vdotq_s32(acc, a, b);
321
}
322
#endif
323
324
#if defined(USE_NEON)
325
326
[[maybe_unused]] static int neon_m128_reduce_add_epi32(int32x4_t s) {
327
#if USE_NEON >= 8
328
return vaddvq_s32(s);
329
#else
330
return s[0] + s[1] + s[2] + s[3];
331
#endif
332
}
333
334
[[maybe_unused]] static int neon_m128_hadd(int32x4_t sum, int bias) {
335
return neon_m128_reduce_add_epi32(sum) + bias;
336
}
337
338
#endif
339
340
#if USE_NEON >= 8
341
[[maybe_unused]] static void neon_m128_add_dpbusd_epi32(int32x4_t& acc, int8x16_t a, int8x16_t b) {
342
343
int16x8_t product0 = vmull_s8(vget_low_s8(a), vget_low_s8(b));
344
int16x8_t product1 = vmull_high_s8(a, b);
345
int16x8_t sum = vpaddq_s16(product0, product1);
346
acc = vpadalq_s16(acc, sum);
347
}
348
#endif
349
350
351
// Compute optimal SIMD register count for feature transformer accumulation.
352
template<IndexType TransformedFeatureWidth, IndexType HalfDimensions, IndexType PSQTBuckets>
353
class SIMDTiling {
354
#ifdef VECTOR
355
// We use __m* types as template arguments, which causes GCC to emit warnings
356
// about losing some attribute information. This is irrelevant to us as we
357
// only take their size, so the following pragma are harmless.
358
#if defined(__GNUC__)
359
#pragma GCC diagnostic push
360
#pragma GCC diagnostic ignored "-Wignored-attributes"
361
#endif
362
363
template<typename SIMDRegisterType, typename LaneType, int NumLanes, int MaxRegisters>
364
static constexpr int BestRegisterCount() {
365
constexpr std::size_t RegisterSize = sizeof(SIMDRegisterType);
366
constexpr std::size_t LaneSize = sizeof(LaneType);
367
368
static_assert(RegisterSize >= LaneSize);
369
static_assert(MaxRegisters <= NumRegistersSIMD);
370
static_assert(MaxRegisters > 0);
371
static_assert(NumRegistersSIMD > 0);
372
static_assert(RegisterSize % LaneSize == 0);
373
static_assert((NumLanes * LaneSize) % RegisterSize == 0);
374
375
const int ideal = (NumLanes * LaneSize) / RegisterSize;
376
if (ideal <= MaxRegisters)
377
return ideal;
378
379
// Look for the largest divisor of the ideal register count that is smaller than MaxRegisters
380
for (int divisor = MaxRegisters; divisor > 1; --divisor)
381
if (ideal % divisor == 0)
382
return divisor;
383
384
return 1;
385
}
386
387
#if defined(__GNUC__)
388
#pragma GCC diagnostic pop
389
#endif
390
391
public:
392
static constexpr int NumRegs =
393
BestRegisterCount<vec_t, WeightType, TransformedFeatureWidth, NumRegistersSIMD>();
394
static constexpr int NumPsqtRegs =
395
BestRegisterCount<psqt_vec_t, PSQTWeightType, PSQTBuckets, NumRegistersSIMD>();
396
397
static constexpr IndexType TileHeight = NumRegs * sizeof(vec_t) / 2;
398
static constexpr IndexType PsqtTileHeight = NumPsqtRegs * sizeof(psqt_vec_t) / 4;
399
400
static_assert(HalfDimensions % TileHeight == 0, "TileHeight must divide HalfDimensions");
401
static_assert(PSQTBuckets % PsqtTileHeight == 0, "PsqtTileHeight must divide PSQTBuckets");
402
#endif
403
};
404
}
405
406
#endif
407
408