Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
official-stockfish
GitHub Repository: official-stockfish/Stockfish
Path: blob/master/src/nnue/network.h
375 views
1
/*
2
Stockfish, a UCI chess playing engine derived from Glaurung 2.1
3
Copyright (C) 2004-2025 The Stockfish developers (see AUTHORS file)
4
5
Stockfish is free software: you can redistribute it and/or modify
6
it under the terms of the GNU General Public License as published by
7
the Free Software Foundation, either version 3 of the License, or
8
(at your option) any later version.
9
10
Stockfish is distributed in the hope that it will be useful,
11
but WITHOUT ANY WARRANTY; without even the implied warranty of
12
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
GNU General Public License for more details.
14
15
You should have received a copy of the GNU General Public License
16
along with this program. If not, see <http://www.gnu.org/licenses/>.
17
*/
18
19
#ifndef NETWORK_H_INCLUDED
20
#define NETWORK_H_INCLUDED
21
22
#include <cstdint>
23
#include <functional>
24
#include <iostream>
25
#include <optional>
26
#include <string>
27
#include <string_view>
28
#include <tuple>
29
#include <utility>
30
31
#include "../memory.h"
32
#include "../types.h"
33
#include "nnue_accumulator.h"
34
#include "nnue_architecture.h"
35
#include "nnue_common.h"
36
#include "nnue_feature_transformer.h"
37
#include "nnue_misc.h"
38
39
namespace Stockfish {
40
class Position;
41
}
42
43
namespace Stockfish::Eval::NNUE {
44
45
enum class EmbeddedNNUEType {
46
BIG,
47
SMALL,
48
};
49
50
using NetworkOutput = std::tuple<Value, Value>;
51
52
template<typename Arch, typename Transformer>
53
class Network {
54
static constexpr IndexType FTDimensions = Arch::TransformedFeatureDimensions;
55
56
public:
57
Network(EvalFile file, EmbeddedNNUEType type) :
58
evalFile(file),
59
embeddedType(type) {}
60
61
Network(const Network& other);
62
Network(Network&& other) = default;
63
64
Network& operator=(const Network& other);
65
Network& operator=(Network&& other) = default;
66
67
void load(const std::string& rootDirectory, std::string evalfilePath);
68
bool save(const std::optional<std::string>& filename) const;
69
70
NetworkOutput evaluate(const Position& pos,
71
AccumulatorStack& accumulatorStack,
72
AccumulatorCaches::Cache<FTDimensions>* cache) const;
73
74
75
void verify(std::string evalfilePath, const std::function<void(std::string_view)>&) const;
76
NnueEvalTrace trace_evaluate(const Position& pos,
77
AccumulatorStack& accumulatorStack,
78
AccumulatorCaches::Cache<FTDimensions>* cache) const;
79
80
private:
81
void load_user_net(const std::string&, const std::string&);
82
void load_internal();
83
84
void initialize();
85
86
bool save(std::ostream&, const std::string&, const std::string&) const;
87
std::optional<std::string> load(std::istream&);
88
89
bool read_header(std::istream&, std::uint32_t*, std::string*) const;
90
bool write_header(std::ostream&, std::uint32_t, const std::string&) const;
91
92
bool read_parameters(std::istream&, std::string&) const;
93
bool write_parameters(std::ostream&, const std::string&) const;
94
95
// Input feature converter
96
LargePagePtr<Transformer> featureTransformer;
97
98
// Evaluation function
99
AlignedPtr<Arch[]> network;
100
101
EvalFile evalFile;
102
EmbeddedNNUEType embeddedType;
103
104
// Hash value of evaluation function structure
105
static constexpr std::uint32_t hash = Transformer::get_hash_value() ^ Arch::get_hash_value();
106
107
template<IndexType Size>
108
friend struct AccumulatorCaches::Cache;
109
110
friend class AccumulatorStack;
111
};
112
113
// Definitions of the network types
114
using SmallFeatureTransformer = FeatureTransformer<TransformedFeatureDimensionsSmall>;
115
using SmallNetworkArchitecture =
116
NetworkArchitecture<TransformedFeatureDimensionsSmall, L2Small, L3Small>;
117
118
using BigFeatureTransformer = FeatureTransformer<TransformedFeatureDimensionsBig>;
119
using BigNetworkArchitecture = NetworkArchitecture<TransformedFeatureDimensionsBig, L2Big, L3Big>;
120
121
using NetworkBig = Network<BigNetworkArchitecture, BigFeatureTransformer>;
122
using NetworkSmall = Network<SmallNetworkArchitecture, SmallFeatureTransformer>;
123
124
125
struct Networks {
126
Networks(NetworkBig&& nB, NetworkSmall&& nS) :
127
big(std::move(nB)),
128
small(std::move(nS)) {}
129
130
NetworkBig big;
131
NetworkSmall small;
132
};
133
134
135
} // namespace Stockfish
136
137
#endif
138
139