Path: blob/main/contrib/llvm-project/llvm/lib/Transforms/Scalar/Float2Int.cpp
35266 views
//===- Float2Int.cpp - Demote floating point ops to work on integers ------===//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.3// See https://llvm.org/LICENSE.txt for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===----------------------------------------------------------------------===//7//8// This file implements the Float2Int pass, which aims to demote floating9// point operations to work on integers, where that is losslessly possible.10//11//===----------------------------------------------------------------------===//1213#include "llvm/Transforms/Scalar/Float2Int.h"14#include "llvm/ADT/APInt.h"15#include "llvm/ADT/APSInt.h"16#include "llvm/ADT/SmallVector.h"17#include "llvm/Analysis/GlobalsModRef.h"18#include "llvm/IR/Constants.h"19#include "llvm/IR/Dominators.h"20#include "llvm/IR/IRBuilder.h"21#include "llvm/IR/Module.h"22#include "llvm/Support/CommandLine.h"23#include "llvm/Support/Debug.h"24#include "llvm/Support/raw_ostream.h"25#include <deque>2627#define DEBUG_TYPE "float2int"2829using namespace llvm;3031// The algorithm is simple. Start at instructions that convert from the32// float to the int domain: fptoui, fptosi and fcmp. Walk up the def-use33// graph, using an equivalence datastructure to unify graphs that interfere.34//35// Mappable instructions are those with an integer corrollary that, given36// integer domain inputs, produce an integer output; fadd, for example.37//38// If a non-mappable instruction is seen, this entire def-use graph is marked39// as non-transformable. If we see an instruction that converts from the40// integer domain to FP domain (uitofp,sitofp), we terminate our walk.4142/// The largest integer type worth dealing with.43static cl::opt<unsigned>44MaxIntegerBW("float2int-max-integer-bw", cl::init(64), cl::Hidden,45cl::desc("Max integer bitwidth to consider in float2int"46"(default=64)"));4748// Given a FCmp predicate, return a matching ICmp predicate if one49// exists, otherwise return BAD_ICMP_PREDICATE.50static CmpInst::Predicate mapFCmpPred(CmpInst::Predicate P) {51switch (P) {52case CmpInst::FCMP_OEQ:53case CmpInst::FCMP_UEQ:54return CmpInst::ICMP_EQ;55case CmpInst::FCMP_OGT:56case CmpInst::FCMP_UGT:57return CmpInst::ICMP_SGT;58case CmpInst::FCMP_OGE:59case CmpInst::FCMP_UGE:60return CmpInst::ICMP_SGE;61case CmpInst::FCMP_OLT:62case CmpInst::FCMP_ULT:63return CmpInst::ICMP_SLT;64case CmpInst::FCMP_OLE:65case CmpInst::FCMP_ULE:66return CmpInst::ICMP_SLE;67case CmpInst::FCMP_ONE:68case CmpInst::FCMP_UNE:69return CmpInst::ICMP_NE;70default:71return CmpInst::BAD_ICMP_PREDICATE;72}73}7475// Given a floating point binary operator, return the matching76// integer version.77static Instruction::BinaryOps mapBinOpcode(unsigned Opcode) {78switch (Opcode) {79default: llvm_unreachable("Unhandled opcode!");80case Instruction::FAdd: return Instruction::Add;81case Instruction::FSub: return Instruction::Sub;82case Instruction::FMul: return Instruction::Mul;83}84}8586// Find the roots - instructions that convert from the FP domain to87// integer domain.88void Float2IntPass::findRoots(Function &F, const DominatorTree &DT) {89for (BasicBlock &BB : F) {90// Unreachable code can take on strange forms that we are not prepared to91// handle. For example, an instruction may have itself as an operand.92if (!DT.isReachableFromEntry(&BB))93continue;9495for (Instruction &I : BB) {96if (isa<VectorType>(I.getType()))97continue;98switch (I.getOpcode()) {99default: break;100case Instruction::FPToUI:101case Instruction::FPToSI:102Roots.insert(&I);103break;104case Instruction::FCmp:105if (mapFCmpPred(cast<CmpInst>(&I)->getPredicate()) !=106CmpInst::BAD_ICMP_PREDICATE)107Roots.insert(&I);108break;109}110}111}112}113114// Helper - mark I as having been traversed, having range R.115void Float2IntPass::seen(Instruction *I, ConstantRange R) {116LLVM_DEBUG(dbgs() << "F2I: " << *I << ":" << R << "\n");117auto IT = SeenInsts.find(I);118if (IT != SeenInsts.end())119IT->second = std::move(R);120else121SeenInsts.insert(std::make_pair(I, std::move(R)));122}123124// Helper - get a range representing a poison value.125ConstantRange Float2IntPass::badRange() {126return ConstantRange::getFull(MaxIntegerBW + 1);127}128ConstantRange Float2IntPass::unknownRange() {129return ConstantRange::getEmpty(MaxIntegerBW + 1);130}131ConstantRange Float2IntPass::validateRange(ConstantRange R) {132if (R.getBitWidth() > MaxIntegerBW + 1)133return badRange();134return R;135}136137// The most obvious way to structure the search is a depth-first, eager138// search from each root. However, that require direct recursion and so139// can only handle small instruction sequences. Instead, we split the search140// up into two phases:141// - walkBackwards: A breadth-first walk of the use-def graph starting from142// the roots. Populate "SeenInsts" with interesting143// instructions and poison values if they're obvious and144// cheap to compute. Calculate the equivalance set structure145// while we're here too.146// - walkForwards: Iterate over SeenInsts in reverse order, so we visit147// defs before their uses. Calculate the real range info.148149// Breadth-first walk of the use-def graph; determine the set of nodes150// we care about and eagerly determine if some of them are poisonous.151void Float2IntPass::walkBackwards() {152std::deque<Instruction*> Worklist(Roots.begin(), Roots.end());153while (!Worklist.empty()) {154Instruction *I = Worklist.back();155Worklist.pop_back();156157if (SeenInsts.contains(I))158// Seen already.159continue;160161switch (I->getOpcode()) {162// FIXME: Handle select and phi nodes.163default:164// Path terminated uncleanly.165seen(I, badRange());166break;167168case Instruction::UIToFP:169case Instruction::SIToFP: {170// Path terminated cleanly - use the type of the integer input to seed171// the analysis.172unsigned BW = I->getOperand(0)->getType()->getPrimitiveSizeInBits();173auto Input = ConstantRange::getFull(BW);174auto CastOp = (Instruction::CastOps)I->getOpcode();175seen(I, validateRange(Input.castOp(CastOp, MaxIntegerBW+1)));176continue;177}178179case Instruction::FNeg:180case Instruction::FAdd:181case Instruction::FSub:182case Instruction::FMul:183case Instruction::FPToUI:184case Instruction::FPToSI:185case Instruction::FCmp:186seen(I, unknownRange());187break;188}189190for (Value *O : I->operands()) {191if (Instruction *OI = dyn_cast<Instruction>(O)) {192// Unify def-use chains if they interfere.193ECs.unionSets(I, OI);194if (SeenInsts.find(I)->second != badRange())195Worklist.push_back(OI);196} else if (!isa<ConstantFP>(O)) {197// Not an instruction or ConstantFP? we can't do anything.198seen(I, badRange());199}200}201}202}203204// Calculate result range from operand ranges.205// Return std::nullopt if the range cannot be calculated yet.206std::optional<ConstantRange> Float2IntPass::calcRange(Instruction *I) {207SmallVector<ConstantRange, 4> OpRanges;208for (Value *O : I->operands()) {209if (Instruction *OI = dyn_cast<Instruction>(O)) {210auto OpIt = SeenInsts.find(OI);211assert(OpIt != SeenInsts.end() && "def not seen before use!");212if (OpIt->second == unknownRange())213return std::nullopt; // Wait until operand range has been calculated.214OpRanges.push_back(OpIt->second);215} else if (ConstantFP *CF = dyn_cast<ConstantFP>(O)) {216// Work out if the floating point number can be losslessly represented217// as an integer.218// APFloat::convertToInteger(&Exact) purports to do what we want, but219// the exactness can be too precise. For example, negative zero can220// never be exactly converted to an integer.221//222// Instead, we ask APFloat to round itself to an integral value - this223// preserves sign-of-zero - then compare the result with the original.224//225const APFloat &F = CF->getValueAPF();226227// First, weed out obviously incorrect values. Non-finite numbers228// can't be represented and neither can negative zero, unless229// we're in fast math mode.230if (!F.isFinite() ||231(F.isZero() && F.isNegative() && isa<FPMathOperator>(I) &&232!I->hasNoSignedZeros()))233return badRange();234235APFloat NewF = F;236auto Res = NewF.roundToIntegral(APFloat::rmNearestTiesToEven);237if (Res != APFloat::opOK || NewF != F)238return badRange();239240// OK, it's representable. Now get it.241APSInt Int(MaxIntegerBW+1, false);242bool Exact;243CF->getValueAPF().convertToInteger(Int,244APFloat::rmNearestTiesToEven,245&Exact);246OpRanges.push_back(ConstantRange(Int));247} else {248llvm_unreachable("Should have already marked this as badRange!");249}250}251252switch (I->getOpcode()) {253// FIXME: Handle select and phi nodes.254default:255case Instruction::UIToFP:256case Instruction::SIToFP:257llvm_unreachable("Should have been handled in walkForwards!");258259case Instruction::FNeg: {260assert(OpRanges.size() == 1 && "FNeg is a unary operator!");261unsigned Size = OpRanges[0].getBitWidth();262auto Zero = ConstantRange(APInt::getZero(Size));263return Zero.sub(OpRanges[0]);264}265266case Instruction::FAdd:267case Instruction::FSub:268case Instruction::FMul: {269assert(OpRanges.size() == 2 && "its a binary operator!");270auto BinOp = (Instruction::BinaryOps) I->getOpcode();271return OpRanges[0].binaryOp(BinOp, OpRanges[1]);272}273274//275// Root-only instructions - we'll only see these if they're the276// first node in a walk.277//278case Instruction::FPToUI:279case Instruction::FPToSI: {280assert(OpRanges.size() == 1 && "FPTo[US]I is a unary operator!");281// Note: We're ignoring the casts output size here as that's what the282// caller expects.283auto CastOp = (Instruction::CastOps)I->getOpcode();284return OpRanges[0].castOp(CastOp, MaxIntegerBW+1);285}286287case Instruction::FCmp:288assert(OpRanges.size() == 2 && "FCmp is a binary operator!");289return OpRanges[0].unionWith(OpRanges[1]);290}291}292293// Walk forwards down the list of seen instructions, so we visit defs before294// uses.295void Float2IntPass::walkForwards() {296std::deque<Instruction *> Worklist;297for (const auto &Pair : SeenInsts)298if (Pair.second == unknownRange())299Worklist.push_back(Pair.first);300301while (!Worklist.empty()) {302Instruction *I = Worklist.back();303Worklist.pop_back();304305if (std::optional<ConstantRange> Range = calcRange(I))306seen(I, *Range);307else308Worklist.push_front(I); // Reprocess later.309}310}311312// If there is a valid transform to be done, do it.313bool Float2IntPass::validateAndTransform(const DataLayout &DL) {314bool MadeChange = false;315316// Iterate over every disjoint partition of the def-use graph.317for (auto It = ECs.begin(), E = ECs.end(); It != E; ++It) {318ConstantRange R(MaxIntegerBW + 1, false);319bool Fail = false;320Type *ConvertedToTy = nullptr;321322// For every member of the partition, union all the ranges together.323for (auto MI = ECs.member_begin(It), ME = ECs.member_end();324MI != ME; ++MI) {325Instruction *I = *MI;326auto SeenI = SeenInsts.find(I);327if (SeenI == SeenInsts.end())328continue;329330R = R.unionWith(SeenI->second);331// We need to ensure I has no users that have not been seen.332// If it does, transformation would be illegal.333//334// Don't count the roots, as they terminate the graphs.335if (!Roots.contains(I)) {336// Set the type of the conversion while we're here.337if (!ConvertedToTy)338ConvertedToTy = I->getType();339for (User *U : I->users()) {340Instruction *UI = dyn_cast<Instruction>(U);341if (!UI || !SeenInsts.contains(UI)) {342LLVM_DEBUG(dbgs() << "F2I: Failing because of " << *U << "\n");343Fail = true;344break;345}346}347}348if (Fail)349break;350}351352// If the set was empty, or we failed, or the range is poisonous,353// bail out.354if (ECs.member_begin(It) == ECs.member_end() || Fail ||355R.isFullSet() || R.isSignWrappedSet())356continue;357assert(ConvertedToTy && "Must have set the convertedtoty by this point!");358359// The number of bits required is the maximum of the upper and360// lower limits, plus one so it can be signed.361unsigned MinBW = R.getMinSignedBits() + 1;362LLVM_DEBUG(dbgs() << "F2I: MinBitwidth=" << MinBW << ", R: " << R << "\n");363364// If we've run off the realms of the exactly representable integers,365// the floating point result will differ from an integer approximation.366367// Do we need more bits than are in the mantissa of the type we converted368// to? semanticsPrecision returns the number of mantissa bits plus one369// for the sign bit.370unsigned MaxRepresentableBits371= APFloat::semanticsPrecision(ConvertedToTy->getFltSemantics()) - 1;372if (MinBW > MaxRepresentableBits) {373LLVM_DEBUG(dbgs() << "F2I: Value not guaranteed to be representable!\n");374continue;375}376377// OK, R is known to be representable.378// Pick the smallest legal type that will fit.379Type *Ty = DL.getSmallestLegalIntType(*Ctx, MinBW);380if (!Ty) {381// Every supported target supports 64-bit and 32-bit integers,382// so fallback to a 32 or 64-bit integer if the value fits.383if (MinBW <= 32) {384Ty = Type::getInt32Ty(*Ctx);385} else if (MinBW <= 64) {386Ty = Type::getInt64Ty(*Ctx);387} else {388LLVM_DEBUG(dbgs() << "F2I: Value requires more bits to represent than "389"the target supports!\n");390continue;391}392}393394for (auto MI = ECs.member_begin(It), ME = ECs.member_end();395MI != ME; ++MI)396convert(*MI, Ty);397MadeChange = true;398}399400return MadeChange;401}402403Value *Float2IntPass::convert(Instruction *I, Type *ToTy) {404if (ConvertedInsts.contains(I))405// Already converted this instruction.406return ConvertedInsts[I];407408SmallVector<Value*,4> NewOperands;409for (Value *V : I->operands()) {410// Don't recurse if we're an instruction that terminates the path.411if (I->getOpcode() == Instruction::UIToFP ||412I->getOpcode() == Instruction::SIToFP) {413NewOperands.push_back(V);414} else if (Instruction *VI = dyn_cast<Instruction>(V)) {415NewOperands.push_back(convert(VI, ToTy));416} else if (ConstantFP *CF = dyn_cast<ConstantFP>(V)) {417APSInt Val(ToTy->getPrimitiveSizeInBits(), /*isUnsigned=*/false);418bool Exact;419CF->getValueAPF().convertToInteger(Val,420APFloat::rmNearestTiesToEven,421&Exact);422NewOperands.push_back(ConstantInt::get(ToTy, Val));423} else {424llvm_unreachable("Unhandled operand type?");425}426}427428// Now create a new instruction.429IRBuilder<> IRB(I);430Value *NewV = nullptr;431switch (I->getOpcode()) {432default: llvm_unreachable("Unhandled instruction!");433434case Instruction::FPToUI:435NewV = IRB.CreateZExtOrTrunc(NewOperands[0], I->getType());436break;437438case Instruction::FPToSI:439NewV = IRB.CreateSExtOrTrunc(NewOperands[0], I->getType());440break;441442case Instruction::FCmp: {443CmpInst::Predicate P = mapFCmpPred(cast<CmpInst>(I)->getPredicate());444assert(P != CmpInst::BAD_ICMP_PREDICATE && "Unhandled predicate!");445NewV = IRB.CreateICmp(P, NewOperands[0], NewOperands[1], I->getName());446break;447}448449case Instruction::UIToFP:450NewV = IRB.CreateZExtOrTrunc(NewOperands[0], ToTy);451break;452453case Instruction::SIToFP:454NewV = IRB.CreateSExtOrTrunc(NewOperands[0], ToTy);455break;456457case Instruction::FNeg:458NewV = IRB.CreateNeg(NewOperands[0], I->getName());459break;460461case Instruction::FAdd:462case Instruction::FSub:463case Instruction::FMul:464NewV = IRB.CreateBinOp(mapBinOpcode(I->getOpcode()),465NewOperands[0], NewOperands[1],466I->getName());467break;468}469470// If we're a root instruction, RAUW.471if (Roots.count(I))472I->replaceAllUsesWith(NewV);473474ConvertedInsts[I] = NewV;475return NewV;476}477478// Perform dead code elimination on the instructions we just modified.479void Float2IntPass::cleanup() {480for (auto &I : reverse(ConvertedInsts))481I.first->eraseFromParent();482}483484bool Float2IntPass::runImpl(Function &F, const DominatorTree &DT) {485LLVM_DEBUG(dbgs() << "F2I: Looking at function " << F.getName() << "\n");486// Clear out all state.487ECs = EquivalenceClasses<Instruction*>();488SeenInsts.clear();489ConvertedInsts.clear();490Roots.clear();491492Ctx = &F.getParent()->getContext();493494findRoots(F, DT);495496walkBackwards();497walkForwards();498499const DataLayout &DL = F.getDataLayout();500bool Modified = validateAndTransform(DL);501if (Modified)502cleanup();503return Modified;504}505506PreservedAnalyses Float2IntPass::run(Function &F, FunctionAnalysisManager &AM) {507const DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F);508if (!runImpl(F, DT))509return PreservedAnalyses::all();510511PreservedAnalyses PA;512PA.preserveSet<CFGAnalyses>();513return PA;514}515516517