Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
official-stockfish
GitHub Repository: official-stockfish/Stockfish
Path: blob/master/src/nnue/nnue_feature_transformer.h
649 views
1
/*
2
Stockfish, a UCI chess playing engine derived from Glaurung 2.1
3
Copyright (C) 2004-2026 The Stockfish developers (see AUTHORS file)
4
5
Stockfish is free software: you can redistribute it and/or modify
6
it under the terms of the GNU General Public License as published by
7
the Free Software Foundation, either version 3 of the License, or
8
(at your option) any later version.
9
10
Stockfish is distributed in the hope that it will be useful,
11
but WITHOUT ANY WARRANTY; without even the implied warranty of
12
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
GNU General Public License for more details.
14
15
You should have received a copy of the GNU General Public License
16
along with this program. If not, see <http://www.gnu.org/licenses/>.
17
*/
18
19
// A class that converts the input features of the NNUE evaluation function
20
21
#ifndef NNUE_FEATURE_TRANSFORMER_H_INCLUDED
22
#define NNUE_FEATURE_TRANSFORMER_H_INCLUDED
23
24
#include <algorithm>
25
#include <cstdint>
26
#include <cstring>
27
#include <iosfwd>
28
#include <iterator>
29
30
#include "../position.h"
31
#include "../types.h"
32
#include "nnue_accumulator.h"
33
#include "nnue_architecture.h"
34
#include "nnue_common.h"
35
#include "simd.h"
36
37
namespace Stockfish::Eval::NNUE {
38
39
// Returns the inverse of a permutation
40
template<std::size_t Len>
41
constexpr std::array<std::size_t, Len>
42
invert_permutation(const std::array<std::size_t, Len>& order) {
43
std::array<std::size_t, Len> inverse{};
44
for (std::size_t i = 0; i < order.size(); i++)
45
inverse[order[i]] = i;
46
return inverse;
47
}
48
49
// Divide a byte region of size TotalSize to chunks of size
50
// BlockSize, and permute the blocks by a given order
51
template<std::size_t BlockSize, typename T, std::size_t N, std::size_t OrderSize>
52
void permute(std::array<T, N>& data, const std::array<std::size_t, OrderSize>& order) {
53
constexpr std::size_t TotalSize = N * sizeof(T);
54
55
static_assert(TotalSize % (BlockSize * OrderSize) == 0,
56
"ChunkSize * OrderSize must perfectly divide TotalSize");
57
58
constexpr std::size_t ProcessChunkSize = BlockSize * OrderSize;
59
60
std::array<std::byte, ProcessChunkSize> buffer{};
61
62
std::byte* const bytes = reinterpret_cast<std::byte*>(data.data());
63
64
for (std::size_t i = 0; i < TotalSize; i += ProcessChunkSize)
65
{
66
std::byte* const values = &bytes[i];
67
68
for (std::size_t j = 0; j < OrderSize; j++)
69
{
70
auto* const buffer_chunk = &buffer[j * BlockSize];
71
auto* const value_chunk = &values[order[j] * BlockSize];
72
73
std::copy(value_chunk, value_chunk + BlockSize, buffer_chunk);
74
}
75
76
std::copy(std::begin(buffer), std::end(buffer), values);
77
}
78
}
79
80
// Input feature converter
81
template<IndexType TransformedFeatureDimensions>
82
class FeatureTransformer {
83
static constexpr bool UseThreats =
84
(TransformedFeatureDimensions == TransformedFeatureDimensionsBig);
85
// Number of output dimensions for one side
86
static constexpr IndexType HalfDimensions = TransformedFeatureDimensions;
87
88
public:
89
// Output type
90
using OutputType = TransformedFeatureType;
91
92
// Number of input/output dimensions
93
static constexpr IndexType InputDimensions = PSQFeatureSet::Dimensions;
94
static constexpr IndexType ThreatInputDimensions = ThreatFeatureSet::Dimensions;
95
static constexpr IndexType TotalInputDimensions =
96
InputDimensions + (UseThreats ? ThreatInputDimensions : 0);
97
static constexpr IndexType OutputDimensions = HalfDimensions;
98
99
// Size of forward propagation buffer
100
static constexpr std::size_t BufferSize = OutputDimensions * sizeof(OutputType);
101
102
// Store the order by which 128-bit blocks of a 1024-bit data must
103
// be permuted so that calling packus on adjacent vectors of 16-bit
104
// integers loaded from the data results in the pre-permutation order
105
static constexpr auto PackusEpi16Order = []() -> std::array<std::size_t, 8> {
106
#if defined(USE_AVX512)
107
// _mm512_packus_epi16 after permutation:
108
// | 0 | 2 | 4 | 6 | // Vector 0
109
// | 1 | 3 | 5 | 7 | // Vector 1
110
// | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | // Packed Result
111
return {0, 2, 4, 6, 1, 3, 5, 7};
112
#elif defined(USE_AVX2)
113
// _mm256_packus_epi16 after permutation:
114
// | 0 | 2 | | 4 | 6 | // Vector 0, 2
115
// | 1 | 3 | | 5 | 7 | // Vector 1, 3
116
// | 0 | 1 | 2 | 3 | | 4 | 5 | 6 | 7 | // Packed Result
117
return {0, 2, 1, 3, 4, 6, 5, 7};
118
#else
119
return {0, 1, 2, 3, 4, 5, 6, 7};
120
#endif
121
}();
122
123
static constexpr auto InversePackusEpi16Order = invert_permutation(PackusEpi16Order);
124
125
// Hash value embedded in the evaluation file
126
static constexpr std::uint32_t get_hash_value() {
127
return (UseThreats ? ThreatFeatureSet::HashValue : PSQFeatureSet::HashValue)
128
^ (OutputDimensions * 2);
129
}
130
131
void permute_weights() {
132
permute<16>(biases, PackusEpi16Order);
133
permute<16>(weights, PackusEpi16Order);
134
135
if constexpr (UseThreats)
136
permute<8>(threatWeights, PackusEpi16Order);
137
}
138
139
void unpermute_weights() {
140
permute<16>(biases, InversePackusEpi16Order);
141
permute<16>(weights, InversePackusEpi16Order);
142
143
if constexpr (UseThreats)
144
permute<8>(threatWeights, InversePackusEpi16Order);
145
}
146
147
inline void scale_weights(bool read) {
148
for (auto& w : weights)
149
w = read ? w * 2 : w / 2;
150
for (auto& b : biases)
151
b = read ? b * 2 : b / 2;
152
}
153
154
// Read network parameters
155
bool read_parameters(std::istream& stream) {
156
read_leb_128(stream, biases);
157
158
if constexpr (UseThreats)
159
{
160
read_little_endian<ThreatWeightType>(stream, threatWeights.data(),
161
ThreatInputDimensions * HalfDimensions);
162
read_leb_128(stream, weights);
163
164
read_leb_128(stream, threatPsqtWeights, psqtWeights);
165
}
166
else
167
{
168
read_leb_128(stream, weights);
169
read_leb_128(stream, psqtWeights);
170
}
171
172
permute_weights();
173
174
if constexpr (!UseThreats)
175
scale_weights(true);
176
177
return !stream.fail();
178
}
179
180
// Write network parameters
181
bool write_parameters(std::ostream& stream) const {
182
std::unique_ptr<FeatureTransformer> copy = std::make_unique<FeatureTransformer>(*this);
183
184
copy->unpermute_weights();
185
186
if constexpr (!UseThreats)
187
copy->scale_weights(false);
188
189
write_leb_128<BiasType>(stream, copy->biases);
190
191
if constexpr (UseThreats)
192
{
193
write_little_endian<ThreatWeightType>(stream, copy->threatWeights.data(),
194
ThreatInputDimensions * HalfDimensions);
195
write_leb_128<WeightType>(stream, copy->weights);
196
197
auto combinedPsqtWeights =
198
std::make_unique<std::array<PSQTWeightType, TotalInputDimensions * PSQTBuckets>>();
199
200
std::copy(std::begin(copy->threatPsqtWeights),
201
std::begin(copy->threatPsqtWeights) + ThreatInputDimensions * PSQTBuckets,
202
combinedPsqtWeights->begin());
203
204
std::copy(std::begin(copy->psqtWeights),
205
std::begin(copy->psqtWeights) + InputDimensions * PSQTBuckets,
206
combinedPsqtWeights->begin() + ThreatInputDimensions * PSQTBuckets);
207
208
write_leb_128<PSQTWeightType>(stream, *combinedPsqtWeights);
209
}
210
else
211
{
212
write_leb_128<WeightType>(stream, copy->weights);
213
write_leb_128<PSQTWeightType>(stream, copy->psqtWeights);
214
}
215
216
return !stream.fail();
217
}
218
219
std::size_t get_content_hash() const {
220
std::size_t h = 0;
221
222
hash_combine(h, get_raw_data_hash(biases));
223
hash_combine(h, get_raw_data_hash(weights));
224
hash_combine(h, get_raw_data_hash(psqtWeights));
225
226
if constexpr (UseThreats)
227
{
228
hash_combine(h, get_raw_data_hash(threatWeights));
229
hash_combine(h, get_raw_data_hash(threatPsqtWeights));
230
}
231
232
hash_combine(h, get_hash_value());
233
234
return h;
235
}
236
237
// Convert input features
238
std::int32_t transform(const Position& pos,
239
AccumulatorStack& accumulatorStack,
240
AccumulatorCaches::Cache<HalfDimensions>& cache,
241
OutputType* output,
242
int bucket) const {
243
244
using namespace SIMD;
245
accumulatorStack.evaluate(pos, *this, cache);
246
const auto& accumulatorState = accumulatorStack.latest<PSQFeatureSet>();
247
const auto& threatAccumulatorState = accumulatorStack.latest<ThreatFeatureSet>();
248
249
const Color perspectives[2] = {pos.side_to_move(), ~pos.side_to_move()};
250
const auto& psqtAccumulation = (accumulatorState.acc<HalfDimensions>()).psqtAccumulation;
251
auto psqt =
252
(psqtAccumulation[perspectives[0]][bucket] - psqtAccumulation[perspectives[1]][bucket]);
253
254
if constexpr (UseThreats)
255
{
256
const auto& threatPsqtAccumulation =
257
(threatAccumulatorState.acc<HalfDimensions>()).psqtAccumulation;
258
psqt = (psqt + threatPsqtAccumulation[perspectives[0]][bucket]
259
- threatPsqtAccumulation[perspectives[1]][bucket])
260
/ 2;
261
}
262
else
263
psqt /= 2;
264
265
const auto& accumulation = (accumulatorState.acc<HalfDimensions>()).accumulation;
266
const auto& threatAccumulation =
267
(threatAccumulatorState.acc<HalfDimensions>()).accumulation;
268
269
for (IndexType p = 0; p < 2; ++p)
270
{
271
const IndexType offset = (HalfDimensions / 2) * p;
272
273
#if defined(VECTOR)
274
275
constexpr IndexType OutputChunkSize = MaxChunkSize;
276
static_assert((HalfDimensions / 2) % OutputChunkSize == 0);
277
constexpr IndexType NumOutputChunks = HalfDimensions / 2 / OutputChunkSize;
278
279
const vec_t Zero = vec_zero();
280
const vec_t One = vec_set_16(UseThreats ? 255 : 127 * 2);
281
282
const vec_t* in0 = reinterpret_cast<const vec_t*>(&(accumulation[perspectives[p]][0]));
283
const vec_t* in1 =
284
reinterpret_cast<const vec_t*>(&(accumulation[perspectives[p]][HalfDimensions / 2]));
285
vec_t* out = reinterpret_cast<vec_t*>(output + offset);
286
287
// Per the NNUE architecture, here we want to multiply pairs of
288
// clipped elements and divide the product by 128. To do this,
289
// we can naively perform min/max operation to clip each of the
290
// four int16 vectors, mullo pairs together, then pack them into
291
// one int8 vector. However, there exists a faster way.
292
293
// The idea here is to use the implicit clipping from packus to
294
// save us two vec_max_16 instructions. This clipping works due
295
// to the fact that any int16 integer below zero will be zeroed
296
// on packus.
297
298
// Consider the case where the second element is negative.
299
// If we do standard clipping, that element will be zero, which
300
// means our pairwise product is zero. If we perform packus and
301
// remove the lower-side clip for the second element, then our
302
// product before packus will be negative, and is zeroed on pack.
303
// The two operation produce equivalent results, but the second
304
// one (using packus) saves one max operation per pair.
305
306
// But here we run into a problem: mullo does not preserve the
307
// sign of the multiplication. We can get around this by doing
308
// mulhi, which keeps the sign. But that requires an additional
309
// tweak.
310
311
// mulhi cuts off the last 16 bits of the resulting product,
312
// which is the same as performing a rightward shift of 16 bits.
313
// We can use this to our advantage. Recall that we want to
314
// divide the final product by 128, which is equivalent to a
315
// 7-bit right shift. Intuitively, if we shift the clipped
316
// value left by 9, and perform mulhi, which shifts the product
317
// right by 16 bits, then we will net a right shift of 7 bits.
318
// However, this won't work as intended. Since we clip the
319
// values to have a maximum value of 127, shifting it by 9 bits
320
// might occupy the signed bit, resulting in some positive
321
// values being interpreted as negative after the shift.
322
323
// There is a way, however, to get around this limitation. When
324
// loading the network, scale accumulator weights and biases by
325
// 2. To get the same pairwise multiplication result as before,
326
// we need to divide the product by 128 * 2 * 2 = 512, which
327
// amounts to a right shift of 9 bits. So now we only have to
328
// shift left by 7 bits, perform mulhi (shifts right by 16 bits)
329
// and net a 9 bit right shift. Since we scaled everything by
330
// two, the values are clipped at 127 * 2 = 254, which occupies
331
// 8 bits. Shifting it by 7 bits left will no longer occupy the
332
// signed bit, so we are safe.
333
334
// Note that on NEON processors, we shift left by 6 instead
335
// because the instruction "vqdmulhq_s16" also doubles the
336
// return value after the multiplication, adding an extra shift
337
// to the left by 1, so we compensate by shifting less before
338
// the multiplication.
339
340
constexpr int shift =
341
#if defined(USE_SSE2)
342
7;
343
#else
344
6;
345
#endif
346
if constexpr (UseThreats)
347
{
348
const vec_t* tin0 =
349
reinterpret_cast<const vec_t*>(&(threatAccumulation[perspectives[p]][0]));
350
const vec_t* tin1 = reinterpret_cast<const vec_t*>(
351
&(threatAccumulation[perspectives[p]][HalfDimensions / 2]));
352
for (IndexType j = 0; j < NumOutputChunks; ++j)
353
{
354
const vec_t acc0a = vec_add_16(in0[j * 2 + 0], tin0[j * 2 + 0]);
355
const vec_t acc0b = vec_add_16(in0[j * 2 + 1], tin0[j * 2 + 1]);
356
const vec_t acc1a = vec_add_16(in1[j * 2 + 0], tin1[j * 2 + 0]);
357
const vec_t acc1b = vec_add_16(in1[j * 2 + 1], tin1[j * 2 + 1]);
358
359
const vec_t sum0a =
360
vec_slli_16(vec_max_16(vec_min_16(acc0a, One), Zero), shift);
361
const vec_t sum0b =
362
vec_slli_16(vec_max_16(vec_min_16(acc0b, One), Zero), shift);
363
const vec_t sum1a = vec_min_16(acc1a, One);
364
const vec_t sum1b = vec_min_16(acc1b, One);
365
366
const vec_t pa = vec_mulhi_16(sum0a, sum1a);
367
const vec_t pb = vec_mulhi_16(sum0b, sum1b);
368
369
out[j] = vec_packus_16(pa, pb);
370
}
371
}
372
else
373
{
374
for (IndexType j = 0; j < NumOutputChunks; ++j)
375
{
376
const vec_t sum0a =
377
vec_slli_16(vec_max_16(vec_min_16(in0[j * 2 + 0], One), Zero), shift);
378
const vec_t sum0b =
379
vec_slli_16(vec_max_16(vec_min_16(in0[j * 2 + 1], One), Zero), shift);
380
const vec_t sum1a = vec_min_16(in1[j * 2 + 0], One);
381
const vec_t sum1b = vec_min_16(in1[j * 2 + 1], One);
382
383
const vec_t pa = vec_mulhi_16(sum0a, sum1a);
384
const vec_t pb = vec_mulhi_16(sum0b, sum1b);
385
386
out[j] = vec_packus_16(pa, pb);
387
}
388
}
389
390
#else
391
392
for (IndexType j = 0; j < HalfDimensions / 2; ++j)
393
{
394
BiasType sum0 = accumulation[static_cast<int>(perspectives[p])][j + 0];
395
BiasType sum1 =
396
accumulation[static_cast<int>(perspectives[p])][j + HalfDimensions / 2];
397
398
if constexpr (UseThreats)
399
{
400
BiasType sum0t = threatAccumulation[static_cast<int>(perspectives[p])][j + 0];
401
BiasType sum1t =
402
threatAccumulation[static_cast<int>(perspectives[p])][j + HalfDimensions / 2];
403
sum0 = std::clamp<BiasType>(sum0 + sum0t, 0, 255);
404
sum1 = std::clamp<BiasType>(sum1 + sum1t, 0, 255);
405
}
406
else
407
{
408
sum0 = std::clamp<BiasType>(sum0, 0, 127 * 2);
409
sum1 = std::clamp<BiasType>(sum1, 0, 127 * 2);
410
}
411
412
output[offset + j] = static_cast<OutputType>(unsigned(sum0 * sum1) / 512);
413
}
414
415
#endif
416
}
417
418
return psqt;
419
} // end of function transform()
420
421
alignas(CacheLineSize) std::array<BiasType, HalfDimensions> biases;
422
alignas(CacheLineSize) std::array<WeightType, HalfDimensions * InputDimensions> weights;
423
alignas(CacheLineSize)
424
std::array<ThreatWeightType,
425
UseThreats ? HalfDimensions * ThreatInputDimensions : 0> threatWeights;
426
alignas(CacheLineSize) std::array<PSQTWeightType, InputDimensions * PSQTBuckets> psqtWeights;
427
alignas(CacheLineSize)
428
std::array<PSQTWeightType,
429
UseThreats ? ThreatInputDimensions * PSQTBuckets : 0> threatPsqtWeights;
430
};
431
432
} // namespace Stockfish::Eval::NNUE
433
434
435
template<Stockfish::Eval::NNUE::IndexType TransformedFeatureDimensions>
436
struct std::hash<Stockfish::Eval::NNUE::FeatureTransformer<TransformedFeatureDimensions>> {
437
std::size_t
438
operator()(const Stockfish::Eval::NNUE::FeatureTransformer<TransformedFeatureDimensions>& ft)
439
const noexcept {
440
return ft.get_content_hash();
441
}
442
};
443
444
#endif // #ifndef NNUE_FEATURE_TRANSFORMER_H_INCLUDED
445
446