Path: blob/main/contrib/llvm-project/llvm/lib/ExecutionEngine/Orc/SimpleRemoteEPC.cpp
35266 views
//===------- SimpleRemoteEPC.cpp -- Simple remote executor control --------===//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.3// See https://llvm.org/LICENSE.txt for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===----------------------------------------------------------------------===//78#include "llvm/ExecutionEngine/Orc/SimpleRemoteEPC.h"9#include "llvm/ExecutionEngine/Orc/EPCGenericJITLinkMemoryManager.h"10#include "llvm/ExecutionEngine/Orc/EPCGenericMemoryAccess.h"11#include "llvm/ExecutionEngine/Orc/Shared/OrcRTBridge.h"12#include "llvm/Support/FormatVariadic.h"1314#define DEBUG_TYPE "orc"1516namespace llvm {17namespace orc {1819SimpleRemoteEPC::~SimpleRemoteEPC() {20#ifndef NDEBUG21std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);22assert(Disconnected && "Destroyed without disconnection");23#endif // NDEBUG24}2526Expected<tpctypes::DylibHandle>27SimpleRemoteEPC::loadDylib(const char *DylibPath) {28return DylibMgr->open(DylibPath, 0);29}3031/// Async helper to chain together calls to DylibMgr::lookupAsync to fulfill all32/// all the requests.33/// FIXME: The dylib manager should support multiple LookupRequests natively.34static void35lookupSymbolsAsyncHelper(EPCGenericDylibManager &DylibMgr,36ArrayRef<SimpleRemoteEPC::LookupRequest> Request,37std::vector<tpctypes::LookupResult> Result,38SimpleRemoteEPC::SymbolLookupCompleteFn Complete) {39if (Request.empty())40return Complete(std::move(Result));4142auto &Element = Request.front();43DylibMgr.lookupAsync(Element.Handle, Element.Symbols,44[&DylibMgr, Request, Complete = std::move(Complete),45Result = std::move(Result)](auto R) mutable {46if (!R)47return Complete(R.takeError());48Result.push_back({});49Result.back().reserve(R->size());50for (auto Addr : *R)51Result.back().push_back(Addr);5253lookupSymbolsAsyncHelper(54DylibMgr, Request.drop_front(), std::move(Result),55std::move(Complete));56});57}5859void SimpleRemoteEPC::lookupSymbolsAsync(ArrayRef<LookupRequest> Request,60SymbolLookupCompleteFn Complete) {61lookupSymbolsAsyncHelper(*DylibMgr, Request, {}, std::move(Complete));62}6364Expected<int32_t> SimpleRemoteEPC::runAsMain(ExecutorAddr MainFnAddr,65ArrayRef<std::string> Args) {66int64_t Result = 0;67if (auto Err = callSPSWrapper<rt::SPSRunAsMainSignature>(68RunAsMainAddr, Result, MainFnAddr, Args))69return std::move(Err);70return Result;71}7273Expected<int32_t> SimpleRemoteEPC::runAsVoidFunction(ExecutorAddr VoidFnAddr) {74int32_t Result = 0;75if (auto Err = callSPSWrapper<rt::SPSRunAsVoidFunctionSignature>(76RunAsVoidFunctionAddr, Result, VoidFnAddr))77return std::move(Err);78return Result;79}8081Expected<int32_t> SimpleRemoteEPC::runAsIntFunction(ExecutorAddr IntFnAddr,82int Arg) {83int32_t Result = 0;84if (auto Err = callSPSWrapper<rt::SPSRunAsIntFunctionSignature>(85RunAsIntFunctionAddr, Result, IntFnAddr, Arg))86return std::move(Err);87return Result;88}8990void SimpleRemoteEPC::callWrapperAsync(ExecutorAddr WrapperFnAddr,91IncomingWFRHandler OnComplete,92ArrayRef<char> ArgBuffer) {93uint64_t SeqNo;94{95std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);96SeqNo = getNextSeqNo();97assert(!PendingCallWrapperResults.count(SeqNo) && "SeqNo already in use");98PendingCallWrapperResults[SeqNo] = std::move(OnComplete);99}100101if (auto Err = sendMessage(SimpleRemoteEPCOpcode::CallWrapper, SeqNo,102WrapperFnAddr, ArgBuffer)) {103IncomingWFRHandler H;104105// We just registered OnComplete, but there may be a race between this106// thread returning from sendMessage and handleDisconnect being called from107// the transport's listener thread. If handleDisconnect gets there first108// then it will have failed 'H' for us. If we get there first (or if109// handleDisconnect already ran) then we need to take care of it.110{111std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);112auto I = PendingCallWrapperResults.find(SeqNo);113if (I != PendingCallWrapperResults.end()) {114H = std::move(I->second);115PendingCallWrapperResults.erase(I);116}117}118119if (H)120H(shared::WrapperFunctionResult::createOutOfBandError("disconnecting"));121122getExecutionSession().reportError(std::move(Err));123}124}125126Error SimpleRemoteEPC::disconnect() {127T->disconnect();128D->shutdown();129std::unique_lock<std::mutex> Lock(SimpleRemoteEPCMutex);130DisconnectCV.wait(Lock, [this] { return Disconnected; });131return std::move(DisconnectErr);132}133134Expected<SimpleRemoteEPCTransportClient::HandleMessageAction>135SimpleRemoteEPC::handleMessage(SimpleRemoteEPCOpcode OpC, uint64_t SeqNo,136ExecutorAddr TagAddr,137SimpleRemoteEPCArgBytesVector ArgBytes) {138139LLVM_DEBUG({140dbgs() << "SimpleRemoteEPC::handleMessage: opc = ";141switch (OpC) {142case SimpleRemoteEPCOpcode::Setup:143dbgs() << "Setup";144assert(SeqNo == 0 && "Non-zero SeqNo for Setup?");145assert(!TagAddr && "Non-zero TagAddr for Setup?");146break;147case SimpleRemoteEPCOpcode::Hangup:148dbgs() << "Hangup";149assert(SeqNo == 0 && "Non-zero SeqNo for Hangup?");150assert(!TagAddr && "Non-zero TagAddr for Hangup?");151break;152case SimpleRemoteEPCOpcode::Result:153dbgs() << "Result";154assert(!TagAddr && "Non-zero TagAddr for Result?");155break;156case SimpleRemoteEPCOpcode::CallWrapper:157dbgs() << "CallWrapper";158break;159}160dbgs() << ", seqno = " << SeqNo << ", tag-addr = " << TagAddr161<< ", arg-buffer = " << formatv("{0:x}", ArgBytes.size())162<< " bytes\n";163});164165using UT = std::underlying_type_t<SimpleRemoteEPCOpcode>;166if (static_cast<UT>(OpC) > static_cast<UT>(SimpleRemoteEPCOpcode::LastOpC))167return make_error<StringError>("Unexpected opcode",168inconvertibleErrorCode());169170switch (OpC) {171case SimpleRemoteEPCOpcode::Setup:172if (auto Err = handleSetup(SeqNo, TagAddr, std::move(ArgBytes)))173return std::move(Err);174break;175case SimpleRemoteEPCOpcode::Hangup:176T->disconnect();177if (auto Err = handleHangup(std::move(ArgBytes)))178return std::move(Err);179return EndSession;180case SimpleRemoteEPCOpcode::Result:181if (auto Err = handleResult(SeqNo, TagAddr, std::move(ArgBytes)))182return std::move(Err);183break;184case SimpleRemoteEPCOpcode::CallWrapper:185handleCallWrapper(SeqNo, TagAddr, std::move(ArgBytes));186break;187}188return ContinueSession;189}190191void SimpleRemoteEPC::handleDisconnect(Error Err) {192LLVM_DEBUG({193dbgs() << "SimpleRemoteEPC::handleDisconnect: "194<< (Err ? "failure" : "success") << "\n";195});196197PendingCallWrapperResultsMap TmpPending;198199{200std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);201std::swap(TmpPending, PendingCallWrapperResults);202}203204for (auto &KV : TmpPending)205KV.second(206shared::WrapperFunctionResult::createOutOfBandError("disconnecting"));207208std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);209DisconnectErr = joinErrors(std::move(DisconnectErr), std::move(Err));210Disconnected = true;211DisconnectCV.notify_all();212}213214Expected<std::unique_ptr<jitlink::JITLinkMemoryManager>>215SimpleRemoteEPC::createDefaultMemoryManager(SimpleRemoteEPC &SREPC) {216EPCGenericJITLinkMemoryManager::SymbolAddrs SAs;217if (auto Err = SREPC.getBootstrapSymbols(218{{SAs.Allocator, rt::SimpleExecutorMemoryManagerInstanceName},219{SAs.Reserve, rt::SimpleExecutorMemoryManagerReserveWrapperName},220{SAs.Finalize, rt::SimpleExecutorMemoryManagerFinalizeWrapperName},221{SAs.Deallocate,222rt::SimpleExecutorMemoryManagerDeallocateWrapperName}}))223return std::move(Err);224225return std::make_unique<EPCGenericJITLinkMemoryManager>(SREPC, SAs);226}227228Expected<std::unique_ptr<ExecutorProcessControl::MemoryAccess>>229SimpleRemoteEPC::createDefaultMemoryAccess(SimpleRemoteEPC &SREPC) {230return nullptr;231}232233Error SimpleRemoteEPC::sendMessage(SimpleRemoteEPCOpcode OpC, uint64_t SeqNo,234ExecutorAddr TagAddr,235ArrayRef<char> ArgBytes) {236assert(OpC != SimpleRemoteEPCOpcode::Setup &&237"SimpleRemoteEPC sending Setup message? That's the wrong direction.");238239LLVM_DEBUG({240dbgs() << "SimpleRemoteEPC::sendMessage: opc = ";241switch (OpC) {242case SimpleRemoteEPCOpcode::Hangup:243dbgs() << "Hangup";244assert(SeqNo == 0 && "Non-zero SeqNo for Hangup?");245assert(!TagAddr && "Non-zero TagAddr for Hangup?");246break;247case SimpleRemoteEPCOpcode::Result:248dbgs() << "Result";249assert(!TagAddr && "Non-zero TagAddr for Result?");250break;251case SimpleRemoteEPCOpcode::CallWrapper:252dbgs() << "CallWrapper";253break;254default:255llvm_unreachable("Invalid opcode");256}257dbgs() << ", seqno = " << SeqNo << ", tag-addr = " << TagAddr258<< ", arg-buffer = " << formatv("{0:x}", ArgBytes.size())259<< " bytes\n";260});261auto Err = T->sendMessage(OpC, SeqNo, TagAddr, ArgBytes);262LLVM_DEBUG({263if (Err)264dbgs() << " \\--> SimpleRemoteEPC::sendMessage failed\n";265});266return Err;267}268269Error SimpleRemoteEPC::handleSetup(uint64_t SeqNo, ExecutorAddr TagAddr,270SimpleRemoteEPCArgBytesVector ArgBytes) {271if (SeqNo != 0)272return make_error<StringError>("Setup packet SeqNo not zero",273inconvertibleErrorCode());274275if (TagAddr)276return make_error<StringError>("Setup packet TagAddr not zero",277inconvertibleErrorCode());278279std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);280auto I = PendingCallWrapperResults.find(0);281assert(PendingCallWrapperResults.size() == 1 &&282I != PendingCallWrapperResults.end() &&283"Setup message handler not connectly set up");284auto SetupMsgHandler = std::move(I->second);285PendingCallWrapperResults.erase(I);286287auto WFR =288shared::WrapperFunctionResult::copyFrom(ArgBytes.data(), ArgBytes.size());289SetupMsgHandler(std::move(WFR));290return Error::success();291}292293Error SimpleRemoteEPC::setup(Setup S) {294using namespace SimpleRemoteEPCDefaultBootstrapSymbolNames;295296std::promise<MSVCPExpected<SimpleRemoteEPCExecutorInfo>> EIP;297auto EIF = EIP.get_future();298299// Prepare a handler for the setup packet.300PendingCallWrapperResults[0] =301RunInPlace()(302[&](shared::WrapperFunctionResult SetupMsgBytes) {303if (const char *ErrMsg = SetupMsgBytes.getOutOfBandError()) {304EIP.set_value(305make_error<StringError>(ErrMsg, inconvertibleErrorCode()));306return;307}308using SPSSerialize =309shared::SPSArgList<shared::SPSSimpleRemoteEPCExecutorInfo>;310shared::SPSInputBuffer IB(SetupMsgBytes.data(), SetupMsgBytes.size());311SimpleRemoteEPCExecutorInfo EI;312if (SPSSerialize::deserialize(IB, EI))313EIP.set_value(EI);314else315EIP.set_value(make_error<StringError>(316"Could not deserialize setup message", inconvertibleErrorCode()));317});318319// Start the transport.320if (auto Err = T->start())321return Err;322323// Wait for setup packet to arrive.324auto EI = EIF.get();325if (!EI) {326T->disconnect();327return EI.takeError();328}329330LLVM_DEBUG({331dbgs() << "SimpleRemoteEPC received setup message:\n"332<< " Triple: " << EI->TargetTriple << "\n"333<< " Page size: " << EI->PageSize << "\n"334<< " Bootstrap map" << (EI->BootstrapMap.empty() ? " empty" : ":")335<< "\n";336for (const auto &KV : EI->BootstrapMap)337dbgs() << " " << KV.first() << ": " << KV.second.size()338<< "-byte SPS encoded buffer\n";339dbgs() << " Bootstrap symbols"340<< (EI->BootstrapSymbols.empty() ? " empty" : ":") << "\n";341for (const auto &KV : EI->BootstrapSymbols)342dbgs() << " " << KV.first() << ": " << KV.second << "\n";343});344TargetTriple = Triple(EI->TargetTriple);345PageSize = EI->PageSize;346BootstrapMap = std::move(EI->BootstrapMap);347BootstrapSymbols = std::move(EI->BootstrapSymbols);348349if (auto Err = getBootstrapSymbols(350{{JDI.JITDispatchContext, ExecutorSessionObjectName},351{JDI.JITDispatchFunction, DispatchFnName},352{RunAsMainAddr, rt::RunAsMainWrapperName},353{RunAsVoidFunctionAddr, rt::RunAsVoidFunctionWrapperName},354{RunAsIntFunctionAddr, rt::RunAsIntFunctionWrapperName}}))355return Err;356357if (auto DM =358EPCGenericDylibManager::CreateWithDefaultBootstrapSymbols(*this))359DylibMgr = std::make_unique<EPCGenericDylibManager>(std::move(*DM));360else361return DM.takeError();362363// Set a default CreateMemoryManager if none is specified.364if (!S.CreateMemoryManager)365S.CreateMemoryManager = createDefaultMemoryManager;366367if (auto MemMgr = S.CreateMemoryManager(*this)) {368OwnedMemMgr = std::move(*MemMgr);369this->MemMgr = OwnedMemMgr.get();370} else371return MemMgr.takeError();372373// Set a default CreateMemoryAccess if none is specified.374if (!S.CreateMemoryAccess)375S.CreateMemoryAccess = createDefaultMemoryAccess;376377if (auto MemAccess = S.CreateMemoryAccess(*this)) {378OwnedMemAccess = std::move(*MemAccess);379this->MemAccess = OwnedMemAccess.get();380} else381return MemAccess.takeError();382383return Error::success();384}385386Error SimpleRemoteEPC::handleResult(uint64_t SeqNo, ExecutorAddr TagAddr,387SimpleRemoteEPCArgBytesVector ArgBytes) {388IncomingWFRHandler SendResult;389390if (TagAddr)391return make_error<StringError>("Unexpected TagAddr in result message",392inconvertibleErrorCode());393394{395std::lock_guard<std::mutex> Lock(SimpleRemoteEPCMutex);396auto I = PendingCallWrapperResults.find(SeqNo);397if (I == PendingCallWrapperResults.end())398return make_error<StringError>("No call for sequence number " +399Twine(SeqNo),400inconvertibleErrorCode());401SendResult = std::move(I->second);402PendingCallWrapperResults.erase(I);403releaseSeqNo(SeqNo);404}405406auto WFR =407shared::WrapperFunctionResult::copyFrom(ArgBytes.data(), ArgBytes.size());408SendResult(std::move(WFR));409return Error::success();410}411412void SimpleRemoteEPC::handleCallWrapper(413uint64_t RemoteSeqNo, ExecutorAddr TagAddr,414SimpleRemoteEPCArgBytesVector ArgBytes) {415assert(ES && "No ExecutionSession attached");416D->dispatch(makeGenericNamedTask(417[this, RemoteSeqNo, TagAddr, ArgBytes = std::move(ArgBytes)]() {418ES->runJITDispatchHandler(419[this, RemoteSeqNo](shared::WrapperFunctionResult WFR) {420if (auto Err =421sendMessage(SimpleRemoteEPCOpcode::Result, RemoteSeqNo,422ExecutorAddr(), {WFR.data(), WFR.size()}))423getExecutionSession().reportError(std::move(Err));424},425TagAddr, ArgBytes);426},427"callWrapper task"));428}429430Error SimpleRemoteEPC::handleHangup(SimpleRemoteEPCArgBytesVector ArgBytes) {431using namespace llvm::orc::shared;432auto WFR = WrapperFunctionResult::copyFrom(ArgBytes.data(), ArgBytes.size());433if (const char *ErrMsg = WFR.getOutOfBandError())434return make_error<StringError>(ErrMsg, inconvertibleErrorCode());435436detail::SPSSerializableError Info;437SPSInputBuffer IB(WFR.data(), WFR.size());438if (!SPSArgList<SPSError>::deserialize(IB, Info))439return make_error<StringError>("Could not deserialize hangup info",440inconvertibleErrorCode());441return fromSPSSerializable(std::move(Info));442}443444} // end namespace orc445} // end namespace llvm446447448