#ifndef NETWORK_H_INCLUDED
#define NETWORK_H_INCLUDED
#include <cstdint>
#include <functional>
#include <iostream>
#include <optional>
#include <string>
#include <string_view>
#include <tuple>
#include <utility>
#include "../memory.h"
#include "../types.h"
#include "nnue_accumulator.h"
#include "nnue_architecture.h"
#include "nnue_common.h"
#include "nnue_feature_transformer.h"
#include "nnue_misc.h"
namespace Stockfish {
class Position;
}
namespace Stockfish::Eval::NNUE {
enum class EmbeddedNNUEType {
BIG,
SMALL,
};
using NetworkOutput = std::tuple<Value, Value>;
template<typename Arch, typename Transformer>
class Network {
static constexpr IndexType FTDimensions = Arch::TransformedFeatureDimensions;
public:
Network(EvalFile file, EmbeddedNNUEType type) :
evalFile(file),
embeddedType(type) {}
Network(const Network& other);
Network(Network&& other) = default;
Network& operator=(const Network& other);
Network& operator=(Network&& other) = default;
void load(const std::string& rootDirectory, std::string evalfilePath);
bool save(const std::optional<std::string>& filename) const;
NetworkOutput evaluate(const Position& pos,
AccumulatorStack& accumulatorStack,
AccumulatorCaches::Cache<FTDimensions>* cache) const;
void verify(std::string evalfilePath, const std::function<void(std::string_view)>&) const;
NnueEvalTrace trace_evaluate(const Position& pos,
AccumulatorStack& accumulatorStack,
AccumulatorCaches::Cache<FTDimensions>* cache) const;
private:
void load_user_net(const std::string&, const std::string&);
void load_internal();
void initialize();
bool save(std::ostream&, const std::string&, const std::string&) const;
std::optional<std::string> load(std::istream&);
bool read_header(std::istream&, std::uint32_t*, std::string*) const;
bool write_header(std::ostream&, std::uint32_t, const std::string&) const;
bool read_parameters(std::istream&, std::string&) const;
bool write_parameters(std::ostream&, const std::string&) const;
LargePagePtr<Transformer> featureTransformer;
AlignedPtr<Arch[]> network;
EvalFile evalFile;
EmbeddedNNUEType embeddedType;
static constexpr std::uint32_t hash = Transformer::get_hash_value() ^ Arch::get_hash_value();
template<IndexType Size>
friend struct AccumulatorCaches::Cache;
friend class AccumulatorStack;
};
using SmallFeatureTransformer = FeatureTransformer<TransformedFeatureDimensionsSmall>;
using SmallNetworkArchitecture =
NetworkArchitecture<TransformedFeatureDimensionsSmall, L2Small, L3Small>;
using BigFeatureTransformer = FeatureTransformer<TransformedFeatureDimensionsBig>;
using BigNetworkArchitecture = NetworkArchitecture<TransformedFeatureDimensionsBig, L2Big, L3Big>;
using NetworkBig = Network<BigNetworkArchitecture, BigFeatureTransformer>;
using NetworkSmall = Network<SmallNetworkArchitecture, SmallFeatureTransformer>;
struct Networks {
Networks(NetworkBig&& nB, NetworkSmall&& nS) :
big(std::move(nB)),
small(std::move(nS)) {}
NetworkBig big;
NetworkSmall small;
};
}
#endif