Path: blob/main/contrib/llvm-project/llvm/lib/Target/VE/VECustomDAG.cpp
35294 views
//===-- VECustomDAG.h - VE Custom DAG Nodes ------------*- C++ -*-===//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.3// See https://llvm.org/LICENSE.txt for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===----------------------------------------------------------------------===//7//8// This file defines the interfaces that VE uses to lower LLVM code into a9// selection DAG.10//11//===----------------------------------------------------------------------===//1213#include "VECustomDAG.h"1415#ifndef DEBUG_TYPE16#define DEBUG_TYPE "vecustomdag"17#endif1819namespace llvm {2021bool isPackedVectorType(EVT SomeVT) {22if (!SomeVT.isVector())23return false;24return SomeVT.getVectorNumElements() > StandardVectorWidth;25}2627MVT splitVectorType(MVT VT) {28if (!VT.isVector())29return VT;30return MVT::getVectorVT(VT.getVectorElementType(), StandardVectorWidth);31}3233MVT getLegalVectorType(Packing P, MVT ElemVT) {34return MVT::getVectorVT(ElemVT, P == Packing::Normal ? StandardVectorWidth35: PackedVectorWidth);36}3738Packing getTypePacking(EVT VT) {39assert(VT.isVector());40return isPackedVectorType(VT) ? Packing::Dense : Packing::Normal;41}4243bool isMaskType(EVT SomeVT) {44if (!SomeVT.isVector())45return false;46return SomeVT.getVectorElementType() == MVT::i1;47}4849bool isMaskArithmetic(SDValue Op) {50switch (Op.getOpcode()) {51default:52return false;53case ISD::AND:54case ISD::XOR:55case ISD::OR:56return isMaskType(Op.getValueType());57}58}5960/// \returns the VVP_* SDNode opcode corresponsing to \p OC.61std::optional<unsigned> getVVPOpcode(unsigned Opcode) {62switch (Opcode) {63case ISD::MLOAD:64return VEISD::VVP_LOAD;65case ISD::MSTORE:66return VEISD::VVP_STORE;67#define HANDLE_VP_TO_VVP(VPOPC, VVPNAME) \68case ISD::VPOPC: \69return VEISD::VVPNAME;70#define ADD_VVP_OP(VVPNAME, SDNAME) \71case VEISD::VVPNAME: \72case ISD::SDNAME: \73return VEISD::VVPNAME;74#include "VVPNodes.def"75// TODO: Map those in VVPNodes.def too76case ISD::EXPERIMENTAL_VP_STRIDED_LOAD:77return VEISD::VVP_LOAD;78case ISD::EXPERIMENTAL_VP_STRIDED_STORE:79return VEISD::VVP_STORE;80}81return std::nullopt;82}8384bool maySafelyIgnoreMask(SDValue Op) {85auto VVPOpc = getVVPOpcode(Op->getOpcode());86auto Opc = VVPOpc.value_or(Op->getOpcode());8788switch (Opc) {89case VEISD::VVP_SDIV:90case VEISD::VVP_UDIV:91case VEISD::VVP_FDIV:92case VEISD::VVP_SELECT:93return false;9495default:96return true;97}98}99100bool supportsPackedMode(unsigned Opcode, EVT IdiomVT) {101bool IsPackedOp = isPackedVectorType(IdiomVT);102bool IsMaskOp = isMaskType(IdiomVT);103switch (Opcode) {104default:105return false;106107case VEISD::VEC_BROADCAST:108return true;109#define REGISTER_PACKED(VVP_NAME) case VEISD::VVP_NAME:110#include "VVPNodes.def"111return IsPackedOp && !IsMaskOp;112}113}114115bool isPackingSupportOpcode(unsigned Opc) {116switch (Opc) {117case VEISD::VEC_PACK:118case VEISD::VEC_UNPACK_LO:119case VEISD::VEC_UNPACK_HI:120return true;121}122return false;123}124125bool isVVPOrVEC(unsigned Opcode) {126switch (Opcode) {127case VEISD::VEC_BROADCAST:128#define ADD_VVP_OP(VVPNAME, ...) case VEISD::VVPNAME:129#include "VVPNodes.def"130return true;131}132return false;133}134135bool isVVPUnaryOp(unsigned VVPOpcode) {136switch (VVPOpcode) {137#define ADD_UNARY_VVP_OP(VVPNAME, ...) \138case VEISD::VVPNAME: \139return true;140#include "VVPNodes.def"141}142return false;143}144145bool isVVPBinaryOp(unsigned VVPOpcode) {146switch (VVPOpcode) {147#define ADD_BINARY_VVP_OP(VVPNAME, ...) \148case VEISD::VVPNAME: \149return true;150#include "VVPNodes.def"151}152return false;153}154155bool isVVPReductionOp(unsigned Opcode) {156switch (Opcode) {157#define ADD_REDUCE_VVP_OP(VVP_NAME, SDNAME) case VEISD::VVP_NAME:158#include "VVPNodes.def"159return true;160}161return false;162}163164// Return the AVL operand position for this VVP or VEC Op.165std::optional<int> getAVLPos(unsigned Opc) {166// This is only available for VP SDNodes167auto PosOpt = ISD::getVPExplicitVectorLengthIdx(Opc);168if (PosOpt)169return *PosOpt;170171// VVP Opcodes.172if (isVVPBinaryOp(Opc))173return 3;174175// VM Opcodes.176switch (Opc) {177case VEISD::VEC_BROADCAST:178return 1;179case VEISD::VVP_SELECT:180return 3;181case VEISD::VVP_LOAD:182return 4;183case VEISD::VVP_STORE:184return 5;185}186187return std::nullopt;188}189190std::optional<int> getMaskPos(unsigned Opc) {191// This is only available for VP SDNodes192auto PosOpt = ISD::getVPMaskIdx(Opc);193if (PosOpt)194return *PosOpt;195196// VVP Opcodes.197if (isVVPBinaryOp(Opc))198return 2;199200// Other opcodes.201switch (Opc) {202case ISD::MSTORE:203return 4;204case ISD::MLOAD:205return 3;206case VEISD::VVP_SELECT:207return 2;208}209210return std::nullopt;211}212213bool isLegalAVL(SDValue AVL) { return AVL->getOpcode() == VEISD::LEGALAVL; }214215/// Node Properties {216217SDValue getNodeChain(SDValue Op) {218if (MemSDNode *MemN = dyn_cast<MemSDNode>(Op.getNode()))219return MemN->getChain();220221switch (Op->getOpcode()) {222case VEISD::VVP_LOAD:223case VEISD::VVP_STORE:224return Op->getOperand(0);225}226return SDValue();227}228229SDValue getMemoryPtr(SDValue Op) {230if (auto *MemN = dyn_cast<MemSDNode>(Op.getNode()))231return MemN->getBasePtr();232233switch (Op->getOpcode()) {234case VEISD::VVP_LOAD:235return Op->getOperand(1);236case VEISD::VVP_STORE:237return Op->getOperand(2);238}239return SDValue();240}241242std::optional<EVT> getIdiomaticVectorType(SDNode *Op) {243unsigned OC = Op->getOpcode();244245// For memory ops -> the transfered data type246if (auto MemN = dyn_cast<MemSDNode>(Op))247return MemN->getMemoryVT();248249switch (OC) {250// Standard ISD.251case ISD::SELECT: // not aliased with VVP_SELECT252case ISD::CONCAT_VECTORS:253case ISD::EXTRACT_SUBVECTOR:254case ISD::VECTOR_SHUFFLE:255case ISD::BUILD_VECTOR:256case ISD::SCALAR_TO_VECTOR:257return Op->getValueType(0);258}259260// Translate to VVP where possible.261unsigned OriginalOC = OC;262if (auto VVPOpc = getVVPOpcode(OC))263OC = *VVPOpc;264265if (isVVPReductionOp(OC))266return Op->getOperand(hasReductionStartParam(OriginalOC) ? 1 : 0)267.getValueType();268269switch (OC) {270default:271case VEISD::VVP_SETCC:272return Op->getOperand(0).getValueType();273274case VEISD::VVP_SELECT:275#define ADD_BINARY_VVP_OP(VVP_NAME, ...) case VEISD::VVP_NAME:276#include "VVPNodes.def"277return Op->getValueType(0);278279case VEISD::VVP_LOAD:280return Op->getValueType(0);281282case VEISD::VVP_STORE:283return Op->getOperand(1)->getValueType(0);284285// VEC286case VEISD::VEC_BROADCAST:287return Op->getValueType(0);288}289}290291SDValue getLoadStoreStride(SDValue Op, VECustomDAG &CDAG) {292switch (Op->getOpcode()) {293case VEISD::VVP_STORE:294return Op->getOperand(3);295case VEISD::VVP_LOAD:296return Op->getOperand(2);297}298299if (auto *StoreN = dyn_cast<VPStridedStoreSDNode>(Op.getNode()))300return StoreN->getStride();301if (auto *StoreN = dyn_cast<VPStridedLoadSDNode>(Op.getNode()))302return StoreN->getStride();303304if (isa<MemSDNode>(Op.getNode())) {305// Regular MLOAD/MSTORE/LOAD/STORE306// No stride argument -> use the contiguous element size as stride.307uint64_t ElemStride = getIdiomaticVectorType(Op.getNode())308->getVectorElementType()309.getStoreSize();310return CDAG.getConstant(ElemStride, MVT::i64);311}312return SDValue();313}314315SDValue getGatherScatterIndex(SDValue Op) {316if (auto *N = dyn_cast<MaskedGatherScatterSDNode>(Op.getNode()))317return N->getIndex();318if (auto *N = dyn_cast<VPGatherScatterSDNode>(Op.getNode()))319return N->getIndex();320return SDValue();321}322323SDValue getGatherScatterScale(SDValue Op) {324if (auto *N = dyn_cast<MaskedGatherScatterSDNode>(Op.getNode()))325return N->getScale();326if (auto *N = dyn_cast<VPGatherScatterSDNode>(Op.getNode()))327return N->getScale();328return SDValue();329}330331SDValue getStoredValue(SDValue Op) {332switch (Op->getOpcode()) {333case ISD::EXPERIMENTAL_VP_STRIDED_STORE:334case VEISD::VVP_STORE:335return Op->getOperand(1);336}337if (auto *StoreN = dyn_cast<StoreSDNode>(Op.getNode()))338return StoreN->getValue();339if (auto *StoreN = dyn_cast<MaskedStoreSDNode>(Op.getNode()))340return StoreN->getValue();341if (auto *StoreN = dyn_cast<VPStridedStoreSDNode>(Op.getNode()))342return StoreN->getValue();343if (auto *StoreN = dyn_cast<VPStoreSDNode>(Op.getNode()))344return StoreN->getValue();345if (auto *StoreN = dyn_cast<MaskedScatterSDNode>(Op.getNode()))346return StoreN->getValue();347if (auto *StoreN = dyn_cast<VPScatterSDNode>(Op.getNode()))348return StoreN->getValue();349return SDValue();350}351352SDValue getNodePassthru(SDValue Op) {353if (auto *N = dyn_cast<MaskedLoadSDNode>(Op.getNode()))354return N->getPassThru();355if (auto *N = dyn_cast<MaskedGatherSDNode>(Op.getNode()))356return N->getPassThru();357358return SDValue();359}360361bool hasReductionStartParam(unsigned OPC) {362// TODO: Ordered reduction opcodes.363if (ISD::isVPReduction(OPC))364return true;365return false;366}367368unsigned getScalarReductionOpcode(unsigned VVPOC, bool IsMask) {369assert(!IsMask && "Mask reduction isel");370371switch (VVPOC) {372#define HANDLE_VVP_REDUCE_TO_SCALAR(VVP_RED_ISD, REDUCE_ISD) \373case VEISD::VVP_RED_ISD: \374return ISD::REDUCE_ISD;375#include "VVPNodes.def"376default:377break;378}379llvm_unreachable("Cannot not scalarize this reduction Opcode!");380}381382/// } Node Properties383384SDValue getNodeAVL(SDValue Op) {385auto PosOpt = getAVLPos(Op->getOpcode());386return PosOpt ? Op->getOperand(*PosOpt) : SDValue();387}388389SDValue getNodeMask(SDValue Op) {390auto PosOpt = getMaskPos(Op->getOpcode());391return PosOpt ? Op->getOperand(*PosOpt) : SDValue();392}393394std::pair<SDValue, bool> getAnnotatedNodeAVL(SDValue Op) {395SDValue AVL = getNodeAVL(Op);396if (!AVL)397return {SDValue(), true};398if (isLegalAVL(AVL))399return {AVL->getOperand(0), true};400return {AVL, false};401}402403SDValue VECustomDAG::getConstant(uint64_t Val, EVT VT, bool IsTarget,404bool IsOpaque) const {405return DAG.getConstant(Val, DL, VT, IsTarget, IsOpaque);406}407408SDValue VECustomDAG::getConstantMask(Packing Packing, bool AllTrue) const {409auto MaskVT = getLegalVectorType(Packing, MVT::i1);410411// VEISelDAGtoDAG will replace this pattern with the constant-true VM.412auto TrueVal = DAG.getConstant(-1, DL, MVT::i32);413auto AVL = getConstant(MaskVT.getVectorNumElements(), MVT::i32);414auto Res = getNode(VEISD::VEC_BROADCAST, MaskVT, {TrueVal, AVL});415if (AllTrue)416return Res;417418return DAG.getNOT(DL, Res, Res.getValueType());419}420421SDValue VECustomDAG::getMaskBroadcast(EVT ResultVT, SDValue Scalar,422SDValue AVL) const {423// Constant mask splat.424if (auto BcConst = dyn_cast<ConstantSDNode>(Scalar))425return getConstantMask(getTypePacking(ResultVT),426BcConst->getSExtValue() != 0);427428// Expand the broadcast to a vector comparison.429auto ScalarBoolVT = Scalar.getSimpleValueType();430assert(ScalarBoolVT == MVT::i32);431432// Cast to i32 ty.433SDValue CmpElem = DAG.getSExtOrTrunc(Scalar, DL, MVT::i32);434unsigned ElemCount = ResultVT.getVectorNumElements();435MVT CmpVecTy = MVT::getVectorVT(ScalarBoolVT, ElemCount);436437// Broadcast to vector.438SDValue BCVec =439DAG.getNode(VEISD::VEC_BROADCAST, DL, CmpVecTy, {CmpElem, AVL});440SDValue ZeroVec =441getBroadcast(CmpVecTy, {DAG.getConstant(0, DL, ScalarBoolVT)}, AVL);442443MVT BoolVecTy = MVT::getVectorVT(MVT::i1, ElemCount);444445// Broadcast(Data) != Broadcast(0)446// TODO: Use a VVP operation for this.447return DAG.getSetCC(DL, BoolVecTy, BCVec, ZeroVec, ISD::CondCode::SETNE);448}449450SDValue VECustomDAG::getBroadcast(EVT ResultVT, SDValue Scalar,451SDValue AVL) const {452assert(ResultVT.isVector());453auto ScaVT = Scalar.getValueType();454455if (isMaskType(ResultVT))456return getMaskBroadcast(ResultVT, Scalar, AVL);457458if (isPackedVectorType(ResultVT)) {459// v512x packed mode broadcast460// Replicate the scalar reg (f32 or i32) onto the opposing half of the full461// scalar register. If it's an I64 type, assume that this has already462// happened.463if (ScaVT == MVT::f32) {464Scalar = getNode(VEISD::REPL_F32, MVT::i64, Scalar);465} else if (ScaVT == MVT::i32) {466Scalar = getNode(VEISD::REPL_I32, MVT::i64, Scalar);467}468}469470return getNode(VEISD::VEC_BROADCAST, ResultVT, {Scalar, AVL});471}472473SDValue VECustomDAG::annotateLegalAVL(SDValue AVL) const {474if (isLegalAVL(AVL))475return AVL;476return getNode(VEISD::LEGALAVL, AVL.getValueType(), AVL);477}478479SDValue VECustomDAG::getUnpack(EVT DestVT, SDValue Vec, PackElem Part,480SDValue AVL) const {481assert(getAnnotatedNodeAVL(AVL).second && "Expected a pack-legalized AVL");482483// TODO: Peek through VEC_PACK and VEC_BROADCAST(REPL_<sth> ..) operands.484unsigned OC =485(Part == PackElem::Lo) ? VEISD::VEC_UNPACK_LO : VEISD::VEC_UNPACK_HI;486return DAG.getNode(OC, DL, DestVT, Vec, AVL);487}488489SDValue VECustomDAG::getPack(EVT DestVT, SDValue LoVec, SDValue HiVec,490SDValue AVL) const {491assert(getAnnotatedNodeAVL(AVL).second && "Expected a pack-legalized AVL");492493// TODO: Peek through VEC_UNPACK_LO|HI operands.494return DAG.getNode(VEISD::VEC_PACK, DL, DestVT, LoVec, HiVec, AVL);495}496497VETargetMasks VECustomDAG::getTargetSplitMask(SDValue RawMask, SDValue RawAVL,498PackElem Part) const {499// Adjust AVL for this part500SDValue NewAVL;501SDValue OneV = getConstant(1, MVT::i32);502if (Part == PackElem::Hi)503NewAVL = getNode(ISD::ADD, MVT::i32, {RawAVL, OneV});504else505NewAVL = RawAVL;506NewAVL = getNode(ISD::SRL, MVT::i32, {NewAVL, OneV});507508NewAVL = annotateLegalAVL(NewAVL);509510// Legalize Mask (unpack or all-true)511SDValue NewMask;512if (!RawMask)513NewMask = getConstantMask(Packing::Normal, true);514else515NewMask = getUnpack(MVT::v256i1, RawMask, Part, NewAVL);516517return VETargetMasks(NewMask, NewAVL);518}519520SDValue VECustomDAG::getSplitPtrOffset(SDValue Ptr, SDValue ByteStride,521PackElem Part) const {522// High starts at base ptr but has more significant bits in the 64bit vector523// element.524if (Part == PackElem::Hi)525return Ptr;526return getNode(ISD::ADD, MVT::i64, {Ptr, ByteStride});527}528529SDValue VECustomDAG::getSplitPtrStride(SDValue PackStride) const {530if (auto ConstBytes = dyn_cast<ConstantSDNode>(PackStride))531return getConstant(2 * ConstBytes->getSExtValue(), MVT::i64);532return getNode(ISD::SHL, MVT::i64, {PackStride, getConstant(1, MVT::i32)});533}534535SDValue VECustomDAG::getGatherScatterAddress(SDValue BasePtr, SDValue Scale,536SDValue Index, SDValue Mask,537SDValue AVL) const {538EVT IndexVT = Index.getValueType();539540// Apply scale.541SDValue ScaledIndex;542if (!Scale || isOneConstant(Scale))543ScaledIndex = Index;544else {545SDValue ScaleBroadcast = getBroadcast(IndexVT, Scale, AVL);546ScaledIndex =547getNode(VEISD::VVP_MUL, IndexVT, {Index, ScaleBroadcast, Mask, AVL});548}549550// Add basePtr.551if (isNullConstant(BasePtr))552return ScaledIndex;553554// re-constitute pointer vector (basePtr + index * scale)555SDValue BaseBroadcast = getBroadcast(IndexVT, BasePtr, AVL);556auto ResPtr =557getNode(VEISD::VVP_ADD, IndexVT, {BaseBroadcast, ScaledIndex, Mask, AVL});558return ResPtr;559}560561SDValue VECustomDAG::getLegalReductionOpVVP(unsigned VVPOpcode, EVT ResVT,562SDValue StartV, SDValue VectorV,563SDValue Mask, SDValue AVL,564SDNodeFlags Flags) const {565566// Optionally attach the start param with a scalar op (where it is567// unsupported).568bool scalarizeStartParam = StartV && !hasReductionStartParam(VVPOpcode);569bool IsMaskReduction = isMaskType(VectorV.getValueType());570assert(!IsMaskReduction && "TODO Implement");571auto AttachStartValue = [&](SDValue ReductionResV) {572if (!scalarizeStartParam)573return ReductionResV;574auto ScalarOC = getScalarReductionOpcode(VVPOpcode, IsMaskReduction);575return getNode(ScalarOC, ResVT, {StartV, ReductionResV});576};577578// Fixup: Always Use sequential 'fmul' reduction.579if (!scalarizeStartParam && StartV) {580assert(hasReductionStartParam(VVPOpcode));581return AttachStartValue(582getNode(VVPOpcode, ResVT, {StartV, VectorV, Mask, AVL}, Flags));583} else584return AttachStartValue(585getNode(VVPOpcode, ResVT, {VectorV, Mask, AVL}, Flags));586}587588} // namespace llvm589590591