Path: blob/main/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
35266 views
//===- SPIRVLegalizerInfo.cpp --- SPIR-V Legalization Rules ------*- C++ -*-==//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.3// See https://llvm.org/LICENSE.txt for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===----------------------------------------------------------------------===//7//8// This file implements the targeting of the Machinelegalizer class for SPIR-V.9//10//===----------------------------------------------------------------------===//1112#include "SPIRVLegalizerInfo.h"13#include "SPIRV.h"14#include "SPIRVGlobalRegistry.h"15#include "SPIRVSubtarget.h"16#include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"17#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"18#include "llvm/CodeGen/MachineInstr.h"19#include "llvm/CodeGen/MachineRegisterInfo.h"20#include "llvm/CodeGen/TargetOpcodes.h"2122using namespace llvm;23using namespace llvm::LegalizeActions;24using namespace llvm::LegalityPredicates;2526static const std::set<unsigned> TypeFoldingSupportingOpcs = {27TargetOpcode::G_ADD,28TargetOpcode::G_FADD,29TargetOpcode::G_SUB,30TargetOpcode::G_FSUB,31TargetOpcode::G_MUL,32TargetOpcode::G_FMUL,33TargetOpcode::G_SDIV,34TargetOpcode::G_UDIV,35TargetOpcode::G_FDIV,36TargetOpcode::G_SREM,37TargetOpcode::G_UREM,38TargetOpcode::G_FREM,39TargetOpcode::G_FNEG,40TargetOpcode::G_CONSTANT,41TargetOpcode::G_FCONSTANT,42TargetOpcode::G_AND,43TargetOpcode::G_OR,44TargetOpcode::G_XOR,45TargetOpcode::G_SHL,46TargetOpcode::G_ASHR,47TargetOpcode::G_LSHR,48TargetOpcode::G_SELECT,49TargetOpcode::G_EXTRACT_VECTOR_ELT,50};5152bool isTypeFoldingSupported(unsigned Opcode) {53return TypeFoldingSupportingOpcs.count(Opcode) > 0;54}5556SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {57using namespace TargetOpcode;5859this->ST = &ST;60GR = ST.getSPIRVGlobalRegistry();6162const LLT s1 = LLT::scalar(1);63const LLT s8 = LLT::scalar(8);64const LLT s16 = LLT::scalar(16);65const LLT s32 = LLT::scalar(32);66const LLT s64 = LLT::scalar(64);6768const LLT v16s64 = LLT::fixed_vector(16, 64);69const LLT v16s32 = LLT::fixed_vector(16, 32);70const LLT v16s16 = LLT::fixed_vector(16, 16);71const LLT v16s8 = LLT::fixed_vector(16, 8);72const LLT v16s1 = LLT::fixed_vector(16, 1);7374const LLT v8s64 = LLT::fixed_vector(8, 64);75const LLT v8s32 = LLT::fixed_vector(8, 32);76const LLT v8s16 = LLT::fixed_vector(8, 16);77const LLT v8s8 = LLT::fixed_vector(8, 8);78const LLT v8s1 = LLT::fixed_vector(8, 1);7980const LLT v4s64 = LLT::fixed_vector(4, 64);81const LLT v4s32 = LLT::fixed_vector(4, 32);82const LLT v4s16 = LLT::fixed_vector(4, 16);83const LLT v4s8 = LLT::fixed_vector(4, 8);84const LLT v4s1 = LLT::fixed_vector(4, 1);8586const LLT v3s64 = LLT::fixed_vector(3, 64);87const LLT v3s32 = LLT::fixed_vector(3, 32);88const LLT v3s16 = LLT::fixed_vector(3, 16);89const LLT v3s8 = LLT::fixed_vector(3, 8);90const LLT v3s1 = LLT::fixed_vector(3, 1);9192const LLT v2s64 = LLT::fixed_vector(2, 64);93const LLT v2s32 = LLT::fixed_vector(2, 32);94const LLT v2s16 = LLT::fixed_vector(2, 16);95const LLT v2s8 = LLT::fixed_vector(2, 8);96const LLT v2s1 = LLT::fixed_vector(2, 1);9798const unsigned PSize = ST.getPointerSize();99const LLT p0 = LLT::pointer(0, PSize); // Function100const LLT p1 = LLT::pointer(1, PSize); // CrossWorkgroup101const LLT p2 = LLT::pointer(2, PSize); // UniformConstant102const LLT p3 = LLT::pointer(3, PSize); // Workgroup103const LLT p4 = LLT::pointer(4, PSize); // Generic104const LLT p5 =105LLT::pointer(5, PSize); // Input, SPV_INTEL_usm_storage_classes (Device)106const LLT p6 = LLT::pointer(6, PSize); // SPV_INTEL_usm_storage_classes (Host)107108// TODO: remove copy-pasting here by using concatenation in some way.109auto allPtrsScalarsAndVectors = {110p0, p1, p2, p3, p4, p5, p6, s1, s8, s16,111s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, v3s16,112v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, v8s1, v8s8, v8s16,113v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};114115auto allVectors = {v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8,116v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32,117v4s64, v8s1, v8s8, v8s16, v8s32, v8s64, v16s1,118v16s8, v16s16, v16s32, v16s64};119120auto allScalarsAndVectors = {121s1, s8, s16, s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64,122v3s1, v3s8, v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64,123v8s1, v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};124125auto allIntScalarsAndVectors = {s8, s16, s32, s64, v2s8, v2s16,126v2s32, v2s64, v3s8, v3s16, v3s32, v3s64,127v4s8, v4s16, v4s32, v4s64, v8s8, v8s16,128v8s32, v8s64, v16s8, v16s16, v16s32, v16s64};129130auto allBoolScalarsAndVectors = {s1, v2s1, v3s1, v4s1, v8s1, v16s1};131132auto allIntScalars = {s8, s16, s32, s64};133134auto allFloatScalars = {s16, s32, s64};135136auto allFloatScalarsAndVectors = {137s16, s32, s64, v2s16, v2s32, v2s64, v3s16, v3s32, v3s64,138v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64};139140auto allFloatAndIntScalarsAndPtrs = {s8, s16, s32, s64, p0, p1,141p2, p3, p4, p5, p6};142143auto allPtrs = {p0, p1, p2, p3, p4, p5, p6};144auto allWritablePtrs = {p0, p1, p3, p4, p5, p6};145146for (auto Opc : TypeFoldingSupportingOpcs)147getActionDefinitionsBuilder(Opc).custom();148149getActionDefinitionsBuilder(G_GLOBAL_VALUE).alwaysLegal();150151// TODO: add proper rules for vectors legalization.152getActionDefinitionsBuilder(153{G_BUILD_VECTOR, G_SHUFFLE_VECTOR, G_SPLAT_VECTOR})154.alwaysLegal();155156// Vector Reduction Operations157getActionDefinitionsBuilder(158{G_VECREDUCE_SMIN, G_VECREDUCE_SMAX, G_VECREDUCE_UMIN, G_VECREDUCE_UMAX,159G_VECREDUCE_ADD, G_VECREDUCE_MUL, G_VECREDUCE_FMUL, G_VECREDUCE_FMIN,160G_VECREDUCE_FMAX, G_VECREDUCE_FMINIMUM, G_VECREDUCE_FMAXIMUM,161G_VECREDUCE_OR, G_VECREDUCE_AND, G_VECREDUCE_XOR})162.legalFor(allVectors)163.scalarize(1)164.lower();165166getActionDefinitionsBuilder({G_VECREDUCE_SEQ_FADD, G_VECREDUCE_SEQ_FMUL})167.scalarize(2)168.lower();169170// Merge/Unmerge171// TODO: add proper legalization rules.172getActionDefinitionsBuilder(G_UNMERGE_VALUES).alwaysLegal();173174getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE})175.legalIf(all(typeInSet(0, allWritablePtrs), typeInSet(1, allPtrs)));176177getActionDefinitionsBuilder(G_MEMSET).legalIf(178all(typeInSet(0, allWritablePtrs), typeInSet(1, allIntScalars)));179180getActionDefinitionsBuilder(G_ADDRSPACE_CAST)181.legalForCartesianProduct(allPtrs, allPtrs);182183getActionDefinitionsBuilder({G_LOAD, G_STORE}).legalIf(typeInSet(1, allPtrs));184185getActionDefinitionsBuilder(G_BITREVERSE).legalFor(allIntScalarsAndVectors);186187getActionDefinitionsBuilder(G_FMA).legalFor(allFloatScalarsAndVectors);188189getActionDefinitionsBuilder({G_FPTOSI, G_FPTOUI})190.legalForCartesianProduct(allIntScalarsAndVectors,191allFloatScalarsAndVectors);192193getActionDefinitionsBuilder({G_SITOFP, G_UITOFP})194.legalForCartesianProduct(allFloatScalarsAndVectors,195allScalarsAndVectors);196197getActionDefinitionsBuilder({G_SMIN, G_SMAX, G_UMIN, G_UMAX, G_ABS})198.legalFor(allIntScalarsAndVectors);199200getActionDefinitionsBuilder(G_CTPOP).legalForCartesianProduct(201allIntScalarsAndVectors, allIntScalarsAndVectors);202203getActionDefinitionsBuilder(G_PHI).legalFor(allPtrsScalarsAndVectors);204205getActionDefinitionsBuilder(G_BITCAST).legalIf(206all(typeInSet(0, allPtrsScalarsAndVectors),207typeInSet(1, allPtrsScalarsAndVectors)));208209getActionDefinitionsBuilder({G_IMPLICIT_DEF, G_FREEZE}).alwaysLegal();210211getActionDefinitionsBuilder({G_STACKSAVE, G_STACKRESTORE}).alwaysLegal();212213getActionDefinitionsBuilder(G_INTTOPTR)214.legalForCartesianProduct(allPtrs, allIntScalars);215getActionDefinitionsBuilder(G_PTRTOINT)216.legalForCartesianProduct(allIntScalars, allPtrs);217getActionDefinitionsBuilder(G_PTR_ADD).legalForCartesianProduct(218allPtrs, allIntScalars);219220// ST.canDirectlyComparePointers() for pointer args is supported in221// legalizeCustom().222getActionDefinitionsBuilder(G_ICMP).customIf(223all(typeInSet(0, allBoolScalarsAndVectors),224typeInSet(1, allPtrsScalarsAndVectors)));225226getActionDefinitionsBuilder(G_FCMP).legalIf(227all(typeInSet(0, allBoolScalarsAndVectors),228typeInSet(1, allFloatScalarsAndVectors)));229230getActionDefinitionsBuilder({G_ATOMICRMW_OR, G_ATOMICRMW_ADD, G_ATOMICRMW_AND,231G_ATOMICRMW_MAX, G_ATOMICRMW_MIN,232G_ATOMICRMW_SUB, G_ATOMICRMW_XOR,233G_ATOMICRMW_UMAX, G_ATOMICRMW_UMIN})234.legalForCartesianProduct(allIntScalars, allWritablePtrs);235236getActionDefinitionsBuilder(237{G_ATOMICRMW_FADD, G_ATOMICRMW_FSUB, G_ATOMICRMW_FMIN, G_ATOMICRMW_FMAX})238.legalForCartesianProduct(allFloatScalars, allWritablePtrs);239240getActionDefinitionsBuilder(G_ATOMICRMW_XCHG)241.legalForCartesianProduct(allFloatAndIntScalarsAndPtrs, allWritablePtrs);242243getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG_WITH_SUCCESS).lower();244// TODO: add proper legalization rules.245getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG).alwaysLegal();246247getActionDefinitionsBuilder({G_UADDO, G_USUBO, G_SMULO, G_UMULO})248.alwaysLegal();249250// Extensions.251getActionDefinitionsBuilder({G_TRUNC, G_ZEXT, G_SEXT, G_ANYEXT})252.legalForCartesianProduct(allScalarsAndVectors);253254// FP conversions.255getActionDefinitionsBuilder({G_FPTRUNC, G_FPEXT})256.legalForCartesianProduct(allFloatScalarsAndVectors);257258// Pointer-handling.259getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0});260261// Control-flow. In some cases (e.g. constants) s1 may be promoted to s32.262getActionDefinitionsBuilder(G_BRCOND).legalFor({s1, s32});263264// TODO: Review the target OpenCL and GLSL Extended Instruction Set specs to265// tighten these requirements. Many of these math functions are only legal on266// specific bitwidths, so they are not selectable for267// allFloatScalarsAndVectors.268getActionDefinitionsBuilder({G_FPOW,269G_FEXP,270G_FEXP2,271G_FLOG,272G_FLOG2,273G_FLOG10,274G_FABS,275G_FMINNUM,276G_FMAXNUM,277G_FCEIL,278G_FCOS,279G_FSIN,280G_FTAN,281G_FACOS,282G_FASIN,283G_FATAN,284G_FCOSH,285G_FSINH,286G_FTANH,287G_FSQRT,288G_FFLOOR,289G_FRINT,290G_FNEARBYINT,291G_INTRINSIC_ROUND,292G_INTRINSIC_TRUNC,293G_FMINIMUM,294G_FMAXIMUM,295G_INTRINSIC_ROUNDEVEN})296.legalFor(allFloatScalarsAndVectors);297298getActionDefinitionsBuilder(G_FCOPYSIGN)299.legalForCartesianProduct(allFloatScalarsAndVectors,300allFloatScalarsAndVectors);301302getActionDefinitionsBuilder(G_FPOWI).legalForCartesianProduct(303allFloatScalarsAndVectors, allIntScalarsAndVectors);304305if (ST.canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) {306getActionDefinitionsBuilder(307{G_CTTZ, G_CTTZ_ZERO_UNDEF, G_CTLZ, G_CTLZ_ZERO_UNDEF})308.legalForCartesianProduct(allIntScalarsAndVectors,309allIntScalarsAndVectors);310311// Struct return types become a single scalar, so cannot easily legalize.312getActionDefinitionsBuilder({G_SMULH, G_UMULH}).alwaysLegal();313314// supported saturation arithmetic315getActionDefinitionsBuilder({G_SADDSAT, G_UADDSAT, G_SSUBSAT, G_USUBSAT})316.legalFor(allIntScalarsAndVectors);317}318319getLegacyLegalizerInfo().computeTables();320verify(*ST.getInstrInfo());321}322323static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpirvType,324LegalizerHelper &Helper,325MachineRegisterInfo &MRI,326SPIRVGlobalRegistry *GR) {327Register ConvReg = MRI.createGenericVirtualRegister(ConvTy);328GR->assignSPIRVTypeToVReg(SpirvType, ConvReg, Helper.MIRBuilder.getMF());329Helper.MIRBuilder.buildInstr(TargetOpcode::G_PTRTOINT)330.addDef(ConvReg)331.addUse(Reg);332return ConvReg;333}334335bool SPIRVLegalizerInfo::legalizeCustom(336LegalizerHelper &Helper, MachineInstr &MI,337LostDebugLocObserver &LocObserver) const {338auto Opc = MI.getOpcode();339MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();340if (!isTypeFoldingSupported(Opc)) {341assert(Opc == TargetOpcode::G_ICMP);342assert(GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg()));343auto &Op0 = MI.getOperand(2);344auto &Op1 = MI.getOperand(3);345Register Reg0 = Op0.getReg();346Register Reg1 = Op1.getReg();347CmpInst::Predicate Cond =348static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate());349if ((!ST->canDirectlyComparePointers() ||350(Cond != CmpInst::ICMP_EQ && Cond != CmpInst::ICMP_NE)) &&351MRI.getType(Reg0).isPointer() && MRI.getType(Reg1).isPointer()) {352LLT ConvT = LLT::scalar(ST->getPointerSize());353Type *LLVMTy = IntegerType::get(MI.getMF()->getFunction().getContext(),354ST->getPointerSize());355SPIRVType *SpirvTy = GR->getOrCreateSPIRVType(LLVMTy, Helper.MIRBuilder);356Op0.setReg(convertPtrToInt(Reg0, ConvT, SpirvTy, Helper, MRI, GR));357Op1.setReg(convertPtrToInt(Reg1, ConvT, SpirvTy, Helper, MRI, GR));358}359return true;360}361// TODO: implement legalization for other opcodes.362return true;363}364365366