Path: blob/main/contrib/llvm-project/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp
213845 views
//===----------------------------------------------------------------------===//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.3// See https://llvm.org/LICENSE.txt for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===----------------------------------------------------------------------===//7//8// This file implements pass that inlines CIR operations regions into the parent9// function region.10//11//===----------------------------------------------------------------------===//1213#include "PassDetail.h"14#include "mlir/Dialect/Func/IR/FuncOps.h"15#include "mlir/IR/Block.h"16#include "mlir/IR/Builders.h"17#include "mlir/IR/PatternMatch.h"18#include "mlir/Support/LogicalResult.h"19#include "mlir/Transforms/DialectConversion.h"20#include "mlir/Transforms/GreedyPatternRewriteDriver.h"21#include "clang/CIR/Dialect/IR/CIRDialect.h"22#include "clang/CIR/Dialect/Passes.h"23#include "clang/CIR/MissingFeatures.h"2425using namespace mlir;26using namespace cir;2728namespace {2930/// Lowers operations with the terminator trait that have a single successor.31void lowerTerminator(mlir::Operation *op, mlir::Block *dest,32mlir::PatternRewriter &rewriter) {33assert(op->hasTrait<mlir::OpTrait::IsTerminator>() && "not a terminator");34mlir::OpBuilder::InsertionGuard guard(rewriter);35rewriter.setInsertionPoint(op);36rewriter.replaceOpWithNewOp<cir::BrOp>(op, dest);37}3839/// Walks a region while skipping operations of type `Ops`. This ensures the40/// callback is not applied to said operations and its children.41template <typename... Ops>42void walkRegionSkipping(43mlir::Region ®ion,44mlir::function_ref<mlir::WalkResult(mlir::Operation *)> callback) {45region.walk<mlir::WalkOrder::PreOrder>([&](mlir::Operation *op) {46if (isa<Ops...>(op))47return mlir::WalkResult::skip();48return callback(op);49});50}5152struct CIRFlattenCFGPass : public CIRFlattenCFGBase<CIRFlattenCFGPass> {5354CIRFlattenCFGPass() = default;55void runOnOperation() override;56};5758struct CIRIfFlattening : public mlir::OpRewritePattern<cir::IfOp> {59using OpRewritePattern<IfOp>::OpRewritePattern;6061mlir::LogicalResult62matchAndRewrite(cir::IfOp ifOp,63mlir::PatternRewriter &rewriter) const override {64mlir::OpBuilder::InsertionGuard guard(rewriter);65mlir::Location loc = ifOp.getLoc();66bool emptyElse = ifOp.getElseRegion().empty();67mlir::Block *currentBlock = rewriter.getInsertionBlock();68mlir::Block *remainingOpsBlock =69rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());70mlir::Block *continueBlock;71if (ifOp->getResults().empty())72continueBlock = remainingOpsBlock;73else74llvm_unreachable("NYI");7576// Inline the region77mlir::Block *thenBeforeBody = &ifOp.getThenRegion().front();78mlir::Block *thenAfterBody = &ifOp.getThenRegion().back();79rewriter.inlineRegionBefore(ifOp.getThenRegion(), continueBlock);8081rewriter.setInsertionPointToEnd(thenAfterBody);82if (auto thenYieldOp =83dyn_cast<cir::YieldOp>(thenAfterBody->getTerminator())) {84rewriter.replaceOpWithNewOp<cir::BrOp>(thenYieldOp, thenYieldOp.getArgs(),85continueBlock);86}8788rewriter.setInsertionPointToEnd(continueBlock);8990// Has else region: inline it.91mlir::Block *elseBeforeBody = nullptr;92mlir::Block *elseAfterBody = nullptr;93if (!emptyElse) {94elseBeforeBody = &ifOp.getElseRegion().front();95elseAfterBody = &ifOp.getElseRegion().back();96rewriter.inlineRegionBefore(ifOp.getElseRegion(), continueBlock);97} else {98elseBeforeBody = elseAfterBody = continueBlock;99}100101rewriter.setInsertionPointToEnd(currentBlock);102rewriter.create<cir::BrCondOp>(loc, ifOp.getCondition(), thenBeforeBody,103elseBeforeBody);104105if (!emptyElse) {106rewriter.setInsertionPointToEnd(elseAfterBody);107if (auto elseYieldOP =108dyn_cast<cir::YieldOp>(elseAfterBody->getTerminator())) {109rewriter.replaceOpWithNewOp<cir::BrOp>(110elseYieldOP, elseYieldOP.getArgs(), continueBlock);111}112}113114rewriter.replaceOp(ifOp, continueBlock->getArguments());115return mlir::success();116}117};118119class CIRScopeOpFlattening : public mlir::OpRewritePattern<cir::ScopeOp> {120public:121using OpRewritePattern<cir::ScopeOp>::OpRewritePattern;122123mlir::LogicalResult124matchAndRewrite(cir::ScopeOp scopeOp,125mlir::PatternRewriter &rewriter) const override {126mlir::OpBuilder::InsertionGuard guard(rewriter);127mlir::Location loc = scopeOp.getLoc();128129// Empty scope: just remove it.130// TODO: Remove this logic once CIR uses MLIR infrastructure to remove131// trivially dead operations. MLIR canonicalizer is too aggressive and we132// need to either (a) make sure all our ops model all side-effects and/or133// (b) have more options in the canonicalizer in MLIR to temper134// aggressiveness level.135if (scopeOp.isEmpty()) {136rewriter.eraseOp(scopeOp);137return mlir::success();138}139140// Split the current block before the ScopeOp to create the inlining141// point.142mlir::Block *currentBlock = rewriter.getInsertionBlock();143mlir::Block *continueBlock =144rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());145if (scopeOp.getNumResults() > 0)146continueBlock->addArguments(scopeOp.getResultTypes(), loc);147148// Inline body region.149mlir::Block *beforeBody = &scopeOp.getScopeRegion().front();150mlir::Block *afterBody = &scopeOp.getScopeRegion().back();151rewriter.inlineRegionBefore(scopeOp.getScopeRegion(), continueBlock);152153// Save stack and then branch into the body of the region.154rewriter.setInsertionPointToEnd(currentBlock);155assert(!cir::MissingFeatures::stackSaveOp());156rewriter.create<cir::BrOp>(loc, mlir::ValueRange(), beforeBody);157158// Replace the scopeop return with a branch that jumps out of the body.159// Stack restore before leaving the body region.160rewriter.setInsertionPointToEnd(afterBody);161if (auto yieldOp = dyn_cast<cir::YieldOp>(afterBody->getTerminator())) {162rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, yieldOp.getArgs(),163continueBlock);164}165166// Replace the op with values return from the body region.167rewriter.replaceOp(scopeOp, continueBlock->getArguments());168169return mlir::success();170}171};172173class CIRSwitchOpFlattening : public mlir::OpRewritePattern<cir::SwitchOp> {174public:175using OpRewritePattern<cir::SwitchOp>::OpRewritePattern;176177inline void rewriteYieldOp(mlir::PatternRewriter &rewriter,178cir::YieldOp yieldOp,179mlir::Block *destination) const {180rewriter.setInsertionPoint(yieldOp);181rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, yieldOp.getOperands(),182destination);183}184185// Return the new defaultDestination block.186Block *condBrToRangeDestination(cir::SwitchOp op,187mlir::PatternRewriter &rewriter,188mlir::Block *rangeDestination,189mlir::Block *defaultDestination,190const APInt &lowerBound,191const APInt &upperBound) const {192assert(lowerBound.sle(upperBound) && "Invalid range");193mlir::Block *resBlock = rewriter.createBlock(defaultDestination);194cir::IntType sIntType = cir::IntType::get(op.getContext(), 32, true);195cir::IntType uIntType = cir::IntType::get(op.getContext(), 32, false);196197cir::ConstantOp rangeLength = rewriter.create<cir::ConstantOp>(198op.getLoc(), cir::IntAttr::get(sIntType, upperBound - lowerBound));199200cir::ConstantOp lowerBoundValue = rewriter.create<cir::ConstantOp>(201op.getLoc(), cir::IntAttr::get(sIntType, lowerBound));202cir::BinOp diffValue =203rewriter.create<cir::BinOp>(op.getLoc(), sIntType, cir::BinOpKind::Sub,204op.getCondition(), lowerBoundValue);205206// Use unsigned comparison to check if the condition is in the range.207cir::CastOp uDiffValue = rewriter.create<cir::CastOp>(208op.getLoc(), uIntType, CastKind::integral, diffValue);209cir::CastOp uRangeLength = rewriter.create<cir::CastOp>(210op.getLoc(), uIntType, CastKind::integral, rangeLength);211212cir::CmpOp cmpResult = rewriter.create<cir::CmpOp>(213op.getLoc(), cir::BoolType::get(op.getContext()), cir::CmpOpKind::le,214uDiffValue, uRangeLength);215rewriter.create<cir::BrCondOp>(op.getLoc(), cmpResult, rangeDestination,216defaultDestination);217return resBlock;218}219220mlir::LogicalResult221matchAndRewrite(cir::SwitchOp op,222mlir::PatternRewriter &rewriter) const override {223llvm::SmallVector<CaseOp> cases;224op.collectCases(cases);225226// Empty switch statement: just erase it.227if (cases.empty()) {228rewriter.eraseOp(op);229return mlir::success();230}231232// Create exit block from the next node of cir.switch op.233mlir::Block *exitBlock = rewriter.splitBlock(234rewriter.getBlock(), op->getNextNode()->getIterator());235236// We lower cir.switch op in the following process:237// 1. Inline the region from the switch op after switch op.238// 2. Traverse each cir.case op:239// a. Record the entry block, block arguments and condition for every240// case. b. Inline the case region after the case op.241// 3. Replace the empty cir.switch.op with the new cir.switchflat op by the242// recorded block and conditions.243244// inline everything from switch body between the switch op and the exit245// block.246{247cir::YieldOp switchYield = nullptr;248// Clear switch operation.249for (mlir::Block &block :250llvm::make_early_inc_range(op.getBody().getBlocks()))251if (auto yieldOp = dyn_cast<cir::YieldOp>(block.getTerminator()))252switchYield = yieldOp;253254assert(!op.getBody().empty());255mlir::Block *originalBlock = op->getBlock();256mlir::Block *swopBlock =257rewriter.splitBlock(originalBlock, op->getIterator());258rewriter.inlineRegionBefore(op.getBody(), exitBlock);259260if (switchYield)261rewriteYieldOp(rewriter, switchYield, exitBlock);262263rewriter.setInsertionPointToEnd(originalBlock);264rewriter.create<cir::BrOp>(op.getLoc(), swopBlock);265}266267// Allocate required data structures (disconsider default case in268// vectors).269llvm::SmallVector<mlir::APInt, 8> caseValues;270llvm::SmallVector<mlir::Block *, 8> caseDestinations;271llvm::SmallVector<mlir::ValueRange, 8> caseOperands;272273llvm::SmallVector<std::pair<APInt, APInt>> rangeValues;274llvm::SmallVector<mlir::Block *> rangeDestinations;275llvm::SmallVector<mlir::ValueRange> rangeOperands;276277// Initialize default case as optional.278mlir::Block *defaultDestination = exitBlock;279mlir::ValueRange defaultOperands = exitBlock->getArguments();280281// Digest the case statements values and bodies.282for (cir::CaseOp caseOp : cases) {283mlir::Region ®ion = caseOp.getCaseRegion();284285// Found default case: save destination and operands.286switch (caseOp.getKind()) {287case cir::CaseOpKind::Default:288defaultDestination = ®ion.front();289defaultOperands = defaultDestination->getArguments();290break;291case cir::CaseOpKind::Range:292assert(caseOp.getValue().size() == 2 &&293"Case range should have 2 case value");294rangeValues.push_back(295{cast<cir::IntAttr>(caseOp.getValue()[0]).getValue(),296cast<cir::IntAttr>(caseOp.getValue()[1]).getValue()});297rangeDestinations.push_back(®ion.front());298rangeOperands.push_back(rangeDestinations.back()->getArguments());299break;300case cir::CaseOpKind::Anyof:301case cir::CaseOpKind::Equal:302// AnyOf cases kind can have multiple values, hence the loop below.303for (const mlir::Attribute &value : caseOp.getValue()) {304caseValues.push_back(cast<cir::IntAttr>(value).getValue());305caseDestinations.push_back(®ion.front());306caseOperands.push_back(caseDestinations.back()->getArguments());307}308break;309}310311// Handle break statements.312walkRegionSkipping<cir::LoopOpInterface, cir::SwitchOp>(313region, [&](mlir::Operation *op) {314if (!isa<cir::BreakOp>(op))315return mlir::WalkResult::advance();316317lowerTerminator(op, exitBlock, rewriter);318return mlir::WalkResult::skip();319});320321// Track fallthrough in cases.322for (mlir::Block &blk : region.getBlocks()) {323if (blk.getNumSuccessors())324continue;325326if (auto yieldOp = dyn_cast<cir::YieldOp>(blk.getTerminator())) {327mlir::Operation *nextOp = caseOp->getNextNode();328assert(nextOp && "caseOp is not expected to be the last op");329mlir::Block *oldBlock = nextOp->getBlock();330mlir::Block *newBlock =331rewriter.splitBlock(oldBlock, nextOp->getIterator());332rewriter.setInsertionPointToEnd(oldBlock);333rewriter.create<cir::BrOp>(nextOp->getLoc(), mlir::ValueRange(),334newBlock);335rewriteYieldOp(rewriter, yieldOp, newBlock);336}337}338339mlir::Block *oldBlock = caseOp->getBlock();340mlir::Block *newBlock =341rewriter.splitBlock(oldBlock, caseOp->getIterator());342343mlir::Block &entryBlock = caseOp.getCaseRegion().front();344rewriter.inlineRegionBefore(caseOp.getCaseRegion(), newBlock);345346// Create a branch to the entry of the inlined region.347rewriter.setInsertionPointToEnd(oldBlock);348rewriter.create<cir::BrOp>(caseOp.getLoc(), &entryBlock);349}350351// Remove all cases since we've inlined the regions.352for (cir::CaseOp caseOp : cases) {353mlir::Block *caseBlock = caseOp->getBlock();354// Erase the block with no predecessors here to make the generated code355// simpler a little bit.356if (caseBlock->hasNoPredecessors())357rewriter.eraseBlock(caseBlock);358else359rewriter.eraseOp(caseOp);360}361362for (auto [rangeVal, operand, destination] :363llvm::zip(rangeValues, rangeOperands, rangeDestinations)) {364APInt lowerBound = rangeVal.first;365APInt upperBound = rangeVal.second;366367// The case range is unreachable, skip it.368if (lowerBound.sgt(upperBound))369continue;370371// If range is small, add multiple switch instruction cases.372// This magical number is from the original CGStmt code.373constexpr int kSmallRangeThreshold = 64;374if ((upperBound - lowerBound)375.ult(llvm::APInt(32, kSmallRangeThreshold))) {376for (APInt iValue = lowerBound; iValue.sle(upperBound); ++iValue) {377caseValues.push_back(iValue);378caseOperands.push_back(operand);379caseDestinations.push_back(destination);380}381continue;382}383384defaultDestination =385condBrToRangeDestination(op, rewriter, destination,386defaultDestination, lowerBound, upperBound);387defaultOperands = operand;388}389390// Set switch op to branch to the newly created blocks.391rewriter.setInsertionPoint(op);392rewriter.replaceOpWithNewOp<cir::SwitchFlatOp>(393op, op.getCondition(), defaultDestination, defaultOperands, caseValues,394caseDestinations, caseOperands);395396return mlir::success();397}398};399400class CIRLoopOpInterfaceFlattening401: public mlir::OpInterfaceRewritePattern<cir::LoopOpInterface> {402public:403using mlir::OpInterfaceRewritePattern<404cir::LoopOpInterface>::OpInterfaceRewritePattern;405406inline void lowerConditionOp(cir::ConditionOp op, mlir::Block *body,407mlir::Block *exit,408mlir::PatternRewriter &rewriter) const {409mlir::OpBuilder::InsertionGuard guard(rewriter);410rewriter.setInsertionPoint(op);411rewriter.replaceOpWithNewOp<cir::BrCondOp>(op, op.getCondition(), body,412exit);413}414415mlir::LogicalResult416matchAndRewrite(cir::LoopOpInterface op,417mlir::PatternRewriter &rewriter) const final {418// Setup CFG blocks.419mlir::Block *entry = rewriter.getInsertionBlock();420mlir::Block *exit =421rewriter.splitBlock(entry, rewriter.getInsertionPoint());422mlir::Block *cond = &op.getCond().front();423mlir::Block *body = &op.getBody().front();424mlir::Block *step =425(op.maybeGetStep() ? &op.maybeGetStep()->front() : nullptr);426427// Setup loop entry branch.428rewriter.setInsertionPointToEnd(entry);429rewriter.create<cir::BrOp>(op.getLoc(), &op.getEntry().front());430431// Branch from condition region to body or exit.432auto conditionOp = cast<cir::ConditionOp>(cond->getTerminator());433lowerConditionOp(conditionOp, body, exit, rewriter);434435// TODO(cir): Remove the walks below. It visits operations unnecessarily.436// However, to solve this we would likely need a custom DialectConversion437// driver to customize the order that operations are visited.438439// Lower continue statements.440mlir::Block *dest = (step ? step : cond);441op.walkBodySkippingNestedLoops([&](mlir::Operation *op) {442if (!isa<cir::ContinueOp>(op))443return mlir::WalkResult::advance();444445lowerTerminator(op, dest, rewriter);446return mlir::WalkResult::skip();447});448449// Lower break statements.450assert(!cir::MissingFeatures::switchOp());451walkRegionSkipping<cir::LoopOpInterface>(452op.getBody(), [&](mlir::Operation *op) {453if (!isa<cir::BreakOp>(op))454return mlir::WalkResult::advance();455456lowerTerminator(op, exit, rewriter);457return mlir::WalkResult::skip();458});459460// Lower optional body region yield.461for (mlir::Block &blk : op.getBody().getBlocks()) {462auto bodyYield = dyn_cast<cir::YieldOp>(blk.getTerminator());463if (bodyYield)464lowerTerminator(bodyYield, (step ? step : cond), rewriter);465}466467// Lower mandatory step region yield.468if (step)469lowerTerminator(cast<cir::YieldOp>(step->getTerminator()), cond,470rewriter);471472// Move region contents out of the loop op.473rewriter.inlineRegionBefore(op.getCond(), exit);474rewriter.inlineRegionBefore(op.getBody(), exit);475if (step)476rewriter.inlineRegionBefore(*op.maybeGetStep(), exit);477478rewriter.eraseOp(op);479return mlir::success();480}481};482483class CIRTernaryOpFlattening : public mlir::OpRewritePattern<cir::TernaryOp> {484public:485using OpRewritePattern<cir::TernaryOp>::OpRewritePattern;486487mlir::LogicalResult488matchAndRewrite(cir::TernaryOp op,489mlir::PatternRewriter &rewriter) const override {490Location loc = op->getLoc();491Block *condBlock = rewriter.getInsertionBlock();492Block::iterator opPosition = rewriter.getInsertionPoint();493Block *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);494llvm::SmallVector<mlir::Location, 2> locs;495// Ternary result is optional, make sure to populate the location only496// when relevant.497if (op->getResultTypes().size())498locs.push_back(loc);499Block *continueBlock =500rewriter.createBlock(remainingOpsBlock, op->getResultTypes(), locs);501rewriter.create<cir::BrOp>(loc, remainingOpsBlock);502503Region &trueRegion = op.getTrueRegion();504Block *trueBlock = &trueRegion.front();505mlir::Operation *trueTerminator = trueRegion.back().getTerminator();506rewriter.setInsertionPointToEnd(&trueRegion.back());507auto trueYieldOp = dyn_cast<cir::YieldOp>(trueTerminator);508509rewriter.replaceOpWithNewOp<cir::BrOp>(trueYieldOp, trueYieldOp.getArgs(),510continueBlock);511rewriter.inlineRegionBefore(trueRegion, continueBlock);512513Block *falseBlock = continueBlock;514Region &falseRegion = op.getFalseRegion();515516falseBlock = &falseRegion.front();517mlir::Operation *falseTerminator = falseRegion.back().getTerminator();518rewriter.setInsertionPointToEnd(&falseRegion.back());519auto falseYieldOp = dyn_cast<cir::YieldOp>(falseTerminator);520rewriter.replaceOpWithNewOp<cir::BrOp>(falseYieldOp, falseYieldOp.getArgs(),521continueBlock);522rewriter.inlineRegionBefore(falseRegion, continueBlock);523524rewriter.setInsertionPointToEnd(condBlock);525rewriter.create<cir::BrCondOp>(loc, op.getCond(), trueBlock, falseBlock);526527rewriter.replaceOp(op, continueBlock->getArguments());528529// Ok, we're done!530return mlir::success();531}532};533534void populateFlattenCFGPatterns(RewritePatternSet &patterns) {535patterns536.add<CIRIfFlattening, CIRLoopOpInterfaceFlattening, CIRScopeOpFlattening,537CIRSwitchOpFlattening, CIRTernaryOpFlattening>(538patterns.getContext());539}540541void CIRFlattenCFGPass::runOnOperation() {542RewritePatternSet patterns(&getContext());543populateFlattenCFGPatterns(patterns);544545// Collect operations to apply patterns.546llvm::SmallVector<Operation *, 16> ops;547getOperation()->walk<mlir::WalkOrder::PostOrder>([&](Operation *op) {548assert(!cir::MissingFeatures::ifOp());549assert(!cir::MissingFeatures::switchOp());550assert(!cir::MissingFeatures::tryOp());551if (isa<IfOp, ScopeOp, SwitchOp, LoopOpInterface, TernaryOp>(op))552ops.push_back(op);553});554555// Apply patterns.556if (applyOpPatternsGreedily(ops, std::move(patterns)).failed())557signalPassFailure();558}559560} // namespace561562namespace mlir {563564std::unique_ptr<Pass> createCIRFlattenCFGPass() {565return std::make_unique<CIRFlattenCFGPass>();566}567568} // namespace mlir569570571