Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
official-stockfish
GitHub Repository: official-stockfish/Stockfish
Path: blob/master/src/nnue/network.h
632 views
1
/*
2
Stockfish, a UCI chess playing engine derived from Glaurung 2.1
3
Copyright (C) 2004-2026 The Stockfish developers (see AUTHORS file)
4
5
Stockfish is free software: you can redistribute it and/or modify
6
it under the terms of the GNU General Public License as published by
7
the Free Software Foundation, either version 3 of the License, or
8
(at your option) any later version.
9
10
Stockfish is distributed in the hope that it will be useful,
11
but WITHOUT ANY WARRANTY; without even the implied warranty of
12
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
GNU General Public License for more details.
14
15
You should have received a copy of the GNU General Public License
16
along with this program. If not, see <http://www.gnu.org/licenses/>.
17
*/
18
19
#ifndef NETWORK_H_INCLUDED
20
#define NETWORK_H_INCLUDED
21
22
#include <cstddef>
23
#include <cstdint>
24
#include <functional>
25
#include <iostream>
26
#include <memory>
27
#include <optional>
28
#include <string>
29
#include <string_view>
30
#include <tuple>
31
32
#include "../misc.h"
33
#include "../types.h"
34
#include "nnue_accumulator.h"
35
#include "nnue_architecture.h"
36
#include "nnue_common.h"
37
#include "nnue_feature_transformer.h"
38
#include "nnue_misc.h"
39
40
namespace Stockfish {
41
class Position;
42
}
43
44
namespace Stockfish::Eval::NNUE {
45
46
enum class EmbeddedNNUEType {
47
BIG,
48
SMALL,
49
};
50
51
using NetworkOutput = std::tuple<Value, Value>;
52
53
// The network must be a trivial type, i.e. the memory must be in-line.
54
// This is required to allow sharing the network via shared memory, as
55
// there is no way to run destructors.
56
template<typename Arch, typename Transformer>
57
class Network {
58
static constexpr IndexType FTDimensions = Arch::TransformedFeatureDimensions;
59
60
public:
61
Network(EvalFile file, EmbeddedNNUEType type) :
62
evalFile(file),
63
embeddedType(type) {}
64
65
Network(const Network& other) = default;
66
Network(Network&& other) = default;
67
68
Network& operator=(const Network& other) = default;
69
Network& operator=(Network&& other) = default;
70
71
void load(const std::string& rootDirectory, std::string evalfilePath);
72
bool save(const std::optional<std::string>& filename) const;
73
74
std::size_t get_content_hash() const;
75
76
NetworkOutput evaluate(const Position& pos,
77
AccumulatorStack& accumulatorStack,
78
AccumulatorCaches::Cache<FTDimensions>& cache) const;
79
80
81
void verify(std::string evalfilePath, const std::function<void(std::string_view)>&) const;
82
NnueEvalTrace trace_evaluate(const Position& pos,
83
AccumulatorStack& accumulatorStack,
84
AccumulatorCaches::Cache<FTDimensions>& cache) const;
85
86
private:
87
void load_user_net(const std::string&, const std::string&);
88
void load_internal();
89
90
void initialize();
91
92
bool save(std::ostream&, const std::string&, const std::string&) const;
93
std::optional<std::string> load(std::istream&);
94
95
bool read_header(std::istream&, std::uint32_t*, std::string*) const;
96
bool write_header(std::ostream&, std::uint32_t, const std::string&) const;
97
98
bool read_parameters(std::istream&, std::string&);
99
bool write_parameters(std::ostream&, const std::string&) const;
100
101
// Input feature converter
102
Transformer featureTransformer;
103
104
// Evaluation function
105
Arch network[LayerStacks];
106
107
EvalFile evalFile;
108
EmbeddedNNUEType embeddedType;
109
110
bool initialized = false;
111
112
// Hash value of evaluation function structure
113
static constexpr std::uint32_t hash = Transformer::get_hash_value() ^ Arch::get_hash_value();
114
115
template<IndexType Size>
116
friend struct AccumulatorCaches::Cache;
117
};
118
119
// Definitions of the network types
120
using SmallFeatureTransformer = FeatureTransformer<TransformedFeatureDimensionsSmall>;
121
using SmallNetworkArchitecture =
122
NetworkArchitecture<TransformedFeatureDimensionsSmall, L2Small, L3Small>;
123
124
using BigFeatureTransformer = FeatureTransformer<TransformedFeatureDimensionsBig>;
125
using BigNetworkArchitecture = NetworkArchitecture<TransformedFeatureDimensionsBig, L2Big, L3Big>;
126
127
using NetworkBig = Network<BigNetworkArchitecture, BigFeatureTransformer>;
128
using NetworkSmall = Network<SmallNetworkArchitecture, SmallFeatureTransformer>;
129
130
131
struct Networks {
132
Networks(EvalFile bigFile, EvalFile smallFile) :
133
big(bigFile, EmbeddedNNUEType::BIG),
134
small(smallFile, EmbeddedNNUEType::SMALL) {}
135
136
NetworkBig big;
137
NetworkSmall small;
138
};
139
140
141
} // namespace Stockfish
142
143
template<typename ArchT, typename FeatureTransformerT>
144
struct std::hash<Stockfish::Eval::NNUE::Network<ArchT, FeatureTransformerT>> {
145
std::size_t operator()(
146
const Stockfish::Eval::NNUE::Network<ArchT, FeatureTransformerT>& network) const noexcept {
147
return network.get_content_hash();
148
}
149
};
150
151
template<>
152
struct std::hash<Stockfish::Eval::NNUE::Networks> {
153
std::size_t operator()(const Stockfish::Eval::NNUE::Networks& networks) const noexcept {
154
std::size_t h = 0;
155
Stockfish::hash_combine(h, networks.big);
156
Stockfish::hash_combine(h, networks.small);
157
return h;
158
}
159
};
160
161
#endif
162
163