#include "movegen.h"
#include <cassert>
#include <initializer_list>
#include "bitboard.h"
#include "position.h"
#if defined(USE_AVX512ICL)
#include <array>
#include <algorithm>
#include <immintrin.h>
#endif
namespace Stockfish {
namespace {
#if defined(USE_AVX512ICL)
inline Move* write_moves(Move* moveList, uint32_t mask, __m512i vector) {
_mm512_storeu_si512(reinterpret_cast<__m512i*>(moveList),
_mm512_maskz_compress_epi16(mask, vector));
return moveList + popcount(mask);
}
template<Direction offset>
inline Move* splat_pawn_moves(Move* moveList, Bitboard to_bb) {
alignas(64) static constexpr auto SPLAT_TABLE = [] {
std::array<Move, 64> table{};
for (int8_t i = 0; i < 64; i++)
{
Square from{std::clamp<int8_t>(i - offset, 0, 63)};
table[i] = {Move(from, Square{i})};
}
return table;
}();
auto table = reinterpret_cast<const __m512i*>(SPLAT_TABLE.data());
moveList =
write_moves(moveList, static_cast<uint32_t>(to_bb >> 0), _mm512_load_si512(table + 0));
moveList =
write_moves(moveList, static_cast<uint32_t>(to_bb >> 32), _mm512_load_si512(table + 1));
return moveList;
}
inline Move* splat_moves(Move* moveList, Square from, Bitboard to_bb) {
alignas(64) static constexpr auto SPLAT_TABLE = [] {
std::array<Move, 64> table{};
for (int8_t i = 0; i < 64; i++)
table[i] = {Move(SQUARE_ZERO, Square{i})};
return table;
}();
__m512i fromVec = _mm512_set1_epi16(Move(from, SQUARE_ZERO).raw());
auto table = reinterpret_cast<const __m512i*>(SPLAT_TABLE.data());
moveList = write_moves(moveList, static_cast<uint32_t>(to_bb >> 0),
_mm512_or_si512(_mm512_load_si512(table + 0), fromVec));
moveList = write_moves(moveList, static_cast<uint32_t>(to_bb >> 32),
_mm512_or_si512(_mm512_load_si512(table + 1), fromVec));
return moveList;
}
#else
template<Direction offset>
inline Move* splat_pawn_moves(Move* moveList, Bitboard to_bb) {
while (to_bb)
{
Square to = pop_lsb(to_bb);
*moveList++ = Move(to - offset, to);
}
return moveList;
}
inline Move* splat_moves(Move* moveList, Square from, Bitboard to_bb) {
while (to_bb)
*moveList++ = Move(from, pop_lsb(to_bb));
return moveList;
}
#endif
template<GenType Type, Direction D, bool Enemy>
Move* make_promotions(Move* moveList, [[maybe_unused]] Square to) {
constexpr bool all = Type == EVASIONS || Type == NON_EVASIONS;
if constexpr (Type == CAPTURES || all)
*moveList++ = Move::make<PROMOTION>(to - D, to, QUEEN);
if constexpr ((Type == CAPTURES && Enemy) || (Type == QUIETS && !Enemy) || all)
{
*moveList++ = Move::make<PROMOTION>(to - D, to, ROOK);
*moveList++ = Move::make<PROMOTION>(to - D, to, BISHOP);
*moveList++ = Move::make<PROMOTION>(to - D, to, KNIGHT);
}
return moveList;
}
template<Color Us, GenType Type>
Move* generate_pawn_moves(const Position& pos, Move* moveList, Bitboard target) {
constexpr Color Them = ~Us;
constexpr Bitboard TRank7BB = (Us == WHITE ? Rank7BB : Rank2BB);
constexpr Bitboard TRank3BB = (Us == WHITE ? Rank3BB : Rank6BB);
constexpr Direction Up = pawn_push(Us);
constexpr Direction UpRight = (Us == WHITE ? NORTH_EAST : SOUTH_WEST);
constexpr Direction UpLeft = (Us == WHITE ? NORTH_WEST : SOUTH_EAST);
const Bitboard emptySquares = ~pos.pieces();
const Bitboard enemies = Type == EVASIONS ? pos.checkers() : pos.pieces(Them);
Bitboard pawnsOn7 = pos.pieces(Us, PAWN) & TRank7BB;
Bitboard pawnsNotOn7 = pos.pieces(Us, PAWN) & ~TRank7BB;
if constexpr (Type != CAPTURES)
{
Bitboard b1 = shift<Up>(pawnsNotOn7) & emptySquares;
Bitboard b2 = shift<Up>(b1 & TRank3BB) & emptySquares;
if constexpr (Type == EVASIONS)
{
b1 &= target;
b2 &= target;
}
moveList = splat_pawn_moves<Up>(moveList, b1);
moveList = splat_pawn_moves<Up + Up>(moveList, b2);
}
if (pawnsOn7)
{
Bitboard b1 = shift<UpRight>(pawnsOn7) & enemies;
Bitboard b2 = shift<UpLeft>(pawnsOn7) & enemies;
Bitboard b3 = shift<Up>(pawnsOn7) & emptySquares;
if constexpr (Type == EVASIONS)
b3 &= target;
while (b1)
moveList = make_promotions<Type, UpRight, true>(moveList, pop_lsb(b1));
while (b2)
moveList = make_promotions<Type, UpLeft, true>(moveList, pop_lsb(b2));
while (b3)
moveList = make_promotions<Type, Up, false>(moveList, pop_lsb(b3));
}
if constexpr (Type == CAPTURES || Type == EVASIONS || Type == NON_EVASIONS)
{
Bitboard b1 = shift<UpRight>(pawnsNotOn7) & enemies;
Bitboard b2 = shift<UpLeft>(pawnsNotOn7) & enemies;
moveList = splat_pawn_moves<UpRight>(moveList, b1);
moveList = splat_pawn_moves<UpLeft>(moveList, b2);
if (pos.ep_square() != SQ_NONE)
{
assert(rank_of(pos.ep_square()) == relative_rank(Us, RANK_6));
if (Type == EVASIONS && (target & (pos.ep_square() + Up)))
return moveList;
b1 = pawnsNotOn7 & attacks_bb<PAWN>(pos.ep_square(), Them);
assert(b1);
while (b1)
*moveList++ = Move::make<EN_PASSANT>(pop_lsb(b1), pos.ep_square());
}
}
return moveList;
}
template<Color Us, PieceType Pt>
Move* generate_moves(const Position& pos, Move* moveList, Bitboard target) {
static_assert(Pt != KING && Pt != PAWN, "Unsupported piece type in generate_moves()");
Bitboard bb = pos.pieces(Us, Pt);
while (bb)
{
Square from = pop_lsb(bb);
Bitboard b = attacks_bb<Pt>(from, pos.pieces()) & target;
moveList = splat_moves(moveList, from, b);
}
return moveList;
}
template<Color Us, GenType Type>
Move* generate_all(const Position& pos, Move* moveList) {
static_assert(Type != LEGAL, "Unsupported type in generate_all()");
const Square ksq = pos.square<KING>(Us);
Bitboard target;
if (Type != EVASIONS || !more_than_one(pos.checkers()))
{
target = Type == EVASIONS ? between_bb(ksq, lsb(pos.checkers()))
: Type == NON_EVASIONS ? ~pos.pieces(Us)
: Type == CAPTURES ? pos.pieces(~Us)
: ~pos.pieces();
moveList = generate_pawn_moves<Us, Type>(pos, moveList, target);
moveList = generate_moves<Us, KNIGHT>(pos, moveList, target);
moveList = generate_moves<Us, BISHOP>(pos, moveList, target);
moveList = generate_moves<Us, ROOK>(pos, moveList, target);
moveList = generate_moves<Us, QUEEN>(pos, moveList, target);
}
Bitboard b = attacks_bb<KING>(ksq) & (Type == EVASIONS ? ~pos.pieces(Us) : target);
moveList = splat_moves(moveList, ksq, b);
if ((Type == QUIETS || Type == NON_EVASIONS) && pos.can_castle(Us & ANY_CASTLING))
for (CastlingRights cr : {Us & KING_SIDE, Us & QUEEN_SIDE})
if (!pos.castling_impeded(cr) && pos.can_castle(cr))
*moveList++ = Move::make<CASTLING>(ksq, pos.castling_rook_square(cr));
return moveList;
}
}
template<GenType Type>
Move* generate(const Position& pos, Move* moveList) {
static_assert(Type != LEGAL, "Unsupported type in generate()");
assert((Type == EVASIONS) == bool(pos.checkers()));
Color us = pos.side_to_move();
return us == WHITE ? generate_all<WHITE, Type>(pos, moveList)
: generate_all<BLACK, Type>(pos, moveList);
}
template Move* generate<CAPTURES>(const Position&, Move*);
template Move* generate<QUIETS>(const Position&, Move*);
template Move* generate<EVASIONS>(const Position&, Move*);
template Move* generate<NON_EVASIONS>(const Position&, Move*);
template<>
Move* generate<LEGAL>(const Position& pos, Move* moveList) {
Color us = pos.side_to_move();
Bitboard pinned = pos.blockers_for_king(us) & pos.pieces(us);
Square ksq = pos.square<KING>(us);
Move* cur = moveList;
moveList =
pos.checkers() ? generate<EVASIONS>(pos, moveList) : generate<NON_EVASIONS>(pos, moveList);
while (cur != moveList)
if (((pinned & cur->from_sq()) || cur->from_sq() == ksq || cur->type_of() == EN_PASSANT)
&& !pos.legal(*cur))
*cur = *(--moveList);
else
++cur;
return moveList;
}
}