Path: blob/main/contrib/llvm-project/llvm/lib/ExecutionEngine/Orc/ReOptimizeLayer.cpp
213799 views
#include "llvm/ExecutionEngine/Orc/ReOptimizeLayer.h"1#include "llvm/ExecutionEngine/Orc/Mangling.h"23using namespace llvm;4using namespace orc;56bool ReOptimizeLayer::ReOptMaterializationUnitState::tryStartReoptimize() {7std::unique_lock<std::mutex> Lock(Mutex);8if (Reoptimizing)9return false;1011Reoptimizing = true;12return true;13}1415void ReOptimizeLayer::ReOptMaterializationUnitState::reoptimizeSucceeded() {16std::unique_lock<std::mutex> Lock(Mutex);17assert(Reoptimizing && "Tried to mark unstarted reoptimization as done");18Reoptimizing = false;19CurVersion++;20}2122void ReOptimizeLayer::ReOptMaterializationUnitState::reoptimizeFailed() {23std::unique_lock<std::mutex> Lock(Mutex);24assert(Reoptimizing && "Tried to mark unstarted reoptimization as done");25Reoptimizing = false;26}2728Error ReOptimizeLayer::reigsterRuntimeFunctions(JITDylib &PlatformJD) {29ExecutionSession::JITDispatchHandlerAssociationMap WFs;30using ReoptimizeSPSSig = shared::SPSError(uint64_t, uint32_t);31WFs[Mangle("__orc_rt_reoptimize_tag")] =32ES.wrapAsyncWithSPS<ReoptimizeSPSSig>(this,33&ReOptimizeLayer::rt_reoptimize);34return ES.registerJITDispatchHandlers(PlatformJD, std::move(WFs));35}3637void ReOptimizeLayer::emit(std::unique_ptr<MaterializationResponsibility> R,38ThreadSafeModule TSM) {39auto &JD = R->getTargetJITDylib();4041bool HasNonCallable = false;42for (auto &KV : R->getSymbols()) {43auto &Flags = KV.second;44if (!Flags.isCallable())45HasNonCallable = true;46}4748if (HasNonCallable) {49BaseLayer.emit(std::move(R), std::move(TSM));50return;51}5253auto &MUState = createMaterializationUnitState(TSM);5455if (auto Err = R->withResourceKeyDo([&](ResourceKey Key) {56registerMaterializationUnitResource(Key, MUState);57})) {58ES.reportError(std::move(Err));59R->failMaterialization();60return;61}6263if (auto Err =64ProfilerFunc(*this, MUState.getID(), MUState.getCurVersion(), TSM)) {65ES.reportError(std::move(Err));66R->failMaterialization();67return;68}6970auto InitialDests =71emitMUImplSymbols(MUState, MUState.getCurVersion(), JD, std::move(TSM));72if (!InitialDests) {73ES.reportError(InitialDests.takeError());74R->failMaterialization();75return;76}7778RSManager.emitRedirectableSymbols(std::move(R), std::move(*InitialDests));79}8081Error ReOptimizeLayer::reoptimizeIfCallFrequent(ReOptimizeLayer &Parent,82ReOptMaterializationUnitID MUID,83unsigned CurVersion,84ThreadSafeModule &TSM) {85return TSM.withModuleDo([&](Module &M) -> Error {86Type *I64Ty = Type::getInt64Ty(M.getContext());87GlobalVariable *Counter = new GlobalVariable(88M, I64Ty, false, GlobalValue::InternalLinkage,89Constant::getNullValue(I64Ty), "__orc_reopt_counter");90auto ArgBufferConst = createReoptimizeArgBuffer(M, MUID, CurVersion);91if (auto Err = ArgBufferConst.takeError())92return Err;93GlobalVariable *ArgBuffer =94new GlobalVariable(M, (*ArgBufferConst)->getType(), true,95GlobalValue::InternalLinkage, (*ArgBufferConst));96for (auto &F : M) {97if (F.isDeclaration())98continue;99auto &BB = F.getEntryBlock();100auto *IP = &*BB.getFirstInsertionPt();101IRBuilder<> IRB(IP);102Value *Threshold = ConstantInt::get(I64Ty, CallCountThreshold, true);103Value *Cnt = IRB.CreateLoad(I64Ty, Counter);104// Use EQ to prevent further reoptimize calls.105Value *Cmp = IRB.CreateICmpEQ(Cnt, Threshold);106Value *Added = IRB.CreateAdd(Cnt, ConstantInt::get(I64Ty, 1));107(void)IRB.CreateStore(Added, Counter);108Instruction *SplitTerminator = SplitBlockAndInsertIfThen(Cmp, IP, false);109createReoptimizeCall(M, *SplitTerminator, ArgBuffer);110}111return Error::success();112});113}114115Expected<SymbolMap>116ReOptimizeLayer::emitMUImplSymbols(ReOptMaterializationUnitState &MUState,117uint32_t Version, JITDylib &JD,118ThreadSafeModule TSM) {119DenseMap<SymbolStringPtr, SymbolStringPtr> RenamedMap;120cantFail(TSM.withModuleDo([&](Module &M) -> Error {121MangleAndInterner Mangle(ES, M.getDataLayout());122for (auto &F : M)123if (!F.isDeclaration()) {124std::string NewName =125(F.getName() + ".__def__." + Twine(Version)).str();126RenamedMap[Mangle(F.getName())] = Mangle(NewName);127F.setName(NewName);128}129return Error::success();130}));131132auto RT = JD.createResourceTracker();133if (auto Err =134JD.define(std::make_unique<BasicIRLayerMaterializationUnit>(135BaseLayer, *getManglingOptions(), std::move(TSM)),136RT))137return Err;138MUState.setResourceTracker(RT);139140SymbolLookupSet LookupSymbols;141for (auto [K, V] : RenamedMap)142LookupSymbols.add(V);143144auto ImplSymbols =145ES.lookup({{&JD, JITDylibLookupFlags::MatchAllSymbols}}, LookupSymbols,146LookupKind::Static, SymbolState::Resolved);147if (auto Err = ImplSymbols.takeError())148return Err;149150SymbolMap Result;151for (auto [K, V] : RenamedMap)152Result[K] = (*ImplSymbols)[V];153154return Result;155}156157void ReOptimizeLayer::rt_reoptimize(SendErrorFn SendResult,158ReOptMaterializationUnitID MUID,159uint32_t CurVersion) {160auto &MUState = getMaterializationUnitState(MUID);161if (CurVersion < MUState.getCurVersion() || !MUState.tryStartReoptimize()) {162SendResult(Error::success());163return;164}165166ThreadSafeModule TSM = cloneToNewContext(MUState.getThreadSafeModule());167auto OldRT = MUState.getResourceTracker();168auto &JD = OldRT->getJITDylib();169170if (auto Err = ReOptFunc(*this, MUID, CurVersion + 1, OldRT, TSM)) {171ES.reportError(std::move(Err));172MUState.reoptimizeFailed();173SendResult(Error::success());174return;175}176177auto SymbolDests =178emitMUImplSymbols(MUState, CurVersion + 1, JD, std::move(TSM));179if (!SymbolDests) {180ES.reportError(SymbolDests.takeError());181MUState.reoptimizeFailed();182SendResult(Error::success());183return;184}185186if (auto Err = RSManager.redirect(JD, std::move(*SymbolDests))) {187ES.reportError(std::move(Err));188MUState.reoptimizeFailed();189SendResult(Error::success());190return;191}192193MUState.reoptimizeSucceeded();194SendResult(Error::success());195}196197Expected<Constant *> ReOptimizeLayer::createReoptimizeArgBuffer(198Module &M, ReOptMaterializationUnitID MUID, uint32_t CurVersion) {199size_t ArgBufferSize = SPSReoptimizeArgList::size(MUID, CurVersion);200std::vector<char> ArgBuffer(ArgBufferSize);201shared::SPSOutputBuffer OB(ArgBuffer.data(), ArgBuffer.size());202if (!SPSReoptimizeArgList::serialize(OB, MUID, CurVersion))203return make_error<StringError>("Could not serealize args list",204inconvertibleErrorCode());205return ConstantDataArray::get(M.getContext(), ArrayRef(ArgBuffer));206}207208void ReOptimizeLayer::createReoptimizeCall(Module &M, Instruction &IP,209GlobalVariable *ArgBuffer) {210GlobalVariable *DispatchCtx =211M.getGlobalVariable("__orc_rt_jit_dispatch_ctx");212if (!DispatchCtx)213DispatchCtx = new GlobalVariable(M, PointerType::get(M.getContext(), 0),214false, GlobalValue::ExternalLinkage,215nullptr, "__orc_rt_jit_dispatch_ctx");216GlobalVariable *ReoptimizeTag =217M.getGlobalVariable("__orc_rt_reoptimize_tag");218if (!ReoptimizeTag)219ReoptimizeTag = new GlobalVariable(M, PointerType::get(M.getContext(), 0),220false, GlobalValue::ExternalLinkage,221nullptr, "__orc_rt_reoptimize_tag");222Function *DispatchFunc = M.getFunction("__orc_rt_jit_dispatch");223if (!DispatchFunc) {224std::vector<Type *> Args = {PointerType::get(M.getContext(), 0),225PointerType::get(M.getContext(), 0),226PointerType::get(M.getContext(), 0),227IntegerType::get(M.getContext(), 64)};228FunctionType *FuncTy =229FunctionType::get(Type::getVoidTy(M.getContext()), Args, false);230DispatchFunc = Function::Create(FuncTy, GlobalValue::ExternalLinkage,231"__orc_rt_jit_dispatch", &M);232}233size_t ArgBufferSizeConst =234SPSReoptimizeArgList::size(ReOptMaterializationUnitID{}, uint32_t{});235Constant *ArgBufferSize = ConstantInt::get(236IntegerType::get(M.getContext(), 64), ArgBufferSizeConst, false);237IRBuilder<> IRB(&IP);238(void)IRB.CreateCall(DispatchFunc,239{DispatchCtx, ReoptimizeTag, ArgBuffer, ArgBufferSize});240}241242ReOptimizeLayer::ReOptMaterializationUnitState &243ReOptimizeLayer::createMaterializationUnitState(const ThreadSafeModule &TSM) {244std::unique_lock<std::mutex> Lock(Mutex);245ReOptMaterializationUnitID MUID = NextID;246MUStates.emplace(MUID,247ReOptMaterializationUnitState(MUID, cloneToNewContext(TSM)));248++NextID;249return MUStates.at(MUID);250}251252ReOptimizeLayer::ReOptMaterializationUnitState &253ReOptimizeLayer::getMaterializationUnitState(ReOptMaterializationUnitID MUID) {254std::unique_lock<std::mutex> Lock(Mutex);255return MUStates.at(MUID);256}257258void ReOptimizeLayer::registerMaterializationUnitResource(259ResourceKey Key, ReOptMaterializationUnitState &State) {260std::unique_lock<std::mutex> Lock(Mutex);261MUResources[Key].insert(State.getID());262}263264Error ReOptimizeLayer::handleRemoveResources(JITDylib &JD, ResourceKey K) {265std::unique_lock<std::mutex> Lock(Mutex);266for (auto MUID : MUResources[K])267MUStates.erase(MUID);268269MUResources.erase(K);270return Error::success();271}272273void ReOptimizeLayer::handleTransferResources(JITDylib &JD, ResourceKey DstK,274ResourceKey SrcK) {275std::unique_lock<std::mutex> Lock(Mutex);276MUResources[DstK].insert_range(MUResources[SrcK]);277MUResources.erase(SrcK);278}279280281