Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
freebsd
GitHub Repository: freebsd/freebsd-src
Path: blob/main/contrib/llvm-project/compiler-rt/lib/orc/wrapper_function_utils.h
39566 views
1
//===-- wrapper_function_utils.h - Utilities for wrapper funcs --*- C++ -*-===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
//
9
// This file is a part of the ORC runtime support library.
10
//
11
//===----------------------------------------------------------------------===//
12
13
#ifndef ORC_RT_WRAPPER_FUNCTION_UTILS_H
14
#define ORC_RT_WRAPPER_FUNCTION_UTILS_H
15
16
#include "orc_rt/c_api.h"
17
#include "common.h"
18
#include "error.h"
19
#include "executor_address.h"
20
#include "simple_packed_serialization.h"
21
#include <type_traits>
22
23
namespace __orc_rt {
24
25
/// C++ wrapper function result: Same as CWrapperFunctionResult but
26
/// auto-releases memory.
27
class WrapperFunctionResult {
28
public:
29
/// Create a default WrapperFunctionResult.
30
WrapperFunctionResult() { orc_rt_CWrapperFunctionResultInit(&R); }
31
32
/// Create a WrapperFunctionResult from a CWrapperFunctionResult. This
33
/// instance takes ownership of the result object and will automatically
34
/// call dispose on the result upon destruction.
35
WrapperFunctionResult(orc_rt_CWrapperFunctionResult R) : R(R) {}
36
37
WrapperFunctionResult(const WrapperFunctionResult &) = delete;
38
WrapperFunctionResult &operator=(const WrapperFunctionResult &) = delete;
39
40
WrapperFunctionResult(WrapperFunctionResult &&Other) {
41
orc_rt_CWrapperFunctionResultInit(&R);
42
std::swap(R, Other.R);
43
}
44
45
WrapperFunctionResult &operator=(WrapperFunctionResult &&Other) {
46
orc_rt_CWrapperFunctionResult Tmp;
47
orc_rt_CWrapperFunctionResultInit(&Tmp);
48
std::swap(Tmp, Other.R);
49
std::swap(R, Tmp);
50
return *this;
51
}
52
53
~WrapperFunctionResult() { orc_rt_DisposeCWrapperFunctionResult(&R); }
54
55
/// Relinquish ownership of and return the
56
/// orc_rt_CWrapperFunctionResult.
57
orc_rt_CWrapperFunctionResult release() {
58
orc_rt_CWrapperFunctionResult Tmp;
59
orc_rt_CWrapperFunctionResultInit(&Tmp);
60
std::swap(R, Tmp);
61
return Tmp;
62
}
63
64
/// Get a pointer to the data contained in this instance.
65
char *data() { return orc_rt_CWrapperFunctionResultData(&R); }
66
67
/// Returns the size of the data contained in this instance.
68
size_t size() const { return orc_rt_CWrapperFunctionResultSize(&R); }
69
70
/// Returns true if this value is equivalent to a default-constructed
71
/// WrapperFunctionResult.
72
bool empty() const { return orc_rt_CWrapperFunctionResultEmpty(&R); }
73
74
/// Create a WrapperFunctionResult with the given size and return a pointer
75
/// to the underlying memory.
76
static WrapperFunctionResult allocate(size_t Size) {
77
WrapperFunctionResult R;
78
R.R = orc_rt_CWrapperFunctionResultAllocate(Size);
79
return R;
80
}
81
82
/// Copy from the given char range.
83
static WrapperFunctionResult copyFrom(const char *Source, size_t Size) {
84
return orc_rt_CreateCWrapperFunctionResultFromRange(Source, Size);
85
}
86
87
/// Copy from the given null-terminated string (includes the null-terminator).
88
static WrapperFunctionResult copyFrom(const char *Source) {
89
return orc_rt_CreateCWrapperFunctionResultFromString(Source);
90
}
91
92
/// Copy from the given std::string (includes the null terminator).
93
static WrapperFunctionResult copyFrom(const std::string &Source) {
94
return copyFrom(Source.c_str());
95
}
96
97
/// Create an out-of-band error by copying the given string.
98
static WrapperFunctionResult createOutOfBandError(const char *Msg) {
99
return orc_rt_CreateCWrapperFunctionResultFromOutOfBandError(Msg);
100
}
101
102
/// Create an out-of-band error by copying the given string.
103
static WrapperFunctionResult createOutOfBandError(const std::string &Msg) {
104
return createOutOfBandError(Msg.c_str());
105
}
106
107
template <typename SPSArgListT, typename... ArgTs>
108
static WrapperFunctionResult fromSPSArgs(const ArgTs &...Args) {
109
auto Result = allocate(SPSArgListT::size(Args...));
110
SPSOutputBuffer OB(Result.data(), Result.size());
111
if (!SPSArgListT::serialize(OB, Args...))
112
return createOutOfBandError(
113
"Error serializing arguments to blob in call");
114
return Result;
115
}
116
117
/// If this value is an out-of-band error then this returns the error message,
118
/// otherwise returns nullptr.
119
const char *getOutOfBandError() const {
120
return orc_rt_CWrapperFunctionResultGetOutOfBandError(&R);
121
}
122
123
private:
124
orc_rt_CWrapperFunctionResult R;
125
};
126
127
namespace detail {
128
129
template <typename RetT> class WrapperFunctionHandlerCaller {
130
public:
131
template <typename HandlerT, typename ArgTupleT, std::size_t... I>
132
static decltype(auto) call(HandlerT &&H, ArgTupleT &Args,
133
std::index_sequence<I...>) {
134
return std::forward<HandlerT>(H)(std::get<I>(Args)...);
135
}
136
};
137
138
template <> class WrapperFunctionHandlerCaller<void> {
139
public:
140
template <typename HandlerT, typename ArgTupleT, std::size_t... I>
141
static SPSEmpty call(HandlerT &&H, ArgTupleT &Args,
142
std::index_sequence<I...>) {
143
std::forward<HandlerT>(H)(std::get<I>(Args)...);
144
return SPSEmpty();
145
}
146
};
147
148
template <typename WrapperFunctionImplT,
149
template <typename> class ResultSerializer, typename... SPSTagTs>
150
class WrapperFunctionHandlerHelper
151
: public WrapperFunctionHandlerHelper<
152
decltype(&std::remove_reference_t<WrapperFunctionImplT>::operator()),
153
ResultSerializer, SPSTagTs...> {};
154
155
template <typename RetT, typename... ArgTs,
156
template <typename> class ResultSerializer, typename... SPSTagTs>
157
class WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
158
SPSTagTs...> {
159
public:
160
using ArgTuple = std::tuple<std::decay_t<ArgTs>...>;
161
using ArgIndices = std::make_index_sequence<std::tuple_size<ArgTuple>::value>;
162
163
template <typename HandlerT>
164
static WrapperFunctionResult apply(HandlerT &&H, const char *ArgData,
165
size_t ArgSize) {
166
ArgTuple Args;
167
if (!deserialize(ArgData, ArgSize, Args, ArgIndices{}))
168
return WrapperFunctionResult::createOutOfBandError(
169
"Could not deserialize arguments for wrapper function call");
170
171
auto HandlerResult = WrapperFunctionHandlerCaller<RetT>::call(
172
std::forward<HandlerT>(H), Args, ArgIndices{});
173
174
return ResultSerializer<decltype(HandlerResult)>::serialize(
175
std::move(HandlerResult));
176
}
177
178
private:
179
template <std::size_t... I>
180
static bool deserialize(const char *ArgData, size_t ArgSize, ArgTuple &Args,
181
std::index_sequence<I...>) {
182
SPSInputBuffer IB(ArgData, ArgSize);
183
return SPSArgList<SPSTagTs...>::deserialize(IB, std::get<I>(Args)...);
184
}
185
};
186
187
// Map function pointers to function types.
188
template <typename RetT, typename... ArgTs,
189
template <typename> class ResultSerializer, typename... SPSTagTs>
190
class WrapperFunctionHandlerHelper<RetT (*)(ArgTs...), ResultSerializer,
191
SPSTagTs...>
192
: public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
193
SPSTagTs...> {};
194
195
// Map non-const member function types to function types.
196
template <typename ClassT, typename RetT, typename... ArgTs,
197
template <typename> class ResultSerializer, typename... SPSTagTs>
198
class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...), ResultSerializer,
199
SPSTagTs...>
200
: public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
201
SPSTagTs...> {};
202
203
// Map const member function types to function types.
204
template <typename ClassT, typename RetT, typename... ArgTs,
205
template <typename> class ResultSerializer, typename... SPSTagTs>
206
class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...) const,
207
ResultSerializer, SPSTagTs...>
208
: public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,
209
SPSTagTs...> {};
210
211
template <typename SPSRetTagT, typename RetT> class ResultSerializer {
212
public:
213
static WrapperFunctionResult serialize(RetT Result) {
214
return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>(Result);
215
}
216
};
217
218
template <typename SPSRetTagT> class ResultSerializer<SPSRetTagT, Error> {
219
public:
220
static WrapperFunctionResult serialize(Error Err) {
221
return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>(
222
toSPSSerializable(std::move(Err)));
223
}
224
};
225
226
template <typename SPSRetTagT, typename T>
227
class ResultSerializer<SPSRetTagT, Expected<T>> {
228
public:
229
static WrapperFunctionResult serialize(Expected<T> E) {
230
return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>(
231
toSPSSerializable(std::move(E)));
232
}
233
};
234
235
template <typename SPSRetTagT, typename RetT> class ResultDeserializer {
236
public:
237
static void makeSafe(RetT &Result) {}
238
239
static Error deserialize(RetT &Result, const char *ArgData, size_t ArgSize) {
240
SPSInputBuffer IB(ArgData, ArgSize);
241
if (!SPSArgList<SPSRetTagT>::deserialize(IB, Result))
242
return make_error<StringError>(
243
"Error deserializing return value from blob in call");
244
return Error::success();
245
}
246
};
247
248
template <> class ResultDeserializer<SPSError, Error> {
249
public:
250
static void makeSafe(Error &Err) { cantFail(std::move(Err)); }
251
252
static Error deserialize(Error &Err, const char *ArgData, size_t ArgSize) {
253
SPSInputBuffer IB(ArgData, ArgSize);
254
SPSSerializableError BSE;
255
if (!SPSArgList<SPSError>::deserialize(IB, BSE))
256
return make_error<StringError>(
257
"Error deserializing return value from blob in call");
258
Err = fromSPSSerializable(std::move(BSE));
259
return Error::success();
260
}
261
};
262
263
template <typename SPSTagT, typename T>
264
class ResultDeserializer<SPSExpected<SPSTagT>, Expected<T>> {
265
public:
266
static void makeSafe(Expected<T> &E) { cantFail(E.takeError()); }
267
268
static Error deserialize(Expected<T> &E, const char *ArgData,
269
size_t ArgSize) {
270
SPSInputBuffer IB(ArgData, ArgSize);
271
SPSSerializableExpected<T> BSE;
272
if (!SPSArgList<SPSExpected<SPSTagT>>::deserialize(IB, BSE))
273
return make_error<StringError>(
274
"Error deserializing return value from blob in call");
275
E = fromSPSSerializable(std::move(BSE));
276
return Error::success();
277
}
278
};
279
280
} // end namespace detail
281
282
template <typename SPSSignature> class WrapperFunction;
283
284
template <typename SPSRetTagT, typename... SPSTagTs>
285
class WrapperFunction<SPSRetTagT(SPSTagTs...)> {
286
private:
287
template <typename RetT>
288
using ResultSerializer = detail::ResultSerializer<SPSRetTagT, RetT>;
289
290
public:
291
template <typename RetT, typename... ArgTs>
292
static Error call(const void *FnTag, RetT &Result, const ArgTs &...Args) {
293
294
// RetT might be an Error or Expected value. Set the checked flag now:
295
// we don't want the user to have to check the unused result if this
296
// operation fails.
297
detail::ResultDeserializer<SPSRetTagT, RetT>::makeSafe(Result);
298
299
// Since the functions cannot be zero/unresolved on Windows, the following
300
// reference taking would always be non-zero, thus generating a compiler
301
// warning otherwise.
302
#if !defined(_WIN32)
303
if (ORC_RT_UNLIKELY(!&__orc_rt_jit_dispatch_ctx))
304
return make_error<StringError>("__orc_rt_jit_dispatch_ctx not set");
305
if (ORC_RT_UNLIKELY(!&__orc_rt_jit_dispatch))
306
return make_error<StringError>("__orc_rt_jit_dispatch not set");
307
#endif
308
auto ArgBuffer =
309
WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSTagTs...>>(Args...);
310
if (const char *ErrMsg = ArgBuffer.getOutOfBandError())
311
return make_error<StringError>(ErrMsg);
312
313
WrapperFunctionResult ResultBuffer = __orc_rt_jit_dispatch(
314
&__orc_rt_jit_dispatch_ctx, FnTag, ArgBuffer.data(), ArgBuffer.size());
315
if (auto ErrMsg = ResultBuffer.getOutOfBandError())
316
return make_error<StringError>(ErrMsg);
317
318
return detail::ResultDeserializer<SPSRetTagT, RetT>::deserialize(
319
Result, ResultBuffer.data(), ResultBuffer.size());
320
}
321
322
template <typename HandlerT>
323
static WrapperFunctionResult handle(const char *ArgData, size_t ArgSize,
324
HandlerT &&Handler) {
325
using WFHH =
326
detail::WrapperFunctionHandlerHelper<std::remove_reference_t<HandlerT>,
327
ResultSerializer, SPSTagTs...>;
328
return WFHH::apply(std::forward<HandlerT>(Handler), ArgData, ArgSize);
329
}
330
331
private:
332
template <typename T> static const T &makeSerializable(const T &Value) {
333
return Value;
334
}
335
336
static detail::SPSSerializableError makeSerializable(Error Err) {
337
return detail::toSPSSerializable(std::move(Err));
338
}
339
340
template <typename T>
341
static detail::SPSSerializableExpected<T> makeSerializable(Expected<T> E) {
342
return detail::toSPSSerializable(std::move(E));
343
}
344
};
345
346
template <typename... SPSTagTs>
347
class WrapperFunction<void(SPSTagTs...)>
348
: private WrapperFunction<SPSEmpty(SPSTagTs...)> {
349
public:
350
template <typename... ArgTs>
351
static Error call(const void *FnTag, const ArgTs &...Args) {
352
SPSEmpty BE;
353
return WrapperFunction<SPSEmpty(SPSTagTs...)>::call(FnTag, BE, Args...);
354
}
355
356
using WrapperFunction<SPSEmpty(SPSTagTs...)>::handle;
357
};
358
359
/// A function object that takes an ExecutorAddr as its first argument,
360
/// casts that address to a ClassT*, then calls the given method on that
361
/// pointer passing in the remaining function arguments. This utility
362
/// removes some of the boilerplate from writing wrappers for method calls.
363
///
364
/// @code{.cpp}
365
/// class MyClass {
366
/// public:
367
/// void myMethod(uint32_t, bool) { ... }
368
/// };
369
///
370
/// // SPS Method signature -- note MyClass object address as first argument.
371
/// using SPSMyMethodWrapperSignature =
372
/// SPSTuple<SPSExecutorAddr, uint32_t, bool>;
373
///
374
/// WrapperFunctionResult
375
/// myMethodCallWrapper(const char *ArgData, size_t ArgSize) {
376
/// return WrapperFunction<SPSMyMethodWrapperSignature>::handle(
377
/// ArgData, ArgSize, makeMethodWrapperHandler(&MyClass::myMethod));
378
/// }
379
/// @endcode
380
///
381
template <typename RetT, typename ClassT, typename... ArgTs>
382
class MethodWrapperHandler {
383
public:
384
using MethodT = RetT (ClassT::*)(ArgTs...);
385
MethodWrapperHandler(MethodT M) : M(M) {}
386
RetT operator()(ExecutorAddr ObjAddr, ArgTs &...Args) {
387
return (ObjAddr.toPtr<ClassT *>()->*M)(std::forward<ArgTs>(Args)...);
388
}
389
390
private:
391
MethodT M;
392
};
393
394
/// Create a MethodWrapperHandler object from the given method pointer.
395
template <typename RetT, typename ClassT, typename... ArgTs>
396
MethodWrapperHandler<RetT, ClassT, ArgTs...>
397
makeMethodWrapperHandler(RetT (ClassT::*Method)(ArgTs...)) {
398
return MethodWrapperHandler<RetT, ClassT, ArgTs...>(Method);
399
}
400
401
/// Represents a call to a wrapper function.
402
class WrapperFunctionCall {
403
public:
404
// FIXME: Switch to a SmallVector<char, 24> once ORC runtime has a
405
// smallvector.
406
using ArgDataBufferType = std::vector<char>;
407
408
/// Create a WrapperFunctionCall using the given SPS serializer to serialize
409
/// the arguments.
410
template <typename SPSSerializer, typename... ArgTs>
411
static Expected<WrapperFunctionCall> Create(ExecutorAddr FnAddr,
412
const ArgTs &...Args) {
413
ArgDataBufferType ArgData;
414
ArgData.resize(SPSSerializer::size(Args...));
415
SPSOutputBuffer OB(ArgData.empty() ? nullptr : ArgData.data(),
416
ArgData.size());
417
if (SPSSerializer::serialize(OB, Args...))
418
return WrapperFunctionCall(FnAddr, std::move(ArgData));
419
return make_error<StringError>("Cannot serialize arguments for "
420
"AllocActionCall");
421
}
422
423
WrapperFunctionCall() = default;
424
425
/// Create a WrapperFunctionCall from a target function and arg buffer.
426
WrapperFunctionCall(ExecutorAddr FnAddr, ArgDataBufferType ArgData)
427
: FnAddr(FnAddr), ArgData(std::move(ArgData)) {}
428
429
/// Returns the address to be called.
430
const ExecutorAddr &getCallee() const { return FnAddr; }
431
432
/// Returns the argument data.
433
const ArgDataBufferType &getArgData() const { return ArgData; }
434
435
/// WrapperFunctionCalls convert to true if the callee is non-null.
436
explicit operator bool() const { return !!FnAddr; }
437
438
/// Run call returning raw WrapperFunctionResult.
439
WrapperFunctionResult run() const {
440
using FnTy =
441
orc_rt_CWrapperFunctionResult(const char *ArgData, size_t ArgSize);
442
return WrapperFunctionResult(
443
FnAddr.toPtr<FnTy *>()(ArgData.data(), ArgData.size()));
444
}
445
446
/// Run call and deserialize result using SPS.
447
template <typename SPSRetT, typename RetT>
448
std::enable_if_t<!std::is_same<SPSRetT, void>::value, Error>
449
runWithSPSRet(RetT &RetVal) const {
450
auto WFR = run();
451
if (const char *ErrMsg = WFR.getOutOfBandError())
452
return make_error<StringError>(ErrMsg);
453
SPSInputBuffer IB(WFR.data(), WFR.size());
454
if (!SPSSerializationTraits<SPSRetT, RetT>::deserialize(IB, RetVal))
455
return make_error<StringError>("Could not deserialize result from "
456
"serialized wrapper function call");
457
return Error::success();
458
}
459
460
/// Overload for SPS functions returning void.
461
template <typename SPSRetT>
462
std::enable_if_t<std::is_same<SPSRetT, void>::value, Error>
463
runWithSPSRet() const {
464
SPSEmpty E;
465
return runWithSPSRet<SPSEmpty>(E);
466
}
467
468
/// Run call and deserialize an SPSError result. SPSError returns and
469
/// deserialization failures are merged into the returned error.
470
Error runWithSPSRetErrorMerged() const {
471
detail::SPSSerializableError RetErr;
472
if (auto Err = runWithSPSRet<SPSError>(RetErr))
473
return Err;
474
return detail::fromSPSSerializable(std::move(RetErr));
475
}
476
477
private:
478
ExecutorAddr FnAddr;
479
std::vector<char> ArgData;
480
};
481
482
using SPSWrapperFunctionCall = SPSTuple<SPSExecutorAddr, SPSSequence<char>>;
483
484
template <>
485
class SPSSerializationTraits<SPSWrapperFunctionCall, WrapperFunctionCall> {
486
public:
487
static size_t size(const WrapperFunctionCall &WFC) {
488
return SPSArgList<SPSExecutorAddr, SPSSequence<char>>::size(
489
WFC.getCallee(), WFC.getArgData());
490
}
491
492
static bool serialize(SPSOutputBuffer &OB, const WrapperFunctionCall &WFC) {
493
return SPSArgList<SPSExecutorAddr, SPSSequence<char>>::serialize(
494
OB, WFC.getCallee(), WFC.getArgData());
495
}
496
497
static bool deserialize(SPSInputBuffer &IB, WrapperFunctionCall &WFC) {
498
ExecutorAddr FnAddr;
499
WrapperFunctionCall::ArgDataBufferType ArgData;
500
if (!SPSWrapperFunctionCall::AsArgList::deserialize(IB, FnAddr, ArgData))
501
return false;
502
WFC = WrapperFunctionCall(FnAddr, std::move(ArgData));
503
return true;
504
}
505
};
506
507
} // end namespace __orc_rt
508
509
#endif // ORC_RT_WRAPPER_FUNCTION_UTILS_H
510
511