Path: blob/master/src/nnue/layers/affine_transform.h
637 views
/*1Stockfish, a UCI chess playing engine derived from Glaurung 2.12Copyright (C) 2004-2026 The Stockfish developers (see AUTHORS file)34Stockfish is free software: you can redistribute it and/or modify5it under the terms of the GNU General Public License as published by6the Free Software Foundation, either version 3 of the License, or7(at your option) any later version.89Stockfish is distributed in the hope that it will be useful,10but WITHOUT ANY WARRANTY; without even the implied warranty of11MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the12GNU General Public License for more details.1314You should have received a copy of the GNU General Public License15along with this program. If not, see <http://www.gnu.org/licenses/>.16*/1718// Definition of layer AffineTransform of NNUE evaluation function1920#ifndef NNUE_LAYERS_AFFINE_TRANSFORM_H_INCLUDED21#define NNUE_LAYERS_AFFINE_TRANSFORM_H_INCLUDED2223#include <cstdint>24#include <iostream>2526#include "../../memory.h"27#include "../nnue_common.h"28#include "../simd.h"2930/*31This file contains the definition for a fully connected layer (aka affine transform).3233- expected use-case is for when PaddedInputDimensions == 32 and InputDimensions <= 32.34- that's why AVX512 is hard to implement35- expected use-case is small layers36- inputs are processed in chunks of 4, weights are respectively transposed37- accumulation happens directly to int32s38*/3940namespace Stockfish::Eval::NNUE::Layers {4142#if defined(USE_SSSE3) || defined(USE_NEON_DOTPROD)43#define ENABLE_SEQ_OPT44#endif4546// Fallback implementation for older/other architectures.47// Requires the input to be padded to at least 16 values.48#ifndef ENABLE_SEQ_OPT4950template<IndexType InputDimensions, IndexType PaddedInputDimensions, IndexType OutputDimensions>51static void affine_transform_non_ssse3(std::int32_t* output,52const std::int8_t* weights,53const std::int32_t* biases,54const std::uint8_t* input) {55#if defined(USE_SSE2) || defined(USE_NEON)56#if defined(USE_SSE2)57// At least a multiple of 16, with SSE2.58constexpr IndexType NumChunks = ceil_to_multiple<IndexType>(InputDimensions, 16) / 16;59const __m128i Zeros = _mm_setzero_si128();60const auto inputVector = reinterpret_cast<const __m128i*>(input);6162#elif defined(USE_NEON)63constexpr IndexType NumChunks = ceil_to_multiple<IndexType>(InputDimensions, 16) / 16;64const auto inputVector = reinterpret_cast<const int8x8_t*>(input);65#endif6667for (IndexType i = 0; i < OutputDimensions; ++i)68{69const IndexType offset = i * PaddedInputDimensions;7071#if defined(USE_SSE2)72__m128i sumLo = _mm_cvtsi32_si128(biases[i]);73__m128i sumHi = Zeros;74const auto row = reinterpret_cast<const __m128i*>(&weights[offset]);75for (IndexType j = 0; j < NumChunks; ++j)76{77__m128i row_j = _mm_load_si128(&row[j]);78__m128i input_j = _mm_load_si128(&inputVector[j]);79__m128i extendedRowLo = _mm_srai_epi16(_mm_unpacklo_epi8(row_j, row_j), 8);80__m128i extendedRowHi = _mm_srai_epi16(_mm_unpackhi_epi8(row_j, row_j), 8);81__m128i extendedInputLo = _mm_unpacklo_epi8(input_j, Zeros);82__m128i extendedInputHi = _mm_unpackhi_epi8(input_j, Zeros);83__m128i productLo = _mm_madd_epi16(extendedRowLo, extendedInputLo);84__m128i productHi = _mm_madd_epi16(extendedRowHi, extendedInputHi);85sumLo = _mm_add_epi32(sumLo, productLo);86sumHi = _mm_add_epi32(sumHi, productHi);87}88__m128i sum = _mm_add_epi32(sumLo, sumHi);89__m128i sumHigh_64 = _mm_shuffle_epi32(sum, _MM_SHUFFLE(1, 0, 3, 2));90sum = _mm_add_epi32(sum, sumHigh_64);91__m128i sum_second_32 = _mm_shufflelo_epi16(sum, _MM_SHUFFLE(1, 0, 3, 2));92sum = _mm_add_epi32(sum, sum_second_32);93output[i] = _mm_cvtsi128_si32(sum);9495#elif defined(USE_NEON)9697int32x4_t sum = {biases[i]};98const auto row = reinterpret_cast<const SIMD::vec_i8x8_t*>(&weights[offset]);99for (IndexType j = 0; j < NumChunks; ++j)100{101int16x8_t product = vmull_s8(inputVector[j * 2], row[j * 2]);102product = vmlal_s8(product, inputVector[j * 2 + 1], row[j * 2 + 1]);103sum = vpadalq_s16(sum, product);104}105output[i] = SIMD::neon_m128_reduce_add_epi32(sum);106107#endif108}109#else110std::memcpy(output, biases, sizeof(std::int32_t) * OutputDimensions);111112// Traverse weights in transpose order to take advantage of input sparsity113for (IndexType i = 0; i < InputDimensions; ++i)114if (input[i])115{116const std::int8_t* w = &weights[i];117const int in = input[i];118for (IndexType j = 0; j < OutputDimensions; ++j)119output[j] += w[j * PaddedInputDimensions] * in;120}121#endif122}123124#endif // !ENABLE_SEQ_OPT125126template<IndexType InDims, IndexType OutDims>127class AffineTransform {128public:129// Input/output type130using InputType = std::uint8_t;131using OutputType = std::int32_t;132133// Number of input/output dimensions134static constexpr IndexType InputDimensions = InDims;135static constexpr IndexType OutputDimensions = OutDims;136137static constexpr IndexType PaddedInputDimensions =138ceil_to_multiple<IndexType>(InputDimensions, MaxSimdWidth);139static constexpr IndexType PaddedOutputDimensions =140ceil_to_multiple<IndexType>(OutputDimensions, MaxSimdWidth);141142using OutputBuffer = OutputType[PaddedOutputDimensions];143144// Hash value embedded in the evaluation file145static constexpr std::uint32_t get_hash_value(std::uint32_t prevHash) {146std::uint32_t hashValue = 0xCC03DAE4u;147hashValue += OutputDimensions;148hashValue ^= prevHash >> 1;149hashValue ^= prevHash << 31;150return hashValue;151}152153static constexpr IndexType get_weight_index_scrambled(IndexType i) {154return (i / 4) % (PaddedInputDimensions / 4) * OutputDimensions * 4155+ i / PaddedInputDimensions * 4 + i % 4;156}157158static constexpr IndexType get_weight_index(IndexType i) {159#ifdef ENABLE_SEQ_OPT160return get_weight_index_scrambled(i);161#else162return i;163#endif164}165166// Read network parameters167bool read_parameters(std::istream& stream) {168read_little_endian<BiasType>(stream, biases, OutputDimensions);169for (IndexType i = 0; i < OutputDimensions * PaddedInputDimensions; ++i)170weights[get_weight_index(i)] = read_little_endian<WeightType>(stream);171172return !stream.fail();173}174175// Write network parameters176bool write_parameters(std::ostream& stream) const {177write_little_endian<BiasType>(stream, biases, OutputDimensions);178179for (IndexType i = 0; i < OutputDimensions * PaddedInputDimensions; ++i)180write_little_endian<WeightType>(stream, weights[get_weight_index(i)]);181182return !stream.fail();183}184185std::size_t get_content_hash() const {186std::size_t h = 0;187hash_combine(h, get_raw_data_hash(biases));188hash_combine(h, get_raw_data_hash(weights));189hash_combine(h, get_hash_value(0));190return h;191}192193// Forward propagation194void propagate(const InputType* input, OutputType* output) const {195196#ifdef ENABLE_SEQ_OPT197198if constexpr (OutputDimensions > 1)199{200#if defined(USE_AVX512)201using vec_t = __m512i;202#define vec_set_32 _mm512_set1_epi32203#define vec_add_dpbusd_32 SIMD::m512_add_dpbusd_epi32204#elif defined(USE_AVX2)205using vec_t = __m256i;206#define vec_set_32 _mm256_set1_epi32207#define vec_add_dpbusd_32 SIMD::m256_add_dpbusd_epi32208#elif defined(USE_SSSE3)209using vec_t = __m128i;210#define vec_set_32 _mm_set1_epi32211#define vec_add_dpbusd_32 SIMD::m128_add_dpbusd_epi32212#elif defined(USE_NEON_DOTPROD)213using vec_t = int32x4_t;214#define vec_set_32 vdupq_n_s32215#define vec_add_dpbusd_32(acc, a, b) \216SIMD::dotprod_m128_add_dpbusd_epi32(acc, vreinterpretq_s8_s32(a), \217vreinterpretq_s8_s32(b))218#endif219220static constexpr IndexType OutputSimdWidth = sizeof(vec_t) / sizeof(OutputType);221222static_assert(OutputDimensions % OutputSimdWidth == 0);223224constexpr IndexType NumChunks = ceil_to_multiple<IndexType>(InputDimensions, 8) / 4;225constexpr IndexType NumRegs = OutputDimensions / OutputSimdWidth;226227const vec_t* biasvec = reinterpret_cast<const vec_t*>(biases);228vec_t acc[NumRegs];229for (IndexType k = 0; k < NumRegs; ++k)230acc[k] = biasvec[k];231232for (IndexType i = 0; i < NumChunks; ++i)233{234const vec_t in0 =235vec_set_32(load_as<std::int32_t>(input + i * sizeof(std::int32_t)));236const auto col0 =237reinterpret_cast<const vec_t*>(&weights[i * OutputDimensions * 4]);238239for (IndexType k = 0; k < NumRegs; ++k)240vec_add_dpbusd_32(acc[k], in0, col0[k]);241}242243vec_t* outptr = reinterpret_cast<vec_t*>(output);244for (IndexType k = 0; k < NumRegs; ++k)245outptr[k] = acc[k];246247#undef vec_set_32248#undef vec_add_dpbusd_32249}250else if constexpr (OutputDimensions == 1)251{252// We cannot use AVX512 for the last layer because there are only 32 inputs253// and the buffer is not padded to 64 elements.254#if defined(USE_AVX2)255using vec_t = __m256i;256#define vec_setzero() _mm256_setzero_si256()257#define vec_add_dpbusd_32 SIMD::m256_add_dpbusd_epi32258#define vec_hadd SIMD::m256_hadd259#elif defined(USE_SSSE3)260using vec_t = __m128i;261#define vec_setzero() _mm_setzero_si128()262#define vec_add_dpbusd_32 SIMD::m128_add_dpbusd_epi32263#define vec_hadd SIMD::m128_hadd264#elif defined(USE_NEON_DOTPROD)265using vec_t = int32x4_t;266#define vec_setzero() vdupq_n_s32(0)267#define vec_add_dpbusd_32(acc, a, b) \268SIMD::dotprod_m128_add_dpbusd_epi32(acc, vreinterpretq_s8_s32(a), \269vreinterpretq_s8_s32(b))270#define vec_hadd SIMD::neon_m128_hadd271#endif272273const auto inputVector = reinterpret_cast<const vec_t*>(input);274275static constexpr IndexType InputSimdWidth = sizeof(vec_t) / sizeof(InputType);276277static_assert(PaddedInputDimensions % InputSimdWidth == 0);278279constexpr IndexType NumChunks = PaddedInputDimensions / InputSimdWidth;280vec_t sum0 = vec_setzero();281const auto row0 = reinterpret_cast<const vec_t*>(&weights[0]);282283for (int j = 0; j < int(NumChunks); ++j)284{285const vec_t in = inputVector[j];286vec_add_dpbusd_32(sum0, in, row0[j]);287}288output[0] = vec_hadd(sum0, biases[0]);289290#undef vec_setzero291#undef vec_add_dpbusd_32292#undef vec_hadd293}294#else295// Use old implementation for the other architectures.296affine_transform_non_ssse3<InputDimensions, PaddedInputDimensions, OutputDimensions>(297output, weights, biases, input);298#endif299}300301private:302using BiasType = OutputType;303using WeightType = std::int8_t;304305alignas(CacheLineSize) BiasType biases[OutputDimensions];306alignas(CacheLineSize) WeightType weights[OutputDimensions * PaddedInputDimensions];307};308309} // namespace Stockfish::Eval::NNUE::Layers310311#endif // #ifndef NNUE_LAYERS_AFFINE_TRANSFORM_H_INCLUDED312313314