Path: blob/master/src/nnue/nnue_accumulator.cpp
636 views
/*1Stockfish, a UCI chess playing engine derived from Glaurung 2.12Copyright (C) 2004-2026 The Stockfish developers (see AUTHORS file)34Stockfish is free software: you can redistribute it and/or modify5it under the terms of the GNU General Public License as published by6the Free Software Foundation, either version 3 of the License, or7(at your option) any later version.89Stockfish is distributed in the hope that it will be useful,10but WITHOUT ANY WARRANTY; without even the implied warranty of11MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the12GNU General Public License for more details.1314You should have received a copy of the GNU General Public License15along with this program. If not, see <http://www.gnu.org/licenses/>.16*/1718#include "nnue_accumulator.h"1920#include <cassert>21#include <cstdint>22#include <new>23#include <type_traits>2425#include "../bitboard.h"26#include "../misc.h"27#include "../position.h"28#include "../types.h"29#include "features/half_ka_v2_hm.h"30#include "nnue_architecture.h"31#include "nnue_common.h"32#include "nnue_feature_transformer.h" // IWYU pragma: keep33#include "simd.h"3435namespace Stockfish::Eval::NNUE {3637using namespace SIMD;3839namespace {4041template<IndexType TransformedFeatureDimensions>42void double_inc_update(Color perspective,43const FeatureTransformer<TransformedFeatureDimensions>& featureTransformer,44const Square ksq,45AccumulatorState<PSQFeatureSet>& middle_state,46AccumulatorState<PSQFeatureSet>& target_state,47const AccumulatorState<PSQFeatureSet>& computed);4849template<IndexType TransformedFeatureDimensions>50void double_inc_update(Color perspective,51const FeatureTransformer<TransformedFeatureDimensions>& featureTransformer,52const Square ksq,53AccumulatorState<ThreatFeatureSet>& middle_state,54AccumulatorState<ThreatFeatureSet>& target_state,55const AccumulatorState<ThreatFeatureSet>& computed,56const DirtyPiece& dp2);5758template<bool Forward, typename FeatureSet, IndexType TransformedFeatureDimensions>59void update_accumulator_incremental(60Color perspective,61const FeatureTransformer<TransformedFeatureDimensions>& featureTransformer,62const Square ksq,63AccumulatorState<FeatureSet>& target_state,64const AccumulatorState<FeatureSet>& computed);6566template<IndexType Dimensions>67void update_accumulator_refresh_cache(Color perspective,68const FeatureTransformer<Dimensions>& featureTransformer,69const Position& pos,70AccumulatorState<PSQFeatureSet>& accumulatorState,71AccumulatorCaches::Cache<Dimensions>& cache);7273template<IndexType Dimensions>74void update_threats_accumulator_full(Color perspective,75const FeatureTransformer<Dimensions>& featureTransformer,76const Position& pos,77AccumulatorState<ThreatFeatureSet>& accumulatorState);78}7980template<typename T>81const AccumulatorState<T>& AccumulatorStack::latest() const noexcept {82return accumulators<T>()[size - 1];83}8485// Explicit template instantiations86template const AccumulatorState<PSQFeatureSet>& AccumulatorStack::latest() const noexcept;87template const AccumulatorState<ThreatFeatureSet>& AccumulatorStack::latest() const noexcept;8889template<typename T>90AccumulatorState<T>& AccumulatorStack::mut_latest() noexcept {91return mut_accumulators<T>()[size - 1];92}9394template<typename T>95const std::array<AccumulatorState<T>, AccumulatorStack::MaxSize>&96AccumulatorStack::accumulators() const noexcept {97static_assert(std::is_same_v<T, PSQFeatureSet> || std::is_same_v<T, ThreatFeatureSet>,98"Invalid Feature Set Type");99100if constexpr (std::is_same_v<T, PSQFeatureSet>)101return psq_accumulators;102103if constexpr (std::is_same_v<T, ThreatFeatureSet>)104return threat_accumulators;105}106107template<typename T>108std::array<AccumulatorState<T>, AccumulatorStack::MaxSize>&109AccumulatorStack::mut_accumulators() noexcept {110static_assert(std::is_same_v<T, PSQFeatureSet> || std::is_same_v<T, ThreatFeatureSet>,111"Invalid Feature Set Type");112113if constexpr (std::is_same_v<T, PSQFeatureSet>)114return psq_accumulators;115116if constexpr (std::is_same_v<T, ThreatFeatureSet>)117return threat_accumulators;118}119120void AccumulatorStack::reset() noexcept {121psq_accumulators[0].reset({});122threat_accumulators[0].reset({});123size = 1;124}125126std::pair<DirtyPiece&, DirtyThreats&> AccumulatorStack::push() noexcept {127assert(size < MaxSize);128auto& dp = psq_accumulators[size].reset();129auto& dts = threat_accumulators[size].reset();130new (&dts) DirtyThreats;131size++;132return {dp, dts};133}134135void AccumulatorStack::pop() noexcept {136assert(size > 1);137size--;138}139140template<IndexType Dimensions>141void AccumulatorStack::evaluate(const Position& pos,142const FeatureTransformer<Dimensions>& featureTransformer,143AccumulatorCaches::Cache<Dimensions>& cache) noexcept {144constexpr bool UseThreats = (Dimensions == TransformedFeatureDimensionsBig);145146evaluate_side<PSQFeatureSet>(WHITE, pos, featureTransformer, cache);147148if (UseThreats)149evaluate_side<ThreatFeatureSet>(WHITE, pos, featureTransformer, cache);150151evaluate_side<PSQFeatureSet>(BLACK, pos, featureTransformer, cache);152153if (UseThreats)154evaluate_side<ThreatFeatureSet>(BLACK, pos, featureTransformer, cache);155}156157template<typename FeatureSet, IndexType Dimensions>158void AccumulatorStack::evaluate_side(Color perspective,159const Position& pos,160const FeatureTransformer<Dimensions>& featureTransformer,161AccumulatorCaches::Cache<Dimensions>& cache) noexcept {162163const auto last_usable_accum =164find_last_usable_accumulator<FeatureSet, Dimensions>(perspective);165166if ((accumulators<FeatureSet>()[last_usable_accum].template acc<Dimensions>())167.computed[perspective])168forward_update_incremental<FeatureSet>(perspective, pos, featureTransformer,169last_usable_accum);170171else172{173if constexpr (std::is_same_v<FeatureSet, PSQFeatureSet>)174update_accumulator_refresh_cache(perspective, featureTransformer, pos,175mut_latest<PSQFeatureSet>(), cache);176else177update_threats_accumulator_full(perspective, featureTransformer, pos,178mut_latest<ThreatFeatureSet>());179180backward_update_incremental<FeatureSet>(perspective, pos, featureTransformer,181last_usable_accum);182}183}184185// Find the earliest usable accumulator, this can either be a computed accumulator or the accumulator186// state just before a change that requires full refresh.187template<typename FeatureSet, IndexType Dimensions>188std::size_t AccumulatorStack::find_last_usable_accumulator(Color perspective) const noexcept {189190for (std::size_t curr_idx = size - 1; curr_idx > 0; curr_idx--)191{192if ((accumulators<FeatureSet>()[curr_idx].template acc<Dimensions>()).computed[perspective])193return curr_idx;194195if (FeatureSet::requires_refresh(accumulators<FeatureSet>()[curr_idx].diff, perspective))196return curr_idx;197}198199return 0;200}201202template<typename FeatureSet, IndexType Dimensions>203void AccumulatorStack::forward_update_incremental(204Color perspective,205const Position& pos,206const FeatureTransformer<Dimensions>& featureTransformer,207const std::size_t begin) noexcept {208209assert(begin < accumulators<FeatureSet>().size());210assert((accumulators<FeatureSet>()[begin].template acc<Dimensions>()).computed[perspective]);211212const Square ksq = pos.square<KING>(perspective);213214for (std::size_t next = begin + 1; next < size; next++)215{216if (next + 1 < size)217{218DirtyPiece& dp1 = mut_accumulators<PSQFeatureSet>()[next].diff;219DirtyPiece& dp2 = mut_accumulators<PSQFeatureSet>()[next + 1].diff;220221auto& accumulators = mut_accumulators<FeatureSet>();222223if constexpr (std::is_same_v<FeatureSet, ThreatFeatureSet>)224{225if (dp2.remove_sq != SQ_NONE226&& (accumulators[next].diff.threateningSqs & square_bb(dp2.remove_sq)))227{228double_inc_update(perspective, featureTransformer, ksq, accumulators[next],229accumulators[next + 1], accumulators[next - 1], dp2);230next++;231continue;232}233}234235if constexpr (std::is_same_v<FeatureSet, PSQFeatureSet>)236{237if (dp1.to != SQ_NONE && dp1.to == dp2.remove_sq)238{239const Square captureSq = dp1.to;240dp1.to = dp2.remove_sq = SQ_NONE;241double_inc_update(perspective, featureTransformer, ksq, accumulators[next],242accumulators[next + 1], accumulators[next - 1]);243dp1.to = dp2.remove_sq = captureSq;244next++;245continue;246}247}248}249250update_accumulator_incremental<true>(perspective, featureTransformer, ksq,251mut_accumulators<FeatureSet>()[next],252accumulators<FeatureSet>()[next - 1]);253}254255assert((latest<PSQFeatureSet>().acc<Dimensions>()).computed[perspective]);256}257258template<typename FeatureSet, IndexType Dimensions>259void AccumulatorStack::backward_update_incremental(260Color perspective,261262const Position& pos,263const FeatureTransformer<Dimensions>& featureTransformer,264const std::size_t end) noexcept {265266assert(end < accumulators<FeatureSet>().size());267assert(end < size);268assert((latest<FeatureSet>().template acc<Dimensions>()).computed[perspective]);269270const Square ksq = pos.square<KING>(perspective);271272for (std::int64_t next = std::int64_t(size) - 2; next >= std::int64_t(end); next--)273update_accumulator_incremental<false>(perspective, featureTransformer, ksq,274mut_accumulators<FeatureSet>()[next],275accumulators<FeatureSet>()[next + 1]);276277assert((accumulators<FeatureSet>()[end].template acc<Dimensions>()).computed[perspective]);278}279280// Explicit template instantiations281template void AccumulatorStack::evaluate<TransformedFeatureDimensionsBig>(282const Position& pos,283const FeatureTransformer<TransformedFeatureDimensionsBig>& featureTransformer,284AccumulatorCaches::Cache<TransformedFeatureDimensionsBig>& cache) noexcept;285template void AccumulatorStack::evaluate<TransformedFeatureDimensionsSmall>(286const Position& pos,287const FeatureTransformer<TransformedFeatureDimensionsSmall>& featureTransformer,288AccumulatorCaches::Cache<TransformedFeatureDimensionsSmall>& cache) noexcept;289290291namespace {292293template<typename VectorWrapper,294IndexType Width,295UpdateOperation... ops,296typename ElementType,297typename... Ts,298std::enable_if_t<is_all_same_v<ElementType, Ts...>, bool> = true>299void fused_row_reduce(const ElementType* in, ElementType* out, const Ts* const... rows) {300constexpr IndexType size = Width * sizeof(ElementType) / sizeof(typename VectorWrapper::type);301302auto* vecIn = reinterpret_cast<const typename VectorWrapper::type*>(in);303auto* vecOut = reinterpret_cast<typename VectorWrapper::type*>(out);304305for (IndexType i = 0; i < size; ++i)306vecOut[i] = fused<VectorWrapper, ops...>(307vecIn[i], reinterpret_cast<const typename VectorWrapper::type*>(rows)[i]...);308}309310template<typename FeatureSet, IndexType Dimensions>311struct AccumulatorUpdateContext {312Color perspective;313const FeatureTransformer<Dimensions>& featureTransformer;314const AccumulatorState<FeatureSet>& from;315AccumulatorState<FeatureSet>& to;316317AccumulatorUpdateContext(Color persp,318const FeatureTransformer<Dimensions>& ft,319const AccumulatorState<FeatureSet>& accF,320AccumulatorState<FeatureSet>& accT) noexcept :321perspective{persp},322featureTransformer{ft},323from{accF},324to{accT} {}325326template<UpdateOperation... ops,327typename... Ts,328std::enable_if_t<is_all_same_v<IndexType, Ts...>, bool> = true>329void apply(const Ts... indices) {330auto to_weight_vector = [&](const IndexType index) {331return &featureTransformer.weights[index * Dimensions];332};333334auto to_psqt_weight_vector = [&](const IndexType index) {335return &featureTransformer.psqtWeights[index * PSQTBuckets];336};337338fused_row_reduce<Vec16Wrapper, Dimensions, ops...>(339(from.template acc<Dimensions>()).accumulation[perspective].data(),340(to.template acc<Dimensions>()).accumulation[perspective].data(),341to_weight_vector(indices)...);342343fused_row_reduce<Vec32Wrapper, PSQTBuckets, ops...>(344(from.template acc<Dimensions>()).psqtAccumulation[perspective].data(),345(to.template acc<Dimensions>()).psqtAccumulation[perspective].data(),346to_psqt_weight_vector(indices)...);347}348349void apply(const typename FeatureSet::IndexList& added,350const typename FeatureSet::IndexList& removed) {351const auto& fromAcc = from.template acc<Dimensions>().accumulation[perspective];352auto& toAcc = to.template acc<Dimensions>().accumulation[perspective];353354const auto& fromPsqtAcc = from.template acc<Dimensions>().psqtAccumulation[perspective];355auto& toPsqtAcc = to.template acc<Dimensions>().psqtAccumulation[perspective];356357#ifdef VECTOR358using Tiling = SIMDTiling<Dimensions, Dimensions, PSQTBuckets>;359vec_t acc[Tiling::NumRegs];360psqt_vec_t psqt[Tiling::NumPsqtRegs];361362const auto* threatWeights = &featureTransformer.threatWeights[0];363364for (IndexType j = 0; j < Dimensions / Tiling::TileHeight; ++j)365{366auto* fromTile = reinterpret_cast<const vec_t*>(&fromAcc[j * Tiling::TileHeight]);367auto* toTile = reinterpret_cast<vec_t*>(&toAcc[j * Tiling::TileHeight]);368369for (IndexType k = 0; k < Tiling::NumRegs; ++k)370acc[k] = fromTile[k];371372for (int i = 0; i < removed.ssize(); ++i)373{374size_t index = removed[i];375const size_t offset = Dimensions * index;376auto* column = reinterpret_cast<const vec_i8_t*>(&threatWeights[offset]);377378#ifdef USE_NEON379for (IndexType k = 0; k < Tiling::NumRegs; k += 2)380{381acc[k] = vec_sub_16(acc[k], vmovl_s8(vget_low_s8(column[k / 2])));382acc[k + 1] = vec_sub_16(acc[k + 1], vmovl_high_s8(column[k / 2]));383}384#else385for (IndexType k = 0; k < Tiling::NumRegs; ++k)386acc[k] = vec_sub_16(acc[k], vec_convert_8_16(column[k]));387#endif388}389390for (int i = 0; i < added.ssize(); ++i)391{392size_t index = added[i];393const size_t offset = Dimensions * index;394auto* column = reinterpret_cast<const vec_i8_t*>(&threatWeights[offset]);395396#ifdef USE_NEON397for (IndexType k = 0; k < Tiling::NumRegs; k += 2)398{399acc[k] = vec_add_16(acc[k], vmovl_s8(vget_low_s8(column[k / 2])));400acc[k + 1] = vec_add_16(acc[k + 1], vmovl_high_s8(column[k / 2]));401}402#else403for (IndexType k = 0; k < Tiling::NumRegs; ++k)404acc[k] = vec_add_16(acc[k], vec_convert_8_16(column[k]));405#endif406}407408for (IndexType k = 0; k < Tiling::NumRegs; k++)409vec_store(&toTile[k], acc[k]);410411threatWeights += Tiling::TileHeight;412}413414for (IndexType j = 0; j < PSQTBuckets / Tiling::PsqtTileHeight; ++j)415{416auto* fromTilePsqt =417reinterpret_cast<const psqt_vec_t*>(&fromPsqtAcc[j * Tiling::PsqtTileHeight]);418auto* toTilePsqt =419reinterpret_cast<psqt_vec_t*>(&toPsqtAcc[j * Tiling::PsqtTileHeight]);420421for (IndexType k = 0; k < Tiling::NumPsqtRegs; ++k)422psqt[k] = fromTilePsqt[k];423424for (int i = 0; i < removed.ssize(); ++i)425{426size_t index = removed[i];427const size_t offset = PSQTBuckets * index + j * Tiling::PsqtTileHeight;428auto* columnPsqt = reinterpret_cast<const psqt_vec_t*>(429&featureTransformer.threatPsqtWeights[offset]);430431for (std::size_t k = 0; k < Tiling::NumPsqtRegs; ++k)432psqt[k] = vec_sub_psqt_32(psqt[k], columnPsqt[k]);433}434435for (int i = 0; i < added.ssize(); ++i)436{437size_t index = added[i];438const size_t offset = PSQTBuckets * index + j * Tiling::PsqtTileHeight;439auto* columnPsqt = reinterpret_cast<const psqt_vec_t*>(440&featureTransformer.threatPsqtWeights[offset]);441442for (std::size_t k = 0; k < Tiling::NumPsqtRegs; ++k)443psqt[k] = vec_add_psqt_32(psqt[k], columnPsqt[k]);444}445446for (IndexType k = 0; k < Tiling::NumPsqtRegs; ++k)447vec_store_psqt(&toTilePsqt[k], psqt[k]);448}449450#else451452toAcc = fromAcc;453toPsqtAcc = fromPsqtAcc;454455for (const auto index : removed)456{457const IndexType offset = Dimensions * index;458459for (IndexType j = 0; j < Dimensions; ++j)460toAcc[j] -= featureTransformer.threatWeights[offset + j];461462for (std::size_t k = 0; k < PSQTBuckets; ++k)463toPsqtAcc[k] -= featureTransformer.threatPsqtWeights[index * PSQTBuckets + k];464}465466for (const auto index : added)467{468const IndexType offset = Dimensions * index;469470for (IndexType j = 0; j < Dimensions; ++j)471toAcc[j] += featureTransformer.threatWeights[offset + j];472473for (std::size_t k = 0; k < PSQTBuckets; ++k)474toPsqtAcc[k] += featureTransformer.threatPsqtWeights[index * PSQTBuckets + k];475}476477#endif478}479};480481template<typename FeatureSet, IndexType Dimensions>482auto make_accumulator_update_context(Color perspective,483const FeatureTransformer<Dimensions>& featureTransformer,484const AccumulatorState<FeatureSet>& accumulatorFrom,485AccumulatorState<FeatureSet>& accumulatorTo) noexcept {486return AccumulatorUpdateContext<FeatureSet, Dimensions>{perspective, featureTransformer,487accumulatorFrom, accumulatorTo};488}489490template<IndexType TransformedFeatureDimensions>491void double_inc_update(Color perspective,492const FeatureTransformer<TransformedFeatureDimensions>& featureTransformer,493const Square ksq,494AccumulatorState<PSQFeatureSet>& middle_state,495AccumulatorState<PSQFeatureSet>& target_state,496const AccumulatorState<PSQFeatureSet>& computed) {497498assert(computed.acc<TransformedFeatureDimensions>().computed[perspective]);499assert(!middle_state.acc<TransformedFeatureDimensions>().computed[perspective]);500assert(!target_state.acc<TransformedFeatureDimensions>().computed[perspective]);501502PSQFeatureSet::IndexList removed, added;503PSQFeatureSet::append_changed_indices(perspective, ksq, middle_state.diff, removed, added);504// you can't capture a piece that was just involved in castling since the rook ends up505// in a square that the king passed506assert(added.size() < 2);507PSQFeatureSet::append_changed_indices(perspective, ksq, target_state.diff, removed, added);508509[[maybe_unused]] const int addedSize = added.ssize();510[[maybe_unused]] const int removedSize = removed.ssize();511512assert(addedSize == 1);513assert(removedSize == 2 || removedSize == 3);514515// Workaround compiler warning for uninitialized variables, replicated on516// profile builds on windows with gcc 14.2.0.517// Also helps with optimizations on some compilers.518519sf_assume(addedSize == 1);520sf_assume(removedSize == 2 || removedSize == 3);521522auto updateContext =523make_accumulator_update_context(perspective, featureTransformer, computed, target_state);524525if (removedSize == 2)526{527updateContext.template apply<Add, Sub, Sub>(added[0], removed[0], removed[1]);528}529else530{531updateContext.template apply<Add, Sub, Sub, Sub>(added[0], removed[0], removed[1],532removed[2]);533}534535target_state.acc<TransformedFeatureDimensions>().computed[perspective] = true;536}537538template<IndexType TransformedFeatureDimensions>539void double_inc_update(Color perspective,540const FeatureTransformer<TransformedFeatureDimensions>& featureTransformer,541const Square ksq,542AccumulatorState<ThreatFeatureSet>& middle_state,543AccumulatorState<ThreatFeatureSet>& target_state,544const AccumulatorState<ThreatFeatureSet>& computed,545const DirtyPiece& dp2) {546547assert(computed.acc<TransformedFeatureDimensions>().computed[perspective]);548assert(!middle_state.acc<TransformedFeatureDimensions>().computed[perspective]);549assert(!target_state.acc<TransformedFeatureDimensions>().computed[perspective]);550551ThreatFeatureSet::FusedUpdateData fusedData;552553fusedData.dp2removed = dp2.remove_sq;554555ThreatFeatureSet::IndexList removed, added;556const auto* pfBase = &featureTransformer.threatWeights[0];557auto pfStride = static_cast<IndexType>(TransformedFeatureDimensions);558ThreatFeatureSet::append_changed_indices(perspective, ksq, middle_state.diff, removed, added,559&fusedData, true, pfBase, pfStride);560ThreatFeatureSet::append_changed_indices(perspective, ksq, target_state.diff, removed, added,561&fusedData, false, pfBase, pfStride);562563auto updateContext =564make_accumulator_update_context(perspective, featureTransformer, computed, target_state);565566updateContext.apply(added, removed);567568target_state.acc<TransformedFeatureDimensions>().computed[perspective] = true;569}570571template<bool Forward, typename FeatureSet, IndexType TransformedFeatureDimensions>572void update_accumulator_incremental(573Color perspective,574const FeatureTransformer<TransformedFeatureDimensions>& featureTransformer,575const Square ksq,576AccumulatorState<FeatureSet>& target_state,577const AccumulatorState<FeatureSet>& computed) {578579assert((computed.template acc<TransformedFeatureDimensions>()).computed[perspective]);580assert(!(target_state.template acc<TransformedFeatureDimensions>()).computed[perspective]);581582// The size must be enough to contain the largest possible update.583// That might depend on the feature set and generally relies on the584// feature set's update cost calculation to be correct and never allow585// updates with more added/removed features than MaxActiveDimensions.586// In this case, the maximum size of both feature addition and removal587// is 2, since we are incrementally updating one move at a time.588typename FeatureSet::IndexList removed, added;589if constexpr (std::is_same_v<FeatureSet, ThreatFeatureSet>)590{591const auto* pfBase = &featureTransformer.threatWeights[0];592auto pfStride = static_cast<IndexType>(TransformedFeatureDimensions);593if constexpr (Forward)594FeatureSet::append_changed_indices(perspective, ksq, target_state.diff, removed, added,595nullptr, false, pfBase, pfStride);596else597FeatureSet::append_changed_indices(perspective, ksq, computed.diff, added, removed,598nullptr, false, pfBase, pfStride);599}600else601{602if constexpr (Forward)603FeatureSet::append_changed_indices(perspective, ksq, target_state.diff, removed, added);604else605FeatureSet::append_changed_indices(perspective, ksq, computed.diff, added, removed);606}607608auto updateContext =609make_accumulator_update_context(perspective, featureTransformer, computed, target_state);610611if constexpr (std::is_same_v<FeatureSet, ThreatFeatureSet>)612updateContext.apply(added, removed);613else614{615[[maybe_unused]] const int addedSize = added.ssize();616[[maybe_unused]] const int removedSize = removed.ssize();617618assert(addedSize == 1 || addedSize == 2);619assert(removedSize == 1 || removedSize == 2);620assert((Forward && addedSize <= removedSize) || (!Forward && addedSize >= removedSize));621622// Workaround compiler warning for uninitialized variables, replicated623// on profile builds on windows with gcc 14.2.0.624// Also helps with optimizations on some compilers.625626sf_assume(addedSize == 1 || addedSize == 2);627sf_assume(removedSize == 1 || removedSize == 2);628629if (!(removedSize == 1 || removedSize == 2) || !(addedSize == 1 || addedSize == 2))630sf_unreachable();631632if ((Forward && removedSize == 1) || (!Forward && addedSize == 1))633{634assert(addedSize == 1 && removedSize == 1);635updateContext.template apply<Add, Sub>(added[0], removed[0]);636}637else if (Forward && addedSize == 1)638{639assert(removedSize == 2);640updateContext.template apply<Add, Sub, Sub>(added[0], removed[0], removed[1]);641}642else if (!Forward && removedSize == 1)643{644assert(addedSize == 2);645updateContext.template apply<Add, Add, Sub>(added[0], added[1], removed[0]);646}647else648{649assert(addedSize == 2 && removedSize == 2);650updateContext.template apply<Add, Add, Sub, Sub>(added[0], added[1], removed[0],651removed[1]);652}653}654655(target_state.template acc<TransformedFeatureDimensions>()).computed[perspective] = true;656}657658Bitboard get_changed_pieces(const std::array<Piece, SQUARE_NB>& oldPieces,659const std::array<Piece, SQUARE_NB>& newPieces) {660#if defined(USE_AVX512) || defined(USE_AVX2)661static_assert(sizeof(Piece) == 1);662Bitboard sameBB = 0;663664for (int i = 0; i < 64; i += 32)665{666const __m256i old_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(&oldPieces[i]));667const __m256i new_v = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(&newPieces[i]));668const __m256i cmpEqual = _mm256_cmpeq_epi8(old_v, new_v);669const std::uint32_t equalMask = _mm256_movemask_epi8(cmpEqual);670sameBB |= static_cast<Bitboard>(equalMask) << i;671}672return ~sameBB;673#elif defined(USE_NEON)674uint8x16x4_t old_v = vld4q_u8(reinterpret_cast<const uint8_t*>(oldPieces.data()));675uint8x16x4_t new_v = vld4q_u8(reinterpret_cast<const uint8_t*>(newPieces.data()));676auto cmp = [=](const int i) { return vceqq_u8(old_v.val[i], new_v.val[i]); };677678uint8x16_t cmp0_1 = vsriq_n_u8(cmp(1), cmp(0), 1);679uint8x16_t cmp2_3 = vsriq_n_u8(cmp(3), cmp(2), 1);680uint8x16_t merged = vsriq_n_u8(cmp2_3, cmp0_1, 2);681merged = vsriq_n_u8(merged, merged, 4);682uint8x8_t sameBB = vshrn_n_u16(vreinterpretq_u16_u8(merged), 4);683684return ~vget_lane_u64(vreinterpret_u64_u8(sameBB), 0);685#else686Bitboard changed = 0;687688for (Square sq = SQUARE_ZERO; sq < SQUARE_NB; ++sq)689changed |= static_cast<Bitboard>(oldPieces[sq] != newPieces[sq]) << sq;690691return changed;692#endif693}694695template<IndexType Dimensions>696void update_accumulator_refresh_cache(Color perspective,697const FeatureTransformer<Dimensions>& featureTransformer,698const Position& pos,699AccumulatorState<PSQFeatureSet>& accumulatorState,700AccumulatorCaches::Cache<Dimensions>& cache) {701702using Tiling [[maybe_unused]] = SIMDTiling<Dimensions, Dimensions, PSQTBuckets>;703704const Square ksq = pos.square<KING>(perspective);705auto& entry = cache[ksq][perspective];706PSQFeatureSet::IndexList removed, added;707708const Bitboard changedBB = get_changed_pieces(entry.pieces, pos.piece_array());709Bitboard removedBB = changedBB & entry.pieceBB;710Bitboard addedBB = changedBB & pos.pieces();711712while (removedBB)713{714Square sq = pop_lsb(removedBB);715removed.push_back(PSQFeatureSet::make_index(perspective, sq, entry.pieces[sq], ksq));716}717while (addedBB)718{719Square sq = pop_lsb(addedBB);720added.push_back(PSQFeatureSet::make_index(perspective, sq, pos.piece_on(sq), ksq));721}722723entry.pieceBB = pos.pieces();724entry.pieces = pos.piece_array();725726auto& accumulator = accumulatorState.acc<Dimensions>();727accumulator.computed[perspective] = true;728729#ifdef VECTOR730vec_t acc[Tiling::NumRegs];731psqt_vec_t psqt[Tiling::NumPsqtRegs];732733const auto* weights = &featureTransformer.weights[0];734735for (IndexType j = 0; j < Dimensions / Tiling::TileHeight; ++j)736{737auto* accTile =738reinterpret_cast<vec_t*>(&accumulator.accumulation[perspective][j * Tiling::TileHeight]);739auto* entryTile = reinterpret_cast<vec_t*>(&entry.accumulation[j * Tiling::TileHeight]);740741for (IndexType k = 0; k < Tiling::NumRegs; ++k)742acc[k] = entryTile[k];743744int i = 0;745for (; i < std::min(removed.ssize(), added.ssize()); ++i)746{747size_t indexR = removed[i];748const size_t offsetR = Dimensions * indexR;749auto* columnR = reinterpret_cast<const vec_t*>(&weights[offsetR]);750size_t indexA = added[i];751const size_t offsetA = Dimensions * indexA;752auto* columnA = reinterpret_cast<const vec_t*>(&weights[offsetA]);753754for (IndexType k = 0; k < Tiling::NumRegs; ++k)755acc[k] = fused<Vec16Wrapper, Add, Sub>(acc[k], columnA[k], columnR[k]);756}757for (; i < removed.ssize(); ++i)758{759size_t index = removed[i];760const size_t offset = Dimensions * index;761auto* column = reinterpret_cast<const vec_t*>(&weights[offset]);762763for (IndexType k = 0; k < Tiling::NumRegs; ++k)764acc[k] = vec_sub_16(acc[k], column[k]);765}766for (; i < added.ssize(); ++i)767{768size_t index = added[i];769const size_t offset = Dimensions * index;770auto* column = reinterpret_cast<const vec_t*>(&weights[offset]);771772for (IndexType k = 0; k < Tiling::NumRegs; ++k)773acc[k] = vec_add_16(acc[k], column[k]);774}775776for (IndexType k = 0; k < Tiling::NumRegs; k++)777vec_store(&entryTile[k], acc[k]);778for (IndexType k = 0; k < Tiling::NumRegs; k++)779vec_store(&accTile[k], acc[k]);780781weights += Tiling::TileHeight;782}783784for (IndexType j = 0; j < PSQTBuckets / Tiling::PsqtTileHeight; ++j)785{786auto* accTilePsqt = reinterpret_cast<psqt_vec_t*>(787&accumulator.psqtAccumulation[perspective][j * Tiling::PsqtTileHeight]);788auto* entryTilePsqt =789reinterpret_cast<psqt_vec_t*>(&entry.psqtAccumulation[j * Tiling::PsqtTileHeight]);790791for (IndexType k = 0; k < Tiling::NumPsqtRegs; ++k)792psqt[k] = entryTilePsqt[k];793794for (int i = 0; i < removed.ssize(); ++i)795{796size_t index = removed[i];797const size_t offset = PSQTBuckets * index + j * Tiling::PsqtTileHeight;798auto* columnPsqt =799reinterpret_cast<const psqt_vec_t*>(&featureTransformer.psqtWeights[offset]);800801for (std::size_t k = 0; k < Tiling::NumPsqtRegs; ++k)802psqt[k] = vec_sub_psqt_32(psqt[k], columnPsqt[k]);803}804for (int i = 0; i < added.ssize(); ++i)805{806size_t index = added[i];807const size_t offset = PSQTBuckets * index + j * Tiling::PsqtTileHeight;808auto* columnPsqt =809reinterpret_cast<const psqt_vec_t*>(&featureTransformer.psqtWeights[offset]);810811for (std::size_t k = 0; k < Tiling::NumPsqtRegs; ++k)812psqt[k] = vec_add_psqt_32(psqt[k], columnPsqt[k]);813}814815for (IndexType k = 0; k < Tiling::NumPsqtRegs; ++k)816vec_store_psqt(&entryTilePsqt[k], psqt[k]);817for (IndexType k = 0; k < Tiling::NumPsqtRegs; ++k)818vec_store_psqt(&accTilePsqt[k], psqt[k]);819}820821#else822823for (const auto index : removed)824{825const IndexType offset = Dimensions * index;826for (IndexType j = 0; j < Dimensions; ++j)827entry.accumulation[j] -= featureTransformer.weights[offset + j];828829for (std::size_t k = 0; k < PSQTBuckets; ++k)830entry.psqtAccumulation[k] -= featureTransformer.psqtWeights[index * PSQTBuckets + k];831}832for (const auto index : added)833{834const IndexType offset = Dimensions * index;835for (IndexType j = 0; j < Dimensions; ++j)836entry.accumulation[j] += featureTransformer.weights[offset + j];837838for (std::size_t k = 0; k < PSQTBuckets; ++k)839entry.psqtAccumulation[k] += featureTransformer.psqtWeights[index * PSQTBuckets + k];840}841842// The accumulator of the refresh entry has been updated.843// Now copy its content to the actual accumulator we were refreshing.844accumulator.accumulation[perspective] = entry.accumulation;845accumulator.psqtAccumulation[perspective] = entry.psqtAccumulation;846#endif847}848849template<IndexType Dimensions>850void update_threats_accumulator_full(Color perspective,851const FeatureTransformer<Dimensions>& featureTransformer,852const Position& pos,853AccumulatorState<ThreatFeatureSet>& accumulatorState) {854using Tiling [[maybe_unused]] = SIMDTiling<Dimensions, Dimensions, PSQTBuckets>;855856ThreatFeatureSet::IndexList active;857ThreatFeatureSet::append_active_indices(perspective, pos, active);858859auto& accumulator = accumulatorState.acc<Dimensions>();860accumulator.computed[perspective] = true;861862#ifdef VECTOR863vec_t acc[Tiling::NumRegs];864psqt_vec_t psqt[Tiling::NumPsqtRegs];865866const auto* threatWeights = &featureTransformer.threatWeights[0];867868for (IndexType j = 0; j < Dimensions / Tiling::TileHeight; ++j)869{870auto* accTile =871reinterpret_cast<vec_t*>(&accumulator.accumulation[perspective][j * Tiling::TileHeight]);872873for (IndexType k = 0; k < Tiling::NumRegs; ++k)874acc[k] = vec_zero();875876int i = 0;877878for (; i < active.ssize(); ++i)879{880size_t index = active[i];881const size_t offset = Dimensions * index;882auto* column = reinterpret_cast<const vec_i8_t*>(&threatWeights[offset]);883884#ifdef USE_NEON885for (IndexType k = 0; k < Tiling::NumRegs; k += 2)886{887acc[k] = vec_add_16(acc[k], vmovl_s8(vget_low_s8(column[k / 2])));888acc[k + 1] = vec_add_16(acc[k + 1], vmovl_high_s8(column[k / 2]));889}890#else891for (IndexType k = 0; k < Tiling::NumRegs; ++k)892acc[k] = vec_add_16(acc[k], vec_convert_8_16(column[k]));893#endif894}895896for (IndexType k = 0; k < Tiling::NumRegs; k++)897vec_store(&accTile[k], acc[k]);898899threatWeights += Tiling::TileHeight;900}901902for (IndexType j = 0; j < PSQTBuckets / Tiling::PsqtTileHeight; ++j)903{904auto* accTilePsqt = reinterpret_cast<psqt_vec_t*>(905&accumulator.psqtAccumulation[perspective][j * Tiling::PsqtTileHeight]);906907for (IndexType k = 0; k < Tiling::NumPsqtRegs; ++k)908psqt[k] = vec_zero_psqt();909910for (int i = 0; i < active.ssize(); ++i)911{912size_t index = active[i];913const size_t offset = PSQTBuckets * index + j * Tiling::PsqtTileHeight;914auto* columnPsqt =915reinterpret_cast<const psqt_vec_t*>(&featureTransformer.threatPsqtWeights[offset]);916917for (std::size_t k = 0; k < Tiling::NumPsqtRegs; ++k)918psqt[k] = vec_add_psqt_32(psqt[k], columnPsqt[k]);919}920921for (IndexType k = 0; k < Tiling::NumPsqtRegs; ++k)922vec_store_psqt(&accTilePsqt[k], psqt[k]);923}924925#else926927for (IndexType j = 0; j < Dimensions; ++j)928accumulator.accumulation[perspective][j] = 0;929930for (std::size_t k = 0; k < PSQTBuckets; ++k)931accumulator.psqtAccumulation[perspective][k] = 0;932933for (const auto index : active)934{935const IndexType offset = Dimensions * index;936937for (IndexType j = 0; j < Dimensions; ++j)938accumulator.accumulation[perspective][j] +=939featureTransformer.threatWeights[offset + j];940941for (std::size_t k = 0; k < PSQTBuckets; ++k)942accumulator.psqtAccumulation[perspective][k] +=943featureTransformer.threatPsqtWeights[index * PSQTBuckets + k];944}945946#endif947}948949}950951}952953954