Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
freebsd
GitHub Repository: freebsd/freebsd-src
Path: blob/main/contrib/llvm-project/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
213845 views
1
//====- LowerToLLVM.cpp - Lowering from CIR to LLVMIR ---------------------===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
//
9
// This file implements lowering of CIR operations to LLVMIR.
10
//
11
//===----------------------------------------------------------------------===//
12
13
#include "LowerToLLVM.h"
14
15
#include <deque>
16
#include <optional>
17
18
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
19
#include "mlir/Dialect/DLTI/DLTI.h"
20
#include "mlir/Dialect/Func/IR/FuncOps.h"
21
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
22
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
23
#include "mlir/IR/BuiltinAttributes.h"
24
#include "mlir/IR/BuiltinDialect.h"
25
#include "mlir/IR/BuiltinOps.h"
26
#include "mlir/IR/Types.h"
27
#include "mlir/Pass/Pass.h"
28
#include "mlir/Pass/PassManager.h"
29
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
30
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
31
#include "mlir/Target/LLVMIR/Export.h"
32
#include "mlir/Transforms/DialectConversion.h"
33
#include "clang/CIR/Dialect/IR/CIRAttrs.h"
34
#include "clang/CIR/Dialect/IR/CIRDialect.h"
35
#include "clang/CIR/Dialect/Passes.h"
36
#include "clang/CIR/LoweringHelpers.h"
37
#include "clang/CIR/MissingFeatures.h"
38
#include "clang/CIR/Passes.h"
39
#include "llvm/ADT/TypeSwitch.h"
40
#include "llvm/IR/Module.h"
41
#include "llvm/Support/ErrorHandling.h"
42
#include "llvm/Support/TimeProfiler.h"
43
44
using namespace cir;
45
using namespace llvm;
46
47
namespace cir {
48
namespace direct {
49
50
//===----------------------------------------------------------------------===//
51
// Helper Methods
52
//===----------------------------------------------------------------------===//
53
54
namespace {
55
/// If the given type is a vector type, return the vector's element type.
56
/// Otherwise return the given type unchanged.
57
mlir::Type elementTypeIfVector(mlir::Type type) {
58
return llvm::TypeSwitch<mlir::Type, mlir::Type>(type)
59
.Case<cir::VectorType, mlir::VectorType>(
60
[](auto p) { return p.getElementType(); })
61
.Default([](mlir::Type p) { return p; });
62
}
63
} // namespace
64
65
/// Given a type convertor and a data layout, convert the given type to a type
66
/// that is suitable for memory operations. For example, this can be used to
67
/// lower cir.bool accesses to i8.
68
static mlir::Type convertTypeForMemory(const mlir::TypeConverter &converter,
69
mlir::DataLayout const &dataLayout,
70
mlir::Type type) {
71
// TODO(cir): Handle other types similarly to clang's codegen
72
// convertTypeForMemory
73
if (isa<cir::BoolType>(type)) {
74
return mlir::IntegerType::get(type.getContext(),
75
dataLayout.getTypeSizeInBits(type));
76
}
77
78
return converter.convertType(type);
79
}
80
81
static mlir::Value createIntCast(mlir::OpBuilder &bld, mlir::Value src,
82
mlir::IntegerType dstTy,
83
bool isSigned = false) {
84
mlir::Type srcTy = src.getType();
85
assert(mlir::isa<mlir::IntegerType>(srcTy));
86
87
unsigned srcWidth = mlir::cast<mlir::IntegerType>(srcTy).getWidth();
88
unsigned dstWidth = mlir::cast<mlir::IntegerType>(dstTy).getWidth();
89
mlir::Location loc = src.getLoc();
90
91
if (dstWidth > srcWidth && isSigned)
92
return bld.create<mlir::LLVM::SExtOp>(loc, dstTy, src);
93
if (dstWidth > srcWidth)
94
return bld.create<mlir::LLVM::ZExtOp>(loc, dstTy, src);
95
if (dstWidth < srcWidth)
96
return bld.create<mlir::LLVM::TruncOp>(loc, dstTy, src);
97
return bld.create<mlir::LLVM::BitcastOp>(loc, dstTy, src);
98
}
99
100
static mlir::LLVM::Visibility
101
lowerCIRVisibilityToLLVMVisibility(cir::VisibilityKind visibilityKind) {
102
switch (visibilityKind) {
103
case cir::VisibilityKind::Default:
104
return ::mlir::LLVM::Visibility::Default;
105
case cir::VisibilityKind::Hidden:
106
return ::mlir::LLVM::Visibility::Hidden;
107
case cir::VisibilityKind::Protected:
108
return ::mlir::LLVM::Visibility::Protected;
109
}
110
}
111
112
/// Emits the value from memory as expected by its users. Should be called when
113
/// the memory represetnation of a CIR type is not equal to its scalar
114
/// representation.
115
static mlir::Value emitFromMemory(mlir::ConversionPatternRewriter &rewriter,
116
mlir::DataLayout const &dataLayout,
117
cir::LoadOp op, mlir::Value value) {
118
119
// TODO(cir): Handle other types similarly to clang's codegen EmitFromMemory
120
if (auto boolTy = mlir::dyn_cast<cir::BoolType>(op.getType())) {
121
// Create a cast value from specified size in datalayout to i1
122
assert(value.getType().isInteger(dataLayout.getTypeSizeInBits(boolTy)));
123
return createIntCast(rewriter, value, rewriter.getI1Type());
124
}
125
126
return value;
127
}
128
129
/// Emits a value to memory with the expected scalar type. Should be called when
130
/// the memory represetnation of a CIR type is not equal to its scalar
131
/// representation.
132
static mlir::Value emitToMemory(mlir::ConversionPatternRewriter &rewriter,
133
mlir::DataLayout const &dataLayout,
134
mlir::Type origType, mlir::Value value) {
135
136
// TODO(cir): Handle other types similarly to clang's codegen EmitToMemory
137
if (auto boolTy = mlir::dyn_cast<cir::BoolType>(origType)) {
138
// Create zext of value from i1 to i8
139
mlir::IntegerType memType =
140
rewriter.getIntegerType(dataLayout.getTypeSizeInBits(boolTy));
141
return createIntCast(rewriter, value, memType);
142
}
143
144
return value;
145
}
146
147
mlir::LLVM::Linkage convertLinkage(cir::GlobalLinkageKind linkage) {
148
using CIR = cir::GlobalLinkageKind;
149
using LLVM = mlir::LLVM::Linkage;
150
151
switch (linkage) {
152
case CIR::AvailableExternallyLinkage:
153
return LLVM::AvailableExternally;
154
case CIR::CommonLinkage:
155
return LLVM::Common;
156
case CIR::ExternalLinkage:
157
return LLVM::External;
158
case CIR::ExternalWeakLinkage:
159
return LLVM::ExternWeak;
160
case CIR::InternalLinkage:
161
return LLVM::Internal;
162
case CIR::LinkOnceAnyLinkage:
163
return LLVM::Linkonce;
164
case CIR::LinkOnceODRLinkage:
165
return LLVM::LinkonceODR;
166
case CIR::PrivateLinkage:
167
return LLVM::Private;
168
case CIR::WeakAnyLinkage:
169
return LLVM::Weak;
170
case CIR::WeakODRLinkage:
171
return LLVM::WeakODR;
172
};
173
llvm_unreachable("Unknown CIR linkage type");
174
}
175
176
static mlir::Value getLLVMIntCast(mlir::ConversionPatternRewriter &rewriter,
177
mlir::Value llvmSrc, mlir::Type llvmDstIntTy,
178
bool isUnsigned, uint64_t cirSrcWidth,
179
uint64_t cirDstIntWidth) {
180
if (cirSrcWidth == cirDstIntWidth)
181
return llvmSrc;
182
183
auto loc = llvmSrc.getLoc();
184
if (cirSrcWidth < cirDstIntWidth) {
185
if (isUnsigned)
186
return rewriter.create<mlir::LLVM::ZExtOp>(loc, llvmDstIntTy, llvmSrc);
187
return rewriter.create<mlir::LLVM::SExtOp>(loc, llvmDstIntTy, llvmSrc);
188
}
189
190
// Otherwise truncate
191
return rewriter.create<mlir::LLVM::TruncOp>(loc, llvmDstIntTy, llvmSrc);
192
}
193
194
class CIRAttrToValue {
195
public:
196
CIRAttrToValue(mlir::Operation *parentOp,
197
mlir::ConversionPatternRewriter &rewriter,
198
const mlir::TypeConverter *converter)
199
: parentOp(parentOp), rewriter(rewriter), converter(converter) {}
200
201
mlir::Value visit(mlir::Attribute attr) {
202
return llvm::TypeSwitch<mlir::Attribute, mlir::Value>(attr)
203
.Case<cir::IntAttr, cir::FPAttr, cir::ConstComplexAttr,
204
cir::ConstArrayAttr, cir::ConstVectorAttr, cir::ConstPtrAttr,
205
cir::ZeroAttr>([&](auto attrT) { return visitCirAttr(attrT); })
206
.Default([&](auto attrT) { return mlir::Value(); });
207
}
208
209
mlir::Value visitCirAttr(cir::IntAttr intAttr);
210
mlir::Value visitCirAttr(cir::FPAttr fltAttr);
211
mlir::Value visitCirAttr(cir::ConstComplexAttr complexAttr);
212
mlir::Value visitCirAttr(cir::ConstPtrAttr ptrAttr);
213
mlir::Value visitCirAttr(cir::ConstArrayAttr attr);
214
mlir::Value visitCirAttr(cir::ConstVectorAttr attr);
215
mlir::Value visitCirAttr(cir::ZeroAttr attr);
216
217
private:
218
mlir::Operation *parentOp;
219
mlir::ConversionPatternRewriter &rewriter;
220
const mlir::TypeConverter *converter;
221
};
222
223
/// Switches on the type of attribute and calls the appropriate conversion.
224
mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp,
225
const mlir::Attribute attr,
226
mlir::ConversionPatternRewriter &rewriter,
227
const mlir::TypeConverter *converter) {
228
CIRAttrToValue valueConverter(parentOp, rewriter, converter);
229
mlir::Value value = valueConverter.visit(attr);
230
if (!value)
231
llvm_unreachable("unhandled attribute type");
232
return value;
233
}
234
235
void convertSideEffectForCall(mlir::Operation *callOp, bool isNothrow,
236
cir::SideEffect sideEffect,
237
mlir::LLVM::MemoryEffectsAttr &memoryEffect,
238
bool &noUnwind, bool &willReturn) {
239
using mlir::LLVM::ModRefInfo;
240
241
switch (sideEffect) {
242
case cir::SideEffect::All:
243
memoryEffect = {};
244
noUnwind = isNothrow;
245
willReturn = false;
246
break;
247
248
case cir::SideEffect::Pure:
249
memoryEffect = mlir::LLVM::MemoryEffectsAttr::get(
250
callOp->getContext(), /*other=*/ModRefInfo::Ref,
251
/*argMem=*/ModRefInfo::Ref,
252
/*inaccessibleMem=*/ModRefInfo::Ref);
253
noUnwind = true;
254
willReturn = true;
255
break;
256
257
case cir::SideEffect::Const:
258
memoryEffect = mlir::LLVM::MemoryEffectsAttr::get(
259
callOp->getContext(), /*other=*/ModRefInfo::NoModRef,
260
/*argMem=*/ModRefInfo::NoModRef,
261
/*inaccessibleMem=*/ModRefInfo::NoModRef);
262
noUnwind = true;
263
willReturn = true;
264
break;
265
}
266
}
267
268
/// IntAttr visitor.
269
mlir::Value CIRAttrToValue::visitCirAttr(cir::IntAttr intAttr) {
270
mlir::Location loc = parentOp->getLoc();
271
return rewriter.create<mlir::LLVM::ConstantOp>(
272
loc, converter->convertType(intAttr.getType()), intAttr.getValue());
273
}
274
275
/// FPAttr visitor.
276
mlir::Value CIRAttrToValue::visitCirAttr(cir::FPAttr fltAttr) {
277
mlir::Location loc = parentOp->getLoc();
278
return rewriter.create<mlir::LLVM::ConstantOp>(
279
loc, converter->convertType(fltAttr.getType()), fltAttr.getValue());
280
}
281
282
/// ConstComplexAttr visitor.
283
mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstComplexAttr complexAttr) {
284
auto complexType = mlir::cast<cir::ComplexType>(complexAttr.getType());
285
mlir::Type complexElemTy = complexType.getElementType();
286
mlir::Type complexElemLLVMTy = converter->convertType(complexElemTy);
287
288
mlir::Attribute components[2];
289
if (const auto intType = mlir::dyn_cast<cir::IntType>(complexElemTy)) {
290
components[0] = rewriter.getIntegerAttr(
291
complexElemLLVMTy,
292
mlir::cast<cir::IntAttr>(complexAttr.getReal()).getValue());
293
components[1] = rewriter.getIntegerAttr(
294
complexElemLLVMTy,
295
mlir::cast<cir::IntAttr>(complexAttr.getImag()).getValue());
296
} else {
297
components[0] = rewriter.getFloatAttr(
298
complexElemLLVMTy,
299
mlir::cast<cir::FPAttr>(complexAttr.getReal()).getValue());
300
components[1] = rewriter.getFloatAttr(
301
complexElemLLVMTy,
302
mlir::cast<cir::FPAttr>(complexAttr.getImag()).getValue());
303
}
304
305
mlir::Location loc = parentOp->getLoc();
306
return rewriter.create<mlir::LLVM::ConstantOp>(
307
loc, converter->convertType(complexAttr.getType()),
308
rewriter.getArrayAttr(components));
309
}
310
311
/// ConstPtrAttr visitor.
312
mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstPtrAttr ptrAttr) {
313
mlir::Location loc = parentOp->getLoc();
314
if (ptrAttr.isNullValue()) {
315
return rewriter.create<mlir::LLVM::ZeroOp>(
316
loc, converter->convertType(ptrAttr.getType()));
317
}
318
mlir::DataLayout layout(parentOp->getParentOfType<mlir::ModuleOp>());
319
mlir::Value ptrVal = rewriter.create<mlir::LLVM::ConstantOp>(
320
loc, rewriter.getIntegerType(layout.getTypeSizeInBits(ptrAttr.getType())),
321
ptrAttr.getValue().getInt());
322
return rewriter.create<mlir::LLVM::IntToPtrOp>(
323
loc, converter->convertType(ptrAttr.getType()), ptrVal);
324
}
325
326
// ConstArrayAttr visitor
327
mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstArrayAttr attr) {
328
mlir::Type llvmTy = converter->convertType(attr.getType());
329
mlir::Location loc = parentOp->getLoc();
330
mlir::Value result;
331
332
if (attr.hasTrailingZeros()) {
333
mlir::Type arrayTy = attr.getType();
334
result = rewriter.create<mlir::LLVM::ZeroOp>(
335
loc, converter->convertType(arrayTy));
336
} else {
337
result = rewriter.create<mlir::LLVM::UndefOp>(loc, llvmTy);
338
}
339
340
// Iteratively lower each constant element of the array.
341
if (auto arrayAttr = mlir::dyn_cast<mlir::ArrayAttr>(attr.getElts())) {
342
for (auto [idx, elt] : llvm::enumerate(arrayAttr)) {
343
mlir::DataLayout dataLayout(parentOp->getParentOfType<mlir::ModuleOp>());
344
mlir::Value init = visit(elt);
345
result =
346
rewriter.create<mlir::LLVM::InsertValueOp>(loc, result, init, idx);
347
}
348
} else if (auto strAttr = mlir::dyn_cast<mlir::StringAttr>(attr.getElts())) {
349
// TODO(cir): this diverges from traditional lowering. Normally the string
350
// would be a global constant that is memcopied.
351
auto arrayTy = mlir::dyn_cast<cir::ArrayType>(strAttr.getType());
352
assert(arrayTy && "String attribute must have an array type");
353
mlir::Type eltTy = arrayTy.getElementType();
354
for (auto [idx, elt] : llvm::enumerate(strAttr)) {
355
auto init = rewriter.create<mlir::LLVM::ConstantOp>(
356
loc, converter->convertType(eltTy), elt);
357
result =
358
rewriter.create<mlir::LLVM::InsertValueOp>(loc, result, init, idx);
359
}
360
} else {
361
llvm_unreachable("unexpected ConstArrayAttr elements");
362
}
363
364
return result;
365
}
366
367
/// ConstVectorAttr visitor.
368
mlir::Value CIRAttrToValue::visitCirAttr(cir::ConstVectorAttr attr) {
369
const mlir::Type llvmTy = converter->convertType(attr.getType());
370
const mlir::Location loc = parentOp->getLoc();
371
372
SmallVector<mlir::Attribute> mlirValues;
373
for (const mlir::Attribute elementAttr : attr.getElts()) {
374
mlir::Attribute mlirAttr;
375
if (auto intAttr = mlir::dyn_cast<cir::IntAttr>(elementAttr)) {
376
mlirAttr = rewriter.getIntegerAttr(
377
converter->convertType(intAttr.getType()), intAttr.getValue());
378
} else if (auto floatAttr = mlir::dyn_cast<cir::FPAttr>(elementAttr)) {
379
mlirAttr = rewriter.getFloatAttr(
380
converter->convertType(floatAttr.getType()), floatAttr.getValue());
381
} else {
382
llvm_unreachable(
383
"vector constant with an element that is neither an int nor a float");
384
}
385
mlirValues.push_back(mlirAttr);
386
}
387
388
return rewriter.create<mlir::LLVM::ConstantOp>(
389
loc, llvmTy,
390
mlir::DenseElementsAttr::get(mlir::cast<mlir::ShapedType>(llvmTy),
391
mlirValues));
392
}
393
394
/// ZeroAttr visitor.
395
mlir::Value CIRAttrToValue::visitCirAttr(cir::ZeroAttr attr) {
396
mlir::Location loc = parentOp->getLoc();
397
return rewriter.create<mlir::LLVM::ZeroOp>(
398
loc, converter->convertType(attr.getType()));
399
}
400
401
// This class handles rewriting initializer attributes for types that do not
402
// require region initialization.
403
class GlobalInitAttrRewriter {
404
public:
405
GlobalInitAttrRewriter(mlir::Type type,
406
mlir::ConversionPatternRewriter &rewriter)
407
: llvmType(type), rewriter(rewriter) {}
408
409
mlir::Attribute visit(mlir::Attribute attr) {
410
return llvm::TypeSwitch<mlir::Attribute, mlir::Attribute>(attr)
411
.Case<cir::IntAttr, cir::FPAttr, cir::BoolAttr>(
412
[&](auto attrT) { return visitCirAttr(attrT); })
413
.Default([&](auto attrT) { return mlir::Attribute(); });
414
}
415
416
mlir::Attribute visitCirAttr(cir::IntAttr attr) {
417
return rewriter.getIntegerAttr(llvmType, attr.getValue());
418
}
419
420
mlir::Attribute visitCirAttr(cir::FPAttr attr) {
421
return rewriter.getFloatAttr(llvmType, attr.getValue());
422
}
423
424
mlir::Attribute visitCirAttr(cir::BoolAttr attr) {
425
return rewriter.getBoolAttr(attr.getValue());
426
}
427
428
private:
429
mlir::Type llvmType;
430
mlir::ConversionPatternRewriter &rewriter;
431
};
432
433
// This pass requires the CIR to be in a "flat" state. All blocks in each
434
// function must belong to the parent region. Once scopes and control flow
435
// are implemented in CIR, a pass will be run before this one to flatten
436
// the CIR and get it into the state that this pass requires.
437
struct ConvertCIRToLLVMPass
438
: public mlir::PassWrapper<ConvertCIRToLLVMPass,
439
mlir::OperationPass<mlir::ModuleOp>> {
440
void getDependentDialects(mlir::DialectRegistry &registry) const override {
441
registry.insert<mlir::BuiltinDialect, mlir::DLTIDialect,
442
mlir::LLVM::LLVMDialect, mlir::func::FuncDialect>();
443
}
444
void runOnOperation() final;
445
446
void processCIRAttrs(mlir::ModuleOp module);
447
448
StringRef getDescription() const override {
449
return "Convert the prepared CIR dialect module to LLVM dialect";
450
}
451
452
StringRef getArgument() const override { return "cir-flat-to-llvm"; }
453
};
454
455
mlir::LogicalResult CIRToLLVMAssumeOpLowering::matchAndRewrite(
456
cir::AssumeOp op, OpAdaptor adaptor,
457
mlir::ConversionPatternRewriter &rewriter) const {
458
auto cond = adaptor.getPredicate();
459
rewriter.replaceOpWithNewOp<mlir::LLVM::AssumeOp>(op, cond);
460
return mlir::success();
461
}
462
463
mlir::LogicalResult CIRToLLVMBitClrsbOpLowering::matchAndRewrite(
464
cir::BitClrsbOp op, OpAdaptor adaptor,
465
mlir::ConversionPatternRewriter &rewriter) const {
466
auto zero = rewriter.create<mlir::LLVM::ConstantOp>(
467
op.getLoc(), adaptor.getInput().getType(), 0);
468
auto isNeg = rewriter.create<mlir::LLVM::ICmpOp>(
469
op.getLoc(),
470
mlir::LLVM::ICmpPredicateAttr::get(rewriter.getContext(),
471
mlir::LLVM::ICmpPredicate::slt),
472
adaptor.getInput(), zero);
473
474
auto negOne = rewriter.create<mlir::LLVM::ConstantOp>(
475
op.getLoc(), adaptor.getInput().getType(), -1);
476
auto flipped = rewriter.create<mlir::LLVM::XOrOp>(op.getLoc(),
477
adaptor.getInput(), negOne);
478
479
auto select = rewriter.create<mlir::LLVM::SelectOp>(
480
op.getLoc(), isNeg, flipped, adaptor.getInput());
481
482
auto resTy = getTypeConverter()->convertType(op.getType());
483
auto clz = rewriter.create<mlir::LLVM::CountLeadingZerosOp>(
484
op.getLoc(), resTy, select, /*is_zero_poison=*/false);
485
486
auto one = rewriter.create<mlir::LLVM::ConstantOp>(op.getLoc(), resTy, 1);
487
auto res = rewriter.create<mlir::LLVM::SubOp>(op.getLoc(), clz, one);
488
rewriter.replaceOp(op, res);
489
490
return mlir::LogicalResult::success();
491
}
492
493
mlir::LogicalResult CIRToLLVMBitClzOpLowering::matchAndRewrite(
494
cir::BitClzOp op, OpAdaptor adaptor,
495
mlir::ConversionPatternRewriter &rewriter) const {
496
auto resTy = getTypeConverter()->convertType(op.getType());
497
auto llvmOp = rewriter.create<mlir::LLVM::CountLeadingZerosOp>(
498
op.getLoc(), resTy, adaptor.getInput(), op.getPoisonZero());
499
rewriter.replaceOp(op, llvmOp);
500
return mlir::LogicalResult::success();
501
}
502
503
mlir::LogicalResult CIRToLLVMBitCtzOpLowering::matchAndRewrite(
504
cir::BitCtzOp op, OpAdaptor adaptor,
505
mlir::ConversionPatternRewriter &rewriter) const {
506
auto resTy = getTypeConverter()->convertType(op.getType());
507
auto llvmOp = rewriter.create<mlir::LLVM::CountTrailingZerosOp>(
508
op.getLoc(), resTy, adaptor.getInput(), op.getPoisonZero());
509
rewriter.replaceOp(op, llvmOp);
510
return mlir::LogicalResult::success();
511
}
512
513
mlir::LogicalResult CIRToLLVMBitParityOpLowering::matchAndRewrite(
514
cir::BitParityOp op, OpAdaptor adaptor,
515
mlir::ConversionPatternRewriter &rewriter) const {
516
auto resTy = getTypeConverter()->convertType(op.getType());
517
auto popcnt = rewriter.create<mlir::LLVM::CtPopOp>(op.getLoc(), resTy,
518
adaptor.getInput());
519
520
auto one = rewriter.create<mlir::LLVM::ConstantOp>(op.getLoc(), resTy, 1);
521
auto popcntMod2 =
522
rewriter.create<mlir::LLVM::AndOp>(op.getLoc(), popcnt, one);
523
rewriter.replaceOp(op, popcntMod2);
524
525
return mlir::LogicalResult::success();
526
}
527
528
mlir::LogicalResult CIRToLLVMBitPopcountOpLowering::matchAndRewrite(
529
cir::BitPopcountOp op, OpAdaptor adaptor,
530
mlir::ConversionPatternRewriter &rewriter) const {
531
auto resTy = getTypeConverter()->convertType(op.getType());
532
auto llvmOp = rewriter.create<mlir::LLVM::CtPopOp>(op.getLoc(), resTy,
533
adaptor.getInput());
534
rewriter.replaceOp(op, llvmOp);
535
return mlir::LogicalResult::success();
536
}
537
538
mlir::LogicalResult CIRToLLVMBitReverseOpLowering::matchAndRewrite(
539
cir::BitReverseOp op, OpAdaptor adaptor,
540
mlir::ConversionPatternRewriter &rewriter) const {
541
rewriter.replaceOpWithNewOp<mlir::LLVM::BitReverseOp>(op, adaptor.getInput());
542
return mlir::success();
543
}
544
545
mlir::LogicalResult CIRToLLVMBrCondOpLowering::matchAndRewrite(
546
cir::BrCondOp brOp, OpAdaptor adaptor,
547
mlir::ConversionPatternRewriter &rewriter) const {
548
// When ZExtOp is implemented, we'll need to check if the condition is a
549
// ZExtOp and if so, delete it if it has a single use.
550
assert(!cir::MissingFeatures::zextOp());
551
552
mlir::Value i1Condition = adaptor.getCond();
553
554
rewriter.replaceOpWithNewOp<mlir::LLVM::CondBrOp>(
555
brOp, i1Condition, brOp.getDestTrue(), adaptor.getDestOperandsTrue(),
556
brOp.getDestFalse(), adaptor.getDestOperandsFalse());
557
558
return mlir::success();
559
}
560
561
mlir::LogicalResult CIRToLLVMByteSwapOpLowering::matchAndRewrite(
562
cir::ByteSwapOp op, OpAdaptor adaptor,
563
mlir::ConversionPatternRewriter &rewriter) const {
564
rewriter.replaceOpWithNewOp<mlir::LLVM::ByteSwapOp>(op, adaptor.getInput());
565
return mlir::LogicalResult::success();
566
}
567
568
mlir::Type CIRToLLVMCastOpLowering::convertTy(mlir::Type ty) const {
569
return getTypeConverter()->convertType(ty);
570
}
571
572
mlir::LogicalResult CIRToLLVMCastOpLowering::matchAndRewrite(
573
cir::CastOp castOp, OpAdaptor adaptor,
574
mlir::ConversionPatternRewriter &rewriter) const {
575
// For arithmetic conversions, LLVM IR uses the same instruction to convert
576
// both individual scalars and entire vectors. This lowering pass handles
577
// both situations.
578
579
switch (castOp.getKind()) {
580
case cir::CastKind::array_to_ptrdecay: {
581
const auto ptrTy = mlir::cast<cir::PointerType>(castOp.getType());
582
mlir::Value sourceValue = adaptor.getSrc();
583
mlir::Type targetType = convertTy(ptrTy);
584
mlir::Type elementTy = convertTypeForMemory(*getTypeConverter(), dataLayout,
585
ptrTy.getPointee());
586
llvm::SmallVector<mlir::LLVM::GEPArg> offset{0};
587
rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(
588
castOp, targetType, elementTy, sourceValue, offset);
589
break;
590
}
591
case cir::CastKind::int_to_bool: {
592
mlir::Value llvmSrcVal = adaptor.getSrc();
593
mlir::Value zeroInt = rewriter.create<mlir::LLVM::ConstantOp>(
594
castOp.getLoc(), llvmSrcVal.getType(), 0);
595
rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>(
596
castOp, mlir::LLVM::ICmpPredicate::ne, llvmSrcVal, zeroInt);
597
break;
598
}
599
case cir::CastKind::integral: {
600
mlir::Type srcType = castOp.getSrc().getType();
601
mlir::Type dstType = castOp.getType();
602
mlir::Value llvmSrcVal = adaptor.getSrc();
603
mlir::Type llvmDstType = getTypeConverter()->convertType(dstType);
604
cir::IntType srcIntType =
605
mlir::cast<cir::IntType>(elementTypeIfVector(srcType));
606
cir::IntType dstIntType =
607
mlir::cast<cir::IntType>(elementTypeIfVector(dstType));
608
rewriter.replaceOp(castOp, getLLVMIntCast(rewriter, llvmSrcVal, llvmDstType,
609
srcIntType.isUnsigned(),
610
srcIntType.getWidth(),
611
dstIntType.getWidth()));
612
break;
613
}
614
case cir::CastKind::floating: {
615
mlir::Value llvmSrcVal = adaptor.getSrc();
616
mlir::Type llvmDstTy = getTypeConverter()->convertType(castOp.getType());
617
618
mlir::Type srcTy = elementTypeIfVector(castOp.getSrc().getType());
619
mlir::Type dstTy = elementTypeIfVector(castOp.getType());
620
621
if (!mlir::isa<cir::FPTypeInterface>(dstTy) ||
622
!mlir::isa<cir::FPTypeInterface>(srcTy))
623
return castOp.emitError() << "NYI cast from " << srcTy << " to " << dstTy;
624
625
auto getFloatWidth = [](mlir::Type ty) -> unsigned {
626
return mlir::cast<cir::FPTypeInterface>(ty).getWidth();
627
};
628
629
if (getFloatWidth(srcTy) > getFloatWidth(dstTy))
630
rewriter.replaceOpWithNewOp<mlir::LLVM::FPTruncOp>(castOp, llvmDstTy,
631
llvmSrcVal);
632
else
633
rewriter.replaceOpWithNewOp<mlir::LLVM::FPExtOp>(castOp, llvmDstTy,
634
llvmSrcVal);
635
return mlir::success();
636
}
637
case cir::CastKind::int_to_ptr: {
638
auto dstTy = mlir::cast<cir::PointerType>(castOp.getType());
639
mlir::Value llvmSrcVal = adaptor.getSrc();
640
mlir::Type llvmDstTy = getTypeConverter()->convertType(dstTy);
641
rewriter.replaceOpWithNewOp<mlir::LLVM::IntToPtrOp>(castOp, llvmDstTy,
642
llvmSrcVal);
643
return mlir::success();
644
}
645
case cir::CastKind::ptr_to_int: {
646
auto dstTy = mlir::cast<cir::IntType>(castOp.getType());
647
mlir::Value llvmSrcVal = adaptor.getSrc();
648
mlir::Type llvmDstTy = getTypeConverter()->convertType(dstTy);
649
rewriter.replaceOpWithNewOp<mlir::LLVM::PtrToIntOp>(castOp, llvmDstTy,
650
llvmSrcVal);
651
return mlir::success();
652
}
653
case cir::CastKind::float_to_bool: {
654
mlir::Value llvmSrcVal = adaptor.getSrc();
655
auto kind = mlir::LLVM::FCmpPredicate::une;
656
657
// Check if float is not equal to zero.
658
auto zeroFloat = rewriter.create<mlir::LLVM::ConstantOp>(
659
castOp.getLoc(), llvmSrcVal.getType(),
660
mlir::FloatAttr::get(llvmSrcVal.getType(), 0.0));
661
662
// Extend comparison result to either bool (C++) or int (C).
663
rewriter.replaceOpWithNewOp<mlir::LLVM::FCmpOp>(castOp, kind, llvmSrcVal,
664
zeroFloat);
665
666
return mlir::success();
667
}
668
case cir::CastKind::bool_to_int: {
669
auto dstTy = mlir::cast<cir::IntType>(castOp.getType());
670
mlir::Value llvmSrcVal = adaptor.getSrc();
671
auto llvmSrcTy = mlir::cast<mlir::IntegerType>(llvmSrcVal.getType());
672
auto llvmDstTy =
673
mlir::cast<mlir::IntegerType>(getTypeConverter()->convertType(dstTy));
674
if (llvmSrcTy.getWidth() == llvmDstTy.getWidth())
675
rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(castOp, llvmDstTy,
676
llvmSrcVal);
677
else
678
rewriter.replaceOpWithNewOp<mlir::LLVM::ZExtOp>(castOp, llvmDstTy,
679
llvmSrcVal);
680
return mlir::success();
681
}
682
case cir::CastKind::bool_to_float: {
683
mlir::Type dstTy = castOp.getType();
684
mlir::Value llvmSrcVal = adaptor.getSrc();
685
mlir::Type llvmDstTy = getTypeConverter()->convertType(dstTy);
686
rewriter.replaceOpWithNewOp<mlir::LLVM::UIToFPOp>(castOp, llvmDstTy,
687
llvmSrcVal);
688
return mlir::success();
689
}
690
case cir::CastKind::int_to_float: {
691
mlir::Type dstTy = castOp.getType();
692
mlir::Value llvmSrcVal = adaptor.getSrc();
693
mlir::Type llvmDstTy = getTypeConverter()->convertType(dstTy);
694
if (mlir::cast<cir::IntType>(elementTypeIfVector(castOp.getSrc().getType()))
695
.isSigned())
696
rewriter.replaceOpWithNewOp<mlir::LLVM::SIToFPOp>(castOp, llvmDstTy,
697
llvmSrcVal);
698
else
699
rewriter.replaceOpWithNewOp<mlir::LLVM::UIToFPOp>(castOp, llvmDstTy,
700
llvmSrcVal);
701
return mlir::success();
702
}
703
case cir::CastKind::float_to_int: {
704
mlir::Type dstTy = castOp.getType();
705
mlir::Value llvmSrcVal = adaptor.getSrc();
706
mlir::Type llvmDstTy = getTypeConverter()->convertType(dstTy);
707
if (mlir::cast<cir::IntType>(elementTypeIfVector(castOp.getType()))
708
.isSigned())
709
rewriter.replaceOpWithNewOp<mlir::LLVM::FPToSIOp>(castOp, llvmDstTy,
710
llvmSrcVal);
711
else
712
rewriter.replaceOpWithNewOp<mlir::LLVM::FPToUIOp>(castOp, llvmDstTy,
713
llvmSrcVal);
714
return mlir::success();
715
}
716
case cir::CastKind::bitcast: {
717
mlir::Type dstTy = castOp.getType();
718
mlir::Type llvmDstTy = getTypeConverter()->convertType(dstTy);
719
720
assert(!MissingFeatures::cxxABI());
721
assert(!MissingFeatures::dataMemberType());
722
723
mlir::Value llvmSrcVal = adaptor.getSrc();
724
rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(castOp, llvmDstTy,
725
llvmSrcVal);
726
return mlir::success();
727
}
728
case cir::CastKind::ptr_to_bool: {
729
mlir::Value llvmSrcVal = adaptor.getSrc();
730
mlir::Value zeroPtr = rewriter.create<mlir::LLVM::ZeroOp>(
731
castOp.getLoc(), llvmSrcVal.getType());
732
rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>(
733
castOp, mlir::LLVM::ICmpPredicate::ne, llvmSrcVal, zeroPtr);
734
break;
735
}
736
case cir::CastKind::address_space: {
737
mlir::Type dstTy = castOp.getType();
738
mlir::Value llvmSrcVal = adaptor.getSrc();
739
mlir::Type llvmDstTy = getTypeConverter()->convertType(dstTy);
740
rewriter.replaceOpWithNewOp<mlir::LLVM::AddrSpaceCastOp>(castOp, llvmDstTy,
741
llvmSrcVal);
742
break;
743
}
744
case cir::CastKind::member_ptr_to_bool:
745
assert(!MissingFeatures::cxxABI());
746
assert(!MissingFeatures::methodType());
747
break;
748
default: {
749
return castOp.emitError("Unhandled cast kind: ")
750
<< castOp.getKindAttrName();
751
}
752
}
753
754
return mlir::success();
755
}
756
757
mlir::LogicalResult CIRToLLVMPtrStrideOpLowering::matchAndRewrite(
758
cir::PtrStrideOp ptrStrideOp, OpAdaptor adaptor,
759
mlir::ConversionPatternRewriter &rewriter) const {
760
761
const mlir::TypeConverter *tc = getTypeConverter();
762
const mlir::Type resultTy = tc->convertType(ptrStrideOp.getType());
763
764
mlir::Type elementTy =
765
convertTypeForMemory(*tc, dataLayout, ptrStrideOp.getElementTy());
766
mlir::MLIRContext *ctx = elementTy.getContext();
767
768
// void and function types doesn't really have a layout to use in GEPs,
769
// make it i8 instead.
770
if (mlir::isa<mlir::LLVM::LLVMVoidType>(elementTy) ||
771
mlir::isa<mlir::LLVM::LLVMFunctionType>(elementTy))
772
elementTy = mlir::IntegerType::get(elementTy.getContext(), 8,
773
mlir::IntegerType::Signless);
774
// Zero-extend, sign-extend or trunc the pointer value.
775
mlir::Value index = adaptor.getStride();
776
const unsigned width =
777
mlir::cast<mlir::IntegerType>(index.getType()).getWidth();
778
const std::optional<std::uint64_t> layoutWidth =
779
dataLayout.getTypeIndexBitwidth(adaptor.getBase().getType());
780
781
mlir::Operation *indexOp = index.getDefiningOp();
782
if (indexOp && layoutWidth && width != *layoutWidth) {
783
// If the index comes from a subtraction, make sure the extension happens
784
// before it. To achieve that, look at unary minus, which already got
785
// lowered to "sub 0, x".
786
const auto sub = dyn_cast<mlir::LLVM::SubOp>(indexOp);
787
auto unary = dyn_cast_if_present<cir::UnaryOp>(
788
ptrStrideOp.getStride().getDefiningOp());
789
bool rewriteSub =
790
unary && unary.getKind() == cir::UnaryOpKind::Minus && sub;
791
if (rewriteSub)
792
index = indexOp->getOperand(1);
793
794
// Handle the cast
795
const auto llvmDstType = mlir::IntegerType::get(ctx, *layoutWidth);
796
index = getLLVMIntCast(rewriter, index, llvmDstType,
797
ptrStrideOp.getStride().getType().isUnsigned(),
798
width, *layoutWidth);
799
800
// Rewrite the sub in front of extensions/trunc
801
if (rewriteSub) {
802
index = rewriter.create<mlir::LLVM::SubOp>(
803
index.getLoc(), index.getType(),
804
rewriter.create<mlir::LLVM::ConstantOp>(index.getLoc(),
805
index.getType(), 0),
806
index);
807
rewriter.eraseOp(sub);
808
}
809
}
810
811
rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(
812
ptrStrideOp, resultTy, elementTy, adaptor.getBase(), index);
813
return mlir::success();
814
}
815
816
mlir::LogicalResult CIRToLLVMBaseClassAddrOpLowering::matchAndRewrite(
817
cir::BaseClassAddrOp baseClassOp, OpAdaptor adaptor,
818
mlir::ConversionPatternRewriter &rewriter) const {
819
const mlir::Type resultType =
820
getTypeConverter()->convertType(baseClassOp.getType());
821
mlir::Value derivedAddr = adaptor.getDerivedAddr();
822
llvm::SmallVector<mlir::LLVM::GEPArg, 1> offset = {
823
adaptor.getOffset().getZExtValue()};
824
mlir::Type byteType = mlir::IntegerType::get(resultType.getContext(), 8,
825
mlir::IntegerType::Signless);
826
if (adaptor.getOffset().getZExtValue() == 0) {
827
rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(
828
baseClassOp, resultType, adaptor.getDerivedAddr());
829
return mlir::success();
830
}
831
832
if (baseClassOp.getAssumeNotNull()) {
833
rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(
834
baseClassOp, resultType, byteType, derivedAddr, offset);
835
} else {
836
auto loc = baseClassOp.getLoc();
837
mlir::Value isNull = rewriter.create<mlir::LLVM::ICmpOp>(
838
loc, mlir::LLVM::ICmpPredicate::eq, derivedAddr,
839
rewriter.create<mlir::LLVM::ZeroOp>(loc, derivedAddr.getType()));
840
mlir::Value adjusted = rewriter.create<mlir::LLVM::GEPOp>(
841
loc, resultType, byteType, derivedAddr, offset);
842
rewriter.replaceOpWithNewOp<mlir::LLVM::SelectOp>(baseClassOp, isNull,
843
derivedAddr, adjusted);
844
}
845
return mlir::success();
846
}
847
848
mlir::LogicalResult CIRToLLVMAllocaOpLowering::matchAndRewrite(
849
cir::AllocaOp op, OpAdaptor adaptor,
850
mlir::ConversionPatternRewriter &rewriter) const {
851
assert(!cir::MissingFeatures::opAllocaDynAllocSize());
852
mlir::Value size = rewriter.create<mlir::LLVM::ConstantOp>(
853
op.getLoc(), typeConverter->convertType(rewriter.getIndexType()), 1);
854
mlir::Type elementTy =
855
convertTypeForMemory(*getTypeConverter(), dataLayout, op.getAllocaType());
856
mlir::Type resultTy =
857
convertTypeForMemory(*getTypeConverter(), dataLayout, op.getType());
858
859
assert(!cir::MissingFeatures::addressSpace());
860
assert(!cir::MissingFeatures::opAllocaAnnotations());
861
862
rewriter.replaceOpWithNewOp<mlir::LLVM::AllocaOp>(
863
op, resultTy, elementTy, size, op.getAlignmentAttr().getInt());
864
865
return mlir::success();
866
}
867
868
mlir::LogicalResult CIRToLLVMReturnOpLowering::matchAndRewrite(
869
cir::ReturnOp op, OpAdaptor adaptor,
870
mlir::ConversionPatternRewriter &rewriter) const {
871
rewriter.replaceOpWithNewOp<mlir::LLVM::ReturnOp>(op, adaptor.getOperands());
872
return mlir::LogicalResult::success();
873
}
874
875
static mlir::LogicalResult
876
rewriteCallOrInvoke(mlir::Operation *op, mlir::ValueRange callOperands,
877
mlir::ConversionPatternRewriter &rewriter,
878
const mlir::TypeConverter *converter,
879
mlir::FlatSymbolRefAttr calleeAttr) {
880
llvm::SmallVector<mlir::Type, 8> llvmResults;
881
mlir::ValueTypeRange<mlir::ResultRange> cirResults = op->getResultTypes();
882
auto call = cast<cir::CIRCallOpInterface>(op);
883
884
if (converter->convertTypes(cirResults, llvmResults).failed())
885
return mlir::failure();
886
887
assert(!cir::MissingFeatures::opCallCallConv());
888
889
mlir::LLVM::MemoryEffectsAttr memoryEffects;
890
bool noUnwind = false;
891
bool willReturn = false;
892
convertSideEffectForCall(op, call.getNothrow(), call.getSideEffect(),
893
memoryEffects, noUnwind, willReturn);
894
895
mlir::LLVM::LLVMFunctionType llvmFnTy;
896
if (calleeAttr) { // direct call
897
mlir::FunctionOpInterface fn =
898
mlir::SymbolTable::lookupNearestSymbolFrom<mlir::FunctionOpInterface>(
899
op, calleeAttr);
900
assert(fn && "Did not find function for call");
901
llvmFnTy = cast<mlir::LLVM::LLVMFunctionType>(
902
converter->convertType(fn.getFunctionType()));
903
} else { // indirect call
904
assert(!op->getOperands().empty() &&
905
"operands list must no be empty for the indirect call");
906
auto calleeTy = op->getOperands().front().getType();
907
auto calleePtrTy = cast<cir::PointerType>(calleeTy);
908
auto calleeFuncTy = cast<cir::FuncType>(calleePtrTy.getPointee());
909
calleeFuncTy.dump();
910
converter->convertType(calleeFuncTy).dump();
911
llvmFnTy = cast<mlir::LLVM::LLVMFunctionType>(
912
converter->convertType(calleeFuncTy));
913
}
914
915
assert(!cir::MissingFeatures::opCallLandingPad());
916
assert(!cir::MissingFeatures::opCallContinueBlock());
917
assert(!cir::MissingFeatures::opCallCallConv());
918
919
auto newOp = rewriter.replaceOpWithNewOp<mlir::LLVM::CallOp>(
920
op, llvmFnTy, calleeAttr, callOperands);
921
if (memoryEffects)
922
newOp.setMemoryEffectsAttr(memoryEffects);
923
newOp.setNoUnwind(noUnwind);
924
newOp.setWillReturn(willReturn);
925
926
return mlir::success();
927
}
928
929
mlir::LogicalResult CIRToLLVMCallOpLowering::matchAndRewrite(
930
cir::CallOp op, OpAdaptor adaptor,
931
mlir::ConversionPatternRewriter &rewriter) const {
932
return rewriteCallOrInvoke(op.getOperation(), adaptor.getOperands(), rewriter,
933
getTypeConverter(), op.getCalleeAttr());
934
}
935
936
mlir::LogicalResult CIRToLLVMLoadOpLowering::matchAndRewrite(
937
cir::LoadOp op, OpAdaptor adaptor,
938
mlir::ConversionPatternRewriter &rewriter) const {
939
const mlir::Type llvmTy =
940
convertTypeForMemory(*getTypeConverter(), dataLayout, op.getType());
941
assert(!cir::MissingFeatures::opLoadStoreMemOrder());
942
std::optional<size_t> opAlign = op.getAlignment();
943
unsigned alignment =
944
(unsigned)opAlign.value_or(dataLayout.getTypeABIAlignment(llvmTy));
945
946
assert(!cir::MissingFeatures::lowerModeOptLevel());
947
948
// TODO: nontemporal, syncscope.
949
assert(!cir::MissingFeatures::opLoadStoreVolatile());
950
mlir::LLVM::LoadOp newLoad = rewriter.create<mlir::LLVM::LoadOp>(
951
op->getLoc(), llvmTy, adaptor.getAddr(), alignment,
952
/*volatile=*/false, /*nontemporal=*/false,
953
/*invariant=*/false, /*invariantGroup=*/false,
954
mlir::LLVM::AtomicOrdering::not_atomic);
955
956
// Convert adapted result to its original type if needed.
957
mlir::Value result =
958
emitFromMemory(rewriter, dataLayout, op, newLoad.getResult());
959
rewriter.replaceOp(op, result);
960
assert(!cir::MissingFeatures::opLoadStoreTbaa());
961
return mlir::LogicalResult::success();
962
}
963
964
mlir::LogicalResult CIRToLLVMStoreOpLowering::matchAndRewrite(
965
cir::StoreOp op, OpAdaptor adaptor,
966
mlir::ConversionPatternRewriter &rewriter) const {
967
assert(!cir::MissingFeatures::opLoadStoreMemOrder());
968
const mlir::Type llvmTy =
969
getTypeConverter()->convertType(op.getValue().getType());
970
std::optional<size_t> opAlign = op.getAlignment();
971
unsigned alignment =
972
(unsigned)opAlign.value_or(dataLayout.getTypeABIAlignment(llvmTy));
973
974
assert(!cir::MissingFeatures::lowerModeOptLevel());
975
976
// Convert adapted value to its memory type if needed.
977
mlir::Value value = emitToMemory(rewriter, dataLayout,
978
op.getValue().getType(), adaptor.getValue());
979
// TODO: nontemporal, syncscope.
980
assert(!cir::MissingFeatures::opLoadStoreVolatile());
981
mlir::LLVM::StoreOp storeOp = rewriter.create<mlir::LLVM::StoreOp>(
982
op->getLoc(), value, adaptor.getAddr(), alignment, /*volatile=*/false,
983
/*nontemporal=*/false, /*invariantGroup=*/false,
984
mlir::LLVM::AtomicOrdering::not_atomic);
985
rewriter.replaceOp(op, storeOp);
986
assert(!cir::MissingFeatures::opLoadStoreTbaa());
987
return mlir::LogicalResult::success();
988
}
989
990
bool hasTrailingZeros(cir::ConstArrayAttr attr) {
991
auto array = mlir::dyn_cast<mlir::ArrayAttr>(attr.getElts());
992
return attr.hasTrailingZeros() ||
993
(array && std::count_if(array.begin(), array.end(), [](auto elt) {
994
auto ar = dyn_cast<cir::ConstArrayAttr>(elt);
995
return ar && hasTrailingZeros(ar);
996
}));
997
}
998
999
mlir::LogicalResult CIRToLLVMConstantOpLowering::matchAndRewrite(
1000
cir::ConstantOp op, OpAdaptor adaptor,
1001
mlir::ConversionPatternRewriter &rewriter) const {
1002
mlir::Attribute attr = op.getValue();
1003
1004
if (mlir::isa<mlir::IntegerType>(op.getType())) {
1005
// Verified cir.const operations cannot actually be of these types, but the
1006
// lowering pass may generate temporary cir.const operations with these
1007
// types. This is OK since MLIR allows unverified operations to be alive
1008
// during a pass as long as they don't live past the end of the pass.
1009
attr = op.getValue();
1010
} else if (mlir::isa<cir::BoolType>(op.getType())) {
1011
int value = mlir::cast<cir::BoolAttr>(op.getValue()).getValue();
1012
attr = rewriter.getIntegerAttr(typeConverter->convertType(op.getType()),
1013
value);
1014
} else if (mlir::isa<cir::IntType>(op.getType())) {
1015
assert(!cir::MissingFeatures::opGlobalViewAttr());
1016
1017
attr = rewriter.getIntegerAttr(
1018
typeConverter->convertType(op.getType()),
1019
mlir::cast<cir::IntAttr>(op.getValue()).getValue());
1020
} else if (mlir::isa<cir::FPTypeInterface>(op.getType())) {
1021
attr = rewriter.getFloatAttr(
1022
typeConverter->convertType(op.getType()),
1023
mlir::cast<cir::FPAttr>(op.getValue()).getValue());
1024
} else if (mlir::isa<cir::PointerType>(op.getType())) {
1025
// Optimize with dedicated LLVM op for null pointers.
1026
if (mlir::isa<cir::ConstPtrAttr>(op.getValue())) {
1027
if (mlir::cast<cir::ConstPtrAttr>(op.getValue()).isNullValue()) {
1028
rewriter.replaceOpWithNewOp<mlir::LLVM::ZeroOp>(
1029
op, typeConverter->convertType(op.getType()));
1030
return mlir::success();
1031
}
1032
}
1033
assert(!cir::MissingFeatures::opGlobalViewAttr());
1034
attr = op.getValue();
1035
} else if (const auto arrTy = mlir::dyn_cast<cir::ArrayType>(op.getType())) {
1036
const auto constArr = mlir::dyn_cast<cir::ConstArrayAttr>(op.getValue());
1037
if (!constArr && !isa<cir::ZeroAttr, cir::UndefAttr>(op.getValue()))
1038
return op.emitError() << "array does not have a constant initializer";
1039
1040
std::optional<mlir::Attribute> denseAttr;
1041
if (constArr && hasTrailingZeros(constArr)) {
1042
const mlir::Value newOp =
1043
lowerCirAttrAsValue(op, constArr, rewriter, getTypeConverter());
1044
rewriter.replaceOp(op, newOp);
1045
return mlir::success();
1046
} else if (constArr &&
1047
(denseAttr = lowerConstArrayAttr(constArr, typeConverter))) {
1048
attr = denseAttr.value();
1049
} else {
1050
const mlir::Value initVal =
1051
lowerCirAttrAsValue(op, op.getValue(), rewriter, typeConverter);
1052
rewriter.replaceAllUsesWith(op, initVal);
1053
rewriter.eraseOp(op);
1054
return mlir::success();
1055
}
1056
} else if (const auto vecTy = mlir::dyn_cast<cir::VectorType>(op.getType())) {
1057
rewriter.replaceOp(op, lowerCirAttrAsValue(op, op.getValue(), rewriter,
1058
getTypeConverter()));
1059
return mlir::success();
1060
} else if (auto complexTy = mlir::dyn_cast<cir::ComplexType>(op.getType())) {
1061
mlir::Type complexElemTy = complexTy.getElementType();
1062
mlir::Type complexElemLLVMTy = typeConverter->convertType(complexElemTy);
1063
1064
if (auto zeroInitAttr = mlir::dyn_cast<cir::ZeroAttr>(op.getValue())) {
1065
mlir::TypedAttr zeroAttr = rewriter.getZeroAttr(complexElemLLVMTy);
1066
mlir::ArrayAttr array = rewriter.getArrayAttr({zeroAttr, zeroAttr});
1067
rewriter.replaceOpWithNewOp<mlir::LLVM::ConstantOp>(
1068
op, getTypeConverter()->convertType(op.getType()), array);
1069
return mlir::success();
1070
}
1071
1072
auto complexAttr = mlir::cast<cir::ConstComplexAttr>(op.getValue());
1073
1074
mlir::Attribute components[2];
1075
if (mlir::isa<cir::IntType>(complexElemTy)) {
1076
components[0] = rewriter.getIntegerAttr(
1077
complexElemLLVMTy,
1078
mlir::cast<cir::IntAttr>(complexAttr.getReal()).getValue());
1079
components[1] = rewriter.getIntegerAttr(
1080
complexElemLLVMTy,
1081
mlir::cast<cir::IntAttr>(complexAttr.getImag()).getValue());
1082
} else {
1083
components[0] = rewriter.getFloatAttr(
1084
complexElemLLVMTy,
1085
mlir::cast<cir::FPAttr>(complexAttr.getReal()).getValue());
1086
components[1] = rewriter.getFloatAttr(
1087
complexElemLLVMTy,
1088
mlir::cast<cir::FPAttr>(complexAttr.getImag()).getValue());
1089
}
1090
1091
attr = rewriter.getArrayAttr(components);
1092
} else {
1093
return op.emitError() << "unsupported constant type " << op.getType();
1094
}
1095
1096
rewriter.replaceOpWithNewOp<mlir::LLVM::ConstantOp>(
1097
op, getTypeConverter()->convertType(op.getType()), attr);
1098
1099
return mlir::success();
1100
}
1101
1102
mlir::LogicalResult CIRToLLVMExpectOpLowering::matchAndRewrite(
1103
cir::ExpectOp op, OpAdaptor adaptor,
1104
mlir::ConversionPatternRewriter &rewriter) const {
1105
// TODO(cir): do not generate LLVM intrinsics under -O0
1106
assert(!cir::MissingFeatures::optInfoAttr());
1107
1108
std::optional<llvm::APFloat> prob = op.getProb();
1109
if (prob)
1110
rewriter.replaceOpWithNewOp<mlir::LLVM::ExpectWithProbabilityOp>(
1111
op, adaptor.getVal(), adaptor.getExpected(), prob.value());
1112
else
1113
rewriter.replaceOpWithNewOp<mlir::LLVM::ExpectOp>(op, adaptor.getVal(),
1114
adaptor.getExpected());
1115
return mlir::success();
1116
}
1117
1118
/// Convert the `cir.func` attributes to `llvm.func` attributes.
1119
/// Only retain those attributes that are not constructed by
1120
/// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out
1121
/// argument attributes.
1122
void CIRToLLVMFuncOpLowering::lowerFuncAttributes(
1123
cir::FuncOp func, bool filterArgAndResAttrs,
1124
SmallVectorImpl<mlir::NamedAttribute> &result) const {
1125
assert(!cir::MissingFeatures::opFuncCallingConv());
1126
for (mlir::NamedAttribute attr : func->getAttrs()) {
1127
assert(!cir::MissingFeatures::opFuncCallingConv());
1128
if (attr.getName() == mlir::SymbolTable::getSymbolAttrName() ||
1129
attr.getName() == func.getFunctionTypeAttrName() ||
1130
attr.getName() == getLinkageAttrNameString() ||
1131
attr.getName() == func.getGlobalVisibilityAttrName() ||
1132
attr.getName() == func.getDsoLocalAttrName() ||
1133
(filterArgAndResAttrs &&
1134
(attr.getName() == func.getArgAttrsAttrName() ||
1135
attr.getName() == func.getResAttrsAttrName())))
1136
continue;
1137
1138
assert(!cir::MissingFeatures::opFuncExtraAttrs());
1139
result.push_back(attr);
1140
}
1141
}
1142
1143
mlir::LogicalResult CIRToLLVMFuncOpLowering::matchAndRewrite(
1144
cir::FuncOp op, OpAdaptor adaptor,
1145
mlir::ConversionPatternRewriter &rewriter) const {
1146
1147
cir::FuncType fnType = op.getFunctionType();
1148
bool isDsoLocal = op.getDsoLocal();
1149
mlir::TypeConverter::SignatureConversion signatureConversion(
1150
fnType.getNumInputs());
1151
1152
for (const auto &argType : llvm::enumerate(fnType.getInputs())) {
1153
mlir::Type convertedType = typeConverter->convertType(argType.value());
1154
if (!convertedType)
1155
return mlir::failure();
1156
signatureConversion.addInputs(argType.index(), convertedType);
1157
}
1158
1159
mlir::Type resultType =
1160
getTypeConverter()->convertType(fnType.getReturnType());
1161
1162
// Create the LLVM function operation.
1163
mlir::Type llvmFnTy = mlir::LLVM::LLVMFunctionType::get(
1164
resultType ? resultType : mlir::LLVM::LLVMVoidType::get(getContext()),
1165
signatureConversion.getConvertedTypes(),
1166
/*isVarArg=*/fnType.isVarArg());
1167
// LLVMFuncOp expects a single FileLine Location instead of a fused
1168
// location.
1169
mlir::Location loc = op.getLoc();
1170
if (mlir::FusedLoc fusedLoc = mlir::dyn_cast<mlir::FusedLoc>(loc))
1171
loc = fusedLoc.getLocations()[0];
1172
assert((mlir::isa<mlir::FileLineColLoc>(loc) ||
1173
mlir::isa<mlir::UnknownLoc>(loc)) &&
1174
"expected single location or unknown location here");
1175
1176
mlir::LLVM::Linkage linkage = convertLinkage(op.getLinkage());
1177
assert(!cir::MissingFeatures::opFuncCallingConv());
1178
mlir::LLVM::CConv cconv = mlir::LLVM::CConv::C;
1179
SmallVector<mlir::NamedAttribute, 4> attributes;
1180
lowerFuncAttributes(op, /*filterArgAndResAttrs=*/false, attributes);
1181
1182
mlir::LLVM::LLVMFuncOp fn = rewriter.create<mlir::LLVM::LLVMFuncOp>(
1183
loc, op.getName(), llvmFnTy, linkage, isDsoLocal, cconv,
1184
mlir::SymbolRefAttr(), attributes);
1185
1186
assert(!cir::MissingFeatures::opFuncMultipleReturnVals());
1187
1188
fn.setVisibility_Attr(mlir::LLVM::VisibilityAttr::get(
1189
getContext(), lowerCIRVisibilityToLLVMVisibility(
1190
op.getGlobalVisibilityAttr().getValue())));
1191
1192
rewriter.inlineRegionBefore(op.getBody(), fn.getBody(), fn.end());
1193
if (failed(rewriter.convertRegionTypes(&fn.getBody(), *typeConverter,
1194
&signatureConversion)))
1195
return mlir::failure();
1196
1197
rewriter.eraseOp(op);
1198
1199
return mlir::LogicalResult::success();
1200
}
1201
1202
mlir::LogicalResult CIRToLLVMGetGlobalOpLowering::matchAndRewrite(
1203
cir::GetGlobalOp op, OpAdaptor adaptor,
1204
mlir::ConversionPatternRewriter &rewriter) const {
1205
// FIXME(cir): Premature DCE to avoid lowering stuff we're not using.
1206
// CIRGen should mitigate this and not emit the get_global.
1207
if (op->getUses().empty()) {
1208
rewriter.eraseOp(op);
1209
return mlir::success();
1210
}
1211
1212
mlir::Type type = getTypeConverter()->convertType(op.getType());
1213
mlir::Operation *newop =
1214
rewriter.create<mlir::LLVM::AddressOfOp>(op.getLoc(), type, op.getName());
1215
1216
assert(!cir::MissingFeatures::opGlobalThreadLocal());
1217
1218
rewriter.replaceOp(op, newop);
1219
return mlir::success();
1220
}
1221
1222
/// Replace CIR global with a region initialized LLVM global and update
1223
/// insertion point to the end of the initializer block.
1224
void CIRToLLVMGlobalOpLowering::setupRegionInitializedLLVMGlobalOp(
1225
cir::GlobalOp op, mlir::ConversionPatternRewriter &rewriter) const {
1226
const mlir::Type llvmType =
1227
convertTypeForMemory(*getTypeConverter(), dataLayout, op.getSymType());
1228
1229
// FIXME: These default values are placeholders until the the equivalent
1230
// attributes are available on cir.global ops. This duplicates code
1231
// in CIRToLLVMGlobalOpLowering::matchAndRewrite() but that will go
1232
// away when the placeholders are no longer needed.
1233
assert(!cir::MissingFeatures::opGlobalConstant());
1234
const bool isConst = false;
1235
assert(!cir::MissingFeatures::addressSpace());
1236
const unsigned addrSpace = 0;
1237
const bool isDsoLocal = op.getDsoLocal();
1238
assert(!cir::MissingFeatures::opGlobalThreadLocal());
1239
const bool isThreadLocal = false;
1240
const uint64_t alignment = op.getAlignment().value_or(0);
1241
const mlir::LLVM::Linkage linkage = convertLinkage(op.getLinkage());
1242
const StringRef symbol = op.getSymName();
1243
mlir::SymbolRefAttr comdatAttr = getComdatAttr(op, rewriter);
1244
1245
SmallVector<mlir::NamedAttribute> attributes;
1246
mlir::LLVM::GlobalOp newGlobalOp =
1247
rewriter.replaceOpWithNewOp<mlir::LLVM::GlobalOp>(
1248
op, llvmType, isConst, linkage, symbol, nullptr, alignment, addrSpace,
1249
isDsoLocal, isThreadLocal, comdatAttr, attributes);
1250
newGlobalOp.getRegion().emplaceBlock();
1251
rewriter.setInsertionPointToEnd(newGlobalOp.getInitializerBlock());
1252
}
1253
1254
mlir::LogicalResult
1255
CIRToLLVMGlobalOpLowering::matchAndRewriteRegionInitializedGlobal(
1256
cir::GlobalOp op, mlir::Attribute init,
1257
mlir::ConversionPatternRewriter &rewriter) const {
1258
// TODO: Generalize this handling when more types are needed here.
1259
assert((isa<cir::ConstArrayAttr, cir::ConstVectorAttr, cir::ConstPtrAttr,
1260
cir::ConstComplexAttr, cir::ZeroAttr>(init)));
1261
1262
// TODO(cir): once LLVM's dialect has proper equivalent attributes this
1263
// should be updated. For now, we use a custom op to initialize globals
1264
// to the appropriate value.
1265
const mlir::Location loc = op.getLoc();
1266
setupRegionInitializedLLVMGlobalOp(op, rewriter);
1267
CIRAttrToValue valueConverter(op, rewriter, typeConverter);
1268
mlir::Value value = valueConverter.visit(init);
1269
rewriter.create<mlir::LLVM::ReturnOp>(loc, value);
1270
return mlir::success();
1271
}
1272
1273
mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite(
1274
cir::GlobalOp op, OpAdaptor adaptor,
1275
mlir::ConversionPatternRewriter &rewriter) const {
1276
1277
std::optional<mlir::Attribute> init = op.getInitialValue();
1278
1279
// Fetch required values to create LLVM op.
1280
const mlir::Type cirSymType = op.getSymType();
1281
1282
// This is the LLVM dialect type.
1283
const mlir::Type llvmType =
1284
convertTypeForMemory(*getTypeConverter(), dataLayout, cirSymType);
1285
// FIXME: These default values are placeholders until the the equivalent
1286
// attributes are available on cir.global ops.
1287
assert(!cir::MissingFeatures::opGlobalConstant());
1288
const bool isConst = false;
1289
assert(!cir::MissingFeatures::addressSpace());
1290
const unsigned addrSpace = 0;
1291
const bool isDsoLocal = op.getDsoLocal();
1292
assert(!cir::MissingFeatures::opGlobalThreadLocal());
1293
const bool isThreadLocal = false;
1294
const uint64_t alignment = op.getAlignment().value_or(0);
1295
const mlir::LLVM::Linkage linkage = convertLinkage(op.getLinkage());
1296
const StringRef symbol = op.getSymName();
1297
SmallVector<mlir::NamedAttribute> attributes;
1298
mlir::SymbolRefAttr comdatAttr = getComdatAttr(op, rewriter);
1299
1300
if (init.has_value()) {
1301
if (mlir::isa<cir::FPAttr, cir::IntAttr, cir::BoolAttr>(init.value())) {
1302
GlobalInitAttrRewriter initRewriter(llvmType, rewriter);
1303
init = initRewriter.visit(init.value());
1304
// If initRewriter returned a null attribute, init will have a value but
1305
// the value will be null. If that happens, initRewriter didn't handle the
1306
// attribute type. It probably needs to be added to
1307
// GlobalInitAttrRewriter.
1308
if (!init.value()) {
1309
op.emitError() << "unsupported initializer '" << init.value() << "'";
1310
return mlir::failure();
1311
}
1312
} else if (mlir::isa<cir::ConstArrayAttr, cir::ConstVectorAttr,
1313
cir::ConstPtrAttr, cir::ConstComplexAttr,
1314
cir::ZeroAttr>(init.value())) {
1315
// TODO(cir): once LLVM's dialect has proper equivalent attributes this
1316
// should be updated. For now, we use a custom op to initialize globals
1317
// to the appropriate value.
1318
return matchAndRewriteRegionInitializedGlobal(op, init.value(), rewriter);
1319
} else {
1320
// We will only get here if new initializer types are added and this
1321
// code is not updated to handle them.
1322
op.emitError() << "unsupported initializer '" << init.value() << "'";
1323
return mlir::failure();
1324
}
1325
}
1326
1327
// Rewrite op.
1328
rewriter.replaceOpWithNewOp<mlir::LLVM::GlobalOp>(
1329
op, llvmType, isConst, linkage, symbol, init.value_or(mlir::Attribute()),
1330
alignment, addrSpace, isDsoLocal, isThreadLocal, comdatAttr, attributes);
1331
return mlir::success();
1332
}
1333
1334
mlir::SymbolRefAttr
1335
CIRToLLVMGlobalOpLowering::getComdatAttr(cir::GlobalOp &op,
1336
mlir::OpBuilder &builder) const {
1337
if (!op.getComdat())
1338
return mlir::SymbolRefAttr{};
1339
1340
mlir::ModuleOp module = op->getParentOfType<mlir::ModuleOp>();
1341
mlir::OpBuilder::InsertionGuard guard(builder);
1342
StringRef comdatName("__llvm_comdat_globals");
1343
if (!comdatOp) {
1344
builder.setInsertionPointToStart(module.getBody());
1345
comdatOp =
1346
builder.create<mlir::LLVM::ComdatOp>(module.getLoc(), comdatName);
1347
}
1348
1349
builder.setInsertionPointToStart(&comdatOp.getBody().back());
1350
auto selectorOp = builder.create<mlir::LLVM::ComdatSelectorOp>(
1351
comdatOp.getLoc(), op.getSymName(), mlir::LLVM::comdat::Comdat::Any);
1352
return mlir::SymbolRefAttr::get(
1353
builder.getContext(), comdatName,
1354
mlir::FlatSymbolRefAttr::get(selectorOp.getSymNameAttr()));
1355
}
1356
1357
mlir::LogicalResult CIRToLLVMSwitchFlatOpLowering::matchAndRewrite(
1358
cir::SwitchFlatOp op, OpAdaptor adaptor,
1359
mlir::ConversionPatternRewriter &rewriter) const {
1360
1361
llvm::SmallVector<mlir::APInt, 8> caseValues;
1362
for (mlir::Attribute val : op.getCaseValues()) {
1363
auto intAttr = cast<cir::IntAttr>(val);
1364
caseValues.push_back(intAttr.getValue());
1365
}
1366
1367
llvm::SmallVector<mlir::Block *, 8> caseDestinations;
1368
llvm::SmallVector<mlir::ValueRange, 8> caseOperands;
1369
1370
for (mlir::Block *x : op.getCaseDestinations())
1371
caseDestinations.push_back(x);
1372
1373
for (mlir::OperandRange x : op.getCaseOperands())
1374
caseOperands.push_back(x);
1375
1376
// Set switch op to branch to the newly created blocks.
1377
rewriter.setInsertionPoint(op);
1378
rewriter.replaceOpWithNewOp<mlir::LLVM::SwitchOp>(
1379
op, adaptor.getCondition(), op.getDefaultDestination(),
1380
op.getDefaultOperands(), caseValues, caseDestinations, caseOperands);
1381
return mlir::success();
1382
}
1383
1384
mlir::LogicalResult CIRToLLVMUnaryOpLowering::matchAndRewrite(
1385
cir::UnaryOp op, OpAdaptor adaptor,
1386
mlir::ConversionPatternRewriter &rewriter) const {
1387
assert(op.getType() == op.getInput().getType() &&
1388
"Unary operation's operand type and result type are different");
1389
mlir::Type type = op.getType();
1390
mlir::Type elementType = elementTypeIfVector(type);
1391
bool isVector = mlir::isa<cir::VectorType>(type);
1392
mlir::Type llvmType = getTypeConverter()->convertType(type);
1393
mlir::Location loc = op.getLoc();
1394
1395
// Integer unary operations: + - ~ ++ --
1396
if (mlir::isa<cir::IntType>(elementType)) {
1397
mlir::LLVM::IntegerOverflowFlags maybeNSW =
1398
op.getNoSignedWrap() ? mlir::LLVM::IntegerOverflowFlags::nsw
1399
: mlir::LLVM::IntegerOverflowFlags::none;
1400
switch (op.getKind()) {
1401
case cir::UnaryOpKind::Inc: {
1402
assert(!isVector && "++ not allowed on vector types");
1403
auto one = rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, 1);
1404
rewriter.replaceOpWithNewOp<mlir::LLVM::AddOp>(
1405
op, llvmType, adaptor.getInput(), one, maybeNSW);
1406
return mlir::success();
1407
}
1408
case cir::UnaryOpKind::Dec: {
1409
assert(!isVector && "-- not allowed on vector types");
1410
auto one = rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, 1);
1411
rewriter.replaceOpWithNewOp<mlir::LLVM::SubOp>(op, adaptor.getInput(),
1412
one, maybeNSW);
1413
return mlir::success();
1414
}
1415
case cir::UnaryOpKind::Plus:
1416
rewriter.replaceOp(op, adaptor.getInput());
1417
return mlir::success();
1418
case cir::UnaryOpKind::Minus: {
1419
mlir::Value zero;
1420
if (isVector)
1421
zero = rewriter.create<mlir::LLVM::ZeroOp>(loc, llvmType);
1422
else
1423
zero = rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, 0);
1424
rewriter.replaceOpWithNewOp<mlir::LLVM::SubOp>(
1425
op, zero, adaptor.getInput(), maybeNSW);
1426
return mlir::success();
1427
}
1428
case cir::UnaryOpKind::Not: {
1429
// bit-wise compliment operator, implemented as an XOR with -1.
1430
mlir::Value minusOne;
1431
if (isVector) {
1432
const uint64_t numElements =
1433
mlir::dyn_cast<cir::VectorType>(type).getSize();
1434
std::vector<int32_t> values(numElements, -1);
1435
mlir::DenseIntElementsAttr denseVec = rewriter.getI32VectorAttr(values);
1436
minusOne =
1437
rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, denseVec);
1438
} else {
1439
minusOne = rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, -1);
1440
}
1441
rewriter.replaceOpWithNewOp<mlir::LLVM::XOrOp>(op, adaptor.getInput(),
1442
minusOne);
1443
return mlir::success();
1444
}
1445
}
1446
llvm_unreachable("Unexpected unary op for int");
1447
}
1448
1449
// Floating point unary operations: + - ++ --
1450
if (mlir::isa<cir::FPTypeInterface>(elementType)) {
1451
switch (op.getKind()) {
1452
case cir::UnaryOpKind::Inc: {
1453
assert(!isVector && "++ not allowed on vector types");
1454
mlir::LLVM::ConstantOp one = rewriter.create<mlir::LLVM::ConstantOp>(
1455
loc, llvmType, rewriter.getFloatAttr(llvmType, 1.0));
1456
rewriter.replaceOpWithNewOp<mlir::LLVM::FAddOp>(op, llvmType, one,
1457
adaptor.getInput());
1458
return mlir::success();
1459
}
1460
case cir::UnaryOpKind::Dec: {
1461
assert(!isVector && "-- not allowed on vector types");
1462
mlir::LLVM::ConstantOp minusOne = rewriter.create<mlir::LLVM::ConstantOp>(
1463
loc, llvmType, rewriter.getFloatAttr(llvmType, -1.0));
1464
rewriter.replaceOpWithNewOp<mlir::LLVM::FAddOp>(op, llvmType, minusOne,
1465
adaptor.getInput());
1466
return mlir::success();
1467
}
1468
case cir::UnaryOpKind::Plus:
1469
rewriter.replaceOp(op, adaptor.getInput());
1470
return mlir::success();
1471
case cir::UnaryOpKind::Minus:
1472
rewriter.replaceOpWithNewOp<mlir::LLVM::FNegOp>(op, llvmType,
1473
adaptor.getInput());
1474
return mlir::success();
1475
case cir::UnaryOpKind::Not:
1476
return op.emitError() << "Unary not is invalid for floating-point types";
1477
}
1478
llvm_unreachable("Unexpected unary op for float");
1479
}
1480
1481
// Boolean unary operations: ! only. (For all others, the operand has
1482
// already been promoted to int.)
1483
if (mlir::isa<cir::BoolType>(elementType)) {
1484
switch (op.getKind()) {
1485
case cir::UnaryOpKind::Inc:
1486
case cir::UnaryOpKind::Dec:
1487
case cir::UnaryOpKind::Plus:
1488
case cir::UnaryOpKind::Minus:
1489
// Some of these are allowed in source code, but we shouldn't get here
1490
// with a boolean type.
1491
return op.emitError() << "Unsupported unary operation on boolean type";
1492
case cir::UnaryOpKind::Not: {
1493
assert(!isVector && "NYI: op! on vector mask");
1494
auto one = rewriter.create<mlir::LLVM::ConstantOp>(loc, llvmType, 1);
1495
rewriter.replaceOpWithNewOp<mlir::LLVM::XOrOp>(op, adaptor.getInput(),
1496
one);
1497
return mlir::success();
1498
}
1499
}
1500
llvm_unreachable("Unexpected unary op for bool");
1501
}
1502
1503
// Pointer unary operations: + only. (++ and -- of pointers are implemented
1504
// with cir.ptr_stride, not cir.unary.)
1505
if (mlir::isa<cir::PointerType>(elementType)) {
1506
return op.emitError()
1507
<< "Unary operation on pointer types is not yet implemented";
1508
}
1509
1510
return op.emitError() << "Unary operation has unsupported type: "
1511
<< elementType;
1512
}
1513
1514
mlir::LLVM::IntegerOverflowFlags
1515
CIRToLLVMBinOpLowering::getIntOverflowFlag(cir::BinOp op) const {
1516
if (op.getNoUnsignedWrap())
1517
return mlir::LLVM::IntegerOverflowFlags::nuw;
1518
1519
if (op.getNoSignedWrap())
1520
return mlir::LLVM::IntegerOverflowFlags::nsw;
1521
1522
return mlir::LLVM::IntegerOverflowFlags::none;
1523
}
1524
1525
static bool isIntTypeUnsigned(mlir::Type type) {
1526
// TODO: Ideally, we should only need to check cir::IntType here.
1527
return mlir::isa<cir::IntType>(type)
1528
? mlir::cast<cir::IntType>(type).isUnsigned()
1529
: mlir::cast<mlir::IntegerType>(type).isUnsigned();
1530
}
1531
1532
mlir::LogicalResult CIRToLLVMBinOpLowering::matchAndRewrite(
1533
cir::BinOp op, OpAdaptor adaptor,
1534
mlir::ConversionPatternRewriter &rewriter) const {
1535
if (adaptor.getLhs().getType() != adaptor.getRhs().getType())
1536
return op.emitError() << "inconsistent operands' types not supported yet";
1537
1538
mlir::Type type = op.getRhs().getType();
1539
if (!mlir::isa<cir::IntType, cir::BoolType, cir::FPTypeInterface,
1540
mlir::IntegerType, cir::VectorType>(type))
1541
return op.emitError() << "operand type not supported yet";
1542
1543
const mlir::Type llvmTy = getTypeConverter()->convertType(op.getType());
1544
const mlir::Type llvmEltTy = elementTypeIfVector(llvmTy);
1545
1546
const mlir::Value rhs = adaptor.getRhs();
1547
const mlir::Value lhs = adaptor.getLhs();
1548
type = elementTypeIfVector(type);
1549
1550
switch (op.getKind()) {
1551
case cir::BinOpKind::Add:
1552
if (mlir::isa<mlir::IntegerType>(llvmEltTy)) {
1553
if (op.getSaturated()) {
1554
if (isIntTypeUnsigned(type)) {
1555
rewriter.replaceOpWithNewOp<mlir::LLVM::UAddSat>(op, lhs, rhs);
1556
break;
1557
}
1558
rewriter.replaceOpWithNewOp<mlir::LLVM::SAddSat>(op, lhs, rhs);
1559
break;
1560
}
1561
rewriter.replaceOpWithNewOp<mlir::LLVM::AddOp>(op, llvmTy, lhs, rhs,
1562
getIntOverflowFlag(op));
1563
} else {
1564
rewriter.replaceOpWithNewOp<mlir::LLVM::FAddOp>(op, lhs, rhs);
1565
}
1566
break;
1567
case cir::BinOpKind::Sub:
1568
if (mlir::isa<mlir::IntegerType>(llvmEltTy)) {
1569
if (op.getSaturated()) {
1570
if (isIntTypeUnsigned(type)) {
1571
rewriter.replaceOpWithNewOp<mlir::LLVM::USubSat>(op, lhs, rhs);
1572
break;
1573
}
1574
rewriter.replaceOpWithNewOp<mlir::LLVM::SSubSat>(op, lhs, rhs);
1575
break;
1576
}
1577
rewriter.replaceOpWithNewOp<mlir::LLVM::SubOp>(op, llvmTy, lhs, rhs,
1578
getIntOverflowFlag(op));
1579
} else {
1580
rewriter.replaceOpWithNewOp<mlir::LLVM::FSubOp>(op, lhs, rhs);
1581
}
1582
break;
1583
case cir::BinOpKind::Mul:
1584
if (mlir::isa<mlir::IntegerType>(llvmEltTy))
1585
rewriter.replaceOpWithNewOp<mlir::LLVM::MulOp>(op, llvmTy, lhs, rhs,
1586
getIntOverflowFlag(op));
1587
else
1588
rewriter.replaceOpWithNewOp<mlir::LLVM::FMulOp>(op, lhs, rhs);
1589
break;
1590
case cir::BinOpKind::Div:
1591
if (mlir::isa<mlir::IntegerType>(llvmEltTy)) {
1592
auto isUnsigned = isIntTypeUnsigned(type);
1593
if (isUnsigned)
1594
rewriter.replaceOpWithNewOp<mlir::LLVM::UDivOp>(op, lhs, rhs);
1595
else
1596
rewriter.replaceOpWithNewOp<mlir::LLVM::SDivOp>(op, lhs, rhs);
1597
} else {
1598
rewriter.replaceOpWithNewOp<mlir::LLVM::FDivOp>(op, lhs, rhs);
1599
}
1600
break;
1601
case cir::BinOpKind::Rem:
1602
if (mlir::isa<mlir::IntegerType>(llvmEltTy)) {
1603
auto isUnsigned = isIntTypeUnsigned(type);
1604
if (isUnsigned)
1605
rewriter.replaceOpWithNewOp<mlir::LLVM::URemOp>(op, lhs, rhs);
1606
else
1607
rewriter.replaceOpWithNewOp<mlir::LLVM::SRemOp>(op, lhs, rhs);
1608
} else {
1609
rewriter.replaceOpWithNewOp<mlir::LLVM::FRemOp>(op, lhs, rhs);
1610
}
1611
break;
1612
case cir::BinOpKind::And:
1613
rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(op, lhs, rhs);
1614
break;
1615
case cir::BinOpKind::Or:
1616
rewriter.replaceOpWithNewOp<mlir::LLVM::OrOp>(op, lhs, rhs);
1617
break;
1618
case cir::BinOpKind::Xor:
1619
rewriter.replaceOpWithNewOp<mlir::LLVM::XOrOp>(op, lhs, rhs);
1620
break;
1621
case cir::BinOpKind::Max:
1622
if (mlir::isa<mlir::IntegerType>(llvmEltTy)) {
1623
auto isUnsigned = isIntTypeUnsigned(type);
1624
if (isUnsigned)
1625
rewriter.replaceOpWithNewOp<mlir::LLVM::UMaxOp>(op, llvmTy, lhs, rhs);
1626
else
1627
rewriter.replaceOpWithNewOp<mlir::LLVM::SMaxOp>(op, llvmTy, lhs, rhs);
1628
}
1629
break;
1630
}
1631
return mlir::LogicalResult::success();
1632
}
1633
1634
/// Convert from a CIR comparison kind to an LLVM IR integral comparison kind.
1635
static mlir::LLVM::ICmpPredicate
1636
convertCmpKindToICmpPredicate(cir::CmpOpKind kind, bool isSigned) {
1637
using CIR = cir::CmpOpKind;
1638
using LLVMICmp = mlir::LLVM::ICmpPredicate;
1639
switch (kind) {
1640
case CIR::eq:
1641
return LLVMICmp::eq;
1642
case CIR::ne:
1643
return LLVMICmp::ne;
1644
case CIR::lt:
1645
return (isSigned ? LLVMICmp::slt : LLVMICmp::ult);
1646
case CIR::le:
1647
return (isSigned ? LLVMICmp::sle : LLVMICmp::ule);
1648
case CIR::gt:
1649
return (isSigned ? LLVMICmp::sgt : LLVMICmp::ugt);
1650
case CIR::ge:
1651
return (isSigned ? LLVMICmp::sge : LLVMICmp::uge);
1652
}
1653
llvm_unreachable("Unknown CmpOpKind");
1654
}
1655
1656
/// Convert from a CIR comparison kind to an LLVM IR floating-point comparison
1657
/// kind.
1658
static mlir::LLVM::FCmpPredicate
1659
convertCmpKindToFCmpPredicate(cir::CmpOpKind kind) {
1660
using CIR = cir::CmpOpKind;
1661
using LLVMFCmp = mlir::LLVM::FCmpPredicate;
1662
switch (kind) {
1663
case CIR::eq:
1664
return LLVMFCmp::oeq;
1665
case CIR::ne:
1666
return LLVMFCmp::une;
1667
case CIR::lt:
1668
return LLVMFCmp::olt;
1669
case CIR::le:
1670
return LLVMFCmp::ole;
1671
case CIR::gt:
1672
return LLVMFCmp::ogt;
1673
case CIR::ge:
1674
return LLVMFCmp::oge;
1675
}
1676
llvm_unreachable("Unknown CmpOpKind");
1677
}
1678
1679
mlir::LogicalResult CIRToLLVMCmpOpLowering::matchAndRewrite(
1680
cir::CmpOp cmpOp, OpAdaptor adaptor,
1681
mlir::ConversionPatternRewriter &rewriter) const {
1682
mlir::Type type = cmpOp.getLhs().getType();
1683
1684
assert(!cir::MissingFeatures::dataMemberType());
1685
assert(!cir::MissingFeatures::methodType());
1686
1687
if (mlir::isa<cir::IntType, mlir::IntegerType>(type)) {
1688
bool isSigned = mlir::isa<cir::IntType>(type)
1689
? mlir::cast<cir::IntType>(type).isSigned()
1690
: mlir::cast<mlir::IntegerType>(type).isSigned();
1691
mlir::LLVM::ICmpPredicate kind =
1692
convertCmpKindToICmpPredicate(cmpOp.getKind(), isSigned);
1693
rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>(
1694
cmpOp, kind, adaptor.getLhs(), adaptor.getRhs());
1695
return mlir::success();
1696
}
1697
1698
if (auto ptrTy = mlir::dyn_cast<cir::PointerType>(type)) {
1699
mlir::LLVM::ICmpPredicate kind =
1700
convertCmpKindToICmpPredicate(cmpOp.getKind(),
1701
/* isSigned=*/false);
1702
rewriter.replaceOpWithNewOp<mlir::LLVM::ICmpOp>(
1703
cmpOp, kind, adaptor.getLhs(), adaptor.getRhs());
1704
return mlir::success();
1705
}
1706
1707
if (mlir::isa<cir::FPTypeInterface>(type)) {
1708
mlir::LLVM::FCmpPredicate kind =
1709
convertCmpKindToFCmpPredicate(cmpOp.getKind());
1710
rewriter.replaceOpWithNewOp<mlir::LLVM::FCmpOp>(
1711
cmpOp, kind, adaptor.getLhs(), adaptor.getRhs());
1712
return mlir::success();
1713
}
1714
1715
if (mlir::isa<cir::ComplexType>(type)) {
1716
mlir::Value lhs = adaptor.getLhs();
1717
mlir::Value rhs = adaptor.getRhs();
1718
mlir::Location loc = cmpOp.getLoc();
1719
1720
auto complexType = mlir::cast<cir::ComplexType>(cmpOp.getLhs().getType());
1721
mlir::Type complexElemTy =
1722
getTypeConverter()->convertType(complexType.getElementType());
1723
1724
auto lhsReal =
1725
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, lhs, 0);
1726
auto lhsImag =
1727
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, lhs, 1);
1728
auto rhsReal =
1729
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, rhs, 0);
1730
auto rhsImag =
1731
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, rhs, 1);
1732
1733
if (cmpOp.getKind() == cir::CmpOpKind::eq) {
1734
if (complexElemTy.isInteger()) {
1735
auto realCmp = rewriter.create<mlir::LLVM::ICmpOp>(
1736
loc, mlir::LLVM::ICmpPredicate::eq, lhsReal, rhsReal);
1737
auto imagCmp = rewriter.create<mlir::LLVM::ICmpOp>(
1738
loc, mlir::LLVM::ICmpPredicate::eq, lhsImag, rhsImag);
1739
rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(cmpOp, realCmp, imagCmp);
1740
return mlir::success();
1741
}
1742
1743
auto realCmp = rewriter.create<mlir::LLVM::FCmpOp>(
1744
loc, mlir::LLVM::FCmpPredicate::oeq, lhsReal, rhsReal);
1745
auto imagCmp = rewriter.create<mlir::LLVM::FCmpOp>(
1746
loc, mlir::LLVM::FCmpPredicate::oeq, lhsImag, rhsImag);
1747
rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(cmpOp, realCmp, imagCmp);
1748
return mlir::success();
1749
}
1750
1751
if (cmpOp.getKind() == cir::CmpOpKind::ne) {
1752
if (complexElemTy.isInteger()) {
1753
auto realCmp = rewriter.create<mlir::LLVM::ICmpOp>(
1754
loc, mlir::LLVM::ICmpPredicate::ne, lhsReal, rhsReal);
1755
auto imagCmp = rewriter.create<mlir::LLVM::ICmpOp>(
1756
loc, mlir::LLVM::ICmpPredicate::ne, lhsImag, rhsImag);
1757
rewriter.replaceOpWithNewOp<mlir::LLVM::OrOp>(cmpOp, realCmp, imagCmp);
1758
return mlir::success();
1759
}
1760
1761
auto realCmp = rewriter.create<mlir::LLVM::FCmpOp>(
1762
loc, mlir::LLVM::FCmpPredicate::une, lhsReal, rhsReal);
1763
auto imagCmp = rewriter.create<mlir::LLVM::FCmpOp>(
1764
loc, mlir::LLVM::FCmpPredicate::une, lhsImag, rhsImag);
1765
rewriter.replaceOpWithNewOp<mlir::LLVM::OrOp>(cmpOp, realCmp, imagCmp);
1766
return mlir::success();
1767
}
1768
}
1769
1770
return cmpOp.emitError() << "unsupported type for CmpOp: " << type;
1771
}
1772
1773
mlir::LogicalResult CIRToLLVMShiftOpLowering::matchAndRewrite(
1774
cir::ShiftOp op, OpAdaptor adaptor,
1775
mlir::ConversionPatternRewriter &rewriter) const {
1776
assert((op.getValue().getType() == op.getType()) &&
1777
"inconsistent operands' types NYI");
1778
1779
const mlir::Type llvmTy = getTypeConverter()->convertType(op.getType());
1780
mlir::Value amt = adaptor.getAmount();
1781
mlir::Value val = adaptor.getValue();
1782
1783
auto cirAmtTy = mlir::dyn_cast<cir::IntType>(op.getAmount().getType());
1784
bool isUnsigned;
1785
if (cirAmtTy) {
1786
auto cirValTy = mlir::cast<cir::IntType>(op.getValue().getType());
1787
isUnsigned = cirValTy.isUnsigned();
1788
1789
// Ensure shift amount is the same type as the value. Some undefined
1790
// behavior might occur in the casts below as per [C99 6.5.7.3].
1791
// Vector type shift amount needs no cast as type consistency is expected to
1792
// be already be enforced at CIRGen.
1793
if (cirAmtTy)
1794
amt = getLLVMIntCast(rewriter, amt, llvmTy, true, cirAmtTy.getWidth(),
1795
cirValTy.getWidth());
1796
} else {
1797
auto cirValVTy = mlir::cast<cir::VectorType>(op.getValue().getType());
1798
isUnsigned =
1799
mlir::cast<cir::IntType>(cirValVTy.getElementType()).isUnsigned();
1800
}
1801
1802
// Lower to the proper LLVM shift operation.
1803
if (op.getIsShiftleft()) {
1804
rewriter.replaceOpWithNewOp<mlir::LLVM::ShlOp>(op, llvmTy, val, amt);
1805
return mlir::success();
1806
}
1807
1808
if (isUnsigned)
1809
rewriter.replaceOpWithNewOp<mlir::LLVM::LShrOp>(op, llvmTy, val, amt);
1810
else
1811
rewriter.replaceOpWithNewOp<mlir::LLVM::AShrOp>(op, llvmTy, val, amt);
1812
return mlir::success();
1813
}
1814
1815
mlir::LogicalResult CIRToLLVMSelectOpLowering::matchAndRewrite(
1816
cir::SelectOp op, OpAdaptor adaptor,
1817
mlir::ConversionPatternRewriter &rewriter) const {
1818
auto getConstantBool = [](mlir::Value value) -> cir::BoolAttr {
1819
auto definingOp =
1820
mlir::dyn_cast_if_present<cir::ConstantOp>(value.getDefiningOp());
1821
if (!definingOp)
1822
return {};
1823
1824
auto constValue = mlir::dyn_cast<cir::BoolAttr>(definingOp.getValue());
1825
if (!constValue)
1826
return {};
1827
1828
return constValue;
1829
};
1830
1831
// Two special cases in the LLVMIR codegen of select op:
1832
// - select %0, %1, false => and %0, %1
1833
// - select %0, true, %1 => or %0, %1
1834
if (mlir::isa<cir::BoolType>(op.getTrueValue().getType())) {
1835
cir::BoolAttr trueValue = getConstantBool(op.getTrueValue());
1836
cir::BoolAttr falseValue = getConstantBool(op.getFalseValue());
1837
if (falseValue && !falseValue.getValue()) {
1838
// select %0, %1, false => and %0, %1
1839
rewriter.replaceOpWithNewOp<mlir::LLVM::AndOp>(op, adaptor.getCondition(),
1840
adaptor.getTrueValue());
1841
return mlir::success();
1842
}
1843
if (trueValue && trueValue.getValue()) {
1844
// select %0, true, %1 => or %0, %1
1845
rewriter.replaceOpWithNewOp<mlir::LLVM::OrOp>(op, adaptor.getCondition(),
1846
adaptor.getFalseValue());
1847
return mlir::success();
1848
}
1849
}
1850
1851
mlir::Value llvmCondition = adaptor.getCondition();
1852
rewriter.replaceOpWithNewOp<mlir::LLVM::SelectOp>(
1853
op, llvmCondition, adaptor.getTrueValue(), adaptor.getFalseValue());
1854
1855
return mlir::success();
1856
}
1857
1858
static void prepareTypeConverter(mlir::LLVMTypeConverter &converter,
1859
mlir::DataLayout &dataLayout) {
1860
converter.addConversion([&](cir::PointerType type) -> mlir::Type {
1861
// Drop pointee type since LLVM dialect only allows opaque pointers.
1862
assert(!cir::MissingFeatures::addressSpace());
1863
unsigned targetAS = 0;
1864
1865
return mlir::LLVM::LLVMPointerType::get(type.getContext(), targetAS);
1866
});
1867
converter.addConversion([&](cir::ArrayType type) -> mlir::Type {
1868
mlir::Type ty =
1869
convertTypeForMemory(converter, dataLayout, type.getElementType());
1870
return mlir::LLVM::LLVMArrayType::get(ty, type.getSize());
1871
});
1872
converter.addConversion([&](cir::VectorType type) -> mlir::Type {
1873
const mlir::Type ty = converter.convertType(type.getElementType());
1874
return mlir::VectorType::get(type.getSize(), ty);
1875
});
1876
converter.addConversion([&](cir::BoolType type) -> mlir::Type {
1877
return mlir::IntegerType::get(type.getContext(), 1,
1878
mlir::IntegerType::Signless);
1879
});
1880
converter.addConversion([&](cir::IntType type) -> mlir::Type {
1881
// LLVM doesn't work with signed types, so we drop the CIR signs here.
1882
return mlir::IntegerType::get(type.getContext(), type.getWidth());
1883
});
1884
converter.addConversion([&](cir::SingleType type) -> mlir::Type {
1885
return mlir::Float32Type::get(type.getContext());
1886
});
1887
converter.addConversion([&](cir::DoubleType type) -> mlir::Type {
1888
return mlir::Float64Type::get(type.getContext());
1889
});
1890
converter.addConversion([&](cir::FP80Type type) -> mlir::Type {
1891
return mlir::Float80Type::get(type.getContext());
1892
});
1893
converter.addConversion([&](cir::FP128Type type) -> mlir::Type {
1894
return mlir::Float128Type::get(type.getContext());
1895
});
1896
converter.addConversion([&](cir::LongDoubleType type) -> mlir::Type {
1897
return converter.convertType(type.getUnderlying());
1898
});
1899
converter.addConversion([&](cir::FP16Type type) -> mlir::Type {
1900
return mlir::Float16Type::get(type.getContext());
1901
});
1902
converter.addConversion([&](cir::BF16Type type) -> mlir::Type {
1903
return mlir::BFloat16Type::get(type.getContext());
1904
});
1905
converter.addConversion([&](cir::ComplexType type) -> mlir::Type {
1906
// A complex type is lowered to an LLVM struct that contains the real and
1907
// imaginary part as data fields.
1908
mlir::Type elementTy = converter.convertType(type.getElementType());
1909
mlir::Type structFields[2] = {elementTy, elementTy};
1910
return mlir::LLVM::LLVMStructType::getLiteral(type.getContext(),
1911
structFields);
1912
});
1913
converter.addConversion([&](cir::FuncType type) -> std::optional<mlir::Type> {
1914
auto result = converter.convertType(type.getReturnType());
1915
llvm::SmallVector<mlir::Type> arguments;
1916
arguments.reserve(type.getNumInputs());
1917
if (converter.convertTypes(type.getInputs(), arguments).failed())
1918
return std::nullopt;
1919
auto varArg = type.isVarArg();
1920
return mlir::LLVM::LLVMFunctionType::get(result, arguments, varArg);
1921
});
1922
converter.addConversion([&](cir::RecordType type) -> mlir::Type {
1923
// Convert struct members.
1924
llvm::SmallVector<mlir::Type> llvmMembers;
1925
switch (type.getKind()) {
1926
case cir::RecordType::Class:
1927
case cir::RecordType::Struct:
1928
for (mlir::Type ty : type.getMembers())
1929
llvmMembers.push_back(convertTypeForMemory(converter, dataLayout, ty));
1930
break;
1931
// Unions are lowered as only the largest member.
1932
case cir::RecordType::Union:
1933
if (auto largestMember = type.getLargestMember(dataLayout))
1934
llvmMembers.push_back(
1935
convertTypeForMemory(converter, dataLayout, largestMember));
1936
if (type.getPadded()) {
1937
auto last = *type.getMembers().rbegin();
1938
llvmMembers.push_back(
1939
convertTypeForMemory(converter, dataLayout, last));
1940
}
1941
break;
1942
}
1943
1944
// Record has a name: lower as an identified record.
1945
mlir::LLVM::LLVMStructType llvmStruct;
1946
if (type.getName()) {
1947
llvmStruct = mlir::LLVM::LLVMStructType::getIdentified(
1948
type.getContext(), type.getPrefixedName());
1949
if (llvmStruct.setBody(llvmMembers, type.getPacked()).failed())
1950
llvm_unreachable("Failed to set body of record");
1951
} else { // Record has no name: lower as literal record.
1952
llvmStruct = mlir::LLVM::LLVMStructType::getLiteral(
1953
type.getContext(), llvmMembers, type.getPacked());
1954
}
1955
1956
return llvmStruct;
1957
});
1958
}
1959
1960
// The applyPartialConversion function traverses blocks in the dominance order,
1961
// so it does not lower and operations that are not reachachable from the
1962
// operations passed in as arguments. Since we do need to lower such code in
1963
// order to avoid verification errors occur, we cannot just pass the module op
1964
// to applyPartialConversion. We must build a set of unreachable ops and
1965
// explicitly add them, along with the module, to the vector we pass to
1966
// applyPartialConversion.
1967
//
1968
// For instance, this CIR code:
1969
//
1970
// cir.func @foo(%arg0: !s32i) -> !s32i {
1971
// %4 = cir.cast(int_to_bool, %arg0 : !s32i), !cir.bool
1972
// cir.if %4 {
1973
// %5 = cir.const #cir.int<1> : !s32i
1974
// cir.return %5 : !s32i
1975
// } else {
1976
// %5 = cir.const #cir.int<0> : !s32i
1977
// cir.return %5 : !s32i
1978
// }
1979
// cir.return %arg0 : !s32i
1980
// }
1981
//
1982
// contains an unreachable return operation (the last one). After the flattening
1983
// pass it will be placed into the unreachable block. The possible error
1984
// after the lowering pass is: error: 'cir.return' op expects parent op to be
1985
// one of 'cir.func, cir.scope, cir.if ... The reason that this operation was
1986
// not lowered and the new parent is llvm.func.
1987
//
1988
// In the future we may want to get rid of this function and use a DCE pass or
1989
// something similar. But for now we need to guarantee the absence of the
1990
// dialect verification errors.
1991
static void collectUnreachable(mlir::Operation *parent,
1992
llvm::SmallVector<mlir::Operation *> &ops) {
1993
1994
llvm::SmallVector<mlir::Block *> unreachableBlocks;
1995
parent->walk([&](mlir::Block *blk) { // check
1996
if (blk->hasNoPredecessors() && !blk->isEntryBlock())
1997
unreachableBlocks.push_back(blk);
1998
});
1999
2000
std::set<mlir::Block *> visited;
2001
for (mlir::Block *root : unreachableBlocks) {
2002
// We create a work list for each unreachable block.
2003
// Thus we traverse operations in some order.
2004
std::deque<mlir::Block *> workList;
2005
workList.push_back(root);
2006
2007
while (!workList.empty()) {
2008
mlir::Block *blk = workList.back();
2009
workList.pop_back();
2010
if (visited.count(blk))
2011
continue;
2012
visited.emplace(blk);
2013
2014
for (mlir::Operation &op : *blk)
2015
ops.push_back(&op);
2016
2017
for (mlir::Block *succ : blk->getSuccessors())
2018
workList.push_back(succ);
2019
}
2020
}
2021
}
2022
2023
void ConvertCIRToLLVMPass::processCIRAttrs(mlir::ModuleOp module) {
2024
// Lower the module attributes to LLVM equivalents.
2025
if (mlir::Attribute tripleAttr =
2026
module->getAttr(cir::CIRDialect::getTripleAttrName()))
2027
module->setAttr(mlir::LLVM::LLVMDialect::getTargetTripleAttrName(),
2028
tripleAttr);
2029
}
2030
2031
void ConvertCIRToLLVMPass::runOnOperation() {
2032
llvm::TimeTraceScope scope("Convert CIR to LLVM Pass");
2033
2034
mlir::ModuleOp module = getOperation();
2035
mlir::DataLayout dl(module);
2036
mlir::LLVMTypeConverter converter(&getContext());
2037
prepareTypeConverter(converter, dl);
2038
2039
mlir::RewritePatternSet patterns(&getContext());
2040
2041
patterns.add<CIRToLLVMReturnOpLowering>(patterns.getContext());
2042
// This could currently be merged with the group below, but it will get more
2043
// arguments later, so we'll keep it separate for now.
2044
patterns.add<CIRToLLVMAllocaOpLowering>(converter, patterns.getContext(), dl);
2045
patterns.add<CIRToLLVMLoadOpLowering>(converter, patterns.getContext(), dl);
2046
patterns.add<CIRToLLVMStoreOpLowering>(converter, patterns.getContext(), dl);
2047
patterns.add<CIRToLLVMGlobalOpLowering>(converter, patterns.getContext(), dl);
2048
patterns.add<CIRToLLVMCastOpLowering>(converter, patterns.getContext(), dl);
2049
patterns.add<CIRToLLVMPtrStrideOpLowering>(converter, patterns.getContext(),
2050
dl);
2051
patterns.add<
2052
// clang-format off
2053
CIRToLLVMAssumeOpLowering,
2054
CIRToLLVMBaseClassAddrOpLowering,
2055
CIRToLLVMBinOpLowering,
2056
CIRToLLVMBitClrsbOpLowering,
2057
CIRToLLVMBitClzOpLowering,
2058
CIRToLLVMBitCtzOpLowering,
2059
CIRToLLVMBitParityOpLowering,
2060
CIRToLLVMBitPopcountOpLowering,
2061
CIRToLLVMBitReverseOpLowering,
2062
CIRToLLVMBrCondOpLowering,
2063
CIRToLLVMBrOpLowering,
2064
CIRToLLVMByteSwapOpLowering,
2065
CIRToLLVMCallOpLowering,
2066
CIRToLLVMCmpOpLowering,
2067
CIRToLLVMComplexAddOpLowering,
2068
CIRToLLVMComplexCreateOpLowering,
2069
CIRToLLVMComplexImagOpLowering,
2070
CIRToLLVMComplexImagPtrOpLowering,
2071
CIRToLLVMComplexRealOpLowering,
2072
CIRToLLVMComplexRealPtrOpLowering,
2073
CIRToLLVMComplexSubOpLowering,
2074
CIRToLLVMConstantOpLowering,
2075
CIRToLLVMExpectOpLowering,
2076
CIRToLLVMFuncOpLowering,
2077
CIRToLLVMGetBitfieldOpLowering,
2078
CIRToLLVMGetGlobalOpLowering,
2079
CIRToLLVMGetMemberOpLowering,
2080
CIRToLLVMSelectOpLowering,
2081
CIRToLLVMSetBitfieldOpLowering,
2082
CIRToLLVMShiftOpLowering,
2083
CIRToLLVMStackRestoreOpLowering,
2084
CIRToLLVMStackSaveOpLowering,
2085
CIRToLLVMSwitchFlatOpLowering,
2086
CIRToLLVMTrapOpLowering,
2087
CIRToLLVMUnaryOpLowering,
2088
CIRToLLVMVecCmpOpLowering,
2089
CIRToLLVMVecCreateOpLowering,
2090
CIRToLLVMVecExtractOpLowering,
2091
CIRToLLVMVecInsertOpLowering,
2092
CIRToLLVMVecShuffleDynamicOpLowering,
2093
CIRToLLVMVecShuffleOpLowering,
2094
CIRToLLVMVecSplatOpLowering,
2095
CIRToLLVMVecTernaryOpLowering
2096
// clang-format on
2097
>(converter, patterns.getContext());
2098
2099
processCIRAttrs(module);
2100
2101
mlir::ConversionTarget target(getContext());
2102
target.addLegalOp<mlir::ModuleOp>();
2103
target.addLegalDialect<mlir::LLVM::LLVMDialect>();
2104
target.addIllegalDialect<mlir::BuiltinDialect, cir::CIRDialect,
2105
mlir::func::FuncDialect>();
2106
2107
llvm::SmallVector<mlir::Operation *> ops;
2108
ops.push_back(module);
2109
collectUnreachable(module, ops);
2110
2111
if (failed(applyPartialConversion(ops, target, std::move(patterns))))
2112
signalPassFailure();
2113
}
2114
2115
mlir::LogicalResult CIRToLLVMBrOpLowering::matchAndRewrite(
2116
cir::BrOp op, OpAdaptor adaptor,
2117
mlir::ConversionPatternRewriter &rewriter) const {
2118
rewriter.replaceOpWithNewOp<mlir::LLVM::BrOp>(op, adaptor.getOperands(),
2119
op.getDest());
2120
return mlir::LogicalResult::success();
2121
}
2122
2123
mlir::LogicalResult CIRToLLVMGetMemberOpLowering::matchAndRewrite(
2124
cir::GetMemberOp op, OpAdaptor adaptor,
2125
mlir::ConversionPatternRewriter &rewriter) const {
2126
mlir::Type llResTy = getTypeConverter()->convertType(op.getType());
2127
const auto recordTy =
2128
mlir::cast<cir::RecordType>(op.getAddrTy().getPointee());
2129
assert(recordTy && "expected record type");
2130
2131
switch (recordTy.getKind()) {
2132
case cir::RecordType::Class:
2133
case cir::RecordType::Struct: {
2134
// Since the base address is a pointer to an aggregate, the first offset
2135
// is always zero. The second offset tell us which member it will access.
2136
llvm::SmallVector<mlir::LLVM::GEPArg, 2> offset{0, op.getIndex()};
2137
const mlir::Type elementTy = getTypeConverter()->convertType(recordTy);
2138
rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(op, llResTy, elementTy,
2139
adaptor.getAddr(), offset);
2140
return mlir::success();
2141
}
2142
case cir::RecordType::Union:
2143
// Union members share the address space, so we just need a bitcast to
2144
// conform to type-checking.
2145
rewriter.replaceOpWithNewOp<mlir::LLVM::BitcastOp>(op, llResTy,
2146
adaptor.getAddr());
2147
return mlir::success();
2148
}
2149
}
2150
2151
mlir::LogicalResult CIRToLLVMTrapOpLowering::matchAndRewrite(
2152
cir::TrapOp op, OpAdaptor adaptor,
2153
mlir::ConversionPatternRewriter &rewriter) const {
2154
mlir::Location loc = op->getLoc();
2155
rewriter.eraseOp(op);
2156
2157
rewriter.create<mlir::LLVM::Trap>(loc);
2158
2159
// Note that the call to llvm.trap is not a terminator in LLVM dialect.
2160
// So we must emit an additional llvm.unreachable to terminate the current
2161
// block.
2162
rewriter.create<mlir::LLVM::UnreachableOp>(loc);
2163
2164
return mlir::success();
2165
}
2166
2167
mlir::LogicalResult CIRToLLVMStackSaveOpLowering::matchAndRewrite(
2168
cir::StackSaveOp op, OpAdaptor adaptor,
2169
mlir::ConversionPatternRewriter &rewriter) const {
2170
const mlir::Type ptrTy = getTypeConverter()->convertType(op.getType());
2171
rewriter.replaceOpWithNewOp<mlir::LLVM::StackSaveOp>(op, ptrTy);
2172
return mlir::success();
2173
}
2174
2175
mlir::LogicalResult CIRToLLVMStackRestoreOpLowering::matchAndRewrite(
2176
cir::StackRestoreOp op, OpAdaptor adaptor,
2177
mlir::ConversionPatternRewriter &rewriter) const {
2178
rewriter.replaceOpWithNewOp<mlir::LLVM::StackRestoreOp>(op, adaptor.getPtr());
2179
return mlir::success();
2180
}
2181
2182
mlir::LogicalResult CIRToLLVMVecCreateOpLowering::matchAndRewrite(
2183
cir::VecCreateOp op, OpAdaptor adaptor,
2184
mlir::ConversionPatternRewriter &rewriter) const {
2185
// Start with an 'undef' value for the vector. Then 'insertelement' for
2186
// each of the vector elements.
2187
const auto vecTy = mlir::cast<cir::VectorType>(op.getType());
2188
const mlir::Type llvmTy = typeConverter->convertType(vecTy);
2189
const mlir::Location loc = op.getLoc();
2190
mlir::Value result = rewriter.create<mlir::LLVM::PoisonOp>(loc, llvmTy);
2191
assert(vecTy.getSize() == op.getElements().size() &&
2192
"cir.vec.create op count doesn't match vector type elements count");
2193
2194
for (uint64_t i = 0; i < vecTy.getSize(); ++i) {
2195
const mlir::Value indexValue =
2196
rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI64Type(), i);
2197
result = rewriter.create<mlir::LLVM::InsertElementOp>(
2198
loc, result, adaptor.getElements()[i], indexValue);
2199
}
2200
2201
rewriter.replaceOp(op, result);
2202
return mlir::success();
2203
}
2204
2205
mlir::LogicalResult CIRToLLVMVecExtractOpLowering::matchAndRewrite(
2206
cir::VecExtractOp op, OpAdaptor adaptor,
2207
mlir::ConversionPatternRewriter &rewriter) const {
2208
rewriter.replaceOpWithNewOp<mlir::LLVM::ExtractElementOp>(
2209
op, adaptor.getVec(), adaptor.getIndex());
2210
return mlir::success();
2211
}
2212
2213
mlir::LogicalResult CIRToLLVMVecInsertOpLowering::matchAndRewrite(
2214
cir::VecInsertOp op, OpAdaptor adaptor,
2215
mlir::ConversionPatternRewriter &rewriter) const {
2216
rewriter.replaceOpWithNewOp<mlir::LLVM::InsertElementOp>(
2217
op, adaptor.getVec(), adaptor.getValue(), adaptor.getIndex());
2218
return mlir::success();
2219
}
2220
2221
mlir::LogicalResult CIRToLLVMVecCmpOpLowering::matchAndRewrite(
2222
cir::VecCmpOp op, OpAdaptor adaptor,
2223
mlir::ConversionPatternRewriter &rewriter) const {
2224
mlir::Type elementType = elementTypeIfVector(op.getLhs().getType());
2225
mlir::Value bitResult;
2226
if (auto intType = mlir::dyn_cast<cir::IntType>(elementType)) {
2227
bitResult = rewriter.create<mlir::LLVM::ICmpOp>(
2228
op.getLoc(),
2229
convertCmpKindToICmpPredicate(op.getKind(), intType.isSigned()),
2230
adaptor.getLhs(), adaptor.getRhs());
2231
} else if (mlir::isa<cir::FPTypeInterface>(elementType)) {
2232
bitResult = rewriter.create<mlir::LLVM::FCmpOp>(
2233
op.getLoc(), convertCmpKindToFCmpPredicate(op.getKind()),
2234
adaptor.getLhs(), adaptor.getRhs());
2235
} else {
2236
return op.emitError() << "unsupported type for VecCmpOp: " << elementType;
2237
}
2238
2239
// LLVM IR vector comparison returns a vector of i1. This one-bit vector
2240
// must be sign-extended to the correct result type.
2241
rewriter.replaceOpWithNewOp<mlir::LLVM::SExtOp>(
2242
op, typeConverter->convertType(op.getType()), bitResult);
2243
return mlir::success();
2244
}
2245
2246
mlir::LogicalResult CIRToLLVMVecSplatOpLowering::matchAndRewrite(
2247
cir::VecSplatOp op, OpAdaptor adaptor,
2248
mlir::ConversionPatternRewriter &rewriter) const {
2249
// Vector splat can be implemented with an `insertelement` and a
2250
// `shufflevector`, which is better than an `insertelement` for each
2251
// element in the vector. Start with an undef vector. Insert the value into
2252
// the first element. Then use a `shufflevector` with a mask of all 0 to
2253
// fill out the entire vector with that value.
2254
cir::VectorType vecTy = op.getType();
2255
mlir::Type llvmTy = typeConverter->convertType(vecTy);
2256
mlir::Location loc = op.getLoc();
2257
mlir::Value poison = rewriter.create<mlir::LLVM::PoisonOp>(loc, llvmTy);
2258
2259
mlir::Value elementValue = adaptor.getValue();
2260
if (mlir::isa<mlir::LLVM::PoisonOp>(elementValue.getDefiningOp())) {
2261
// If the splat value is poison, then we can just use poison value
2262
// for the entire vector.
2263
rewriter.replaceOp(op, poison);
2264
return mlir::success();
2265
}
2266
2267
if (auto constValue =
2268
dyn_cast<mlir::LLVM::ConstantOp>(elementValue.getDefiningOp())) {
2269
if (auto intAttr = dyn_cast<mlir::IntegerAttr>(constValue.getValue())) {
2270
mlir::DenseIntElementsAttr denseVec = mlir::DenseIntElementsAttr::get(
2271
mlir::cast<mlir::ShapedType>(llvmTy), intAttr.getValue());
2272
rewriter.replaceOpWithNewOp<mlir::LLVM::ConstantOp>(
2273
op, denseVec.getType(), denseVec);
2274
return mlir::success();
2275
}
2276
2277
if (auto fpAttr = dyn_cast<mlir::FloatAttr>(constValue.getValue())) {
2278
mlir::DenseFPElementsAttr denseVec = mlir::DenseFPElementsAttr::get(
2279
mlir::cast<mlir::ShapedType>(llvmTy), fpAttr.getValue());
2280
rewriter.replaceOpWithNewOp<mlir::LLVM::ConstantOp>(
2281
op, denseVec.getType(), denseVec);
2282
return mlir::success();
2283
}
2284
}
2285
2286
mlir::Value indexValue =
2287
rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI64Type(), 0);
2288
mlir::Value oneElement = rewriter.create<mlir::LLVM::InsertElementOp>(
2289
loc, poison, elementValue, indexValue);
2290
SmallVector<int32_t> zeroValues(vecTy.getSize(), 0);
2291
rewriter.replaceOpWithNewOp<mlir::LLVM::ShuffleVectorOp>(op, oneElement,
2292
poison, zeroValues);
2293
return mlir::success();
2294
}
2295
2296
mlir::LogicalResult CIRToLLVMVecShuffleOpLowering::matchAndRewrite(
2297
cir::VecShuffleOp op, OpAdaptor adaptor,
2298
mlir::ConversionPatternRewriter &rewriter) const {
2299
// LLVM::ShuffleVectorOp takes an ArrayRef of int for the list of indices.
2300
// Convert the ClangIR ArrayAttr of IntAttr constants into a
2301
// SmallVector<int>.
2302
SmallVector<int, 8> indices;
2303
std::transform(
2304
op.getIndices().begin(), op.getIndices().end(),
2305
std::back_inserter(indices), [](mlir::Attribute intAttr) {
2306
return mlir::cast<cir::IntAttr>(intAttr).getValue().getSExtValue();
2307
});
2308
rewriter.replaceOpWithNewOp<mlir::LLVM::ShuffleVectorOp>(
2309
op, adaptor.getVec1(), adaptor.getVec2(), indices);
2310
return mlir::success();
2311
}
2312
2313
mlir::LogicalResult CIRToLLVMVecShuffleDynamicOpLowering::matchAndRewrite(
2314
cir::VecShuffleDynamicOp op, OpAdaptor adaptor,
2315
mlir::ConversionPatternRewriter &rewriter) const {
2316
// LLVM IR does not have an operation that corresponds to this form of
2317
// the built-in.
2318
// __builtin_shufflevector(V, I)
2319
// is implemented as this pseudocode, where the for loop is unrolled
2320
// and N is the number of elements:
2321
//
2322
// result = undef
2323
// maskbits = NextPowerOf2(N - 1)
2324
// masked = I & maskbits
2325
// for (i in 0 <= i < N)
2326
// result[i] = V[masked[i]]
2327
mlir::Location loc = op.getLoc();
2328
mlir::Value input = adaptor.getVec();
2329
mlir::Type llvmIndexVecType =
2330
getTypeConverter()->convertType(op.getIndices().getType());
2331
mlir::Type llvmIndexType = getTypeConverter()->convertType(
2332
elementTypeIfVector(op.getIndices().getType()));
2333
uint64_t numElements =
2334
mlir::cast<cir::VectorType>(op.getVec().getType()).getSize();
2335
2336
uint64_t maskBits = llvm::NextPowerOf2(numElements - 1) - 1;
2337
mlir::Value maskValue = rewriter.create<mlir::LLVM::ConstantOp>(
2338
loc, llvmIndexType, rewriter.getIntegerAttr(llvmIndexType, maskBits));
2339
mlir::Value maskVector =
2340
rewriter.create<mlir::LLVM::UndefOp>(loc, llvmIndexVecType);
2341
2342
for (uint64_t i = 0; i < numElements; ++i) {
2343
mlir::Value idxValue =
2344
rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI64Type(), i);
2345
maskVector = rewriter.create<mlir::LLVM::InsertElementOp>(
2346
loc, maskVector, maskValue, idxValue);
2347
}
2348
2349
mlir::Value maskedIndices = rewriter.create<mlir::LLVM::AndOp>(
2350
loc, llvmIndexVecType, adaptor.getIndices(), maskVector);
2351
mlir::Value result = rewriter.create<mlir::LLVM::UndefOp>(
2352
loc, getTypeConverter()->convertType(op.getVec().getType()));
2353
for (uint64_t i = 0; i < numElements; ++i) {
2354
mlir::Value iValue =
2355
rewriter.create<mlir::LLVM::ConstantOp>(loc, rewriter.getI64Type(), i);
2356
mlir::Value indexValue = rewriter.create<mlir::LLVM::ExtractElementOp>(
2357
loc, maskedIndices, iValue);
2358
mlir::Value valueAtIndex =
2359
rewriter.create<mlir::LLVM::ExtractElementOp>(loc, input, indexValue);
2360
result = rewriter.create<mlir::LLVM::InsertElementOp>(loc, result,
2361
valueAtIndex, iValue);
2362
}
2363
rewriter.replaceOp(op, result);
2364
return mlir::success();
2365
}
2366
2367
mlir::LogicalResult CIRToLLVMVecTernaryOpLowering::matchAndRewrite(
2368
cir::VecTernaryOp op, OpAdaptor adaptor,
2369
mlir::ConversionPatternRewriter &rewriter) const {
2370
// Convert `cond` into a vector of i1, then use that in a `select` op.
2371
mlir::Value bitVec = rewriter.create<mlir::LLVM::ICmpOp>(
2372
op.getLoc(), mlir::LLVM::ICmpPredicate::ne, adaptor.getCond(),
2373
rewriter.create<mlir::LLVM::ZeroOp>(
2374
op.getCond().getLoc(),
2375
typeConverter->convertType(op.getCond().getType())));
2376
rewriter.replaceOpWithNewOp<mlir::LLVM::SelectOp>(
2377
op, bitVec, adaptor.getLhs(), adaptor.getRhs());
2378
return mlir::success();
2379
}
2380
2381
mlir::LogicalResult CIRToLLVMComplexAddOpLowering::matchAndRewrite(
2382
cir::ComplexAddOp op, OpAdaptor adaptor,
2383
mlir::ConversionPatternRewriter &rewriter) const {
2384
mlir::Value lhs = adaptor.getLhs();
2385
mlir::Value rhs = adaptor.getRhs();
2386
mlir::Location loc = op.getLoc();
2387
2388
auto complexType = mlir::cast<cir::ComplexType>(op.getLhs().getType());
2389
mlir::Type complexElemTy =
2390
getTypeConverter()->convertType(complexType.getElementType());
2391
auto lhsReal =
2392
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, lhs, 0);
2393
auto lhsImag =
2394
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, lhs, 1);
2395
auto rhsReal =
2396
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, rhs, 0);
2397
auto rhsImag =
2398
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, rhs, 1);
2399
2400
mlir::Value newReal;
2401
mlir::Value newImag;
2402
if (complexElemTy.isInteger()) {
2403
newReal = rewriter.create<mlir::LLVM::AddOp>(loc, complexElemTy, lhsReal,
2404
rhsReal);
2405
newImag = rewriter.create<mlir::LLVM::AddOp>(loc, complexElemTy, lhsImag,
2406
rhsImag);
2407
} else {
2408
assert(!cir::MissingFeatures::fastMathFlags());
2409
assert(!cir::MissingFeatures::fpConstraints());
2410
newReal = rewriter.create<mlir::LLVM::FAddOp>(loc, complexElemTy, lhsReal,
2411
rhsReal);
2412
newImag = rewriter.create<mlir::LLVM::FAddOp>(loc, complexElemTy, lhsImag,
2413
rhsImag);
2414
}
2415
2416
mlir::Type complexLLVMTy =
2417
getTypeConverter()->convertType(op.getResult().getType());
2418
auto initialComplex =
2419
rewriter.create<mlir::LLVM::PoisonOp>(op->getLoc(), complexLLVMTy);
2420
2421
auto realComplex = rewriter.create<mlir::LLVM::InsertValueOp>(
2422
op->getLoc(), initialComplex, newReal, 0);
2423
2424
rewriter.replaceOpWithNewOp<mlir::LLVM::InsertValueOp>(op, realComplex,
2425
newImag, 1);
2426
2427
return mlir::success();
2428
}
2429
2430
mlir::LogicalResult CIRToLLVMComplexCreateOpLowering::matchAndRewrite(
2431
cir::ComplexCreateOp op, OpAdaptor adaptor,
2432
mlir::ConversionPatternRewriter &rewriter) const {
2433
mlir::Type complexLLVMTy =
2434
getTypeConverter()->convertType(op.getResult().getType());
2435
auto initialComplex =
2436
rewriter.create<mlir::LLVM::UndefOp>(op->getLoc(), complexLLVMTy);
2437
2438
auto realComplex = rewriter.create<mlir::LLVM::InsertValueOp>(
2439
op->getLoc(), initialComplex, adaptor.getReal(), 0);
2440
2441
auto complex = rewriter.create<mlir::LLVM::InsertValueOp>(
2442
op->getLoc(), realComplex, adaptor.getImag(), 1);
2443
2444
rewriter.replaceOp(op, complex);
2445
return mlir::success();
2446
}
2447
2448
mlir::LogicalResult CIRToLLVMComplexRealOpLowering::matchAndRewrite(
2449
cir::ComplexRealOp op, OpAdaptor adaptor,
2450
mlir::ConversionPatternRewriter &rewriter) const {
2451
mlir::Type resultLLVMTy = getTypeConverter()->convertType(op.getType());
2452
rewriter.replaceOpWithNewOp<mlir::LLVM::ExtractValueOp>(
2453
op, resultLLVMTy, adaptor.getOperand(), llvm::ArrayRef<std::int64_t>{0});
2454
return mlir::success();
2455
}
2456
2457
mlir::LogicalResult CIRToLLVMComplexSubOpLowering::matchAndRewrite(
2458
cir::ComplexSubOp op, OpAdaptor adaptor,
2459
mlir::ConversionPatternRewriter &rewriter) const {
2460
mlir::Value lhs = adaptor.getLhs();
2461
mlir::Value rhs = adaptor.getRhs();
2462
mlir::Location loc = op.getLoc();
2463
2464
auto complexType = mlir::cast<cir::ComplexType>(op.getLhs().getType());
2465
mlir::Type complexElemTy =
2466
getTypeConverter()->convertType(complexType.getElementType());
2467
auto lhsReal =
2468
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, lhs, 0);
2469
auto lhsImag =
2470
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, lhs, 1);
2471
auto rhsReal =
2472
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, rhs, 0);
2473
auto rhsImag =
2474
rewriter.create<mlir::LLVM::ExtractValueOp>(loc, complexElemTy, rhs, 1);
2475
2476
mlir::Value newReal;
2477
mlir::Value newImag;
2478
if (complexElemTy.isInteger()) {
2479
newReal = rewriter.create<mlir::LLVM::SubOp>(loc, complexElemTy, lhsReal,
2480
rhsReal);
2481
newImag = rewriter.create<mlir::LLVM::SubOp>(loc, complexElemTy, lhsImag,
2482
rhsImag);
2483
} else {
2484
assert(!cir::MissingFeatures::fastMathFlags());
2485
assert(!cir::MissingFeatures::fpConstraints());
2486
newReal = rewriter.create<mlir::LLVM::FSubOp>(loc, complexElemTy, lhsReal,
2487
rhsReal);
2488
newImag = rewriter.create<mlir::LLVM::FSubOp>(loc, complexElemTy, lhsImag,
2489
rhsImag);
2490
}
2491
2492
mlir::Type complexLLVMTy =
2493
getTypeConverter()->convertType(op.getResult().getType());
2494
auto initialComplex =
2495
rewriter.create<mlir::LLVM::PoisonOp>(op->getLoc(), complexLLVMTy);
2496
2497
auto realComplex = rewriter.create<mlir::LLVM::InsertValueOp>(
2498
op->getLoc(), initialComplex, newReal, 0);
2499
2500
rewriter.replaceOpWithNewOp<mlir::LLVM::InsertValueOp>(op, realComplex,
2501
newImag, 1);
2502
2503
return mlir::success();
2504
}
2505
2506
mlir::LogicalResult CIRToLLVMComplexImagOpLowering::matchAndRewrite(
2507
cir::ComplexImagOp op, OpAdaptor adaptor,
2508
mlir::ConversionPatternRewriter &rewriter) const {
2509
mlir::Type resultLLVMTy = getTypeConverter()->convertType(op.getType());
2510
rewriter.replaceOpWithNewOp<mlir::LLVM::ExtractValueOp>(
2511
op, resultLLVMTy, adaptor.getOperand(), llvm::ArrayRef<std::int64_t>{1});
2512
return mlir::success();
2513
}
2514
2515
mlir::IntegerType computeBitfieldIntType(mlir::Type storageType,
2516
mlir::MLIRContext *context,
2517
unsigned &storageSize) {
2518
return TypeSwitch<mlir::Type, mlir::IntegerType>(storageType)
2519
.Case<cir::ArrayType>([&](cir::ArrayType atTy) {
2520
storageSize = atTy.getSize() * 8;
2521
return mlir::IntegerType::get(context, storageSize);
2522
})
2523
.Case<cir::IntType>([&](cir::IntType intTy) {
2524
storageSize = intTy.getWidth();
2525
return mlir::IntegerType::get(context, storageSize);
2526
})
2527
.Default([](mlir::Type) -> mlir::IntegerType {
2528
llvm_unreachable(
2529
"Either ArrayType or IntType expected for bitfields storage");
2530
});
2531
}
2532
2533
mlir::LogicalResult CIRToLLVMSetBitfieldOpLowering::matchAndRewrite(
2534
cir::SetBitfieldOp op, OpAdaptor adaptor,
2535
mlir::ConversionPatternRewriter &rewriter) const {
2536
mlir::OpBuilder::InsertionGuard guard(rewriter);
2537
rewriter.setInsertionPoint(op);
2538
2539
cir::BitfieldInfoAttr info = op.getBitfieldInfo();
2540
uint64_t size = info.getSize();
2541
uint64_t offset = info.getOffset();
2542
mlir::Type storageType = info.getStorageType();
2543
mlir::MLIRContext *context = storageType.getContext();
2544
2545
unsigned storageSize = 0;
2546
2547
mlir::IntegerType intType =
2548
computeBitfieldIntType(storageType, context, storageSize);
2549
2550
mlir::Value srcVal = createIntCast(rewriter, adaptor.getSrc(), intType);
2551
unsigned srcWidth = storageSize;
2552
mlir::Value resultVal = srcVal;
2553
2554
if (storageSize != size) {
2555
assert(storageSize > size && "Invalid bitfield size.");
2556
2557
mlir::Value val = rewriter.create<mlir::LLVM::LoadOp>(
2558
op.getLoc(), intType, adaptor.getAddr(), /* alignment */ 0,
2559
op.getIsVolatile());
2560
2561
srcVal =
2562
createAnd(rewriter, srcVal, llvm::APInt::getLowBitsSet(srcWidth, size));
2563
resultVal = srcVal;
2564
srcVal = createShL(rewriter, srcVal, offset);
2565
2566
// Mask out the original value.
2567
val = createAnd(rewriter, val,
2568
~llvm::APInt::getBitsSet(srcWidth, offset, offset + size));
2569
2570
// Or together the unchanged values and the source value.
2571
srcVal = rewriter.create<mlir::LLVM::OrOp>(op.getLoc(), val, srcVal);
2572
}
2573
2574
rewriter.create<mlir::LLVM::StoreOp>(op.getLoc(), srcVal, adaptor.getAddr(),
2575
/* alignment */ 0, op.getIsVolatile());
2576
2577
mlir::Type resultTy = getTypeConverter()->convertType(op.getType());
2578
2579
if (info.getIsSigned()) {
2580
assert(size <= storageSize);
2581
unsigned highBits = storageSize - size;
2582
2583
if (highBits) {
2584
resultVal = createShL(rewriter, resultVal, highBits);
2585
resultVal = createAShR(rewriter, resultVal, highBits);
2586
}
2587
}
2588
2589
resultVal = createIntCast(rewriter, resultVal,
2590
mlir::cast<mlir::IntegerType>(resultTy),
2591
info.getIsSigned());
2592
2593
rewriter.replaceOp(op, resultVal);
2594
return mlir::success();
2595
}
2596
2597
mlir::LogicalResult CIRToLLVMComplexImagPtrOpLowering::matchAndRewrite(
2598
cir::ComplexImagPtrOp op, OpAdaptor adaptor,
2599
mlir::ConversionPatternRewriter &rewriter) const {
2600
cir::PointerType operandTy = op.getOperand().getType();
2601
mlir::Type resultLLVMTy = getTypeConverter()->convertType(op.getType());
2602
mlir::Type elementLLVMTy =
2603
getTypeConverter()->convertType(operandTy.getPointee());
2604
2605
mlir::LLVM::GEPArg gepIndices[2] = {{0}, {1}};
2606
mlir::LLVM::GEPNoWrapFlags inboundsNuw =
2607
mlir::LLVM::GEPNoWrapFlags::inbounds | mlir::LLVM::GEPNoWrapFlags::nuw;
2608
rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(
2609
op, resultLLVMTy, elementLLVMTy, adaptor.getOperand(), gepIndices,
2610
inboundsNuw);
2611
return mlir::success();
2612
}
2613
2614
mlir::LogicalResult CIRToLLVMComplexRealPtrOpLowering::matchAndRewrite(
2615
cir::ComplexRealPtrOp op, OpAdaptor adaptor,
2616
mlir::ConversionPatternRewriter &rewriter) const {
2617
cir::PointerType operandTy = op.getOperand().getType();
2618
mlir::Type resultLLVMTy = getTypeConverter()->convertType(op.getType());
2619
mlir::Type elementLLVMTy =
2620
getTypeConverter()->convertType(operandTy.getPointee());
2621
2622
mlir::LLVM::GEPArg gepIndices[2] = {0, 0};
2623
mlir::LLVM::GEPNoWrapFlags inboundsNuw =
2624
mlir::LLVM::GEPNoWrapFlags::inbounds | mlir::LLVM::GEPNoWrapFlags::nuw;
2625
rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(
2626
op, resultLLVMTy, elementLLVMTy, adaptor.getOperand(), gepIndices,
2627
inboundsNuw);
2628
return mlir::success();
2629
}
2630
2631
mlir::LogicalResult CIRToLLVMGetBitfieldOpLowering::matchAndRewrite(
2632
cir::GetBitfieldOp op, OpAdaptor adaptor,
2633
mlir::ConversionPatternRewriter &rewriter) const {
2634
2635
mlir::OpBuilder::InsertionGuard guard(rewriter);
2636
rewriter.setInsertionPoint(op);
2637
2638
cir::BitfieldInfoAttr info = op.getBitfieldInfo();
2639
uint64_t size = info.getSize();
2640
uint64_t offset = info.getOffset();
2641
mlir::Type storageType = info.getStorageType();
2642
mlir::MLIRContext *context = storageType.getContext();
2643
unsigned storageSize = 0;
2644
2645
mlir::IntegerType intType =
2646
computeBitfieldIntType(storageType, context, storageSize);
2647
2648
mlir::Value val = rewriter.create<mlir::LLVM::LoadOp>(
2649
op.getLoc(), intType, adaptor.getAddr(), 0, op.getIsVolatile());
2650
val = rewriter.create<mlir::LLVM::BitcastOp>(op.getLoc(), intType, val);
2651
2652
if (info.getIsSigned()) {
2653
assert(static_cast<unsigned>(offset + size) <= storageSize);
2654
unsigned highBits = storageSize - offset - size;
2655
val = createShL(rewriter, val, highBits);
2656
val = createAShR(rewriter, val, offset + highBits);
2657
} else {
2658
val = createLShR(rewriter, val, offset);
2659
2660
if (static_cast<unsigned>(offset) + size < storageSize)
2661
val = createAnd(rewriter, val,
2662
llvm::APInt::getLowBitsSet(storageSize, size));
2663
}
2664
2665
mlir::Type resTy = getTypeConverter()->convertType(op.getType());
2666
mlir::Value newOp = createIntCast(
2667
rewriter, val, mlir::cast<mlir::IntegerType>(resTy), info.getIsSigned());
2668
rewriter.replaceOp(op, newOp);
2669
return mlir::success();
2670
}
2671
2672
std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
2673
return std::make_unique<ConvertCIRToLLVMPass>();
2674
}
2675
2676
void populateCIRToLLVMPasses(mlir::OpPassManager &pm) {
2677
mlir::populateCIRPreLoweringPasses(pm);
2678
pm.addPass(createConvertCIRToLLVMPass());
2679
}
2680
2681
std::unique_ptr<llvm::Module>
2682
lowerDirectlyFromCIRToLLVMIR(mlir::ModuleOp mlirModule, LLVMContext &llvmCtx) {
2683
llvm::TimeTraceScope scope("lower from CIR to LLVM directly");
2684
2685
mlir::MLIRContext *mlirCtx = mlirModule.getContext();
2686
2687
mlir::PassManager pm(mlirCtx);
2688
populateCIRToLLVMPasses(pm);
2689
2690
(void)mlir::applyPassManagerCLOptions(pm);
2691
2692
if (mlir::failed(pm.run(mlirModule))) {
2693
// FIXME: Handle any errors where they occurs and return a nullptr here.
2694
report_fatal_error(
2695
"The pass manager failed to lower CIR to LLVMIR dialect!");
2696
}
2697
2698
mlir::registerBuiltinDialectTranslation(*mlirCtx);
2699
mlir::registerLLVMDialectTranslation(*mlirCtx);
2700
mlir::registerCIRDialectTranslation(*mlirCtx);
2701
2702
llvm::TimeTraceScope translateScope("translateModuleToLLVMIR");
2703
2704
StringRef moduleName = mlirModule.getName().value_or("CIRToLLVMModule");
2705
std::unique_ptr<llvm::Module> llvmModule =
2706
mlir::translateModuleToLLVMIR(mlirModule, llvmCtx, moduleName);
2707
2708
if (!llvmModule) {
2709
// FIXME: Handle any errors where they occurs and return a nullptr here.
2710
report_fatal_error("Lowering from LLVMIR dialect to llvm IR failed!");
2711
}
2712
2713
return llvmModule;
2714
}
2715
} // namespace direct
2716
} // namespace cir
2717
2718