Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
freebsd
GitHub Repository: freebsd/freebsd-src
Path: blob/main/contrib/llvm-project/llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp
35267 views
1
//===-- AArch64Arm64ECCallLowering.cpp - Lower Arm64EC calls ----*- C++ -*-===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
///
9
/// \file
10
/// This file contains the IR transform to lower external or indirect calls for
11
/// the ARM64EC calling convention. Such calls must go through the runtime, so
12
/// we can translate the calling convention for calls into the emulator.
13
///
14
/// This subsumes Control Flow Guard handling.
15
///
16
//===----------------------------------------------------------------------===//
17
18
#include "AArch64.h"
19
#include "llvm/ADT/SetVector.h"
20
#include "llvm/ADT/SmallString.h"
21
#include "llvm/ADT/SmallVector.h"
22
#include "llvm/ADT/Statistic.h"
23
#include "llvm/IR/CallingConv.h"
24
#include "llvm/IR/GlobalAlias.h"
25
#include "llvm/IR/IRBuilder.h"
26
#include "llvm/IR/Instruction.h"
27
#include "llvm/IR/Mangler.h"
28
#include "llvm/IR/Module.h"
29
#include "llvm/InitializePasses.h"
30
#include "llvm/Object/COFF.h"
31
#include "llvm/Pass.h"
32
#include "llvm/Support/CommandLine.h"
33
#include "llvm/TargetParser/Triple.h"
34
35
using namespace llvm;
36
using namespace llvm::COFF;
37
38
using OperandBundleDef = OperandBundleDefT<Value *>;
39
40
#define DEBUG_TYPE "arm64eccalllowering"
41
42
STATISTIC(Arm64ECCallsLowered, "Number of Arm64EC calls lowered");
43
44
static cl::opt<bool> LowerDirectToIndirect("arm64ec-lower-direct-to-indirect",
45
cl::Hidden, cl::init(true));
46
static cl::opt<bool> GenerateThunks("arm64ec-generate-thunks", cl::Hidden,
47
cl::init(true));
48
49
namespace {
50
51
enum ThunkArgTranslation : uint8_t {
52
Direct,
53
Bitcast,
54
PointerIndirection,
55
};
56
57
struct ThunkArgInfo {
58
Type *Arm64Ty;
59
Type *X64Ty;
60
ThunkArgTranslation Translation;
61
};
62
63
class AArch64Arm64ECCallLowering : public ModulePass {
64
public:
65
static char ID;
66
AArch64Arm64ECCallLowering() : ModulePass(ID) {
67
initializeAArch64Arm64ECCallLoweringPass(*PassRegistry::getPassRegistry());
68
}
69
70
Function *buildExitThunk(FunctionType *FnTy, AttributeList Attrs);
71
Function *buildEntryThunk(Function *F);
72
void lowerCall(CallBase *CB);
73
Function *buildGuestExitThunk(Function *F);
74
Function *buildPatchableThunk(GlobalAlias *UnmangledAlias,
75
GlobalAlias *MangledAlias);
76
bool processFunction(Function &F, SetVector<GlobalValue *> &DirectCalledFns,
77
DenseMap<GlobalAlias *, GlobalAlias *> &FnsMap);
78
bool runOnModule(Module &M) override;
79
80
private:
81
int cfguard_module_flag = 0;
82
FunctionType *GuardFnType = nullptr;
83
PointerType *GuardFnPtrType = nullptr;
84
FunctionType *DispatchFnType = nullptr;
85
PointerType *DispatchFnPtrType = nullptr;
86
Constant *GuardFnCFGlobal = nullptr;
87
Constant *GuardFnGlobal = nullptr;
88
Constant *DispatchFnGlobal = nullptr;
89
Module *M = nullptr;
90
91
Type *PtrTy;
92
Type *I64Ty;
93
Type *VoidTy;
94
95
void getThunkType(FunctionType *FT, AttributeList AttrList,
96
Arm64ECThunkType TT, raw_ostream &Out,
97
FunctionType *&Arm64Ty, FunctionType *&X64Ty,
98
SmallVector<ThunkArgTranslation> &ArgTranslations);
99
void getThunkRetType(FunctionType *FT, AttributeList AttrList,
100
raw_ostream &Out, Type *&Arm64RetTy, Type *&X64RetTy,
101
SmallVectorImpl<Type *> &Arm64ArgTypes,
102
SmallVectorImpl<Type *> &X64ArgTypes,
103
SmallVector<ThunkArgTranslation> &ArgTranslations,
104
bool &HasSretPtr);
105
void getThunkArgTypes(FunctionType *FT, AttributeList AttrList,
106
Arm64ECThunkType TT, raw_ostream &Out,
107
SmallVectorImpl<Type *> &Arm64ArgTypes,
108
SmallVectorImpl<Type *> &X64ArgTypes,
109
SmallVectorImpl<ThunkArgTranslation> &ArgTranslations,
110
bool HasSretPtr);
111
ThunkArgInfo canonicalizeThunkType(Type *T, Align Alignment, bool Ret,
112
uint64_t ArgSizeBytes, raw_ostream &Out);
113
};
114
115
} // end anonymous namespace
116
117
void AArch64Arm64ECCallLowering::getThunkType(
118
FunctionType *FT, AttributeList AttrList, Arm64ECThunkType TT,
119
raw_ostream &Out, FunctionType *&Arm64Ty, FunctionType *&X64Ty,
120
SmallVector<ThunkArgTranslation> &ArgTranslations) {
121
Out << (TT == Arm64ECThunkType::Entry ? "$ientry_thunk$cdecl$"
122
: "$iexit_thunk$cdecl$");
123
124
Type *Arm64RetTy;
125
Type *X64RetTy;
126
127
SmallVector<Type *> Arm64ArgTypes;
128
SmallVector<Type *> X64ArgTypes;
129
130
// The first argument to a thunk is the called function, stored in x9.
131
// For exit thunks, we pass the called function down to the emulator;
132
// for entry/guest exit thunks, we just call the Arm64 function directly.
133
if (TT == Arm64ECThunkType::Exit)
134
Arm64ArgTypes.push_back(PtrTy);
135
X64ArgTypes.push_back(PtrTy);
136
137
bool HasSretPtr = false;
138
getThunkRetType(FT, AttrList, Out, Arm64RetTy, X64RetTy, Arm64ArgTypes,
139
X64ArgTypes, ArgTranslations, HasSretPtr);
140
141
getThunkArgTypes(FT, AttrList, TT, Out, Arm64ArgTypes, X64ArgTypes,
142
ArgTranslations, HasSretPtr);
143
144
Arm64Ty = FunctionType::get(Arm64RetTy, Arm64ArgTypes, false);
145
146
X64Ty = FunctionType::get(X64RetTy, X64ArgTypes, false);
147
}
148
149
void AArch64Arm64ECCallLowering::getThunkArgTypes(
150
FunctionType *FT, AttributeList AttrList, Arm64ECThunkType TT,
151
raw_ostream &Out, SmallVectorImpl<Type *> &Arm64ArgTypes,
152
SmallVectorImpl<Type *> &X64ArgTypes,
153
SmallVectorImpl<ThunkArgTranslation> &ArgTranslations, bool HasSretPtr) {
154
155
Out << "$";
156
if (FT->isVarArg()) {
157
// We treat the variadic function's thunk as a normal function
158
// with the following type on the ARM side:
159
// rettype exitthunk(
160
// ptr x9, ptr x0, i64 x1, i64 x2, i64 x3, ptr x4, i64 x5)
161
//
162
// that can coverage all types of variadic function.
163
// x9 is similar to normal exit thunk, store the called function.
164
// x0-x3 is the arguments be stored in registers.
165
// x4 is the address of the arguments on the stack.
166
// x5 is the size of the arguments on the stack.
167
//
168
// On the x64 side, it's the same except that x5 isn't set.
169
//
170
// If both the ARM and X64 sides are sret, there are only three
171
// arguments in registers.
172
//
173
// If the X64 side is sret, but the ARM side isn't, we pass an extra value
174
// to/from the X64 side, and let SelectionDAG transform it into a memory
175
// location.
176
Out << "varargs";
177
178
// x0-x3
179
for (int i = HasSretPtr ? 1 : 0; i < 4; i++) {
180
Arm64ArgTypes.push_back(I64Ty);
181
X64ArgTypes.push_back(I64Ty);
182
ArgTranslations.push_back(ThunkArgTranslation::Direct);
183
}
184
185
// x4
186
Arm64ArgTypes.push_back(PtrTy);
187
X64ArgTypes.push_back(PtrTy);
188
ArgTranslations.push_back(ThunkArgTranslation::Direct);
189
// x5
190
Arm64ArgTypes.push_back(I64Ty);
191
if (TT != Arm64ECThunkType::Entry) {
192
// FIXME: x5 isn't actually used by the x64 side; revisit once we
193
// have proper isel for varargs
194
X64ArgTypes.push_back(I64Ty);
195
ArgTranslations.push_back(ThunkArgTranslation::Direct);
196
}
197
return;
198
}
199
200
unsigned I = 0;
201
if (HasSretPtr)
202
I++;
203
204
if (I == FT->getNumParams()) {
205
Out << "v";
206
return;
207
}
208
209
for (unsigned E = FT->getNumParams(); I != E; ++I) {
210
#if 0
211
// FIXME: Need more information about argument size; see
212
// https://reviews.llvm.org/D132926
213
uint64_t ArgSizeBytes = AttrList.getParamArm64ECArgSizeBytes(I);
214
Align ParamAlign = AttrList.getParamAlignment(I).valueOrOne();
215
#else
216
uint64_t ArgSizeBytes = 0;
217
Align ParamAlign = Align();
218
#endif
219
auto [Arm64Ty, X64Ty, ArgTranslation] =
220
canonicalizeThunkType(FT->getParamType(I), ParamAlign,
221
/*Ret*/ false, ArgSizeBytes, Out);
222
Arm64ArgTypes.push_back(Arm64Ty);
223
X64ArgTypes.push_back(X64Ty);
224
ArgTranslations.push_back(ArgTranslation);
225
}
226
}
227
228
void AArch64Arm64ECCallLowering::getThunkRetType(
229
FunctionType *FT, AttributeList AttrList, raw_ostream &Out,
230
Type *&Arm64RetTy, Type *&X64RetTy, SmallVectorImpl<Type *> &Arm64ArgTypes,
231
SmallVectorImpl<Type *> &X64ArgTypes,
232
SmallVector<ThunkArgTranslation> &ArgTranslations, bool &HasSretPtr) {
233
Type *T = FT->getReturnType();
234
#if 0
235
// FIXME: Need more information about argument size; see
236
// https://reviews.llvm.org/D132926
237
uint64_t ArgSizeBytes = AttrList.getRetArm64ECArgSizeBytes();
238
#else
239
int64_t ArgSizeBytes = 0;
240
#endif
241
if (T->isVoidTy()) {
242
if (FT->getNumParams()) {
243
Attribute SRetAttr0 = AttrList.getParamAttr(0, Attribute::StructRet);
244
Attribute InRegAttr0 = AttrList.getParamAttr(0, Attribute::InReg);
245
Attribute SRetAttr1, InRegAttr1;
246
if (FT->getNumParams() > 1) {
247
// Also check the second parameter (for class methods, the first
248
// parameter is "this", and the second parameter is the sret pointer.)
249
// It doesn't matter which one is sret.
250
SRetAttr1 = AttrList.getParamAttr(1, Attribute::StructRet);
251
InRegAttr1 = AttrList.getParamAttr(1, Attribute::InReg);
252
}
253
if ((SRetAttr0.isValid() && InRegAttr0.isValid()) ||
254
(SRetAttr1.isValid() && InRegAttr1.isValid())) {
255
// sret+inreg indicates a call that returns a C++ class value. This is
256
// actually equivalent to just passing and returning a void* pointer
257
// as the first or second argument. Translate it that way, instead of
258
// trying to model "inreg" in the thunk's calling convention; this
259
// simplfies the rest of the code, and matches MSVC mangling.
260
Out << "i8";
261
Arm64RetTy = I64Ty;
262
X64RetTy = I64Ty;
263
return;
264
}
265
if (SRetAttr0.isValid()) {
266
// FIXME: Sanity-check the sret type; if it's an integer or pointer,
267
// we'll get screwy mangling/codegen.
268
// FIXME: For large struct types, mangle as an integer argument and
269
// integer return, so we can reuse more thunks, instead of "m" syntax.
270
// (MSVC mangles this case as an integer return with no argument, but
271
// that's a miscompile.)
272
Type *SRetType = SRetAttr0.getValueAsType();
273
Align SRetAlign = AttrList.getParamAlignment(0).valueOrOne();
274
canonicalizeThunkType(SRetType, SRetAlign, /*Ret*/ true, ArgSizeBytes,
275
Out);
276
Arm64RetTy = VoidTy;
277
X64RetTy = VoidTy;
278
Arm64ArgTypes.push_back(FT->getParamType(0));
279
X64ArgTypes.push_back(FT->getParamType(0));
280
ArgTranslations.push_back(ThunkArgTranslation::Direct);
281
HasSretPtr = true;
282
return;
283
}
284
}
285
286
Out << "v";
287
Arm64RetTy = VoidTy;
288
X64RetTy = VoidTy;
289
return;
290
}
291
292
auto info =
293
canonicalizeThunkType(T, Align(), /*Ret*/ true, ArgSizeBytes, Out);
294
Arm64RetTy = info.Arm64Ty;
295
X64RetTy = info.X64Ty;
296
if (X64RetTy->isPointerTy()) {
297
// If the X64 type is canonicalized to a pointer, that means it's
298
// passed/returned indirectly. For a return value, that means it's an
299
// sret pointer.
300
X64ArgTypes.push_back(X64RetTy);
301
X64RetTy = VoidTy;
302
}
303
}
304
305
ThunkArgInfo AArch64Arm64ECCallLowering::canonicalizeThunkType(
306
Type *T, Align Alignment, bool Ret, uint64_t ArgSizeBytes,
307
raw_ostream &Out) {
308
309
auto direct = [](Type *T) {
310
return ThunkArgInfo{T, T, ThunkArgTranslation::Direct};
311
};
312
313
auto bitcast = [this](Type *Arm64Ty, uint64_t SizeInBytes) {
314
return ThunkArgInfo{Arm64Ty,
315
llvm::Type::getIntNTy(M->getContext(), SizeInBytes * 8),
316
ThunkArgTranslation::Bitcast};
317
};
318
319
auto pointerIndirection = [this](Type *Arm64Ty) {
320
return ThunkArgInfo{Arm64Ty, PtrTy,
321
ThunkArgTranslation::PointerIndirection};
322
};
323
324
if (T->isFloatTy()) {
325
Out << "f";
326
return direct(T);
327
}
328
329
if (T->isDoubleTy()) {
330
Out << "d";
331
return direct(T);
332
}
333
334
if (T->isFloatingPointTy()) {
335
report_fatal_error(
336
"Only 32 and 64 bit floating points are supported for ARM64EC thunks");
337
}
338
339
auto &DL = M->getDataLayout();
340
341
if (auto *StructTy = dyn_cast<StructType>(T))
342
if (StructTy->getNumElements() == 1)
343
T = StructTy->getElementType(0);
344
345
if (T->isArrayTy()) {
346
Type *ElementTy = T->getArrayElementType();
347
uint64_t ElementCnt = T->getArrayNumElements();
348
uint64_t ElementSizePerBytes = DL.getTypeSizeInBits(ElementTy) / 8;
349
uint64_t TotalSizeBytes = ElementCnt * ElementSizePerBytes;
350
if (ElementTy->isFloatTy() || ElementTy->isDoubleTy()) {
351
Out << (ElementTy->isFloatTy() ? "F" : "D") << TotalSizeBytes;
352
if (Alignment.value() >= 16 && !Ret)
353
Out << "a" << Alignment.value();
354
if (TotalSizeBytes <= 8) {
355
// Arm64 returns small structs of float/double in float registers;
356
// X64 uses RAX.
357
return bitcast(T, TotalSizeBytes);
358
} else {
359
// Struct is passed directly on Arm64, but indirectly on X64.
360
return pointerIndirection(T);
361
}
362
} else if (T->isFloatingPointTy()) {
363
report_fatal_error("Only 32 and 64 bit floating points are supported for "
364
"ARM64EC thunks");
365
}
366
}
367
368
if ((T->isIntegerTy() || T->isPointerTy()) && DL.getTypeSizeInBits(T) <= 64) {
369
Out << "i8";
370
return direct(I64Ty);
371
}
372
373
unsigned TypeSize = ArgSizeBytes;
374
if (TypeSize == 0)
375
TypeSize = DL.getTypeSizeInBits(T) / 8;
376
Out << "m";
377
if (TypeSize != 4)
378
Out << TypeSize;
379
if (Alignment.value() >= 16 && !Ret)
380
Out << "a" << Alignment.value();
381
// FIXME: Try to canonicalize Arm64Ty more thoroughly?
382
if (TypeSize == 1 || TypeSize == 2 || TypeSize == 4 || TypeSize == 8) {
383
// Pass directly in an integer register
384
return bitcast(T, TypeSize);
385
} else {
386
// Passed directly on Arm64, but indirectly on X64.
387
return pointerIndirection(T);
388
}
389
}
390
391
// This function builds the "exit thunk", a function which translates
392
// arguments and return values when calling x64 code from AArch64 code.
393
Function *AArch64Arm64ECCallLowering::buildExitThunk(FunctionType *FT,
394
AttributeList Attrs) {
395
SmallString<256> ExitThunkName;
396
llvm::raw_svector_ostream ExitThunkStream(ExitThunkName);
397
FunctionType *Arm64Ty, *X64Ty;
398
SmallVector<ThunkArgTranslation> ArgTranslations;
399
getThunkType(FT, Attrs, Arm64ECThunkType::Exit, ExitThunkStream, Arm64Ty,
400
X64Ty, ArgTranslations);
401
if (Function *F = M->getFunction(ExitThunkName))
402
return F;
403
404
Function *F = Function::Create(Arm64Ty, GlobalValue::LinkOnceODRLinkage, 0,
405
ExitThunkName, M);
406
F->setCallingConv(CallingConv::ARM64EC_Thunk_Native);
407
F->setSection(".wowthk$aa");
408
F->setComdat(M->getOrInsertComdat(ExitThunkName));
409
// Copy MSVC, and always set up a frame pointer. (Maybe this isn't necessary.)
410
F->addFnAttr("frame-pointer", "all");
411
// Only copy sret from the first argument. For C++ instance methods, clang can
412
// stick an sret marking on a later argument, but it doesn't actually affect
413
// the ABI, so we can omit it. This avoids triggering a verifier assertion.
414
if (FT->getNumParams()) {
415
auto SRet = Attrs.getParamAttr(0, Attribute::StructRet);
416
auto InReg = Attrs.getParamAttr(0, Attribute::InReg);
417
if (SRet.isValid() && !InReg.isValid())
418
F->addParamAttr(1, SRet);
419
}
420
// FIXME: Copy anything other than sret? Shouldn't be necessary for normal
421
// C ABI, but might show up in other cases.
422
BasicBlock *BB = BasicBlock::Create(M->getContext(), "", F);
423
IRBuilder<> IRB(BB);
424
Value *CalleePtr =
425
M->getOrInsertGlobal("__os_arm64x_dispatch_call_no_redirect", PtrTy);
426
Value *Callee = IRB.CreateLoad(PtrTy, CalleePtr);
427
auto &DL = M->getDataLayout();
428
SmallVector<Value *> Args;
429
430
// Pass the called function in x9.
431
auto X64TyOffset = 1;
432
Args.push_back(F->arg_begin());
433
434
Type *RetTy = Arm64Ty->getReturnType();
435
if (RetTy != X64Ty->getReturnType()) {
436
// If the return type is an array or struct, translate it. Values of size
437
// 8 or less go into RAX; bigger values go into memory, and we pass a
438
// pointer.
439
if (DL.getTypeStoreSize(RetTy) > 8) {
440
Args.push_back(IRB.CreateAlloca(RetTy));
441
X64TyOffset++;
442
}
443
}
444
445
for (auto [Arg, X64ArgType, ArgTranslation] : llvm::zip_equal(
446
make_range(F->arg_begin() + 1, F->arg_end()),
447
make_range(X64Ty->param_begin() + X64TyOffset, X64Ty->param_end()),
448
ArgTranslations)) {
449
// Translate arguments from AArch64 calling convention to x86 calling
450
// convention.
451
//
452
// For simple types, we don't need to do any translation: they're
453
// represented the same way. (Implicit sign extension is not part of
454
// either convention.)
455
//
456
// The big thing we have to worry about is struct types... but
457
// fortunately AArch64 clang is pretty friendly here: the cases that need
458
// translation are always passed as a struct or array. (If we run into
459
// some cases where this doesn't work, we can teach clang to mark it up
460
// with an attribute.)
461
//
462
// The first argument is the called function, stored in x9.
463
if (ArgTranslation != ThunkArgTranslation::Direct) {
464
Value *Mem = IRB.CreateAlloca(Arg.getType());
465
IRB.CreateStore(&Arg, Mem);
466
if (ArgTranslation == ThunkArgTranslation::Bitcast) {
467
Type *IntTy = IRB.getIntNTy(DL.getTypeStoreSizeInBits(Arg.getType()));
468
Args.push_back(IRB.CreateLoad(IntTy, IRB.CreateBitCast(Mem, PtrTy)));
469
} else {
470
assert(ArgTranslation == ThunkArgTranslation::PointerIndirection);
471
Args.push_back(Mem);
472
}
473
} else {
474
Args.push_back(&Arg);
475
}
476
assert(Args.back()->getType() == X64ArgType);
477
}
478
// FIXME: Transfer necessary attributes? sret? anything else?
479
480
Callee = IRB.CreateBitCast(Callee, PtrTy);
481
CallInst *Call = IRB.CreateCall(X64Ty, Callee, Args);
482
Call->setCallingConv(CallingConv::ARM64EC_Thunk_X64);
483
484
Value *RetVal = Call;
485
if (RetTy != X64Ty->getReturnType()) {
486
// If we rewrote the return type earlier, convert the return value to
487
// the proper type.
488
if (DL.getTypeStoreSize(RetTy) > 8) {
489
RetVal = IRB.CreateLoad(RetTy, Args[1]);
490
} else {
491
Value *CastAlloca = IRB.CreateAlloca(RetTy);
492
IRB.CreateStore(Call, IRB.CreateBitCast(CastAlloca, PtrTy));
493
RetVal = IRB.CreateLoad(RetTy, CastAlloca);
494
}
495
}
496
497
if (RetTy->isVoidTy())
498
IRB.CreateRetVoid();
499
else
500
IRB.CreateRet(RetVal);
501
return F;
502
}
503
504
// This function builds the "entry thunk", a function which translates
505
// arguments and return values when calling AArch64 code from x64 code.
506
Function *AArch64Arm64ECCallLowering::buildEntryThunk(Function *F) {
507
SmallString<256> EntryThunkName;
508
llvm::raw_svector_ostream EntryThunkStream(EntryThunkName);
509
FunctionType *Arm64Ty, *X64Ty;
510
SmallVector<ThunkArgTranslation> ArgTranslations;
511
getThunkType(F->getFunctionType(), F->getAttributes(),
512
Arm64ECThunkType::Entry, EntryThunkStream, Arm64Ty, X64Ty,
513
ArgTranslations);
514
if (Function *F = M->getFunction(EntryThunkName))
515
return F;
516
517
Function *Thunk = Function::Create(X64Ty, GlobalValue::LinkOnceODRLinkage, 0,
518
EntryThunkName, M);
519
Thunk->setCallingConv(CallingConv::ARM64EC_Thunk_X64);
520
Thunk->setSection(".wowthk$aa");
521
Thunk->setComdat(M->getOrInsertComdat(EntryThunkName));
522
// Copy MSVC, and always set up a frame pointer. (Maybe this isn't necessary.)
523
Thunk->addFnAttr("frame-pointer", "all");
524
525
BasicBlock *BB = BasicBlock::Create(M->getContext(), "", Thunk);
526
IRBuilder<> IRB(BB);
527
528
Type *RetTy = Arm64Ty->getReturnType();
529
Type *X64RetType = X64Ty->getReturnType();
530
531
bool TransformDirectToSRet = X64RetType->isVoidTy() && !RetTy->isVoidTy();
532
unsigned ThunkArgOffset = TransformDirectToSRet ? 2 : 1;
533
unsigned PassthroughArgSize =
534
(F->isVarArg() ? 5 : Thunk->arg_size()) - ThunkArgOffset;
535
assert(ArgTranslations.size() == (F->isVarArg() ? 5 : PassthroughArgSize));
536
537
// Translate arguments to call.
538
SmallVector<Value *> Args;
539
for (unsigned i = 0; i != PassthroughArgSize; ++i) {
540
Value *Arg = Thunk->getArg(i + ThunkArgOffset);
541
Type *ArgTy = Arm64Ty->getParamType(i);
542
ThunkArgTranslation ArgTranslation = ArgTranslations[i];
543
if (ArgTranslation != ThunkArgTranslation::Direct) {
544
// Translate array/struct arguments to the expected type.
545
if (ArgTranslation == ThunkArgTranslation::Bitcast) {
546
Value *CastAlloca = IRB.CreateAlloca(ArgTy);
547
IRB.CreateStore(Arg, IRB.CreateBitCast(CastAlloca, PtrTy));
548
Arg = IRB.CreateLoad(ArgTy, CastAlloca);
549
} else {
550
assert(ArgTranslation == ThunkArgTranslation::PointerIndirection);
551
Arg = IRB.CreateLoad(ArgTy, IRB.CreateBitCast(Arg, PtrTy));
552
}
553
}
554
assert(Arg->getType() == ArgTy);
555
Args.push_back(Arg);
556
}
557
558
if (F->isVarArg()) {
559
// The 5th argument to variadic entry thunks is used to model the x64 sp
560
// which is passed to the thunk in x4, this can be passed to the callee as
561
// the variadic argument start address after skipping over the 32 byte
562
// shadow store.
563
564
// The EC thunk CC will assign any argument marked as InReg to x4.
565
Thunk->addParamAttr(5, Attribute::InReg);
566
Value *Arg = Thunk->getArg(5);
567
Arg = IRB.CreatePtrAdd(Arg, IRB.getInt64(0x20));
568
Args.push_back(Arg);
569
570
// Pass in a zero variadic argument size (in x5).
571
Args.push_back(IRB.getInt64(0));
572
}
573
574
// Call the function passed to the thunk.
575
Value *Callee = Thunk->getArg(0);
576
Callee = IRB.CreateBitCast(Callee, PtrTy);
577
CallInst *Call = IRB.CreateCall(Arm64Ty, Callee, Args);
578
579
auto SRetAttr = F->getAttributes().getParamAttr(0, Attribute::StructRet);
580
auto InRegAttr = F->getAttributes().getParamAttr(0, Attribute::InReg);
581
if (SRetAttr.isValid() && !InRegAttr.isValid()) {
582
Thunk->addParamAttr(1, SRetAttr);
583
Call->addParamAttr(0, SRetAttr);
584
}
585
586
Value *RetVal = Call;
587
if (TransformDirectToSRet) {
588
IRB.CreateStore(RetVal, IRB.CreateBitCast(Thunk->getArg(1), PtrTy));
589
} else if (X64RetType != RetTy) {
590
Value *CastAlloca = IRB.CreateAlloca(X64RetType);
591
IRB.CreateStore(Call, IRB.CreateBitCast(CastAlloca, PtrTy));
592
RetVal = IRB.CreateLoad(X64RetType, CastAlloca);
593
}
594
595
// Return to the caller. Note that the isel has code to translate this
596
// "ret" to a tail call to __os_arm64x_dispatch_ret. (Alternatively, we
597
// could emit a tail call here, but that would require a dedicated calling
598
// convention, which seems more complicated overall.)
599
if (X64RetType->isVoidTy())
600
IRB.CreateRetVoid();
601
else
602
IRB.CreateRet(RetVal);
603
604
return Thunk;
605
}
606
607
// Builds the "guest exit thunk", a helper to call a function which may or may
608
// not be an exit thunk. (We optimistically assume non-dllimport function
609
// declarations refer to functions defined in AArch64 code; if the linker
610
// can't prove that, we use this routine instead.)
611
Function *AArch64Arm64ECCallLowering::buildGuestExitThunk(Function *F) {
612
llvm::raw_null_ostream NullThunkName;
613
FunctionType *Arm64Ty, *X64Ty;
614
SmallVector<ThunkArgTranslation> ArgTranslations;
615
getThunkType(F->getFunctionType(), F->getAttributes(),
616
Arm64ECThunkType::GuestExit, NullThunkName, Arm64Ty, X64Ty,
617
ArgTranslations);
618
auto MangledName = getArm64ECMangledFunctionName(F->getName().str());
619
assert(MangledName && "Can't guest exit to function that's already native");
620
std::string ThunkName = *MangledName;
621
if (ThunkName[0] == '?' && ThunkName.find("@") != std::string::npos) {
622
ThunkName.insert(ThunkName.find("@"), "$exit_thunk");
623
} else {
624
ThunkName.append("$exit_thunk");
625
}
626
Function *GuestExit =
627
Function::Create(Arm64Ty, GlobalValue::WeakODRLinkage, 0, ThunkName, M);
628
GuestExit->setComdat(M->getOrInsertComdat(ThunkName));
629
GuestExit->setSection(".wowthk$aa");
630
GuestExit->setMetadata(
631
"arm64ec_unmangled_name",
632
MDNode::get(M->getContext(),
633
MDString::get(M->getContext(), F->getName())));
634
GuestExit->setMetadata(
635
"arm64ec_ecmangled_name",
636
MDNode::get(M->getContext(),
637
MDString::get(M->getContext(), *MangledName)));
638
F->setMetadata("arm64ec_hasguestexit", MDNode::get(M->getContext(), {}));
639
BasicBlock *BB = BasicBlock::Create(M->getContext(), "", GuestExit);
640
IRBuilder<> B(BB);
641
642
// Load the global symbol as a pointer to the check function.
643
Value *GuardFn;
644
if (cfguard_module_flag == 2 && !F->hasFnAttribute("guard_nocf"))
645
GuardFn = GuardFnCFGlobal;
646
else
647
GuardFn = GuardFnGlobal;
648
LoadInst *GuardCheckLoad = B.CreateLoad(GuardFnPtrType, GuardFn);
649
650
// Create new call instruction. The CFGuard check should always be a call,
651
// even if the original CallBase is an Invoke or CallBr instruction.
652
Function *Thunk = buildExitThunk(F->getFunctionType(), F->getAttributes());
653
CallInst *GuardCheck = B.CreateCall(
654
GuardFnType, GuardCheckLoad,
655
{B.CreateBitCast(F, B.getPtrTy()), B.CreateBitCast(Thunk, B.getPtrTy())});
656
657
// Ensure that the first argument is passed in the correct register.
658
GuardCheck->setCallingConv(CallingConv::CFGuard_Check);
659
660
Value *GuardRetVal = B.CreateBitCast(GuardCheck, PtrTy);
661
SmallVector<Value *> Args;
662
for (Argument &Arg : GuestExit->args())
663
Args.push_back(&Arg);
664
CallInst *Call = B.CreateCall(Arm64Ty, GuardRetVal, Args);
665
Call->setTailCallKind(llvm::CallInst::TCK_MustTail);
666
667
if (Call->getType()->isVoidTy())
668
B.CreateRetVoid();
669
else
670
B.CreateRet(Call);
671
672
auto SRetAttr = F->getAttributes().getParamAttr(0, Attribute::StructRet);
673
auto InRegAttr = F->getAttributes().getParamAttr(0, Attribute::InReg);
674
if (SRetAttr.isValid() && !InRegAttr.isValid()) {
675
GuestExit->addParamAttr(0, SRetAttr);
676
Call->addParamAttr(0, SRetAttr);
677
}
678
679
return GuestExit;
680
}
681
682
Function *
683
AArch64Arm64ECCallLowering::buildPatchableThunk(GlobalAlias *UnmangledAlias,
684
GlobalAlias *MangledAlias) {
685
llvm::raw_null_ostream NullThunkName;
686
FunctionType *Arm64Ty, *X64Ty;
687
Function *F = cast<Function>(MangledAlias->getAliasee());
688
SmallVector<ThunkArgTranslation> ArgTranslations;
689
getThunkType(F->getFunctionType(), F->getAttributes(),
690
Arm64ECThunkType::GuestExit, NullThunkName, Arm64Ty, X64Ty,
691
ArgTranslations);
692
std::string ThunkName(MangledAlias->getName());
693
if (ThunkName[0] == '?' && ThunkName.find("@") != std::string::npos) {
694
ThunkName.insert(ThunkName.find("@"), "$hybpatch_thunk");
695
} else {
696
ThunkName.append("$hybpatch_thunk");
697
}
698
699
Function *GuestExit =
700
Function::Create(Arm64Ty, GlobalValue::WeakODRLinkage, 0, ThunkName, M);
701
GuestExit->setComdat(M->getOrInsertComdat(ThunkName));
702
GuestExit->setSection(".wowthk$aa");
703
BasicBlock *BB = BasicBlock::Create(M->getContext(), "", GuestExit);
704
IRBuilder<> B(BB);
705
706
// Load the global symbol as a pointer to the check function.
707
LoadInst *DispatchLoad = B.CreateLoad(DispatchFnPtrType, DispatchFnGlobal);
708
709
// Create new dispatch call instruction.
710
Function *ExitThunk =
711
buildExitThunk(F->getFunctionType(), F->getAttributes());
712
CallInst *Dispatch =
713
B.CreateCall(DispatchFnType, DispatchLoad,
714
{UnmangledAlias, ExitThunk, UnmangledAlias->getAliasee()});
715
716
// Ensure that the first arguments are passed in the correct registers.
717
Dispatch->setCallingConv(CallingConv::CFGuard_Check);
718
719
Value *DispatchRetVal = B.CreateBitCast(Dispatch, PtrTy);
720
SmallVector<Value *> Args;
721
for (Argument &Arg : GuestExit->args())
722
Args.push_back(&Arg);
723
CallInst *Call = B.CreateCall(Arm64Ty, DispatchRetVal, Args);
724
Call->setTailCallKind(llvm::CallInst::TCK_MustTail);
725
726
if (Call->getType()->isVoidTy())
727
B.CreateRetVoid();
728
else
729
B.CreateRet(Call);
730
731
auto SRetAttr = F->getAttributes().getParamAttr(0, Attribute::StructRet);
732
auto InRegAttr = F->getAttributes().getParamAttr(0, Attribute::InReg);
733
if (SRetAttr.isValid() && !InRegAttr.isValid()) {
734
GuestExit->addParamAttr(0, SRetAttr);
735
Call->addParamAttr(0, SRetAttr);
736
}
737
738
MangledAlias->setAliasee(GuestExit);
739
return GuestExit;
740
}
741
742
// Lower an indirect call with inline code.
743
void AArch64Arm64ECCallLowering::lowerCall(CallBase *CB) {
744
assert(Triple(CB->getModule()->getTargetTriple()).isOSWindows() &&
745
"Only applicable for Windows targets");
746
747
IRBuilder<> B(CB);
748
Value *CalledOperand = CB->getCalledOperand();
749
750
// If the indirect call is called within catchpad or cleanuppad,
751
// we need to copy "funclet" bundle of the call.
752
SmallVector<llvm::OperandBundleDef, 1> Bundles;
753
if (auto Bundle = CB->getOperandBundle(LLVMContext::OB_funclet))
754
Bundles.push_back(OperandBundleDef(*Bundle));
755
756
// Load the global symbol as a pointer to the check function.
757
Value *GuardFn;
758
if (cfguard_module_flag == 2 && !CB->hasFnAttr("guard_nocf"))
759
GuardFn = GuardFnCFGlobal;
760
else
761
GuardFn = GuardFnGlobal;
762
LoadInst *GuardCheckLoad = B.CreateLoad(GuardFnPtrType, GuardFn);
763
764
// Create new call instruction. The CFGuard check should always be a call,
765
// even if the original CallBase is an Invoke or CallBr instruction.
766
Function *Thunk = buildExitThunk(CB->getFunctionType(), CB->getAttributes());
767
CallInst *GuardCheck =
768
B.CreateCall(GuardFnType, GuardCheckLoad,
769
{B.CreateBitCast(CalledOperand, B.getPtrTy()),
770
B.CreateBitCast(Thunk, B.getPtrTy())},
771
Bundles);
772
773
// Ensure that the first argument is passed in the correct register.
774
GuardCheck->setCallingConv(CallingConv::CFGuard_Check);
775
776
Value *GuardRetVal = B.CreateBitCast(GuardCheck, CalledOperand->getType());
777
CB->setCalledOperand(GuardRetVal);
778
}
779
780
bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
781
if (!GenerateThunks)
782
return false;
783
784
M = &Mod;
785
786
// Check if this module has the cfguard flag and read its value.
787
if (auto *MD =
788
mdconst::extract_or_null<ConstantInt>(M->getModuleFlag("cfguard")))
789
cfguard_module_flag = MD->getZExtValue();
790
791
PtrTy = PointerType::getUnqual(M->getContext());
792
I64Ty = Type::getInt64Ty(M->getContext());
793
VoidTy = Type::getVoidTy(M->getContext());
794
795
GuardFnType = FunctionType::get(PtrTy, {PtrTy, PtrTy}, false);
796
GuardFnPtrType = PointerType::get(GuardFnType, 0);
797
DispatchFnType = FunctionType::get(PtrTy, {PtrTy, PtrTy, PtrTy}, false);
798
DispatchFnPtrType = PointerType::get(DispatchFnType, 0);
799
GuardFnCFGlobal =
800
M->getOrInsertGlobal("__os_arm64x_check_icall_cfg", GuardFnPtrType);
801
GuardFnGlobal =
802
M->getOrInsertGlobal("__os_arm64x_check_icall", GuardFnPtrType);
803
DispatchFnGlobal =
804
M->getOrInsertGlobal("__os_arm64x_dispatch_call", DispatchFnPtrType);
805
806
DenseMap<GlobalAlias *, GlobalAlias *> FnsMap;
807
SetVector<GlobalAlias *> PatchableFns;
808
809
for (Function &F : Mod) {
810
if (!F.hasFnAttribute(Attribute::HybridPatchable) || F.isDeclaration() ||
811
F.hasLocalLinkage() || F.getName().ends_with("$hp_target"))
812
continue;
813
814
// Rename hybrid patchable functions and change callers to use a global
815
// alias instead.
816
if (std::optional<std::string> MangledName =
817
getArm64ECMangledFunctionName(F.getName().str())) {
818
std::string OrigName(F.getName());
819
F.setName(MangledName.value() + "$hp_target");
820
821
// The unmangled symbol is a weak alias to an undefined symbol with the
822
// "EXP+" prefix. This undefined symbol is resolved by the linker by
823
// creating an x86 thunk that jumps back to the actual EC target. Since we
824
// can't represent that in IR, we create an alias to the target instead.
825
// The "EXP+" symbol is set as metadata, which is then used by
826
// emitGlobalAlias to emit the right alias.
827
auto *A =
828
GlobalAlias::create(GlobalValue::LinkOnceODRLinkage, OrigName, &F);
829
F.replaceAllUsesWith(A);
830
F.setMetadata("arm64ec_exp_name",
831
MDNode::get(M->getContext(),
832
MDString::get(M->getContext(),
833
"EXP+" + MangledName.value())));
834
A->setAliasee(&F);
835
836
if (F.hasDLLExportStorageClass()) {
837
A->setDLLStorageClass(GlobalValue::DLLExportStorageClass);
838
F.setDLLStorageClass(GlobalValue::DefaultStorageClass);
839
}
840
841
FnsMap[A] = GlobalAlias::create(GlobalValue::LinkOnceODRLinkage,
842
MangledName.value(), &F);
843
PatchableFns.insert(A);
844
}
845
}
846
847
SetVector<GlobalValue *> DirectCalledFns;
848
for (Function &F : Mod)
849
if (!F.isDeclaration() &&
850
F.getCallingConv() != CallingConv::ARM64EC_Thunk_Native &&
851
F.getCallingConv() != CallingConv::ARM64EC_Thunk_X64)
852
processFunction(F, DirectCalledFns, FnsMap);
853
854
struct ThunkInfo {
855
Constant *Src;
856
Constant *Dst;
857
Arm64ECThunkType Kind;
858
};
859
SmallVector<ThunkInfo> ThunkMapping;
860
for (Function &F : Mod) {
861
if (!F.isDeclaration() && (!F.hasLocalLinkage() || F.hasAddressTaken()) &&
862
F.getCallingConv() != CallingConv::ARM64EC_Thunk_Native &&
863
F.getCallingConv() != CallingConv::ARM64EC_Thunk_X64) {
864
if (!F.hasComdat())
865
F.setComdat(Mod.getOrInsertComdat(F.getName()));
866
ThunkMapping.push_back(
867
{&F, buildEntryThunk(&F), Arm64ECThunkType::Entry});
868
}
869
}
870
for (GlobalValue *O : DirectCalledFns) {
871
auto GA = dyn_cast<GlobalAlias>(O);
872
auto F = dyn_cast<Function>(GA ? GA->getAliasee() : O);
873
ThunkMapping.push_back(
874
{O, buildExitThunk(F->getFunctionType(), F->getAttributes()),
875
Arm64ECThunkType::Exit});
876
if (!GA && !F->hasDLLImportStorageClass())
877
ThunkMapping.push_back(
878
{buildGuestExitThunk(F), F, Arm64ECThunkType::GuestExit});
879
}
880
for (GlobalAlias *A : PatchableFns) {
881
Function *Thunk = buildPatchableThunk(A, FnsMap[A]);
882
ThunkMapping.push_back({Thunk, A, Arm64ECThunkType::GuestExit});
883
}
884
885
if (!ThunkMapping.empty()) {
886
SmallVector<Constant *> ThunkMappingArrayElems;
887
for (ThunkInfo &Thunk : ThunkMapping) {
888
ThunkMappingArrayElems.push_back(ConstantStruct::getAnon(
889
{ConstantExpr::getBitCast(Thunk.Src, PtrTy),
890
ConstantExpr::getBitCast(Thunk.Dst, PtrTy),
891
ConstantInt::get(M->getContext(), APInt(32, uint8_t(Thunk.Kind)))}));
892
}
893
Constant *ThunkMappingArray = ConstantArray::get(
894
llvm::ArrayType::get(ThunkMappingArrayElems[0]->getType(),
895
ThunkMappingArrayElems.size()),
896
ThunkMappingArrayElems);
897
new GlobalVariable(Mod, ThunkMappingArray->getType(), /*isConstant*/ false,
898
GlobalValue::ExternalLinkage, ThunkMappingArray,
899
"llvm.arm64ec.symbolmap");
900
}
901
902
return true;
903
}
904
905
bool AArch64Arm64ECCallLowering::processFunction(
906
Function &F, SetVector<GlobalValue *> &DirectCalledFns,
907
DenseMap<GlobalAlias *, GlobalAlias *> &FnsMap) {
908
SmallVector<CallBase *, 8> IndirectCalls;
909
910
// For ARM64EC targets, a function definition's name is mangled differently
911
// from the normal symbol. We currently have no representation of this sort
912
// of symbol in IR, so we change the name to the mangled name, then store
913
// the unmangled name as metadata. Later passes that need the unmangled
914
// name (emitting the definition) can grab it from the metadata.
915
//
916
// FIXME: Handle functions with weak linkage?
917
if (!F.hasLocalLinkage() || F.hasAddressTaken()) {
918
if (std::optional<std::string> MangledName =
919
getArm64ECMangledFunctionName(F.getName().str())) {
920
F.setMetadata("arm64ec_unmangled_name",
921
MDNode::get(M->getContext(),
922
MDString::get(M->getContext(), F.getName())));
923
if (F.hasComdat() && F.getComdat()->getName() == F.getName()) {
924
Comdat *MangledComdat = M->getOrInsertComdat(MangledName.value());
925
SmallVector<GlobalObject *> ComdatUsers =
926
to_vector(F.getComdat()->getUsers());
927
for (GlobalObject *User : ComdatUsers)
928
User->setComdat(MangledComdat);
929
}
930
F.setName(MangledName.value());
931
}
932
}
933
934
// Iterate over the instructions to find all indirect call/invoke/callbr
935
// instructions. Make a separate list of pointers to indirect
936
// call/invoke/callbr instructions because the original instructions will be
937
// deleted as the checks are added.
938
for (BasicBlock &BB : F) {
939
for (Instruction &I : BB) {
940
auto *CB = dyn_cast<CallBase>(&I);
941
if (!CB || CB->getCallingConv() == CallingConv::ARM64EC_Thunk_X64 ||
942
CB->isInlineAsm())
943
continue;
944
945
// We need to instrument any call that isn't directly calling an
946
// ARM64 function.
947
//
948
// FIXME: getCalledFunction() fails if there's a bitcast (e.g.
949
// unprototyped functions in C)
950
if (Function *F = CB->getCalledFunction()) {
951
if (!LowerDirectToIndirect || F->hasLocalLinkage() ||
952
F->isIntrinsic() || !F->isDeclaration())
953
continue;
954
955
DirectCalledFns.insert(F);
956
continue;
957
}
958
959
// Use mangled global alias for direct calls to patchable functions.
960
if (GlobalAlias *A = dyn_cast<GlobalAlias>(CB->getCalledOperand())) {
961
auto I = FnsMap.find(A);
962
if (I != FnsMap.end()) {
963
CB->setCalledOperand(I->second);
964
DirectCalledFns.insert(I->first);
965
continue;
966
}
967
}
968
969
IndirectCalls.push_back(CB);
970
++Arm64ECCallsLowered;
971
}
972
}
973
974
if (IndirectCalls.empty())
975
return false;
976
977
for (CallBase *CB : IndirectCalls)
978
lowerCall(CB);
979
980
return true;
981
}
982
983
char AArch64Arm64ECCallLowering::ID = 0;
984
INITIALIZE_PASS(AArch64Arm64ECCallLowering, "Arm64ECCallLowering",
985
"AArch64Arm64ECCallLowering", false, false)
986
987
ModulePass *llvm::createAArch64Arm64ECCallLoweringPass() {
988
return new AArch64Arm64ECCallLowering;
989
}
990
991