Path: blob/main/contrib/llvm-project/llvm/lib/Target/NVPTX/NVPTXLowerAggrCopies.cpp
35271 views
//===- NVPTXLowerAggrCopies.cpp - ------------------------------*- C++ -*--===//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.3// See https://llvm.org/LICENSE.txt for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===----------------------------------------------------------------------===//7//8// \file9// Lower aggregate copies, memset, memcpy, memmov intrinsics into loops when10// the size is large or is not a compile-time constant.11//12//===----------------------------------------------------------------------===//1314#include "NVPTXLowerAggrCopies.h"15#include "llvm/Analysis/TargetTransformInfo.h"16#include "llvm/CodeGen/StackProtector.h"17#include "llvm/IR/Constants.h"18#include "llvm/IR/DataLayout.h"19#include "llvm/IR/Function.h"20#include "llvm/IR/IRBuilder.h"21#include "llvm/IR/Instructions.h"22#include "llvm/IR/IntrinsicInst.h"23#include "llvm/IR/Intrinsics.h"24#include "llvm/IR/LLVMContext.h"25#include "llvm/IR/Module.h"26#include "llvm/Support/Debug.h"27#include "llvm/Transforms/Utils/BasicBlockUtils.h"28#include "llvm/Transforms/Utils/LowerMemIntrinsics.h"2930#define DEBUG_TYPE "nvptx"3132using namespace llvm;3334namespace {3536// actual analysis class, which is a functionpass37struct NVPTXLowerAggrCopies : public FunctionPass {38static char ID;3940NVPTXLowerAggrCopies() : FunctionPass(ID) {}4142void getAnalysisUsage(AnalysisUsage &AU) const override {43AU.addPreserved<StackProtector>();44AU.addRequired<TargetTransformInfoWrapperPass>();45}4647bool runOnFunction(Function &F) override;4849static const unsigned MaxAggrCopySize = 128;5051StringRef getPassName() const override {52return "Lower aggregate copies/intrinsics into loops";53}54};5556char NVPTXLowerAggrCopies::ID = 0;5758bool NVPTXLowerAggrCopies::runOnFunction(Function &F) {59SmallVector<LoadInst *, 4> AggrLoads;60SmallVector<MemIntrinsic *, 4> MemCalls;6162const DataLayout &DL = F.getDataLayout();63LLVMContext &Context = F.getParent()->getContext();64const TargetTransformInfo &TTI =65getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);6667// Collect all aggregate loads and mem* calls.68for (BasicBlock &BB : F) {69for (Instruction &I : BB) {70if (LoadInst *LI = dyn_cast<LoadInst>(&I)) {71if (!LI->hasOneUse())72continue;7374if (DL.getTypeStoreSize(LI->getType()) < MaxAggrCopySize)75continue;7677if (StoreInst *SI = dyn_cast<StoreInst>(LI->user_back())) {78if (SI->getOperand(0) != LI)79continue;80AggrLoads.push_back(LI);81}82} else if (MemIntrinsic *IntrCall = dyn_cast<MemIntrinsic>(&I)) {83// Convert intrinsic calls with variable size or with constant size84// larger than the MaxAggrCopySize threshold.85if (ConstantInt *LenCI = dyn_cast<ConstantInt>(IntrCall->getLength())) {86if (LenCI->getZExtValue() >= MaxAggrCopySize) {87MemCalls.push_back(IntrCall);88}89} else {90MemCalls.push_back(IntrCall);91}92}93}94}9596if (AggrLoads.size() == 0 && MemCalls.size() == 0) {97return false;98}99100//101// Do the transformation of an aggr load/copy/set to a loop102//103for (LoadInst *LI : AggrLoads) {104auto *SI = cast<StoreInst>(*LI->user_begin());105Value *SrcAddr = LI->getOperand(0);106Value *DstAddr = SI->getOperand(1);107unsigned NumLoads = DL.getTypeStoreSize(LI->getType());108ConstantInt *CopyLen =109ConstantInt::get(Type::getInt32Ty(Context), NumLoads);110111createMemCpyLoopKnownSize(/* ConvertedInst */ SI,112/* SrcAddr */ SrcAddr, /* DstAddr */ DstAddr,113/* CopyLen */ CopyLen,114/* SrcAlign */ LI->getAlign(),115/* DestAlign */ SI->getAlign(),116/* SrcIsVolatile */ LI->isVolatile(),117/* DstIsVolatile */ SI->isVolatile(),118/* CanOverlap */ true, TTI);119120SI->eraseFromParent();121LI->eraseFromParent();122}123124// Transform mem* intrinsic calls.125for (MemIntrinsic *MemCall : MemCalls) {126if (MemCpyInst *Memcpy = dyn_cast<MemCpyInst>(MemCall)) {127expandMemCpyAsLoop(Memcpy, TTI);128} else if (MemMoveInst *Memmove = dyn_cast<MemMoveInst>(MemCall)) {129expandMemMoveAsLoop(Memmove, TTI);130} else if (MemSetInst *Memset = dyn_cast<MemSetInst>(MemCall)) {131expandMemSetAsLoop(Memset);132}133MemCall->eraseFromParent();134}135136return true;137}138139} // namespace140141namespace llvm {142void initializeNVPTXLowerAggrCopiesPass(PassRegistry &);143}144145INITIALIZE_PASS(NVPTXLowerAggrCopies, "nvptx-lower-aggr-copies",146"Lower aggregate copies, and llvm.mem* intrinsics into loops",147false, false)148149FunctionPass *llvm::createLowerAggrCopies() {150return new NVPTXLowerAggrCopies();151}152153154