Path: blob/master/src/nnue/nnue_feature_transformer.h
375 views
/*1Stockfish, a UCI chess playing engine derived from Glaurung 2.12Copyright (C) 2004-2025 The Stockfish developers (see AUTHORS file)34Stockfish is free software: you can redistribute it and/or modify5it under the terms of the GNU General Public License as published by6the Free Software Foundation, either version 3 of the License, or7(at your option) any later version.89Stockfish is distributed in the hope that it will be useful,10but WITHOUT ANY WARRANTY; without even the implied warranty of11MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the12GNU General Public License for more details.1314You should have received a copy of the GNU General Public License15along with this program. If not, see <http://www.gnu.org/licenses/>.16*/1718// A class that converts the input features of the NNUE evaluation function1920#ifndef NNUE_FEATURE_TRANSFORMER_H_INCLUDED21#define NNUE_FEATURE_TRANSFORMER_H_INCLUDED2223#include <algorithm>24#include <cstdint>25#include <cstring>26#include <iosfwd>2728#include "../position.h"29#include "../types.h"30#include "nnue_accumulator.h"31#include "nnue_architecture.h"32#include "nnue_common.h"33#include "simd.h"3435namespace Stockfish::Eval::NNUE {3637// Returns the inverse of a permutation38template<std::size_t Len>39constexpr std::array<std::size_t, Len>40invert_permutation(const std::array<std::size_t, Len>& order) {41std::array<std::size_t, Len> inverse{};42for (std::size_t i = 0; i < order.size(); i++)43inverse[order[i]] = i;44return inverse;45}4647// Divide a byte region of size TotalSize to chunks of size48// BlockSize, and permute the blocks by a given order49template<std::size_t BlockSize, typename T, std::size_t N, std::size_t OrderSize>50void permute(T (&data)[N], const std::array<std::size_t, OrderSize>& order) {51constexpr std::size_t TotalSize = N * sizeof(T);5253static_assert(TotalSize % (BlockSize * OrderSize) == 0,54"ChunkSize * OrderSize must perfectly divide TotalSize");5556constexpr std::size_t ProcessChunkSize = BlockSize * OrderSize;5758std::array<std::byte, ProcessChunkSize> buffer{};5960std::byte* const bytes = reinterpret_cast<std::byte*>(data);6162for (std::size_t i = 0; i < TotalSize; i += ProcessChunkSize)63{64std::byte* const values = &bytes[i];6566for (std::size_t j = 0; j < OrderSize; j++)67{68auto* const buffer_chunk = &buffer[j * BlockSize];69auto* const value_chunk = &values[order[j] * BlockSize];7071std::copy(value_chunk, value_chunk + BlockSize, buffer_chunk);72}7374std::copy(std::begin(buffer), std::end(buffer), values);75}76}7778// Input feature converter79template<IndexType TransformedFeatureDimensions>80class FeatureTransformer {8182// Number of output dimensions for one side83static constexpr IndexType HalfDimensions = TransformedFeatureDimensions;8485public:86// Output type87using OutputType = TransformedFeatureType;8889// Number of input/output dimensions90static constexpr IndexType InputDimensions = FeatureSet::Dimensions;91static constexpr IndexType OutputDimensions = HalfDimensions;9293// Size of forward propagation buffer94static constexpr std::size_t BufferSize = OutputDimensions * sizeof(OutputType);9596// Store the order by which 128-bit blocks of a 1024-bit data must97// be permuted so that calling packus on adjacent vectors of 16-bit98// integers loaded from the data results in the pre-permutation order99static constexpr auto PackusEpi16Order = []() -> std::array<std::size_t, 8> {100#if defined(USE_AVX512)101// _mm512_packus_epi16 after permutation:102// | 0 | 2 | 4 | 6 | // Vector 0103// | 1 | 3 | 5 | 7 | // Vector 1104// | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | // Packed Result105return {0, 2, 4, 6, 1, 3, 5, 7};106#elif defined(USE_AVX2)107// _mm256_packus_epi16 after permutation:108// | 0 | 2 | | 4 | 6 | // Vector 0, 2109// | 1 | 3 | | 5 | 7 | // Vector 1, 3110// | 0 | 1 | 2 | 3 | | 4 | 5 | 6 | 7 | // Packed Result111return {0, 2, 1, 3, 4, 6, 5, 7};112#else113return {0, 1, 2, 3, 4, 5, 6, 7};114#endif115}();116117static constexpr auto InversePackusEpi16Order = invert_permutation(PackusEpi16Order);118119// Hash value embedded in the evaluation file120static constexpr std::uint32_t get_hash_value() {121return FeatureSet::HashValue ^ (OutputDimensions * 2);122}123124void permute_weights() {125permute<16>(biases, PackusEpi16Order);126permute<16>(weights, PackusEpi16Order);127}128129void unpermute_weights() {130permute<16>(biases, InversePackusEpi16Order);131permute<16>(weights, InversePackusEpi16Order);132}133134inline void scale_weights(bool read) {135for (IndexType j = 0; j < InputDimensions; ++j)136{137WeightType* w = &weights[j * HalfDimensions];138for (IndexType i = 0; i < HalfDimensions; ++i)139w[i] = read ? w[i] * 2 : w[i] / 2;140}141142for (IndexType i = 0; i < HalfDimensions; ++i)143biases[i] = read ? biases[i] * 2 : biases[i] / 2;144}145146// Read network parameters147bool read_parameters(std::istream& stream) {148149read_leb_128<BiasType>(stream, biases, HalfDimensions);150read_leb_128<WeightType>(stream, weights, HalfDimensions * InputDimensions);151read_leb_128<PSQTWeightType>(stream, psqtWeights, PSQTBuckets * InputDimensions);152153permute_weights();154scale_weights(true);155return !stream.fail();156}157158// Write network parameters159bool write_parameters(std::ostream& stream) {160161unpermute_weights();162scale_weights(false);163164write_leb_128<BiasType>(stream, biases, HalfDimensions);165write_leb_128<WeightType>(stream, weights, HalfDimensions * InputDimensions);166write_leb_128<PSQTWeightType>(stream, psqtWeights, PSQTBuckets * InputDimensions);167168permute_weights();169scale_weights(true);170return !stream.fail();171}172173// Convert input features174std::int32_t transform(const Position& pos,175AccumulatorStack& accumulatorStack,176AccumulatorCaches::Cache<HalfDimensions>* cache,177OutputType* output,178int bucket) const {179180using namespace SIMD;181182accumulatorStack.evaluate(pos, *this, *cache);183const auto& accumulatorState = accumulatorStack.latest();184185const Color perspectives[2] = {pos.side_to_move(), ~pos.side_to_move()};186const auto& psqtAccumulation = (accumulatorState.acc<HalfDimensions>()).psqtAccumulation;187const auto psqt =188(psqtAccumulation[perspectives[0]][bucket] - psqtAccumulation[perspectives[1]][bucket])189/ 2;190191const auto& accumulation = (accumulatorState.acc<HalfDimensions>()).accumulation;192193for (IndexType p = 0; p < 2; ++p)194{195const IndexType offset = (HalfDimensions / 2) * p;196197#if defined(VECTOR)198199constexpr IndexType OutputChunkSize = MaxChunkSize;200static_assert((HalfDimensions / 2) % OutputChunkSize == 0);201constexpr IndexType NumOutputChunks = HalfDimensions / 2 / OutputChunkSize;202203const vec_t Zero = vec_zero();204const vec_t One = vec_set_16(127 * 2);205206const vec_t* in0 = reinterpret_cast<const vec_t*>(&(accumulation[perspectives[p]][0]));207const vec_t* in1 =208reinterpret_cast<const vec_t*>(&(accumulation[perspectives[p]][HalfDimensions / 2]));209vec_t* out = reinterpret_cast<vec_t*>(output + offset);210211// Per the NNUE architecture, here we want to multiply pairs of212// clipped elements and divide the product by 128. To do this,213// we can naively perform min/max operation to clip each of the214// four int16 vectors, mullo pairs together, then pack them into215// one int8 vector. However, there exists a faster way.216217// The idea here is to use the implicit clipping from packus to218// save us two vec_max_16 instructions. This clipping works due219// to the fact that any int16 integer below zero will be zeroed220// on packus.221222// Consider the case where the second element is negative.223// If we do standard clipping, that element will be zero, which224// means our pairwise product is zero. If we perform packus and225// remove the lower-side clip for the second element, then our226// product before packus will be negative, and is zeroed on pack.227// The two operation produce equivalent results, but the second228// one (using packus) saves one max operation per pair.229230// But here we run into a problem: mullo does not preserve the231// sign of the multiplication. We can get around this by doing232// mulhi, which keeps the sign. But that requires an additional233// tweak.234235// mulhi cuts off the last 16 bits of the resulting product,236// which is the same as performing a rightward shift of 16 bits.237// We can use this to our advantage. Recall that we want to238// divide the final product by 128, which is equivalent to a239// 7-bit right shift. Intuitively, if we shift the clipped240// value left by 9, and perform mulhi, which shifts the product241// right by 16 bits, then we will net a right shift of 7 bits.242// However, this won't work as intended. Since we clip the243// values to have a maximum value of 127, shifting it by 9 bits244// might occupy the signed bit, resulting in some positive245// values being interpreted as negative after the shift.246247// There is a way, however, to get around this limitation. When248// loading the network, scale accumulator weights and biases by249// 2. To get the same pairwise multiplication result as before,250// we need to divide the product by 128 * 2 * 2 = 512, which251// amounts to a right shift of 9 bits. So now we only have to252// shift left by 7 bits, perform mulhi (shifts right by 16 bits)253// and net a 9 bit right shift. Since we scaled everything by254// two, the values are clipped at 127 * 2 = 254, which occupies255// 8 bits. Shifting it by 7 bits left will no longer occupy the256// signed bit, so we are safe.257258// Note that on NEON processors, we shift left by 6 instead259// because the instruction "vqdmulhq_s16" also doubles the260// return value after the multiplication, adding an extra shift261// to the left by 1, so we compensate by shifting less before262// the multiplication.263264constexpr int shift =265#if defined(USE_SSE2)2667;267#else2686;269#endif270271for (IndexType j = 0; j < NumOutputChunks; ++j)272{273const vec_t sum0a =274vec_slli_16(vec_max_16(vec_min_16(in0[j * 2 + 0], One), Zero), shift);275const vec_t sum0b =276vec_slli_16(vec_max_16(vec_min_16(in0[j * 2 + 1], One), Zero), shift);277const vec_t sum1a = vec_min_16(in1[j * 2 + 0], One);278const vec_t sum1b = vec_min_16(in1[j * 2 + 1], One);279280const vec_t pa = vec_mulhi_16(sum0a, sum1a);281const vec_t pb = vec_mulhi_16(sum0b, sum1b);282283out[j] = vec_packus_16(pa, pb);284}285286#else287288for (IndexType j = 0; j < HalfDimensions / 2; ++j)289{290BiasType sum0 = accumulation[static_cast<int>(perspectives[p])][j + 0];291BiasType sum1 =292accumulation[static_cast<int>(perspectives[p])][j + HalfDimensions / 2];293sum0 = std::clamp<BiasType>(sum0, 0, 127 * 2);294sum1 = std::clamp<BiasType>(sum1, 0, 127 * 2);295output[offset + j] = static_cast<OutputType>(unsigned(sum0 * sum1) / 512);296}297298#endif299}300301return psqt;302} // end of function transform()303304alignas(CacheLineSize) BiasType biases[HalfDimensions];305alignas(CacheLineSize) WeightType weights[HalfDimensions * InputDimensions];306alignas(CacheLineSize) PSQTWeightType psqtWeights[InputDimensions * PSQTBuckets];307};308309} // namespace Stockfish::Eval::NNUE310311#endif // #ifndef NNUE_FEATURE_TRANSFORMER_H_INCLUDED312313314