Path: blob/main/contrib/llvm-project/llvm/lib/Target/VE/VVPISelLowering.cpp
35294 views
//===-- VVPISelLowering.cpp - VE DAG Lowering Implementation --------------===//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.3// See https://llvm.org/LICENSE.txt for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===----------------------------------------------------------------------===//7//8// This file implements the lowering and legalization of vector instructions to9// VVP_*layer SDNodes.10//11//===----------------------------------------------------------------------===//1213#include "VECustomDAG.h"14#include "VEISelLowering.h"1516using namespace llvm;1718#define DEBUG_TYPE "ve-lower"1920SDValue VETargetLowering::splitMaskArithmetic(SDValue Op,21SelectionDAG &DAG) const {22VECustomDAG CDAG(DAG, Op);23SDValue AVL =24CDAG.getConstant(Op.getValueType().getVectorNumElements(), MVT::i32);25SDValue A = Op->getOperand(0);26SDValue B = Op->getOperand(1);27SDValue LoA = CDAG.getUnpack(MVT::v256i1, A, PackElem::Lo, AVL);28SDValue HiA = CDAG.getUnpack(MVT::v256i1, A, PackElem::Hi, AVL);29SDValue LoB = CDAG.getUnpack(MVT::v256i1, B, PackElem::Lo, AVL);30SDValue HiB = CDAG.getUnpack(MVT::v256i1, B, PackElem::Hi, AVL);31unsigned Opc = Op.getOpcode();32auto LoRes = CDAG.getNode(Opc, MVT::v256i1, {LoA, LoB});33auto HiRes = CDAG.getNode(Opc, MVT::v256i1, {HiA, HiB});34return CDAG.getPack(MVT::v512i1, LoRes, HiRes, AVL);35}3637SDValue VETargetLowering::lowerToVVP(SDValue Op, SelectionDAG &DAG) const {38// Can we represent this as a VVP node.39const unsigned Opcode = Op->getOpcode();40auto VVPOpcodeOpt = getVVPOpcode(Opcode);41if (!VVPOpcodeOpt)42return SDValue();43unsigned VVPOpcode = *VVPOpcodeOpt;44const bool FromVP = ISD::isVPOpcode(Opcode);4546// The representative and legalized vector type of this operation.47VECustomDAG CDAG(DAG, Op);48// Dispatch to complex lowering functions.49switch (VVPOpcode) {50case VEISD::VVP_LOAD:51case VEISD::VVP_STORE:52return lowerVVP_LOAD_STORE(Op, CDAG);53case VEISD::VVP_GATHER:54case VEISD::VVP_SCATTER:55return lowerVVP_GATHER_SCATTER(Op, CDAG);56}5758EVT OpVecVT = *getIdiomaticVectorType(Op.getNode());59EVT LegalVecVT = getTypeToTransformTo(*DAG.getContext(), OpVecVT);60auto Packing = getTypePacking(LegalVecVT.getSimpleVT());6162SDValue AVL;63SDValue Mask;6465if (FromVP) {66// All upstream VP SDNodes always have a mask and avl.67auto MaskIdx = ISD::getVPMaskIdx(Opcode);68auto AVLIdx = ISD::getVPExplicitVectorLengthIdx(Opcode);69if (MaskIdx)70Mask = Op->getOperand(*MaskIdx);71if (AVLIdx)72AVL = Op->getOperand(*AVLIdx);73}7475// Materialize default mask and avl.76if (!AVL)77AVL = CDAG.getConstant(OpVecVT.getVectorNumElements(), MVT::i32);78if (!Mask)79Mask = CDAG.getConstantMask(Packing, true);8081assert(LegalVecVT.isSimple());82if (isVVPUnaryOp(VVPOpcode))83return CDAG.getNode(VVPOpcode, LegalVecVT, {Op->getOperand(0), Mask, AVL});84if (isVVPBinaryOp(VVPOpcode))85return CDAG.getNode(VVPOpcode, LegalVecVT,86{Op->getOperand(0), Op->getOperand(1), Mask, AVL});87if (isVVPReductionOp(VVPOpcode)) {88auto SrcHasStart = hasReductionStartParam(Op->getOpcode());89SDValue StartV = SrcHasStart ? Op->getOperand(0) : SDValue();90SDValue VectorV = Op->getOperand(SrcHasStart ? 1 : 0);91return CDAG.getLegalReductionOpVVP(VVPOpcode, Op.getValueType(), StartV,92VectorV, Mask, AVL, Op->getFlags());93}9495switch (VVPOpcode) {96default:97llvm_unreachable("lowerToVVP called for unexpected SDNode.");98case VEISD::VVP_FFMA: {99// VE has a swizzled operand order in FMA (compared to LLVM IR and100// SDNodes).101auto X = Op->getOperand(2);102auto Y = Op->getOperand(0);103auto Z = Op->getOperand(1);104return CDAG.getNode(VVPOpcode, LegalVecVT, {X, Y, Z, Mask, AVL});105}106case VEISD::VVP_SELECT: {107auto Mask = Op->getOperand(0);108auto OnTrue = Op->getOperand(1);109auto OnFalse = Op->getOperand(2);110return CDAG.getNode(VVPOpcode, LegalVecVT, {OnTrue, OnFalse, Mask, AVL});111}112case VEISD::VVP_SETCC: {113EVT LegalResVT = getTypeToTransformTo(*DAG.getContext(), Op.getValueType());114auto LHS = Op->getOperand(0);115auto RHS = Op->getOperand(1);116auto Pred = Op->getOperand(2);117return CDAG.getNode(VVPOpcode, LegalResVT, {LHS, RHS, Pred, Mask, AVL});118}119}120}121122SDValue VETargetLowering::lowerVVP_LOAD_STORE(SDValue Op,123VECustomDAG &CDAG) const {124auto VVPOpc = *getVVPOpcode(Op->getOpcode());125const bool IsLoad = (VVPOpc == VEISD::VVP_LOAD);126127// Shares.128SDValue BasePtr = getMemoryPtr(Op);129SDValue Mask = getNodeMask(Op);130SDValue Chain = getNodeChain(Op);131SDValue AVL = getNodeAVL(Op);132// Store specific.133SDValue Data = getStoredValue(Op);134// Load specific.135SDValue PassThru = getNodePassthru(Op);136137SDValue StrideV = getLoadStoreStride(Op, CDAG);138139auto DataVT = *getIdiomaticVectorType(Op.getNode());140auto Packing = getTypePacking(DataVT);141142// TODO: Infer lower AVL from mask.143if (!AVL)144AVL = CDAG.getConstant(DataVT.getVectorNumElements(), MVT::i32);145146// Default to the all-true mask.147if (!Mask)148Mask = CDAG.getConstantMask(Packing, true);149150if (IsLoad) {151MVT LegalDataVT = getLegalVectorType(152Packing, DataVT.getVectorElementType().getSimpleVT());153154auto NewLoadV = CDAG.getNode(VEISD::VVP_LOAD, {LegalDataVT, MVT::Other},155{Chain, BasePtr, StrideV, Mask, AVL});156157if (!PassThru || PassThru->isUndef())158return NewLoadV;159160// Convert passthru to an explicit select node.161SDValue DataV = CDAG.getNode(VEISD::VVP_SELECT, DataVT,162{NewLoadV, PassThru, Mask, AVL});163SDValue NewLoadChainV = SDValue(NewLoadV.getNode(), 1);164165// Merge them back into one node.166return CDAG.getMergeValues({DataV, NewLoadChainV});167}168169// VVP_STORE170assert(VVPOpc == VEISD::VVP_STORE);171if (getTypeAction(*CDAG.getDAG()->getContext(), Data.getValueType()) !=172TargetLowering::TypeLegal)173// Doesn't lower store instruction if an operand is not lowered yet.174// If it isn't, return SDValue(). In this way, LLVM will try to lower175// store instruction again after lowering all operands.176return SDValue();177return CDAG.getNode(VEISD::VVP_STORE, Op.getNode()->getVTList(),178{Chain, Data, BasePtr, StrideV, Mask, AVL});179}180181SDValue VETargetLowering::splitPackedLoadStore(SDValue Op,182VECustomDAG &CDAG) const {183auto VVPOC = *getVVPOpcode(Op.getOpcode());184assert((VVPOC == VEISD::VVP_LOAD) || (VVPOC == VEISD::VVP_STORE));185186MVT DataVT = getIdiomaticVectorType(Op.getNode())->getSimpleVT();187assert(getTypePacking(DataVT) == Packing::Dense &&188"Can only split packed load/store");189MVT SplitDataVT = splitVectorType(DataVT);190191assert(!getNodePassthru(Op) &&192"Should have been folded in lowering to VVP layer");193194// Analyze the operation195SDValue PackedMask = getNodeMask(Op);196SDValue PackedAVL = getAnnotatedNodeAVL(Op).first;197SDValue PackPtr = getMemoryPtr(Op);198SDValue PackData = getStoredValue(Op);199SDValue PackStride = getLoadStoreStride(Op, CDAG);200201unsigned ChainResIdx = PackData ? 0 : 1;202203SDValue PartOps[2];204205SDValue UpperPartAVL; // we will use this for packing things back together206for (PackElem Part : {PackElem::Hi, PackElem::Lo}) {207// VP ops already have an explicit mask and AVL. When expanding from non-VP208// attach those additional inputs here.209auto SplitTM = CDAG.getTargetSplitMask(PackedMask, PackedAVL, Part);210211// Keep track of the (higher) lvl.212if (Part == PackElem::Hi)213UpperPartAVL = SplitTM.AVL;214215// Attach non-predicating value operands216SmallVector<SDValue, 4> OpVec;217218// Chain219OpVec.push_back(getNodeChain(Op));220221// Data222if (PackData) {223SDValue PartData =224CDAG.getUnpack(SplitDataVT, PackData, Part, SplitTM.AVL);225OpVec.push_back(PartData);226}227228// Ptr & Stride229// Push (ptr + ElemBytes * <Part>, 2 * ElemBytes)230// Stride info231// EVT DataVT = LegalizeVectorType(getMemoryDataVT(Op), Op, DAG, Mode);232OpVec.push_back(CDAG.getSplitPtrOffset(PackPtr, PackStride, Part));233OpVec.push_back(CDAG.getSplitPtrStride(PackStride));234235// Add predicating args and generate part node236OpVec.push_back(SplitTM.Mask);237OpVec.push_back(SplitTM.AVL);238239if (PackData) {240// Store241PartOps[(int)Part] = CDAG.getNode(VVPOC, MVT::Other, OpVec);242} else {243// Load244PartOps[(int)Part] =245CDAG.getNode(VVPOC, {SplitDataVT, MVT::Other}, OpVec);246}247}248249// Merge the chains250SDValue LowChain = SDValue(PartOps[(int)PackElem::Lo].getNode(), ChainResIdx);251SDValue HiChain = SDValue(PartOps[(int)PackElem::Hi].getNode(), ChainResIdx);252SDValue FusedChains =253CDAG.getNode(ISD::TokenFactor, MVT::Other, {LowChain, HiChain});254255// Chain only [store]256if (PackData)257return FusedChains;258259// Re-pack into full packed vector result260MVT PackedVT =261getLegalVectorType(Packing::Dense, DataVT.getVectorElementType());262SDValue PackedVals = CDAG.getPack(PackedVT, PartOps[(int)PackElem::Lo],263PartOps[(int)PackElem::Hi], UpperPartAVL);264265return CDAG.getMergeValues({PackedVals, FusedChains});266}267268SDValue VETargetLowering::lowerVVP_GATHER_SCATTER(SDValue Op,269VECustomDAG &CDAG) const {270EVT DataVT = *getIdiomaticVectorType(Op.getNode());271auto Packing = getTypePacking(DataVT);272MVT LegalDataVT =273getLegalVectorType(Packing, DataVT.getVectorElementType().getSimpleVT());274275SDValue AVL = getAnnotatedNodeAVL(Op).first;276SDValue Index = getGatherScatterIndex(Op);277SDValue BasePtr = getMemoryPtr(Op);278SDValue Mask = getNodeMask(Op);279SDValue Chain = getNodeChain(Op);280SDValue Scale = getGatherScatterScale(Op);281SDValue PassThru = getNodePassthru(Op);282SDValue StoredValue = getStoredValue(Op);283if (PassThru && PassThru->isUndef())284PassThru = SDValue();285286bool IsScatter = (bool)StoredValue;287288// TODO: Infer lower AVL from mask.289if (!AVL)290AVL = CDAG.getConstant(DataVT.getVectorNumElements(), MVT::i32);291292// Default to the all-true mask.293if (!Mask)294Mask = CDAG.getConstantMask(Packing, true);295296SDValue AddressVec =297CDAG.getGatherScatterAddress(BasePtr, Scale, Index, Mask, AVL);298if (IsScatter)299return CDAG.getNode(VEISD::VVP_SCATTER, MVT::Other,300{Chain, StoredValue, AddressVec, Mask, AVL});301302// Gather.303SDValue NewLoadV = CDAG.getNode(VEISD::VVP_GATHER, {LegalDataVT, MVT::Other},304{Chain, AddressVec, Mask, AVL});305306if (!PassThru)307return NewLoadV;308309// TODO: Use vvp_select310SDValue DataV = CDAG.getNode(VEISD::VVP_SELECT, LegalDataVT,311{NewLoadV, PassThru, Mask, AVL});312SDValue NewLoadChainV = SDValue(NewLoadV.getNode(), 1);313return CDAG.getMergeValues({DataV, NewLoadChainV});314}315316SDValue VETargetLowering::legalizeInternalLoadStoreOp(SDValue Op,317VECustomDAG &CDAG) const {318LLVM_DEBUG(dbgs() << "::legalizeInternalLoadStoreOp\n";);319MVT DataVT = getIdiomaticVectorType(Op.getNode())->getSimpleVT();320321// TODO: Recognize packable load,store.322if (isPackedVectorType(DataVT))323return splitPackedLoadStore(Op, CDAG);324325return legalizePackedAVL(Op, CDAG);326}327328SDValue VETargetLowering::legalizeInternalVectorOp(SDValue Op,329SelectionDAG &DAG) const {330LLVM_DEBUG(dbgs() << "::legalizeInternalVectorOp\n";);331VECustomDAG CDAG(DAG, Op);332333// Dispatch to specialized legalization functions.334switch (Op->getOpcode()) {335case VEISD::VVP_LOAD:336case VEISD::VVP_STORE:337return legalizeInternalLoadStoreOp(Op, CDAG);338}339340EVT IdiomVT = Op.getValueType();341if (isPackedVectorType(IdiomVT) &&342!supportsPackedMode(Op.getOpcode(), IdiomVT))343return splitVectorOp(Op, CDAG);344345// TODO: Implement odd/even splitting.346return legalizePackedAVL(Op, CDAG);347}348349SDValue VETargetLowering::splitVectorOp(SDValue Op, VECustomDAG &CDAG) const {350MVT ResVT = splitVectorType(Op.getValue(0).getSimpleValueType());351352auto AVLPos = getAVLPos(Op->getOpcode());353auto MaskPos = getMaskPos(Op->getOpcode());354355SDValue PackedMask = getNodeMask(Op);356auto AVLPair = getAnnotatedNodeAVL(Op);357SDValue PackedAVL = AVLPair.first;358assert(!AVLPair.second && "Expecting non pack-legalized oepration");359360// request the parts361SDValue PartOps[2];362363SDValue UpperPartAVL; // we will use this for packing things back together364for (PackElem Part : {PackElem::Hi, PackElem::Lo}) {365// VP ops already have an explicit mask and AVL. When expanding from non-VP366// attach those additional inputs here.367auto SplitTM = CDAG.getTargetSplitMask(PackedMask, PackedAVL, Part);368369if (Part == PackElem::Hi)370UpperPartAVL = SplitTM.AVL;371372// Attach non-predicating value operands373SmallVector<SDValue, 4> OpVec;374for (unsigned i = 0; i < Op.getNumOperands(); ++i) {375if (AVLPos && ((int)i) == *AVLPos)376continue;377if (MaskPos && ((int)i) == *MaskPos)378continue;379380// Value operand381auto PackedOperand = Op.getOperand(i);382auto UnpackedOpVT = splitVectorType(PackedOperand.getSimpleValueType());383SDValue PartV =384CDAG.getUnpack(UnpackedOpVT, PackedOperand, Part, SplitTM.AVL);385OpVec.push_back(PartV);386}387388// Add predicating args and generate part node.389OpVec.push_back(SplitTM.Mask);390OpVec.push_back(SplitTM.AVL);391// Emit legal VVP nodes.392PartOps[(int)Part] =393CDAG.getNode(Op.getOpcode(), ResVT, OpVec, Op->getFlags());394}395396// Re-package vectors.397return CDAG.getPack(Op.getValueType(), PartOps[(int)PackElem::Lo],398PartOps[(int)PackElem::Hi], UpperPartAVL);399}400401SDValue VETargetLowering::legalizePackedAVL(SDValue Op,402VECustomDAG &CDAG) const {403LLVM_DEBUG(dbgs() << "::legalizePackedAVL\n";);404// Only required for VEC and VVP ops.405if (!isVVPOrVEC(Op->getOpcode()))406return Op;407408// Operation already has a legal AVL.409auto AVL = getNodeAVL(Op);410if (isLegalAVL(AVL))411return Op;412413// Half and round up EVL for 32bit element types.414SDValue LegalAVL = AVL;415MVT IdiomVT = getIdiomaticVectorType(Op.getNode())->getSimpleVT();416if (isPackedVectorType(IdiomVT)) {417assert(maySafelyIgnoreMask(Op) &&418"TODO Shift predication from EVL into Mask");419420if (auto *ConstAVL = dyn_cast<ConstantSDNode>(AVL)) {421LegalAVL = CDAG.getConstant((ConstAVL->getZExtValue() + 1) / 2, MVT::i32);422} else {423auto ConstOne = CDAG.getConstant(1, MVT::i32);424auto PlusOne = CDAG.getNode(ISD::ADD, MVT::i32, {AVL, ConstOne});425LegalAVL = CDAG.getNode(ISD::SRL, MVT::i32, {PlusOne, ConstOne});426}427}428429SDValue AnnotatedLegalAVL = CDAG.annotateLegalAVL(LegalAVL);430431// Copy the operand list.432int NumOp = Op->getNumOperands();433auto AVLPos = getAVLPos(Op->getOpcode());434std::vector<SDValue> FixedOperands;435for (int i = 0; i < NumOp; ++i) {436if (AVLPos && (i == *AVLPos)) {437FixedOperands.push_back(AnnotatedLegalAVL);438continue;439}440FixedOperands.push_back(Op->getOperand(i));441}442443// Clone the operation with fixed operands.444auto Flags = Op->getFlags();445SDValue NewN =446CDAG.getNode(Op->getOpcode(), Op->getVTList(), FixedOperands, Flags);447return NewN;448}449450451