Path: blob/main/contrib/llvm-project/llvm/lib/Target/WebAssembly/WebAssemblyFixFunctionBitcasts.cpp
35266 views
//===-- WebAssemblyFixFunctionBitcasts.cpp - Fix function bitcasts --------===//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.3// See https://llvm.org/LICENSE.txt for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===----------------------------------------------------------------------===//7///8/// \file9/// Fix bitcasted functions.10///11/// WebAssembly requires caller and callee signatures to match, however in LLVM,12/// some amount of slop is vaguely permitted. Detect mismatch by looking for13/// bitcasts of functions and rewrite them to use wrapper functions instead.14///15/// This doesn't catch all cases, such as when a function's address is taken in16/// one place and casted in another, but it works for many common cases.17///18/// Note that LLVM already optimizes away function bitcasts in common cases by19/// dropping arguments as needed, so this pass only ends up getting used in less20/// common cases.21///22//===----------------------------------------------------------------------===//2324#include "WebAssembly.h"25#include "llvm/IR/Constants.h"26#include "llvm/IR/Instructions.h"27#include "llvm/IR/Module.h"28#include "llvm/IR/Operator.h"29#include "llvm/Pass.h"30#include "llvm/Support/Debug.h"31#include "llvm/Support/raw_ostream.h"32using namespace llvm;3334#define DEBUG_TYPE "wasm-fix-function-bitcasts"3536namespace {37class FixFunctionBitcasts final : public ModulePass {38StringRef getPassName() const override {39return "WebAssembly Fix Function Bitcasts";40}4142void getAnalysisUsage(AnalysisUsage &AU) const override {43AU.setPreservesCFG();44ModulePass::getAnalysisUsage(AU);45}4647bool runOnModule(Module &M) override;4849public:50static char ID;51FixFunctionBitcasts() : ModulePass(ID) {}52};53} // End anonymous namespace5455char FixFunctionBitcasts::ID = 0;56INITIALIZE_PASS(FixFunctionBitcasts, DEBUG_TYPE,57"Fix mismatching bitcasts for WebAssembly", false, false)5859ModulePass *llvm::createWebAssemblyFixFunctionBitcasts() {60return new FixFunctionBitcasts();61}6263// Recursively descend the def-use lists from V to find non-bitcast users of64// bitcasts of V.65static void findUses(Value *V, Function &F,66SmallVectorImpl<std::pair<CallBase *, Function *>> &Uses) {67for (User *U : V->users()) {68if (auto *BC = dyn_cast<BitCastOperator>(U))69findUses(BC, F, Uses);70else if (auto *A = dyn_cast<GlobalAlias>(U))71findUses(A, F, Uses);72else if (auto *CB = dyn_cast<CallBase>(U)) {73Value *Callee = CB->getCalledOperand();74if (Callee != V)75// Skip calls where the function isn't the callee76continue;77if (CB->getFunctionType() == F.getValueType())78// Skip uses that are immediately called79continue;80Uses.push_back(std::make_pair(CB, &F));81}82}83}8485// Create a wrapper function with type Ty that calls F (which may have a86// different type). Attempt to support common bitcasted function idioms:87// - Call with more arguments than needed: arguments are dropped88// - Call with fewer arguments than needed: arguments are filled in with undef89// - Return value is not needed: drop it90// - Return value needed but not present: supply an undef91//92// If the all the argument types of trivially castable to one another (i.e.93// I32 vs pointer type) then we don't create a wrapper at all (return nullptr94// instead).95//96// If there is a type mismatch that we know would result in an invalid wasm97// module then generate wrapper that contains unreachable (i.e. abort at98// runtime). Such programs are deep into undefined behaviour territory,99// but we choose to fail at runtime rather than generate and invalid module100// or fail at compiler time. The reason we delay the error is that we want101// to support the CMake which expects to be able to compile and link programs102// that refer to functions with entirely incorrect signatures (this is how103// CMake detects the existence of a function in a toolchain).104//105// For bitcasts that involve struct types we don't know at this stage if they106// would be equivalent at the wasm level and so we can't know if we need to107// generate a wrapper.108static Function *createWrapper(Function *F, FunctionType *Ty) {109Module *M = F->getParent();110111Function *Wrapper = Function::Create(Ty, Function::PrivateLinkage,112F->getName() + "_bitcast", M);113BasicBlock *BB = BasicBlock::Create(M->getContext(), "body", Wrapper);114const DataLayout &DL = BB->getDataLayout();115116// Determine what arguments to pass.117SmallVector<Value *, 4> Args;118Function::arg_iterator AI = Wrapper->arg_begin();119Function::arg_iterator AE = Wrapper->arg_end();120FunctionType::param_iterator PI = F->getFunctionType()->param_begin();121FunctionType::param_iterator PE = F->getFunctionType()->param_end();122bool TypeMismatch = false;123bool WrapperNeeded = false;124125Type *ExpectedRtnType = F->getFunctionType()->getReturnType();126Type *RtnType = Ty->getReturnType();127128if ((F->getFunctionType()->getNumParams() != Ty->getNumParams()) ||129(F->getFunctionType()->isVarArg() != Ty->isVarArg()) ||130(ExpectedRtnType != RtnType))131WrapperNeeded = true;132133for (; AI != AE && PI != PE; ++AI, ++PI) {134Type *ArgType = AI->getType();135Type *ParamType = *PI;136137if (ArgType == ParamType) {138Args.push_back(&*AI);139} else {140if (CastInst::isBitOrNoopPointerCastable(ArgType, ParamType, DL)) {141Instruction *PtrCast =142CastInst::CreateBitOrPointerCast(AI, ParamType, "cast");143PtrCast->insertInto(BB, BB->end());144Args.push_back(PtrCast);145} else if (ArgType->isStructTy() || ParamType->isStructTy()) {146LLVM_DEBUG(dbgs() << "createWrapper: struct param type in bitcast: "147<< F->getName() << "\n");148WrapperNeeded = false;149} else {150LLVM_DEBUG(dbgs() << "createWrapper: arg type mismatch calling: "151<< F->getName() << "\n");152LLVM_DEBUG(dbgs() << "Arg[" << Args.size() << "] Expected: "153<< *ParamType << " Got: " << *ArgType << "\n");154TypeMismatch = true;155break;156}157}158}159160if (WrapperNeeded && !TypeMismatch) {161for (; PI != PE; ++PI)162Args.push_back(UndefValue::get(*PI));163if (F->isVarArg())164for (; AI != AE; ++AI)165Args.push_back(&*AI);166167CallInst *Call = CallInst::Create(F, Args, "", BB);168169Type *ExpectedRtnType = F->getFunctionType()->getReturnType();170Type *RtnType = Ty->getReturnType();171// Determine what value to return.172if (RtnType->isVoidTy()) {173ReturnInst::Create(M->getContext(), BB);174} else if (ExpectedRtnType->isVoidTy()) {175LLVM_DEBUG(dbgs() << "Creating dummy return: " << *RtnType << "\n");176ReturnInst::Create(M->getContext(), UndefValue::get(RtnType), BB);177} else if (RtnType == ExpectedRtnType) {178ReturnInst::Create(M->getContext(), Call, BB);179} else if (CastInst::isBitOrNoopPointerCastable(ExpectedRtnType, RtnType,180DL)) {181Instruction *Cast =182CastInst::CreateBitOrPointerCast(Call, RtnType, "cast");183Cast->insertInto(BB, BB->end());184ReturnInst::Create(M->getContext(), Cast, BB);185} else if (RtnType->isStructTy() || ExpectedRtnType->isStructTy()) {186LLVM_DEBUG(dbgs() << "createWrapper: struct return type in bitcast: "187<< F->getName() << "\n");188WrapperNeeded = false;189} else {190LLVM_DEBUG(dbgs() << "createWrapper: return type mismatch calling: "191<< F->getName() << "\n");192LLVM_DEBUG(dbgs() << "Expected: " << *ExpectedRtnType193<< " Got: " << *RtnType << "\n");194TypeMismatch = true;195}196}197198if (TypeMismatch) {199// Create a new wrapper that simply contains `unreachable`.200Wrapper->eraseFromParent();201Wrapper = Function::Create(Ty, Function::PrivateLinkage,202F->getName() + "_bitcast_invalid", M);203BasicBlock *BB = BasicBlock::Create(M->getContext(), "body", Wrapper);204new UnreachableInst(M->getContext(), BB);205Wrapper->setName(F->getName() + "_bitcast_invalid");206} else if (!WrapperNeeded) {207LLVM_DEBUG(dbgs() << "createWrapper: no wrapper needed: " << F->getName()208<< "\n");209Wrapper->eraseFromParent();210return nullptr;211}212LLVM_DEBUG(dbgs() << "createWrapper: " << F->getName() << "\n");213return Wrapper;214}215216// Test whether a main function with type FuncTy should be rewritten to have217// type MainTy.218static bool shouldFixMainFunction(FunctionType *FuncTy, FunctionType *MainTy) {219// Only fix the main function if it's the standard zero-arg form. That way,220// the standard cases will work as expected, and users will see signature221// mismatches from the linker for non-standard cases.222return FuncTy->getReturnType() == MainTy->getReturnType() &&223FuncTy->getNumParams() == 0 &&224!FuncTy->isVarArg();225}226227bool FixFunctionBitcasts::runOnModule(Module &M) {228LLVM_DEBUG(dbgs() << "********** Fix Function Bitcasts **********\n");229230Function *Main = nullptr;231CallInst *CallMain = nullptr;232SmallVector<std::pair<CallBase *, Function *>, 0> Uses;233234// Collect all the places that need wrappers.235for (Function &F : M) {236// Skip to fix when the function is swiftcc because swiftcc allows237// bitcast type difference for swiftself and swifterror.238if (F.getCallingConv() == CallingConv::Swift)239continue;240findUses(&F, F, Uses);241242// If we have a "main" function, and its type isn't243// "int main(int argc, char *argv[])", create an artificial call with it244// bitcasted to that type so that we generate a wrapper for it, so that245// the C runtime can call it.246if (F.getName() == "main") {247Main = &F;248LLVMContext &C = M.getContext();249Type *MainArgTys[] = {Type::getInt32Ty(C), PointerType::get(C, 0)};250FunctionType *MainTy = FunctionType::get(Type::getInt32Ty(C), MainArgTys,251/*isVarArg=*/false);252if (shouldFixMainFunction(F.getFunctionType(), MainTy)) {253LLVM_DEBUG(dbgs() << "Found `main` function with incorrect type: "254<< *F.getFunctionType() << "\n");255Value *Args[] = {UndefValue::get(MainArgTys[0]),256UndefValue::get(MainArgTys[1])};257CallMain = CallInst::Create(MainTy, Main, Args, "call_main");258Uses.push_back(std::make_pair(CallMain, &F));259}260}261}262263DenseMap<std::pair<Function *, FunctionType *>, Function *> Wrappers;264265for (auto &UseFunc : Uses) {266CallBase *CB = UseFunc.first;267Function *F = UseFunc.second;268FunctionType *Ty = CB->getFunctionType();269270auto Pair = Wrappers.insert(std::make_pair(std::make_pair(F, Ty), nullptr));271if (Pair.second)272Pair.first->second = createWrapper(F, Ty);273274Function *Wrapper = Pair.first->second;275if (!Wrapper)276continue;277278CB->setCalledOperand(Wrapper);279}280281// If we created a wrapper for main, rename the wrapper so that it's the282// one that gets called from startup.283if (CallMain) {284Main->setName("__original_main");285auto *MainWrapper =286cast<Function>(CallMain->getCalledOperand()->stripPointerCasts());287delete CallMain;288if (Main->isDeclaration()) {289// The wrapper is not needed in this case as we don't need to export290// it to anyone else.291MainWrapper->eraseFromParent();292} else {293// Otherwise give the wrapper the same linkage as the original main294// function, so that it can be called from the same places.295MainWrapper->setName("main");296MainWrapper->setLinkage(Main->getLinkage());297MainWrapper->setVisibility(Main->getVisibility());298}299}300301return true;302}303304305