Path: blob/master/src/nnue/nnue_feature_transformer.h
649 views
/*1Stockfish, a UCI chess playing engine derived from Glaurung 2.12Copyright (C) 2004-2026 The Stockfish developers (see AUTHORS file)34Stockfish is free software: you can redistribute it and/or modify5it under the terms of the GNU General Public License as published by6the Free Software Foundation, either version 3 of the License, or7(at your option) any later version.89Stockfish is distributed in the hope that it will be useful,10but WITHOUT ANY WARRANTY; without even the implied warranty of11MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the12GNU General Public License for more details.1314You should have received a copy of the GNU General Public License15along with this program. If not, see <http://www.gnu.org/licenses/>.16*/1718// A class that converts the input features of the NNUE evaluation function1920#ifndef NNUE_FEATURE_TRANSFORMER_H_INCLUDED21#define NNUE_FEATURE_TRANSFORMER_H_INCLUDED2223#include <algorithm>24#include <cstdint>25#include <cstring>26#include <iosfwd>27#include <iterator>2829#include "../position.h"30#include "../types.h"31#include "nnue_accumulator.h"32#include "nnue_architecture.h"33#include "nnue_common.h"34#include "simd.h"3536namespace Stockfish::Eval::NNUE {3738// Returns the inverse of a permutation39template<std::size_t Len>40constexpr std::array<std::size_t, Len>41invert_permutation(const std::array<std::size_t, Len>& order) {42std::array<std::size_t, Len> inverse{};43for (std::size_t i = 0; i < order.size(); i++)44inverse[order[i]] = i;45return inverse;46}4748// Divide a byte region of size TotalSize to chunks of size49// BlockSize, and permute the blocks by a given order50template<std::size_t BlockSize, typename T, std::size_t N, std::size_t OrderSize>51void permute(std::array<T, N>& data, const std::array<std::size_t, OrderSize>& order) {52constexpr std::size_t TotalSize = N * sizeof(T);5354static_assert(TotalSize % (BlockSize * OrderSize) == 0,55"ChunkSize * OrderSize must perfectly divide TotalSize");5657constexpr std::size_t ProcessChunkSize = BlockSize * OrderSize;5859std::array<std::byte, ProcessChunkSize> buffer{};6061std::byte* const bytes = reinterpret_cast<std::byte*>(data.data());6263for (std::size_t i = 0; i < TotalSize; i += ProcessChunkSize)64{65std::byte* const values = &bytes[i];6667for (std::size_t j = 0; j < OrderSize; j++)68{69auto* const buffer_chunk = &buffer[j * BlockSize];70auto* const value_chunk = &values[order[j] * BlockSize];7172std::copy(value_chunk, value_chunk + BlockSize, buffer_chunk);73}7475std::copy(std::begin(buffer), std::end(buffer), values);76}77}7879// Input feature converter80template<IndexType TransformedFeatureDimensions>81class FeatureTransformer {82static constexpr bool UseThreats =83(TransformedFeatureDimensions == TransformedFeatureDimensionsBig);84// Number of output dimensions for one side85static constexpr IndexType HalfDimensions = TransformedFeatureDimensions;8687public:88// Output type89using OutputType = TransformedFeatureType;9091// Number of input/output dimensions92static constexpr IndexType InputDimensions = PSQFeatureSet::Dimensions;93static constexpr IndexType ThreatInputDimensions = ThreatFeatureSet::Dimensions;94static constexpr IndexType TotalInputDimensions =95InputDimensions + (UseThreats ? ThreatInputDimensions : 0);96static constexpr IndexType OutputDimensions = HalfDimensions;9798// Size of forward propagation buffer99static constexpr std::size_t BufferSize = OutputDimensions * sizeof(OutputType);100101// Store the order by which 128-bit blocks of a 1024-bit data must102// be permuted so that calling packus on adjacent vectors of 16-bit103// integers loaded from the data results in the pre-permutation order104static constexpr auto PackusEpi16Order = []() -> std::array<std::size_t, 8> {105#if defined(USE_AVX512)106// _mm512_packus_epi16 after permutation:107// | 0 | 2 | 4 | 6 | // Vector 0108// | 1 | 3 | 5 | 7 | // Vector 1109// | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | // Packed Result110return {0, 2, 4, 6, 1, 3, 5, 7};111#elif defined(USE_AVX2)112// _mm256_packus_epi16 after permutation:113// | 0 | 2 | | 4 | 6 | // Vector 0, 2114// | 1 | 3 | | 5 | 7 | // Vector 1, 3115// | 0 | 1 | 2 | 3 | | 4 | 5 | 6 | 7 | // Packed Result116return {0, 2, 1, 3, 4, 6, 5, 7};117#else118return {0, 1, 2, 3, 4, 5, 6, 7};119#endif120}();121122static constexpr auto InversePackusEpi16Order = invert_permutation(PackusEpi16Order);123124// Hash value embedded in the evaluation file125static constexpr std::uint32_t get_hash_value() {126return (UseThreats ? ThreatFeatureSet::HashValue : PSQFeatureSet::HashValue)127^ (OutputDimensions * 2);128}129130void permute_weights() {131permute<16>(biases, PackusEpi16Order);132permute<16>(weights, PackusEpi16Order);133134if constexpr (UseThreats)135permute<8>(threatWeights, PackusEpi16Order);136}137138void unpermute_weights() {139permute<16>(biases, InversePackusEpi16Order);140permute<16>(weights, InversePackusEpi16Order);141142if constexpr (UseThreats)143permute<8>(threatWeights, InversePackusEpi16Order);144}145146inline void scale_weights(bool read) {147for (auto& w : weights)148w = read ? w * 2 : w / 2;149for (auto& b : biases)150b = read ? b * 2 : b / 2;151}152153// Read network parameters154bool read_parameters(std::istream& stream) {155read_leb_128(stream, biases);156157if constexpr (UseThreats)158{159read_little_endian<ThreatWeightType>(stream, threatWeights.data(),160ThreatInputDimensions * HalfDimensions);161read_leb_128(stream, weights);162163read_leb_128(stream, threatPsqtWeights, psqtWeights);164}165else166{167read_leb_128(stream, weights);168read_leb_128(stream, psqtWeights);169}170171permute_weights();172173if constexpr (!UseThreats)174scale_weights(true);175176return !stream.fail();177}178179// Write network parameters180bool write_parameters(std::ostream& stream) const {181std::unique_ptr<FeatureTransformer> copy = std::make_unique<FeatureTransformer>(*this);182183copy->unpermute_weights();184185if constexpr (!UseThreats)186copy->scale_weights(false);187188write_leb_128<BiasType>(stream, copy->biases);189190if constexpr (UseThreats)191{192write_little_endian<ThreatWeightType>(stream, copy->threatWeights.data(),193ThreatInputDimensions * HalfDimensions);194write_leb_128<WeightType>(stream, copy->weights);195196auto combinedPsqtWeights =197std::make_unique<std::array<PSQTWeightType, TotalInputDimensions * PSQTBuckets>>();198199std::copy(std::begin(copy->threatPsqtWeights),200std::begin(copy->threatPsqtWeights) + ThreatInputDimensions * PSQTBuckets,201combinedPsqtWeights->begin());202203std::copy(std::begin(copy->psqtWeights),204std::begin(copy->psqtWeights) + InputDimensions * PSQTBuckets,205combinedPsqtWeights->begin() + ThreatInputDimensions * PSQTBuckets);206207write_leb_128<PSQTWeightType>(stream, *combinedPsqtWeights);208}209else210{211write_leb_128<WeightType>(stream, copy->weights);212write_leb_128<PSQTWeightType>(stream, copy->psqtWeights);213}214215return !stream.fail();216}217218std::size_t get_content_hash() const {219std::size_t h = 0;220221hash_combine(h, get_raw_data_hash(biases));222hash_combine(h, get_raw_data_hash(weights));223hash_combine(h, get_raw_data_hash(psqtWeights));224225if constexpr (UseThreats)226{227hash_combine(h, get_raw_data_hash(threatWeights));228hash_combine(h, get_raw_data_hash(threatPsqtWeights));229}230231hash_combine(h, get_hash_value());232233return h;234}235236// Convert input features237std::int32_t transform(const Position& pos,238AccumulatorStack& accumulatorStack,239AccumulatorCaches::Cache<HalfDimensions>& cache,240OutputType* output,241int bucket) const {242243using namespace SIMD;244accumulatorStack.evaluate(pos, *this, cache);245const auto& accumulatorState = accumulatorStack.latest<PSQFeatureSet>();246const auto& threatAccumulatorState = accumulatorStack.latest<ThreatFeatureSet>();247248const Color perspectives[2] = {pos.side_to_move(), ~pos.side_to_move()};249const auto& psqtAccumulation = (accumulatorState.acc<HalfDimensions>()).psqtAccumulation;250auto psqt =251(psqtAccumulation[perspectives[0]][bucket] - psqtAccumulation[perspectives[1]][bucket]);252253if constexpr (UseThreats)254{255const auto& threatPsqtAccumulation =256(threatAccumulatorState.acc<HalfDimensions>()).psqtAccumulation;257psqt = (psqt + threatPsqtAccumulation[perspectives[0]][bucket]258- threatPsqtAccumulation[perspectives[1]][bucket])259/ 2;260}261else262psqt /= 2;263264const auto& accumulation = (accumulatorState.acc<HalfDimensions>()).accumulation;265const auto& threatAccumulation =266(threatAccumulatorState.acc<HalfDimensions>()).accumulation;267268for (IndexType p = 0; p < 2; ++p)269{270const IndexType offset = (HalfDimensions / 2) * p;271272#if defined(VECTOR)273274constexpr IndexType OutputChunkSize = MaxChunkSize;275static_assert((HalfDimensions / 2) % OutputChunkSize == 0);276constexpr IndexType NumOutputChunks = HalfDimensions / 2 / OutputChunkSize;277278const vec_t Zero = vec_zero();279const vec_t One = vec_set_16(UseThreats ? 255 : 127 * 2);280281const vec_t* in0 = reinterpret_cast<const vec_t*>(&(accumulation[perspectives[p]][0]));282const vec_t* in1 =283reinterpret_cast<const vec_t*>(&(accumulation[perspectives[p]][HalfDimensions / 2]));284vec_t* out = reinterpret_cast<vec_t*>(output + offset);285286// Per the NNUE architecture, here we want to multiply pairs of287// clipped elements and divide the product by 128. To do this,288// we can naively perform min/max operation to clip each of the289// four int16 vectors, mullo pairs together, then pack them into290// one int8 vector. However, there exists a faster way.291292// The idea here is to use the implicit clipping from packus to293// save us two vec_max_16 instructions. This clipping works due294// to the fact that any int16 integer below zero will be zeroed295// on packus.296297// Consider the case where the second element is negative.298// If we do standard clipping, that element will be zero, which299// means our pairwise product is zero. If we perform packus and300// remove the lower-side clip for the second element, then our301// product before packus will be negative, and is zeroed on pack.302// The two operation produce equivalent results, but the second303// one (using packus) saves one max operation per pair.304305// But here we run into a problem: mullo does not preserve the306// sign of the multiplication. We can get around this by doing307// mulhi, which keeps the sign. But that requires an additional308// tweak.309310// mulhi cuts off the last 16 bits of the resulting product,311// which is the same as performing a rightward shift of 16 bits.312// We can use this to our advantage. Recall that we want to313// divide the final product by 128, which is equivalent to a314// 7-bit right shift. Intuitively, if we shift the clipped315// value left by 9, and perform mulhi, which shifts the product316// right by 16 bits, then we will net a right shift of 7 bits.317// However, this won't work as intended. Since we clip the318// values to have a maximum value of 127, shifting it by 9 bits319// might occupy the signed bit, resulting in some positive320// values being interpreted as negative after the shift.321322// There is a way, however, to get around this limitation. When323// loading the network, scale accumulator weights and biases by324// 2. To get the same pairwise multiplication result as before,325// we need to divide the product by 128 * 2 * 2 = 512, which326// amounts to a right shift of 9 bits. So now we only have to327// shift left by 7 bits, perform mulhi (shifts right by 16 bits)328// and net a 9 bit right shift. Since we scaled everything by329// two, the values are clipped at 127 * 2 = 254, which occupies330// 8 bits. Shifting it by 7 bits left will no longer occupy the331// signed bit, so we are safe.332333// Note that on NEON processors, we shift left by 6 instead334// because the instruction "vqdmulhq_s16" also doubles the335// return value after the multiplication, adding an extra shift336// to the left by 1, so we compensate by shifting less before337// the multiplication.338339constexpr int shift =340#if defined(USE_SSE2)3417;342#else3436;344#endif345if constexpr (UseThreats)346{347const vec_t* tin0 =348reinterpret_cast<const vec_t*>(&(threatAccumulation[perspectives[p]][0]));349const vec_t* tin1 = reinterpret_cast<const vec_t*>(350&(threatAccumulation[perspectives[p]][HalfDimensions / 2]));351for (IndexType j = 0; j < NumOutputChunks; ++j)352{353const vec_t acc0a = vec_add_16(in0[j * 2 + 0], tin0[j * 2 + 0]);354const vec_t acc0b = vec_add_16(in0[j * 2 + 1], tin0[j * 2 + 1]);355const vec_t acc1a = vec_add_16(in1[j * 2 + 0], tin1[j * 2 + 0]);356const vec_t acc1b = vec_add_16(in1[j * 2 + 1], tin1[j * 2 + 1]);357358const vec_t sum0a =359vec_slli_16(vec_max_16(vec_min_16(acc0a, One), Zero), shift);360const vec_t sum0b =361vec_slli_16(vec_max_16(vec_min_16(acc0b, One), Zero), shift);362const vec_t sum1a = vec_min_16(acc1a, One);363const vec_t sum1b = vec_min_16(acc1b, One);364365const vec_t pa = vec_mulhi_16(sum0a, sum1a);366const vec_t pb = vec_mulhi_16(sum0b, sum1b);367368out[j] = vec_packus_16(pa, pb);369}370}371else372{373for (IndexType j = 0; j < NumOutputChunks; ++j)374{375const vec_t sum0a =376vec_slli_16(vec_max_16(vec_min_16(in0[j * 2 + 0], One), Zero), shift);377const vec_t sum0b =378vec_slli_16(vec_max_16(vec_min_16(in0[j * 2 + 1], One), Zero), shift);379const vec_t sum1a = vec_min_16(in1[j * 2 + 0], One);380const vec_t sum1b = vec_min_16(in1[j * 2 + 1], One);381382const vec_t pa = vec_mulhi_16(sum0a, sum1a);383const vec_t pb = vec_mulhi_16(sum0b, sum1b);384385out[j] = vec_packus_16(pa, pb);386}387}388389#else390391for (IndexType j = 0; j < HalfDimensions / 2; ++j)392{393BiasType sum0 = accumulation[static_cast<int>(perspectives[p])][j + 0];394BiasType sum1 =395accumulation[static_cast<int>(perspectives[p])][j + HalfDimensions / 2];396397if constexpr (UseThreats)398{399BiasType sum0t = threatAccumulation[static_cast<int>(perspectives[p])][j + 0];400BiasType sum1t =401threatAccumulation[static_cast<int>(perspectives[p])][j + HalfDimensions / 2];402sum0 = std::clamp<BiasType>(sum0 + sum0t, 0, 255);403sum1 = std::clamp<BiasType>(sum1 + sum1t, 0, 255);404}405else406{407sum0 = std::clamp<BiasType>(sum0, 0, 127 * 2);408sum1 = std::clamp<BiasType>(sum1, 0, 127 * 2);409}410411output[offset + j] = static_cast<OutputType>(unsigned(sum0 * sum1) / 512);412}413414#endif415}416417return psqt;418} // end of function transform()419420alignas(CacheLineSize) std::array<BiasType, HalfDimensions> biases;421alignas(CacheLineSize) std::array<WeightType, HalfDimensions * InputDimensions> weights;422alignas(CacheLineSize)423std::array<ThreatWeightType,424UseThreats ? HalfDimensions * ThreatInputDimensions : 0> threatWeights;425alignas(CacheLineSize) std::array<PSQTWeightType, InputDimensions * PSQTBuckets> psqtWeights;426alignas(CacheLineSize)427std::array<PSQTWeightType,428UseThreats ? ThreatInputDimensions * PSQTBuckets : 0> threatPsqtWeights;429};430431} // namespace Stockfish::Eval::NNUE432433434template<Stockfish::Eval::NNUE::IndexType TransformedFeatureDimensions>435struct std::hash<Stockfish::Eval::NNUE::FeatureTransformer<TransformedFeatureDimensions>> {436std::size_t437operator()(const Stockfish::Eval::NNUE::FeatureTransformer<TransformedFeatureDimensions>& ft)438const noexcept {439return ft.get_content_hash();440}441};442443#endif // #ifndef NNUE_FEATURE_TRANSFORMER_H_INCLUDED444445446