Path: blob/main/contrib/llvm-project/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
35294 views
//===- DXILIntrinsicExpansion.cpp - Prepare LLVM Module for DXIL encoding--===//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.3// See https://llvm.org/LICENSE.txt for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===----------------------------------------------------------------------===//7///8/// \file This file contains DXIL intrinsic expansions for those that don't have9// opcodes in DirectX Intermediate Language (DXIL).10//===----------------------------------------------------------------------===//1112#include "DXILIntrinsicExpansion.h"13#include "DirectX.h"14#include "llvm/ADT/STLExtras.h"15#include "llvm/ADT/SmallVector.h"16#include "llvm/CodeGen/Passes.h"17#include "llvm/IR/IRBuilder.h"18#include "llvm/IR/Instruction.h"19#include "llvm/IR/Instructions.h"20#include "llvm/IR/Intrinsics.h"21#include "llvm/IR/IntrinsicsDirectX.h"22#include "llvm/IR/Module.h"23#include "llvm/IR/PassManager.h"24#include "llvm/IR/Type.h"25#include "llvm/Pass.h"26#include "llvm/Support/ErrorHandling.h"27#include "llvm/Support/MathExtras.h"2829#define DEBUG_TYPE "dxil-intrinsic-expansion"3031using namespace llvm;3233static bool isIntrinsicExpansion(Function &F) {34switch (F.getIntrinsicID()) {35case Intrinsic::abs:36case Intrinsic::exp:37case Intrinsic::log:38case Intrinsic::log10:39case Intrinsic::pow:40case Intrinsic::dx_any:41case Intrinsic::dx_clamp:42case Intrinsic::dx_uclamp:43case Intrinsic::dx_lerp:44case Intrinsic::dx_sdot:45case Intrinsic::dx_udot:46return true;47}48return false;49}5051static bool expandAbs(CallInst *Orig) {52Value *X = Orig->getOperand(0);53IRBuilder<> Builder(Orig->getParent());54Builder.SetInsertPoint(Orig);55Type *Ty = X->getType();56Type *EltTy = Ty->getScalarType();57Constant *Zero = Ty->isVectorTy()58? ConstantVector::getSplat(59ElementCount::getFixed(60cast<FixedVectorType>(Ty)->getNumElements()),61ConstantInt::get(EltTy, 0))62: ConstantInt::get(EltTy, 0);63auto *V = Builder.CreateSub(Zero, X);64auto *MaxCall =65Builder.CreateIntrinsic(Ty, Intrinsic::smax, {X, V}, nullptr, "dx.max");66Orig->replaceAllUsesWith(MaxCall);67Orig->eraseFromParent();68return true;69}7071static bool expandIntegerDot(CallInst *Orig, Intrinsic::ID DotIntrinsic) {72assert(DotIntrinsic == Intrinsic::dx_sdot ||73DotIntrinsic == Intrinsic::dx_udot);74Intrinsic::ID MadIntrinsic = DotIntrinsic == Intrinsic::dx_sdot75? Intrinsic::dx_imad76: Intrinsic::dx_umad;77Value *A = Orig->getOperand(0);78Value *B = Orig->getOperand(1);79[[maybe_unused]] Type *ATy = A->getType();80[[maybe_unused]] Type *BTy = B->getType();81assert(ATy->isVectorTy() && BTy->isVectorTy());8283IRBuilder<> Builder(Orig->getParent());84Builder.SetInsertPoint(Orig);8586auto *AVec = dyn_cast<FixedVectorType>(A->getType());87Value *Elt0 = Builder.CreateExtractElement(A, (uint64_t)0);88Value *Elt1 = Builder.CreateExtractElement(B, (uint64_t)0);89Value *Result = Builder.CreateMul(Elt0, Elt1);90for (unsigned I = 1; I < AVec->getNumElements(); I++) {91Elt0 = Builder.CreateExtractElement(A, I);92Elt1 = Builder.CreateExtractElement(B, I);93Result = Builder.CreateIntrinsic(Result->getType(), MadIntrinsic,94ArrayRef<Value *>{Elt0, Elt1, Result},95nullptr, "dx.mad");96}97Orig->replaceAllUsesWith(Result);98Orig->eraseFromParent();99return true;100}101102static bool expandExpIntrinsic(CallInst *Orig) {103Value *X = Orig->getOperand(0);104IRBuilder<> Builder(Orig->getParent());105Builder.SetInsertPoint(Orig);106Type *Ty = X->getType();107Type *EltTy = Ty->getScalarType();108Constant *Log2eConst =109Ty->isVectorTy() ? ConstantVector::getSplat(110ElementCount::getFixed(111cast<FixedVectorType>(Ty)->getNumElements()),112ConstantFP::get(EltTy, numbers::log2ef))113: ConstantFP::get(EltTy, numbers::log2ef);114Value *NewX = Builder.CreateFMul(Log2eConst, X);115auto *Exp2Call =116Builder.CreateIntrinsic(Ty, Intrinsic::exp2, {NewX}, nullptr, "dx.exp2");117Exp2Call->setTailCall(Orig->isTailCall());118Exp2Call->setAttributes(Orig->getAttributes());119Orig->replaceAllUsesWith(Exp2Call);120Orig->eraseFromParent();121return true;122}123124static bool expandAnyIntrinsic(CallInst *Orig) {125Value *X = Orig->getOperand(0);126IRBuilder<> Builder(Orig->getParent());127Builder.SetInsertPoint(Orig);128Type *Ty = X->getType();129Type *EltTy = Ty->getScalarType();130131if (!Ty->isVectorTy()) {132Value *Cond = EltTy->isFloatingPointTy()133? Builder.CreateFCmpUNE(X, ConstantFP::get(EltTy, 0))134: Builder.CreateICmpNE(X, ConstantInt::get(EltTy, 0));135Orig->replaceAllUsesWith(Cond);136} else {137auto *XVec = dyn_cast<FixedVectorType>(Ty);138Value *Cond =139EltTy->isFloatingPointTy()140? Builder.CreateFCmpUNE(141X, ConstantVector::getSplat(142ElementCount::getFixed(XVec->getNumElements()),143ConstantFP::get(EltTy, 0)))144: Builder.CreateICmpNE(145X, ConstantVector::getSplat(146ElementCount::getFixed(XVec->getNumElements()),147ConstantInt::get(EltTy, 0)));148Value *Result = Builder.CreateExtractElement(Cond, (uint64_t)0);149for (unsigned I = 1; I < XVec->getNumElements(); I++) {150Value *Elt = Builder.CreateExtractElement(Cond, I);151Result = Builder.CreateOr(Result, Elt);152}153Orig->replaceAllUsesWith(Result);154}155Orig->eraseFromParent();156return true;157}158159static bool expandLerpIntrinsic(CallInst *Orig) {160Value *X = Orig->getOperand(0);161Value *Y = Orig->getOperand(1);162Value *S = Orig->getOperand(2);163IRBuilder<> Builder(Orig->getParent());164Builder.SetInsertPoint(Orig);165auto *V = Builder.CreateFSub(Y, X);166V = Builder.CreateFMul(S, V);167auto *Result = Builder.CreateFAdd(X, V, "dx.lerp");168Orig->replaceAllUsesWith(Result);169Orig->eraseFromParent();170return true;171}172173static bool expandLogIntrinsic(CallInst *Orig,174float LogConstVal = numbers::ln2f) {175Value *X = Orig->getOperand(0);176IRBuilder<> Builder(Orig->getParent());177Builder.SetInsertPoint(Orig);178Type *Ty = X->getType();179Type *EltTy = Ty->getScalarType();180Constant *Ln2Const =181Ty->isVectorTy() ? ConstantVector::getSplat(182ElementCount::getFixed(183cast<FixedVectorType>(Ty)->getNumElements()),184ConstantFP::get(EltTy, LogConstVal))185: ConstantFP::get(EltTy, LogConstVal);186auto *Log2Call =187Builder.CreateIntrinsic(Ty, Intrinsic::log2, {X}, nullptr, "elt.log2");188Log2Call->setTailCall(Orig->isTailCall());189Log2Call->setAttributes(Orig->getAttributes());190auto *Result = Builder.CreateFMul(Ln2Const, Log2Call);191Orig->replaceAllUsesWith(Result);192Orig->eraseFromParent();193return true;194}195static bool expandLog10Intrinsic(CallInst *Orig) {196return expandLogIntrinsic(Orig, numbers::ln2f / numbers::ln10f);197}198199static bool expandPowIntrinsic(CallInst *Orig) {200201Value *X = Orig->getOperand(0);202Value *Y = Orig->getOperand(1);203Type *Ty = X->getType();204IRBuilder<> Builder(Orig->getParent());205Builder.SetInsertPoint(Orig);206207auto *Log2Call =208Builder.CreateIntrinsic(Ty, Intrinsic::log2, {X}, nullptr, "elt.log2");209auto *Mul = Builder.CreateFMul(Log2Call, Y);210auto *Exp2Call =211Builder.CreateIntrinsic(Ty, Intrinsic::exp2, {Mul}, nullptr, "elt.exp2");212Exp2Call->setTailCall(Orig->isTailCall());213Exp2Call->setAttributes(Orig->getAttributes());214Orig->replaceAllUsesWith(Exp2Call);215Orig->eraseFromParent();216return true;217}218219static Intrinsic::ID getMaxForClamp(Type *ElemTy,220Intrinsic::ID ClampIntrinsic) {221if (ClampIntrinsic == Intrinsic::dx_uclamp)222return Intrinsic::umax;223assert(ClampIntrinsic == Intrinsic::dx_clamp);224if (ElemTy->isVectorTy())225ElemTy = ElemTy->getScalarType();226if (ElemTy->isIntegerTy())227return Intrinsic::smax;228assert(ElemTy->isFloatingPointTy());229return Intrinsic::maxnum;230}231232static Intrinsic::ID getMinForClamp(Type *ElemTy,233Intrinsic::ID ClampIntrinsic) {234if (ClampIntrinsic == Intrinsic::dx_uclamp)235return Intrinsic::umin;236assert(ClampIntrinsic == Intrinsic::dx_clamp);237if (ElemTy->isVectorTy())238ElemTy = ElemTy->getScalarType();239if (ElemTy->isIntegerTy())240return Intrinsic::smin;241assert(ElemTy->isFloatingPointTy());242return Intrinsic::minnum;243}244245static bool expandClampIntrinsic(CallInst *Orig, Intrinsic::ID ClampIntrinsic) {246Value *X = Orig->getOperand(0);247Value *Min = Orig->getOperand(1);248Value *Max = Orig->getOperand(2);249Type *Ty = X->getType();250IRBuilder<> Builder(Orig->getParent());251Builder.SetInsertPoint(Orig);252auto *MaxCall = Builder.CreateIntrinsic(253Ty, getMaxForClamp(Ty, ClampIntrinsic), {X, Min}, nullptr, "dx.max");254auto *MinCall =255Builder.CreateIntrinsic(Ty, getMinForClamp(Ty, ClampIntrinsic),256{MaxCall, Max}, nullptr, "dx.min");257258Orig->replaceAllUsesWith(MinCall);259Orig->eraseFromParent();260return true;261}262263static bool expandIntrinsic(Function &F, CallInst *Orig) {264switch (F.getIntrinsicID()) {265case Intrinsic::abs:266return expandAbs(Orig);267case Intrinsic::exp:268return expandExpIntrinsic(Orig);269case Intrinsic::log:270return expandLogIntrinsic(Orig);271case Intrinsic::log10:272return expandLog10Intrinsic(Orig);273case Intrinsic::pow:274return expandPowIntrinsic(Orig);275case Intrinsic::dx_any:276return expandAnyIntrinsic(Orig);277case Intrinsic::dx_uclamp:278case Intrinsic::dx_clamp:279return expandClampIntrinsic(Orig, F.getIntrinsicID());280case Intrinsic::dx_lerp:281return expandLerpIntrinsic(Orig);282case Intrinsic::dx_sdot:283case Intrinsic::dx_udot:284return expandIntegerDot(Orig, F.getIntrinsicID());285}286return false;287}288289static bool expansionIntrinsics(Module &M) {290for (auto &F : make_early_inc_range(M.functions())) {291if (!isIntrinsicExpansion(F))292continue;293bool IntrinsicExpanded = false;294for (User *U : make_early_inc_range(F.users())) {295auto *IntrinsicCall = dyn_cast<CallInst>(U);296if (!IntrinsicCall)297continue;298IntrinsicExpanded = expandIntrinsic(F, IntrinsicCall);299}300if (F.user_empty() && IntrinsicExpanded)301F.eraseFromParent();302}303return true;304}305306PreservedAnalyses DXILIntrinsicExpansion::run(Module &M,307ModuleAnalysisManager &) {308if (expansionIntrinsics(M))309return PreservedAnalyses::none();310return PreservedAnalyses::all();311}312313bool DXILIntrinsicExpansionLegacy::runOnModule(Module &M) {314return expansionIntrinsics(M);315}316317char DXILIntrinsicExpansionLegacy::ID = 0;318319INITIALIZE_PASS_BEGIN(DXILIntrinsicExpansionLegacy, DEBUG_TYPE,320"DXIL Intrinsic Expansion", false, false)321INITIALIZE_PASS_END(DXILIntrinsicExpansionLegacy, DEBUG_TYPE,322"DXIL Intrinsic Expansion", false, false)323324ModulePass *llvm::createDXILIntrinsicExpansionLegacyPass() {325return new DXILIntrinsicExpansionLegacy();326}327328329