Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
freebsd
GitHub Repository: freebsd/freebsd-src
Path: blob/main/contrib/llvm-project/llvm/lib/Target/WebAssembly/WebAssemblyFixFunctionBitcasts.cpp
35266 views
1
//===-- WebAssemblyFixFunctionBitcasts.cpp - Fix function bitcasts --------===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
///
9
/// \file
10
/// Fix bitcasted functions.
11
///
12
/// WebAssembly requires caller and callee signatures to match, however in LLVM,
13
/// some amount of slop is vaguely permitted. Detect mismatch by looking for
14
/// bitcasts of functions and rewrite them to use wrapper functions instead.
15
///
16
/// This doesn't catch all cases, such as when a function's address is taken in
17
/// one place and casted in another, but it works for many common cases.
18
///
19
/// Note that LLVM already optimizes away function bitcasts in common cases by
20
/// dropping arguments as needed, so this pass only ends up getting used in less
21
/// common cases.
22
///
23
//===----------------------------------------------------------------------===//
24
25
#include "WebAssembly.h"
26
#include "llvm/IR/Constants.h"
27
#include "llvm/IR/Instructions.h"
28
#include "llvm/IR/Module.h"
29
#include "llvm/IR/Operator.h"
30
#include "llvm/Pass.h"
31
#include "llvm/Support/Debug.h"
32
#include "llvm/Support/raw_ostream.h"
33
using namespace llvm;
34
35
#define DEBUG_TYPE "wasm-fix-function-bitcasts"
36
37
namespace {
38
class FixFunctionBitcasts final : public ModulePass {
39
StringRef getPassName() const override {
40
return "WebAssembly Fix Function Bitcasts";
41
}
42
43
void getAnalysisUsage(AnalysisUsage &AU) const override {
44
AU.setPreservesCFG();
45
ModulePass::getAnalysisUsage(AU);
46
}
47
48
bool runOnModule(Module &M) override;
49
50
public:
51
static char ID;
52
FixFunctionBitcasts() : ModulePass(ID) {}
53
};
54
} // End anonymous namespace
55
56
char FixFunctionBitcasts::ID = 0;
57
INITIALIZE_PASS(FixFunctionBitcasts, DEBUG_TYPE,
58
"Fix mismatching bitcasts for WebAssembly", false, false)
59
60
ModulePass *llvm::createWebAssemblyFixFunctionBitcasts() {
61
return new FixFunctionBitcasts();
62
}
63
64
// Recursively descend the def-use lists from V to find non-bitcast users of
65
// bitcasts of V.
66
static void findUses(Value *V, Function &F,
67
SmallVectorImpl<std::pair<CallBase *, Function *>> &Uses) {
68
for (User *U : V->users()) {
69
if (auto *BC = dyn_cast<BitCastOperator>(U))
70
findUses(BC, F, Uses);
71
else if (auto *A = dyn_cast<GlobalAlias>(U))
72
findUses(A, F, Uses);
73
else if (auto *CB = dyn_cast<CallBase>(U)) {
74
Value *Callee = CB->getCalledOperand();
75
if (Callee != V)
76
// Skip calls where the function isn't the callee
77
continue;
78
if (CB->getFunctionType() == F.getValueType())
79
// Skip uses that are immediately called
80
continue;
81
Uses.push_back(std::make_pair(CB, &F));
82
}
83
}
84
}
85
86
// Create a wrapper function with type Ty that calls F (which may have a
87
// different type). Attempt to support common bitcasted function idioms:
88
// - Call with more arguments than needed: arguments are dropped
89
// - Call with fewer arguments than needed: arguments are filled in with undef
90
// - Return value is not needed: drop it
91
// - Return value needed but not present: supply an undef
92
//
93
// If the all the argument types of trivially castable to one another (i.e.
94
// I32 vs pointer type) then we don't create a wrapper at all (return nullptr
95
// instead).
96
//
97
// If there is a type mismatch that we know would result in an invalid wasm
98
// module then generate wrapper that contains unreachable (i.e. abort at
99
// runtime). Such programs are deep into undefined behaviour territory,
100
// but we choose to fail at runtime rather than generate and invalid module
101
// or fail at compiler time. The reason we delay the error is that we want
102
// to support the CMake which expects to be able to compile and link programs
103
// that refer to functions with entirely incorrect signatures (this is how
104
// CMake detects the existence of a function in a toolchain).
105
//
106
// For bitcasts that involve struct types we don't know at this stage if they
107
// would be equivalent at the wasm level and so we can't know if we need to
108
// generate a wrapper.
109
static Function *createWrapper(Function *F, FunctionType *Ty) {
110
Module *M = F->getParent();
111
112
Function *Wrapper = Function::Create(Ty, Function::PrivateLinkage,
113
F->getName() + "_bitcast", M);
114
BasicBlock *BB = BasicBlock::Create(M->getContext(), "body", Wrapper);
115
const DataLayout &DL = BB->getDataLayout();
116
117
// Determine what arguments to pass.
118
SmallVector<Value *, 4> Args;
119
Function::arg_iterator AI = Wrapper->arg_begin();
120
Function::arg_iterator AE = Wrapper->arg_end();
121
FunctionType::param_iterator PI = F->getFunctionType()->param_begin();
122
FunctionType::param_iterator PE = F->getFunctionType()->param_end();
123
bool TypeMismatch = false;
124
bool WrapperNeeded = false;
125
126
Type *ExpectedRtnType = F->getFunctionType()->getReturnType();
127
Type *RtnType = Ty->getReturnType();
128
129
if ((F->getFunctionType()->getNumParams() != Ty->getNumParams()) ||
130
(F->getFunctionType()->isVarArg() != Ty->isVarArg()) ||
131
(ExpectedRtnType != RtnType))
132
WrapperNeeded = true;
133
134
for (; AI != AE && PI != PE; ++AI, ++PI) {
135
Type *ArgType = AI->getType();
136
Type *ParamType = *PI;
137
138
if (ArgType == ParamType) {
139
Args.push_back(&*AI);
140
} else {
141
if (CastInst::isBitOrNoopPointerCastable(ArgType, ParamType, DL)) {
142
Instruction *PtrCast =
143
CastInst::CreateBitOrPointerCast(AI, ParamType, "cast");
144
PtrCast->insertInto(BB, BB->end());
145
Args.push_back(PtrCast);
146
} else if (ArgType->isStructTy() || ParamType->isStructTy()) {
147
LLVM_DEBUG(dbgs() << "createWrapper: struct param type in bitcast: "
148
<< F->getName() << "\n");
149
WrapperNeeded = false;
150
} else {
151
LLVM_DEBUG(dbgs() << "createWrapper: arg type mismatch calling: "
152
<< F->getName() << "\n");
153
LLVM_DEBUG(dbgs() << "Arg[" << Args.size() << "] Expected: "
154
<< *ParamType << " Got: " << *ArgType << "\n");
155
TypeMismatch = true;
156
break;
157
}
158
}
159
}
160
161
if (WrapperNeeded && !TypeMismatch) {
162
for (; PI != PE; ++PI)
163
Args.push_back(UndefValue::get(*PI));
164
if (F->isVarArg())
165
for (; AI != AE; ++AI)
166
Args.push_back(&*AI);
167
168
CallInst *Call = CallInst::Create(F, Args, "", BB);
169
170
Type *ExpectedRtnType = F->getFunctionType()->getReturnType();
171
Type *RtnType = Ty->getReturnType();
172
// Determine what value to return.
173
if (RtnType->isVoidTy()) {
174
ReturnInst::Create(M->getContext(), BB);
175
} else if (ExpectedRtnType->isVoidTy()) {
176
LLVM_DEBUG(dbgs() << "Creating dummy return: " << *RtnType << "\n");
177
ReturnInst::Create(M->getContext(), UndefValue::get(RtnType), BB);
178
} else if (RtnType == ExpectedRtnType) {
179
ReturnInst::Create(M->getContext(), Call, BB);
180
} else if (CastInst::isBitOrNoopPointerCastable(ExpectedRtnType, RtnType,
181
DL)) {
182
Instruction *Cast =
183
CastInst::CreateBitOrPointerCast(Call, RtnType, "cast");
184
Cast->insertInto(BB, BB->end());
185
ReturnInst::Create(M->getContext(), Cast, BB);
186
} else if (RtnType->isStructTy() || ExpectedRtnType->isStructTy()) {
187
LLVM_DEBUG(dbgs() << "createWrapper: struct return type in bitcast: "
188
<< F->getName() << "\n");
189
WrapperNeeded = false;
190
} else {
191
LLVM_DEBUG(dbgs() << "createWrapper: return type mismatch calling: "
192
<< F->getName() << "\n");
193
LLVM_DEBUG(dbgs() << "Expected: " << *ExpectedRtnType
194
<< " Got: " << *RtnType << "\n");
195
TypeMismatch = true;
196
}
197
}
198
199
if (TypeMismatch) {
200
// Create a new wrapper that simply contains `unreachable`.
201
Wrapper->eraseFromParent();
202
Wrapper = Function::Create(Ty, Function::PrivateLinkage,
203
F->getName() + "_bitcast_invalid", M);
204
BasicBlock *BB = BasicBlock::Create(M->getContext(), "body", Wrapper);
205
new UnreachableInst(M->getContext(), BB);
206
Wrapper->setName(F->getName() + "_bitcast_invalid");
207
} else if (!WrapperNeeded) {
208
LLVM_DEBUG(dbgs() << "createWrapper: no wrapper needed: " << F->getName()
209
<< "\n");
210
Wrapper->eraseFromParent();
211
return nullptr;
212
}
213
LLVM_DEBUG(dbgs() << "createWrapper: " << F->getName() << "\n");
214
return Wrapper;
215
}
216
217
// Test whether a main function with type FuncTy should be rewritten to have
218
// type MainTy.
219
static bool shouldFixMainFunction(FunctionType *FuncTy, FunctionType *MainTy) {
220
// Only fix the main function if it's the standard zero-arg form. That way,
221
// the standard cases will work as expected, and users will see signature
222
// mismatches from the linker for non-standard cases.
223
return FuncTy->getReturnType() == MainTy->getReturnType() &&
224
FuncTy->getNumParams() == 0 &&
225
!FuncTy->isVarArg();
226
}
227
228
bool FixFunctionBitcasts::runOnModule(Module &M) {
229
LLVM_DEBUG(dbgs() << "********** Fix Function Bitcasts **********\n");
230
231
Function *Main = nullptr;
232
CallInst *CallMain = nullptr;
233
SmallVector<std::pair<CallBase *, Function *>, 0> Uses;
234
235
// Collect all the places that need wrappers.
236
for (Function &F : M) {
237
// Skip to fix when the function is swiftcc because swiftcc allows
238
// bitcast type difference for swiftself and swifterror.
239
if (F.getCallingConv() == CallingConv::Swift)
240
continue;
241
findUses(&F, F, Uses);
242
243
// If we have a "main" function, and its type isn't
244
// "int main(int argc, char *argv[])", create an artificial call with it
245
// bitcasted to that type so that we generate a wrapper for it, so that
246
// the C runtime can call it.
247
if (F.getName() == "main") {
248
Main = &F;
249
LLVMContext &C = M.getContext();
250
Type *MainArgTys[] = {Type::getInt32Ty(C), PointerType::get(C, 0)};
251
FunctionType *MainTy = FunctionType::get(Type::getInt32Ty(C), MainArgTys,
252
/*isVarArg=*/false);
253
if (shouldFixMainFunction(F.getFunctionType(), MainTy)) {
254
LLVM_DEBUG(dbgs() << "Found `main` function with incorrect type: "
255
<< *F.getFunctionType() << "\n");
256
Value *Args[] = {UndefValue::get(MainArgTys[0]),
257
UndefValue::get(MainArgTys[1])};
258
CallMain = CallInst::Create(MainTy, Main, Args, "call_main");
259
Uses.push_back(std::make_pair(CallMain, &F));
260
}
261
}
262
}
263
264
DenseMap<std::pair<Function *, FunctionType *>, Function *> Wrappers;
265
266
for (auto &UseFunc : Uses) {
267
CallBase *CB = UseFunc.first;
268
Function *F = UseFunc.second;
269
FunctionType *Ty = CB->getFunctionType();
270
271
auto Pair = Wrappers.insert(std::make_pair(std::make_pair(F, Ty), nullptr));
272
if (Pair.second)
273
Pair.first->second = createWrapper(F, Ty);
274
275
Function *Wrapper = Pair.first->second;
276
if (!Wrapper)
277
continue;
278
279
CB->setCalledOperand(Wrapper);
280
}
281
282
// If we created a wrapper for main, rename the wrapper so that it's the
283
// one that gets called from startup.
284
if (CallMain) {
285
Main->setName("__original_main");
286
auto *MainWrapper =
287
cast<Function>(CallMain->getCalledOperand()->stripPointerCasts());
288
delete CallMain;
289
if (Main->isDeclaration()) {
290
// The wrapper is not needed in this case as we don't need to export
291
// it to anyone else.
292
MainWrapper->eraseFromParent();
293
} else {
294
// Otherwise give the wrapper the same linkage as the original main
295
// function, so that it can be called from the same places.
296
MainWrapper->setName("main");
297
MainWrapper->setLinkage(Main->getLinkage());
298
MainWrapper->setVisibility(Main->getVisibility());
299
}
300
}
301
302
return true;
303
}
304
305