Path: blob/master/src/nnue/features/full_threats.cpp
473 views
/*1Stockfish, a UCI chess playing engine derived from Glaurung 2.12Copyright (C) 2004-2025 The Stockfish developers (see AUTHORS file)34Stockfish is free software: you can redistribute it and/or modify5it under the terms of the GNU General Public License as published by6the Free Software Foundation, either version 3 of the License, or7(at your option) any later version.89Stockfish is distributed in the hope that it will be useful,10but WITHOUT ANY WARRANTY; without even the implied warranty of11MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the12GNU General Public License for more details.1314You should have received a copy of the GNU General Public License15along with this program. If not, see <http://www.gnu.org/licenses/>.16*/1718//Definition of input features FullThreats of NNUE evaluation function1920#include "full_threats.h"2122#include <array>23#include <initializer_list>2425#include "../../bitboard.h"26#include "../../misc.h"27#include "../../position.h"28#include "../../types.h"29#include "../nnue_common.h"3031namespace Stockfish::Eval::NNUE::Features {3233// Lookup array for indexing threats34IndexType offsets[PIECE_NB][SQUARE_NB];3536struct HelperOffsets {37int cumulativePieceOffset, cumulativeOffset;38};39std::array<HelperOffsets, PIECE_NB> helper_offsets;4041// Information on a particular pair of pieces and whether they should be excluded42struct PiecePairData {43// Layout: bits 8..31 are the index contribution of this piece pair, bits 0 and 1 are exclusion info44uint32_t data;45PiecePairData() {}46PiecePairData(bool excluded_pair, bool semi_excluded_pair, IndexType feature_index_base) {47data =48excluded_pair << 1 | (semi_excluded_pair && !excluded_pair) | feature_index_base << 8;49}50// lsb: excluded if from < to; 2nd lsb: always excluded51uint8_t excluded_pair_info() const { return (uint8_t) data; }52IndexType feature_index_base() const { return data >> 8; }53};5455constexpr std::array<Piece, 12> AllPieces = {56W_PAWN, W_KNIGHT, W_BISHOP, W_ROOK, W_QUEEN, W_KING,57B_PAWN, B_KNIGHT, B_BISHOP, B_ROOK, B_QUEEN, B_KING,58};5960// The final index is calculated from summing data found in these two LUTs, as well61// as offsets[attacker][from]62PiecePairData index_lut1[PIECE_NB][PIECE_NB]; // [attacker][attacked]63uint8_t index_lut2[PIECE_NB][SQUARE_NB][SQUARE_NB]; // [attacker][from][to]6465static void init_index_luts() {66for (Piece attacker : AllPieces)67{68for (Piece attacked : AllPieces)69{70bool enemy = (attacker ^ attacked) == 8;71PieceType attackerType = type_of(attacker);72PieceType attackedType = type_of(attacked);7374int map = FullThreats::map[attackerType - 1][attackedType - 1];75bool semi_excluded = attackerType == attackedType && (enemy || attackerType != PAWN);76IndexType feature = helper_offsets[attacker].cumulativeOffset77+ (color_of(attacked) * (numValidTargets[attacker] / 2) + map)78* helper_offsets[attacker].cumulativePieceOffset;7980bool excluded = map < 0;81index_lut1[attacker][attacked] = PiecePairData(excluded, semi_excluded, feature);82}83}8485for (Piece attacker : AllPieces)86{87for (int from = 0; from < SQUARE_NB; ++from)88{89for (int to = 0; to < SQUARE_NB; ++to)90{91Bitboard attacks = attacks_bb(attacker, Square(from));92index_lut2[attacker][from][to] = popcount((square_bb(Square(to)) - 1) & attacks);93}94}95}96}9798void init_threat_offsets() {99int cumulativeOffset = 0;100for (Piece piece : AllPieces)101{102int pieceIdx = piece;103int cumulativePieceOffset = 0;104105for (Square from = SQ_A1; from <= SQ_H8; ++from)106{107offsets[pieceIdx][from] = cumulativePieceOffset;108109if (type_of(piece) != PAWN)110{111Bitboard attacks = attacks_bb(piece, from, 0ULL);112cumulativePieceOffset += popcount(attacks);113}114115else if (from >= SQ_A2 && from <= SQ_H7)116{117Bitboard attacks = (pieceIdx < 8) ? pawn_attacks_bb<WHITE>(square_bb(from))118: pawn_attacks_bb<BLACK>(square_bb(from));119cumulativePieceOffset += popcount(attacks);120}121}122123helper_offsets[pieceIdx] = {cumulativePieceOffset, cumulativeOffset};124125cumulativeOffset += numValidTargets[pieceIdx] * cumulativePieceOffset;126}127128init_index_luts();129}130131// Index of a feature for a given king position and another piece on some square132inline sf_always_inline IndexType FullThreats::make_index(133Color perspective, Piece attacker, Square from, Square to, Piece attacked, Square ksq) {134const std::int8_t orientation = OrientTBL[ksq] ^ (56 * perspective);135unsigned from_oriented = uint8_t(from) ^ orientation;136unsigned to_oriented = uint8_t(to) ^ orientation;137138std::int8_t swap = 8 * perspective;139unsigned attacker_oriented = attacker ^ swap;140unsigned attacked_oriented = attacked ^ swap;141142const auto piecePairData = index_lut1[attacker_oriented][attacked_oriented];143144const bool less_than = from_oriented < to_oriented;145if ((piecePairData.excluded_pair_info() + less_than) & 2)146return FullThreats::Dimensions;147148const IndexType index = piecePairData.feature_index_base()149+ offsets[attacker_oriented][from_oriented]150+ index_lut2[attacker_oriented][from_oriented][to_oriented];151sf_assume(index < Dimensions);152return index;153}154155// Get a list of indices for active features in ascending order156157void FullThreats::append_active_indices(Color perspective, const Position& pos, IndexList& active) {158Square ksq = pos.square<KING>(perspective);159Bitboard occupied = pos.pieces();160161for (Color color : {WHITE, BLACK})162{163for (PieceType pt = PAWN; pt <= KING; ++pt)164{165Color c = Color(perspective ^ color);166Piece attacker = make_piece(c, pt);167Bitboard bb = pos.pieces(c, pt);168169if (pt == PAWN)170{171auto right = (c == WHITE) ? NORTH_EAST : SOUTH_WEST;172auto left = (c == WHITE) ? NORTH_WEST : SOUTH_EAST;173auto attacks_left =174((c == WHITE) ? shift<NORTH_EAST>(bb) : shift<SOUTH_WEST>(bb)) & occupied;175auto attacks_right =176((c == WHITE) ? shift<NORTH_WEST>(bb) : shift<SOUTH_EAST>(bb)) & occupied;177178while (attacks_left)179{180Square to = pop_lsb(attacks_left);181Square from = to - right;182Piece attacked = pos.piece_on(to);183IndexType index = make_index(perspective, attacker, from, to, attacked, ksq);184185if (index < Dimensions)186active.push_back(index);187}188189while (attacks_right)190{191Square to = pop_lsb(attacks_right);192Square from = to - left;193Piece attacked = pos.piece_on(to);194IndexType index = make_index(perspective, attacker, from, to, attacked, ksq);195196if (index < Dimensions)197active.push_back(index);198}199}200else201{202while (bb)203{204Square from = pop_lsb(bb);205Bitboard attacks = (attacks_bb(pt, from, occupied)) & occupied;206207while (attacks)208{209Square to = pop_lsb(attacks);210Piece attacked = pos.piece_on(to);211IndexType index =212make_index(perspective, attacker, from, to, attacked, ksq);213214if (index < Dimensions)215active.push_back(index);216}217}218}219}220}221}222223// Get a list of indices for recently changed features224225void FullThreats::append_changed_indices(Color perspective,226Square ksq,227const DiffType& diff,228IndexList& removed,229IndexList& added,230FusedUpdateData* fusedData,231bool first) {232233for (const auto& dirty : diff.list)234{235auto attacker = dirty.pc();236auto attacked = dirty.threatened_pc();237auto from = dirty.pc_sq();238auto to = dirty.threatened_sq();239auto add = dirty.add();240241if (fusedData)242{243if (from == fusedData->dp2removed)244{245if (add)246{247if (first)248{249fusedData->dp2removedOriginBoard |= square_bb(to);250continue;251}252}253else if (fusedData->dp2removedOriginBoard & square_bb(to))254continue;255}256257if (to != SQ_NONE && to == fusedData->dp2removed)258{259if (add)260{261if (first)262{263fusedData->dp2removedTargetBoard |= square_bb(from);264continue;265}266}267else if (fusedData->dp2removedTargetBoard & square_bb(from))268continue;269}270}271272auto& insert = add ? added : removed;273const IndexType index = make_index(perspective, attacker, from, to, attacked, ksq);274275if (index < Dimensions)276insert.push_back(index);277}278}279280bool FullThreats::requires_refresh(const DiffType& diff, Color perspective) {281return perspective == diff.us && (int8_t(diff.ksq) & 0b100) != (int8_t(diff.prevKsq) & 0b100);282}283284} // namespace Stockfish::Eval::NNUE::Features285286287