Path: blob/main/contrib/llvm-project/llvm/lib/Transforms/AggressiveInstCombine/TruncInstCombine.cpp
35266 views
//===- TruncInstCombine.cpp -----------------------------------------------===//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.3// See https://llvm.org/LICENSE.txt for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===----------------------------------------------------------------------===//7//8// TruncInstCombine - looks for expression graphs post-dominated by TruncInst9// and for each eligible graph, it will create a reduced bit-width expression,10// replace the old expression with this new one and remove the old expression.11// Eligible expression graph is such that:12// 1. Contains only supported instructions.13// 2. Supported leaves: ZExtInst, SExtInst, TruncInst and Constant value.14// 3. Can be evaluated into type with reduced legal bit-width.15// 4. All instructions in the graph must not have users outside the graph.16// The only exception is for {ZExt, SExt}Inst with operand type equal to17// the new reduced type evaluated in (3).18//19// The motivation for this optimization is that evaluating and expression using20// smaller bit-width is preferable, especially for vectorization where we can21// fit more values in one vectorized instruction. In addition, this optimization22// may decrease the number of cast instructions, but will not increase it.23//24//===----------------------------------------------------------------------===//2526#include "AggressiveInstCombineInternal.h"27#include "llvm/ADT/STLExtras.h"28#include "llvm/ADT/Statistic.h"29#include "llvm/Analysis/ConstantFolding.h"30#include "llvm/IR/DataLayout.h"31#include "llvm/IR/Dominators.h"32#include "llvm/IR/IRBuilder.h"33#include "llvm/IR/Instruction.h"34#include "llvm/Support/KnownBits.h"3536using namespace llvm;3738#define DEBUG_TYPE "aggressive-instcombine"3940STATISTIC(NumExprsReduced, "Number of truncations eliminated by reducing bit "41"width of expression graph");42STATISTIC(NumInstrsReduced,43"Number of instructions whose bit width was reduced");4445/// Given an instruction and a container, it fills all the relevant operands of46/// that instruction, with respect to the Trunc expression graph optimizaton.47static void getRelevantOperands(Instruction *I, SmallVectorImpl<Value *> &Ops) {48unsigned Opc = I->getOpcode();49switch (Opc) {50case Instruction::Trunc:51case Instruction::ZExt:52case Instruction::SExt:53// These CastInst are considered leaves of the evaluated expression, thus,54// their operands are not relevent.55break;56case Instruction::Add:57case Instruction::Sub:58case Instruction::Mul:59case Instruction::And:60case Instruction::Or:61case Instruction::Xor:62case Instruction::Shl:63case Instruction::LShr:64case Instruction::AShr:65case Instruction::UDiv:66case Instruction::URem:67case Instruction::InsertElement:68Ops.push_back(I->getOperand(0));69Ops.push_back(I->getOperand(1));70break;71case Instruction::ExtractElement:72Ops.push_back(I->getOperand(0));73break;74case Instruction::Select:75Ops.push_back(I->getOperand(1));76Ops.push_back(I->getOperand(2));77break;78case Instruction::PHI:79for (Value *V : cast<PHINode>(I)->incoming_values())80Ops.push_back(V);81break;82default:83llvm_unreachable("Unreachable!");84}85}8687bool TruncInstCombine::buildTruncExpressionGraph() {88SmallVector<Value *, 8> Worklist;89SmallVector<Instruction *, 8> Stack;90// Clear old instructions info.91InstInfoMap.clear();9293Worklist.push_back(CurrentTruncInst->getOperand(0));9495while (!Worklist.empty()) {96Value *Curr = Worklist.back();9798if (isa<Constant>(Curr)) {99Worklist.pop_back();100continue;101}102103auto *I = dyn_cast<Instruction>(Curr);104if (!I)105return false;106107if (!Stack.empty() && Stack.back() == I) {108// Already handled all instruction operands, can remove it from both the109// Worklist and the Stack, and add it to the instruction info map.110Worklist.pop_back();111Stack.pop_back();112// Insert I to the Info map.113InstInfoMap.insert(std::make_pair(I, Info()));114continue;115}116117if (InstInfoMap.count(I)) {118Worklist.pop_back();119continue;120}121122// Add the instruction to the stack before start handling its operands.123Stack.push_back(I);124125unsigned Opc = I->getOpcode();126switch (Opc) {127case Instruction::Trunc:128case Instruction::ZExt:129case Instruction::SExt:130// trunc(trunc(x)) -> trunc(x)131// trunc(ext(x)) -> ext(x) if the source type is smaller than the new dest132// trunc(ext(x)) -> trunc(x) if the source type is larger than the new133// dest134break;135case Instruction::Add:136case Instruction::Sub:137case Instruction::Mul:138case Instruction::And:139case Instruction::Or:140case Instruction::Xor:141case Instruction::Shl:142case Instruction::LShr:143case Instruction::AShr:144case Instruction::UDiv:145case Instruction::URem:146case Instruction::InsertElement:147case Instruction::ExtractElement:148case Instruction::Select: {149SmallVector<Value *, 2> Operands;150getRelevantOperands(I, Operands);151append_range(Worklist, Operands);152break;153}154case Instruction::PHI: {155SmallVector<Value *, 2> Operands;156getRelevantOperands(I, Operands);157// Add only operands not in Stack to prevent cycle158for (auto *Op : Operands)159if (!llvm::is_contained(Stack, Op))160Worklist.push_back(Op);161break;162}163default:164// TODO: Can handle more cases here:165// 1. shufflevector166// 2. sdiv, srem167// ...168return false;169}170}171return true;172}173174unsigned TruncInstCombine::getMinBitWidth() {175SmallVector<Value *, 8> Worklist;176SmallVector<Instruction *, 8> Stack;177178Value *Src = CurrentTruncInst->getOperand(0);179Type *DstTy = CurrentTruncInst->getType();180unsigned TruncBitWidth = DstTy->getScalarSizeInBits();181unsigned OrigBitWidth =182CurrentTruncInst->getOperand(0)->getType()->getScalarSizeInBits();183184if (isa<Constant>(Src))185return TruncBitWidth;186187Worklist.push_back(Src);188InstInfoMap[cast<Instruction>(Src)].ValidBitWidth = TruncBitWidth;189190while (!Worklist.empty()) {191Value *Curr = Worklist.back();192193if (isa<Constant>(Curr)) {194Worklist.pop_back();195continue;196}197198// Otherwise, it must be an instruction.199auto *I = cast<Instruction>(Curr);200201auto &Info = InstInfoMap[I];202203SmallVector<Value *, 2> Operands;204getRelevantOperands(I, Operands);205206if (!Stack.empty() && Stack.back() == I) {207// Already handled all instruction operands, can remove it from both, the208// Worklist and the Stack, and update MinBitWidth.209Worklist.pop_back();210Stack.pop_back();211for (auto *Operand : Operands)212if (auto *IOp = dyn_cast<Instruction>(Operand))213Info.MinBitWidth =214std::max(Info.MinBitWidth, InstInfoMap[IOp].MinBitWidth);215continue;216}217218// Add the instruction to the stack before start handling its operands.219Stack.push_back(I);220unsigned ValidBitWidth = Info.ValidBitWidth;221222// Update minimum bit-width before handling its operands. This is required223// when the instruction is part of a loop.224Info.MinBitWidth = std::max(Info.MinBitWidth, Info.ValidBitWidth);225226for (auto *Operand : Operands)227if (auto *IOp = dyn_cast<Instruction>(Operand)) {228// If we already calculated the minimum bit-width for this valid229// bit-width, or for a smaller valid bit-width, then just keep the230// answer we already calculated.231unsigned IOpBitwidth = InstInfoMap.lookup(IOp).ValidBitWidth;232if (IOpBitwidth >= ValidBitWidth)233continue;234InstInfoMap[IOp].ValidBitWidth = ValidBitWidth;235Worklist.push_back(IOp);236}237}238unsigned MinBitWidth = InstInfoMap.lookup(cast<Instruction>(Src)).MinBitWidth;239assert(MinBitWidth >= TruncBitWidth);240241if (MinBitWidth > TruncBitWidth) {242// In this case reducing expression with vector type might generate a new243// vector type, which is not preferable as it might result in generating244// sub-optimal code.245if (DstTy->isVectorTy())246return OrigBitWidth;247// Use the smallest integer type in the range [MinBitWidth, OrigBitWidth).248Type *Ty = DL.getSmallestLegalIntType(DstTy->getContext(), MinBitWidth);249// Update minimum bit-width with the new destination type bit-width if250// succeeded to find such, otherwise, with original bit-width.251MinBitWidth = Ty ? Ty->getScalarSizeInBits() : OrigBitWidth;252} else { // MinBitWidth == TruncBitWidth253// In this case the expression can be evaluated with the trunc instruction254// destination type, and trunc instruction can be omitted. However, we255// should not perform the evaluation if the original type is a legal scalar256// type and the target type is illegal.257bool FromLegal = MinBitWidth == 1 || DL.isLegalInteger(OrigBitWidth);258bool ToLegal = MinBitWidth == 1 || DL.isLegalInteger(MinBitWidth);259if (!DstTy->isVectorTy() && FromLegal && !ToLegal)260return OrigBitWidth;261}262return MinBitWidth;263}264265Type *TruncInstCombine::getBestTruncatedType() {266if (!buildTruncExpressionGraph())267return nullptr;268269// We don't want to duplicate instructions, which isn't profitable. Thus, we270// can't shrink something that has multiple users, unless all users are271// post-dominated by the trunc instruction, i.e., were visited during the272// expression evaluation.273unsigned DesiredBitWidth = 0;274for (auto Itr : InstInfoMap) {275Instruction *I = Itr.first;276if (I->hasOneUse())277continue;278bool IsExtInst = (isa<ZExtInst>(I) || isa<SExtInst>(I));279for (auto *U : I->users())280if (auto *UI = dyn_cast<Instruction>(U))281if (UI != CurrentTruncInst && !InstInfoMap.count(UI)) {282if (!IsExtInst)283return nullptr;284// If this is an extension from the dest type, we can eliminate it,285// even if it has multiple users. Thus, update the DesiredBitWidth and286// validate all extension instructions agrees on same DesiredBitWidth.287unsigned ExtInstBitWidth =288I->getOperand(0)->getType()->getScalarSizeInBits();289if (DesiredBitWidth && DesiredBitWidth != ExtInstBitWidth)290return nullptr;291DesiredBitWidth = ExtInstBitWidth;292}293}294295unsigned OrigBitWidth =296CurrentTruncInst->getOperand(0)->getType()->getScalarSizeInBits();297298// Initialize MinBitWidth for shift instructions with the minimum number299// that is greater than shift amount (i.e. shift amount + 1).300// For `lshr` adjust MinBitWidth so that all potentially truncated301// bits of the value-to-be-shifted are zeros.302// For `ashr` adjust MinBitWidth so that all potentially truncated303// bits of the value-to-be-shifted are sign bits (all zeros or ones)304// and even one (first) untruncated bit is sign bit.305// Exit early if MinBitWidth is not less than original bitwidth.306for (auto &Itr : InstInfoMap) {307Instruction *I = Itr.first;308if (I->isShift()) {309KnownBits KnownRHS = computeKnownBits(I->getOperand(1));310unsigned MinBitWidth = KnownRHS.getMaxValue()311.uadd_sat(APInt(OrigBitWidth, 1))312.getLimitedValue(OrigBitWidth);313if (MinBitWidth == OrigBitWidth)314return nullptr;315if (I->getOpcode() == Instruction::LShr) {316KnownBits KnownLHS = computeKnownBits(I->getOperand(0));317MinBitWidth =318std::max(MinBitWidth, KnownLHS.getMaxValue().getActiveBits());319}320if (I->getOpcode() == Instruction::AShr) {321unsigned NumSignBits = ComputeNumSignBits(I->getOperand(0));322MinBitWidth = std::max(MinBitWidth, OrigBitWidth - NumSignBits + 1);323}324if (MinBitWidth >= OrigBitWidth)325return nullptr;326Itr.second.MinBitWidth = MinBitWidth;327}328if (I->getOpcode() == Instruction::UDiv ||329I->getOpcode() == Instruction::URem) {330unsigned MinBitWidth = 0;331for (const auto &Op : I->operands()) {332KnownBits Known = computeKnownBits(Op);333MinBitWidth =334std::max(Known.getMaxValue().getActiveBits(), MinBitWidth);335if (MinBitWidth >= OrigBitWidth)336return nullptr;337}338Itr.second.MinBitWidth = MinBitWidth;339}340}341342// Calculate minimum allowed bit-width allowed for shrinking the currently343// visited truncate's operand.344unsigned MinBitWidth = getMinBitWidth();345346// Check that we can shrink to smaller bit-width than original one and that347// it is similar to the DesiredBitWidth is such exists.348if (MinBitWidth >= OrigBitWidth ||349(DesiredBitWidth && DesiredBitWidth != MinBitWidth))350return nullptr;351352return IntegerType::get(CurrentTruncInst->getContext(), MinBitWidth);353}354355/// Given a reduced scalar type \p Ty and a \p V value, return a reduced type356/// for \p V, according to its type, if it vector type, return the vector357/// version of \p Ty, otherwise return \p Ty.358static Type *getReducedType(Value *V, Type *Ty) {359assert(Ty && !Ty->isVectorTy() && "Expect Scalar Type");360if (auto *VTy = dyn_cast<VectorType>(V->getType()))361return VectorType::get(Ty, VTy->getElementCount());362return Ty;363}364365Value *TruncInstCombine::getReducedOperand(Value *V, Type *SclTy) {366Type *Ty = getReducedType(V, SclTy);367if (auto *C = dyn_cast<Constant>(V)) {368C = ConstantExpr::getTrunc(C, Ty);369// If we got a constantexpr back, try to simplify it with DL info.370return ConstantFoldConstant(C, DL, &TLI);371}372373auto *I = cast<Instruction>(V);374Info Entry = InstInfoMap.lookup(I);375assert(Entry.NewValue);376return Entry.NewValue;377}378379void TruncInstCombine::ReduceExpressionGraph(Type *SclTy) {380NumInstrsReduced += InstInfoMap.size();381// Pairs of old and new phi-nodes382SmallVector<std::pair<PHINode *, PHINode *>, 2> OldNewPHINodes;383for (auto &Itr : InstInfoMap) { // Forward384Instruction *I = Itr.first;385TruncInstCombine::Info &NodeInfo = Itr.second;386387assert(!NodeInfo.NewValue && "Instruction has been evaluated");388389IRBuilder<> Builder(I);390Value *Res = nullptr;391unsigned Opc = I->getOpcode();392switch (Opc) {393case Instruction::Trunc:394case Instruction::ZExt:395case Instruction::SExt: {396Type *Ty = getReducedType(I, SclTy);397// If the source type of the cast is the type we're trying for then we can398// just return the source. There's no need to insert it because it is not399// new.400if (I->getOperand(0)->getType() == Ty) {401assert(!isa<TruncInst>(I) && "Cannot reach here with TruncInst");402NodeInfo.NewValue = I->getOperand(0);403continue;404}405// Otherwise, must be the same type of cast, so just reinsert a new one.406// This also handles the case of zext(trunc(x)) -> zext(x).407Res = Builder.CreateIntCast(I->getOperand(0), Ty,408Opc == Instruction::SExt);409410// Update Worklist entries with new value if needed.411// There are three possible changes to the Worklist:412// 1. Update Old-TruncInst -> New-TruncInst.413// 2. Remove Old-TruncInst (if New node is not TruncInst).414// 3. Add New-TruncInst (if Old node was not TruncInst).415auto *Entry = find(Worklist, I);416if (Entry != Worklist.end()) {417if (auto *NewCI = dyn_cast<TruncInst>(Res))418*Entry = NewCI;419else420Worklist.erase(Entry);421} else if (auto *NewCI = dyn_cast<TruncInst>(Res))422Worklist.push_back(NewCI);423break;424}425case Instruction::Add:426case Instruction::Sub:427case Instruction::Mul:428case Instruction::And:429case Instruction::Or:430case Instruction::Xor:431case Instruction::Shl:432case Instruction::LShr:433case Instruction::AShr:434case Instruction::UDiv:435case Instruction::URem: {436Value *LHS = getReducedOperand(I->getOperand(0), SclTy);437Value *RHS = getReducedOperand(I->getOperand(1), SclTy);438Res = Builder.CreateBinOp((Instruction::BinaryOps)Opc, LHS, RHS);439// Preserve `exact` flag since truncation doesn't change exactness440if (auto *PEO = dyn_cast<PossiblyExactOperator>(I))441if (auto *ResI = dyn_cast<Instruction>(Res))442ResI->setIsExact(PEO->isExact());443break;444}445case Instruction::ExtractElement: {446Value *Vec = getReducedOperand(I->getOperand(0), SclTy);447Value *Idx = I->getOperand(1);448Res = Builder.CreateExtractElement(Vec, Idx);449break;450}451case Instruction::InsertElement: {452Value *Vec = getReducedOperand(I->getOperand(0), SclTy);453Value *NewElt = getReducedOperand(I->getOperand(1), SclTy);454Value *Idx = I->getOperand(2);455Res = Builder.CreateInsertElement(Vec, NewElt, Idx);456break;457}458case Instruction::Select: {459Value *Op0 = I->getOperand(0);460Value *LHS = getReducedOperand(I->getOperand(1), SclTy);461Value *RHS = getReducedOperand(I->getOperand(2), SclTy);462Res = Builder.CreateSelect(Op0, LHS, RHS);463break;464}465case Instruction::PHI: {466Res = Builder.CreatePHI(getReducedType(I, SclTy), I->getNumOperands());467OldNewPHINodes.push_back(468std::make_pair(cast<PHINode>(I), cast<PHINode>(Res)));469break;470}471default:472llvm_unreachable("Unhandled instruction");473}474475NodeInfo.NewValue = Res;476if (auto *ResI = dyn_cast<Instruction>(Res))477ResI->takeName(I);478}479480for (auto &Node : OldNewPHINodes) {481PHINode *OldPN = Node.first;482PHINode *NewPN = Node.second;483for (auto Incoming : zip(OldPN->incoming_values(), OldPN->blocks()))484NewPN->addIncoming(getReducedOperand(std::get<0>(Incoming), SclTy),485std::get<1>(Incoming));486}487488Value *Res = getReducedOperand(CurrentTruncInst->getOperand(0), SclTy);489Type *DstTy = CurrentTruncInst->getType();490if (Res->getType() != DstTy) {491IRBuilder<> Builder(CurrentTruncInst);492Res = Builder.CreateIntCast(Res, DstTy, false);493if (auto *ResI = dyn_cast<Instruction>(Res))494ResI->takeName(CurrentTruncInst);495}496CurrentTruncInst->replaceAllUsesWith(Res);497498// Erase old expression graph, which was replaced by the reduced expression499// graph.500CurrentTruncInst->eraseFromParent();501// First, erase old phi-nodes and its uses502for (auto &Node : OldNewPHINodes) {503PHINode *OldPN = Node.first;504OldPN->replaceAllUsesWith(PoisonValue::get(OldPN->getType()));505InstInfoMap.erase(OldPN);506OldPN->eraseFromParent();507}508// Now we have expression graph turned into dag.509// We iterate backward, which means we visit the instruction before we510// visit any of its operands, this way, when we get to the operand, we already511// removed the instructions (from the expression dag) that uses it.512for (auto &I : llvm::reverse(InstInfoMap)) {513// We still need to check that the instruction has no users before we erase514// it, because {SExt, ZExt}Inst Instruction might have other users that was515// not reduced, in such case, we need to keep that instruction.516if (I.first->use_empty())517I.first->eraseFromParent();518else519assert((isa<SExtInst>(I.first) || isa<ZExtInst>(I.first)) &&520"Only {SExt, ZExt}Inst might have unreduced users");521}522}523524bool TruncInstCombine::run(Function &F) {525bool MadeIRChange = false;526527// Collect all TruncInst in the function into the Worklist for evaluating.528for (auto &BB : F) {529// Ignore unreachable basic block.530if (!DT.isReachableFromEntry(&BB))531continue;532for (auto &I : BB)533if (auto *CI = dyn_cast<TruncInst>(&I))534Worklist.push_back(CI);535}536537// Process all TruncInst in the Worklist, for each instruction:538// 1. Check if it dominates an eligible expression graph to be reduced.539// 2. Create a reduced expression graph and replace the old one with it.540while (!Worklist.empty()) {541CurrentTruncInst = Worklist.pop_back_val();542543if (Type *NewDstSclTy = getBestTruncatedType()) {544LLVM_DEBUG(545dbgs() << "ICE: TruncInstCombine reducing type of expression graph "546"dominated by: "547<< CurrentTruncInst << '\n');548ReduceExpressionGraph(NewDstSclTy);549++NumExprsReduced;550MadeIRChange = true;551}552}553554return MadeIRChange;555}556557558