Path: blob/main/contrib/llvm-project/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp
213845 views
//===----------------------------------------------------------------------===//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.3// See https://llvm.org/LICENSE.txt for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===----------------------------------------------------------------------===//78#include "PassDetail.h"9#include "mlir/Dialect/Func/IR/FuncOps.h"10#include "mlir/IR/Block.h"11#include "mlir/IR/Operation.h"12#include "mlir/IR/PatternMatch.h"13#include "mlir/IR/Region.h"14#include "mlir/Support/LogicalResult.h"15#include "mlir/Transforms/GreedyPatternRewriteDriver.h"16#include "clang/CIR/Dialect/IR/CIRDialect.h"17#include "clang/CIR/Dialect/Passes.h"18#include "llvm/ADT/SmallVector.h"1920using namespace mlir;21using namespace cir;2223//===----------------------------------------------------------------------===//24// Rewrite patterns25//===----------------------------------------------------------------------===//2627namespace {2829/// Simplify suitable ternary operations into select operations.30///31/// For now we only simplify those ternary operations whose true and false32/// branches directly yield a value or a constant. That is, both of the true and33/// the false branch must either contain a cir.yield operation as the only34/// operation in the branch, or contain a cir.const operation followed by a35/// cir.yield operation that yields the constant value.36///37/// For example, we will simplify the following ternary operation:38///39/// %0 = ...40/// %1 = cir.ternary (%condition, true {41/// %2 = cir.const ...42/// cir.yield %243/// } false {44/// cir.yield %045///46/// into the following sequence of operations:47///48/// %1 = cir.const ...49/// %0 = cir.select if %condition then %1 else %250struct SimplifyTernary final : public OpRewritePattern<TernaryOp> {51using OpRewritePattern<TernaryOp>::OpRewritePattern;5253LogicalResult matchAndRewrite(TernaryOp op,54PatternRewriter &rewriter) const override {55if (op->getNumResults() != 1)56return mlir::failure();5758if (!isSimpleTernaryBranch(op.getTrueRegion()) ||59!isSimpleTernaryBranch(op.getFalseRegion()))60return mlir::failure();6162cir::YieldOp trueBranchYieldOp =63mlir::cast<cir::YieldOp>(op.getTrueRegion().front().getTerminator());64cir::YieldOp falseBranchYieldOp =65mlir::cast<cir::YieldOp>(op.getFalseRegion().front().getTerminator());66mlir::Value trueValue = trueBranchYieldOp.getArgs()[0];67mlir::Value falseValue = falseBranchYieldOp.getArgs()[0];6869rewriter.inlineBlockBefore(&op.getTrueRegion().front(), op);70rewriter.inlineBlockBefore(&op.getFalseRegion().front(), op);71rewriter.eraseOp(trueBranchYieldOp);72rewriter.eraseOp(falseBranchYieldOp);73rewriter.replaceOpWithNewOp<cir::SelectOp>(op, op.getCond(), trueValue,74falseValue);7576return mlir::success();77}7879private:80bool isSimpleTernaryBranch(mlir::Region ®ion) const {81if (!region.hasOneBlock())82return false;8384mlir::Block &onlyBlock = region.front();85mlir::Block::OpListType &ops = onlyBlock.getOperations();8687// The region/block could only contain at most 2 operations.88if (ops.size() > 2)89return false;9091if (ops.size() == 1) {92// The region/block only contain a cir.yield operation.93return true;94}9596// Check whether the region/block contains a cir.const followed by a97// cir.yield that yields the value.98auto yieldOp = mlir::cast<cir::YieldOp>(onlyBlock.getTerminator());99auto yieldValueDefOp = mlir::dyn_cast_if_present<cir::ConstantOp>(100yieldOp.getArgs()[0].getDefiningOp());101return yieldValueDefOp && yieldValueDefOp->getBlock() == &onlyBlock;102}103};104105/// Simplify select operations with boolean constants into simpler forms.106///107/// This pattern simplifies select operations where both true and false values108/// are boolean constants. Two specific cases are handled:109///110/// 1. When selecting between true and false based on a condition,111/// the operation simplifies to just the condition itself:112///113/// %0 = cir.select if %condition then true else false114/// ->115/// (replaced with %condition directly)116///117/// 2. When selecting between false and true based on a condition,118/// the operation simplifies to the logical negation of the condition:119///120/// %0 = cir.select if %condition then false else true121/// ->122/// %0 = cir.unary not %condition123struct SimplifySelect : public OpRewritePattern<SelectOp> {124using OpRewritePattern<SelectOp>::OpRewritePattern;125126LogicalResult matchAndRewrite(SelectOp op,127PatternRewriter &rewriter) const final {128mlir::Operation *trueValueOp = op.getTrueValue().getDefiningOp();129mlir::Operation *falseValueOp = op.getFalseValue().getDefiningOp();130auto trueValueConstOp =131mlir::dyn_cast_if_present<cir::ConstantOp>(trueValueOp);132auto falseValueConstOp =133mlir::dyn_cast_if_present<cir::ConstantOp>(falseValueOp);134if (!trueValueConstOp || !falseValueConstOp)135return mlir::failure();136137auto trueValue = mlir::dyn_cast<cir::BoolAttr>(trueValueConstOp.getValue());138auto falseValue =139mlir::dyn_cast<cir::BoolAttr>(falseValueConstOp.getValue());140if (!trueValue || !falseValue)141return mlir::failure();142143// cir.select if %0 then #true else #false -> %0144if (trueValue.getValue() && !falseValue.getValue()) {145rewriter.replaceAllUsesWith(op, op.getCondition());146rewriter.eraseOp(op);147return mlir::success();148}149150// cir.select if %0 then #false else #true -> cir.unary not %0151if (!trueValue.getValue() && falseValue.getValue()) {152rewriter.replaceOpWithNewOp<cir::UnaryOp>(op, cir::UnaryOpKind::Not,153op.getCondition());154return mlir::success();155}156157return mlir::failure();158}159};160161/// Simplify `cir.switch` operations by folding cascading cases162/// into a single `cir.case` with the `anyof` kind.163///164/// This pattern identifies cascading cases within a `cir.switch` operation.165/// Cascading cases are defined as consecutive `cir.case` operations of kind166/// `equal`, each containing a single `cir.yield` operation in their body.167///168/// The pattern merges these cascading cases into a single `cir.case` operation169/// with kind `anyof`, aggregating all the case values.170///171/// The merging process continues until a `cir.case` with a different body172/// (e.g., containing `cir.break` or compound stmt) is encountered, which173/// breaks the chain.174///175/// Example:176///177/// Before:178/// cir.case equal, [#cir.int<0> : !s32i] {179/// cir.yield180/// }181/// cir.case equal, [#cir.int<1> : !s32i] {182/// cir.yield183/// }184/// cir.case equal, [#cir.int<2> : !s32i] {185/// cir.break186/// }187///188/// After applying SimplifySwitch:189/// cir.case anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> :190/// !s32i] {191/// cir.break192/// }193struct SimplifySwitch : public OpRewritePattern<SwitchOp> {194using OpRewritePattern<SwitchOp>::OpRewritePattern;195LogicalResult matchAndRewrite(SwitchOp op,196PatternRewriter &rewriter) const override {197198LogicalResult changed = mlir::failure();199SmallVector<CaseOp, 8> cases;200SmallVector<CaseOp, 4> cascadingCases;201SmallVector<mlir::Attribute, 4> cascadingCaseValues;202203op.collectCases(cases);204if (cases.empty())205return mlir::failure();206207auto flushMergedOps = [&]() {208for (CaseOp &c : cascadingCases)209rewriter.eraseOp(c);210cascadingCases.clear();211cascadingCaseValues.clear();212};213214auto mergeCascadingInto = [&](CaseOp &target) {215rewriter.modifyOpInPlace(target, [&]() {216target.setValueAttr(rewriter.getArrayAttr(cascadingCaseValues));217target.setKind(CaseOpKind::Anyof);218});219changed = mlir::success();220};221222for (CaseOp c : cases) {223cir::CaseOpKind kind = c.getKind();224if (kind == cir::CaseOpKind::Equal &&225isa<YieldOp>(c.getCaseRegion().front().front())) {226// If the case contains only a YieldOp, collect it for cascading merge227cascadingCases.push_back(c);228cascadingCaseValues.push_back(c.getValue()[0]);229} else if (kind == cir::CaseOpKind::Equal && !cascadingCases.empty()) {230// merge previously collected cascading cases231cascadingCaseValues.push_back(c.getValue()[0]);232mergeCascadingInto(c);233flushMergedOps();234} else if (kind != cir::CaseOpKind::Equal && cascadingCases.size() > 1) {235// If a Default, Anyof or Range case is found and there are previous236// cascading cases, merge all of them into the last cascading case.237// We don't currently fold case range statements with other case238// statements.239assert(!cir::MissingFeatures::foldRangeCase());240CaseOp lastCascadingCase = cascadingCases.back();241mergeCascadingInto(lastCascadingCase);242cascadingCases.pop_back();243flushMergedOps();244} else {245cascadingCases.clear();246cascadingCaseValues.clear();247}248}249250// Edge case: all cases are simple cascading cases251if (cascadingCases.size() == cases.size()) {252CaseOp lastCascadingCase = cascadingCases.back();253mergeCascadingInto(lastCascadingCase);254cascadingCases.pop_back();255flushMergedOps();256}257258return changed;259}260};261262struct SimplifyVecSplat : public OpRewritePattern<VecSplatOp> {263using OpRewritePattern<VecSplatOp>::OpRewritePattern;264LogicalResult matchAndRewrite(VecSplatOp op,265PatternRewriter &rewriter) const override {266mlir::Value splatValue = op.getValue();267auto constant =268mlir::dyn_cast_if_present<cir::ConstantOp>(splatValue.getDefiningOp());269if (!constant)270return mlir::failure();271272auto value = constant.getValue();273if (!mlir::isa_and_nonnull<cir::IntAttr>(value) &&274!mlir::isa_and_nonnull<cir::FPAttr>(value))275return mlir::failure();276277cir::VectorType resultType = op.getResult().getType();278SmallVector<mlir::Attribute, 16> elements(resultType.getSize(), value);279auto constVecAttr = cir::ConstVectorAttr::get(280resultType, mlir::ArrayAttr::get(getContext(), elements));281282rewriter.replaceOpWithNewOp<cir::ConstantOp>(op, constVecAttr);283return mlir::success();284}285};286287//===----------------------------------------------------------------------===//288// CIRSimplifyPass289//===----------------------------------------------------------------------===//290291struct CIRSimplifyPass : public CIRSimplifyBase<CIRSimplifyPass> {292using CIRSimplifyBase::CIRSimplifyBase;293294void runOnOperation() override;295};296297void populateMergeCleanupPatterns(RewritePatternSet &patterns) {298// clang-format off299patterns.add<300SimplifyTernary,301SimplifySelect,302SimplifySwitch,303SimplifyVecSplat304>(patterns.getContext());305// clang-format on306}307308void CIRSimplifyPass::runOnOperation() {309// Collect rewrite patterns.310RewritePatternSet patterns(&getContext());311populateMergeCleanupPatterns(patterns);312313// Collect operations to apply patterns.314llvm::SmallVector<Operation *, 16> ops;315getOperation()->walk([&](Operation *op) {316if (isa<TernaryOp, SelectOp, SwitchOp, VecSplatOp>(op))317ops.push_back(op);318});319320// Apply patterns.321if (applyOpPatternsGreedily(ops, std::move(patterns)).failed())322signalPassFailure();323}324325} // namespace326327std::unique_ptr<Pass> mlir::createCIRSimplifyPass() {328return std::make_unique<CIRSimplifyPass>();329}330331332