Path: blob/main/contrib/llvm-project/llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp
35267 views
//===-- AArch64Arm64ECCallLowering.cpp - Lower Arm64EC calls ----*- C++ -*-===//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.3// See https://llvm.org/LICENSE.txt for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===----------------------------------------------------------------------===//7///8/// \file9/// This file contains the IR transform to lower external or indirect calls for10/// the ARM64EC calling convention. Such calls must go through the runtime, so11/// we can translate the calling convention for calls into the emulator.12///13/// This subsumes Control Flow Guard handling.14///15//===----------------------------------------------------------------------===//1617#include "AArch64.h"18#include "llvm/ADT/SetVector.h"19#include "llvm/ADT/SmallString.h"20#include "llvm/ADT/SmallVector.h"21#include "llvm/ADT/Statistic.h"22#include "llvm/IR/CallingConv.h"23#include "llvm/IR/GlobalAlias.h"24#include "llvm/IR/IRBuilder.h"25#include "llvm/IR/Instruction.h"26#include "llvm/IR/Mangler.h"27#include "llvm/IR/Module.h"28#include "llvm/InitializePasses.h"29#include "llvm/Object/COFF.h"30#include "llvm/Pass.h"31#include "llvm/Support/CommandLine.h"32#include "llvm/TargetParser/Triple.h"3334using namespace llvm;35using namespace llvm::COFF;3637using OperandBundleDef = OperandBundleDefT<Value *>;3839#define DEBUG_TYPE "arm64eccalllowering"4041STATISTIC(Arm64ECCallsLowered, "Number of Arm64EC calls lowered");4243static cl::opt<bool> LowerDirectToIndirect("arm64ec-lower-direct-to-indirect",44cl::Hidden, cl::init(true));45static cl::opt<bool> GenerateThunks("arm64ec-generate-thunks", cl::Hidden,46cl::init(true));4748namespace {4950enum ThunkArgTranslation : uint8_t {51Direct,52Bitcast,53PointerIndirection,54};5556struct ThunkArgInfo {57Type *Arm64Ty;58Type *X64Ty;59ThunkArgTranslation Translation;60};6162class AArch64Arm64ECCallLowering : public ModulePass {63public:64static char ID;65AArch64Arm64ECCallLowering() : ModulePass(ID) {66initializeAArch64Arm64ECCallLoweringPass(*PassRegistry::getPassRegistry());67}6869Function *buildExitThunk(FunctionType *FnTy, AttributeList Attrs);70Function *buildEntryThunk(Function *F);71void lowerCall(CallBase *CB);72Function *buildGuestExitThunk(Function *F);73Function *buildPatchableThunk(GlobalAlias *UnmangledAlias,74GlobalAlias *MangledAlias);75bool processFunction(Function &F, SetVector<GlobalValue *> &DirectCalledFns,76DenseMap<GlobalAlias *, GlobalAlias *> &FnsMap);77bool runOnModule(Module &M) override;7879private:80int cfguard_module_flag = 0;81FunctionType *GuardFnType = nullptr;82PointerType *GuardFnPtrType = nullptr;83FunctionType *DispatchFnType = nullptr;84PointerType *DispatchFnPtrType = nullptr;85Constant *GuardFnCFGlobal = nullptr;86Constant *GuardFnGlobal = nullptr;87Constant *DispatchFnGlobal = nullptr;88Module *M = nullptr;8990Type *PtrTy;91Type *I64Ty;92Type *VoidTy;9394void getThunkType(FunctionType *FT, AttributeList AttrList,95Arm64ECThunkType TT, raw_ostream &Out,96FunctionType *&Arm64Ty, FunctionType *&X64Ty,97SmallVector<ThunkArgTranslation> &ArgTranslations);98void getThunkRetType(FunctionType *FT, AttributeList AttrList,99raw_ostream &Out, Type *&Arm64RetTy, Type *&X64RetTy,100SmallVectorImpl<Type *> &Arm64ArgTypes,101SmallVectorImpl<Type *> &X64ArgTypes,102SmallVector<ThunkArgTranslation> &ArgTranslations,103bool &HasSretPtr);104void getThunkArgTypes(FunctionType *FT, AttributeList AttrList,105Arm64ECThunkType TT, raw_ostream &Out,106SmallVectorImpl<Type *> &Arm64ArgTypes,107SmallVectorImpl<Type *> &X64ArgTypes,108SmallVectorImpl<ThunkArgTranslation> &ArgTranslations,109bool HasSretPtr);110ThunkArgInfo canonicalizeThunkType(Type *T, Align Alignment, bool Ret,111uint64_t ArgSizeBytes, raw_ostream &Out);112};113114} // end anonymous namespace115116void AArch64Arm64ECCallLowering::getThunkType(117FunctionType *FT, AttributeList AttrList, Arm64ECThunkType TT,118raw_ostream &Out, FunctionType *&Arm64Ty, FunctionType *&X64Ty,119SmallVector<ThunkArgTranslation> &ArgTranslations) {120Out << (TT == Arm64ECThunkType::Entry ? "$ientry_thunk$cdecl$"121: "$iexit_thunk$cdecl$");122123Type *Arm64RetTy;124Type *X64RetTy;125126SmallVector<Type *> Arm64ArgTypes;127SmallVector<Type *> X64ArgTypes;128129// The first argument to a thunk is the called function, stored in x9.130// For exit thunks, we pass the called function down to the emulator;131// for entry/guest exit thunks, we just call the Arm64 function directly.132if (TT == Arm64ECThunkType::Exit)133Arm64ArgTypes.push_back(PtrTy);134X64ArgTypes.push_back(PtrTy);135136bool HasSretPtr = false;137getThunkRetType(FT, AttrList, Out, Arm64RetTy, X64RetTy, Arm64ArgTypes,138X64ArgTypes, ArgTranslations, HasSretPtr);139140getThunkArgTypes(FT, AttrList, TT, Out, Arm64ArgTypes, X64ArgTypes,141ArgTranslations, HasSretPtr);142143Arm64Ty = FunctionType::get(Arm64RetTy, Arm64ArgTypes, false);144145X64Ty = FunctionType::get(X64RetTy, X64ArgTypes, false);146}147148void AArch64Arm64ECCallLowering::getThunkArgTypes(149FunctionType *FT, AttributeList AttrList, Arm64ECThunkType TT,150raw_ostream &Out, SmallVectorImpl<Type *> &Arm64ArgTypes,151SmallVectorImpl<Type *> &X64ArgTypes,152SmallVectorImpl<ThunkArgTranslation> &ArgTranslations, bool HasSretPtr) {153154Out << "$";155if (FT->isVarArg()) {156// We treat the variadic function's thunk as a normal function157// with the following type on the ARM side:158// rettype exitthunk(159// ptr x9, ptr x0, i64 x1, i64 x2, i64 x3, ptr x4, i64 x5)160//161// that can coverage all types of variadic function.162// x9 is similar to normal exit thunk, store the called function.163// x0-x3 is the arguments be stored in registers.164// x4 is the address of the arguments on the stack.165// x5 is the size of the arguments on the stack.166//167// On the x64 side, it's the same except that x5 isn't set.168//169// If both the ARM and X64 sides are sret, there are only three170// arguments in registers.171//172// If the X64 side is sret, but the ARM side isn't, we pass an extra value173// to/from the X64 side, and let SelectionDAG transform it into a memory174// location.175Out << "varargs";176177// x0-x3178for (int i = HasSretPtr ? 1 : 0; i < 4; i++) {179Arm64ArgTypes.push_back(I64Ty);180X64ArgTypes.push_back(I64Ty);181ArgTranslations.push_back(ThunkArgTranslation::Direct);182}183184// x4185Arm64ArgTypes.push_back(PtrTy);186X64ArgTypes.push_back(PtrTy);187ArgTranslations.push_back(ThunkArgTranslation::Direct);188// x5189Arm64ArgTypes.push_back(I64Ty);190if (TT != Arm64ECThunkType::Entry) {191// FIXME: x5 isn't actually used by the x64 side; revisit once we192// have proper isel for varargs193X64ArgTypes.push_back(I64Ty);194ArgTranslations.push_back(ThunkArgTranslation::Direct);195}196return;197}198199unsigned I = 0;200if (HasSretPtr)201I++;202203if (I == FT->getNumParams()) {204Out << "v";205return;206}207208for (unsigned E = FT->getNumParams(); I != E; ++I) {209#if 0210// FIXME: Need more information about argument size; see211// https://reviews.llvm.org/D132926212uint64_t ArgSizeBytes = AttrList.getParamArm64ECArgSizeBytes(I);213Align ParamAlign = AttrList.getParamAlignment(I).valueOrOne();214#else215uint64_t ArgSizeBytes = 0;216Align ParamAlign = Align();217#endif218auto [Arm64Ty, X64Ty, ArgTranslation] =219canonicalizeThunkType(FT->getParamType(I), ParamAlign,220/*Ret*/ false, ArgSizeBytes, Out);221Arm64ArgTypes.push_back(Arm64Ty);222X64ArgTypes.push_back(X64Ty);223ArgTranslations.push_back(ArgTranslation);224}225}226227void AArch64Arm64ECCallLowering::getThunkRetType(228FunctionType *FT, AttributeList AttrList, raw_ostream &Out,229Type *&Arm64RetTy, Type *&X64RetTy, SmallVectorImpl<Type *> &Arm64ArgTypes,230SmallVectorImpl<Type *> &X64ArgTypes,231SmallVector<ThunkArgTranslation> &ArgTranslations, bool &HasSretPtr) {232Type *T = FT->getReturnType();233#if 0234// FIXME: Need more information about argument size; see235// https://reviews.llvm.org/D132926236uint64_t ArgSizeBytes = AttrList.getRetArm64ECArgSizeBytes();237#else238int64_t ArgSizeBytes = 0;239#endif240if (T->isVoidTy()) {241if (FT->getNumParams()) {242Attribute SRetAttr0 = AttrList.getParamAttr(0, Attribute::StructRet);243Attribute InRegAttr0 = AttrList.getParamAttr(0, Attribute::InReg);244Attribute SRetAttr1, InRegAttr1;245if (FT->getNumParams() > 1) {246// Also check the second parameter (for class methods, the first247// parameter is "this", and the second parameter is the sret pointer.)248// It doesn't matter which one is sret.249SRetAttr1 = AttrList.getParamAttr(1, Attribute::StructRet);250InRegAttr1 = AttrList.getParamAttr(1, Attribute::InReg);251}252if ((SRetAttr0.isValid() && InRegAttr0.isValid()) ||253(SRetAttr1.isValid() && InRegAttr1.isValid())) {254// sret+inreg indicates a call that returns a C++ class value. This is255// actually equivalent to just passing and returning a void* pointer256// as the first or second argument. Translate it that way, instead of257// trying to model "inreg" in the thunk's calling convention; this258// simplfies the rest of the code, and matches MSVC mangling.259Out << "i8";260Arm64RetTy = I64Ty;261X64RetTy = I64Ty;262return;263}264if (SRetAttr0.isValid()) {265// FIXME: Sanity-check the sret type; if it's an integer or pointer,266// we'll get screwy mangling/codegen.267// FIXME: For large struct types, mangle as an integer argument and268// integer return, so we can reuse more thunks, instead of "m" syntax.269// (MSVC mangles this case as an integer return with no argument, but270// that's a miscompile.)271Type *SRetType = SRetAttr0.getValueAsType();272Align SRetAlign = AttrList.getParamAlignment(0).valueOrOne();273canonicalizeThunkType(SRetType, SRetAlign, /*Ret*/ true, ArgSizeBytes,274Out);275Arm64RetTy = VoidTy;276X64RetTy = VoidTy;277Arm64ArgTypes.push_back(FT->getParamType(0));278X64ArgTypes.push_back(FT->getParamType(0));279ArgTranslations.push_back(ThunkArgTranslation::Direct);280HasSretPtr = true;281return;282}283}284285Out << "v";286Arm64RetTy = VoidTy;287X64RetTy = VoidTy;288return;289}290291auto info =292canonicalizeThunkType(T, Align(), /*Ret*/ true, ArgSizeBytes, Out);293Arm64RetTy = info.Arm64Ty;294X64RetTy = info.X64Ty;295if (X64RetTy->isPointerTy()) {296// If the X64 type is canonicalized to a pointer, that means it's297// passed/returned indirectly. For a return value, that means it's an298// sret pointer.299X64ArgTypes.push_back(X64RetTy);300X64RetTy = VoidTy;301}302}303304ThunkArgInfo AArch64Arm64ECCallLowering::canonicalizeThunkType(305Type *T, Align Alignment, bool Ret, uint64_t ArgSizeBytes,306raw_ostream &Out) {307308auto direct = [](Type *T) {309return ThunkArgInfo{T, T, ThunkArgTranslation::Direct};310};311312auto bitcast = [this](Type *Arm64Ty, uint64_t SizeInBytes) {313return ThunkArgInfo{Arm64Ty,314llvm::Type::getIntNTy(M->getContext(), SizeInBytes * 8),315ThunkArgTranslation::Bitcast};316};317318auto pointerIndirection = [this](Type *Arm64Ty) {319return ThunkArgInfo{Arm64Ty, PtrTy,320ThunkArgTranslation::PointerIndirection};321};322323if (T->isFloatTy()) {324Out << "f";325return direct(T);326}327328if (T->isDoubleTy()) {329Out << "d";330return direct(T);331}332333if (T->isFloatingPointTy()) {334report_fatal_error(335"Only 32 and 64 bit floating points are supported for ARM64EC thunks");336}337338auto &DL = M->getDataLayout();339340if (auto *StructTy = dyn_cast<StructType>(T))341if (StructTy->getNumElements() == 1)342T = StructTy->getElementType(0);343344if (T->isArrayTy()) {345Type *ElementTy = T->getArrayElementType();346uint64_t ElementCnt = T->getArrayNumElements();347uint64_t ElementSizePerBytes = DL.getTypeSizeInBits(ElementTy) / 8;348uint64_t TotalSizeBytes = ElementCnt * ElementSizePerBytes;349if (ElementTy->isFloatTy() || ElementTy->isDoubleTy()) {350Out << (ElementTy->isFloatTy() ? "F" : "D") << TotalSizeBytes;351if (Alignment.value() >= 16 && !Ret)352Out << "a" << Alignment.value();353if (TotalSizeBytes <= 8) {354// Arm64 returns small structs of float/double in float registers;355// X64 uses RAX.356return bitcast(T, TotalSizeBytes);357} else {358// Struct is passed directly on Arm64, but indirectly on X64.359return pointerIndirection(T);360}361} else if (T->isFloatingPointTy()) {362report_fatal_error("Only 32 and 64 bit floating points are supported for "363"ARM64EC thunks");364}365}366367if ((T->isIntegerTy() || T->isPointerTy()) && DL.getTypeSizeInBits(T) <= 64) {368Out << "i8";369return direct(I64Ty);370}371372unsigned TypeSize = ArgSizeBytes;373if (TypeSize == 0)374TypeSize = DL.getTypeSizeInBits(T) / 8;375Out << "m";376if (TypeSize != 4)377Out << TypeSize;378if (Alignment.value() >= 16 && !Ret)379Out << "a" << Alignment.value();380// FIXME: Try to canonicalize Arm64Ty more thoroughly?381if (TypeSize == 1 || TypeSize == 2 || TypeSize == 4 || TypeSize == 8) {382// Pass directly in an integer register383return bitcast(T, TypeSize);384} else {385// Passed directly on Arm64, but indirectly on X64.386return pointerIndirection(T);387}388}389390// This function builds the "exit thunk", a function which translates391// arguments and return values when calling x64 code from AArch64 code.392Function *AArch64Arm64ECCallLowering::buildExitThunk(FunctionType *FT,393AttributeList Attrs) {394SmallString<256> ExitThunkName;395llvm::raw_svector_ostream ExitThunkStream(ExitThunkName);396FunctionType *Arm64Ty, *X64Ty;397SmallVector<ThunkArgTranslation> ArgTranslations;398getThunkType(FT, Attrs, Arm64ECThunkType::Exit, ExitThunkStream, Arm64Ty,399X64Ty, ArgTranslations);400if (Function *F = M->getFunction(ExitThunkName))401return F;402403Function *F = Function::Create(Arm64Ty, GlobalValue::LinkOnceODRLinkage, 0,404ExitThunkName, M);405F->setCallingConv(CallingConv::ARM64EC_Thunk_Native);406F->setSection(".wowthk$aa");407F->setComdat(M->getOrInsertComdat(ExitThunkName));408// Copy MSVC, and always set up a frame pointer. (Maybe this isn't necessary.)409F->addFnAttr("frame-pointer", "all");410// Only copy sret from the first argument. For C++ instance methods, clang can411// stick an sret marking on a later argument, but it doesn't actually affect412// the ABI, so we can omit it. This avoids triggering a verifier assertion.413if (FT->getNumParams()) {414auto SRet = Attrs.getParamAttr(0, Attribute::StructRet);415auto InReg = Attrs.getParamAttr(0, Attribute::InReg);416if (SRet.isValid() && !InReg.isValid())417F->addParamAttr(1, SRet);418}419// FIXME: Copy anything other than sret? Shouldn't be necessary for normal420// C ABI, but might show up in other cases.421BasicBlock *BB = BasicBlock::Create(M->getContext(), "", F);422IRBuilder<> IRB(BB);423Value *CalleePtr =424M->getOrInsertGlobal("__os_arm64x_dispatch_call_no_redirect", PtrTy);425Value *Callee = IRB.CreateLoad(PtrTy, CalleePtr);426auto &DL = M->getDataLayout();427SmallVector<Value *> Args;428429// Pass the called function in x9.430auto X64TyOffset = 1;431Args.push_back(F->arg_begin());432433Type *RetTy = Arm64Ty->getReturnType();434if (RetTy != X64Ty->getReturnType()) {435// If the return type is an array or struct, translate it. Values of size436// 8 or less go into RAX; bigger values go into memory, and we pass a437// pointer.438if (DL.getTypeStoreSize(RetTy) > 8) {439Args.push_back(IRB.CreateAlloca(RetTy));440X64TyOffset++;441}442}443444for (auto [Arg, X64ArgType, ArgTranslation] : llvm::zip_equal(445make_range(F->arg_begin() + 1, F->arg_end()),446make_range(X64Ty->param_begin() + X64TyOffset, X64Ty->param_end()),447ArgTranslations)) {448// Translate arguments from AArch64 calling convention to x86 calling449// convention.450//451// For simple types, we don't need to do any translation: they're452// represented the same way. (Implicit sign extension is not part of453// either convention.)454//455// The big thing we have to worry about is struct types... but456// fortunately AArch64 clang is pretty friendly here: the cases that need457// translation are always passed as a struct or array. (If we run into458// some cases where this doesn't work, we can teach clang to mark it up459// with an attribute.)460//461// The first argument is the called function, stored in x9.462if (ArgTranslation != ThunkArgTranslation::Direct) {463Value *Mem = IRB.CreateAlloca(Arg.getType());464IRB.CreateStore(&Arg, Mem);465if (ArgTranslation == ThunkArgTranslation::Bitcast) {466Type *IntTy = IRB.getIntNTy(DL.getTypeStoreSizeInBits(Arg.getType()));467Args.push_back(IRB.CreateLoad(IntTy, IRB.CreateBitCast(Mem, PtrTy)));468} else {469assert(ArgTranslation == ThunkArgTranslation::PointerIndirection);470Args.push_back(Mem);471}472} else {473Args.push_back(&Arg);474}475assert(Args.back()->getType() == X64ArgType);476}477// FIXME: Transfer necessary attributes? sret? anything else?478479Callee = IRB.CreateBitCast(Callee, PtrTy);480CallInst *Call = IRB.CreateCall(X64Ty, Callee, Args);481Call->setCallingConv(CallingConv::ARM64EC_Thunk_X64);482483Value *RetVal = Call;484if (RetTy != X64Ty->getReturnType()) {485// If we rewrote the return type earlier, convert the return value to486// the proper type.487if (DL.getTypeStoreSize(RetTy) > 8) {488RetVal = IRB.CreateLoad(RetTy, Args[1]);489} else {490Value *CastAlloca = IRB.CreateAlloca(RetTy);491IRB.CreateStore(Call, IRB.CreateBitCast(CastAlloca, PtrTy));492RetVal = IRB.CreateLoad(RetTy, CastAlloca);493}494}495496if (RetTy->isVoidTy())497IRB.CreateRetVoid();498else499IRB.CreateRet(RetVal);500return F;501}502503// This function builds the "entry thunk", a function which translates504// arguments and return values when calling AArch64 code from x64 code.505Function *AArch64Arm64ECCallLowering::buildEntryThunk(Function *F) {506SmallString<256> EntryThunkName;507llvm::raw_svector_ostream EntryThunkStream(EntryThunkName);508FunctionType *Arm64Ty, *X64Ty;509SmallVector<ThunkArgTranslation> ArgTranslations;510getThunkType(F->getFunctionType(), F->getAttributes(),511Arm64ECThunkType::Entry, EntryThunkStream, Arm64Ty, X64Ty,512ArgTranslations);513if (Function *F = M->getFunction(EntryThunkName))514return F;515516Function *Thunk = Function::Create(X64Ty, GlobalValue::LinkOnceODRLinkage, 0,517EntryThunkName, M);518Thunk->setCallingConv(CallingConv::ARM64EC_Thunk_X64);519Thunk->setSection(".wowthk$aa");520Thunk->setComdat(M->getOrInsertComdat(EntryThunkName));521// Copy MSVC, and always set up a frame pointer. (Maybe this isn't necessary.)522Thunk->addFnAttr("frame-pointer", "all");523524BasicBlock *BB = BasicBlock::Create(M->getContext(), "", Thunk);525IRBuilder<> IRB(BB);526527Type *RetTy = Arm64Ty->getReturnType();528Type *X64RetType = X64Ty->getReturnType();529530bool TransformDirectToSRet = X64RetType->isVoidTy() && !RetTy->isVoidTy();531unsigned ThunkArgOffset = TransformDirectToSRet ? 2 : 1;532unsigned PassthroughArgSize =533(F->isVarArg() ? 5 : Thunk->arg_size()) - ThunkArgOffset;534assert(ArgTranslations.size() == (F->isVarArg() ? 5 : PassthroughArgSize));535536// Translate arguments to call.537SmallVector<Value *> Args;538for (unsigned i = 0; i != PassthroughArgSize; ++i) {539Value *Arg = Thunk->getArg(i + ThunkArgOffset);540Type *ArgTy = Arm64Ty->getParamType(i);541ThunkArgTranslation ArgTranslation = ArgTranslations[i];542if (ArgTranslation != ThunkArgTranslation::Direct) {543// Translate array/struct arguments to the expected type.544if (ArgTranslation == ThunkArgTranslation::Bitcast) {545Value *CastAlloca = IRB.CreateAlloca(ArgTy);546IRB.CreateStore(Arg, IRB.CreateBitCast(CastAlloca, PtrTy));547Arg = IRB.CreateLoad(ArgTy, CastAlloca);548} else {549assert(ArgTranslation == ThunkArgTranslation::PointerIndirection);550Arg = IRB.CreateLoad(ArgTy, IRB.CreateBitCast(Arg, PtrTy));551}552}553assert(Arg->getType() == ArgTy);554Args.push_back(Arg);555}556557if (F->isVarArg()) {558// The 5th argument to variadic entry thunks is used to model the x64 sp559// which is passed to the thunk in x4, this can be passed to the callee as560// the variadic argument start address after skipping over the 32 byte561// shadow store.562563// The EC thunk CC will assign any argument marked as InReg to x4.564Thunk->addParamAttr(5, Attribute::InReg);565Value *Arg = Thunk->getArg(5);566Arg = IRB.CreatePtrAdd(Arg, IRB.getInt64(0x20));567Args.push_back(Arg);568569// Pass in a zero variadic argument size (in x5).570Args.push_back(IRB.getInt64(0));571}572573// Call the function passed to the thunk.574Value *Callee = Thunk->getArg(0);575Callee = IRB.CreateBitCast(Callee, PtrTy);576CallInst *Call = IRB.CreateCall(Arm64Ty, Callee, Args);577578auto SRetAttr = F->getAttributes().getParamAttr(0, Attribute::StructRet);579auto InRegAttr = F->getAttributes().getParamAttr(0, Attribute::InReg);580if (SRetAttr.isValid() && !InRegAttr.isValid()) {581Thunk->addParamAttr(1, SRetAttr);582Call->addParamAttr(0, SRetAttr);583}584585Value *RetVal = Call;586if (TransformDirectToSRet) {587IRB.CreateStore(RetVal, IRB.CreateBitCast(Thunk->getArg(1), PtrTy));588} else if (X64RetType != RetTy) {589Value *CastAlloca = IRB.CreateAlloca(X64RetType);590IRB.CreateStore(Call, IRB.CreateBitCast(CastAlloca, PtrTy));591RetVal = IRB.CreateLoad(X64RetType, CastAlloca);592}593594// Return to the caller. Note that the isel has code to translate this595// "ret" to a tail call to __os_arm64x_dispatch_ret. (Alternatively, we596// could emit a tail call here, but that would require a dedicated calling597// convention, which seems more complicated overall.)598if (X64RetType->isVoidTy())599IRB.CreateRetVoid();600else601IRB.CreateRet(RetVal);602603return Thunk;604}605606// Builds the "guest exit thunk", a helper to call a function which may or may607// not be an exit thunk. (We optimistically assume non-dllimport function608// declarations refer to functions defined in AArch64 code; if the linker609// can't prove that, we use this routine instead.)610Function *AArch64Arm64ECCallLowering::buildGuestExitThunk(Function *F) {611llvm::raw_null_ostream NullThunkName;612FunctionType *Arm64Ty, *X64Ty;613SmallVector<ThunkArgTranslation> ArgTranslations;614getThunkType(F->getFunctionType(), F->getAttributes(),615Arm64ECThunkType::GuestExit, NullThunkName, Arm64Ty, X64Ty,616ArgTranslations);617auto MangledName = getArm64ECMangledFunctionName(F->getName().str());618assert(MangledName && "Can't guest exit to function that's already native");619std::string ThunkName = *MangledName;620if (ThunkName[0] == '?' && ThunkName.find("@") != std::string::npos) {621ThunkName.insert(ThunkName.find("@"), "$exit_thunk");622} else {623ThunkName.append("$exit_thunk");624}625Function *GuestExit =626Function::Create(Arm64Ty, GlobalValue::WeakODRLinkage, 0, ThunkName, M);627GuestExit->setComdat(M->getOrInsertComdat(ThunkName));628GuestExit->setSection(".wowthk$aa");629GuestExit->setMetadata(630"arm64ec_unmangled_name",631MDNode::get(M->getContext(),632MDString::get(M->getContext(), F->getName())));633GuestExit->setMetadata(634"arm64ec_ecmangled_name",635MDNode::get(M->getContext(),636MDString::get(M->getContext(), *MangledName)));637F->setMetadata("arm64ec_hasguestexit", MDNode::get(M->getContext(), {}));638BasicBlock *BB = BasicBlock::Create(M->getContext(), "", GuestExit);639IRBuilder<> B(BB);640641// Load the global symbol as a pointer to the check function.642Value *GuardFn;643if (cfguard_module_flag == 2 && !F->hasFnAttribute("guard_nocf"))644GuardFn = GuardFnCFGlobal;645else646GuardFn = GuardFnGlobal;647LoadInst *GuardCheckLoad = B.CreateLoad(GuardFnPtrType, GuardFn);648649// Create new call instruction. The CFGuard check should always be a call,650// even if the original CallBase is an Invoke or CallBr instruction.651Function *Thunk = buildExitThunk(F->getFunctionType(), F->getAttributes());652CallInst *GuardCheck = B.CreateCall(653GuardFnType, GuardCheckLoad,654{B.CreateBitCast(F, B.getPtrTy()), B.CreateBitCast(Thunk, B.getPtrTy())});655656// Ensure that the first argument is passed in the correct register.657GuardCheck->setCallingConv(CallingConv::CFGuard_Check);658659Value *GuardRetVal = B.CreateBitCast(GuardCheck, PtrTy);660SmallVector<Value *> Args;661for (Argument &Arg : GuestExit->args())662Args.push_back(&Arg);663CallInst *Call = B.CreateCall(Arm64Ty, GuardRetVal, Args);664Call->setTailCallKind(llvm::CallInst::TCK_MustTail);665666if (Call->getType()->isVoidTy())667B.CreateRetVoid();668else669B.CreateRet(Call);670671auto SRetAttr = F->getAttributes().getParamAttr(0, Attribute::StructRet);672auto InRegAttr = F->getAttributes().getParamAttr(0, Attribute::InReg);673if (SRetAttr.isValid() && !InRegAttr.isValid()) {674GuestExit->addParamAttr(0, SRetAttr);675Call->addParamAttr(0, SRetAttr);676}677678return GuestExit;679}680681Function *682AArch64Arm64ECCallLowering::buildPatchableThunk(GlobalAlias *UnmangledAlias,683GlobalAlias *MangledAlias) {684llvm::raw_null_ostream NullThunkName;685FunctionType *Arm64Ty, *X64Ty;686Function *F = cast<Function>(MangledAlias->getAliasee());687SmallVector<ThunkArgTranslation> ArgTranslations;688getThunkType(F->getFunctionType(), F->getAttributes(),689Arm64ECThunkType::GuestExit, NullThunkName, Arm64Ty, X64Ty,690ArgTranslations);691std::string ThunkName(MangledAlias->getName());692if (ThunkName[0] == '?' && ThunkName.find("@") != std::string::npos) {693ThunkName.insert(ThunkName.find("@"), "$hybpatch_thunk");694} else {695ThunkName.append("$hybpatch_thunk");696}697698Function *GuestExit =699Function::Create(Arm64Ty, GlobalValue::WeakODRLinkage, 0, ThunkName, M);700GuestExit->setComdat(M->getOrInsertComdat(ThunkName));701GuestExit->setSection(".wowthk$aa");702BasicBlock *BB = BasicBlock::Create(M->getContext(), "", GuestExit);703IRBuilder<> B(BB);704705// Load the global symbol as a pointer to the check function.706LoadInst *DispatchLoad = B.CreateLoad(DispatchFnPtrType, DispatchFnGlobal);707708// Create new dispatch call instruction.709Function *ExitThunk =710buildExitThunk(F->getFunctionType(), F->getAttributes());711CallInst *Dispatch =712B.CreateCall(DispatchFnType, DispatchLoad,713{UnmangledAlias, ExitThunk, UnmangledAlias->getAliasee()});714715// Ensure that the first arguments are passed in the correct registers.716Dispatch->setCallingConv(CallingConv::CFGuard_Check);717718Value *DispatchRetVal = B.CreateBitCast(Dispatch, PtrTy);719SmallVector<Value *> Args;720for (Argument &Arg : GuestExit->args())721Args.push_back(&Arg);722CallInst *Call = B.CreateCall(Arm64Ty, DispatchRetVal, Args);723Call->setTailCallKind(llvm::CallInst::TCK_MustTail);724725if (Call->getType()->isVoidTy())726B.CreateRetVoid();727else728B.CreateRet(Call);729730auto SRetAttr = F->getAttributes().getParamAttr(0, Attribute::StructRet);731auto InRegAttr = F->getAttributes().getParamAttr(0, Attribute::InReg);732if (SRetAttr.isValid() && !InRegAttr.isValid()) {733GuestExit->addParamAttr(0, SRetAttr);734Call->addParamAttr(0, SRetAttr);735}736737MangledAlias->setAliasee(GuestExit);738return GuestExit;739}740741// Lower an indirect call with inline code.742void AArch64Arm64ECCallLowering::lowerCall(CallBase *CB) {743assert(Triple(CB->getModule()->getTargetTriple()).isOSWindows() &&744"Only applicable for Windows targets");745746IRBuilder<> B(CB);747Value *CalledOperand = CB->getCalledOperand();748749// If the indirect call is called within catchpad or cleanuppad,750// we need to copy "funclet" bundle of the call.751SmallVector<llvm::OperandBundleDef, 1> Bundles;752if (auto Bundle = CB->getOperandBundle(LLVMContext::OB_funclet))753Bundles.push_back(OperandBundleDef(*Bundle));754755// Load the global symbol as a pointer to the check function.756Value *GuardFn;757if (cfguard_module_flag == 2 && !CB->hasFnAttr("guard_nocf"))758GuardFn = GuardFnCFGlobal;759else760GuardFn = GuardFnGlobal;761LoadInst *GuardCheckLoad = B.CreateLoad(GuardFnPtrType, GuardFn);762763// Create new call instruction. The CFGuard check should always be a call,764// even if the original CallBase is an Invoke or CallBr instruction.765Function *Thunk = buildExitThunk(CB->getFunctionType(), CB->getAttributes());766CallInst *GuardCheck =767B.CreateCall(GuardFnType, GuardCheckLoad,768{B.CreateBitCast(CalledOperand, B.getPtrTy()),769B.CreateBitCast(Thunk, B.getPtrTy())},770Bundles);771772// Ensure that the first argument is passed in the correct register.773GuardCheck->setCallingConv(CallingConv::CFGuard_Check);774775Value *GuardRetVal = B.CreateBitCast(GuardCheck, CalledOperand->getType());776CB->setCalledOperand(GuardRetVal);777}778779bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {780if (!GenerateThunks)781return false;782783M = &Mod;784785// Check if this module has the cfguard flag and read its value.786if (auto *MD =787mdconst::extract_or_null<ConstantInt>(M->getModuleFlag("cfguard")))788cfguard_module_flag = MD->getZExtValue();789790PtrTy = PointerType::getUnqual(M->getContext());791I64Ty = Type::getInt64Ty(M->getContext());792VoidTy = Type::getVoidTy(M->getContext());793794GuardFnType = FunctionType::get(PtrTy, {PtrTy, PtrTy}, false);795GuardFnPtrType = PointerType::get(GuardFnType, 0);796DispatchFnType = FunctionType::get(PtrTy, {PtrTy, PtrTy, PtrTy}, false);797DispatchFnPtrType = PointerType::get(DispatchFnType, 0);798GuardFnCFGlobal =799M->getOrInsertGlobal("__os_arm64x_check_icall_cfg", GuardFnPtrType);800GuardFnGlobal =801M->getOrInsertGlobal("__os_arm64x_check_icall", GuardFnPtrType);802DispatchFnGlobal =803M->getOrInsertGlobal("__os_arm64x_dispatch_call", DispatchFnPtrType);804805DenseMap<GlobalAlias *, GlobalAlias *> FnsMap;806SetVector<GlobalAlias *> PatchableFns;807808for (Function &F : Mod) {809if (!F.hasFnAttribute(Attribute::HybridPatchable) || F.isDeclaration() ||810F.hasLocalLinkage() || F.getName().ends_with("$hp_target"))811continue;812813// Rename hybrid patchable functions and change callers to use a global814// alias instead.815if (std::optional<std::string> MangledName =816getArm64ECMangledFunctionName(F.getName().str())) {817std::string OrigName(F.getName());818F.setName(MangledName.value() + "$hp_target");819820// The unmangled symbol is a weak alias to an undefined symbol with the821// "EXP+" prefix. This undefined symbol is resolved by the linker by822// creating an x86 thunk that jumps back to the actual EC target. Since we823// can't represent that in IR, we create an alias to the target instead.824// The "EXP+" symbol is set as metadata, which is then used by825// emitGlobalAlias to emit the right alias.826auto *A =827GlobalAlias::create(GlobalValue::LinkOnceODRLinkage, OrigName, &F);828F.replaceAllUsesWith(A);829F.setMetadata("arm64ec_exp_name",830MDNode::get(M->getContext(),831MDString::get(M->getContext(),832"EXP+" + MangledName.value())));833A->setAliasee(&F);834835if (F.hasDLLExportStorageClass()) {836A->setDLLStorageClass(GlobalValue::DLLExportStorageClass);837F.setDLLStorageClass(GlobalValue::DefaultStorageClass);838}839840FnsMap[A] = GlobalAlias::create(GlobalValue::LinkOnceODRLinkage,841MangledName.value(), &F);842PatchableFns.insert(A);843}844}845846SetVector<GlobalValue *> DirectCalledFns;847for (Function &F : Mod)848if (!F.isDeclaration() &&849F.getCallingConv() != CallingConv::ARM64EC_Thunk_Native &&850F.getCallingConv() != CallingConv::ARM64EC_Thunk_X64)851processFunction(F, DirectCalledFns, FnsMap);852853struct ThunkInfo {854Constant *Src;855Constant *Dst;856Arm64ECThunkType Kind;857};858SmallVector<ThunkInfo> ThunkMapping;859for (Function &F : Mod) {860if (!F.isDeclaration() && (!F.hasLocalLinkage() || F.hasAddressTaken()) &&861F.getCallingConv() != CallingConv::ARM64EC_Thunk_Native &&862F.getCallingConv() != CallingConv::ARM64EC_Thunk_X64) {863if (!F.hasComdat())864F.setComdat(Mod.getOrInsertComdat(F.getName()));865ThunkMapping.push_back(866{&F, buildEntryThunk(&F), Arm64ECThunkType::Entry});867}868}869for (GlobalValue *O : DirectCalledFns) {870auto GA = dyn_cast<GlobalAlias>(O);871auto F = dyn_cast<Function>(GA ? GA->getAliasee() : O);872ThunkMapping.push_back(873{O, buildExitThunk(F->getFunctionType(), F->getAttributes()),874Arm64ECThunkType::Exit});875if (!GA && !F->hasDLLImportStorageClass())876ThunkMapping.push_back(877{buildGuestExitThunk(F), F, Arm64ECThunkType::GuestExit});878}879for (GlobalAlias *A : PatchableFns) {880Function *Thunk = buildPatchableThunk(A, FnsMap[A]);881ThunkMapping.push_back({Thunk, A, Arm64ECThunkType::GuestExit});882}883884if (!ThunkMapping.empty()) {885SmallVector<Constant *> ThunkMappingArrayElems;886for (ThunkInfo &Thunk : ThunkMapping) {887ThunkMappingArrayElems.push_back(ConstantStruct::getAnon(888{ConstantExpr::getBitCast(Thunk.Src, PtrTy),889ConstantExpr::getBitCast(Thunk.Dst, PtrTy),890ConstantInt::get(M->getContext(), APInt(32, uint8_t(Thunk.Kind)))}));891}892Constant *ThunkMappingArray = ConstantArray::get(893llvm::ArrayType::get(ThunkMappingArrayElems[0]->getType(),894ThunkMappingArrayElems.size()),895ThunkMappingArrayElems);896new GlobalVariable(Mod, ThunkMappingArray->getType(), /*isConstant*/ false,897GlobalValue::ExternalLinkage, ThunkMappingArray,898"llvm.arm64ec.symbolmap");899}900901return true;902}903904bool AArch64Arm64ECCallLowering::processFunction(905Function &F, SetVector<GlobalValue *> &DirectCalledFns,906DenseMap<GlobalAlias *, GlobalAlias *> &FnsMap) {907SmallVector<CallBase *, 8> IndirectCalls;908909// For ARM64EC targets, a function definition's name is mangled differently910// from the normal symbol. We currently have no representation of this sort911// of symbol in IR, so we change the name to the mangled name, then store912// the unmangled name as metadata. Later passes that need the unmangled913// name (emitting the definition) can grab it from the metadata.914//915// FIXME: Handle functions with weak linkage?916if (!F.hasLocalLinkage() || F.hasAddressTaken()) {917if (std::optional<std::string> MangledName =918getArm64ECMangledFunctionName(F.getName().str())) {919F.setMetadata("arm64ec_unmangled_name",920MDNode::get(M->getContext(),921MDString::get(M->getContext(), F.getName())));922if (F.hasComdat() && F.getComdat()->getName() == F.getName()) {923Comdat *MangledComdat = M->getOrInsertComdat(MangledName.value());924SmallVector<GlobalObject *> ComdatUsers =925to_vector(F.getComdat()->getUsers());926for (GlobalObject *User : ComdatUsers)927User->setComdat(MangledComdat);928}929F.setName(MangledName.value());930}931}932933// Iterate over the instructions to find all indirect call/invoke/callbr934// instructions. Make a separate list of pointers to indirect935// call/invoke/callbr instructions because the original instructions will be936// deleted as the checks are added.937for (BasicBlock &BB : F) {938for (Instruction &I : BB) {939auto *CB = dyn_cast<CallBase>(&I);940if (!CB || CB->getCallingConv() == CallingConv::ARM64EC_Thunk_X64 ||941CB->isInlineAsm())942continue;943944// We need to instrument any call that isn't directly calling an945// ARM64 function.946//947// FIXME: getCalledFunction() fails if there's a bitcast (e.g.948// unprototyped functions in C)949if (Function *F = CB->getCalledFunction()) {950if (!LowerDirectToIndirect || F->hasLocalLinkage() ||951F->isIntrinsic() || !F->isDeclaration())952continue;953954DirectCalledFns.insert(F);955continue;956}957958// Use mangled global alias for direct calls to patchable functions.959if (GlobalAlias *A = dyn_cast<GlobalAlias>(CB->getCalledOperand())) {960auto I = FnsMap.find(A);961if (I != FnsMap.end()) {962CB->setCalledOperand(I->second);963DirectCalledFns.insert(I->first);964continue;965}966}967968IndirectCalls.push_back(CB);969++Arm64ECCallsLowered;970}971}972973if (IndirectCalls.empty())974return false;975976for (CallBase *CB : IndirectCalls)977lowerCall(CB);978979return true;980}981982char AArch64Arm64ECCallLowering::ID = 0;983INITIALIZE_PASS(AArch64Arm64ECCallLowering, "Arm64ECCallLowering",984"AArch64Arm64ECCallLowering", false, false)985986ModulePass *llvm::createAArch64Arm64ECCallLoweringPass() {987return new AArch64Arm64ECCallLowering;988}989990991