Path: blob/main/contrib/llvm-project/llvm/lib/Transforms/IPO/OpenMPOpt.cpp
35266 views
//===-- IPO/OpenMPOpt.cpp - Collection of OpenMP specific optimizations ---===//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.3// See https://llvm.org/LICENSE.txt for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===----------------------------------------------------------------------===//7//8// OpenMP specific optimizations:9//10// - Deduplication of runtime calls, e.g., omp_get_thread_num.11// - Replacing globalized device memory with stack memory.12// - Replacing globalized device memory with shared memory.13// - Parallel region merging.14// - Transforming generic-mode device kernels to SPMD mode.15// - Specializing the state machine for generic-mode device kernels.16//17//===----------------------------------------------------------------------===//1819#include "llvm/Transforms/IPO/OpenMPOpt.h"2021#include "llvm/ADT/EnumeratedArray.h"22#include "llvm/ADT/PostOrderIterator.h"23#include "llvm/ADT/SetVector.h"24#include "llvm/ADT/SmallPtrSet.h"25#include "llvm/ADT/SmallVector.h"26#include "llvm/ADT/Statistic.h"27#include "llvm/ADT/StringExtras.h"28#include "llvm/ADT/StringRef.h"29#include "llvm/Analysis/CallGraph.h"30#include "llvm/Analysis/CallGraphSCCPass.h"31#include "llvm/Analysis/MemoryLocation.h"32#include "llvm/Analysis/OptimizationRemarkEmitter.h"33#include "llvm/Analysis/ValueTracking.h"34#include "llvm/Frontend/OpenMP/OMPConstants.h"35#include "llvm/Frontend/OpenMP/OMPDeviceConstants.h"36#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"37#include "llvm/IR/Assumptions.h"38#include "llvm/IR/BasicBlock.h"39#include "llvm/IR/Constants.h"40#include "llvm/IR/DiagnosticInfo.h"41#include "llvm/IR/Dominators.h"42#include "llvm/IR/Function.h"43#include "llvm/IR/GlobalValue.h"44#include "llvm/IR/GlobalVariable.h"45#include "llvm/IR/InstrTypes.h"46#include "llvm/IR/Instruction.h"47#include "llvm/IR/Instructions.h"48#include "llvm/IR/IntrinsicInst.h"49#include "llvm/IR/IntrinsicsAMDGPU.h"50#include "llvm/IR/IntrinsicsNVPTX.h"51#include "llvm/IR/LLVMContext.h"52#include "llvm/Support/Casting.h"53#include "llvm/Support/CommandLine.h"54#include "llvm/Support/Debug.h"55#include "llvm/Transforms/IPO/Attributor.h"56#include "llvm/Transforms/Utils/BasicBlockUtils.h"57#include "llvm/Transforms/Utils/CallGraphUpdater.h"5859#include <algorithm>60#include <optional>61#include <string>6263using namespace llvm;64using namespace omp;6566#define DEBUG_TYPE "openmp-opt"6768static cl::opt<bool> DisableOpenMPOptimizations(69"openmp-opt-disable", cl::desc("Disable OpenMP specific optimizations."),70cl::Hidden, cl::init(false));7172static cl::opt<bool> EnableParallelRegionMerging(73"openmp-opt-enable-merging",74cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden,75cl::init(false));7677static cl::opt<bool>78DisableInternalization("openmp-opt-disable-internalization",79cl::desc("Disable function internalization."),80cl::Hidden, cl::init(false));8182static cl::opt<bool> DeduceICVValues("openmp-deduce-icv-values",83cl::init(false), cl::Hidden);84static cl::opt<bool> PrintICVValues("openmp-print-icv-values", cl::init(false),85cl::Hidden);86static cl::opt<bool> PrintOpenMPKernels("openmp-print-gpu-kernels",87cl::init(false), cl::Hidden);8889static cl::opt<bool> HideMemoryTransferLatency(90"openmp-hide-memory-transfer-latency",91cl::desc("[WIP] Tries to hide the latency of host to device memory"92" transfers"),93cl::Hidden, cl::init(false));9495static cl::opt<bool> DisableOpenMPOptDeglobalization(96"openmp-opt-disable-deglobalization",97cl::desc("Disable OpenMP optimizations involving deglobalization."),98cl::Hidden, cl::init(false));99100static cl::opt<bool> DisableOpenMPOptSPMDization(101"openmp-opt-disable-spmdization",102cl::desc("Disable OpenMP optimizations involving SPMD-ization."),103cl::Hidden, cl::init(false));104105static cl::opt<bool> DisableOpenMPOptFolding(106"openmp-opt-disable-folding",107cl::desc("Disable OpenMP optimizations involving folding."), cl::Hidden,108cl::init(false));109110static cl::opt<bool> DisableOpenMPOptStateMachineRewrite(111"openmp-opt-disable-state-machine-rewrite",112cl::desc("Disable OpenMP optimizations that replace the state machine."),113cl::Hidden, cl::init(false));114115static cl::opt<bool> DisableOpenMPOptBarrierElimination(116"openmp-opt-disable-barrier-elimination",117cl::desc("Disable OpenMP optimizations that eliminate barriers."),118cl::Hidden, cl::init(false));119120static cl::opt<bool> PrintModuleAfterOptimizations(121"openmp-opt-print-module-after",122cl::desc("Print the current module after OpenMP optimizations."),123cl::Hidden, cl::init(false));124125static cl::opt<bool> PrintModuleBeforeOptimizations(126"openmp-opt-print-module-before",127cl::desc("Print the current module before OpenMP optimizations."),128cl::Hidden, cl::init(false));129130static cl::opt<bool> AlwaysInlineDeviceFunctions(131"openmp-opt-inline-device",132cl::desc("Inline all applicible functions on the device."), cl::Hidden,133cl::init(false));134135static cl::opt<bool>136EnableVerboseRemarks("openmp-opt-verbose-remarks",137cl::desc("Enables more verbose remarks."), cl::Hidden,138cl::init(false));139140static cl::opt<unsigned>141SetFixpointIterations("openmp-opt-max-iterations", cl::Hidden,142cl::desc("Maximal number of attributor iterations."),143cl::init(256));144145static cl::opt<unsigned>146SharedMemoryLimit("openmp-opt-shared-limit", cl::Hidden,147cl::desc("Maximum amount of shared memory to use."),148cl::init(std::numeric_limits<unsigned>::max()));149150STATISTIC(NumOpenMPRuntimeCallsDeduplicated,151"Number of OpenMP runtime calls deduplicated");152STATISTIC(NumOpenMPParallelRegionsDeleted,153"Number of OpenMP parallel regions deleted");154STATISTIC(NumOpenMPRuntimeFunctionsIdentified,155"Number of OpenMP runtime functions identified");156STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified,157"Number of OpenMP runtime function uses identified");158STATISTIC(NumOpenMPTargetRegionKernels,159"Number of OpenMP target region entry points (=kernels) identified");160STATISTIC(NumNonOpenMPTargetRegionKernels,161"Number of non-OpenMP target region kernels identified");162STATISTIC(NumOpenMPTargetRegionKernelsSPMD,163"Number of OpenMP target region entry points (=kernels) executed in "164"SPMD-mode instead of generic-mode");165STATISTIC(NumOpenMPTargetRegionKernelsWithoutStateMachine,166"Number of OpenMP target region entry points (=kernels) executed in "167"generic-mode without a state machines");168STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback,169"Number of OpenMP target region entry points (=kernels) executed in "170"generic-mode with customized state machines with fallback");171STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback,172"Number of OpenMP target region entry points (=kernels) executed in "173"generic-mode with customized state machines without fallback");174STATISTIC(175NumOpenMPParallelRegionsReplacedInGPUStateMachine,176"Number of OpenMP parallel regions replaced with ID in GPU state machines");177STATISTIC(NumOpenMPParallelRegionsMerged,178"Number of OpenMP parallel regions merged");179STATISTIC(NumBytesMovedToSharedMemory,180"Amount of memory pushed to shared memory");181STATISTIC(NumBarriersEliminated, "Number of redundant barriers eliminated");182183#if !defined(NDEBUG)184static constexpr auto TAG = "[" DEBUG_TYPE "]";185#endif186187namespace KernelInfo {188189// struct ConfigurationEnvironmentTy {190// uint8_t UseGenericStateMachine;191// uint8_t MayUseNestedParallelism;192// llvm::omp::OMPTgtExecModeFlags ExecMode;193// int32_t MinThreads;194// int32_t MaxThreads;195// int32_t MinTeams;196// int32_t MaxTeams;197// };198199// struct DynamicEnvironmentTy {200// uint16_t DebugIndentionLevel;201// };202203// struct KernelEnvironmentTy {204// ConfigurationEnvironmentTy Configuration;205// IdentTy *Ident;206// DynamicEnvironmentTy *DynamicEnv;207// };208209#define KERNEL_ENVIRONMENT_IDX(MEMBER, IDX) \210constexpr const unsigned MEMBER##Idx = IDX;211212KERNEL_ENVIRONMENT_IDX(Configuration, 0)213KERNEL_ENVIRONMENT_IDX(Ident, 1)214215#undef KERNEL_ENVIRONMENT_IDX216217#define KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MEMBER, IDX) \218constexpr const unsigned MEMBER##Idx = IDX;219220KERNEL_ENVIRONMENT_CONFIGURATION_IDX(UseGenericStateMachine, 0)221KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MayUseNestedParallelism, 1)222KERNEL_ENVIRONMENT_CONFIGURATION_IDX(ExecMode, 2)223KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MinThreads, 3)224KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MaxThreads, 4)225KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MinTeams, 5)226KERNEL_ENVIRONMENT_CONFIGURATION_IDX(MaxTeams, 6)227228#undef KERNEL_ENVIRONMENT_CONFIGURATION_IDX229230#define KERNEL_ENVIRONMENT_GETTER(MEMBER, RETURNTYPE) \231RETURNTYPE *get##MEMBER##FromKernelEnvironment(ConstantStruct *KernelEnvC) { \232return cast<RETURNTYPE>(KernelEnvC->getAggregateElement(MEMBER##Idx)); \233}234235KERNEL_ENVIRONMENT_GETTER(Ident, Constant)236KERNEL_ENVIRONMENT_GETTER(Configuration, ConstantStruct)237238#undef KERNEL_ENVIRONMENT_GETTER239240#define KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MEMBER) \241ConstantInt *get##MEMBER##FromKernelEnvironment( \242ConstantStruct *KernelEnvC) { \243ConstantStruct *ConfigC = \244getConfigurationFromKernelEnvironment(KernelEnvC); \245return dyn_cast<ConstantInt>(ConfigC->getAggregateElement(MEMBER##Idx)); \246}247248KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(UseGenericStateMachine)249KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MayUseNestedParallelism)250KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(ExecMode)251KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MinThreads)252KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MaxThreads)253KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MinTeams)254KERNEL_ENVIRONMENT_CONFIGURATION_GETTER(MaxTeams)255256#undef KERNEL_ENVIRONMENT_CONFIGURATION_GETTER257258GlobalVariable *259getKernelEnvironementGVFromKernelInitCB(CallBase *KernelInitCB) {260constexpr const int InitKernelEnvironmentArgNo = 0;261return cast<GlobalVariable>(262KernelInitCB->getArgOperand(InitKernelEnvironmentArgNo)263->stripPointerCasts());264}265266ConstantStruct *getKernelEnvironementFromKernelInitCB(CallBase *KernelInitCB) {267GlobalVariable *KernelEnvGV =268getKernelEnvironementGVFromKernelInitCB(KernelInitCB);269return cast<ConstantStruct>(KernelEnvGV->getInitializer());270}271} // namespace KernelInfo272273namespace {274275struct AAHeapToShared;276277struct AAICVTracker;278279/// OpenMP specific information. For now, stores RFIs and ICVs also needed for280/// Attributor runs.281struct OMPInformationCache : public InformationCache {282OMPInformationCache(Module &M, AnalysisGetter &AG,283BumpPtrAllocator &Allocator, SetVector<Function *> *CGSCC,284bool OpenMPPostLink)285: InformationCache(M, AG, Allocator, CGSCC), OMPBuilder(M),286OpenMPPostLink(OpenMPPostLink) {287288OMPBuilder.Config.IsTargetDevice = isOpenMPDevice(OMPBuilder.M);289OMPBuilder.initialize();290initializeRuntimeFunctions(M);291initializeInternalControlVars();292}293294/// Generic information that describes an internal control variable.295struct InternalControlVarInfo {296/// The kind, as described by InternalControlVar enum.297InternalControlVar Kind;298299/// The name of the ICV.300StringRef Name;301302/// Environment variable associated with this ICV.303StringRef EnvVarName;304305/// Initial value kind.306ICVInitValue InitKind;307308/// Initial value.309ConstantInt *InitValue;310311/// Setter RTL function associated with this ICV.312RuntimeFunction Setter;313314/// Getter RTL function associated with this ICV.315RuntimeFunction Getter;316317/// RTL Function corresponding to the override clause of this ICV318RuntimeFunction Clause;319};320321/// Generic information that describes a runtime function322struct RuntimeFunctionInfo {323324/// The kind, as described by the RuntimeFunction enum.325RuntimeFunction Kind;326327/// The name of the function.328StringRef Name;329330/// Flag to indicate a variadic function.331bool IsVarArg;332333/// The return type of the function.334Type *ReturnType;335336/// The argument types of the function.337SmallVector<Type *, 8> ArgumentTypes;338339/// The declaration if available.340Function *Declaration = nullptr;341342/// Uses of this runtime function per function containing the use.343using UseVector = SmallVector<Use *, 16>;344345/// Clear UsesMap for runtime function.346void clearUsesMap() { UsesMap.clear(); }347348/// Boolean conversion that is true if the runtime function was found.349operator bool() const { return Declaration; }350351/// Return the vector of uses in function \p F.352UseVector &getOrCreateUseVector(Function *F) {353std::shared_ptr<UseVector> &UV = UsesMap[F];354if (!UV)355UV = std::make_shared<UseVector>();356return *UV;357}358359/// Return the vector of uses in function \p F or `nullptr` if there are360/// none.361const UseVector *getUseVector(Function &F) const {362auto I = UsesMap.find(&F);363if (I != UsesMap.end())364return I->second.get();365return nullptr;366}367368/// Return how many functions contain uses of this runtime function.369size_t getNumFunctionsWithUses() const { return UsesMap.size(); }370371/// Return the number of arguments (or the minimal number for variadic372/// functions).373size_t getNumArgs() const { return ArgumentTypes.size(); }374375/// Run the callback \p CB on each use and forget the use if the result is376/// true. The callback will be fed the function in which the use was377/// encountered as second argument.378void foreachUse(SmallVectorImpl<Function *> &SCC,379function_ref<bool(Use &, Function &)> CB) {380for (Function *F : SCC)381foreachUse(CB, F);382}383384/// Run the callback \p CB on each use within the function \p F and forget385/// the use if the result is true.386void foreachUse(function_ref<bool(Use &, Function &)> CB, Function *F) {387SmallVector<unsigned, 8> ToBeDeleted;388ToBeDeleted.clear();389390unsigned Idx = 0;391UseVector &UV = getOrCreateUseVector(F);392393for (Use *U : UV) {394if (CB(*U, *F))395ToBeDeleted.push_back(Idx);396++Idx;397}398399// Remove the to-be-deleted indices in reverse order as prior400// modifications will not modify the smaller indices.401while (!ToBeDeleted.empty()) {402unsigned Idx = ToBeDeleted.pop_back_val();403UV[Idx] = UV.back();404UV.pop_back();405}406}407408private:409/// Map from functions to all uses of this runtime function contained in410/// them.411DenseMap<Function *, std::shared_ptr<UseVector>> UsesMap;412413public:414/// Iterators for the uses of this runtime function.415decltype(UsesMap)::iterator begin() { return UsesMap.begin(); }416decltype(UsesMap)::iterator end() { return UsesMap.end(); }417};418419/// An OpenMP-IR-Builder instance420OpenMPIRBuilder OMPBuilder;421422/// Map from runtime function kind to the runtime function description.423EnumeratedArray<RuntimeFunctionInfo, RuntimeFunction,424RuntimeFunction::OMPRTL___last>425RFIs;426427/// Map from function declarations/definitions to their runtime enum type.428DenseMap<Function *, RuntimeFunction> RuntimeFunctionIDMap;429430/// Map from ICV kind to the ICV description.431EnumeratedArray<InternalControlVarInfo, InternalControlVar,432InternalControlVar::ICV___last>433ICVs;434435/// Helper to initialize all internal control variable information for those436/// defined in OMPKinds.def.437void initializeInternalControlVars() {438#define ICV_RT_SET(_Name, RTL) \439{ \440auto &ICV = ICVs[_Name]; \441ICV.Setter = RTL; \442}443#define ICV_RT_GET(Name, RTL) \444{ \445auto &ICV = ICVs[Name]; \446ICV.Getter = RTL; \447}448#define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init) \449{ \450auto &ICV = ICVs[Enum]; \451ICV.Name = _Name; \452ICV.Kind = Enum; \453ICV.InitKind = Init; \454ICV.EnvVarName = _EnvVarName; \455switch (ICV.InitKind) { \456case ICV_IMPLEMENTATION_DEFINED: \457ICV.InitValue = nullptr; \458break; \459case ICV_ZERO: \460ICV.InitValue = ConstantInt::get( \461Type::getInt32Ty(OMPBuilder.Int32->getContext()), 0); \462break; \463case ICV_FALSE: \464ICV.InitValue = ConstantInt::getFalse(OMPBuilder.Int1->getContext()); \465break; \466case ICV_LAST: \467break; \468} \469}470#include "llvm/Frontend/OpenMP/OMPKinds.def"471}472473/// Returns true if the function declaration \p F matches the runtime474/// function types, that is, return type \p RTFRetType, and argument types475/// \p RTFArgTypes.476static bool declMatchesRTFTypes(Function *F, Type *RTFRetType,477SmallVector<Type *, 8> &RTFArgTypes) {478// TODO: We should output information to the user (under debug output479// and via remarks).480481if (!F)482return false;483if (F->getReturnType() != RTFRetType)484return false;485if (F->arg_size() != RTFArgTypes.size())486return false;487488auto *RTFTyIt = RTFArgTypes.begin();489for (Argument &Arg : F->args()) {490if (Arg.getType() != *RTFTyIt)491return false;492493++RTFTyIt;494}495496return true;497}498499// Helper to collect all uses of the declaration in the UsesMap.500unsigned collectUses(RuntimeFunctionInfo &RFI, bool CollectStats = true) {501unsigned NumUses = 0;502if (!RFI.Declaration)503return NumUses;504OMPBuilder.addAttributes(RFI.Kind, *RFI.Declaration);505506if (CollectStats) {507NumOpenMPRuntimeFunctionsIdentified += 1;508NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses();509}510511// TODO: We directly convert uses into proper calls and unknown uses.512for (Use &U : RFI.Declaration->uses()) {513if (Instruction *UserI = dyn_cast<Instruction>(U.getUser())) {514if (!CGSCC || CGSCC->empty() || CGSCC->contains(UserI->getFunction())) {515RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U);516++NumUses;517}518} else {519RFI.getOrCreateUseVector(nullptr).push_back(&U);520++NumUses;521}522}523return NumUses;524}525526// Helper function to recollect uses of a runtime function.527void recollectUsesForFunction(RuntimeFunction RTF) {528auto &RFI = RFIs[RTF];529RFI.clearUsesMap();530collectUses(RFI, /*CollectStats*/ false);531}532533// Helper function to recollect uses of all runtime functions.534void recollectUses() {535for (int Idx = 0; Idx < RFIs.size(); ++Idx)536recollectUsesForFunction(static_cast<RuntimeFunction>(Idx));537}538539// Helper function to inherit the calling convention of the function callee.540void setCallingConvention(FunctionCallee Callee, CallInst *CI) {541if (Function *Fn = dyn_cast<Function>(Callee.getCallee()))542CI->setCallingConv(Fn->getCallingConv());543}544545// Helper function to determine if it's legal to create a call to the runtime546// functions.547bool runtimeFnsAvailable(ArrayRef<RuntimeFunction> Fns) {548// We can always emit calls if we haven't yet linked in the runtime.549if (!OpenMPPostLink)550return true;551552// Once the runtime has been already been linked in we cannot emit calls to553// any undefined functions.554for (RuntimeFunction Fn : Fns) {555RuntimeFunctionInfo &RFI = RFIs[Fn];556557if (RFI.Declaration && RFI.Declaration->isDeclaration())558return false;559}560return true;561}562563/// Helper to initialize all runtime function information for those defined564/// in OpenMPKinds.def.565void initializeRuntimeFunctions(Module &M) {566567// Helper macros for handling __VA_ARGS__ in OMP_RTL568#define OMP_TYPE(VarName, ...) \569Type *VarName = OMPBuilder.VarName; \570(void)VarName;571572#define OMP_ARRAY_TYPE(VarName, ...) \573ArrayType *VarName##Ty = OMPBuilder.VarName##Ty; \574(void)VarName##Ty; \575PointerType *VarName##PtrTy = OMPBuilder.VarName##PtrTy; \576(void)VarName##PtrTy;577578#define OMP_FUNCTION_TYPE(VarName, ...) \579FunctionType *VarName = OMPBuilder.VarName; \580(void)VarName; \581PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \582(void)VarName##Ptr;583584#define OMP_STRUCT_TYPE(VarName, ...) \585StructType *VarName = OMPBuilder.VarName; \586(void)VarName; \587PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr; \588(void)VarName##Ptr;589590#define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...) \591{ \592SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__}); \593Function *F = M.getFunction(_Name); \594RTLFunctions.insert(F); \595if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) { \596RuntimeFunctionIDMap[F] = _Enum; \597auto &RFI = RFIs[_Enum]; \598RFI.Kind = _Enum; \599RFI.Name = _Name; \600RFI.IsVarArg = _IsVarArg; \601RFI.ReturnType = OMPBuilder._ReturnType; \602RFI.ArgumentTypes = std::move(ArgsTypes); \603RFI.Declaration = F; \604unsigned NumUses = collectUses(RFI); \605(void)NumUses; \606LLVM_DEBUG({ \607dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not") \608<< " found\n"; \609if (RFI.Declaration) \610dbgs() << TAG << "-> got " << NumUses << " uses in " \611<< RFI.getNumFunctionsWithUses() \612<< " different functions.\n"; \613}); \614} \615}616#include "llvm/Frontend/OpenMP/OMPKinds.def"617618// Remove the `noinline` attribute from `__kmpc`, `ompx::` and `omp_`619// functions, except if `optnone` is present.620if (isOpenMPDevice(M)) {621for (Function &F : M) {622for (StringRef Prefix : {"__kmpc", "_ZN4ompx", "omp_"})623if (F.hasFnAttribute(Attribute::NoInline) &&624F.getName().starts_with(Prefix) &&625!F.hasFnAttribute(Attribute::OptimizeNone))626F.removeFnAttr(Attribute::NoInline);627}628}629630// TODO: We should attach the attributes defined in OMPKinds.def.631}632633/// Collection of known OpenMP runtime functions..634DenseSet<const Function *> RTLFunctions;635636/// Indicates if we have already linked in the OpenMP device library.637bool OpenMPPostLink = false;638};639640template <typename Ty, bool InsertInvalidates = true>641struct BooleanStateWithSetVector : public BooleanState {642bool contains(const Ty &Elem) const { return Set.contains(Elem); }643bool insert(const Ty &Elem) {644if (InsertInvalidates)645BooleanState::indicatePessimisticFixpoint();646return Set.insert(Elem);647}648649const Ty &operator[](int Idx) const { return Set[Idx]; }650bool operator==(const BooleanStateWithSetVector &RHS) const {651return BooleanState::operator==(RHS) && Set == RHS.Set;652}653bool operator!=(const BooleanStateWithSetVector &RHS) const {654return !(*this == RHS);655}656657bool empty() const { return Set.empty(); }658size_t size() const { return Set.size(); }659660/// "Clamp" this state with \p RHS.661BooleanStateWithSetVector &operator^=(const BooleanStateWithSetVector &RHS) {662BooleanState::operator^=(RHS);663Set.insert(RHS.Set.begin(), RHS.Set.end());664return *this;665}666667private:668/// A set to keep track of elements.669SetVector<Ty> Set;670671public:672typename decltype(Set)::iterator begin() { return Set.begin(); }673typename decltype(Set)::iterator end() { return Set.end(); }674typename decltype(Set)::const_iterator begin() const { return Set.begin(); }675typename decltype(Set)::const_iterator end() const { return Set.end(); }676};677678template <typename Ty, bool InsertInvalidates = true>679using BooleanStateWithPtrSetVector =680BooleanStateWithSetVector<Ty *, InsertInvalidates>;681682struct KernelInfoState : AbstractState {683/// Flag to track if we reached a fixpoint.684bool IsAtFixpoint = false;685686/// The parallel regions (identified by the outlined parallel functions) that687/// can be reached from the associated function.688BooleanStateWithPtrSetVector<CallBase, /* InsertInvalidates */ false>689ReachedKnownParallelRegions;690691/// State to track what parallel region we might reach.692BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions;693694/// State to track if we are in SPMD-mode, assumed or know, and why we decided695/// we cannot be. If it is assumed, then RequiresFullRuntime should also be696/// false.697BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker;698699/// The __kmpc_target_init call in this kernel, if any. If we find more than700/// one we abort as the kernel is malformed.701CallBase *KernelInitCB = nullptr;702703/// The constant kernel environement as taken from and passed to704/// __kmpc_target_init.705ConstantStruct *KernelEnvC = nullptr;706707/// The __kmpc_target_deinit call in this kernel, if any. If we find more than708/// one we abort as the kernel is malformed.709CallBase *KernelDeinitCB = nullptr;710711/// Flag to indicate if the associated function is a kernel entry.712bool IsKernelEntry = false;713714/// State to track what kernel entries can reach the associated function.715BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;716717/// State to indicate if we can track parallel level of the associated718/// function. We will give up tracking if we encounter unknown caller or the719/// caller is __kmpc_parallel_51.720BooleanStateWithSetVector<uint8_t> ParallelLevels;721722/// Flag that indicates if the kernel has nested Parallelism723bool NestedParallelism = false;724725/// Abstract State interface726///{727728KernelInfoState() = default;729KernelInfoState(bool BestState) {730if (!BestState)731indicatePessimisticFixpoint();732}733734/// See AbstractState::isValidState(...)735bool isValidState() const override { return true; }736737/// See AbstractState::isAtFixpoint(...)738bool isAtFixpoint() const override { return IsAtFixpoint; }739740/// See AbstractState::indicatePessimisticFixpoint(...)741ChangeStatus indicatePessimisticFixpoint() override {742IsAtFixpoint = true;743ParallelLevels.indicatePessimisticFixpoint();744ReachingKernelEntries.indicatePessimisticFixpoint();745SPMDCompatibilityTracker.indicatePessimisticFixpoint();746ReachedKnownParallelRegions.indicatePessimisticFixpoint();747ReachedUnknownParallelRegions.indicatePessimisticFixpoint();748NestedParallelism = true;749return ChangeStatus::CHANGED;750}751752/// See AbstractState::indicateOptimisticFixpoint(...)753ChangeStatus indicateOptimisticFixpoint() override {754IsAtFixpoint = true;755ParallelLevels.indicateOptimisticFixpoint();756ReachingKernelEntries.indicateOptimisticFixpoint();757SPMDCompatibilityTracker.indicateOptimisticFixpoint();758ReachedKnownParallelRegions.indicateOptimisticFixpoint();759ReachedUnknownParallelRegions.indicateOptimisticFixpoint();760return ChangeStatus::UNCHANGED;761}762763/// Return the assumed state764KernelInfoState &getAssumed() { return *this; }765const KernelInfoState &getAssumed() const { return *this; }766767bool operator==(const KernelInfoState &RHS) const {768if (SPMDCompatibilityTracker != RHS.SPMDCompatibilityTracker)769return false;770if (ReachedKnownParallelRegions != RHS.ReachedKnownParallelRegions)771return false;772if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions)773return false;774if (ReachingKernelEntries != RHS.ReachingKernelEntries)775return false;776if (ParallelLevels != RHS.ParallelLevels)777return false;778if (NestedParallelism != RHS.NestedParallelism)779return false;780return true;781}782783/// Returns true if this kernel contains any OpenMP parallel regions.784bool mayContainParallelRegion() {785return !ReachedKnownParallelRegions.empty() ||786!ReachedUnknownParallelRegions.empty();787}788789/// Return empty set as the best state of potential values.790static KernelInfoState getBestState() { return KernelInfoState(true); }791792static KernelInfoState getBestState(KernelInfoState &KIS) {793return getBestState();794}795796/// Return full set as the worst state of potential values.797static KernelInfoState getWorstState() { return KernelInfoState(false); }798799/// "Clamp" this state with \p KIS.800KernelInfoState operator^=(const KernelInfoState &KIS) {801// Do not merge two different _init and _deinit call sites.802if (KIS.KernelInitCB) {803if (KernelInitCB && KernelInitCB != KIS.KernelInitCB)804llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "805"assumptions.");806KernelInitCB = KIS.KernelInitCB;807}808if (KIS.KernelDeinitCB) {809if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB)810llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "811"assumptions.");812KernelDeinitCB = KIS.KernelDeinitCB;813}814if (KIS.KernelEnvC) {815if (KernelEnvC && KernelEnvC != KIS.KernelEnvC)816llvm_unreachable("Kernel that calls another kernel violates OpenMP-Opt "817"assumptions.");818KernelEnvC = KIS.KernelEnvC;819}820SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;821ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;822ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;823NestedParallelism |= KIS.NestedParallelism;824return *this;825}826827KernelInfoState operator&=(const KernelInfoState &KIS) {828return (*this ^= KIS);829}830831///}832};833834/// Used to map the values physically (in the IR) stored in an offload835/// array, to a vector in memory.836struct OffloadArray {837/// Physical array (in the IR).838AllocaInst *Array = nullptr;839/// Mapped values.840SmallVector<Value *, 8> StoredValues;841/// Last stores made in the offload array.842SmallVector<StoreInst *, 8> LastAccesses;843844OffloadArray() = default;845846/// Initializes the OffloadArray with the values stored in \p Array before847/// instruction \p Before is reached. Returns false if the initialization848/// fails.849/// This MUST be used immediately after the construction of the object.850bool initialize(AllocaInst &Array, Instruction &Before) {851if (!Array.getAllocatedType()->isArrayTy())852return false;853854if (!getValues(Array, Before))855return false;856857this->Array = &Array;858return true;859}860861static const unsigned DeviceIDArgNum = 1;862static const unsigned BasePtrsArgNum = 3;863static const unsigned PtrsArgNum = 4;864static const unsigned SizesArgNum = 5;865866private:867/// Traverses the BasicBlock where \p Array is, collecting the stores made to868/// \p Array, leaving StoredValues with the values stored before the869/// instruction \p Before is reached.870bool getValues(AllocaInst &Array, Instruction &Before) {871// Initialize container.872const uint64_t NumValues = Array.getAllocatedType()->getArrayNumElements();873StoredValues.assign(NumValues, nullptr);874LastAccesses.assign(NumValues, nullptr);875876// TODO: This assumes the instruction \p Before is in the same877// BasicBlock as Array. Make it general, for any control flow graph.878BasicBlock *BB = Array.getParent();879if (BB != Before.getParent())880return false;881882const DataLayout &DL = Array.getDataLayout();883const unsigned int PointerSize = DL.getPointerSize();884885for (Instruction &I : *BB) {886if (&I == &Before)887break;888889if (!isa<StoreInst>(&I))890continue;891892auto *S = cast<StoreInst>(&I);893int64_t Offset = -1;894auto *Dst =895GetPointerBaseWithConstantOffset(S->getPointerOperand(), Offset, DL);896if (Dst == &Array) {897int64_t Idx = Offset / PointerSize;898StoredValues[Idx] = getUnderlyingObject(S->getValueOperand());899LastAccesses[Idx] = S;900}901}902903return isFilled();904}905906/// Returns true if all values in StoredValues and907/// LastAccesses are not nullptrs.908bool isFilled() {909const unsigned NumValues = StoredValues.size();910for (unsigned I = 0; I < NumValues; ++I) {911if (!StoredValues[I] || !LastAccesses[I])912return false;913}914915return true;916}917};918919struct OpenMPOpt {920921using OptimizationRemarkGetter =922function_ref<OptimizationRemarkEmitter &(Function *)>;923924OpenMPOpt(SmallVectorImpl<Function *> &SCC, CallGraphUpdater &CGUpdater,925OptimizationRemarkGetter OREGetter,926OMPInformationCache &OMPInfoCache, Attributor &A)927: M(*(*SCC.begin())->getParent()), SCC(SCC), CGUpdater(CGUpdater),928OREGetter(OREGetter), OMPInfoCache(OMPInfoCache), A(A) {}929930/// Check if any remarks are enabled for openmp-opt931bool remarksEnabled() {932auto &Ctx = M.getContext();933return Ctx.getDiagHandlerPtr()->isAnyRemarkEnabled(DEBUG_TYPE);934}935936/// Run all OpenMP optimizations on the underlying SCC.937bool run(bool IsModulePass) {938if (SCC.empty())939return false;940941bool Changed = false;942943LLVM_DEBUG(dbgs() << TAG << "Run on SCC with " << SCC.size()944<< " functions\n");945946if (IsModulePass) {947Changed |= runAttributor(IsModulePass);948949// Recollect uses, in case Attributor deleted any.950OMPInfoCache.recollectUses();951952// TODO: This should be folded into buildCustomStateMachine.953Changed |= rewriteDeviceCodeStateMachine();954955if (remarksEnabled())956analysisGlobalization();957} else {958if (PrintICVValues)959printICVs();960if (PrintOpenMPKernels)961printKernels();962963Changed |= runAttributor(IsModulePass);964965// Recollect uses, in case Attributor deleted any.966OMPInfoCache.recollectUses();967968Changed |= deleteParallelRegions();969970if (HideMemoryTransferLatency)971Changed |= hideMemTransfersLatency();972Changed |= deduplicateRuntimeCalls();973if (EnableParallelRegionMerging) {974if (mergeParallelRegions()) {975deduplicateRuntimeCalls();976Changed = true;977}978}979}980981if (OMPInfoCache.OpenMPPostLink)982Changed |= removeRuntimeSymbols();983984return Changed;985}986987/// Print initial ICV values for testing.988/// FIXME: This should be done from the Attributor once it is added.989void printICVs() const {990InternalControlVar ICVs[] = {ICV_nthreads, ICV_active_levels, ICV_cancel,991ICV_proc_bind};992993for (Function *F : SCC) {994for (auto ICV : ICVs) {995auto ICVInfo = OMPInfoCache.ICVs[ICV];996auto Remark = [&](OptimizationRemarkAnalysis ORA) {997return ORA << "OpenMP ICV " << ore::NV("OpenMPICV", ICVInfo.Name)998<< " Value: "999<< (ICVInfo.InitValue1000? toString(ICVInfo.InitValue->getValue(), 10, true)1001: "IMPLEMENTATION_DEFINED");1002};10031004emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPICVTracker", Remark);1005}1006}1007}10081009/// Print OpenMP GPU kernels for testing.1010void printKernels() const {1011for (Function *F : SCC) {1012if (!omp::isOpenMPKernel(*F))1013continue;10141015auto Remark = [&](OptimizationRemarkAnalysis ORA) {1016return ORA << "OpenMP GPU kernel "1017<< ore::NV("OpenMPGPUKernel", F->getName()) << "\n";1018};10191020emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPGPU", Remark);1021}1022}10231024/// Return the call if \p U is a callee use in a regular call. If \p RFI is1025/// given it has to be the callee or a nullptr is returned.1026static CallInst *getCallIfRegularCall(1027Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {1028CallInst *CI = dyn_cast<CallInst>(U.getUser());1029if (CI && CI->isCallee(&U) && !CI->hasOperandBundles() &&1030(!RFI ||1031(RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))1032return CI;1033return nullptr;1034}10351036/// Return the call if \p V is a regular call. If \p RFI is given it has to be1037/// the callee or a nullptr is returned.1038static CallInst *getCallIfRegularCall(1039Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {1040CallInst *CI = dyn_cast<CallInst>(&V);1041if (CI && !CI->hasOperandBundles() &&1042(!RFI ||1043(RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))1044return CI;1045return nullptr;1046}10471048private:1049/// Merge parallel regions when it is safe.1050bool mergeParallelRegions() {1051const unsigned CallbackCalleeOperand = 2;1052const unsigned CallbackFirstArgOperand = 3;1053using InsertPointTy = OpenMPIRBuilder::InsertPointTy;10541055// Check if there are any __kmpc_fork_call calls to merge.1056OMPInformationCache::RuntimeFunctionInfo &RFI =1057OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];10581059if (!RFI.Declaration)1060return false;10611062// Unmergable calls that prevent merging a parallel region.1063OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo[] = {1064OMPInfoCache.RFIs[OMPRTL___kmpc_push_proc_bind],1065OMPInfoCache.RFIs[OMPRTL___kmpc_push_num_threads],1066};10671068bool Changed = false;1069LoopInfo *LI = nullptr;1070DominatorTree *DT = nullptr;10711072SmallDenseMap<BasicBlock *, SmallPtrSet<Instruction *, 4>> BB2PRMap;10731074BasicBlock *StartBB = nullptr, *EndBB = nullptr;1075auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {1076BasicBlock *CGStartBB = CodeGenIP.getBlock();1077BasicBlock *CGEndBB =1078SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);1079assert(StartBB != nullptr && "StartBB should not be null");1080CGStartBB->getTerminator()->setSuccessor(0, StartBB);1081assert(EndBB != nullptr && "EndBB should not be null");1082EndBB->getTerminator()->setSuccessor(0, CGEndBB);1083};10841085auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &,1086Value &Inner, Value *&ReplacementValue) -> InsertPointTy {1087ReplacementValue = &Inner;1088return CodeGenIP;1089};10901091auto FiniCB = [&](InsertPointTy CodeGenIP) {};10921093/// Create a sequential execution region within a merged parallel region,1094/// encapsulated in a master construct with a barrier for synchronization.1095auto CreateSequentialRegion = [&](Function *OuterFn,1096BasicBlock *OuterPredBB,1097Instruction *SeqStartI,1098Instruction *SeqEndI) {1099// Isolate the instructions of the sequential region to a separate1100// block.1101BasicBlock *ParentBB = SeqStartI->getParent();1102BasicBlock *SeqEndBB =1103SplitBlock(ParentBB, SeqEndI->getNextNode(), DT, LI);1104BasicBlock *SeqAfterBB =1105SplitBlock(SeqEndBB, &*SeqEndBB->getFirstInsertionPt(), DT, LI);1106BasicBlock *SeqStartBB =1107SplitBlock(ParentBB, SeqStartI, DT, LI, nullptr, "seq.par.merged");11081109assert(ParentBB->getUniqueSuccessor() == SeqStartBB &&1110"Expected a different CFG");1111const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();1112ParentBB->getTerminator()->eraseFromParent();11131114auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {1115BasicBlock *CGStartBB = CodeGenIP.getBlock();1116BasicBlock *CGEndBB =1117SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);1118assert(SeqStartBB != nullptr && "SeqStartBB should not be null");1119CGStartBB->getTerminator()->setSuccessor(0, SeqStartBB);1120assert(SeqEndBB != nullptr && "SeqEndBB should not be null");1121SeqEndBB->getTerminator()->setSuccessor(0, CGEndBB);1122};1123auto FiniCB = [&](InsertPointTy CodeGenIP) {};11241125// Find outputs from the sequential region to outside users and1126// broadcast their values to them.1127for (Instruction &I : *SeqStartBB) {1128SmallPtrSet<Instruction *, 4> OutsideUsers;1129for (User *Usr : I.users()) {1130Instruction &UsrI = *cast<Instruction>(Usr);1131// Ignore outputs to LT intrinsics, code extraction for the merged1132// parallel region will fix them.1133if (UsrI.isLifetimeStartOrEnd())1134continue;11351136if (UsrI.getParent() != SeqStartBB)1137OutsideUsers.insert(&UsrI);1138}11391140if (OutsideUsers.empty())1141continue;11421143// Emit an alloca in the outer region to store the broadcasted1144// value.1145const DataLayout &DL = M.getDataLayout();1146AllocaInst *AllocaI = new AllocaInst(1147I.getType(), DL.getAllocaAddrSpace(), nullptr,1148I.getName() + ".seq.output.alloc", OuterFn->front().begin());11491150// Emit a store instruction in the sequential BB to update the1151// value.1152new StoreInst(&I, AllocaI, SeqStartBB->getTerminator()->getIterator());11531154// Emit a load instruction and replace the use of the output value1155// with it.1156for (Instruction *UsrI : OutsideUsers) {1157LoadInst *LoadI = new LoadInst(I.getType(), AllocaI,1158I.getName() + ".seq.output.load",1159UsrI->getIterator());1160UsrI->replaceUsesOfWith(&I, LoadI);1161}1162}11631164OpenMPIRBuilder::LocationDescription Loc(1165InsertPointTy(ParentBB, ParentBB->end()), DL);1166InsertPointTy SeqAfterIP =1167OMPInfoCache.OMPBuilder.createMaster(Loc, BodyGenCB, FiniCB);11681169OMPInfoCache.OMPBuilder.createBarrier(SeqAfterIP, OMPD_parallel);11701171BranchInst::Create(SeqAfterBB, SeqAfterIP.getBlock());11721173LLVM_DEBUG(dbgs() << TAG << "After sequential inlining " << *OuterFn1174<< "\n");1175};11761177// Helper to merge the __kmpc_fork_call calls in MergableCIs. They are all1178// contained in BB and only separated by instructions that can be1179// redundantly executed in parallel. The block BB is split before the first1180// call (in MergableCIs) and after the last so the entire region we merge1181// into a single parallel region is contained in a single basic block1182// without any other instructions. We use the OpenMPIRBuilder to outline1183// that block and call the resulting function via __kmpc_fork_call.1184auto Merge = [&](const SmallVectorImpl<CallInst *> &MergableCIs,1185BasicBlock *BB) {1186// TODO: Change the interface to allow single CIs expanded, e.g, to1187// include an outer loop.1188assert(MergableCIs.size() > 1 && "Assumed multiple mergable CIs");11891190auto Remark = [&](OptimizationRemark OR) {1191OR << "Parallel region merged with parallel region"1192<< (MergableCIs.size() > 2 ? "s" : "") << " at ";1193for (auto *CI : llvm::drop_begin(MergableCIs)) {1194OR << ore::NV("OpenMPParallelMerge", CI->getDebugLoc());1195if (CI != MergableCIs.back())1196OR << ", ";1197}1198return OR << ".";1199};12001201emitRemark<OptimizationRemark>(MergableCIs.front(), "OMP150", Remark);12021203Function *OriginalFn = BB->getParent();1204LLVM_DEBUG(dbgs() << TAG << "Merge " << MergableCIs.size()1205<< " parallel regions in " << OriginalFn->getName()1206<< "\n");12071208// Isolate the calls to merge in a separate block.1209EndBB = SplitBlock(BB, MergableCIs.back()->getNextNode(), DT, LI);1210BasicBlock *AfterBB =1211SplitBlock(EndBB, &*EndBB->getFirstInsertionPt(), DT, LI);1212StartBB = SplitBlock(BB, MergableCIs.front(), DT, LI, nullptr,1213"omp.par.merged");12141215assert(BB->getUniqueSuccessor() == StartBB && "Expected a different CFG");1216const DebugLoc DL = BB->getTerminator()->getDebugLoc();1217BB->getTerminator()->eraseFromParent();12181219// Create sequential regions for sequential instructions that are1220// in-between mergable parallel regions.1221for (auto *It = MergableCIs.begin(), *End = MergableCIs.end() - 1;1222It != End; ++It) {1223Instruction *ForkCI = *It;1224Instruction *NextForkCI = *(It + 1);12251226// Continue if there are not in-between instructions.1227if (ForkCI->getNextNode() == NextForkCI)1228continue;12291230CreateSequentialRegion(OriginalFn, BB, ForkCI->getNextNode(),1231NextForkCI->getPrevNode());1232}12331234OpenMPIRBuilder::LocationDescription Loc(InsertPointTy(BB, BB->end()),1235DL);1236IRBuilder<>::InsertPoint AllocaIP(1237&OriginalFn->getEntryBlock(),1238OriginalFn->getEntryBlock().getFirstInsertionPt());1239// Create the merged parallel region with default proc binding, to1240// avoid overriding binding settings, and without explicit cancellation.1241InsertPointTy AfterIP = OMPInfoCache.OMPBuilder.createParallel(1242Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB, nullptr, nullptr,1243OMP_PROC_BIND_default, /* IsCancellable */ false);1244BranchInst::Create(AfterBB, AfterIP.getBlock());12451246// Perform the actual outlining.1247OMPInfoCache.OMPBuilder.finalize(OriginalFn);12481249Function *OutlinedFn = MergableCIs.front()->getCaller();12501251// Replace the __kmpc_fork_call calls with direct calls to the outlined1252// callbacks.1253SmallVector<Value *, 8> Args;1254for (auto *CI : MergableCIs) {1255Value *Callee = CI->getArgOperand(CallbackCalleeOperand);1256FunctionType *FT = OMPInfoCache.OMPBuilder.ParallelTask;1257Args.clear();1258Args.push_back(OutlinedFn->getArg(0));1259Args.push_back(OutlinedFn->getArg(1));1260for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E;1261++U)1262Args.push_back(CI->getArgOperand(U));12631264CallInst *NewCI =1265CallInst::Create(FT, Callee, Args, "", CI->getIterator());1266if (CI->getDebugLoc())1267NewCI->setDebugLoc(CI->getDebugLoc());12681269// Forward parameter attributes from the callback to the callee.1270for (unsigned U = CallbackFirstArgOperand, E = CI->arg_size(); U < E;1271++U)1272for (const Attribute &A : CI->getAttributes().getParamAttrs(U))1273NewCI->addParamAttr(1274U - (CallbackFirstArgOperand - CallbackCalleeOperand), A);12751276// Emit an explicit barrier to replace the implicit fork-join barrier.1277if (CI != MergableCIs.back()) {1278// TODO: Remove barrier if the merged parallel region includes the1279// 'nowait' clause.1280OMPInfoCache.OMPBuilder.createBarrier(1281InsertPointTy(NewCI->getParent(),1282NewCI->getNextNode()->getIterator()),1283OMPD_parallel);1284}12851286CI->eraseFromParent();1287}12881289assert(OutlinedFn != OriginalFn && "Outlining failed");1290CGUpdater.registerOutlinedFunction(*OriginalFn, *OutlinedFn);1291CGUpdater.reanalyzeFunction(*OriginalFn);12921293NumOpenMPParallelRegionsMerged += MergableCIs.size();12941295return true;1296};12971298// Helper function that identifes sequences of1299// __kmpc_fork_call uses in a basic block.1300auto DetectPRsCB = [&](Use &U, Function &F) {1301CallInst *CI = getCallIfRegularCall(U, &RFI);1302BB2PRMap[CI->getParent()].insert(CI);13031304return false;1305};13061307BB2PRMap.clear();1308RFI.foreachUse(SCC, DetectPRsCB);1309SmallVector<SmallVector<CallInst *, 4>, 4> MergableCIsVector;1310// Find mergable parallel regions within a basic block that are1311// safe to merge, that is any in-between instructions can safely1312// execute in parallel after merging.1313// TODO: support merging across basic-blocks.1314for (auto &It : BB2PRMap) {1315auto &CIs = It.getSecond();1316if (CIs.size() < 2)1317continue;13181319BasicBlock *BB = It.getFirst();1320SmallVector<CallInst *, 4> MergableCIs;13211322/// Returns true if the instruction is mergable, false otherwise.1323/// A terminator instruction is unmergable by definition since merging1324/// works within a BB. Instructions before the mergable region are1325/// mergable if they are not calls to OpenMP runtime functions that may1326/// set different execution parameters for subsequent parallel regions.1327/// Instructions in-between parallel regions are mergable if they are not1328/// calls to any non-intrinsic function since that may call a non-mergable1329/// OpenMP runtime function.1330auto IsMergable = [&](Instruction &I, bool IsBeforeMergableRegion) {1331// We do not merge across BBs, hence return false (unmergable) if the1332// instruction is a terminator.1333if (I.isTerminator())1334return false;13351336if (!isa<CallInst>(&I))1337return true;13381339CallInst *CI = cast<CallInst>(&I);1340if (IsBeforeMergableRegion) {1341Function *CalledFunction = CI->getCalledFunction();1342if (!CalledFunction)1343return false;1344// Return false (unmergable) if the call before the parallel1345// region calls an explicit affinity (proc_bind) or number of1346// threads (num_threads) compiler-generated function. Those settings1347// may be incompatible with following parallel regions.1348// TODO: ICV tracking to detect compatibility.1349for (const auto &RFI : UnmergableCallsInfo) {1350if (CalledFunction == RFI.Declaration)1351return false;1352}1353} else {1354// Return false (unmergable) if there is a call instruction1355// in-between parallel regions when it is not an intrinsic. It1356// may call an unmergable OpenMP runtime function in its callpath.1357// TODO: Keep track of possible OpenMP calls in the callpath.1358if (!isa<IntrinsicInst>(CI))1359return false;1360}13611362return true;1363};1364// Find maximal number of parallel region CIs that are safe to merge.1365for (auto It = BB->begin(), End = BB->end(); It != End;) {1366Instruction &I = *It;1367++It;13681369if (CIs.count(&I)) {1370MergableCIs.push_back(cast<CallInst>(&I));1371continue;1372}13731374// Continue expanding if the instruction is mergable.1375if (IsMergable(I, MergableCIs.empty()))1376continue;13771378// Forward the instruction iterator to skip the next parallel region1379// since there is an unmergable instruction which can affect it.1380for (; It != End; ++It) {1381Instruction &SkipI = *It;1382if (CIs.count(&SkipI)) {1383LLVM_DEBUG(dbgs() << TAG << "Skip parallel region " << SkipI1384<< " due to " << I << "\n");1385++It;1386break;1387}1388}13891390// Store mergable regions found.1391if (MergableCIs.size() > 1) {1392MergableCIsVector.push_back(MergableCIs);1393LLVM_DEBUG(dbgs() << TAG << "Found " << MergableCIs.size()1394<< " parallel regions in block " << BB->getName()1395<< " of function " << BB->getParent()->getName()1396<< "\n";);1397}13981399MergableCIs.clear();1400}14011402if (!MergableCIsVector.empty()) {1403Changed = true;14041405for (auto &MergableCIs : MergableCIsVector)1406Merge(MergableCIs, BB);1407MergableCIsVector.clear();1408}1409}14101411if (Changed) {1412/// Re-collect use for fork calls, emitted barrier calls, and1413/// any emitted master/end_master calls.1414OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_fork_call);1415OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_barrier);1416OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_master);1417OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_end_master);1418}14191420return Changed;1421}14221423/// Try to delete parallel regions if possible.1424bool deleteParallelRegions() {1425const unsigned CallbackCalleeOperand = 2;14261427OMPInformationCache::RuntimeFunctionInfo &RFI =1428OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];14291430if (!RFI.Declaration)1431return false;14321433bool Changed = false;1434auto DeleteCallCB = [&](Use &U, Function &) {1435CallInst *CI = getCallIfRegularCall(U);1436if (!CI)1437return false;1438auto *Fn = dyn_cast<Function>(1439CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts());1440if (!Fn)1441return false;1442if (!Fn->onlyReadsMemory())1443return false;1444if (!Fn->hasFnAttribute(Attribute::WillReturn))1445return false;14461447LLVM_DEBUG(dbgs() << TAG << "Delete read-only parallel region in "1448<< CI->getCaller()->getName() << "\n");14491450auto Remark = [&](OptimizationRemark OR) {1451return OR << "Removing parallel region with no side-effects.";1452};1453emitRemark<OptimizationRemark>(CI, "OMP160", Remark);14541455CI->eraseFromParent();1456Changed = true;1457++NumOpenMPParallelRegionsDeleted;1458return true;1459};14601461RFI.foreachUse(SCC, DeleteCallCB);14621463return Changed;1464}14651466/// Try to eliminate runtime calls by reusing existing ones.1467bool deduplicateRuntimeCalls() {1468bool Changed = false;14691470RuntimeFunction DeduplicableRuntimeCallIDs[] = {1471OMPRTL_omp_get_num_threads,1472OMPRTL_omp_in_parallel,1473OMPRTL_omp_get_cancellation,1474OMPRTL_omp_get_supported_active_levels,1475OMPRTL_omp_get_level,1476OMPRTL_omp_get_ancestor_thread_num,1477OMPRTL_omp_get_team_size,1478OMPRTL_omp_get_active_level,1479OMPRTL_omp_in_final,1480OMPRTL_omp_get_proc_bind,1481OMPRTL_omp_get_num_places,1482OMPRTL_omp_get_num_procs,1483OMPRTL_omp_get_place_num,1484OMPRTL_omp_get_partition_num_places,1485OMPRTL_omp_get_partition_place_nums};14861487// Global-tid is handled separately.1488SmallSetVector<Value *, 16> GTIdArgs;1489collectGlobalThreadIdArguments(GTIdArgs);1490LLVM_DEBUG(dbgs() << TAG << "Found " << GTIdArgs.size()1491<< " global thread ID arguments\n");14921493for (Function *F : SCC) {1494for (auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs)1495Changed |= deduplicateRuntimeCalls(1496*F, OMPInfoCache.RFIs[DeduplicableRuntimeCallID]);14971498// __kmpc_global_thread_num is special as we can replace it with an1499// argument in enough cases to make it worth trying.1500Value *GTIdArg = nullptr;1501for (Argument &Arg : F->args())1502if (GTIdArgs.count(&Arg)) {1503GTIdArg = &Arg;1504break;1505}1506Changed |= deduplicateRuntimeCalls(1507*F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg);1508}15091510return Changed;1511}15121513/// Tries to remove known runtime symbols that are optional from the module.1514bool removeRuntimeSymbols() {1515// The RPC client symbol is defined in `libc` and indicates that something1516// required an RPC server. If its users were all optimized out then we can1517// safely remove it.1518// TODO: This should be somewhere more common in the future.1519if (GlobalVariable *GV = M.getNamedGlobal("__llvm_libc_rpc_client")) {1520if (!GV->getType()->isPointerTy())1521return false;15221523Constant *C = GV->getInitializer();1524if (!C)1525return false;15261527// Check to see if the only user of the RPC client is the external handle.1528GlobalVariable *Client = dyn_cast<GlobalVariable>(C->stripPointerCasts());1529if (!Client || Client->getNumUses() > 1 ||1530Client->user_back() != GV->getInitializer())1531return false;15321533Client->replaceAllUsesWith(PoisonValue::get(Client->getType()));1534Client->eraseFromParent();15351536GV->replaceAllUsesWith(PoisonValue::get(GV->getType()));1537GV->eraseFromParent();15381539return true;1540}1541return false;1542}15431544/// Tries to hide the latency of runtime calls that involve host to1545/// device memory transfers by splitting them into their "issue" and "wait"1546/// versions. The "issue" is moved upwards as much as possible. The "wait" is1547/// moved downards as much as possible. The "issue" issues the memory transfer1548/// asynchronously, returning a handle. The "wait" waits in the returned1549/// handle for the memory transfer to finish.1550bool hideMemTransfersLatency() {1551auto &RFI = OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin_mapper];1552bool Changed = false;1553auto SplitMemTransfers = [&](Use &U, Function &Decl) {1554auto *RTCall = getCallIfRegularCall(U, &RFI);1555if (!RTCall)1556return false;15571558OffloadArray OffloadArrays[3];1559if (!getValuesInOffloadArrays(*RTCall, OffloadArrays))1560return false;15611562LLVM_DEBUG(dumpValuesInOffloadArrays(OffloadArrays));15631564// TODO: Check if can be moved upwards.1565bool WasSplit = false;1566Instruction *WaitMovementPoint = canBeMovedDownwards(*RTCall);1567if (WaitMovementPoint)1568WasSplit = splitTargetDataBeginRTC(*RTCall, *WaitMovementPoint);15691570Changed |= WasSplit;1571return WasSplit;1572};1573if (OMPInfoCache.runtimeFnsAvailable(1574{OMPRTL___tgt_target_data_begin_mapper_issue,1575OMPRTL___tgt_target_data_begin_mapper_wait}))1576RFI.foreachUse(SCC, SplitMemTransfers);15771578return Changed;1579}15801581void analysisGlobalization() {1582auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];15831584auto CheckGlobalization = [&](Use &U, Function &Decl) {1585if (CallInst *CI = getCallIfRegularCall(U, &RFI)) {1586auto Remark = [&](OptimizationRemarkMissed ORM) {1587return ORM1588<< "Found thread data sharing on the GPU. "1589<< "Expect degraded performance due to data globalization.";1590};1591emitRemark<OptimizationRemarkMissed>(CI, "OMP112", Remark);1592}15931594return false;1595};15961597RFI.foreachUse(SCC, CheckGlobalization);1598}15991600/// Maps the values stored in the offload arrays passed as arguments to1601/// \p RuntimeCall into the offload arrays in \p OAs.1602bool getValuesInOffloadArrays(CallInst &RuntimeCall,1603MutableArrayRef<OffloadArray> OAs) {1604assert(OAs.size() == 3 && "Need space for three offload arrays!");16051606// A runtime call that involves memory offloading looks something like:1607// call void @__tgt_target_data_begin_mapper(arg0, arg1,1608// i8** %offload_baseptrs, i8** %offload_ptrs, i64* %offload_sizes,1609// ...)1610// So, the idea is to access the allocas that allocate space for these1611// offload arrays, offload_baseptrs, offload_ptrs, offload_sizes.1612// Therefore:1613// i8** %offload_baseptrs.1614Value *BasePtrsArg =1615RuntimeCall.getArgOperand(OffloadArray::BasePtrsArgNum);1616// i8** %offload_ptrs.1617Value *PtrsArg = RuntimeCall.getArgOperand(OffloadArray::PtrsArgNum);1618// i8** %offload_sizes.1619Value *SizesArg = RuntimeCall.getArgOperand(OffloadArray::SizesArgNum);16201621// Get values stored in **offload_baseptrs.1622auto *V = getUnderlyingObject(BasePtrsArg);1623if (!isa<AllocaInst>(V))1624return false;1625auto *BasePtrsArray = cast<AllocaInst>(V);1626if (!OAs[0].initialize(*BasePtrsArray, RuntimeCall))1627return false;16281629// Get values stored in **offload_baseptrs.1630V = getUnderlyingObject(PtrsArg);1631if (!isa<AllocaInst>(V))1632return false;1633auto *PtrsArray = cast<AllocaInst>(V);1634if (!OAs[1].initialize(*PtrsArray, RuntimeCall))1635return false;16361637// Get values stored in **offload_sizes.1638V = getUnderlyingObject(SizesArg);1639// If it's a [constant] global array don't analyze it.1640if (isa<GlobalValue>(V))1641return isa<Constant>(V);1642if (!isa<AllocaInst>(V))1643return false;16441645auto *SizesArray = cast<AllocaInst>(V);1646if (!OAs[2].initialize(*SizesArray, RuntimeCall))1647return false;16481649return true;1650}16511652/// Prints the values in the OffloadArrays \p OAs using LLVM_DEBUG.1653/// For now this is a way to test that the function getValuesInOffloadArrays1654/// is working properly.1655/// TODO: Move this to a unittest when unittests are available for OpenMPOpt.1656void dumpValuesInOffloadArrays(ArrayRef<OffloadArray> OAs) {1657assert(OAs.size() == 3 && "There are three offload arrays to debug!");16581659LLVM_DEBUG(dbgs() << TAG << " Successfully got offload values:\n");1660std::string ValuesStr;1661raw_string_ostream Printer(ValuesStr);1662std::string Separator = " --- ";16631664for (auto *BP : OAs[0].StoredValues) {1665BP->print(Printer);1666Printer << Separator;1667}1668LLVM_DEBUG(dbgs() << "\t\toffload_baseptrs: " << ValuesStr << "\n");1669ValuesStr.clear();16701671for (auto *P : OAs[1].StoredValues) {1672P->print(Printer);1673Printer << Separator;1674}1675LLVM_DEBUG(dbgs() << "\t\toffload_ptrs: " << ValuesStr << "\n");1676ValuesStr.clear();16771678for (auto *S : OAs[2].StoredValues) {1679S->print(Printer);1680Printer << Separator;1681}1682LLVM_DEBUG(dbgs() << "\t\toffload_sizes: " << ValuesStr << "\n");1683}16841685/// Returns the instruction where the "wait" counterpart \p RuntimeCall can be1686/// moved. Returns nullptr if the movement is not possible, or not worth it.1687Instruction *canBeMovedDownwards(CallInst &RuntimeCall) {1688// FIXME: This traverses only the BasicBlock where RuntimeCall is.1689// Make it traverse the CFG.16901691Instruction *CurrentI = &RuntimeCall;1692bool IsWorthIt = false;1693while ((CurrentI = CurrentI->getNextNode())) {16941695// TODO: Once we detect the regions to be offloaded we should use the1696// alias analysis manager to check if CurrentI may modify one of1697// the offloaded regions.1698if (CurrentI->mayHaveSideEffects() || CurrentI->mayReadFromMemory()) {1699if (IsWorthIt)1700return CurrentI;17011702return nullptr;1703}17041705// FIXME: For now if we move it over anything without side effect1706// is worth it.1707IsWorthIt = true;1708}17091710// Return end of BasicBlock.1711return RuntimeCall.getParent()->getTerminator();1712}17131714/// Splits \p RuntimeCall into its "issue" and "wait" counterparts.1715bool splitTargetDataBeginRTC(CallInst &RuntimeCall,1716Instruction &WaitMovementPoint) {1717// Create stack allocated handle (__tgt_async_info) at the beginning of the1718// function. Used for storing information of the async transfer, allowing to1719// wait on it later.1720auto &IRBuilder = OMPInfoCache.OMPBuilder;1721Function *F = RuntimeCall.getCaller();1722BasicBlock &Entry = F->getEntryBlock();1723IRBuilder.Builder.SetInsertPoint(&Entry,1724Entry.getFirstNonPHIOrDbgOrAlloca());1725Value *Handle = IRBuilder.Builder.CreateAlloca(1726IRBuilder.AsyncInfo, /*ArraySize=*/nullptr, "handle");1727Handle =1728IRBuilder.Builder.CreateAddrSpaceCast(Handle, IRBuilder.AsyncInfoPtr);17291730// Add "issue" runtime call declaration:1731// declare %struct.tgt_async_info @__tgt_target_data_begin_issue(i64, i32,1732// i8**, i8**, i64*, i64*)1733FunctionCallee IssueDecl = IRBuilder.getOrCreateRuntimeFunction(1734M, OMPRTL___tgt_target_data_begin_mapper_issue);17351736// Change RuntimeCall call site for its asynchronous version.1737SmallVector<Value *, 16> Args;1738for (auto &Arg : RuntimeCall.args())1739Args.push_back(Arg.get());1740Args.push_back(Handle);17411742CallInst *IssueCallsite = CallInst::Create(IssueDecl, Args, /*NameStr=*/"",1743RuntimeCall.getIterator());1744OMPInfoCache.setCallingConvention(IssueDecl, IssueCallsite);1745RuntimeCall.eraseFromParent();17461747// Add "wait" runtime call declaration:1748// declare void @__tgt_target_data_begin_wait(i64, %struct.__tgt_async_info)1749FunctionCallee WaitDecl = IRBuilder.getOrCreateRuntimeFunction(1750M, OMPRTL___tgt_target_data_begin_mapper_wait);17511752Value *WaitParams[2] = {1753IssueCallsite->getArgOperand(1754OffloadArray::DeviceIDArgNum), // device_id.1755Handle // handle to wait on.1756};1757CallInst *WaitCallsite = CallInst::Create(1758WaitDecl, WaitParams, /*NameStr=*/"", WaitMovementPoint.getIterator());1759OMPInfoCache.setCallingConvention(WaitDecl, WaitCallsite);17601761return true;1762}17631764static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent,1765bool GlobalOnly, bool &SingleChoice) {1766if (CurrentIdent == NextIdent)1767return CurrentIdent;17681769// TODO: Figure out how to actually combine multiple debug locations. For1770// now we just keep an existing one if there is a single choice.1771if (!GlobalOnly || isa<GlobalValue>(NextIdent)) {1772SingleChoice = !CurrentIdent;1773return NextIdent;1774}1775return nullptr;1776}17771778/// Return an `struct ident_t*` value that represents the ones used in the1779/// calls of \p RFI inside of \p F. If \p GlobalOnly is true, we will not1780/// return a local `struct ident_t*`. For now, if we cannot find a suitable1781/// return value we create one from scratch. We also do not yet combine1782/// information, e.g., the source locations, see combinedIdentStruct.1783Value *1784getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI,1785Function &F, bool GlobalOnly) {1786bool SingleChoice = true;1787Value *Ident = nullptr;1788auto CombineIdentStruct = [&](Use &U, Function &Caller) {1789CallInst *CI = getCallIfRegularCall(U, &RFI);1790if (!CI || &F != &Caller)1791return false;1792Ident = combinedIdentStruct(Ident, CI->getArgOperand(0),1793/* GlobalOnly */ true, SingleChoice);1794return false;1795};1796RFI.foreachUse(SCC, CombineIdentStruct);17971798if (!Ident || !SingleChoice) {1799// The IRBuilder uses the insertion block to get to the module, this is1800// unfortunate but we work around it for now.1801if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock())1802OMPInfoCache.OMPBuilder.updateToLocation(OpenMPIRBuilder::InsertPointTy(1803&F.getEntryBlock(), F.getEntryBlock().begin()));1804// Create a fallback location if non was found.1805// TODO: Use the debug locations of the calls instead.1806uint32_t SrcLocStrSize;1807Constant *Loc =1808OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);1809Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc, SrcLocStrSize);1810}1811return Ident;1812}18131814/// Try to eliminate calls of \p RFI in \p F by reusing an existing one or1815/// \p ReplVal if given.1816bool deduplicateRuntimeCalls(Function &F,1817OMPInformationCache::RuntimeFunctionInfo &RFI,1818Value *ReplVal = nullptr) {1819auto *UV = RFI.getUseVector(F);1820if (!UV || UV->size() + (ReplVal != nullptr) < 2)1821return false;18221823LLVM_DEBUG(1824dbgs() << TAG << "Deduplicate " << UV->size() << " uses of " << RFI.Name1825<< (ReplVal ? " with an existing value\n" : "\n") << "\n");18261827assert((!ReplVal || (isa<Argument>(ReplVal) &&1828cast<Argument>(ReplVal)->getParent() == &F)) &&1829"Unexpected replacement value!");18301831// TODO: Use dominance to find a good position instead.1832auto CanBeMoved = [this](CallBase &CB) {1833unsigned NumArgs = CB.arg_size();1834if (NumArgs == 0)1835return true;1836if (CB.getArgOperand(0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr)1837return false;1838for (unsigned U = 1; U < NumArgs; ++U)1839if (isa<Instruction>(CB.getArgOperand(U)))1840return false;1841return true;1842};18431844if (!ReplVal) {1845auto *DT =1846OMPInfoCache.getAnalysisResultForFunction<DominatorTreeAnalysis>(F);1847if (!DT)1848return false;1849Instruction *IP = nullptr;1850for (Use *U : *UV) {1851if (CallInst *CI = getCallIfRegularCall(*U, &RFI)) {1852if (IP)1853IP = DT->findNearestCommonDominator(IP, CI);1854else1855IP = CI;1856if (!CanBeMoved(*CI))1857continue;1858if (!ReplVal)1859ReplVal = CI;1860}1861}1862if (!ReplVal)1863return false;1864assert(IP && "Expected insertion point!");1865cast<Instruction>(ReplVal)->moveBefore(IP);1866}18671868// If we use a call as a replacement value we need to make sure the ident is1869// valid at the new location. For now we just pick a global one, either1870// existing and used by one of the calls, or created from scratch.1871if (CallBase *CI = dyn_cast<CallBase>(ReplVal)) {1872if (!CI->arg_empty() &&1873CI->getArgOperand(0)->getType() == OMPInfoCache.OMPBuilder.IdentPtr) {1874Value *Ident = getCombinedIdentFromCallUsesIn(RFI, F,1875/* GlobalOnly */ true);1876CI->setArgOperand(0, Ident);1877}1878}18791880bool Changed = false;1881auto ReplaceAndDeleteCB = [&](Use &U, Function &Caller) {1882CallInst *CI = getCallIfRegularCall(U, &RFI);1883if (!CI || CI == ReplVal || &F != &Caller)1884return false;1885assert(CI->getCaller() == &F && "Unexpected call!");18861887auto Remark = [&](OptimizationRemark OR) {1888return OR << "OpenMP runtime call "1889<< ore::NV("OpenMPOptRuntime", RFI.Name) << " deduplicated.";1890};1891if (CI->getDebugLoc())1892emitRemark<OptimizationRemark>(CI, "OMP170", Remark);1893else1894emitRemark<OptimizationRemark>(&F, "OMP170", Remark);18951896CI->replaceAllUsesWith(ReplVal);1897CI->eraseFromParent();1898++NumOpenMPRuntimeCallsDeduplicated;1899Changed = true;1900return true;1901};1902RFI.foreachUse(SCC, ReplaceAndDeleteCB);19031904return Changed;1905}19061907/// Collect arguments that represent the global thread id in \p GTIdArgs.1908void collectGlobalThreadIdArguments(SmallSetVector<Value *, 16> >IdArgs) {1909// TODO: Below we basically perform a fixpoint iteration with a pessimistic1910// initialization. We could define an AbstractAttribute instead and1911// run the Attributor here once it can be run as an SCC pass.19121913// Helper to check the argument \p ArgNo at all call sites of \p F for1914// a GTId.1915auto CallArgOpIsGTId = [&](Function &F, unsigned ArgNo, CallInst &RefCI) {1916if (!F.hasLocalLinkage())1917return false;1918for (Use &U : F.uses()) {1919if (CallInst *CI = getCallIfRegularCall(U)) {1920Value *ArgOp = CI->getArgOperand(ArgNo);1921if (CI == &RefCI || GTIdArgs.count(ArgOp) ||1922getCallIfRegularCall(1923*ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]))1924continue;1925}1926return false;1927}1928return true;1929};19301931// Helper to identify uses of a GTId as GTId arguments.1932auto AddUserArgs = [&](Value >Id) {1933for (Use &U : GTId.uses())1934if (CallInst *CI = dyn_cast<CallInst>(U.getUser()))1935if (CI->isArgOperand(&U))1936if (Function *Callee = CI->getCalledFunction())1937if (CallArgOpIsGTId(*Callee, U.getOperandNo(), *CI))1938GTIdArgs.insert(Callee->getArg(U.getOperandNo()));1939};19401941// The argument users of __kmpc_global_thread_num calls are GTIds.1942OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI =1943OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num];19441945GlobThreadNumRFI.foreachUse(SCC, [&](Use &U, Function &F) {1946if (CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI))1947AddUserArgs(*CI);1948return false;1949});19501951// Transitively search for more arguments by looking at the users of the1952// ones we know already. During the search the GTIdArgs vector is extended1953// so we cannot cache the size nor can we use a range based for.1954for (unsigned U = 0; U < GTIdArgs.size(); ++U)1955AddUserArgs(*GTIdArgs[U]);1956}19571958/// Kernel (=GPU) optimizations and utility functions1959///1960///{{19611962/// Cache to remember the unique kernel for a function.1963DenseMap<Function *, std::optional<Kernel>> UniqueKernelMap;19641965/// Find the unique kernel that will execute \p F, if any.1966Kernel getUniqueKernelFor(Function &F);19671968/// Find the unique kernel that will execute \p I, if any.1969Kernel getUniqueKernelFor(Instruction &I) {1970return getUniqueKernelFor(*I.getFunction());1971}19721973/// Rewrite the device (=GPU) code state machine create in non-SPMD mode in1974/// the cases we can avoid taking the address of a function.1975bool rewriteDeviceCodeStateMachine();19761977///1978///}}19791980/// Emit a remark generically1981///1982/// This template function can be used to generically emit a remark. The1983/// RemarkKind should be one of the following:1984/// - OptimizationRemark to indicate a successful optimization attempt1985/// - OptimizationRemarkMissed to report a failed optimization attempt1986/// - OptimizationRemarkAnalysis to provide additional information about an1987/// optimization attempt1988///1989/// The remark is built using a callback function provided by the caller that1990/// takes a RemarkKind as input and returns a RemarkKind.1991template <typename RemarkKind, typename RemarkCallBack>1992void emitRemark(Instruction *I, StringRef RemarkName,1993RemarkCallBack &&RemarkCB) const {1994Function *F = I->getParent()->getParent();1995auto &ORE = OREGetter(F);19961997if (RemarkName.starts_with("OMP"))1998ORE.emit([&]() {1999return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I))2000<< " [" << RemarkName << "]";2001});2002else2003ORE.emit(2004[&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I)); });2005}20062007/// Emit a remark on a function.2008template <typename RemarkKind, typename RemarkCallBack>2009void emitRemark(Function *F, StringRef RemarkName,2010RemarkCallBack &&RemarkCB) const {2011auto &ORE = OREGetter(F);20122013if (RemarkName.starts_with("OMP"))2014ORE.emit([&]() {2015return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F))2016<< " [" << RemarkName << "]";2017});2018else2019ORE.emit(2020[&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F)); });2021}20222023/// The underlying module.2024Module &M;20252026/// The SCC we are operating on.2027SmallVectorImpl<Function *> &SCC;20282029/// Callback to update the call graph, the first argument is a removed call,2030/// the second an optional replacement call.2031CallGraphUpdater &CGUpdater;20322033/// Callback to get an OptimizationRemarkEmitter from a Function *2034OptimizationRemarkGetter OREGetter;20352036/// OpenMP-specific information cache. Also Used for Attributor runs.2037OMPInformationCache &OMPInfoCache;20382039/// Attributor instance.2040Attributor &A;20412042/// Helper function to run Attributor on SCC.2043bool runAttributor(bool IsModulePass) {2044if (SCC.empty())2045return false;20462047registerAAs(IsModulePass);20482049ChangeStatus Changed = A.run();20502051LLVM_DEBUG(dbgs() << "[Attributor] Done with " << SCC.size()2052<< " functions, result: " << Changed << ".\n");20532054if (Changed == ChangeStatus::CHANGED)2055OMPInfoCache.invalidateAnalyses();20562057return Changed == ChangeStatus::CHANGED;2058}20592060void registerFoldRuntimeCall(RuntimeFunction RF);20612062/// Populate the Attributor with abstract attribute opportunities in the2063/// functions.2064void registerAAs(bool IsModulePass);20652066public:2067/// Callback to register AAs for live functions, including internal functions2068/// marked live during the traversal.2069static void registerAAsForFunction(Attributor &A, const Function &F);2070};20712072Kernel OpenMPOpt::getUniqueKernelFor(Function &F) {2073if (OMPInfoCache.CGSCC && !OMPInfoCache.CGSCC->empty() &&2074!OMPInfoCache.CGSCC->contains(&F))2075return nullptr;20762077// Use a scope to keep the lifetime of the CachedKernel short.2078{2079std::optional<Kernel> &CachedKernel = UniqueKernelMap[&F];2080if (CachedKernel)2081return *CachedKernel;20822083// TODO: We should use an AA to create an (optimistic and callback2084// call-aware) call graph. For now we stick to simple patterns that2085// are less powerful, basically the worst fixpoint.2086if (isOpenMPKernel(F)) {2087CachedKernel = Kernel(&F);2088return *CachedKernel;2089}20902091CachedKernel = nullptr;2092if (!F.hasLocalLinkage()) {20932094// See https://openmp.llvm.org/remarks/OptimizationRemarks.html2095auto Remark = [&](OptimizationRemarkAnalysis ORA) {2096return ORA << "Potentially unknown OpenMP target region caller.";2097};2098emitRemark<OptimizationRemarkAnalysis>(&F, "OMP100", Remark);20992100return nullptr;2101}2102}21032104auto GetUniqueKernelForUse = [&](const Use &U) -> Kernel {2105if (auto *Cmp = dyn_cast<ICmpInst>(U.getUser())) {2106// Allow use in equality comparisons.2107if (Cmp->isEquality())2108return getUniqueKernelFor(*Cmp);2109return nullptr;2110}2111if (auto *CB = dyn_cast<CallBase>(U.getUser())) {2112// Allow direct calls.2113if (CB->isCallee(&U))2114return getUniqueKernelFor(*CB);21152116OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =2117OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];2118// Allow the use in __kmpc_parallel_51 calls.2119if (OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI))2120return getUniqueKernelFor(*CB);2121return nullptr;2122}2123// Disallow every other use.2124return nullptr;2125};21262127// TODO: In the future we want to track more than just a unique kernel.2128SmallPtrSet<Kernel, 2> PotentialKernels;2129OMPInformationCache::foreachUse(F, [&](const Use &U) {2130PotentialKernels.insert(GetUniqueKernelForUse(U));2131});21322133Kernel K = nullptr;2134if (PotentialKernels.size() == 1)2135K = *PotentialKernels.begin();21362137// Cache the result.2138UniqueKernelMap[&F] = K;21392140return K;2141}21422143bool OpenMPOpt::rewriteDeviceCodeStateMachine() {2144OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =2145OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];21462147bool Changed = false;2148if (!KernelParallelRFI)2149return Changed;21502151// If we have disabled state machine changes, exit2152if (DisableOpenMPOptStateMachineRewrite)2153return Changed;21542155for (Function *F : SCC) {21562157// Check if the function is a use in a __kmpc_parallel_51 call at2158// all.2159bool UnknownUse = false;2160bool KernelParallelUse = false;2161unsigned NumDirectCalls = 0;21622163SmallVector<Use *, 2> ToBeReplacedStateMachineUses;2164OMPInformationCache::foreachUse(*F, [&](Use &U) {2165if (auto *CB = dyn_cast<CallBase>(U.getUser()))2166if (CB->isCallee(&U)) {2167++NumDirectCalls;2168return;2169}21702171if (isa<ICmpInst>(U.getUser())) {2172ToBeReplacedStateMachineUses.push_back(&U);2173return;2174}21752176// Find wrapper functions that represent parallel kernels.2177CallInst *CI =2178OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI);2179const unsigned int WrapperFunctionArgNo = 6;2180if (!KernelParallelUse && CI &&2181CI->getArgOperandNo(&U) == WrapperFunctionArgNo) {2182KernelParallelUse = true;2183ToBeReplacedStateMachineUses.push_back(&U);2184return;2185}2186UnknownUse = true;2187});21882189// Do not emit a remark if we haven't seen a __kmpc_parallel_512190// use.2191if (!KernelParallelUse)2192continue;21932194// If this ever hits, we should investigate.2195// TODO: Checking the number of uses is not a necessary restriction and2196// should be lifted.2197if (UnknownUse || NumDirectCalls != 1 ||2198ToBeReplacedStateMachineUses.size() > 2) {2199auto Remark = [&](OptimizationRemarkAnalysis ORA) {2200return ORA << "Parallel region is used in "2201<< (UnknownUse ? "unknown" : "unexpected")2202<< " ways. Will not attempt to rewrite the state machine.";2203};2204emitRemark<OptimizationRemarkAnalysis>(F, "OMP101", Remark);2205continue;2206}22072208// Even if we have __kmpc_parallel_51 calls, we (for now) give2209// up if the function is not called from a unique kernel.2210Kernel K = getUniqueKernelFor(*F);2211if (!K) {2212auto Remark = [&](OptimizationRemarkAnalysis ORA) {2213return ORA << "Parallel region is not called from a unique kernel. "2214"Will not attempt to rewrite the state machine.";2215};2216emitRemark<OptimizationRemarkAnalysis>(F, "OMP102", Remark);2217continue;2218}22192220// We now know F is a parallel body function called only from the kernel K.2221// We also identified the state machine uses in which we replace the2222// function pointer by a new global symbol for identification purposes. This2223// ensures only direct calls to the function are left.22242225Module &M = *F->getParent();2226Type *Int8Ty = Type::getInt8Ty(M.getContext());22272228auto *ID = new GlobalVariable(2229M, Int8Ty, /* isConstant */ true, GlobalValue::PrivateLinkage,2230UndefValue::get(Int8Ty), F->getName() + ".ID");22312232for (Use *U : ToBeReplacedStateMachineUses)2233U->set(ConstantExpr::getPointerBitCastOrAddrSpaceCast(2234ID, U->get()->getType()));22352236++NumOpenMPParallelRegionsReplacedInGPUStateMachine;22372238Changed = true;2239}22402241return Changed;2242}22432244/// Abstract Attribute for tracking ICV values.2245struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> {2246using Base = StateWrapper<BooleanState, AbstractAttribute>;2247AAICVTracker(const IRPosition &IRP, Attributor &A) : Base(IRP) {}22482249/// Returns true if value is assumed to be tracked.2250bool isAssumedTracked() const { return getAssumed(); }22512252/// Returns true if value is known to be tracked.2253bool isKnownTracked() const { return getAssumed(); }22542255/// Create an abstract attribute biew for the position \p IRP.2256static AAICVTracker &createForPosition(const IRPosition &IRP, Attributor &A);22572258/// Return the value with which \p I can be replaced for specific \p ICV.2259virtual std::optional<Value *> getReplacementValue(InternalControlVar ICV,2260const Instruction *I,2261Attributor &A) const {2262return std::nullopt;2263}22642265/// Return an assumed unique ICV value if a single candidate is found. If2266/// there cannot be one, return a nullptr. If it is not clear yet, return2267/// std::nullopt.2268virtual std::optional<Value *>2269getUniqueReplacementValue(InternalControlVar ICV) const = 0;22702271// Currently only nthreads is being tracked.2272// this array will only grow with time.2273InternalControlVar TrackableICVs[1] = {ICV_nthreads};22742275/// See AbstractAttribute::getName()2276const std::string getName() const override { return "AAICVTracker"; }22772278/// See AbstractAttribute::getIdAddr()2279const char *getIdAddr() const override { return &ID; }22802281/// This function should return true if the type of the \p AA is AAICVTracker2282static bool classof(const AbstractAttribute *AA) {2283return (AA->getIdAddr() == &ID);2284}22852286static const char ID;2287};22882289struct AAICVTrackerFunction : public AAICVTracker {2290AAICVTrackerFunction(const IRPosition &IRP, Attributor &A)2291: AAICVTracker(IRP, A) {}22922293// FIXME: come up with better string.2294const std::string getAsStr(Attributor *) const override {2295return "ICVTrackerFunction";2296}22972298// FIXME: come up with some stats.2299void trackStatistics() const override {}23002301/// We don't manifest anything for this AA.2302ChangeStatus manifest(Attributor &A) override {2303return ChangeStatus::UNCHANGED;2304}23052306// Map of ICV to their values at specific program point.2307EnumeratedArray<DenseMap<Instruction *, Value *>, InternalControlVar,2308InternalControlVar::ICV___last>2309ICVReplacementValuesMap;23102311ChangeStatus updateImpl(Attributor &A) override {2312ChangeStatus HasChanged = ChangeStatus::UNCHANGED;23132314Function *F = getAnchorScope();23152316auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());23172318for (InternalControlVar ICV : TrackableICVs) {2319auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];23202321auto &ValuesMap = ICVReplacementValuesMap[ICV];2322auto TrackValues = [&](Use &U, Function &) {2323CallInst *CI = OpenMPOpt::getCallIfRegularCall(U);2324if (!CI)2325return false;23262327// FIXME: handle setters with more that 1 arguments.2328/// Track new value.2329if (ValuesMap.insert(std::make_pair(CI, CI->getArgOperand(0))).second)2330HasChanged = ChangeStatus::CHANGED;23312332return false;2333};23342335auto CallCheck = [&](Instruction &I) {2336std::optional<Value *> ReplVal = getValueForCall(A, I, ICV);2337if (ReplVal && ValuesMap.insert(std::make_pair(&I, *ReplVal)).second)2338HasChanged = ChangeStatus::CHANGED;23392340return true;2341};23422343// Track all changes of an ICV.2344SetterRFI.foreachUse(TrackValues, F);23452346bool UsedAssumedInformation = false;2347A.checkForAllInstructions(CallCheck, *this, {Instruction::Call},2348UsedAssumedInformation,2349/* CheckBBLivenessOnly */ true);23502351/// TODO: Figure out a way to avoid adding entry in2352/// ICVReplacementValuesMap2353Instruction *Entry = &F->getEntryBlock().front();2354if (HasChanged == ChangeStatus::CHANGED && !ValuesMap.count(Entry))2355ValuesMap.insert(std::make_pair(Entry, nullptr));2356}23572358return HasChanged;2359}23602361/// Helper to check if \p I is a call and get the value for it if it is2362/// unique.2363std::optional<Value *> getValueForCall(Attributor &A, const Instruction &I,2364InternalControlVar &ICV) const {23652366const auto *CB = dyn_cast<CallBase>(&I);2367if (!CB || CB->hasFnAttr("no_openmp") ||2368CB->hasFnAttr("no_openmp_routines"))2369return std::nullopt;23702371auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());2372auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter];2373auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];2374Function *CalledFunction = CB->getCalledFunction();23752376// Indirect call, assume ICV changes.2377if (CalledFunction == nullptr)2378return nullptr;2379if (CalledFunction == GetterRFI.Declaration)2380return std::nullopt;2381if (CalledFunction == SetterRFI.Declaration) {2382if (ICVReplacementValuesMap[ICV].count(&I))2383return ICVReplacementValuesMap[ICV].lookup(&I);23842385return nullptr;2386}23872388// Since we don't know, assume it changes the ICV.2389if (CalledFunction->isDeclaration())2390return nullptr;23912392const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(2393*this, IRPosition::callsite_returned(*CB), DepClassTy::REQUIRED);23942395if (ICVTrackingAA->isAssumedTracked()) {2396std::optional<Value *> URV =2397ICVTrackingAA->getUniqueReplacementValue(ICV);2398if (!URV || (*URV && AA::isValidAtPosition(AA::ValueAndContext(**URV, I),2399OMPInfoCache)))2400return URV;2401}24022403// If we don't know, assume it changes.2404return nullptr;2405}24062407// We don't check unique value for a function, so return std::nullopt.2408std::optional<Value *>2409getUniqueReplacementValue(InternalControlVar ICV) const override {2410return std::nullopt;2411}24122413/// Return the value with which \p I can be replaced for specific \p ICV.2414std::optional<Value *> getReplacementValue(InternalControlVar ICV,2415const Instruction *I,2416Attributor &A) const override {2417const auto &ValuesMap = ICVReplacementValuesMap[ICV];2418if (ValuesMap.count(I))2419return ValuesMap.lookup(I);24202421SmallVector<const Instruction *, 16> Worklist;2422SmallPtrSet<const Instruction *, 16> Visited;2423Worklist.push_back(I);24242425std::optional<Value *> ReplVal;24262427while (!Worklist.empty()) {2428const Instruction *CurrInst = Worklist.pop_back_val();2429if (!Visited.insert(CurrInst).second)2430continue;24312432const BasicBlock *CurrBB = CurrInst->getParent();24332434// Go up and look for all potential setters/calls that might change the2435// ICV.2436while ((CurrInst = CurrInst->getPrevNode())) {2437if (ValuesMap.count(CurrInst)) {2438std::optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst);2439// Unknown value, track new.2440if (!ReplVal) {2441ReplVal = NewReplVal;2442break;2443}24442445// If we found a new value, we can't know the icv value anymore.2446if (NewReplVal)2447if (ReplVal != NewReplVal)2448return nullptr;24492450break;2451}24522453std::optional<Value *> NewReplVal = getValueForCall(A, *CurrInst, ICV);2454if (!NewReplVal)2455continue;24562457// Unknown value, track new.2458if (!ReplVal) {2459ReplVal = NewReplVal;2460break;2461}24622463// if (NewReplVal.hasValue())2464// We found a new value, we can't know the icv value anymore.2465if (ReplVal != NewReplVal)2466return nullptr;2467}24682469// If we are in the same BB and we have a value, we are done.2470if (CurrBB == I->getParent() && ReplVal)2471return ReplVal;24722473// Go through all predecessors and add terminators for analysis.2474for (const BasicBlock *Pred : predecessors(CurrBB))2475if (const Instruction *Terminator = Pred->getTerminator())2476Worklist.push_back(Terminator);2477}24782479return ReplVal;2480}2481};24822483struct AAICVTrackerFunctionReturned : AAICVTracker {2484AAICVTrackerFunctionReturned(const IRPosition &IRP, Attributor &A)2485: AAICVTracker(IRP, A) {}24862487// FIXME: come up with better string.2488const std::string getAsStr(Attributor *) const override {2489return "ICVTrackerFunctionReturned";2490}24912492// FIXME: come up with some stats.2493void trackStatistics() const override {}24942495/// We don't manifest anything for this AA.2496ChangeStatus manifest(Attributor &A) override {2497return ChangeStatus::UNCHANGED;2498}24992500// Map of ICV to their values at specific program point.2501EnumeratedArray<std::optional<Value *>, InternalControlVar,2502InternalControlVar::ICV___last>2503ICVReplacementValuesMap;25042505/// Return the value with which \p I can be replaced for specific \p ICV.2506std::optional<Value *>2507getUniqueReplacementValue(InternalControlVar ICV) const override {2508return ICVReplacementValuesMap[ICV];2509}25102511ChangeStatus updateImpl(Attributor &A) override {2512ChangeStatus Changed = ChangeStatus::UNCHANGED;2513const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(2514*this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);25152516if (!ICVTrackingAA->isAssumedTracked())2517return indicatePessimisticFixpoint();25182519for (InternalControlVar ICV : TrackableICVs) {2520std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];2521std::optional<Value *> UniqueICVValue;25222523auto CheckReturnInst = [&](Instruction &I) {2524std::optional<Value *> NewReplVal =2525ICVTrackingAA->getReplacementValue(ICV, &I, A);25262527// If we found a second ICV value there is no unique returned value.2528if (UniqueICVValue && UniqueICVValue != NewReplVal)2529return false;25302531UniqueICVValue = NewReplVal;25322533return true;2534};25352536bool UsedAssumedInformation = false;2537if (!A.checkForAllInstructions(CheckReturnInst, *this, {Instruction::Ret},2538UsedAssumedInformation,2539/* CheckBBLivenessOnly */ true))2540UniqueICVValue = nullptr;25412542if (UniqueICVValue == ReplVal)2543continue;25442545ReplVal = UniqueICVValue;2546Changed = ChangeStatus::CHANGED;2547}25482549return Changed;2550}2551};25522553struct AAICVTrackerCallSite : AAICVTracker {2554AAICVTrackerCallSite(const IRPosition &IRP, Attributor &A)2555: AAICVTracker(IRP, A) {}25562557void initialize(Attributor &A) override {2558assert(getAnchorScope() && "Expected anchor function");25592560// We only initialize this AA for getters, so we need to know which ICV it2561// gets.2562auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());2563for (InternalControlVar ICV : TrackableICVs) {2564auto ICVInfo = OMPInfoCache.ICVs[ICV];2565auto &Getter = OMPInfoCache.RFIs[ICVInfo.Getter];2566if (Getter.Declaration == getAssociatedFunction()) {2567AssociatedICV = ICVInfo.Kind;2568return;2569}2570}25712572/// Unknown ICV.2573indicatePessimisticFixpoint();2574}25752576ChangeStatus manifest(Attributor &A) override {2577if (!ReplVal || !*ReplVal)2578return ChangeStatus::UNCHANGED;25792580A.changeAfterManifest(IRPosition::inst(*getCtxI()), **ReplVal);2581A.deleteAfterManifest(*getCtxI());25822583return ChangeStatus::CHANGED;2584}25852586// FIXME: come up with better string.2587const std::string getAsStr(Attributor *) const override {2588return "ICVTrackerCallSite";2589}25902591// FIXME: come up with some stats.2592void trackStatistics() const override {}25932594InternalControlVar AssociatedICV;2595std::optional<Value *> ReplVal;25962597ChangeStatus updateImpl(Attributor &A) override {2598const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(2599*this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);26002601// We don't have any information, so we assume it changes the ICV.2602if (!ICVTrackingAA->isAssumedTracked())2603return indicatePessimisticFixpoint();26042605std::optional<Value *> NewReplVal =2606ICVTrackingAA->getReplacementValue(AssociatedICV, getCtxI(), A);26072608if (ReplVal == NewReplVal)2609return ChangeStatus::UNCHANGED;26102611ReplVal = NewReplVal;2612return ChangeStatus::CHANGED;2613}26142615// Return the value with which associated value can be replaced for specific2616// \p ICV.2617std::optional<Value *>2618getUniqueReplacementValue(InternalControlVar ICV) const override {2619return ReplVal;2620}2621};26222623struct AAICVTrackerCallSiteReturned : AAICVTracker {2624AAICVTrackerCallSiteReturned(const IRPosition &IRP, Attributor &A)2625: AAICVTracker(IRP, A) {}26262627// FIXME: come up with better string.2628const std::string getAsStr(Attributor *) const override {2629return "ICVTrackerCallSiteReturned";2630}26312632// FIXME: come up with some stats.2633void trackStatistics() const override {}26342635/// We don't manifest anything for this AA.2636ChangeStatus manifest(Attributor &A) override {2637return ChangeStatus::UNCHANGED;2638}26392640// Map of ICV to their values at specific program point.2641EnumeratedArray<std::optional<Value *>, InternalControlVar,2642InternalControlVar::ICV___last>2643ICVReplacementValuesMap;26442645/// Return the value with which associated value can be replaced for specific2646/// \p ICV.2647std::optional<Value *>2648getUniqueReplacementValue(InternalControlVar ICV) const override {2649return ICVReplacementValuesMap[ICV];2650}26512652ChangeStatus updateImpl(Attributor &A) override {2653ChangeStatus Changed = ChangeStatus::UNCHANGED;2654const auto *ICVTrackingAA = A.getAAFor<AAICVTracker>(2655*this, IRPosition::returned(*getAssociatedFunction()),2656DepClassTy::REQUIRED);26572658// We don't have any information, so we assume it changes the ICV.2659if (!ICVTrackingAA->isAssumedTracked())2660return indicatePessimisticFixpoint();26612662for (InternalControlVar ICV : TrackableICVs) {2663std::optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];2664std::optional<Value *> NewReplVal =2665ICVTrackingAA->getUniqueReplacementValue(ICV);26662667if (ReplVal == NewReplVal)2668continue;26692670ReplVal = NewReplVal;2671Changed = ChangeStatus::CHANGED;2672}2673return Changed;2674}2675};26762677/// Determines if \p BB exits the function unconditionally itself or reaches a2678/// block that does through only unique successors.2679static bool hasFunctionEndAsUniqueSuccessor(const BasicBlock *BB) {2680if (succ_empty(BB))2681return true;2682const BasicBlock *const Successor = BB->getUniqueSuccessor();2683if (!Successor)2684return false;2685return hasFunctionEndAsUniqueSuccessor(Successor);2686}26872688struct AAExecutionDomainFunction : public AAExecutionDomain {2689AAExecutionDomainFunction(const IRPosition &IRP, Attributor &A)2690: AAExecutionDomain(IRP, A) {}26912692~AAExecutionDomainFunction() { delete RPOT; }26932694void initialize(Attributor &A) override {2695Function *F = getAnchorScope();2696assert(F && "Expected anchor function");2697RPOT = new ReversePostOrderTraversal<Function *>(F);2698}26992700const std::string getAsStr(Attributor *) const override {2701unsigned TotalBlocks = 0, InitialThreadBlocks = 0, AlignedBlocks = 0;2702for (auto &It : BEDMap) {2703if (!It.getFirst())2704continue;2705TotalBlocks++;2706InitialThreadBlocks += It.getSecond().IsExecutedByInitialThreadOnly;2707AlignedBlocks += It.getSecond().IsReachedFromAlignedBarrierOnly &&2708It.getSecond().IsReachingAlignedBarrierOnly;2709}2710return "[AAExecutionDomain] " + std::to_string(InitialThreadBlocks) + "/" +2711std::to_string(AlignedBlocks) + " of " +2712std::to_string(TotalBlocks) +2713" executed by initial thread / aligned";2714}27152716/// See AbstractAttribute::trackStatistics().2717void trackStatistics() const override {}27182719ChangeStatus manifest(Attributor &A) override {2720LLVM_DEBUG({2721for (const BasicBlock &BB : *getAnchorScope()) {2722if (!isExecutedByInitialThreadOnly(BB))2723continue;2724dbgs() << TAG << " Basic block @" << getAnchorScope()->getName() << " "2725<< BB.getName() << " is executed by a single thread.\n";2726}2727});27282729ChangeStatus Changed = ChangeStatus::UNCHANGED;27302731if (DisableOpenMPOptBarrierElimination)2732return Changed;27332734SmallPtrSet<CallBase *, 16> DeletedBarriers;2735auto HandleAlignedBarrier = [&](CallBase *CB) {2736const ExecutionDomainTy &ED = CB ? CEDMap[{CB, PRE}] : BEDMap[nullptr];2737if (!ED.IsReachedFromAlignedBarrierOnly ||2738ED.EncounteredNonLocalSideEffect)2739return;2740if (!ED.EncounteredAssumes.empty() && !A.isModulePass())2741return;27422743// We can remove this barrier, if it is one, or aligned barriers reaching2744// the kernel end (if CB is nullptr). Aligned barriers reaching the kernel2745// end should only be removed if the kernel end is their unique successor;2746// otherwise, they may have side-effects that aren't accounted for in the2747// kernel end in their other successors. If those barriers have other2748// barriers reaching them, those can be transitively removed as well as2749// long as the kernel end is also their unique successor.2750if (CB) {2751DeletedBarriers.insert(CB);2752A.deleteAfterManifest(*CB);2753++NumBarriersEliminated;2754Changed = ChangeStatus::CHANGED;2755} else if (!ED.AlignedBarriers.empty()) {2756Changed = ChangeStatus::CHANGED;2757SmallVector<CallBase *> Worklist(ED.AlignedBarriers.begin(),2758ED.AlignedBarriers.end());2759SmallSetVector<CallBase *, 16> Visited;2760while (!Worklist.empty()) {2761CallBase *LastCB = Worklist.pop_back_val();2762if (!Visited.insert(LastCB))2763continue;2764if (LastCB->getFunction() != getAnchorScope())2765continue;2766if (!hasFunctionEndAsUniqueSuccessor(LastCB->getParent()))2767continue;2768if (!DeletedBarriers.count(LastCB)) {2769++NumBarriersEliminated;2770A.deleteAfterManifest(*LastCB);2771continue;2772}2773// The final aligned barrier (LastCB) reaching the kernel end was2774// removed already. This means we can go one step further and remove2775// the barriers encoutered last before (LastCB).2776const ExecutionDomainTy &LastED = CEDMap[{LastCB, PRE}];2777Worklist.append(LastED.AlignedBarriers.begin(),2778LastED.AlignedBarriers.end());2779}2780}27812782// If we actually eliminated a barrier we need to eliminate the associated2783// llvm.assumes as well to avoid creating UB.2784if (!ED.EncounteredAssumes.empty() && (CB || !ED.AlignedBarriers.empty()))2785for (auto *AssumeCB : ED.EncounteredAssumes)2786A.deleteAfterManifest(*AssumeCB);2787};27882789for (auto *CB : AlignedBarriers)2790HandleAlignedBarrier(CB);27912792// Handle the "kernel end barrier" for kernels too.2793if (omp::isOpenMPKernel(*getAnchorScope()))2794HandleAlignedBarrier(nullptr);27952796return Changed;2797}27982799bool isNoOpFence(const FenceInst &FI) const override {2800return getState().isValidState() && !NonNoOpFences.count(&FI);2801}28022803/// Merge barrier and assumption information from \p PredED into the successor2804/// \p ED.2805void2806mergeInPredecessorBarriersAndAssumptions(Attributor &A, ExecutionDomainTy &ED,2807const ExecutionDomainTy &PredED);28082809/// Merge all information from \p PredED into the successor \p ED. If2810/// \p InitialEdgeOnly is set, only the initial edge will enter the block2811/// represented by \p ED from this predecessor.2812bool mergeInPredecessor(Attributor &A, ExecutionDomainTy &ED,2813const ExecutionDomainTy &PredED,2814bool InitialEdgeOnly = false);28152816/// Accumulate information for the entry block in \p EntryBBED.2817bool handleCallees(Attributor &A, ExecutionDomainTy &EntryBBED);28182819/// See AbstractAttribute::updateImpl.2820ChangeStatus updateImpl(Attributor &A) override;28212822/// Query interface, see AAExecutionDomain2823///{2824bool isExecutedByInitialThreadOnly(const BasicBlock &BB) const override {2825if (!isValidState())2826return false;2827assert(BB.getParent() == getAnchorScope() && "Block is out of scope!");2828return BEDMap.lookup(&BB).IsExecutedByInitialThreadOnly;2829}28302831bool isExecutedInAlignedRegion(Attributor &A,2832const Instruction &I) const override {2833assert(I.getFunction() == getAnchorScope() &&2834"Instruction is out of scope!");2835if (!isValidState())2836return false;28372838bool ForwardIsOk = true;2839const Instruction *CurI;28402841// Check forward until a call or the block end is reached.2842CurI = &I;2843do {2844auto *CB = dyn_cast<CallBase>(CurI);2845if (!CB)2846continue;2847if (CB != &I && AlignedBarriers.contains(const_cast<CallBase *>(CB)))2848return true;2849const auto &It = CEDMap.find({CB, PRE});2850if (It == CEDMap.end())2851continue;2852if (!It->getSecond().IsReachingAlignedBarrierOnly)2853ForwardIsOk = false;2854break;2855} while ((CurI = CurI->getNextNonDebugInstruction()));28562857if (!CurI && !BEDMap.lookup(I.getParent()).IsReachingAlignedBarrierOnly)2858ForwardIsOk = false;28592860// Check backward until a call or the block beginning is reached.2861CurI = &I;2862do {2863auto *CB = dyn_cast<CallBase>(CurI);2864if (!CB)2865continue;2866if (CB != &I && AlignedBarriers.contains(const_cast<CallBase *>(CB)))2867return true;2868const auto &It = CEDMap.find({CB, POST});2869if (It == CEDMap.end())2870continue;2871if (It->getSecond().IsReachedFromAlignedBarrierOnly)2872break;2873return false;2874} while ((CurI = CurI->getPrevNonDebugInstruction()));28752876// Delayed decision on the forward pass to allow aligned barrier detection2877// in the backwards traversal.2878if (!ForwardIsOk)2879return false;28802881if (!CurI) {2882const BasicBlock *BB = I.getParent();2883if (BB == &BB->getParent()->getEntryBlock())2884return BEDMap.lookup(nullptr).IsReachedFromAlignedBarrierOnly;2885if (!llvm::all_of(predecessors(BB), [&](const BasicBlock *PredBB) {2886return BEDMap.lookup(PredBB).IsReachedFromAlignedBarrierOnly;2887})) {2888return false;2889}2890}28912892// On neither traversal we found a anything but aligned barriers.2893return true;2894}28952896ExecutionDomainTy getExecutionDomain(const BasicBlock &BB) const override {2897assert(isValidState() &&2898"No request should be made against an invalid state!");2899return BEDMap.lookup(&BB);2900}2901std::pair<ExecutionDomainTy, ExecutionDomainTy>2902getExecutionDomain(const CallBase &CB) const override {2903assert(isValidState() &&2904"No request should be made against an invalid state!");2905return {CEDMap.lookup({&CB, PRE}), CEDMap.lookup({&CB, POST})};2906}2907ExecutionDomainTy getFunctionExecutionDomain() const override {2908assert(isValidState() &&2909"No request should be made against an invalid state!");2910return InterProceduralED;2911}2912///}29132914// Check if the edge into the successor block contains a condition that only2915// lets the main thread execute it.2916static bool isInitialThreadOnlyEdge(Attributor &A, BranchInst *Edge,2917BasicBlock &SuccessorBB) {2918if (!Edge || !Edge->isConditional())2919return false;2920if (Edge->getSuccessor(0) != &SuccessorBB)2921return false;29222923auto *Cmp = dyn_cast<CmpInst>(Edge->getCondition());2924if (!Cmp || !Cmp->isTrueWhenEqual() || !Cmp->isEquality())2925return false;29262927ConstantInt *C = dyn_cast<ConstantInt>(Cmp->getOperand(1));2928if (!C)2929return false;29302931// Match: -1 == __kmpc_target_init (for non-SPMD kernels only!)2932if (C->isAllOnesValue()) {2933auto *CB = dyn_cast<CallBase>(Cmp->getOperand(0));2934auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());2935auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];2936CB = CB ? OpenMPOpt::getCallIfRegularCall(*CB, &RFI) : nullptr;2937if (!CB)2938return false;2939ConstantStruct *KernelEnvC =2940KernelInfo::getKernelEnvironementFromKernelInitCB(CB);2941ConstantInt *ExecModeC =2942KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC);2943return ExecModeC->getSExtValue() & OMP_TGT_EXEC_MODE_GENERIC;2944}29452946if (C->isZero()) {2947// Match: 0 == llvm.nvvm.read.ptx.sreg.tid.x()2948if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))2949if (II->getIntrinsicID() == Intrinsic::nvvm_read_ptx_sreg_tid_x)2950return true;29512952// Match: 0 == llvm.amdgcn.workitem.id.x()2953if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))2954if (II->getIntrinsicID() == Intrinsic::amdgcn_workitem_id_x)2955return true;2956}29572958return false;2959};29602961/// Mapping containing information about the function for other AAs.2962ExecutionDomainTy InterProceduralED;29632964enum Direction { PRE = 0, POST = 1 };2965/// Mapping containing information per block.2966DenseMap<const BasicBlock *, ExecutionDomainTy> BEDMap;2967DenseMap<PointerIntPair<const CallBase *, 1, Direction>, ExecutionDomainTy>2968CEDMap;2969SmallSetVector<CallBase *, 16> AlignedBarriers;29702971ReversePostOrderTraversal<Function *> *RPOT = nullptr;29722973/// Set \p R to \V and report true if that changed \p R.2974static bool setAndRecord(bool &R, bool V) {2975bool Eq = (R == V);2976R = V;2977return !Eq;2978}29792980/// Collection of fences known to be non-no-opt. All fences not in this set2981/// can be assumed no-opt.2982SmallPtrSet<const FenceInst *, 8> NonNoOpFences;2983};29842985void AAExecutionDomainFunction::mergeInPredecessorBarriersAndAssumptions(2986Attributor &A, ExecutionDomainTy &ED, const ExecutionDomainTy &PredED) {2987for (auto *EA : PredED.EncounteredAssumes)2988ED.addAssumeInst(A, *EA);29892990for (auto *AB : PredED.AlignedBarriers)2991ED.addAlignedBarrier(A, *AB);2992}29932994bool AAExecutionDomainFunction::mergeInPredecessor(2995Attributor &A, ExecutionDomainTy &ED, const ExecutionDomainTy &PredED,2996bool InitialEdgeOnly) {29972998bool Changed = false;2999Changed |=3000setAndRecord(ED.IsExecutedByInitialThreadOnly,3001InitialEdgeOnly || (PredED.IsExecutedByInitialThreadOnly &&3002ED.IsExecutedByInitialThreadOnly));30033004Changed |= setAndRecord(ED.IsReachedFromAlignedBarrierOnly,3005ED.IsReachedFromAlignedBarrierOnly &&3006PredED.IsReachedFromAlignedBarrierOnly);3007Changed |= setAndRecord(ED.EncounteredNonLocalSideEffect,3008ED.EncounteredNonLocalSideEffect |3009PredED.EncounteredNonLocalSideEffect);3010// Do not track assumptions and barriers as part of Changed.3011if (ED.IsReachedFromAlignedBarrierOnly)3012mergeInPredecessorBarriersAndAssumptions(A, ED, PredED);3013else3014ED.clearAssumeInstAndAlignedBarriers();3015return Changed;3016}30173018bool AAExecutionDomainFunction::handleCallees(Attributor &A,3019ExecutionDomainTy &EntryBBED) {3020SmallVector<std::pair<ExecutionDomainTy, ExecutionDomainTy>, 4> CallSiteEDs;3021auto PredForCallSite = [&](AbstractCallSite ACS) {3022const auto *EDAA = A.getAAFor<AAExecutionDomain>(3023*this, IRPosition::function(*ACS.getInstruction()->getFunction()),3024DepClassTy::OPTIONAL);3025if (!EDAA || !EDAA->getState().isValidState())3026return false;3027CallSiteEDs.emplace_back(3028EDAA->getExecutionDomain(*cast<CallBase>(ACS.getInstruction())));3029return true;3030};30313032ExecutionDomainTy ExitED;3033bool AllCallSitesKnown;3034if (A.checkForAllCallSites(PredForCallSite, *this,3035/* RequiresAllCallSites */ true,3036AllCallSitesKnown)) {3037for (const auto &[CSInED, CSOutED] : CallSiteEDs) {3038mergeInPredecessor(A, EntryBBED, CSInED);3039ExitED.IsReachingAlignedBarrierOnly &=3040CSOutED.IsReachingAlignedBarrierOnly;3041}30423043} else {3044// We could not find all predecessors, so this is either a kernel or a3045// function with external linkage (or with some other weird uses).3046if (omp::isOpenMPKernel(*getAnchorScope())) {3047EntryBBED.IsExecutedByInitialThreadOnly = false;3048EntryBBED.IsReachedFromAlignedBarrierOnly = true;3049EntryBBED.EncounteredNonLocalSideEffect = false;3050ExitED.IsReachingAlignedBarrierOnly = false;3051} else {3052EntryBBED.IsExecutedByInitialThreadOnly = false;3053EntryBBED.IsReachedFromAlignedBarrierOnly = false;3054EntryBBED.EncounteredNonLocalSideEffect = true;3055ExitED.IsReachingAlignedBarrierOnly = false;3056}3057}30583059bool Changed = false;3060auto &FnED = BEDMap[nullptr];3061Changed |= setAndRecord(FnED.IsReachedFromAlignedBarrierOnly,3062FnED.IsReachedFromAlignedBarrierOnly &3063EntryBBED.IsReachedFromAlignedBarrierOnly);3064Changed |= setAndRecord(FnED.IsReachingAlignedBarrierOnly,3065FnED.IsReachingAlignedBarrierOnly &3066ExitED.IsReachingAlignedBarrierOnly);3067Changed |= setAndRecord(FnED.IsExecutedByInitialThreadOnly,3068EntryBBED.IsExecutedByInitialThreadOnly);3069return Changed;3070}30713072ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {30733074bool Changed = false;30753076// Helper to deal with an aligned barrier encountered during the forward3077// traversal. \p CB is the aligned barrier, \p ED is the execution domain when3078// it was encountered.3079auto HandleAlignedBarrier = [&](CallBase &CB, ExecutionDomainTy &ED) {3080Changed |= AlignedBarriers.insert(&CB);3081// First, update the barrier ED kept in the separate CEDMap.3082auto &CallInED = CEDMap[{&CB, PRE}];3083Changed |= mergeInPredecessor(A, CallInED, ED);3084CallInED.IsReachingAlignedBarrierOnly = true;3085// Next adjust the ED we use for the traversal.3086ED.EncounteredNonLocalSideEffect = false;3087ED.IsReachedFromAlignedBarrierOnly = true;3088// Aligned barrier collection has to come last.3089ED.clearAssumeInstAndAlignedBarriers();3090ED.addAlignedBarrier(A, CB);3091auto &CallOutED = CEDMap[{&CB, POST}];3092Changed |= mergeInPredecessor(A, CallOutED, ED);3093};30943095auto *LivenessAA =3096A.getAAFor<AAIsDead>(*this, getIRPosition(), DepClassTy::OPTIONAL);30973098Function *F = getAnchorScope();3099BasicBlock &EntryBB = F->getEntryBlock();3100bool IsKernel = omp::isOpenMPKernel(*F);31013102SmallVector<Instruction *> SyncInstWorklist;3103for (auto &RIt : *RPOT) {3104BasicBlock &BB = *RIt;31053106bool IsEntryBB = &BB == &EntryBB;3107// TODO: We use local reasoning since we don't have a divergence analysis3108// running as well. We could basically allow uniform branches here.3109bool AlignedBarrierLastInBlock = IsEntryBB && IsKernel;3110bool IsExplicitlyAligned = IsEntryBB && IsKernel;3111ExecutionDomainTy ED;3112// Propagate "incoming edges" into information about this block.3113if (IsEntryBB) {3114Changed |= handleCallees(A, ED);3115} else {3116// For live non-entry blocks we only propagate3117// information via live edges.3118if (LivenessAA && LivenessAA->isAssumedDead(&BB))3119continue;31203121for (auto *PredBB : predecessors(&BB)) {3122if (LivenessAA && LivenessAA->isEdgeDead(PredBB, &BB))3123continue;3124bool InitialEdgeOnly = isInitialThreadOnlyEdge(3125A, dyn_cast<BranchInst>(PredBB->getTerminator()), BB);3126mergeInPredecessor(A, ED, BEDMap[PredBB], InitialEdgeOnly);3127}3128}31293130// Now we traverse the block, accumulate effects in ED and attach3131// information to calls.3132for (Instruction &I : BB) {3133bool UsedAssumedInformation;3134if (A.isAssumedDead(I, *this, LivenessAA, UsedAssumedInformation,3135/* CheckBBLivenessOnly */ false, DepClassTy::OPTIONAL,3136/* CheckForDeadStore */ true))3137continue;31383139// Asummes and "assume-like" (dbg, lifetime, ...) are handled first, the3140// former is collected the latter is ignored.3141if (auto *II = dyn_cast<IntrinsicInst>(&I)) {3142if (auto *AI = dyn_cast_or_null<AssumeInst>(II)) {3143ED.addAssumeInst(A, *AI);3144continue;3145}3146// TODO: Should we also collect and delete lifetime markers?3147if (II->isAssumeLikeIntrinsic())3148continue;3149}31503151if (auto *FI = dyn_cast<FenceInst>(&I)) {3152if (!ED.EncounteredNonLocalSideEffect) {3153// An aligned fence without non-local side-effects is a no-op.3154if (ED.IsReachedFromAlignedBarrierOnly)3155continue;3156// A non-aligned fence without non-local side-effects is a no-op3157// if the ordering only publishes non-local side-effects (or less).3158switch (FI->getOrdering()) {3159case AtomicOrdering::NotAtomic:3160continue;3161case AtomicOrdering::Unordered:3162continue;3163case AtomicOrdering::Monotonic:3164continue;3165case AtomicOrdering::Acquire:3166break;3167case AtomicOrdering::Release:3168continue;3169case AtomicOrdering::AcquireRelease:3170break;3171case AtomicOrdering::SequentiallyConsistent:3172break;3173};3174}3175NonNoOpFences.insert(FI);3176}31773178auto *CB = dyn_cast<CallBase>(&I);3179bool IsNoSync = AA::isNoSyncInst(A, I, *this);3180bool IsAlignedBarrier =3181!IsNoSync && CB &&3182AANoSync::isAlignedBarrier(*CB, AlignedBarrierLastInBlock);31833184AlignedBarrierLastInBlock &= IsNoSync;3185IsExplicitlyAligned &= IsNoSync;31863187// Next we check for calls. Aligned barriers are handled3188// explicitly, everything else is kept for the backward traversal and will3189// also affect our state.3190if (CB) {3191if (IsAlignedBarrier) {3192HandleAlignedBarrier(*CB, ED);3193AlignedBarrierLastInBlock = true;3194IsExplicitlyAligned = true;3195continue;3196}31973198// Check the pointer(s) of a memory intrinsic explicitly.3199if (isa<MemIntrinsic>(&I)) {3200if (!ED.EncounteredNonLocalSideEffect &&3201AA::isPotentiallyAffectedByBarrier(A, I, *this))3202ED.EncounteredNonLocalSideEffect = true;3203if (!IsNoSync) {3204ED.IsReachedFromAlignedBarrierOnly = false;3205SyncInstWorklist.push_back(&I);3206}3207continue;3208}32093210// Record how we entered the call, then accumulate the effect of the3211// call in ED for potential use by the callee.3212auto &CallInED = CEDMap[{CB, PRE}];3213Changed |= mergeInPredecessor(A, CallInED, ED);32143215// If we have a sync-definition we can check if it starts/ends in an3216// aligned barrier. If we are unsure we assume any sync breaks3217// alignment.3218Function *Callee = CB->getCalledFunction();3219if (!IsNoSync && Callee && !Callee->isDeclaration()) {3220const auto *EDAA = A.getAAFor<AAExecutionDomain>(3221*this, IRPosition::function(*Callee), DepClassTy::OPTIONAL);3222if (EDAA && EDAA->getState().isValidState()) {3223const auto &CalleeED = EDAA->getFunctionExecutionDomain();3224ED.IsReachedFromAlignedBarrierOnly =3225CalleeED.IsReachedFromAlignedBarrierOnly;3226AlignedBarrierLastInBlock = ED.IsReachedFromAlignedBarrierOnly;3227if (IsNoSync || !CalleeED.IsReachedFromAlignedBarrierOnly)3228ED.EncounteredNonLocalSideEffect |=3229CalleeED.EncounteredNonLocalSideEffect;3230else3231ED.EncounteredNonLocalSideEffect =3232CalleeED.EncounteredNonLocalSideEffect;3233if (!CalleeED.IsReachingAlignedBarrierOnly) {3234Changed |=3235setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);3236SyncInstWorklist.push_back(&I);3237}3238if (CalleeED.IsReachedFromAlignedBarrierOnly)3239mergeInPredecessorBarriersAndAssumptions(A, ED, CalleeED);3240auto &CallOutED = CEDMap[{CB, POST}];3241Changed |= mergeInPredecessor(A, CallOutED, ED);3242continue;3243}3244}3245if (!IsNoSync) {3246ED.IsReachedFromAlignedBarrierOnly = false;3247Changed |= setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);3248SyncInstWorklist.push_back(&I);3249}3250AlignedBarrierLastInBlock &= ED.IsReachedFromAlignedBarrierOnly;3251ED.EncounteredNonLocalSideEffect |= !CB->doesNotAccessMemory();3252auto &CallOutED = CEDMap[{CB, POST}];3253Changed |= mergeInPredecessor(A, CallOutED, ED);3254}32553256if (!I.mayHaveSideEffects() && !I.mayReadFromMemory())3257continue;32583259// If we have a callee we try to use fine-grained information to3260// determine local side-effects.3261if (CB) {3262const auto *MemAA = A.getAAFor<AAMemoryLocation>(3263*this, IRPosition::callsite_function(*CB), DepClassTy::OPTIONAL);32643265auto AccessPred = [&](const Instruction *I, const Value *Ptr,3266AAMemoryLocation::AccessKind,3267AAMemoryLocation::MemoryLocationsKind) {3268return !AA::isPotentiallyAffectedByBarrier(A, {Ptr}, *this, I);3269};3270if (MemAA && MemAA->getState().isValidState() &&3271MemAA->checkForAllAccessesToMemoryKind(3272AccessPred, AAMemoryLocation::ALL_LOCATIONS))3273continue;3274}32753276auto &InfoCache = A.getInfoCache();3277if (!I.mayHaveSideEffects() && InfoCache.isOnlyUsedByAssume(I))3278continue;32793280if (auto *LI = dyn_cast<LoadInst>(&I))3281if (LI->hasMetadata(LLVMContext::MD_invariant_load))3282continue;32833284if (!ED.EncounteredNonLocalSideEffect &&3285AA::isPotentiallyAffectedByBarrier(A, I, *this))3286ED.EncounteredNonLocalSideEffect = true;3287}32883289bool IsEndAndNotReachingAlignedBarriersOnly = false;3290if (!isa<UnreachableInst>(BB.getTerminator()) &&3291!BB.getTerminator()->getNumSuccessors()) {32923293Changed |= mergeInPredecessor(A, InterProceduralED, ED);32943295auto &FnED = BEDMap[nullptr];3296if (IsKernel && !IsExplicitlyAligned)3297FnED.IsReachingAlignedBarrierOnly = false;3298Changed |= mergeInPredecessor(A, FnED, ED);32993300if (!FnED.IsReachingAlignedBarrierOnly) {3301IsEndAndNotReachingAlignedBarriersOnly = true;3302SyncInstWorklist.push_back(BB.getTerminator());3303auto &BBED = BEDMap[&BB];3304Changed |= setAndRecord(BBED.IsReachingAlignedBarrierOnly, false);3305}3306}33073308ExecutionDomainTy &StoredED = BEDMap[&BB];3309ED.IsReachingAlignedBarrierOnly = StoredED.IsReachingAlignedBarrierOnly &3310!IsEndAndNotReachingAlignedBarriersOnly;33113312// Check if we computed anything different as part of the forward3313// traversal. We do not take assumptions and aligned barriers into account3314// as they do not influence the state we iterate. Backward traversal values3315// are handled later on.3316if (ED.IsExecutedByInitialThreadOnly !=3317StoredED.IsExecutedByInitialThreadOnly ||3318ED.IsReachedFromAlignedBarrierOnly !=3319StoredED.IsReachedFromAlignedBarrierOnly ||3320ED.EncounteredNonLocalSideEffect !=3321StoredED.EncounteredNonLocalSideEffect)3322Changed = true;33233324// Update the state with the new value.3325StoredED = std::move(ED);3326}33273328// Propagate (non-aligned) sync instruction effects backwards until the3329// entry is hit or an aligned barrier.3330SmallSetVector<BasicBlock *, 16> Visited;3331while (!SyncInstWorklist.empty()) {3332Instruction *SyncInst = SyncInstWorklist.pop_back_val();3333Instruction *CurInst = SyncInst;3334bool HitAlignedBarrierOrKnownEnd = false;3335while ((CurInst = CurInst->getPrevNode())) {3336auto *CB = dyn_cast<CallBase>(CurInst);3337if (!CB)3338continue;3339auto &CallOutED = CEDMap[{CB, POST}];3340Changed |= setAndRecord(CallOutED.IsReachingAlignedBarrierOnly, false);3341auto &CallInED = CEDMap[{CB, PRE}];3342HitAlignedBarrierOrKnownEnd =3343AlignedBarriers.count(CB) || !CallInED.IsReachingAlignedBarrierOnly;3344if (HitAlignedBarrierOrKnownEnd)3345break;3346Changed |= setAndRecord(CallInED.IsReachingAlignedBarrierOnly, false);3347}3348if (HitAlignedBarrierOrKnownEnd)3349continue;3350BasicBlock *SyncBB = SyncInst->getParent();3351for (auto *PredBB : predecessors(SyncBB)) {3352if (LivenessAA && LivenessAA->isEdgeDead(PredBB, SyncBB))3353continue;3354if (!Visited.insert(PredBB))3355continue;3356auto &PredED = BEDMap[PredBB];3357if (setAndRecord(PredED.IsReachingAlignedBarrierOnly, false)) {3358Changed = true;3359SyncInstWorklist.push_back(PredBB->getTerminator());3360}3361}3362if (SyncBB != &EntryBB)3363continue;3364Changed |=3365setAndRecord(InterProceduralED.IsReachingAlignedBarrierOnly, false);3366}33673368return Changed ? ChangeStatus::CHANGED : ChangeStatus::UNCHANGED;3369}33703371/// Try to replace memory allocation calls called by a single thread with a3372/// static buffer of shared memory.3373struct AAHeapToShared : public StateWrapper<BooleanState, AbstractAttribute> {3374using Base = StateWrapper<BooleanState, AbstractAttribute>;3375AAHeapToShared(const IRPosition &IRP, Attributor &A) : Base(IRP) {}33763377/// Create an abstract attribute view for the position \p IRP.3378static AAHeapToShared &createForPosition(const IRPosition &IRP,3379Attributor &A);33803381/// Returns true if HeapToShared conversion is assumed to be possible.3382virtual bool isAssumedHeapToShared(CallBase &CB) const = 0;33833384/// Returns true if HeapToShared conversion is assumed and the CB is a3385/// callsite to a free operation to be removed.3386virtual bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const = 0;33873388/// See AbstractAttribute::getName().3389const std::string getName() const override { return "AAHeapToShared"; }33903391/// See AbstractAttribute::getIdAddr().3392const char *getIdAddr() const override { return &ID; }33933394/// This function should return true if the type of the \p AA is3395/// AAHeapToShared.3396static bool classof(const AbstractAttribute *AA) {3397return (AA->getIdAddr() == &ID);3398}33993400/// Unique ID (due to the unique address)3401static const char ID;3402};34033404struct AAHeapToSharedFunction : public AAHeapToShared {3405AAHeapToSharedFunction(const IRPosition &IRP, Attributor &A)3406: AAHeapToShared(IRP, A) {}34073408const std::string getAsStr(Attributor *) const override {3409return "[AAHeapToShared] " + std::to_string(MallocCalls.size()) +3410" malloc calls eligible.";3411}34123413/// See AbstractAttribute::trackStatistics().3414void trackStatistics() const override {}34153416/// This functions finds free calls that will be removed by the3417/// HeapToShared transformation.3418void findPotentialRemovedFreeCalls(Attributor &A) {3419auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());3420auto &FreeRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];34213422PotentialRemovedFreeCalls.clear();3423// Update free call users of found malloc calls.3424for (CallBase *CB : MallocCalls) {3425SmallVector<CallBase *, 4> FreeCalls;3426for (auto *U : CB->users()) {3427CallBase *C = dyn_cast<CallBase>(U);3428if (C && C->getCalledFunction() == FreeRFI.Declaration)3429FreeCalls.push_back(C);3430}34313432if (FreeCalls.size() != 1)3433continue;34343435PotentialRemovedFreeCalls.insert(FreeCalls.front());3436}3437}34383439void initialize(Attributor &A) override {3440if (DisableOpenMPOptDeglobalization) {3441indicatePessimisticFixpoint();3442return;3443}34443445auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());3446auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];3447if (!RFI.Declaration)3448return;34493450Attributor::SimplifictionCallbackTy SCB =3451[](const IRPosition &, const AbstractAttribute *,3452bool &) -> std::optional<Value *> { return nullptr; };34533454Function *F = getAnchorScope();3455for (User *U : RFI.Declaration->users())3456if (CallBase *CB = dyn_cast<CallBase>(U)) {3457if (CB->getFunction() != F)3458continue;3459MallocCalls.insert(CB);3460A.registerSimplificationCallback(IRPosition::callsite_returned(*CB),3461SCB);3462}34633464findPotentialRemovedFreeCalls(A);3465}34663467bool isAssumedHeapToShared(CallBase &CB) const override {3468return isValidState() && MallocCalls.count(&CB);3469}34703471bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const override {3472return isValidState() && PotentialRemovedFreeCalls.count(&CB);3473}34743475ChangeStatus manifest(Attributor &A) override {3476if (MallocCalls.empty())3477return ChangeStatus::UNCHANGED;34783479auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());3480auto &FreeCall = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];34813482Function *F = getAnchorScope();3483auto *HS = A.lookupAAFor<AAHeapToStack>(IRPosition::function(*F), this,3484DepClassTy::OPTIONAL);34853486ChangeStatus Changed = ChangeStatus::UNCHANGED;3487for (CallBase *CB : MallocCalls) {3488// Skip replacing this if HeapToStack has already claimed it.3489if (HS && HS->isAssumedHeapToStack(*CB))3490continue;34913492// Find the unique free call to remove it.3493SmallVector<CallBase *, 4> FreeCalls;3494for (auto *U : CB->users()) {3495CallBase *C = dyn_cast<CallBase>(U);3496if (C && C->getCalledFunction() == FreeCall.Declaration)3497FreeCalls.push_back(C);3498}3499if (FreeCalls.size() != 1)3500continue;35013502auto *AllocSize = cast<ConstantInt>(CB->getArgOperand(0));35033504if (AllocSize->getZExtValue() + SharedMemoryUsed > SharedMemoryLimit) {3505LLVM_DEBUG(dbgs() << TAG << "Cannot replace call " << *CB3506<< " with shared memory."3507<< " Shared memory usage is limited to "3508<< SharedMemoryLimit << " bytes\n");3509continue;3510}35113512LLVM_DEBUG(dbgs() << TAG << "Replace globalization call " << *CB3513<< " with " << AllocSize->getZExtValue()3514<< " bytes of shared memory\n");35153516// Create a new shared memory buffer of the same size as the allocation3517// and replace all the uses of the original allocation with it.3518Module *M = CB->getModule();3519Type *Int8Ty = Type::getInt8Ty(M->getContext());3520Type *Int8ArrTy = ArrayType::get(Int8Ty, AllocSize->getZExtValue());3521auto *SharedMem = new GlobalVariable(3522*M, Int8ArrTy, /* IsConstant */ false, GlobalValue::InternalLinkage,3523PoisonValue::get(Int8ArrTy), CB->getName() + "_shared", nullptr,3524GlobalValue::NotThreadLocal,3525static_cast<unsigned>(AddressSpace::Shared));3526auto *NewBuffer =3527ConstantExpr::getPointerCast(SharedMem, Int8Ty->getPointerTo());35283529auto Remark = [&](OptimizationRemark OR) {3530return OR << "Replaced globalized variable with "3531<< ore::NV("SharedMemory", AllocSize->getZExtValue())3532<< (AllocSize->isOne() ? " byte " : " bytes ")3533<< "of shared memory.";3534};3535A.emitRemark<OptimizationRemark>(CB, "OMP111", Remark);35363537MaybeAlign Alignment = CB->getRetAlign();3538assert(Alignment &&3539"HeapToShared on allocation without alignment attribute");3540SharedMem->setAlignment(*Alignment);35413542A.changeAfterManifest(IRPosition::callsite_returned(*CB), *NewBuffer);3543A.deleteAfterManifest(*CB);3544A.deleteAfterManifest(*FreeCalls.front());35453546SharedMemoryUsed += AllocSize->getZExtValue();3547NumBytesMovedToSharedMemory = SharedMemoryUsed;3548Changed = ChangeStatus::CHANGED;3549}35503551return Changed;3552}35533554ChangeStatus updateImpl(Attributor &A) override {3555if (MallocCalls.empty())3556return indicatePessimisticFixpoint();3557auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());3558auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];3559if (!RFI.Declaration)3560return ChangeStatus::UNCHANGED;35613562Function *F = getAnchorScope();35633564auto NumMallocCalls = MallocCalls.size();35653566// Only consider malloc calls executed by a single thread with a constant.3567for (User *U : RFI.Declaration->users()) {3568if (CallBase *CB = dyn_cast<CallBase>(U)) {3569if (CB->getCaller() != F)3570continue;3571if (!MallocCalls.count(CB))3572continue;3573if (!isa<ConstantInt>(CB->getArgOperand(0))) {3574MallocCalls.remove(CB);3575continue;3576}3577const auto *ED = A.getAAFor<AAExecutionDomain>(3578*this, IRPosition::function(*F), DepClassTy::REQUIRED);3579if (!ED || !ED->isExecutedByInitialThreadOnly(*CB))3580MallocCalls.remove(CB);3581}3582}35833584findPotentialRemovedFreeCalls(A);35853586if (NumMallocCalls != MallocCalls.size())3587return ChangeStatus::CHANGED;35883589return ChangeStatus::UNCHANGED;3590}35913592/// Collection of all malloc calls in a function.3593SmallSetVector<CallBase *, 4> MallocCalls;3594/// Collection of potentially removed free calls in a function.3595SmallPtrSet<CallBase *, 4> PotentialRemovedFreeCalls;3596/// The total amount of shared memory that has been used for HeapToShared.3597unsigned SharedMemoryUsed = 0;3598};35993600struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {3601using Base = StateWrapper<KernelInfoState, AbstractAttribute>;3602AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {}36033604/// The callee value is tracked beyond a simple stripPointerCasts, so we allow3605/// unknown callees.3606static bool requiresCalleeForCallBase() { return false; }36073608/// Statistics are tracked as part of manifest for now.3609void trackStatistics() const override {}36103611/// See AbstractAttribute::getAsStr()3612const std::string getAsStr(Attributor *) const override {3613if (!isValidState())3614return "<invalid>";3615return std::string(SPMDCompatibilityTracker.isAssumed() ? "SPMD"3616: "generic") +3617std::string(SPMDCompatibilityTracker.isAtFixpoint() ? " [FIX]"3618: "") +3619std::string(" #PRs: ") +3620(ReachedKnownParallelRegions.isValidState()3621? std::to_string(ReachedKnownParallelRegions.size())3622: "<invalid>") +3623", #Unknown PRs: " +3624(ReachedUnknownParallelRegions.isValidState()3625? std::to_string(ReachedUnknownParallelRegions.size())3626: "<invalid>") +3627", #Reaching Kernels: " +3628(ReachingKernelEntries.isValidState()3629? std::to_string(ReachingKernelEntries.size())3630: "<invalid>") +3631", #ParLevels: " +3632(ParallelLevels.isValidState()3633? std::to_string(ParallelLevels.size())3634: "<invalid>") +3635", NestedPar: " + (NestedParallelism ? "yes" : "no");3636}36373638/// Create an abstract attribute biew for the position \p IRP.3639static AAKernelInfo &createForPosition(const IRPosition &IRP, Attributor &A);36403641/// See AbstractAttribute::getName()3642const std::string getName() const override { return "AAKernelInfo"; }36433644/// See AbstractAttribute::getIdAddr()3645const char *getIdAddr() const override { return &ID; }36463647/// This function should return true if the type of the \p AA is AAKernelInfo3648static bool classof(const AbstractAttribute *AA) {3649return (AA->getIdAddr() == &ID);3650}36513652static const char ID;3653};36543655/// The function kernel info abstract attribute, basically, what can we say3656/// about a function with regards to the KernelInfoState.3657struct AAKernelInfoFunction : AAKernelInfo {3658AAKernelInfoFunction(const IRPosition &IRP, Attributor &A)3659: AAKernelInfo(IRP, A) {}36603661SmallPtrSet<Instruction *, 4> GuardedInstructions;36623663SmallPtrSetImpl<Instruction *> &getGuardedInstructions() {3664return GuardedInstructions;3665}36663667void setConfigurationOfKernelEnvironment(ConstantStruct *ConfigC) {3668Constant *NewKernelEnvC = ConstantFoldInsertValueInstruction(3669KernelEnvC, ConfigC, {KernelInfo::ConfigurationIdx});3670assert(NewKernelEnvC && "Failed to create new kernel environment");3671KernelEnvC = cast<ConstantStruct>(NewKernelEnvC);3672}36733674#define KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MEMBER) \3675void set##MEMBER##OfKernelEnvironment(ConstantInt *NewVal) { \3676ConstantStruct *ConfigC = \3677KernelInfo::getConfigurationFromKernelEnvironment(KernelEnvC); \3678Constant *NewConfigC = ConstantFoldInsertValueInstruction( \3679ConfigC, NewVal, {KernelInfo::MEMBER##Idx}); \3680assert(NewConfigC && "Failed to create new configuration environment"); \3681setConfigurationOfKernelEnvironment(cast<ConstantStruct>(NewConfigC)); \3682}36833684KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(UseGenericStateMachine)3685KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MayUseNestedParallelism)3686KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(ExecMode)3687KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MinThreads)3688KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MaxThreads)3689KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MinTeams)3690KERNEL_ENVIRONMENT_CONFIGURATION_SETTER(MaxTeams)36913692#undef KERNEL_ENVIRONMENT_CONFIGURATION_SETTER36933694/// See AbstractAttribute::initialize(...).3695void initialize(Attributor &A) override {3696// This is a high-level transform that might change the constant arguments3697// of the init and dinit calls. We need to tell the Attributor about this3698// to avoid other parts using the current constant value for simpliication.3699auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());37003701Function *Fn = getAnchorScope();37023703OMPInformationCache::RuntimeFunctionInfo &InitRFI =3704OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];3705OMPInformationCache::RuntimeFunctionInfo &DeinitRFI =3706OMPInfoCache.RFIs[OMPRTL___kmpc_target_deinit];37073708// For kernels we perform more initialization work, first we find the init3709// and deinit calls.3710auto StoreCallBase = [](Use &U,3711OMPInformationCache::RuntimeFunctionInfo &RFI,3712CallBase *&Storage) {3713CallBase *CB = OpenMPOpt::getCallIfRegularCall(U, &RFI);3714assert(CB &&3715"Unexpected use of __kmpc_target_init or __kmpc_target_deinit!");3716assert(!Storage &&3717"Multiple uses of __kmpc_target_init or __kmpc_target_deinit!");3718Storage = CB;3719return false;3720};3721InitRFI.foreachUse(3722[&](Use &U, Function &) {3723StoreCallBase(U, InitRFI, KernelInitCB);3724return false;3725},3726Fn);3727DeinitRFI.foreachUse(3728[&](Use &U, Function &) {3729StoreCallBase(U, DeinitRFI, KernelDeinitCB);3730return false;3731},3732Fn);37333734// Ignore kernels without initializers such as global constructors.3735if (!KernelInitCB || !KernelDeinitCB)3736return;37373738// Add itself to the reaching kernel and set IsKernelEntry.3739ReachingKernelEntries.insert(Fn);3740IsKernelEntry = true;37413742KernelEnvC =3743KernelInfo::getKernelEnvironementFromKernelInitCB(KernelInitCB);3744GlobalVariable *KernelEnvGV =3745KernelInfo::getKernelEnvironementGVFromKernelInitCB(KernelInitCB);37463747Attributor::GlobalVariableSimplifictionCallbackTy3748KernelConfigurationSimplifyCB =3749[&](const GlobalVariable &GV, const AbstractAttribute *AA,3750bool &UsedAssumedInformation) -> std::optional<Constant *> {3751if (!isAtFixpoint()) {3752if (!AA)3753return nullptr;3754UsedAssumedInformation = true;3755A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);3756}3757return KernelEnvC;3758};37593760A.registerGlobalVariableSimplificationCallback(3761*KernelEnvGV, KernelConfigurationSimplifyCB);37623763// Check if we know we are in SPMD-mode already.3764ConstantInt *ExecModeC =3765KernelInfo::getExecModeFromKernelEnvironment(KernelEnvC);3766ConstantInt *AssumedExecModeC = ConstantInt::get(3767ExecModeC->getIntegerType(),3768ExecModeC->getSExtValue() | OMP_TGT_EXEC_MODE_GENERIC_SPMD);3769if (ExecModeC->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD)3770SPMDCompatibilityTracker.indicateOptimisticFixpoint();3771else if (DisableOpenMPOptSPMDization)3772// This is a generic region but SPMDization is disabled so stop3773// tracking.3774SPMDCompatibilityTracker.indicatePessimisticFixpoint();3775else3776setExecModeOfKernelEnvironment(AssumedExecModeC);37773778const Triple T(Fn->getParent()->getTargetTriple());3779auto *Int32Ty = Type::getInt32Ty(Fn->getContext());3780auto [MinThreads, MaxThreads] =3781OpenMPIRBuilder::readThreadBoundsForKernel(T, *Fn);3782if (MinThreads)3783setMinThreadsOfKernelEnvironment(ConstantInt::get(Int32Ty, MinThreads));3784if (MaxThreads)3785setMaxThreadsOfKernelEnvironment(ConstantInt::get(Int32Ty, MaxThreads));3786auto [MinTeams, MaxTeams] =3787OpenMPIRBuilder::readTeamBoundsForKernel(T, *Fn);3788if (MinTeams)3789setMinTeamsOfKernelEnvironment(ConstantInt::get(Int32Ty, MinTeams));3790if (MaxTeams)3791setMaxTeamsOfKernelEnvironment(ConstantInt::get(Int32Ty, MaxTeams));37923793ConstantInt *MayUseNestedParallelismC =3794KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(KernelEnvC);3795ConstantInt *AssumedMayUseNestedParallelismC = ConstantInt::get(3796MayUseNestedParallelismC->getIntegerType(), NestedParallelism);3797setMayUseNestedParallelismOfKernelEnvironment(3798AssumedMayUseNestedParallelismC);37993800if (!DisableOpenMPOptStateMachineRewrite) {3801ConstantInt *UseGenericStateMachineC =3802KernelInfo::getUseGenericStateMachineFromKernelEnvironment(3803KernelEnvC);3804ConstantInt *AssumedUseGenericStateMachineC =3805ConstantInt::get(UseGenericStateMachineC->getIntegerType(), false);3806setUseGenericStateMachineOfKernelEnvironment(3807AssumedUseGenericStateMachineC);3808}38093810// Register virtual uses of functions we might need to preserve.3811auto RegisterVirtualUse = [&](RuntimeFunction RFKind,3812Attributor::VirtualUseCallbackTy &CB) {3813if (!OMPInfoCache.RFIs[RFKind].Declaration)3814return;3815A.registerVirtualUseCallback(*OMPInfoCache.RFIs[RFKind].Declaration, CB);3816};38173818// Add a dependence to ensure updates if the state changes.3819auto AddDependence = [](Attributor &A, const AAKernelInfo *KI,3820const AbstractAttribute *QueryingAA) {3821if (QueryingAA) {3822A.recordDependence(*KI, *QueryingAA, DepClassTy::OPTIONAL);3823}3824return true;3825};38263827Attributor::VirtualUseCallbackTy CustomStateMachineUseCB =3828[&](Attributor &A, const AbstractAttribute *QueryingAA) {3829// Whenever we create a custom state machine we will insert calls to3830// __kmpc_get_hardware_num_threads_in_block,3831// __kmpc_get_warp_size,3832// __kmpc_barrier_simple_generic,3833// __kmpc_kernel_parallel, and3834// __kmpc_kernel_end_parallel.3835// Not needed if we are on track for SPMDzation.3836if (SPMDCompatibilityTracker.isValidState())3837return AddDependence(A, this, QueryingAA);3838// Not needed if we can't rewrite due to an invalid state.3839if (!ReachedKnownParallelRegions.isValidState())3840return AddDependence(A, this, QueryingAA);3841return false;3842};38433844// Not needed if we are pre-runtime merge.3845if (!KernelInitCB->getCalledFunction()->isDeclaration()) {3846RegisterVirtualUse(OMPRTL___kmpc_get_hardware_num_threads_in_block,3847CustomStateMachineUseCB);3848RegisterVirtualUse(OMPRTL___kmpc_get_warp_size, CustomStateMachineUseCB);3849RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_generic,3850CustomStateMachineUseCB);3851RegisterVirtualUse(OMPRTL___kmpc_kernel_parallel,3852CustomStateMachineUseCB);3853RegisterVirtualUse(OMPRTL___kmpc_kernel_end_parallel,3854CustomStateMachineUseCB);3855}38563857// If we do not perform SPMDzation we do not need the virtual uses below.3858if (SPMDCompatibilityTracker.isAtFixpoint())3859return;38603861Attributor::VirtualUseCallbackTy HWThreadIdUseCB =3862[&](Attributor &A, const AbstractAttribute *QueryingAA) {3863// Whenever we perform SPMDzation we will insert3864// __kmpc_get_hardware_thread_id_in_block calls.3865if (!SPMDCompatibilityTracker.isValidState())3866return AddDependence(A, this, QueryingAA);3867return false;3868};3869RegisterVirtualUse(OMPRTL___kmpc_get_hardware_thread_id_in_block,3870HWThreadIdUseCB);38713872Attributor::VirtualUseCallbackTy SPMDBarrierUseCB =3873[&](Attributor &A, const AbstractAttribute *QueryingAA) {3874// Whenever we perform SPMDzation with guarding we will insert3875// __kmpc_simple_barrier_spmd calls. If SPMDzation failed, there is3876// nothing to guard, or there are no parallel regions, we don't need3877// the calls.3878if (!SPMDCompatibilityTracker.isValidState())3879return AddDependence(A, this, QueryingAA);3880if (SPMDCompatibilityTracker.empty())3881return AddDependence(A, this, QueryingAA);3882if (!mayContainParallelRegion())3883return AddDependence(A, this, QueryingAA);3884return false;3885};3886RegisterVirtualUse(OMPRTL___kmpc_barrier_simple_spmd, SPMDBarrierUseCB);3887}38883889/// Sanitize the string \p S such that it is a suitable global symbol name.3890static std::string sanitizeForGlobalName(std::string S) {3891std::replace_if(3892S.begin(), S.end(),3893[](const char C) {3894return !((C >= 'a' && C <= 'z') || (C >= 'A' && C <= 'Z') ||3895(C >= '0' && C <= '9') || C == '_');3896},3897'.');3898return S;3899}39003901/// Modify the IR based on the KernelInfoState as the fixpoint iteration is3902/// finished now.3903ChangeStatus manifest(Attributor &A) override {3904// If we are not looking at a kernel with __kmpc_target_init and3905// __kmpc_target_deinit call we cannot actually manifest the information.3906if (!KernelInitCB || !KernelDeinitCB)3907return ChangeStatus::UNCHANGED;39083909ChangeStatus Changed = ChangeStatus::UNCHANGED;39103911bool HasBuiltStateMachine = true;3912if (!changeToSPMDMode(A, Changed)) {3913if (!KernelInitCB->getCalledFunction()->isDeclaration())3914HasBuiltStateMachine = buildCustomStateMachine(A, Changed);3915else3916HasBuiltStateMachine = false;3917}39183919// We need to reset KernelEnvC if specific rewriting is not done.3920ConstantStruct *ExistingKernelEnvC =3921KernelInfo::getKernelEnvironementFromKernelInitCB(KernelInitCB);3922ConstantInt *OldUseGenericStateMachineVal =3923KernelInfo::getUseGenericStateMachineFromKernelEnvironment(3924ExistingKernelEnvC);3925if (!HasBuiltStateMachine)3926setUseGenericStateMachineOfKernelEnvironment(3927OldUseGenericStateMachineVal);39283929// At last, update the KernelEnvc3930GlobalVariable *KernelEnvGV =3931KernelInfo::getKernelEnvironementGVFromKernelInitCB(KernelInitCB);3932if (KernelEnvGV->getInitializer() != KernelEnvC) {3933KernelEnvGV->setInitializer(KernelEnvC);3934Changed = ChangeStatus::CHANGED;3935}39363937return Changed;3938}39393940void insertInstructionGuardsHelper(Attributor &A) {3941auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());39423943auto CreateGuardedRegion = [&](Instruction *RegionStartI,3944Instruction *RegionEndI) {3945LoopInfo *LI = nullptr;3946DominatorTree *DT = nullptr;3947MemorySSAUpdater *MSU = nullptr;3948using InsertPointTy = OpenMPIRBuilder::InsertPointTy;39493950BasicBlock *ParentBB = RegionStartI->getParent();3951Function *Fn = ParentBB->getParent();3952Module &M = *Fn->getParent();39533954// Create all the blocks and logic.3955// ParentBB:3956// goto RegionCheckTidBB3957// RegionCheckTidBB:3958// Tid = __kmpc_hardware_thread_id()3959// if (Tid != 0)3960// goto RegionBarrierBB3961// RegionStartBB:3962// <execute instructions guarded>3963// goto RegionEndBB3964// RegionEndBB:3965// <store escaping values to shared mem>3966// goto RegionBarrierBB3967// RegionBarrierBB:3968// __kmpc_simple_barrier_spmd()3969// // second barrier is omitted if lacking escaping values.3970// <load escaping values from shared mem>3971// __kmpc_simple_barrier_spmd()3972// goto RegionExitBB3973// RegionExitBB:3974// <execute rest of instructions>39753976BasicBlock *RegionEndBB = SplitBlock(ParentBB, RegionEndI->getNextNode(),3977DT, LI, MSU, "region.guarded.end");3978BasicBlock *RegionBarrierBB =3979SplitBlock(RegionEndBB, &*RegionEndBB->getFirstInsertionPt(), DT, LI,3980MSU, "region.barrier");3981BasicBlock *RegionExitBB =3982SplitBlock(RegionBarrierBB, &*RegionBarrierBB->getFirstInsertionPt(),3983DT, LI, MSU, "region.exit");3984BasicBlock *RegionStartBB =3985SplitBlock(ParentBB, RegionStartI, DT, LI, MSU, "region.guarded");39863987assert(ParentBB->getUniqueSuccessor() == RegionStartBB &&3988"Expected a different CFG");39893990BasicBlock *RegionCheckTidBB = SplitBlock(3991ParentBB, ParentBB->getTerminator(), DT, LI, MSU, "region.check.tid");39923993// Register basic blocks with the Attributor.3994A.registerManifestAddedBasicBlock(*RegionEndBB);3995A.registerManifestAddedBasicBlock(*RegionBarrierBB);3996A.registerManifestAddedBasicBlock(*RegionExitBB);3997A.registerManifestAddedBasicBlock(*RegionStartBB);3998A.registerManifestAddedBasicBlock(*RegionCheckTidBB);39994000bool HasBroadcastValues = false;4001// Find escaping outputs from the guarded region to outside users and4002// broadcast their values to them.4003for (Instruction &I : *RegionStartBB) {4004SmallVector<Use *, 4> OutsideUses;4005for (Use &U : I.uses()) {4006Instruction &UsrI = *cast<Instruction>(U.getUser());4007if (UsrI.getParent() != RegionStartBB)4008OutsideUses.push_back(&U);4009}40104011if (OutsideUses.empty())4012continue;40134014HasBroadcastValues = true;40154016// Emit a global variable in shared memory to store the broadcasted4017// value.4018auto *SharedMem = new GlobalVariable(4019M, I.getType(), /* IsConstant */ false,4020GlobalValue::InternalLinkage, UndefValue::get(I.getType()),4021sanitizeForGlobalName(4022(I.getName() + ".guarded.output.alloc").str()),4023nullptr, GlobalValue::NotThreadLocal,4024static_cast<unsigned>(AddressSpace::Shared));40254026// Emit a store instruction to update the value.4027new StoreInst(&I, SharedMem,4028RegionEndBB->getTerminator()->getIterator());40294030LoadInst *LoadI = new LoadInst(4031I.getType(), SharedMem, I.getName() + ".guarded.output.load",4032RegionBarrierBB->getTerminator()->getIterator());40334034// Emit a load instruction and replace uses of the output value.4035for (Use *U : OutsideUses)4036A.changeUseAfterManifest(*U, *LoadI);4037}40384039auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());40404041// Go to tid check BB in ParentBB.4042const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();4043ParentBB->getTerminator()->eraseFromParent();4044OpenMPIRBuilder::LocationDescription Loc(4045InsertPointTy(ParentBB, ParentBB->end()), DL);4046OMPInfoCache.OMPBuilder.updateToLocation(Loc);4047uint32_t SrcLocStrSize;4048auto *SrcLocStr =4049OMPInfoCache.OMPBuilder.getOrCreateSrcLocStr(Loc, SrcLocStrSize);4050Value *Ident =4051OMPInfoCache.OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize);4052BranchInst::Create(RegionCheckTidBB, ParentBB)->setDebugLoc(DL);40534054// Add check for Tid in RegionCheckTidBB4055RegionCheckTidBB->getTerminator()->eraseFromParent();4056OpenMPIRBuilder::LocationDescription LocRegionCheckTid(4057InsertPointTy(RegionCheckTidBB, RegionCheckTidBB->end()), DL);4058OMPInfoCache.OMPBuilder.updateToLocation(LocRegionCheckTid);4059FunctionCallee HardwareTidFn =4060OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(4061M, OMPRTL___kmpc_get_hardware_thread_id_in_block);4062CallInst *Tid =4063OMPInfoCache.OMPBuilder.Builder.CreateCall(HardwareTidFn, {});4064Tid->setDebugLoc(DL);4065OMPInfoCache.setCallingConvention(HardwareTidFn, Tid);4066Value *TidCheck = OMPInfoCache.OMPBuilder.Builder.CreateIsNull(Tid);4067OMPInfoCache.OMPBuilder.Builder4068.CreateCondBr(TidCheck, RegionStartBB, RegionBarrierBB)4069->setDebugLoc(DL);40704071// First barrier for synchronization, ensures main thread has updated4072// values.4073FunctionCallee BarrierFn =4074OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(4075M, OMPRTL___kmpc_barrier_simple_spmd);4076OMPInfoCache.OMPBuilder.updateToLocation(InsertPointTy(4077RegionBarrierBB, RegionBarrierBB->getFirstInsertionPt()));4078CallInst *Barrier =4079OMPInfoCache.OMPBuilder.Builder.CreateCall(BarrierFn, {Ident, Tid});4080Barrier->setDebugLoc(DL);4081OMPInfoCache.setCallingConvention(BarrierFn, Barrier);40824083// Second barrier ensures workers have read broadcast values.4084if (HasBroadcastValues) {4085CallInst *Barrier =4086CallInst::Create(BarrierFn, {Ident, Tid}, "",4087RegionBarrierBB->getTerminator()->getIterator());4088Barrier->setDebugLoc(DL);4089OMPInfoCache.setCallingConvention(BarrierFn, Barrier);4090}4091};40924093auto &AllocSharedRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];4094SmallPtrSet<BasicBlock *, 8> Visited;4095for (Instruction *GuardedI : SPMDCompatibilityTracker) {4096BasicBlock *BB = GuardedI->getParent();4097if (!Visited.insert(BB).second)4098continue;40994100SmallVector<std::pair<Instruction *, Instruction *>> Reorders;4101Instruction *LastEffect = nullptr;4102BasicBlock::reverse_iterator IP = BB->rbegin(), IPEnd = BB->rend();4103while (++IP != IPEnd) {4104if (!IP->mayHaveSideEffects() && !IP->mayReadFromMemory())4105continue;4106Instruction *I = &*IP;4107if (OpenMPOpt::getCallIfRegularCall(*I, &AllocSharedRFI))4108continue;4109if (!I->user_empty() || !SPMDCompatibilityTracker.contains(I)) {4110LastEffect = nullptr;4111continue;4112}4113if (LastEffect)4114Reorders.push_back({I, LastEffect});4115LastEffect = &*IP;4116}4117for (auto &Reorder : Reorders)4118Reorder.first->moveBefore(Reorder.second);4119}41204121SmallVector<std::pair<Instruction *, Instruction *>, 4> GuardedRegions;41224123for (Instruction *GuardedI : SPMDCompatibilityTracker) {4124BasicBlock *BB = GuardedI->getParent();4125auto *CalleeAA = A.lookupAAFor<AAKernelInfo>(4126IRPosition::function(*GuardedI->getFunction()), nullptr,4127DepClassTy::NONE);4128assert(CalleeAA != nullptr && "Expected Callee AAKernelInfo");4129auto &CalleeAAFunction = *cast<AAKernelInfoFunction>(CalleeAA);4130// Continue if instruction is already guarded.4131if (CalleeAAFunction.getGuardedInstructions().contains(GuardedI))4132continue;41334134Instruction *GuardedRegionStart = nullptr, *GuardedRegionEnd = nullptr;4135for (Instruction &I : *BB) {4136// If instruction I needs to be guarded update the guarded region4137// bounds.4138if (SPMDCompatibilityTracker.contains(&I)) {4139CalleeAAFunction.getGuardedInstructions().insert(&I);4140if (GuardedRegionStart)4141GuardedRegionEnd = &I;4142else4143GuardedRegionStart = GuardedRegionEnd = &I;41444145continue;4146}41474148// Instruction I does not need guarding, store4149// any region found and reset bounds.4150if (GuardedRegionStart) {4151GuardedRegions.push_back(4152std::make_pair(GuardedRegionStart, GuardedRegionEnd));4153GuardedRegionStart = nullptr;4154GuardedRegionEnd = nullptr;4155}4156}4157}41584159for (auto &GR : GuardedRegions)4160CreateGuardedRegion(GR.first, GR.second);4161}41624163void forceSingleThreadPerWorkgroupHelper(Attributor &A) {4164// Only allow 1 thread per workgroup to continue executing the user code.4165//4166// InitCB = __kmpc_target_init(...)4167// ThreadIdInBlock = __kmpc_get_hardware_thread_id_in_block();4168// if (ThreadIdInBlock != 0) return;4169// UserCode:4170// // user code4171//4172auto &Ctx = getAnchorValue().getContext();4173Function *Kernel = getAssociatedFunction();4174assert(Kernel && "Expected an associated function!");41754176// Create block for user code to branch to from initial block.4177BasicBlock *InitBB = KernelInitCB->getParent();4178BasicBlock *UserCodeBB = InitBB->splitBasicBlock(4179KernelInitCB->getNextNode(), "main.thread.user_code");4180BasicBlock *ReturnBB =4181BasicBlock::Create(Ctx, "exit.threads", Kernel, UserCodeBB);41824183// Register blocks with attributor:4184A.registerManifestAddedBasicBlock(*InitBB);4185A.registerManifestAddedBasicBlock(*UserCodeBB);4186A.registerManifestAddedBasicBlock(*ReturnBB);41874188// Debug location:4189const DebugLoc &DLoc = KernelInitCB->getDebugLoc();4190ReturnInst::Create(Ctx, ReturnBB)->setDebugLoc(DLoc);4191InitBB->getTerminator()->eraseFromParent();41924193// Prepare call to OMPRTL___kmpc_get_hardware_thread_id_in_block.4194Module &M = *Kernel->getParent();4195auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());4196FunctionCallee ThreadIdInBlockFn =4197OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(4198M, OMPRTL___kmpc_get_hardware_thread_id_in_block);41994200// Get thread ID in block.4201CallInst *ThreadIdInBlock =4202CallInst::Create(ThreadIdInBlockFn, "thread_id.in.block", InitBB);4203OMPInfoCache.setCallingConvention(ThreadIdInBlockFn, ThreadIdInBlock);4204ThreadIdInBlock->setDebugLoc(DLoc);42054206// Eliminate all threads in the block with ID not equal to 0:4207Instruction *IsMainThread =4208ICmpInst::Create(ICmpInst::ICmp, CmpInst::ICMP_NE, ThreadIdInBlock,4209ConstantInt::get(ThreadIdInBlock->getType(), 0),4210"thread.is_main", InitBB);4211IsMainThread->setDebugLoc(DLoc);4212BranchInst::Create(ReturnBB, UserCodeBB, IsMainThread, InitBB);4213}42144215bool changeToSPMDMode(Attributor &A, ChangeStatus &Changed) {4216auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());42174218// We cannot change to SPMD mode if the runtime functions aren't availible.4219if (!OMPInfoCache.runtimeFnsAvailable(4220{OMPRTL___kmpc_get_hardware_thread_id_in_block,4221OMPRTL___kmpc_barrier_simple_spmd}))4222return false;42234224if (!SPMDCompatibilityTracker.isAssumed()) {4225for (Instruction *NonCompatibleI : SPMDCompatibilityTracker) {4226if (!NonCompatibleI)4227continue;42284229// Skip diagnostics on calls to known OpenMP runtime functions for now.4230if (auto *CB = dyn_cast<CallBase>(NonCompatibleI))4231if (OMPInfoCache.RTLFunctions.contains(CB->getCalledFunction()))4232continue;42334234auto Remark = [&](OptimizationRemarkAnalysis ORA) {4235ORA << "Value has potential side effects preventing SPMD-mode "4236"execution";4237if (isa<CallBase>(NonCompatibleI)) {4238ORA << ". Add `[[omp::assume(\"ompx_spmd_amenable\")]]` to "4239"the called function to override";4240}4241return ORA << ".";4242};4243A.emitRemark<OptimizationRemarkAnalysis>(NonCompatibleI, "OMP121",4244Remark);42454246LLVM_DEBUG(dbgs() << TAG << "SPMD-incompatible side-effect: "4247<< *NonCompatibleI << "\n");4248}42494250return false;4251}42524253// Get the actual kernel, could be the caller of the anchor scope if we have4254// a debug wrapper.4255Function *Kernel = getAnchorScope();4256if (Kernel->hasLocalLinkage()) {4257assert(Kernel->hasOneUse() && "Unexpected use of debug kernel wrapper.");4258auto *CB = cast<CallBase>(Kernel->user_back());4259Kernel = CB->getCaller();4260}4261assert(omp::isOpenMPKernel(*Kernel) && "Expected kernel function!");42624263// Check if the kernel is already in SPMD mode, if so, return success.4264ConstantStruct *ExistingKernelEnvC =4265KernelInfo::getKernelEnvironementFromKernelInitCB(KernelInitCB);4266auto *ExecModeC =4267KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC);4268const int8_t ExecModeVal = ExecModeC->getSExtValue();4269if (ExecModeVal != OMP_TGT_EXEC_MODE_GENERIC)4270return true;42714272// We will now unconditionally modify the IR, indicate a change.4273Changed = ChangeStatus::CHANGED;42744275// Do not use instruction guards when no parallel is present inside4276// the target region.4277if (mayContainParallelRegion())4278insertInstructionGuardsHelper(A);4279else4280forceSingleThreadPerWorkgroupHelper(A);42814282// Adjust the global exec mode flag that tells the runtime what mode this4283// kernel is executed in.4284assert(ExecModeVal == OMP_TGT_EXEC_MODE_GENERIC &&4285"Initially non-SPMD kernel has SPMD exec mode!");4286setExecModeOfKernelEnvironment(4287ConstantInt::get(ExecModeC->getIntegerType(),4288ExecModeVal | OMP_TGT_EXEC_MODE_GENERIC_SPMD));42894290++NumOpenMPTargetRegionKernelsSPMD;42914292auto Remark = [&](OptimizationRemark OR) {4293return OR << "Transformed generic-mode kernel to SPMD-mode.";4294};4295A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP120", Remark);4296return true;4297};42984299bool buildCustomStateMachine(Attributor &A, ChangeStatus &Changed) {4300// If we have disabled state machine rewrites, don't make a custom one4301if (DisableOpenMPOptStateMachineRewrite)4302return false;43034304// Don't rewrite the state machine if we are not in a valid state.4305if (!ReachedKnownParallelRegions.isValidState())4306return false;43074308auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());4309if (!OMPInfoCache.runtimeFnsAvailable(4310{OMPRTL___kmpc_get_hardware_num_threads_in_block,4311OMPRTL___kmpc_get_warp_size, OMPRTL___kmpc_barrier_simple_generic,4312OMPRTL___kmpc_kernel_parallel, OMPRTL___kmpc_kernel_end_parallel}))4313return false;43144315ConstantStruct *ExistingKernelEnvC =4316KernelInfo::getKernelEnvironementFromKernelInitCB(KernelInitCB);43174318// Check if the current configuration is non-SPMD and generic state machine.4319// If we already have SPMD mode or a custom state machine we do not need to4320// go any further. If it is anything but a constant something is weird and4321// we give up.4322ConstantInt *UseStateMachineC =4323KernelInfo::getUseGenericStateMachineFromKernelEnvironment(4324ExistingKernelEnvC);4325ConstantInt *ModeC =4326KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC);43274328// If we are stuck with generic mode, try to create a custom device (=GPU)4329// state machine which is specialized for the parallel regions that are4330// reachable by the kernel.4331if (UseStateMachineC->isZero() ||4332(ModeC->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD))4333return false;43344335Changed = ChangeStatus::CHANGED;43364337// If not SPMD mode, indicate we use a custom state machine now.4338setUseGenericStateMachineOfKernelEnvironment(4339ConstantInt::get(UseStateMachineC->getIntegerType(), false));43404341// If we don't actually need a state machine we are done here. This can4342// happen if there simply are no parallel regions. In the resulting kernel4343// all worker threads will simply exit right away, leaving the main thread4344// to do the work alone.4345if (!mayContainParallelRegion()) {4346++NumOpenMPTargetRegionKernelsWithoutStateMachine;43474348auto Remark = [&](OptimizationRemark OR) {4349return OR << "Removing unused state machine from generic-mode kernel.";4350};4351A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP130", Remark);43524353return true;4354}43554356// Keep track in the statistics of our new shiny custom state machine.4357if (ReachedUnknownParallelRegions.empty()) {4358++NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback;43594360auto Remark = [&](OptimizationRemark OR) {4361return OR << "Rewriting generic-mode kernel with a customized state "4362"machine.";4363};4364A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP131", Remark);4365} else {4366++NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback;43674368auto Remark = [&](OptimizationRemarkAnalysis OR) {4369return OR << "Generic-mode kernel is executed with a customized state "4370"machine that requires a fallback.";4371};4372A.emitRemark<OptimizationRemarkAnalysis>(KernelInitCB, "OMP132", Remark);43734374// Tell the user why we ended up with a fallback.4375for (CallBase *UnknownParallelRegionCB : ReachedUnknownParallelRegions) {4376if (!UnknownParallelRegionCB)4377continue;4378auto Remark = [&](OptimizationRemarkAnalysis ORA) {4379return ORA << "Call may contain unknown parallel regions. Use "4380<< "`[[omp::assume(\"omp_no_parallelism\")]]` to "4381"override.";4382};4383A.emitRemark<OptimizationRemarkAnalysis>(UnknownParallelRegionCB,4384"OMP133", Remark);4385}4386}43874388// Create all the blocks:4389//4390// InitCB = __kmpc_target_init(...)4391// BlockHwSize =4392// __kmpc_get_hardware_num_threads_in_block();4393// WarpSize = __kmpc_get_warp_size();4394// BlockSize = BlockHwSize - WarpSize;4395// IsWorkerCheckBB: bool IsWorker = InitCB != -1;4396// if (IsWorker) {4397// if (InitCB >= BlockSize) return;4398// SMBeginBB: __kmpc_barrier_simple_generic(...);4399// void *WorkFn;4400// bool Active = __kmpc_kernel_parallel(&WorkFn);4401// if (!WorkFn) return;4402// SMIsActiveCheckBB: if (Active) {4403// SMIfCascadeCurrentBB: if (WorkFn == <ParFn0>)4404// ParFn0(...);4405// SMIfCascadeCurrentBB: else if (WorkFn == <ParFn1>)4406// ParFn1(...);4407// ...4408// SMIfCascadeCurrentBB: else4409// ((WorkFnTy*)WorkFn)(...);4410// SMEndParallelBB: __kmpc_kernel_end_parallel(...);4411// }4412// SMDoneBB: __kmpc_barrier_simple_generic(...);4413// goto SMBeginBB;4414// }4415// UserCodeEntryBB: // user code4416// __kmpc_target_deinit(...)4417//4418auto &Ctx = getAnchorValue().getContext();4419Function *Kernel = getAssociatedFunction();4420assert(Kernel && "Expected an associated function!");44214422BasicBlock *InitBB = KernelInitCB->getParent();4423BasicBlock *UserCodeEntryBB = InitBB->splitBasicBlock(4424KernelInitCB->getNextNode(), "thread.user_code.check");4425BasicBlock *IsWorkerCheckBB =4426BasicBlock::Create(Ctx, "is_worker_check", Kernel, UserCodeEntryBB);4427BasicBlock *StateMachineBeginBB = BasicBlock::Create(4428Ctx, "worker_state_machine.begin", Kernel, UserCodeEntryBB);4429BasicBlock *StateMachineFinishedBB = BasicBlock::Create(4430Ctx, "worker_state_machine.finished", Kernel, UserCodeEntryBB);4431BasicBlock *StateMachineIsActiveCheckBB = BasicBlock::Create(4432Ctx, "worker_state_machine.is_active.check", Kernel, UserCodeEntryBB);4433BasicBlock *StateMachineIfCascadeCurrentBB =4434BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",4435Kernel, UserCodeEntryBB);4436BasicBlock *StateMachineEndParallelBB =4437BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.end",4438Kernel, UserCodeEntryBB);4439BasicBlock *StateMachineDoneBarrierBB = BasicBlock::Create(4440Ctx, "worker_state_machine.done.barrier", Kernel, UserCodeEntryBB);4441A.registerManifestAddedBasicBlock(*InitBB);4442A.registerManifestAddedBasicBlock(*UserCodeEntryBB);4443A.registerManifestAddedBasicBlock(*IsWorkerCheckBB);4444A.registerManifestAddedBasicBlock(*StateMachineBeginBB);4445A.registerManifestAddedBasicBlock(*StateMachineFinishedBB);4446A.registerManifestAddedBasicBlock(*StateMachineIsActiveCheckBB);4447A.registerManifestAddedBasicBlock(*StateMachineIfCascadeCurrentBB);4448A.registerManifestAddedBasicBlock(*StateMachineEndParallelBB);4449A.registerManifestAddedBasicBlock(*StateMachineDoneBarrierBB);44504451const DebugLoc &DLoc = KernelInitCB->getDebugLoc();4452ReturnInst::Create(Ctx, StateMachineFinishedBB)->setDebugLoc(DLoc);4453InitBB->getTerminator()->eraseFromParent();44544455Instruction *IsWorker =4456ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_NE, KernelInitCB,4457ConstantInt::get(KernelInitCB->getType(), -1),4458"thread.is_worker", InitBB);4459IsWorker->setDebugLoc(DLoc);4460BranchInst::Create(IsWorkerCheckBB, UserCodeEntryBB, IsWorker, InitBB);44614462Module &M = *Kernel->getParent();4463FunctionCallee BlockHwSizeFn =4464OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(4465M, OMPRTL___kmpc_get_hardware_num_threads_in_block);4466FunctionCallee WarpSizeFn =4467OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(4468M, OMPRTL___kmpc_get_warp_size);4469CallInst *BlockHwSize =4470CallInst::Create(BlockHwSizeFn, "block.hw_size", IsWorkerCheckBB);4471OMPInfoCache.setCallingConvention(BlockHwSizeFn, BlockHwSize);4472BlockHwSize->setDebugLoc(DLoc);4473CallInst *WarpSize =4474CallInst::Create(WarpSizeFn, "warp.size", IsWorkerCheckBB);4475OMPInfoCache.setCallingConvention(WarpSizeFn, WarpSize);4476WarpSize->setDebugLoc(DLoc);4477Instruction *BlockSize = BinaryOperator::CreateSub(4478BlockHwSize, WarpSize, "block.size", IsWorkerCheckBB);4479BlockSize->setDebugLoc(DLoc);4480Instruction *IsMainOrWorker = ICmpInst::Create(4481ICmpInst::ICmp, llvm::CmpInst::ICMP_SLT, KernelInitCB, BlockSize,4482"thread.is_main_or_worker", IsWorkerCheckBB);4483IsMainOrWorker->setDebugLoc(DLoc);4484BranchInst::Create(StateMachineBeginBB, StateMachineFinishedBB,4485IsMainOrWorker, IsWorkerCheckBB);44864487// Create local storage for the work function pointer.4488const DataLayout &DL = M.getDataLayout();4489Type *VoidPtrTy = PointerType::getUnqual(Ctx);4490Instruction *WorkFnAI =4491new AllocaInst(VoidPtrTy, DL.getAllocaAddrSpace(), nullptr,4492"worker.work_fn.addr", Kernel->getEntryBlock().begin());4493WorkFnAI->setDebugLoc(DLoc);44944495OMPInfoCache.OMPBuilder.updateToLocation(4496OpenMPIRBuilder::LocationDescription(4497IRBuilder<>::InsertPoint(StateMachineBeginBB,4498StateMachineBeginBB->end()),4499DLoc));45004501Value *Ident = KernelInfo::getIdentFromKernelEnvironment(KernelEnvC);4502Value *GTid = KernelInitCB;45034504FunctionCallee BarrierFn =4505OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(4506M, OMPRTL___kmpc_barrier_simple_generic);4507CallInst *Barrier =4508CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineBeginBB);4509OMPInfoCache.setCallingConvention(BarrierFn, Barrier);4510Barrier->setDebugLoc(DLoc);45114512if (WorkFnAI->getType()->getPointerAddressSpace() !=4513(unsigned int)AddressSpace::Generic) {4514WorkFnAI = new AddrSpaceCastInst(4515WorkFnAI, PointerType::get(Ctx, (unsigned int)AddressSpace::Generic),4516WorkFnAI->getName() + ".generic", StateMachineBeginBB);4517WorkFnAI->setDebugLoc(DLoc);4518}45194520FunctionCallee KernelParallelFn =4521OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(4522M, OMPRTL___kmpc_kernel_parallel);4523CallInst *IsActiveWorker = CallInst::Create(4524KernelParallelFn, {WorkFnAI}, "worker.is_active", StateMachineBeginBB);4525OMPInfoCache.setCallingConvention(KernelParallelFn, IsActiveWorker);4526IsActiveWorker->setDebugLoc(DLoc);4527Instruction *WorkFn = new LoadInst(VoidPtrTy, WorkFnAI, "worker.work_fn",4528StateMachineBeginBB);4529WorkFn->setDebugLoc(DLoc);45304531FunctionType *ParallelRegionFnTy = FunctionType::get(4532Type::getVoidTy(Ctx), {Type::getInt16Ty(Ctx), Type::getInt32Ty(Ctx)},4533false);45344535Instruction *IsDone =4536ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn,4537Constant::getNullValue(VoidPtrTy), "worker.is_done",4538StateMachineBeginBB);4539IsDone->setDebugLoc(DLoc);4540BranchInst::Create(StateMachineFinishedBB, StateMachineIsActiveCheckBB,4541IsDone, StateMachineBeginBB)4542->setDebugLoc(DLoc);45434544BranchInst::Create(StateMachineIfCascadeCurrentBB,4545StateMachineDoneBarrierBB, IsActiveWorker,4546StateMachineIsActiveCheckBB)4547->setDebugLoc(DLoc);45484549Value *ZeroArg =4550Constant::getNullValue(ParallelRegionFnTy->getParamType(0));45514552const unsigned int WrapperFunctionArgNo = 6;45534554// Now that we have most of the CFG skeleton it is time for the if-cascade4555// that checks the function pointer we got from the runtime against the4556// parallel regions we expect, if there are any.4557for (int I = 0, E = ReachedKnownParallelRegions.size(); I < E; ++I) {4558auto *CB = ReachedKnownParallelRegions[I];4559auto *ParallelRegion = dyn_cast<Function>(4560CB->getArgOperand(WrapperFunctionArgNo)->stripPointerCasts());4561BasicBlock *PRExecuteBB = BasicBlock::Create(4562Ctx, "worker_state_machine.parallel_region.execute", Kernel,4563StateMachineEndParallelBB);4564CallInst::Create(ParallelRegion, {ZeroArg, GTid}, "", PRExecuteBB)4565->setDebugLoc(DLoc);4566BranchInst::Create(StateMachineEndParallelBB, PRExecuteBB)4567->setDebugLoc(DLoc);45684569BasicBlock *PRNextBB =4570BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",4571Kernel, StateMachineEndParallelBB);4572A.registerManifestAddedBasicBlock(*PRExecuteBB);4573A.registerManifestAddedBasicBlock(*PRNextBB);45744575// Check if we need to compare the pointer at all or if we can just4576// call the parallel region function.4577Value *IsPR;4578if (I + 1 < E || !ReachedUnknownParallelRegions.empty()) {4579Instruction *CmpI = ICmpInst::Create(4580ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn, ParallelRegion,4581"worker.check_parallel_region", StateMachineIfCascadeCurrentBB);4582CmpI->setDebugLoc(DLoc);4583IsPR = CmpI;4584} else {4585IsPR = ConstantInt::getTrue(Ctx);4586}45874588BranchInst::Create(PRExecuteBB, PRNextBB, IsPR,4589StateMachineIfCascadeCurrentBB)4590->setDebugLoc(DLoc);4591StateMachineIfCascadeCurrentBB = PRNextBB;4592}45934594// At the end of the if-cascade we place the indirect function pointer call4595// in case we might need it, that is if there can be parallel regions we4596// have not handled in the if-cascade above.4597if (!ReachedUnknownParallelRegions.empty()) {4598StateMachineIfCascadeCurrentBB->setName(4599"worker_state_machine.parallel_region.fallback.execute");4600CallInst::Create(ParallelRegionFnTy, WorkFn, {ZeroArg, GTid}, "",4601StateMachineIfCascadeCurrentBB)4602->setDebugLoc(DLoc);4603}4604BranchInst::Create(StateMachineEndParallelBB,4605StateMachineIfCascadeCurrentBB)4606->setDebugLoc(DLoc);46074608FunctionCallee EndParallelFn =4609OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(4610M, OMPRTL___kmpc_kernel_end_parallel);4611CallInst *EndParallel =4612CallInst::Create(EndParallelFn, {}, "", StateMachineEndParallelBB);4613OMPInfoCache.setCallingConvention(EndParallelFn, EndParallel);4614EndParallel->setDebugLoc(DLoc);4615BranchInst::Create(StateMachineDoneBarrierBB, StateMachineEndParallelBB)4616->setDebugLoc(DLoc);46174618CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineDoneBarrierBB)4619->setDebugLoc(DLoc);4620BranchInst::Create(StateMachineBeginBB, StateMachineDoneBarrierBB)4621->setDebugLoc(DLoc);46224623return true;4624}46254626/// Fixpoint iteration update function. Will be called every time a dependence4627/// changed its state (and in the beginning).4628ChangeStatus updateImpl(Attributor &A) override {4629KernelInfoState StateBefore = getState();46304631// When we leave this function this RAII will make sure the member4632// KernelEnvC is updated properly depending on the state. That member is4633// used for simplification of values and needs to be up to date at all4634// times.4635struct UpdateKernelEnvCRAII {4636AAKernelInfoFunction &AA;46374638UpdateKernelEnvCRAII(AAKernelInfoFunction &AA) : AA(AA) {}46394640~UpdateKernelEnvCRAII() {4641if (!AA.KernelEnvC)4642return;46434644ConstantStruct *ExistingKernelEnvC =4645KernelInfo::getKernelEnvironementFromKernelInitCB(AA.KernelInitCB);46464647if (!AA.isValidState()) {4648AA.KernelEnvC = ExistingKernelEnvC;4649return;4650}46514652if (!AA.ReachedKnownParallelRegions.isValidState())4653AA.setUseGenericStateMachineOfKernelEnvironment(4654KernelInfo::getUseGenericStateMachineFromKernelEnvironment(4655ExistingKernelEnvC));46564657if (!AA.SPMDCompatibilityTracker.isValidState())4658AA.setExecModeOfKernelEnvironment(4659KernelInfo::getExecModeFromKernelEnvironment(ExistingKernelEnvC));46604661ConstantInt *MayUseNestedParallelismC =4662KernelInfo::getMayUseNestedParallelismFromKernelEnvironment(4663AA.KernelEnvC);4664ConstantInt *NewMayUseNestedParallelismC = ConstantInt::get(4665MayUseNestedParallelismC->getIntegerType(), AA.NestedParallelism);4666AA.setMayUseNestedParallelismOfKernelEnvironment(4667NewMayUseNestedParallelismC);4668}4669} RAII(*this);46704671// Callback to check a read/write instruction.4672auto CheckRWInst = [&](Instruction &I) {4673// We handle calls later.4674if (isa<CallBase>(I))4675return true;4676// We only care about write effects.4677if (!I.mayWriteToMemory())4678return true;4679if (auto *SI = dyn_cast<StoreInst>(&I)) {4680const auto *UnderlyingObjsAA = A.getAAFor<AAUnderlyingObjects>(4681*this, IRPosition::value(*SI->getPointerOperand()),4682DepClassTy::OPTIONAL);4683auto *HS = A.getAAFor<AAHeapToStack>(4684*this, IRPosition::function(*I.getFunction()),4685DepClassTy::OPTIONAL);4686if (UnderlyingObjsAA &&4687UnderlyingObjsAA->forallUnderlyingObjects([&](Value &Obj) {4688if (AA::isAssumedThreadLocalObject(A, Obj, *this))4689return true;4690// Check for AAHeapToStack moved objects which must not be4691// guarded.4692auto *CB = dyn_cast<CallBase>(&Obj);4693return CB && HS && HS->isAssumedHeapToStack(*CB);4694}))4695return true;4696}46974698// Insert instruction that needs guarding.4699SPMDCompatibilityTracker.insert(&I);4700return true;4701};47024703bool UsedAssumedInformationInCheckRWInst = false;4704if (!SPMDCompatibilityTracker.isAtFixpoint())4705if (!A.checkForAllReadWriteInstructions(4706CheckRWInst, *this, UsedAssumedInformationInCheckRWInst))4707SPMDCompatibilityTracker.indicatePessimisticFixpoint();47084709bool UsedAssumedInformationFromReachingKernels = false;4710if (!IsKernelEntry) {4711updateParallelLevels(A);47124713bool AllReachingKernelsKnown = true;4714updateReachingKernelEntries(A, AllReachingKernelsKnown);4715UsedAssumedInformationFromReachingKernels = !AllReachingKernelsKnown;47164717if (!SPMDCompatibilityTracker.empty()) {4718if (!ParallelLevels.isValidState())4719SPMDCompatibilityTracker.indicatePessimisticFixpoint();4720else if (!ReachingKernelEntries.isValidState())4721SPMDCompatibilityTracker.indicatePessimisticFixpoint();4722else {4723// Check if all reaching kernels agree on the mode as we can otherwise4724// not guard instructions. We might not be sure about the mode so we4725// we cannot fix the internal spmd-zation state either.4726int SPMD = 0, Generic = 0;4727for (auto *Kernel : ReachingKernelEntries) {4728auto *CBAA = A.getAAFor<AAKernelInfo>(4729*this, IRPosition::function(*Kernel), DepClassTy::OPTIONAL);4730if (CBAA && CBAA->SPMDCompatibilityTracker.isValidState() &&4731CBAA->SPMDCompatibilityTracker.isAssumed())4732++SPMD;4733else4734++Generic;4735if (!CBAA || !CBAA->SPMDCompatibilityTracker.isAtFixpoint())4736UsedAssumedInformationFromReachingKernels = true;4737}4738if (SPMD != 0 && Generic != 0)4739SPMDCompatibilityTracker.indicatePessimisticFixpoint();4740}4741}4742}47434744// Callback to check a call instruction.4745bool AllParallelRegionStatesWereFixed = true;4746bool AllSPMDStatesWereFixed = true;4747auto CheckCallInst = [&](Instruction &I) {4748auto &CB = cast<CallBase>(I);4749auto *CBAA = A.getAAFor<AAKernelInfo>(4750*this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);4751if (!CBAA)4752return false;4753getState() ^= CBAA->getState();4754AllSPMDStatesWereFixed &= CBAA->SPMDCompatibilityTracker.isAtFixpoint();4755AllParallelRegionStatesWereFixed &=4756CBAA->ReachedKnownParallelRegions.isAtFixpoint();4757AllParallelRegionStatesWereFixed &=4758CBAA->ReachedUnknownParallelRegions.isAtFixpoint();4759return true;4760};47614762bool UsedAssumedInformationInCheckCallInst = false;4763if (!A.checkForAllCallLikeInstructions(4764CheckCallInst, *this, UsedAssumedInformationInCheckCallInst)) {4765LLVM_DEBUG(dbgs() << TAG4766<< "Failed to visit all call-like instructions!\n";);4767return indicatePessimisticFixpoint();4768}47694770// If we haven't used any assumed information for the reached parallel4771// region states we can fix it.4772if (!UsedAssumedInformationInCheckCallInst &&4773AllParallelRegionStatesWereFixed) {4774ReachedKnownParallelRegions.indicateOptimisticFixpoint();4775ReachedUnknownParallelRegions.indicateOptimisticFixpoint();4776}47774778// If we haven't used any assumed information for the SPMD state we can fix4779// it.4780if (!UsedAssumedInformationInCheckRWInst &&4781!UsedAssumedInformationInCheckCallInst &&4782!UsedAssumedInformationFromReachingKernels && AllSPMDStatesWereFixed)4783SPMDCompatibilityTracker.indicateOptimisticFixpoint();47844785return StateBefore == getState() ? ChangeStatus::UNCHANGED4786: ChangeStatus::CHANGED;4787}47884789private:4790/// Update info regarding reaching kernels.4791void updateReachingKernelEntries(Attributor &A,4792bool &AllReachingKernelsKnown) {4793auto PredCallSite = [&](AbstractCallSite ACS) {4794Function *Caller = ACS.getInstruction()->getFunction();47954796assert(Caller && "Caller is nullptr");47974798auto *CAA = A.getOrCreateAAFor<AAKernelInfo>(4799IRPosition::function(*Caller), this, DepClassTy::REQUIRED);4800if (CAA && CAA->ReachingKernelEntries.isValidState()) {4801ReachingKernelEntries ^= CAA->ReachingKernelEntries;4802return true;4803}48044805// We lost track of the caller of the associated function, any kernel4806// could reach now.4807ReachingKernelEntries.indicatePessimisticFixpoint();48084809return true;4810};48114812if (!A.checkForAllCallSites(PredCallSite, *this,4813true /* RequireAllCallSites */,4814AllReachingKernelsKnown))4815ReachingKernelEntries.indicatePessimisticFixpoint();4816}48174818/// Update info regarding parallel levels.4819void updateParallelLevels(Attributor &A) {4820auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());4821OMPInformationCache::RuntimeFunctionInfo &Parallel51RFI =4822OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];48234824auto PredCallSite = [&](AbstractCallSite ACS) {4825Function *Caller = ACS.getInstruction()->getFunction();48264827assert(Caller && "Caller is nullptr");48284829auto *CAA =4830A.getOrCreateAAFor<AAKernelInfo>(IRPosition::function(*Caller));4831if (CAA && CAA->ParallelLevels.isValidState()) {4832// Any function that is called by `__kmpc_parallel_51` will not be4833// folded as the parallel level in the function is updated. In order to4834// get it right, all the analysis would depend on the implentation. That4835// said, if in the future any change to the implementation, the analysis4836// could be wrong. As a consequence, we are just conservative here.4837if (Caller == Parallel51RFI.Declaration) {4838ParallelLevels.indicatePessimisticFixpoint();4839return true;4840}48414842ParallelLevels ^= CAA->ParallelLevels;48434844return true;4845}48464847// We lost track of the caller of the associated function, any kernel4848// could reach now.4849ParallelLevels.indicatePessimisticFixpoint();48504851return true;4852};48534854bool AllCallSitesKnown = true;4855if (!A.checkForAllCallSites(PredCallSite, *this,4856true /* RequireAllCallSites */,4857AllCallSitesKnown))4858ParallelLevels.indicatePessimisticFixpoint();4859}4860};48614862/// The call site kernel info abstract attribute, basically, what can we say4863/// about a call site with regards to the KernelInfoState. For now this simply4864/// forwards the information from the callee.4865struct AAKernelInfoCallSite : AAKernelInfo {4866AAKernelInfoCallSite(const IRPosition &IRP, Attributor &A)4867: AAKernelInfo(IRP, A) {}48684869/// See AbstractAttribute::initialize(...).4870void initialize(Attributor &A) override {4871AAKernelInfo::initialize(A);48724873CallBase &CB = cast<CallBase>(getAssociatedValue());4874auto *AssumptionAA = A.getAAFor<AAAssumptionInfo>(4875*this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);48764877// Check for SPMD-mode assumptions.4878if (AssumptionAA && AssumptionAA->hasAssumption("ompx_spmd_amenable")) {4879indicateOptimisticFixpoint();4880return;4881}48824883// First weed out calls we do not care about, that is readonly/readnone4884// calls, intrinsics, and "no_openmp" calls. Neither of these can reach a4885// parallel region or anything else we are looking for.4886if (!CB.mayWriteToMemory() || isa<IntrinsicInst>(CB)) {4887indicateOptimisticFixpoint();4888return;4889}48904891// Next we check if we know the callee. If it is a known OpenMP function4892// we will handle them explicitly in the switch below. If it is not, we4893// will use an AAKernelInfo object on the callee to gather information and4894// merge that into the current state. The latter happens in the updateImpl.4895auto CheckCallee = [&](Function *Callee, unsigned NumCallees) {4896auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());4897const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);4898if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {4899// Unknown caller or declarations are not analyzable, we give up.4900if (!Callee || !A.isFunctionIPOAmendable(*Callee)) {49014902// Unknown callees might contain parallel regions, except if they have4903// an appropriate assumption attached.4904if (!AssumptionAA ||4905!(AssumptionAA->hasAssumption("omp_no_openmp") ||4906AssumptionAA->hasAssumption("omp_no_parallelism")))4907ReachedUnknownParallelRegions.insert(&CB);49084909// If SPMDCompatibilityTracker is not fixed, we need to give up on the4910// idea we can run something unknown in SPMD-mode.4911if (!SPMDCompatibilityTracker.isAtFixpoint()) {4912SPMDCompatibilityTracker.indicatePessimisticFixpoint();4913SPMDCompatibilityTracker.insert(&CB);4914}49154916// We have updated the state for this unknown call properly, there4917// won't be any change so we indicate a fixpoint.4918indicateOptimisticFixpoint();4919}4920// If the callee is known and can be used in IPO, we will update the4921// state based on the callee state in updateImpl.4922return;4923}4924if (NumCallees > 1) {4925indicatePessimisticFixpoint();4926return;4927}49284929RuntimeFunction RF = It->getSecond();4930switch (RF) {4931// All the functions we know are compatible with SPMD mode.4932case OMPRTL___kmpc_is_spmd_exec_mode:4933case OMPRTL___kmpc_distribute_static_fini:4934case OMPRTL___kmpc_for_static_fini:4935case OMPRTL___kmpc_global_thread_num:4936case OMPRTL___kmpc_get_hardware_num_threads_in_block:4937case OMPRTL___kmpc_get_hardware_num_blocks:4938case OMPRTL___kmpc_single:4939case OMPRTL___kmpc_end_single:4940case OMPRTL___kmpc_master:4941case OMPRTL___kmpc_end_master:4942case OMPRTL___kmpc_barrier:4943case OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2:4944case OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2:4945case OMPRTL___kmpc_error:4946case OMPRTL___kmpc_flush:4947case OMPRTL___kmpc_get_hardware_thread_id_in_block:4948case OMPRTL___kmpc_get_warp_size:4949case OMPRTL_omp_get_thread_num:4950case OMPRTL_omp_get_num_threads:4951case OMPRTL_omp_get_max_threads:4952case OMPRTL_omp_in_parallel:4953case OMPRTL_omp_get_dynamic:4954case OMPRTL_omp_get_cancellation:4955case OMPRTL_omp_get_nested:4956case OMPRTL_omp_get_schedule:4957case OMPRTL_omp_get_thread_limit:4958case OMPRTL_omp_get_supported_active_levels:4959case OMPRTL_omp_get_max_active_levels:4960case OMPRTL_omp_get_level:4961case OMPRTL_omp_get_ancestor_thread_num:4962case OMPRTL_omp_get_team_size:4963case OMPRTL_omp_get_active_level:4964case OMPRTL_omp_in_final:4965case OMPRTL_omp_get_proc_bind:4966case OMPRTL_omp_get_num_places:4967case OMPRTL_omp_get_num_procs:4968case OMPRTL_omp_get_place_proc_ids:4969case OMPRTL_omp_get_place_num:4970case OMPRTL_omp_get_partition_num_places:4971case OMPRTL_omp_get_partition_place_nums:4972case OMPRTL_omp_get_wtime:4973break;4974case OMPRTL___kmpc_distribute_static_init_4:4975case OMPRTL___kmpc_distribute_static_init_4u:4976case OMPRTL___kmpc_distribute_static_init_8:4977case OMPRTL___kmpc_distribute_static_init_8u:4978case OMPRTL___kmpc_for_static_init_4:4979case OMPRTL___kmpc_for_static_init_4u:4980case OMPRTL___kmpc_for_static_init_8:4981case OMPRTL___kmpc_for_static_init_8u: {4982// Check the schedule and allow static schedule in SPMD mode.4983unsigned ScheduleArgOpNo = 2;4984auto *ScheduleTypeCI =4985dyn_cast<ConstantInt>(CB.getArgOperand(ScheduleArgOpNo));4986unsigned ScheduleTypeVal =4987ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0;4988switch (OMPScheduleType(ScheduleTypeVal)) {4989case OMPScheduleType::UnorderedStatic:4990case OMPScheduleType::UnorderedStaticChunked:4991case OMPScheduleType::OrderedDistribute:4992case OMPScheduleType::OrderedDistributeChunked:4993break;4994default:4995SPMDCompatibilityTracker.indicatePessimisticFixpoint();4996SPMDCompatibilityTracker.insert(&CB);4997break;4998};4999} break;5000case OMPRTL___kmpc_target_init:5001KernelInitCB = &CB;5002break;5003case OMPRTL___kmpc_target_deinit:5004KernelDeinitCB = &CB;5005break;5006case OMPRTL___kmpc_parallel_51:5007if (!handleParallel51(A, CB))5008indicatePessimisticFixpoint();5009return;5010case OMPRTL___kmpc_omp_task:5011// We do not look into tasks right now, just give up.5012SPMDCompatibilityTracker.indicatePessimisticFixpoint();5013SPMDCompatibilityTracker.insert(&CB);5014ReachedUnknownParallelRegions.insert(&CB);5015break;5016case OMPRTL___kmpc_alloc_shared:5017case OMPRTL___kmpc_free_shared:5018// Return without setting a fixpoint, to be resolved in updateImpl.5019return;5020default:5021// Unknown OpenMP runtime calls cannot be executed in SPMD-mode,5022// generally. However, they do not hide parallel regions.5023SPMDCompatibilityTracker.indicatePessimisticFixpoint();5024SPMDCompatibilityTracker.insert(&CB);5025break;5026}5027// All other OpenMP runtime calls will not reach parallel regions so they5028// can be safely ignored for now. Since it is a known OpenMP runtime call5029// we have now modeled all effects and there is no need for any update.5030indicateOptimisticFixpoint();5031};50325033const auto *AACE =5034A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::OPTIONAL);5035if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) {5036CheckCallee(getAssociatedFunction(), 1);5037return;5038}5039const auto &OptimisticEdges = AACE->getOptimisticEdges();5040for (auto *Callee : OptimisticEdges) {5041CheckCallee(Callee, OptimisticEdges.size());5042if (isAtFixpoint())5043break;5044}5045}50465047ChangeStatus updateImpl(Attributor &A) override {5048// TODO: Once we have call site specific value information we can provide5049// call site specific liveness information and then it makes5050// sense to specialize attributes for call sites arguments instead of5051// redirecting requests to the callee argument.5052auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());5053KernelInfoState StateBefore = getState();50545055auto CheckCallee = [&](Function *F, int NumCallees) {5056const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F);50575058// If F is not a runtime function, propagate the AAKernelInfo of the5059// callee.5060if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {5061const IRPosition &FnPos = IRPosition::function(*F);5062auto *FnAA =5063A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED);5064if (!FnAA)5065return indicatePessimisticFixpoint();5066if (getState() == FnAA->getState())5067return ChangeStatus::UNCHANGED;5068getState() = FnAA->getState();5069return ChangeStatus::CHANGED;5070}5071if (NumCallees > 1)5072return indicatePessimisticFixpoint();50735074CallBase &CB = cast<CallBase>(getAssociatedValue());5075if (It->getSecond() == OMPRTL___kmpc_parallel_51) {5076if (!handleParallel51(A, CB))5077return indicatePessimisticFixpoint();5078return StateBefore == getState() ? ChangeStatus::UNCHANGED5079: ChangeStatus::CHANGED;5080}50815082// F is a runtime function that allocates or frees memory, check5083// AAHeapToStack and AAHeapToShared.5084assert(5085(It->getSecond() == OMPRTL___kmpc_alloc_shared ||5086It->getSecond() == OMPRTL___kmpc_free_shared) &&5087"Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call");50885089auto *HeapToStackAA = A.getAAFor<AAHeapToStack>(5090*this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);5091auto *HeapToSharedAA = A.getAAFor<AAHeapToShared>(5092*this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);50935094RuntimeFunction RF = It->getSecond();50955096switch (RF) {5097// If neither HeapToStack nor HeapToShared assume the call is removed,5098// assume SPMD incompatibility.5099case OMPRTL___kmpc_alloc_shared:5100if ((!HeapToStackAA || !HeapToStackAA->isAssumedHeapToStack(CB)) &&5101(!HeapToSharedAA || !HeapToSharedAA->isAssumedHeapToShared(CB)))5102SPMDCompatibilityTracker.insert(&CB);5103break;5104case OMPRTL___kmpc_free_shared:5105if ((!HeapToStackAA ||5106!HeapToStackAA->isAssumedHeapToStackRemovedFree(CB)) &&5107(!HeapToSharedAA ||5108!HeapToSharedAA->isAssumedHeapToSharedRemovedFree(CB)))5109SPMDCompatibilityTracker.insert(&CB);5110break;5111default:5112SPMDCompatibilityTracker.indicatePessimisticFixpoint();5113SPMDCompatibilityTracker.insert(&CB);5114}5115return ChangeStatus::CHANGED;5116};51175118const auto *AACE =5119A.getAAFor<AACallEdges>(*this, getIRPosition(), DepClassTy::OPTIONAL);5120if (!AACE || !AACE->getState().isValidState() || AACE->hasUnknownCallee()) {5121if (Function *F = getAssociatedFunction())5122CheckCallee(F, /*NumCallees=*/1);5123} else {5124const auto &OptimisticEdges = AACE->getOptimisticEdges();5125for (auto *Callee : OptimisticEdges) {5126CheckCallee(Callee, OptimisticEdges.size());5127if (isAtFixpoint())5128break;5129}5130}51315132return StateBefore == getState() ? ChangeStatus::UNCHANGED5133: ChangeStatus::CHANGED;5134}51355136/// Deal with a __kmpc_parallel_51 call (\p CB). Returns true if the call was5137/// handled, if a problem occurred, false is returned.5138bool handleParallel51(Attributor &A, CallBase &CB) {5139const unsigned int NonWrapperFunctionArgNo = 5;5140const unsigned int WrapperFunctionArgNo = 6;5141auto ParallelRegionOpArgNo = SPMDCompatibilityTracker.isAssumed()5142? NonWrapperFunctionArgNo5143: WrapperFunctionArgNo;51445145auto *ParallelRegion = dyn_cast<Function>(5146CB.getArgOperand(ParallelRegionOpArgNo)->stripPointerCasts());5147if (!ParallelRegion)5148return false;51495150ReachedKnownParallelRegions.insert(&CB);5151/// Check nested parallelism5152auto *FnAA = A.getAAFor<AAKernelInfo>(5153*this, IRPosition::function(*ParallelRegion), DepClassTy::OPTIONAL);5154NestedParallelism |= !FnAA || !FnAA->getState().isValidState() ||5155!FnAA->ReachedKnownParallelRegions.empty() ||5156!FnAA->ReachedKnownParallelRegions.isValidState() ||5157!FnAA->ReachedUnknownParallelRegions.isValidState() ||5158!FnAA->ReachedUnknownParallelRegions.empty();5159return true;5160}5161};51625163struct AAFoldRuntimeCall5164: public StateWrapper<BooleanState, AbstractAttribute> {5165using Base = StateWrapper<BooleanState, AbstractAttribute>;51665167AAFoldRuntimeCall(const IRPosition &IRP, Attributor &A) : Base(IRP) {}51685169/// Statistics are tracked as part of manifest for now.5170void trackStatistics() const override {}51715172/// Create an abstract attribute biew for the position \p IRP.5173static AAFoldRuntimeCall &createForPosition(const IRPosition &IRP,5174Attributor &A);51755176/// See AbstractAttribute::getName()5177const std::string getName() const override { return "AAFoldRuntimeCall"; }51785179/// See AbstractAttribute::getIdAddr()5180const char *getIdAddr() const override { return &ID; }51815182/// This function should return true if the type of the \p AA is5183/// AAFoldRuntimeCall5184static bool classof(const AbstractAttribute *AA) {5185return (AA->getIdAddr() == &ID);5186}51875188static const char ID;5189};51905191struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {5192AAFoldRuntimeCallCallSiteReturned(const IRPosition &IRP, Attributor &A)5193: AAFoldRuntimeCall(IRP, A) {}51945195/// See AbstractAttribute::getAsStr()5196const std::string getAsStr(Attributor *) const override {5197if (!isValidState())5198return "<invalid>";51995200std::string Str("simplified value: ");52015202if (!SimplifiedValue)5203return Str + std::string("none");52045205if (!*SimplifiedValue)5206return Str + std::string("nullptr");52075208if (ConstantInt *CI = dyn_cast<ConstantInt>(*SimplifiedValue))5209return Str + std::to_string(CI->getSExtValue());52105211return Str + std::string("unknown");5212}52135214void initialize(Attributor &A) override {5215if (DisableOpenMPOptFolding)5216indicatePessimisticFixpoint();52175218Function *Callee = getAssociatedFunction();52195220auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());5221const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);5222assert(It != OMPInfoCache.RuntimeFunctionIDMap.end() &&5223"Expected a known OpenMP runtime function");52245225RFKind = It->getSecond();52265227CallBase &CB = cast<CallBase>(getAssociatedValue());5228A.registerSimplificationCallback(5229IRPosition::callsite_returned(CB),5230[&](const IRPosition &IRP, const AbstractAttribute *AA,5231bool &UsedAssumedInformation) -> std::optional<Value *> {5232assert((isValidState() ||5233(SimplifiedValue && *SimplifiedValue == nullptr)) &&5234"Unexpected invalid state!");52355236if (!isAtFixpoint()) {5237UsedAssumedInformation = true;5238if (AA)5239A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);5240}5241return SimplifiedValue;5242});5243}52445245ChangeStatus updateImpl(Attributor &A) override {5246ChangeStatus Changed = ChangeStatus::UNCHANGED;5247switch (RFKind) {5248case OMPRTL___kmpc_is_spmd_exec_mode:5249Changed |= foldIsSPMDExecMode(A);5250break;5251case OMPRTL___kmpc_parallel_level:5252Changed |= foldParallelLevel(A);5253break;5254case OMPRTL___kmpc_get_hardware_num_threads_in_block:5255Changed = Changed | foldKernelFnAttribute(A, "omp_target_thread_limit");5256break;5257case OMPRTL___kmpc_get_hardware_num_blocks:5258Changed = Changed | foldKernelFnAttribute(A, "omp_target_num_teams");5259break;5260default:5261llvm_unreachable("Unhandled OpenMP runtime function!");5262}52635264return Changed;5265}52665267ChangeStatus manifest(Attributor &A) override {5268ChangeStatus Changed = ChangeStatus::UNCHANGED;52695270if (SimplifiedValue && *SimplifiedValue) {5271Instruction &I = *getCtxI();5272A.changeAfterManifest(IRPosition::inst(I), **SimplifiedValue);5273A.deleteAfterManifest(I);52745275CallBase *CB = dyn_cast<CallBase>(&I);5276auto Remark = [&](OptimizationRemark OR) {5277if (auto *C = dyn_cast<ConstantInt>(*SimplifiedValue))5278return OR << "Replacing OpenMP runtime call "5279<< CB->getCalledFunction()->getName() << " with "5280<< ore::NV("FoldedValue", C->getZExtValue()) << ".";5281return OR << "Replacing OpenMP runtime call "5282<< CB->getCalledFunction()->getName() << ".";5283};52845285if (CB && EnableVerboseRemarks)5286A.emitRemark<OptimizationRemark>(CB, "OMP180", Remark);52875288LLVM_DEBUG(dbgs() << TAG << "Replacing runtime call: " << I << " with "5289<< **SimplifiedValue << "\n");52905291Changed = ChangeStatus::CHANGED;5292}52935294return Changed;5295}52965297ChangeStatus indicatePessimisticFixpoint() override {5298SimplifiedValue = nullptr;5299return AAFoldRuntimeCall::indicatePessimisticFixpoint();5300}53015302private:5303/// Fold __kmpc_is_spmd_exec_mode into a constant if possible.5304ChangeStatus foldIsSPMDExecMode(Attributor &A) {5305std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;53065307unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;5308unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;5309auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(5310*this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);53115312if (!CallerKernelInfoAA ||5313!CallerKernelInfoAA->ReachingKernelEntries.isValidState())5314return indicatePessimisticFixpoint();53155316for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {5317auto *AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),5318DepClassTy::REQUIRED);53195320if (!AA || !AA->isValidState()) {5321SimplifiedValue = nullptr;5322return indicatePessimisticFixpoint();5323}53245325if (AA->SPMDCompatibilityTracker.isAssumed()) {5326if (AA->SPMDCompatibilityTracker.isAtFixpoint())5327++KnownSPMDCount;5328else5329++AssumedSPMDCount;5330} else {5331if (AA->SPMDCompatibilityTracker.isAtFixpoint())5332++KnownNonSPMDCount;5333else5334++AssumedNonSPMDCount;5335}5336}53375338if ((AssumedSPMDCount + KnownSPMDCount) &&5339(AssumedNonSPMDCount + KnownNonSPMDCount))5340return indicatePessimisticFixpoint();53415342auto &Ctx = getAnchorValue().getContext();5343if (KnownSPMDCount || AssumedSPMDCount) {5344assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&5345"Expected only SPMD kernels!");5346// All reaching kernels are in SPMD mode. Update all function calls to5347// __kmpc_is_spmd_exec_mode to 1.5348SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true);5349} else if (KnownNonSPMDCount || AssumedNonSPMDCount) {5350assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&5351"Expected only non-SPMD kernels!");5352// All reaching kernels are in non-SPMD mode. Update all function5353// calls to __kmpc_is_spmd_exec_mode to 0.5354SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), false);5355} else {5356// We have empty reaching kernels, therefore we cannot tell if the5357// associated call site can be folded. At this moment, SimplifiedValue5358// must be none.5359assert(!SimplifiedValue && "SimplifiedValue should be none");5360}53615362return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED5363: ChangeStatus::CHANGED;5364}53655366/// Fold __kmpc_parallel_level into a constant if possible.5367ChangeStatus foldParallelLevel(Attributor &A) {5368std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;53695370auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(5371*this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);53725373if (!CallerKernelInfoAA ||5374!CallerKernelInfoAA->ParallelLevels.isValidState())5375return indicatePessimisticFixpoint();53765377if (!CallerKernelInfoAA->ReachingKernelEntries.isValidState())5378return indicatePessimisticFixpoint();53795380if (CallerKernelInfoAA->ReachingKernelEntries.empty()) {5381assert(!SimplifiedValue &&5382"SimplifiedValue should keep none at this point");5383return ChangeStatus::UNCHANGED;5384}53855386unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;5387unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;5388for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {5389auto *AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),5390DepClassTy::REQUIRED);5391if (!AA || !AA->SPMDCompatibilityTracker.isValidState())5392return indicatePessimisticFixpoint();53935394if (AA->SPMDCompatibilityTracker.isAssumed()) {5395if (AA->SPMDCompatibilityTracker.isAtFixpoint())5396++KnownSPMDCount;5397else5398++AssumedSPMDCount;5399} else {5400if (AA->SPMDCompatibilityTracker.isAtFixpoint())5401++KnownNonSPMDCount;5402else5403++AssumedNonSPMDCount;5404}5405}54065407if ((AssumedSPMDCount + KnownSPMDCount) &&5408(AssumedNonSPMDCount + KnownNonSPMDCount))5409return indicatePessimisticFixpoint();54105411auto &Ctx = getAnchorValue().getContext();5412// If the caller can only be reached by SPMD kernel entries, the parallel5413// level is 1. Similarly, if the caller can only be reached by non-SPMD5414// kernel entries, it is 0.5415if (AssumedSPMDCount || KnownSPMDCount) {5416assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&5417"Expected only SPMD kernels!");5418SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1);5419} else {5420assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&5421"Expected only non-SPMD kernels!");5422SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 0);5423}5424return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED5425: ChangeStatus::CHANGED;5426}54275428ChangeStatus foldKernelFnAttribute(Attributor &A, llvm::StringRef Attr) {5429// Specialize only if all the calls agree with the attribute constant value5430int32_t CurrentAttrValue = -1;5431std::optional<Value *> SimplifiedValueBefore = SimplifiedValue;54325433auto *CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(5434*this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);54355436if (!CallerKernelInfoAA ||5437!CallerKernelInfoAA->ReachingKernelEntries.isValidState())5438return indicatePessimisticFixpoint();54395440// Iterate over the kernels that reach this function5441for (Kernel K : CallerKernelInfoAA->ReachingKernelEntries) {5442int32_t NextAttrVal = K->getFnAttributeAsParsedInteger(Attr, -1);54435444if (NextAttrVal == -1 ||5445(CurrentAttrValue != -1 && CurrentAttrValue != NextAttrVal))5446return indicatePessimisticFixpoint();5447CurrentAttrValue = NextAttrVal;5448}54495450if (CurrentAttrValue != -1) {5451auto &Ctx = getAnchorValue().getContext();5452SimplifiedValue =5453ConstantInt::get(Type::getInt32Ty(Ctx), CurrentAttrValue);5454}5455return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED5456: ChangeStatus::CHANGED;5457}54585459/// An optional value the associated value is assumed to fold to. That is, we5460/// assume the associated value (which is a call) can be replaced by this5461/// simplified value.5462std::optional<Value *> SimplifiedValue;54635464/// The runtime function kind of the callee of the associated call site.5465RuntimeFunction RFKind;5466};54675468} // namespace54695470/// Register folding callsite5471void OpenMPOpt::registerFoldRuntimeCall(RuntimeFunction RF) {5472auto &RFI = OMPInfoCache.RFIs[RF];5473RFI.foreachUse(SCC, [&](Use &U, Function &F) {5474CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &RFI);5475if (!CI)5476return false;5477A.getOrCreateAAFor<AAFoldRuntimeCall>(5478IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr,5479DepClassTy::NONE, /* ForceUpdate */ false,5480/* UpdateAfterInit */ false);5481return false;5482});5483}54845485void OpenMPOpt::registerAAs(bool IsModulePass) {5486if (SCC.empty())5487return;54885489if (IsModulePass) {5490// Ensure we create the AAKernelInfo AAs first and without triggering an5491// update. This will make sure we register all value simplification5492// callbacks before any other AA has the chance to create an AAValueSimplify5493// or similar.5494auto CreateKernelInfoCB = [&](Use &, Function &Kernel) {5495A.getOrCreateAAFor<AAKernelInfo>(5496IRPosition::function(Kernel), /* QueryingAA */ nullptr,5497DepClassTy::NONE, /* ForceUpdate */ false,5498/* UpdateAfterInit */ false);5499return false;5500};5501OMPInformationCache::RuntimeFunctionInfo &InitRFI =5502OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];5503InitRFI.foreachUse(SCC, CreateKernelInfoCB);55045505registerFoldRuntimeCall(OMPRTL___kmpc_is_spmd_exec_mode);5506registerFoldRuntimeCall(OMPRTL___kmpc_parallel_level);5507registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_threads_in_block);5508registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_blocks);5509}55105511// Create CallSite AA for all Getters.5512if (DeduceICVValues) {5513for (int Idx = 0; Idx < OMPInfoCache.ICVs.size() - 1; ++Idx) {5514auto ICVInfo = OMPInfoCache.ICVs[static_cast<InternalControlVar>(Idx)];55155516auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter];55175518auto CreateAA = [&](Use &U, Function &Caller) {5519CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &GetterRFI);5520if (!CI)5521return false;55225523auto &CB = cast<CallBase>(*CI);55245525IRPosition CBPos = IRPosition::callsite_function(CB);5526A.getOrCreateAAFor<AAICVTracker>(CBPos);5527return false;5528};55295530GetterRFI.foreachUse(SCC, CreateAA);5531}5532}55335534// Create an ExecutionDomain AA for every function and a HeapToStack AA for5535// every function if there is a device kernel.5536if (!isOpenMPDevice(M))5537return;55385539for (auto *F : SCC) {5540if (F->isDeclaration())5541continue;55425543// We look at internal functions only on-demand but if any use is not a5544// direct call or outside the current set of analyzed functions, we have5545// to do it eagerly.5546if (F->hasLocalLinkage()) {5547if (llvm::all_of(F->uses(), [this](const Use &U) {5548const auto *CB = dyn_cast<CallBase>(U.getUser());5549return CB && CB->isCallee(&U) &&5550A.isRunOn(const_cast<Function *>(CB->getCaller()));5551}))5552continue;5553}5554registerAAsForFunction(A, *F);5555}5556}55575558void OpenMPOpt::registerAAsForFunction(Attributor &A, const Function &F) {5559if (!DisableOpenMPOptDeglobalization)5560A.getOrCreateAAFor<AAHeapToShared>(IRPosition::function(F));5561A.getOrCreateAAFor<AAExecutionDomain>(IRPosition::function(F));5562if (!DisableOpenMPOptDeglobalization)5563A.getOrCreateAAFor<AAHeapToStack>(IRPosition::function(F));5564if (F.hasFnAttribute(Attribute::Convergent))5565A.getOrCreateAAFor<AANonConvergent>(IRPosition::function(F));55665567for (auto &I : instructions(F)) {5568if (auto *LI = dyn_cast<LoadInst>(&I)) {5569bool UsedAssumedInformation = false;5570A.getAssumedSimplified(IRPosition::value(*LI), /* AA */ nullptr,5571UsedAssumedInformation, AA::Interprocedural);5572continue;5573}5574if (auto *CI = dyn_cast<CallBase>(&I)) {5575if (CI->isIndirectCall())5576A.getOrCreateAAFor<AAIndirectCallInfo>(5577IRPosition::callsite_function(*CI));5578}5579if (auto *SI = dyn_cast<StoreInst>(&I)) {5580A.getOrCreateAAFor<AAIsDead>(IRPosition::value(*SI));5581continue;5582}5583if (auto *FI = dyn_cast<FenceInst>(&I)) {5584A.getOrCreateAAFor<AAIsDead>(IRPosition::value(*FI));5585continue;5586}5587if (auto *II = dyn_cast<IntrinsicInst>(&I)) {5588if (II->getIntrinsicID() == Intrinsic::assume) {5589A.getOrCreateAAFor<AAPotentialValues>(5590IRPosition::value(*II->getArgOperand(0)));5591continue;5592}5593}5594}5595}55965597const char AAICVTracker::ID = 0;5598const char AAKernelInfo::ID = 0;5599const char AAExecutionDomain::ID = 0;5600const char AAHeapToShared::ID = 0;5601const char AAFoldRuntimeCall::ID = 0;56025603AAICVTracker &AAICVTracker::createForPosition(const IRPosition &IRP,5604Attributor &A) {5605AAICVTracker *AA = nullptr;5606switch (IRP.getPositionKind()) {5607case IRPosition::IRP_INVALID:5608case IRPosition::IRP_FLOAT:5609case IRPosition::IRP_ARGUMENT:5610case IRPosition::IRP_CALL_SITE_ARGUMENT:5611llvm_unreachable("ICVTracker can only be created for function position!");5612case IRPosition::IRP_RETURNED:5613AA = new (A.Allocator) AAICVTrackerFunctionReturned(IRP, A);5614break;5615case IRPosition::IRP_CALL_SITE_RETURNED:5616AA = new (A.Allocator) AAICVTrackerCallSiteReturned(IRP, A);5617break;5618case IRPosition::IRP_CALL_SITE:5619AA = new (A.Allocator) AAICVTrackerCallSite(IRP, A);5620break;5621case IRPosition::IRP_FUNCTION:5622AA = new (A.Allocator) AAICVTrackerFunction(IRP, A);5623break;5624}56255626return *AA;5627}56285629AAExecutionDomain &AAExecutionDomain::createForPosition(const IRPosition &IRP,5630Attributor &A) {5631AAExecutionDomainFunction *AA = nullptr;5632switch (IRP.getPositionKind()) {5633case IRPosition::IRP_INVALID:5634case IRPosition::IRP_FLOAT:5635case IRPosition::IRP_ARGUMENT:5636case IRPosition::IRP_CALL_SITE_ARGUMENT:5637case IRPosition::IRP_RETURNED:5638case IRPosition::IRP_CALL_SITE_RETURNED:5639case IRPosition::IRP_CALL_SITE:5640llvm_unreachable(5641"AAExecutionDomain can only be created for function position!");5642case IRPosition::IRP_FUNCTION:5643AA = new (A.Allocator) AAExecutionDomainFunction(IRP, A);5644break;5645}56465647return *AA;5648}56495650AAHeapToShared &AAHeapToShared::createForPosition(const IRPosition &IRP,5651Attributor &A) {5652AAHeapToSharedFunction *AA = nullptr;5653switch (IRP.getPositionKind()) {5654case IRPosition::IRP_INVALID:5655case IRPosition::IRP_FLOAT:5656case IRPosition::IRP_ARGUMENT:5657case IRPosition::IRP_CALL_SITE_ARGUMENT:5658case IRPosition::IRP_RETURNED:5659case IRPosition::IRP_CALL_SITE_RETURNED:5660case IRPosition::IRP_CALL_SITE:5661llvm_unreachable(5662"AAHeapToShared can only be created for function position!");5663case IRPosition::IRP_FUNCTION:5664AA = new (A.Allocator) AAHeapToSharedFunction(IRP, A);5665break;5666}56675668return *AA;5669}56705671AAKernelInfo &AAKernelInfo::createForPosition(const IRPosition &IRP,5672Attributor &A) {5673AAKernelInfo *AA = nullptr;5674switch (IRP.getPositionKind()) {5675case IRPosition::IRP_INVALID:5676case IRPosition::IRP_FLOAT:5677case IRPosition::IRP_ARGUMENT:5678case IRPosition::IRP_RETURNED:5679case IRPosition::IRP_CALL_SITE_RETURNED:5680case IRPosition::IRP_CALL_SITE_ARGUMENT:5681llvm_unreachable("KernelInfo can only be created for function position!");5682case IRPosition::IRP_CALL_SITE:5683AA = new (A.Allocator) AAKernelInfoCallSite(IRP, A);5684break;5685case IRPosition::IRP_FUNCTION:5686AA = new (A.Allocator) AAKernelInfoFunction(IRP, A);5687break;5688}56895690return *AA;5691}56925693AAFoldRuntimeCall &AAFoldRuntimeCall::createForPosition(const IRPosition &IRP,5694Attributor &A) {5695AAFoldRuntimeCall *AA = nullptr;5696switch (IRP.getPositionKind()) {5697case IRPosition::IRP_INVALID:5698case IRPosition::IRP_FLOAT:5699case IRPosition::IRP_ARGUMENT:5700case IRPosition::IRP_RETURNED:5701case IRPosition::IRP_FUNCTION:5702case IRPosition::IRP_CALL_SITE:5703case IRPosition::IRP_CALL_SITE_ARGUMENT:5704llvm_unreachable("KernelInfo can only be created for call site position!");5705case IRPosition::IRP_CALL_SITE_RETURNED:5706AA = new (A.Allocator) AAFoldRuntimeCallCallSiteReturned(IRP, A);5707break;5708}57095710return *AA;5711}57125713PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) {5714if (!containsOpenMP(M))5715return PreservedAnalyses::all();5716if (DisableOpenMPOptimizations)5717return PreservedAnalyses::all();57185719FunctionAnalysisManager &FAM =5720AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();5721KernelSet Kernels = getDeviceKernels(M);57225723if (PrintModuleBeforeOptimizations)5724LLVM_DEBUG(dbgs() << TAG << "Module before OpenMPOpt Module Pass:\n" << M);57255726auto IsCalled = [&](Function &F) {5727if (Kernels.contains(&F))5728return true;5729for (const User *U : F.users())5730if (!isa<BlockAddress>(U))5731return true;5732return false;5733};57345735auto EmitRemark = [&](Function &F) {5736auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F);5737ORE.emit([&]() {5738OptimizationRemarkAnalysis ORA(DEBUG_TYPE, "OMP140", &F);5739return ORA << "Could not internalize function. "5740<< "Some optimizations may not be possible. [OMP140]";5741});5742};57435744bool Changed = false;57455746// Create internal copies of each function if this is a kernel Module. This5747// allows iterprocedural passes to see every call edge.5748DenseMap<Function *, Function *> InternalizedMap;5749if (isOpenMPDevice(M)) {5750SmallPtrSet<Function *, 16> InternalizeFns;5751for (Function &F : M)5752if (!F.isDeclaration() && !Kernels.contains(&F) && IsCalled(F) &&5753!DisableInternalization) {5754if (Attributor::isInternalizable(F)) {5755InternalizeFns.insert(&F);5756} else if (!F.hasLocalLinkage() && !F.hasFnAttribute(Attribute::Cold)) {5757EmitRemark(F);5758}5759}57605761Changed |=5762Attributor::internalizeFunctions(InternalizeFns, InternalizedMap);5763}57645765// Look at every function in the Module unless it was internalized.5766SetVector<Function *> Functions;5767SmallVector<Function *, 16> SCC;5768for (Function &F : M)5769if (!F.isDeclaration() && !InternalizedMap.lookup(&F)) {5770SCC.push_back(&F);5771Functions.insert(&F);5772}57735774if (SCC.empty())5775return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();57765777AnalysisGetter AG(FAM);57785779auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {5780return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);5781};57825783BumpPtrAllocator Allocator;5784CallGraphUpdater CGUpdater;57855786bool PostLink = LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink ||5787LTOPhase == ThinOrFullLTOPhase::ThinLTOPreLink;5788OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ nullptr, PostLink);57895790unsigned MaxFixpointIterations =5791(isOpenMPDevice(M)) ? SetFixpointIterations : 32;57925793AttributorConfig AC(CGUpdater);5794AC.DefaultInitializeLiveInternals = false;5795AC.IsModulePass = true;5796AC.RewriteSignatures = false;5797AC.MaxFixpointIterations = MaxFixpointIterations;5798AC.OREGetter = OREGetter;5799AC.PassName = DEBUG_TYPE;5800AC.InitializationCallback = OpenMPOpt::registerAAsForFunction;5801AC.IPOAmendableCB = [](const Function &F) {5802return F.hasFnAttribute("kernel");5803};58045805Attributor A(Functions, InfoCache, AC);58065807OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);5808Changed |= OMPOpt.run(true);58095810// Optionally inline device functions for potentially better performance.5811if (AlwaysInlineDeviceFunctions && isOpenMPDevice(M))5812for (Function &F : M)5813if (!F.isDeclaration() && !Kernels.contains(&F) &&5814!F.hasFnAttribute(Attribute::NoInline))5815F.addFnAttr(Attribute::AlwaysInline);58165817if (PrintModuleAfterOptimizations)5818LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt Module Pass:\n" << M);58195820if (Changed)5821return PreservedAnalyses::none();58225823return PreservedAnalyses::all();5824}58255826PreservedAnalyses OpenMPOptCGSCCPass::run(LazyCallGraph::SCC &C,5827CGSCCAnalysisManager &AM,5828LazyCallGraph &CG,5829CGSCCUpdateResult &UR) {5830if (!containsOpenMP(*C.begin()->getFunction().getParent()))5831return PreservedAnalyses::all();5832if (DisableOpenMPOptimizations)5833return PreservedAnalyses::all();58345835SmallVector<Function *, 16> SCC;5836// If there are kernels in the module, we have to run on all SCC's.5837for (LazyCallGraph::Node &N : C) {5838Function *Fn = &N.getFunction();5839SCC.push_back(Fn);5840}58415842if (SCC.empty())5843return PreservedAnalyses::all();58445845Module &M = *C.begin()->getFunction().getParent();58465847if (PrintModuleBeforeOptimizations)5848LLVM_DEBUG(dbgs() << TAG << "Module before OpenMPOpt CGSCC Pass:\n" << M);58495850KernelSet Kernels = getDeviceKernels(M);58515852FunctionAnalysisManager &FAM =5853AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();58545855AnalysisGetter AG(FAM);58565857auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {5858return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);5859};58605861BumpPtrAllocator Allocator;5862CallGraphUpdater CGUpdater;5863CGUpdater.initialize(CG, C, AM, UR);58645865bool PostLink = LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink ||5866LTOPhase == ThinOrFullLTOPhase::ThinLTOPreLink;5867SetVector<Function *> Functions(SCC.begin(), SCC.end());5868OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator,5869/*CGSCC*/ &Functions, PostLink);58705871unsigned MaxFixpointIterations =5872(isOpenMPDevice(M)) ? SetFixpointIterations : 32;58735874AttributorConfig AC(CGUpdater);5875AC.DefaultInitializeLiveInternals = false;5876AC.IsModulePass = false;5877AC.RewriteSignatures = false;5878AC.MaxFixpointIterations = MaxFixpointIterations;5879AC.OREGetter = OREGetter;5880AC.PassName = DEBUG_TYPE;5881AC.InitializationCallback = OpenMPOpt::registerAAsForFunction;58825883Attributor A(Functions, InfoCache, AC);58845885OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);5886bool Changed = OMPOpt.run(false);58875888if (PrintModuleAfterOptimizations)5889LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt CGSCC Pass:\n" << M);58905891if (Changed)5892return PreservedAnalyses::none();58935894return PreservedAnalyses::all();5895}58965897bool llvm::omp::isOpenMPKernel(Function &Fn) {5898return Fn.hasFnAttribute("kernel");5899}59005901KernelSet llvm::omp::getDeviceKernels(Module &M) {5902// TODO: Create a more cross-platform way of determining device kernels.5903NamedMDNode *MD = M.getNamedMetadata("nvvm.annotations");5904KernelSet Kernels;59055906if (!MD)5907return Kernels;59085909for (auto *Op : MD->operands()) {5910if (Op->getNumOperands() < 2)5911continue;5912MDString *KindID = dyn_cast<MDString>(Op->getOperand(1));5913if (!KindID || KindID->getString() != "kernel")5914continue;59155916Function *KernelFn =5917mdconst::dyn_extract_or_null<Function>(Op->getOperand(0));5918if (!KernelFn)5919continue;59205921// We are only interested in OpenMP target regions. Others, such as kernels5922// generated by CUDA but linked together, are not interesting to this pass.5923if (isOpenMPKernel(*KernelFn)) {5924++NumOpenMPTargetRegionKernels;5925Kernels.insert(KernelFn);5926} else5927++NumNonOpenMPTargetRegionKernels;5928}59295930return Kernels;5931}59325933bool llvm::omp::containsOpenMP(Module &M) {5934Metadata *MD = M.getModuleFlag("openmp");5935if (!MD)5936return false;59375938return true;5939}59405941bool llvm::omp::isOpenMPDevice(Module &M) {5942Metadata *MD = M.getModuleFlag("openmp-device");5943if (!MD)5944return false;59455946return true;5947}594859495950