Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
official-stockfish
GitHub Repository: official-stockfish/Stockfish
Path: blob/master/src/nnue/nnue_architecture.h
375 views
1
/*
2
Stockfish, a UCI chess playing engine derived from Glaurung 2.1
3
Copyright (C) 2004-2025 The Stockfish developers (see AUTHORS file)
4
5
Stockfish is free software: you can redistribute it and/or modify
6
it under the terms of the GNU General Public License as published by
7
the Free Software Foundation, either version 3 of the License, or
8
(at your option) any later version.
9
10
Stockfish is distributed in the hope that it will be useful,
11
but WITHOUT ANY WARRANTY; without even the implied warranty of
12
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
GNU General Public License for more details.
14
15
You should have received a copy of the GNU General Public License
16
along with this program. If not, see <http://www.gnu.org/licenses/>.
17
*/
18
19
// Input features and network structure used in NNUE evaluation function
20
21
#ifndef NNUE_ARCHITECTURE_H_INCLUDED
22
#define NNUE_ARCHITECTURE_H_INCLUDED
23
24
#include <cstdint>
25
#include <cstring>
26
#include <iosfwd>
27
28
#include "features/half_ka_v2_hm.h"
29
#include "layers/affine_transform.h"
30
#include "layers/affine_transform_sparse_input.h"
31
#include "layers/clipped_relu.h"
32
#include "layers/sqr_clipped_relu.h"
33
#include "nnue_common.h"
34
35
namespace Stockfish::Eval::NNUE {
36
37
// Input features used in evaluation function
38
using FeatureSet = Features::HalfKAv2_hm;
39
40
// Number of input feature dimensions after conversion
41
constexpr IndexType TransformedFeatureDimensionsBig = 3072;
42
constexpr int L2Big = 15;
43
constexpr int L3Big = 32;
44
45
constexpr IndexType TransformedFeatureDimensionsSmall = 128;
46
constexpr int L2Small = 15;
47
constexpr int L3Small = 32;
48
49
constexpr IndexType PSQTBuckets = 8;
50
constexpr IndexType LayerStacks = 8;
51
52
// If vector instructions are enabled, we update and refresh the
53
// accumulator tile by tile such that each tile fits in the CPU's
54
// vector registers.
55
static_assert(PSQTBuckets % 8 == 0,
56
"Per feature PSQT values cannot be processed at granularity lower than 8 at a time.");
57
58
template<IndexType L1, int L2, int L3>
59
struct NetworkArchitecture {
60
static constexpr IndexType TransformedFeatureDimensions = L1;
61
static constexpr int FC_0_OUTPUTS = L2;
62
static constexpr int FC_1_OUTPUTS = L3;
63
64
Layers::AffineTransformSparseInput<TransformedFeatureDimensions, FC_0_OUTPUTS + 1> fc_0;
65
Layers::SqrClippedReLU<FC_0_OUTPUTS + 1> ac_sqr_0;
66
Layers::ClippedReLU<FC_0_OUTPUTS + 1> ac_0;
67
Layers::AffineTransform<FC_0_OUTPUTS * 2, FC_1_OUTPUTS> fc_1;
68
Layers::ClippedReLU<FC_1_OUTPUTS> ac_1;
69
Layers::AffineTransform<FC_1_OUTPUTS, 1> fc_2;
70
71
// Hash value embedded in the evaluation file
72
static constexpr std::uint32_t get_hash_value() {
73
// input slice hash
74
std::uint32_t hashValue = 0xEC42E90Du;
75
hashValue ^= TransformedFeatureDimensions * 2;
76
77
hashValue = decltype(fc_0)::get_hash_value(hashValue);
78
hashValue = decltype(ac_0)::get_hash_value(hashValue);
79
hashValue = decltype(fc_1)::get_hash_value(hashValue);
80
hashValue = decltype(ac_1)::get_hash_value(hashValue);
81
hashValue = decltype(fc_2)::get_hash_value(hashValue);
82
83
return hashValue;
84
}
85
86
// Read network parameters
87
bool read_parameters(std::istream& stream) {
88
return fc_0.read_parameters(stream) && ac_0.read_parameters(stream)
89
&& fc_1.read_parameters(stream) && ac_1.read_parameters(stream)
90
&& fc_2.read_parameters(stream);
91
}
92
93
// Write network parameters
94
bool write_parameters(std::ostream& stream) const {
95
return fc_0.write_parameters(stream) && ac_0.write_parameters(stream)
96
&& fc_1.write_parameters(stream) && ac_1.write_parameters(stream)
97
&& fc_2.write_parameters(stream);
98
}
99
100
std::int32_t propagate(const TransformedFeatureType* transformedFeatures) {
101
struct alignas(CacheLineSize) Buffer {
102
alignas(CacheLineSize) typename decltype(fc_0)::OutputBuffer fc_0_out;
103
alignas(CacheLineSize) typename decltype(ac_sqr_0)::OutputType
104
ac_sqr_0_out[ceil_to_multiple<IndexType>(FC_0_OUTPUTS * 2, 32)];
105
alignas(CacheLineSize) typename decltype(ac_0)::OutputBuffer ac_0_out;
106
alignas(CacheLineSize) typename decltype(fc_1)::OutputBuffer fc_1_out;
107
alignas(CacheLineSize) typename decltype(ac_1)::OutputBuffer ac_1_out;
108
alignas(CacheLineSize) typename decltype(fc_2)::OutputBuffer fc_2_out;
109
110
Buffer() { std::memset(this, 0, sizeof(*this)); }
111
};
112
113
#if defined(__clang__) && (__APPLE__)
114
// workaround for a bug reported with xcode 12
115
static thread_local auto tlsBuffer = std::make_unique<Buffer>();
116
// Access TLS only once, cache result.
117
Buffer& buffer = *tlsBuffer;
118
#else
119
alignas(CacheLineSize) static thread_local Buffer buffer;
120
#endif
121
122
fc_0.propagate(transformedFeatures, buffer.fc_0_out);
123
ac_sqr_0.propagate(buffer.fc_0_out, buffer.ac_sqr_0_out);
124
ac_0.propagate(buffer.fc_0_out, buffer.ac_0_out);
125
std::memcpy(buffer.ac_sqr_0_out + FC_0_OUTPUTS, buffer.ac_0_out,
126
FC_0_OUTPUTS * sizeof(typename decltype(ac_0)::OutputType));
127
fc_1.propagate(buffer.ac_sqr_0_out, buffer.fc_1_out);
128
ac_1.propagate(buffer.fc_1_out, buffer.ac_1_out);
129
fc_2.propagate(buffer.ac_1_out, buffer.fc_2_out);
130
131
// buffer.fc_0_out[FC_0_OUTPUTS] is such that 1.0 is equal to 127*(1<<WeightScaleBits) in
132
// quantized form, but we want 1.0 to be equal to 600*OutputScale
133
std::int32_t fwdOut =
134
(buffer.fc_0_out[FC_0_OUTPUTS]) * (600 * OutputScale) / (127 * (1 << WeightScaleBits));
135
std::int32_t outputValue = buffer.fc_2_out[0] + fwdOut;
136
137
return outputValue;
138
}
139
};
140
141
} // namespace Stockfish::Eval::NNUE
142
143
#endif // #ifndef NNUE_ARCHITECTURE_H_INCLUDED
144
145