Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
official-stockfish
GitHub Repository: official-stockfish/Stockfish
Path: blob/master/src/nnue/nnue_architecture.h
647 views
1
/*
2
Stockfish, a UCI chess playing engine derived from Glaurung 2.1
3
Copyright (C) 2004-2026 The Stockfish developers (see AUTHORS file)
4
5
Stockfish is free software: you can redistribute it and/or modify
6
it under the terms of the GNU General Public License as published by
7
the Free Software Foundation, either version 3 of the License, or
8
(at your option) any later version.
9
10
Stockfish is distributed in the hope that it will be useful,
11
but WITHOUT ANY WARRANTY; without even the implied warranty of
12
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
GNU General Public License for more details.
14
15
You should have received a copy of the GNU General Public License
16
along with this program. If not, see <http://www.gnu.org/licenses/>.
17
*/
18
19
// Input features and network structure used in NNUE evaluation function
20
21
#ifndef NNUE_ARCHITECTURE_H_INCLUDED
22
#define NNUE_ARCHITECTURE_H_INCLUDED
23
24
#include <cstdint>
25
#include <cstring>
26
#include <iosfwd>
27
28
#include "features/half_ka_v2_hm.h"
29
#include "features/full_threats.h"
30
#include "layers/affine_transform.h"
31
#include "layers/affine_transform_sparse_input.h"
32
#include "layers/clipped_relu.h"
33
#include "layers/sqr_clipped_relu.h"
34
#include "nnue_common.h"
35
36
namespace Stockfish::Eval::NNUE {
37
38
// Input features used in evaluation function
39
using ThreatFeatureSet = Features::FullThreats;
40
using PSQFeatureSet = Features::HalfKAv2_hm;
41
42
// Number of input feature dimensions after conversion
43
constexpr IndexType TransformedFeatureDimensionsBig = 1024;
44
constexpr int L2Big = 15;
45
constexpr int L3Big = 32;
46
47
constexpr IndexType TransformedFeatureDimensionsSmall = 128;
48
constexpr int L2Small = 15;
49
constexpr int L3Small = 32;
50
51
constexpr IndexType PSQTBuckets = 8;
52
constexpr IndexType LayerStacks = 8;
53
54
// If vector instructions are enabled, we update and refresh the
55
// accumulator tile by tile such that each tile fits in the CPU's
56
// vector registers.
57
static_assert(PSQTBuckets % 8 == 0,
58
"Per feature PSQT values cannot be processed at granularity lower than 8 at a time.");
59
60
template<IndexType L1, int L2, int L3>
61
struct NetworkArchitecture {
62
static constexpr IndexType TransformedFeatureDimensions = L1;
63
static constexpr int FC_0_OUTPUTS = L2;
64
static constexpr int FC_1_OUTPUTS = L3;
65
66
Layers::AffineTransformSparseInput<TransformedFeatureDimensions, FC_0_OUTPUTS + 1> fc_0;
67
Layers::SqrClippedReLU<FC_0_OUTPUTS + 1> ac_sqr_0;
68
Layers::ClippedReLU<FC_0_OUTPUTS + 1> ac_0;
69
Layers::AffineTransform<FC_0_OUTPUTS * 2, FC_1_OUTPUTS> fc_1;
70
Layers::ClippedReLU<FC_1_OUTPUTS> ac_1;
71
Layers::AffineTransform<FC_1_OUTPUTS, 1> fc_2;
72
73
// Hash value embedded in the evaluation file
74
static constexpr std::uint32_t get_hash_value() {
75
// input slice hash
76
std::uint32_t hashValue = 0xEC42E90Du;
77
hashValue ^= TransformedFeatureDimensions * 2;
78
79
hashValue = decltype(fc_0)::get_hash_value(hashValue);
80
hashValue = decltype(ac_0)::get_hash_value(hashValue);
81
hashValue = decltype(fc_1)::get_hash_value(hashValue);
82
hashValue = decltype(ac_1)::get_hash_value(hashValue);
83
hashValue = decltype(fc_2)::get_hash_value(hashValue);
84
85
return hashValue;
86
}
87
88
// Read network parameters
89
bool read_parameters(std::istream& stream) {
90
return fc_0.read_parameters(stream) && ac_0.read_parameters(stream)
91
&& fc_1.read_parameters(stream) && ac_1.read_parameters(stream)
92
&& fc_2.read_parameters(stream);
93
}
94
95
// Write network parameters
96
bool write_parameters(std::ostream& stream) const {
97
return fc_0.write_parameters(stream) && ac_0.write_parameters(stream)
98
&& fc_1.write_parameters(stream) && ac_1.write_parameters(stream)
99
&& fc_2.write_parameters(stream);
100
}
101
102
std::int32_t propagate(const TransformedFeatureType* transformedFeatures) const {
103
struct alignas(CacheLineSize) Buffer {
104
alignas(CacheLineSize) typename decltype(fc_0)::OutputBuffer fc_0_out;
105
alignas(CacheLineSize) typename decltype(ac_sqr_0)::OutputType
106
ac_sqr_0_out[ceil_to_multiple<IndexType>(FC_0_OUTPUTS * 2, 32)];
107
alignas(CacheLineSize) typename decltype(ac_0)::OutputBuffer ac_0_out;
108
alignas(CacheLineSize) typename decltype(fc_1)::OutputBuffer fc_1_out;
109
alignas(CacheLineSize) typename decltype(ac_1)::OutputBuffer ac_1_out;
110
alignas(CacheLineSize) typename decltype(fc_2)::OutputBuffer fc_2_out;
111
112
Buffer() { std::memset(this, 0, sizeof(*this)); }
113
};
114
115
#if defined(__clang__) && (__APPLE__)
116
// workaround for a bug reported with xcode 12
117
static thread_local auto tlsBuffer = std::make_unique<Buffer>();
118
// Access TLS only once, cache result.
119
Buffer& buffer = *tlsBuffer;
120
#else
121
alignas(CacheLineSize) static thread_local Buffer buffer;
122
#endif
123
124
fc_0.propagate(transformedFeatures, buffer.fc_0_out);
125
ac_sqr_0.propagate(buffer.fc_0_out, buffer.ac_sqr_0_out);
126
ac_0.propagate(buffer.fc_0_out, buffer.ac_0_out);
127
std::memcpy(buffer.ac_sqr_0_out + FC_0_OUTPUTS, buffer.ac_0_out,
128
FC_0_OUTPUTS * sizeof(typename decltype(ac_0)::OutputType));
129
fc_1.propagate(buffer.ac_sqr_0_out, buffer.fc_1_out);
130
ac_1.propagate(buffer.fc_1_out, buffer.ac_1_out);
131
fc_2.propagate(buffer.ac_1_out, buffer.fc_2_out);
132
133
// buffer.fc_0_out[FC_0_OUTPUTS] is such that 1.0 is equal to 127*(1<<WeightScaleBits) in
134
// quantized form, but we want 1.0 to be equal to 600*OutputScale
135
std::int32_t fwdOut =
136
(buffer.fc_0_out[FC_0_OUTPUTS]) * (600 * OutputScale) / (127 * (1 << WeightScaleBits));
137
std::int32_t outputValue = buffer.fc_2_out[0] + fwdOut;
138
139
return outputValue;
140
}
141
142
std::size_t get_content_hash() const {
143
std::size_t h = 0;
144
hash_combine(h, fc_0.get_content_hash());
145
hash_combine(h, ac_sqr_0.get_content_hash());
146
hash_combine(h, ac_0.get_content_hash());
147
hash_combine(h, fc_1.get_content_hash());
148
hash_combine(h, ac_1.get_content_hash());
149
hash_combine(h, fc_2.get_content_hash());
150
hash_combine(h, get_hash_value());
151
return h;
152
}
153
};
154
155
} // namespace Stockfish::Eval::NNUE
156
157
template<Stockfish::Eval::NNUE::IndexType L1, int L2, int L3>
158
struct std::hash<Stockfish::Eval::NNUE::NetworkArchitecture<L1, L2, L3>> {
159
std::size_t
160
operator()(const Stockfish::Eval::NNUE::NetworkArchitecture<L1, L2, L3>& arch) const noexcept {
161
return arch.get_content_hash();
162
}
163
};
164
165
#endif // #ifndef NNUE_ARCHITECTURE_H_INCLUDED
166
167