Path: blob/main/contrib/llvm-project/compiler-rt/lib/orc/wrapper_function_utils.h
39566 views
//===-- wrapper_function_utils.h - Utilities for wrapper funcs --*- C++ -*-===//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.3// See https://llvm.org/LICENSE.txt for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===----------------------------------------------------------------------===//7//8// This file is a part of the ORC runtime support library.9//10//===----------------------------------------------------------------------===//1112#ifndef ORC_RT_WRAPPER_FUNCTION_UTILS_H13#define ORC_RT_WRAPPER_FUNCTION_UTILS_H1415#include "orc_rt/c_api.h"16#include "common.h"17#include "error.h"18#include "executor_address.h"19#include "simple_packed_serialization.h"20#include <type_traits>2122namespace __orc_rt {2324/// C++ wrapper function result: Same as CWrapperFunctionResult but25/// auto-releases memory.26class WrapperFunctionResult {27public:28/// Create a default WrapperFunctionResult.29WrapperFunctionResult() { orc_rt_CWrapperFunctionResultInit(&R); }3031/// Create a WrapperFunctionResult from a CWrapperFunctionResult. This32/// instance takes ownership of the result object and will automatically33/// call dispose on the result upon destruction.34WrapperFunctionResult(orc_rt_CWrapperFunctionResult R) : R(R) {}3536WrapperFunctionResult(const WrapperFunctionResult &) = delete;37WrapperFunctionResult &operator=(const WrapperFunctionResult &) = delete;3839WrapperFunctionResult(WrapperFunctionResult &&Other) {40orc_rt_CWrapperFunctionResultInit(&R);41std::swap(R, Other.R);42}4344WrapperFunctionResult &operator=(WrapperFunctionResult &&Other) {45orc_rt_CWrapperFunctionResult Tmp;46orc_rt_CWrapperFunctionResultInit(&Tmp);47std::swap(Tmp, Other.R);48std::swap(R, Tmp);49return *this;50}5152~WrapperFunctionResult() { orc_rt_DisposeCWrapperFunctionResult(&R); }5354/// Relinquish ownership of and return the55/// orc_rt_CWrapperFunctionResult.56orc_rt_CWrapperFunctionResult release() {57orc_rt_CWrapperFunctionResult Tmp;58orc_rt_CWrapperFunctionResultInit(&Tmp);59std::swap(R, Tmp);60return Tmp;61}6263/// Get a pointer to the data contained in this instance.64char *data() { return orc_rt_CWrapperFunctionResultData(&R); }6566/// Returns the size of the data contained in this instance.67size_t size() const { return orc_rt_CWrapperFunctionResultSize(&R); }6869/// Returns true if this value is equivalent to a default-constructed70/// WrapperFunctionResult.71bool empty() const { return orc_rt_CWrapperFunctionResultEmpty(&R); }7273/// Create a WrapperFunctionResult with the given size and return a pointer74/// to the underlying memory.75static WrapperFunctionResult allocate(size_t Size) {76WrapperFunctionResult R;77R.R = orc_rt_CWrapperFunctionResultAllocate(Size);78return R;79}8081/// Copy from the given char range.82static WrapperFunctionResult copyFrom(const char *Source, size_t Size) {83return orc_rt_CreateCWrapperFunctionResultFromRange(Source, Size);84}8586/// Copy from the given null-terminated string (includes the null-terminator).87static WrapperFunctionResult copyFrom(const char *Source) {88return orc_rt_CreateCWrapperFunctionResultFromString(Source);89}9091/// Copy from the given std::string (includes the null terminator).92static WrapperFunctionResult copyFrom(const std::string &Source) {93return copyFrom(Source.c_str());94}9596/// Create an out-of-band error by copying the given string.97static WrapperFunctionResult createOutOfBandError(const char *Msg) {98return orc_rt_CreateCWrapperFunctionResultFromOutOfBandError(Msg);99}100101/// Create an out-of-band error by copying the given string.102static WrapperFunctionResult createOutOfBandError(const std::string &Msg) {103return createOutOfBandError(Msg.c_str());104}105106template <typename SPSArgListT, typename... ArgTs>107static WrapperFunctionResult fromSPSArgs(const ArgTs &...Args) {108auto Result = allocate(SPSArgListT::size(Args...));109SPSOutputBuffer OB(Result.data(), Result.size());110if (!SPSArgListT::serialize(OB, Args...))111return createOutOfBandError(112"Error serializing arguments to blob in call");113return Result;114}115116/// If this value is an out-of-band error then this returns the error message,117/// otherwise returns nullptr.118const char *getOutOfBandError() const {119return orc_rt_CWrapperFunctionResultGetOutOfBandError(&R);120}121122private:123orc_rt_CWrapperFunctionResult R;124};125126namespace detail {127128template <typename RetT> class WrapperFunctionHandlerCaller {129public:130template <typename HandlerT, typename ArgTupleT, std::size_t... I>131static decltype(auto) call(HandlerT &&H, ArgTupleT &Args,132std::index_sequence<I...>) {133return std::forward<HandlerT>(H)(std::get<I>(Args)...);134}135};136137template <> class WrapperFunctionHandlerCaller<void> {138public:139template <typename HandlerT, typename ArgTupleT, std::size_t... I>140static SPSEmpty call(HandlerT &&H, ArgTupleT &Args,141std::index_sequence<I...>) {142std::forward<HandlerT>(H)(std::get<I>(Args)...);143return SPSEmpty();144}145};146147template <typename WrapperFunctionImplT,148template <typename> class ResultSerializer, typename... SPSTagTs>149class WrapperFunctionHandlerHelper150: public WrapperFunctionHandlerHelper<151decltype(&std::remove_reference_t<WrapperFunctionImplT>::operator()),152ResultSerializer, SPSTagTs...> {};153154template <typename RetT, typename... ArgTs,155template <typename> class ResultSerializer, typename... SPSTagTs>156class WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,157SPSTagTs...> {158public:159using ArgTuple = std::tuple<std::decay_t<ArgTs>...>;160using ArgIndices = std::make_index_sequence<std::tuple_size<ArgTuple>::value>;161162template <typename HandlerT>163static WrapperFunctionResult apply(HandlerT &&H, const char *ArgData,164size_t ArgSize) {165ArgTuple Args;166if (!deserialize(ArgData, ArgSize, Args, ArgIndices{}))167return WrapperFunctionResult::createOutOfBandError(168"Could not deserialize arguments for wrapper function call");169170auto HandlerResult = WrapperFunctionHandlerCaller<RetT>::call(171std::forward<HandlerT>(H), Args, ArgIndices{});172173return ResultSerializer<decltype(HandlerResult)>::serialize(174std::move(HandlerResult));175}176177private:178template <std::size_t... I>179static bool deserialize(const char *ArgData, size_t ArgSize, ArgTuple &Args,180std::index_sequence<I...>) {181SPSInputBuffer IB(ArgData, ArgSize);182return SPSArgList<SPSTagTs...>::deserialize(IB, std::get<I>(Args)...);183}184};185186// Map function pointers to function types.187template <typename RetT, typename... ArgTs,188template <typename> class ResultSerializer, typename... SPSTagTs>189class WrapperFunctionHandlerHelper<RetT (*)(ArgTs...), ResultSerializer,190SPSTagTs...>191: public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,192SPSTagTs...> {};193194// Map non-const member function types to function types.195template <typename ClassT, typename RetT, typename... ArgTs,196template <typename> class ResultSerializer, typename... SPSTagTs>197class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...), ResultSerializer,198SPSTagTs...>199: public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,200SPSTagTs...> {};201202// Map const member function types to function types.203template <typename ClassT, typename RetT, typename... ArgTs,204template <typename> class ResultSerializer, typename... SPSTagTs>205class WrapperFunctionHandlerHelper<RetT (ClassT::*)(ArgTs...) const,206ResultSerializer, SPSTagTs...>207: public WrapperFunctionHandlerHelper<RetT(ArgTs...), ResultSerializer,208SPSTagTs...> {};209210template <typename SPSRetTagT, typename RetT> class ResultSerializer {211public:212static WrapperFunctionResult serialize(RetT Result) {213return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>(Result);214}215};216217template <typename SPSRetTagT> class ResultSerializer<SPSRetTagT, Error> {218public:219static WrapperFunctionResult serialize(Error Err) {220return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>(221toSPSSerializable(std::move(Err)));222}223};224225template <typename SPSRetTagT, typename T>226class ResultSerializer<SPSRetTagT, Expected<T>> {227public:228static WrapperFunctionResult serialize(Expected<T> E) {229return WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSRetTagT>>(230toSPSSerializable(std::move(E)));231}232};233234template <typename SPSRetTagT, typename RetT> class ResultDeserializer {235public:236static void makeSafe(RetT &Result) {}237238static Error deserialize(RetT &Result, const char *ArgData, size_t ArgSize) {239SPSInputBuffer IB(ArgData, ArgSize);240if (!SPSArgList<SPSRetTagT>::deserialize(IB, Result))241return make_error<StringError>(242"Error deserializing return value from blob in call");243return Error::success();244}245};246247template <> class ResultDeserializer<SPSError, Error> {248public:249static void makeSafe(Error &Err) { cantFail(std::move(Err)); }250251static Error deserialize(Error &Err, const char *ArgData, size_t ArgSize) {252SPSInputBuffer IB(ArgData, ArgSize);253SPSSerializableError BSE;254if (!SPSArgList<SPSError>::deserialize(IB, BSE))255return make_error<StringError>(256"Error deserializing return value from blob in call");257Err = fromSPSSerializable(std::move(BSE));258return Error::success();259}260};261262template <typename SPSTagT, typename T>263class ResultDeserializer<SPSExpected<SPSTagT>, Expected<T>> {264public:265static void makeSafe(Expected<T> &E) { cantFail(E.takeError()); }266267static Error deserialize(Expected<T> &E, const char *ArgData,268size_t ArgSize) {269SPSInputBuffer IB(ArgData, ArgSize);270SPSSerializableExpected<T> BSE;271if (!SPSArgList<SPSExpected<SPSTagT>>::deserialize(IB, BSE))272return make_error<StringError>(273"Error deserializing return value from blob in call");274E = fromSPSSerializable(std::move(BSE));275return Error::success();276}277};278279} // end namespace detail280281template <typename SPSSignature> class WrapperFunction;282283template <typename SPSRetTagT, typename... SPSTagTs>284class WrapperFunction<SPSRetTagT(SPSTagTs...)> {285private:286template <typename RetT>287using ResultSerializer = detail::ResultSerializer<SPSRetTagT, RetT>;288289public:290template <typename RetT, typename... ArgTs>291static Error call(const void *FnTag, RetT &Result, const ArgTs &...Args) {292293// RetT might be an Error or Expected value. Set the checked flag now:294// we don't want the user to have to check the unused result if this295// operation fails.296detail::ResultDeserializer<SPSRetTagT, RetT>::makeSafe(Result);297298// Since the functions cannot be zero/unresolved on Windows, the following299// reference taking would always be non-zero, thus generating a compiler300// warning otherwise.301#if !defined(_WIN32)302if (ORC_RT_UNLIKELY(!&__orc_rt_jit_dispatch_ctx))303return make_error<StringError>("__orc_rt_jit_dispatch_ctx not set");304if (ORC_RT_UNLIKELY(!&__orc_rt_jit_dispatch))305return make_error<StringError>("__orc_rt_jit_dispatch not set");306#endif307auto ArgBuffer =308WrapperFunctionResult::fromSPSArgs<SPSArgList<SPSTagTs...>>(Args...);309if (const char *ErrMsg = ArgBuffer.getOutOfBandError())310return make_error<StringError>(ErrMsg);311312WrapperFunctionResult ResultBuffer = __orc_rt_jit_dispatch(313&__orc_rt_jit_dispatch_ctx, FnTag, ArgBuffer.data(), ArgBuffer.size());314if (auto ErrMsg = ResultBuffer.getOutOfBandError())315return make_error<StringError>(ErrMsg);316317return detail::ResultDeserializer<SPSRetTagT, RetT>::deserialize(318Result, ResultBuffer.data(), ResultBuffer.size());319}320321template <typename HandlerT>322static WrapperFunctionResult handle(const char *ArgData, size_t ArgSize,323HandlerT &&Handler) {324using WFHH =325detail::WrapperFunctionHandlerHelper<std::remove_reference_t<HandlerT>,326ResultSerializer, SPSTagTs...>;327return WFHH::apply(std::forward<HandlerT>(Handler), ArgData, ArgSize);328}329330private:331template <typename T> static const T &makeSerializable(const T &Value) {332return Value;333}334335static detail::SPSSerializableError makeSerializable(Error Err) {336return detail::toSPSSerializable(std::move(Err));337}338339template <typename T>340static detail::SPSSerializableExpected<T> makeSerializable(Expected<T> E) {341return detail::toSPSSerializable(std::move(E));342}343};344345template <typename... SPSTagTs>346class WrapperFunction<void(SPSTagTs...)>347: private WrapperFunction<SPSEmpty(SPSTagTs...)> {348public:349template <typename... ArgTs>350static Error call(const void *FnTag, const ArgTs &...Args) {351SPSEmpty BE;352return WrapperFunction<SPSEmpty(SPSTagTs...)>::call(FnTag, BE, Args...);353}354355using WrapperFunction<SPSEmpty(SPSTagTs...)>::handle;356};357358/// A function object that takes an ExecutorAddr as its first argument,359/// casts that address to a ClassT*, then calls the given method on that360/// pointer passing in the remaining function arguments. This utility361/// removes some of the boilerplate from writing wrappers for method calls.362///363/// @code{.cpp}364/// class MyClass {365/// public:366/// void myMethod(uint32_t, bool) { ... }367/// };368///369/// // SPS Method signature -- note MyClass object address as first argument.370/// using SPSMyMethodWrapperSignature =371/// SPSTuple<SPSExecutorAddr, uint32_t, bool>;372///373/// WrapperFunctionResult374/// myMethodCallWrapper(const char *ArgData, size_t ArgSize) {375/// return WrapperFunction<SPSMyMethodWrapperSignature>::handle(376/// ArgData, ArgSize, makeMethodWrapperHandler(&MyClass::myMethod));377/// }378/// @endcode379///380template <typename RetT, typename ClassT, typename... ArgTs>381class MethodWrapperHandler {382public:383using MethodT = RetT (ClassT::*)(ArgTs...);384MethodWrapperHandler(MethodT M) : M(M) {}385RetT operator()(ExecutorAddr ObjAddr, ArgTs &...Args) {386return (ObjAddr.toPtr<ClassT *>()->*M)(std::forward<ArgTs>(Args)...);387}388389private:390MethodT M;391};392393/// Create a MethodWrapperHandler object from the given method pointer.394template <typename RetT, typename ClassT, typename... ArgTs>395MethodWrapperHandler<RetT, ClassT, ArgTs...>396makeMethodWrapperHandler(RetT (ClassT::*Method)(ArgTs...)) {397return MethodWrapperHandler<RetT, ClassT, ArgTs...>(Method);398}399400/// Represents a call to a wrapper function.401class WrapperFunctionCall {402public:403// FIXME: Switch to a SmallVector<char, 24> once ORC runtime has a404// smallvector.405using ArgDataBufferType = std::vector<char>;406407/// Create a WrapperFunctionCall using the given SPS serializer to serialize408/// the arguments.409template <typename SPSSerializer, typename... ArgTs>410static Expected<WrapperFunctionCall> Create(ExecutorAddr FnAddr,411const ArgTs &...Args) {412ArgDataBufferType ArgData;413ArgData.resize(SPSSerializer::size(Args...));414SPSOutputBuffer OB(ArgData.empty() ? nullptr : ArgData.data(),415ArgData.size());416if (SPSSerializer::serialize(OB, Args...))417return WrapperFunctionCall(FnAddr, std::move(ArgData));418return make_error<StringError>("Cannot serialize arguments for "419"AllocActionCall");420}421422WrapperFunctionCall() = default;423424/// Create a WrapperFunctionCall from a target function and arg buffer.425WrapperFunctionCall(ExecutorAddr FnAddr, ArgDataBufferType ArgData)426: FnAddr(FnAddr), ArgData(std::move(ArgData)) {}427428/// Returns the address to be called.429const ExecutorAddr &getCallee() const { return FnAddr; }430431/// Returns the argument data.432const ArgDataBufferType &getArgData() const { return ArgData; }433434/// WrapperFunctionCalls convert to true if the callee is non-null.435explicit operator bool() const { return !!FnAddr; }436437/// Run call returning raw WrapperFunctionResult.438WrapperFunctionResult run() const {439using FnTy =440orc_rt_CWrapperFunctionResult(const char *ArgData, size_t ArgSize);441return WrapperFunctionResult(442FnAddr.toPtr<FnTy *>()(ArgData.data(), ArgData.size()));443}444445/// Run call and deserialize result using SPS.446template <typename SPSRetT, typename RetT>447std::enable_if_t<!std::is_same<SPSRetT, void>::value, Error>448runWithSPSRet(RetT &RetVal) const {449auto WFR = run();450if (const char *ErrMsg = WFR.getOutOfBandError())451return make_error<StringError>(ErrMsg);452SPSInputBuffer IB(WFR.data(), WFR.size());453if (!SPSSerializationTraits<SPSRetT, RetT>::deserialize(IB, RetVal))454return make_error<StringError>("Could not deserialize result from "455"serialized wrapper function call");456return Error::success();457}458459/// Overload for SPS functions returning void.460template <typename SPSRetT>461std::enable_if_t<std::is_same<SPSRetT, void>::value, Error>462runWithSPSRet() const {463SPSEmpty E;464return runWithSPSRet<SPSEmpty>(E);465}466467/// Run call and deserialize an SPSError result. SPSError returns and468/// deserialization failures are merged into the returned error.469Error runWithSPSRetErrorMerged() const {470detail::SPSSerializableError RetErr;471if (auto Err = runWithSPSRet<SPSError>(RetErr))472return Err;473return detail::fromSPSSerializable(std::move(RetErr));474}475476private:477ExecutorAddr FnAddr;478std::vector<char> ArgData;479};480481using SPSWrapperFunctionCall = SPSTuple<SPSExecutorAddr, SPSSequence<char>>;482483template <>484class SPSSerializationTraits<SPSWrapperFunctionCall, WrapperFunctionCall> {485public:486static size_t size(const WrapperFunctionCall &WFC) {487return SPSArgList<SPSExecutorAddr, SPSSequence<char>>::size(488WFC.getCallee(), WFC.getArgData());489}490491static bool serialize(SPSOutputBuffer &OB, const WrapperFunctionCall &WFC) {492return SPSArgList<SPSExecutorAddr, SPSSequence<char>>::serialize(493OB, WFC.getCallee(), WFC.getArgData());494}495496static bool deserialize(SPSInputBuffer &IB, WrapperFunctionCall &WFC) {497ExecutorAddr FnAddr;498WrapperFunctionCall::ArgDataBufferType ArgData;499if (!SPSWrapperFunctionCall::AsArgList::deserialize(IB, FnAddr, ArgData))500return false;501WFC = WrapperFunctionCall(FnAddr, std::move(ArgData));502return true;503}504};505506} // end namespace __orc_rt507508#endif // ORC_RT_WRAPPER_FUNCTION_UTILS_H509510511