Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
freebsd
GitHub Repository: freebsd/freebsd-src
Path: blob/main/contrib/llvm-project/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp
213845 views
1
//===----------------------------------------------------------------------===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
9
#include "PassDetail.h"
10
#include "mlir/Dialect/Func/IR/FuncOps.h"
11
#include "mlir/IR/Block.h"
12
#include "mlir/IR/Operation.h"
13
#include "mlir/IR/PatternMatch.h"
14
#include "mlir/IR/Region.h"
15
#include "mlir/Support/LogicalResult.h"
16
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
17
#include "clang/CIR/Dialect/IR/CIRDialect.h"
18
#include "clang/CIR/Dialect/Passes.h"
19
#include "llvm/ADT/SmallVector.h"
20
21
using namespace mlir;
22
using namespace cir;
23
24
//===----------------------------------------------------------------------===//
25
// Rewrite patterns
26
//===----------------------------------------------------------------------===//
27
28
namespace {
29
30
/// Simplify suitable ternary operations into select operations.
31
///
32
/// For now we only simplify those ternary operations whose true and false
33
/// branches directly yield a value or a constant. That is, both of the true and
34
/// the false branch must either contain a cir.yield operation as the only
35
/// operation in the branch, or contain a cir.const operation followed by a
36
/// cir.yield operation that yields the constant value.
37
///
38
/// For example, we will simplify the following ternary operation:
39
///
40
/// %0 = ...
41
/// %1 = cir.ternary (%condition, true {
42
/// %2 = cir.const ...
43
/// cir.yield %2
44
/// } false {
45
/// cir.yield %0
46
///
47
/// into the following sequence of operations:
48
///
49
/// %1 = cir.const ...
50
/// %0 = cir.select if %condition then %1 else %2
51
struct SimplifyTernary final : public OpRewritePattern<TernaryOp> {
52
using OpRewritePattern<TernaryOp>::OpRewritePattern;
53
54
LogicalResult matchAndRewrite(TernaryOp op,
55
PatternRewriter &rewriter) const override {
56
if (op->getNumResults() != 1)
57
return mlir::failure();
58
59
if (!isSimpleTernaryBranch(op.getTrueRegion()) ||
60
!isSimpleTernaryBranch(op.getFalseRegion()))
61
return mlir::failure();
62
63
cir::YieldOp trueBranchYieldOp =
64
mlir::cast<cir::YieldOp>(op.getTrueRegion().front().getTerminator());
65
cir::YieldOp falseBranchYieldOp =
66
mlir::cast<cir::YieldOp>(op.getFalseRegion().front().getTerminator());
67
mlir::Value trueValue = trueBranchYieldOp.getArgs()[0];
68
mlir::Value falseValue = falseBranchYieldOp.getArgs()[0];
69
70
rewriter.inlineBlockBefore(&op.getTrueRegion().front(), op);
71
rewriter.inlineBlockBefore(&op.getFalseRegion().front(), op);
72
rewriter.eraseOp(trueBranchYieldOp);
73
rewriter.eraseOp(falseBranchYieldOp);
74
rewriter.replaceOpWithNewOp<cir::SelectOp>(op, op.getCond(), trueValue,
75
falseValue);
76
77
return mlir::success();
78
}
79
80
private:
81
bool isSimpleTernaryBranch(mlir::Region &region) const {
82
if (!region.hasOneBlock())
83
return false;
84
85
mlir::Block &onlyBlock = region.front();
86
mlir::Block::OpListType &ops = onlyBlock.getOperations();
87
88
// The region/block could only contain at most 2 operations.
89
if (ops.size() > 2)
90
return false;
91
92
if (ops.size() == 1) {
93
// The region/block only contain a cir.yield operation.
94
return true;
95
}
96
97
// Check whether the region/block contains a cir.const followed by a
98
// cir.yield that yields the value.
99
auto yieldOp = mlir::cast<cir::YieldOp>(onlyBlock.getTerminator());
100
auto yieldValueDefOp = mlir::dyn_cast_if_present<cir::ConstantOp>(
101
yieldOp.getArgs()[0].getDefiningOp());
102
return yieldValueDefOp && yieldValueDefOp->getBlock() == &onlyBlock;
103
}
104
};
105
106
/// Simplify select operations with boolean constants into simpler forms.
107
///
108
/// This pattern simplifies select operations where both true and false values
109
/// are boolean constants. Two specific cases are handled:
110
///
111
/// 1. When selecting between true and false based on a condition,
112
/// the operation simplifies to just the condition itself:
113
///
114
/// %0 = cir.select if %condition then true else false
115
/// ->
116
/// (replaced with %condition directly)
117
///
118
/// 2. When selecting between false and true based on a condition,
119
/// the operation simplifies to the logical negation of the condition:
120
///
121
/// %0 = cir.select if %condition then false else true
122
/// ->
123
/// %0 = cir.unary not %condition
124
struct SimplifySelect : public OpRewritePattern<SelectOp> {
125
using OpRewritePattern<SelectOp>::OpRewritePattern;
126
127
LogicalResult matchAndRewrite(SelectOp op,
128
PatternRewriter &rewriter) const final {
129
mlir::Operation *trueValueOp = op.getTrueValue().getDefiningOp();
130
mlir::Operation *falseValueOp = op.getFalseValue().getDefiningOp();
131
auto trueValueConstOp =
132
mlir::dyn_cast_if_present<cir::ConstantOp>(trueValueOp);
133
auto falseValueConstOp =
134
mlir::dyn_cast_if_present<cir::ConstantOp>(falseValueOp);
135
if (!trueValueConstOp || !falseValueConstOp)
136
return mlir::failure();
137
138
auto trueValue = mlir::dyn_cast<cir::BoolAttr>(trueValueConstOp.getValue());
139
auto falseValue =
140
mlir::dyn_cast<cir::BoolAttr>(falseValueConstOp.getValue());
141
if (!trueValue || !falseValue)
142
return mlir::failure();
143
144
// cir.select if %0 then #true else #false -> %0
145
if (trueValue.getValue() && !falseValue.getValue()) {
146
rewriter.replaceAllUsesWith(op, op.getCondition());
147
rewriter.eraseOp(op);
148
return mlir::success();
149
}
150
151
// cir.select if %0 then #false else #true -> cir.unary not %0
152
if (!trueValue.getValue() && falseValue.getValue()) {
153
rewriter.replaceOpWithNewOp<cir::UnaryOp>(op, cir::UnaryOpKind::Not,
154
op.getCondition());
155
return mlir::success();
156
}
157
158
return mlir::failure();
159
}
160
};
161
162
/// Simplify `cir.switch` operations by folding cascading cases
163
/// into a single `cir.case` with the `anyof` kind.
164
///
165
/// This pattern identifies cascading cases within a `cir.switch` operation.
166
/// Cascading cases are defined as consecutive `cir.case` operations of kind
167
/// `equal`, each containing a single `cir.yield` operation in their body.
168
///
169
/// The pattern merges these cascading cases into a single `cir.case` operation
170
/// with kind `anyof`, aggregating all the case values.
171
///
172
/// The merging process continues until a `cir.case` with a different body
173
/// (e.g., containing `cir.break` or compound stmt) is encountered, which
174
/// breaks the chain.
175
///
176
/// Example:
177
///
178
/// Before:
179
/// cir.case equal, [#cir.int<0> : !s32i] {
180
/// cir.yield
181
/// }
182
/// cir.case equal, [#cir.int<1> : !s32i] {
183
/// cir.yield
184
/// }
185
/// cir.case equal, [#cir.int<2> : !s32i] {
186
/// cir.break
187
/// }
188
///
189
/// After applying SimplifySwitch:
190
/// cir.case anyof, [#cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<2> :
191
/// !s32i] {
192
/// cir.break
193
/// }
194
struct SimplifySwitch : public OpRewritePattern<SwitchOp> {
195
using OpRewritePattern<SwitchOp>::OpRewritePattern;
196
LogicalResult matchAndRewrite(SwitchOp op,
197
PatternRewriter &rewriter) const override {
198
199
LogicalResult changed = mlir::failure();
200
SmallVector<CaseOp, 8> cases;
201
SmallVector<CaseOp, 4> cascadingCases;
202
SmallVector<mlir::Attribute, 4> cascadingCaseValues;
203
204
op.collectCases(cases);
205
if (cases.empty())
206
return mlir::failure();
207
208
auto flushMergedOps = [&]() {
209
for (CaseOp &c : cascadingCases)
210
rewriter.eraseOp(c);
211
cascadingCases.clear();
212
cascadingCaseValues.clear();
213
};
214
215
auto mergeCascadingInto = [&](CaseOp &target) {
216
rewriter.modifyOpInPlace(target, [&]() {
217
target.setValueAttr(rewriter.getArrayAttr(cascadingCaseValues));
218
target.setKind(CaseOpKind::Anyof);
219
});
220
changed = mlir::success();
221
};
222
223
for (CaseOp c : cases) {
224
cir::CaseOpKind kind = c.getKind();
225
if (kind == cir::CaseOpKind::Equal &&
226
isa<YieldOp>(c.getCaseRegion().front().front())) {
227
// If the case contains only a YieldOp, collect it for cascading merge
228
cascadingCases.push_back(c);
229
cascadingCaseValues.push_back(c.getValue()[0]);
230
} else if (kind == cir::CaseOpKind::Equal && !cascadingCases.empty()) {
231
// merge previously collected cascading cases
232
cascadingCaseValues.push_back(c.getValue()[0]);
233
mergeCascadingInto(c);
234
flushMergedOps();
235
} else if (kind != cir::CaseOpKind::Equal && cascadingCases.size() > 1) {
236
// If a Default, Anyof or Range case is found and there are previous
237
// cascading cases, merge all of them into the last cascading case.
238
// We don't currently fold case range statements with other case
239
// statements.
240
assert(!cir::MissingFeatures::foldRangeCase());
241
CaseOp lastCascadingCase = cascadingCases.back();
242
mergeCascadingInto(lastCascadingCase);
243
cascadingCases.pop_back();
244
flushMergedOps();
245
} else {
246
cascadingCases.clear();
247
cascadingCaseValues.clear();
248
}
249
}
250
251
// Edge case: all cases are simple cascading cases
252
if (cascadingCases.size() == cases.size()) {
253
CaseOp lastCascadingCase = cascadingCases.back();
254
mergeCascadingInto(lastCascadingCase);
255
cascadingCases.pop_back();
256
flushMergedOps();
257
}
258
259
return changed;
260
}
261
};
262
263
struct SimplifyVecSplat : public OpRewritePattern<VecSplatOp> {
264
using OpRewritePattern<VecSplatOp>::OpRewritePattern;
265
LogicalResult matchAndRewrite(VecSplatOp op,
266
PatternRewriter &rewriter) const override {
267
mlir::Value splatValue = op.getValue();
268
auto constant =
269
mlir::dyn_cast_if_present<cir::ConstantOp>(splatValue.getDefiningOp());
270
if (!constant)
271
return mlir::failure();
272
273
auto value = constant.getValue();
274
if (!mlir::isa_and_nonnull<cir::IntAttr>(value) &&
275
!mlir::isa_and_nonnull<cir::FPAttr>(value))
276
return mlir::failure();
277
278
cir::VectorType resultType = op.getResult().getType();
279
SmallVector<mlir::Attribute, 16> elements(resultType.getSize(), value);
280
auto constVecAttr = cir::ConstVectorAttr::get(
281
resultType, mlir::ArrayAttr::get(getContext(), elements));
282
283
rewriter.replaceOpWithNewOp<cir::ConstantOp>(op, constVecAttr);
284
return mlir::success();
285
}
286
};
287
288
//===----------------------------------------------------------------------===//
289
// CIRSimplifyPass
290
//===----------------------------------------------------------------------===//
291
292
struct CIRSimplifyPass : public CIRSimplifyBase<CIRSimplifyPass> {
293
using CIRSimplifyBase::CIRSimplifyBase;
294
295
void runOnOperation() override;
296
};
297
298
void populateMergeCleanupPatterns(RewritePatternSet &patterns) {
299
// clang-format off
300
patterns.add<
301
SimplifyTernary,
302
SimplifySelect,
303
SimplifySwitch,
304
SimplifyVecSplat
305
>(patterns.getContext());
306
// clang-format on
307
}
308
309
void CIRSimplifyPass::runOnOperation() {
310
// Collect rewrite patterns.
311
RewritePatternSet patterns(&getContext());
312
populateMergeCleanupPatterns(patterns);
313
314
// Collect operations to apply patterns.
315
llvm::SmallVector<Operation *, 16> ops;
316
getOperation()->walk([&](Operation *op) {
317
if (isa<TernaryOp, SelectOp, SwitchOp, VecSplatOp>(op))
318
ops.push_back(op);
319
});
320
321
// Apply patterns.
322
if (applyOpPatternsGreedily(ops, std::move(patterns)).failed())
323
signalPassFailure();
324
}
325
326
} // namespace
327
328
std::unique_ptr<Pass> mlir::createCIRSimplifyPass() {
329
return std::make_unique<CIRSimplifyPass>();
330
}
331
332