Path: blob/main/contrib/llvm-project/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp
35269 views
//===----------------------- AlignmentFromAssumptions.cpp -----------------===//1// Set Load/Store Alignments From Assumptions2//3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.4// See https://llvm.org/LICENSE.txt for license information.5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception6//7//===----------------------------------------------------------------------===//8//9// This file implements a ScalarEvolution-based transformation to set10// the alignments of load, stores and memory intrinsics based on the truth11// expressions of assume intrinsics. The primary motivation is to handle12// complex alignment assumptions that apply to vector loads and stores that13// appear after vectorization and unrolling.14//15//===----------------------------------------------------------------------===//1617#include "llvm/Transforms/Scalar/AlignmentFromAssumptions.h"18#include "llvm/ADT/SmallPtrSet.h"19#include "llvm/ADT/Statistic.h"20#include "llvm/Analysis/AliasAnalysis.h"21#include "llvm/Analysis/AssumptionCache.h"22#include "llvm/Analysis/GlobalsModRef.h"23#include "llvm/Analysis/LoopInfo.h"24#include "llvm/Analysis/ScalarEvolutionExpressions.h"25#include "llvm/Analysis/ValueTracking.h"26#include "llvm/IR/Dominators.h"27#include "llvm/IR/Instruction.h"28#include "llvm/IR/Instructions.h"29#include "llvm/IR/IntrinsicInst.h"30#include "llvm/Support/Debug.h"31#include "llvm/Support/raw_ostream.h"3233#define DEBUG_TYPE "alignment-from-assumptions"34using namespace llvm;3536STATISTIC(NumLoadAlignChanged,37"Number of loads changed by alignment assumptions");38STATISTIC(NumStoreAlignChanged,39"Number of stores changed by alignment assumptions");40STATISTIC(NumMemIntAlignChanged,41"Number of memory intrinsics changed by alignment assumptions");4243// Given an expression for the (constant) alignment, AlignSCEV, and an44// expression for the displacement between a pointer and the aligned address,45// DiffSCEV, compute the alignment of the displaced pointer if it can be reduced46// to a constant. Using SCEV to compute alignment handles the case where47// DiffSCEV is a recurrence with constant start such that the aligned offset48// is constant. e.g. {16,+,32} % 32 -> 16.49static MaybeAlign getNewAlignmentDiff(const SCEV *DiffSCEV,50const SCEV *AlignSCEV,51ScalarEvolution *SE) {52// DiffUnits = Diff % int64_t(Alignment)53const SCEV *DiffUnitsSCEV = SE->getURemExpr(DiffSCEV, AlignSCEV);5455LLVM_DEBUG(dbgs() << "\talignment relative to " << *AlignSCEV << " is "56<< *DiffUnitsSCEV << " (diff: " << *DiffSCEV << ")\n");5758if (const SCEVConstant *ConstDUSCEV =59dyn_cast<SCEVConstant>(DiffUnitsSCEV)) {60int64_t DiffUnits = ConstDUSCEV->getValue()->getSExtValue();6162// If the displacement is an exact multiple of the alignment, then the63// displaced pointer has the same alignment as the aligned pointer, so64// return the alignment value.65if (!DiffUnits)66return cast<SCEVConstant>(AlignSCEV)->getValue()->getAlignValue();6768// If the displacement is not an exact multiple, but the remainder is a69// constant, then return this remainder (but only if it is a power of 2).70uint64_t DiffUnitsAbs = std::abs(DiffUnits);71if (isPowerOf2_64(DiffUnitsAbs))72return Align(DiffUnitsAbs);73}7475return std::nullopt;76}7778// There is an address given by an offset OffSCEV from AASCEV which has an79// alignment AlignSCEV. Use that information, if possible, to compute a new80// alignment for Ptr.81static Align getNewAlignment(const SCEV *AASCEV, const SCEV *AlignSCEV,82const SCEV *OffSCEV, Value *Ptr,83ScalarEvolution *SE) {84const SCEV *PtrSCEV = SE->getSCEV(Ptr);8586const SCEV *DiffSCEV = SE->getMinusSCEV(PtrSCEV, AASCEV);87if (isa<SCEVCouldNotCompute>(DiffSCEV))88return Align(1);8990// On 32-bit platforms, DiffSCEV might now have type i32 -- we've always91// sign-extended OffSCEV to i64, so make sure they agree again.92DiffSCEV = SE->getNoopOrSignExtend(DiffSCEV, OffSCEV->getType());9394// What we really want to know is the overall offset to the aligned95// address. This address is displaced by the provided offset.96DiffSCEV = SE->getAddExpr(DiffSCEV, OffSCEV);9798LLVM_DEBUG(dbgs() << "AFI: alignment of " << *Ptr << " relative to "99<< *AlignSCEV << " and offset " << *OffSCEV100<< " using diff " << *DiffSCEV << "\n");101102if (MaybeAlign NewAlignment = getNewAlignmentDiff(DiffSCEV, AlignSCEV, SE)) {103LLVM_DEBUG(dbgs() << "\tnew alignment: " << DebugStr(NewAlignment) << "\n");104return *NewAlignment;105}106107if (const SCEVAddRecExpr *DiffARSCEV = dyn_cast<SCEVAddRecExpr>(DiffSCEV)) {108// The relative offset to the alignment assumption did not yield a constant,109// but we should try harder: if we assume that a is 32-byte aligned, then in110// for (i = 0; i < 1024; i += 4) r += a[i]; not all of the loads from a are111// 32-byte aligned, but instead alternate between 32 and 16-byte alignment.112// As a result, the new alignment will not be a constant, but can still113// be improved over the default (of 4) to 16.114115const SCEV *DiffStartSCEV = DiffARSCEV->getStart();116const SCEV *DiffIncSCEV = DiffARSCEV->getStepRecurrence(*SE);117118LLVM_DEBUG(dbgs() << "\ttrying start/inc alignment using start "119<< *DiffStartSCEV << " and inc " << *DiffIncSCEV << "\n");120121// Now compute the new alignment using the displacement to the value in the122// first iteration, and also the alignment using the per-iteration delta.123// If these are the same, then use that answer. Otherwise, use the smaller124// one, but only if it divides the larger one.125MaybeAlign NewAlignment = getNewAlignmentDiff(DiffStartSCEV, AlignSCEV, SE);126MaybeAlign NewIncAlignment =127getNewAlignmentDiff(DiffIncSCEV, AlignSCEV, SE);128129LLVM_DEBUG(dbgs() << "\tnew start alignment: " << DebugStr(NewAlignment)130<< "\n");131LLVM_DEBUG(dbgs() << "\tnew inc alignment: " << DebugStr(NewIncAlignment)132<< "\n");133134if (!NewAlignment || !NewIncAlignment)135return Align(1);136137const Align NewAlign = *NewAlignment;138const Align NewIncAlign = *NewIncAlignment;139if (NewAlign > NewIncAlign) {140LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: "141<< DebugStr(NewIncAlign) << "\n");142return NewIncAlign;143}144if (NewIncAlign > NewAlign) {145LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " << DebugStr(NewAlign)146<< "\n");147return NewAlign;148}149assert(NewIncAlign == NewAlign);150LLVM_DEBUG(dbgs() << "\tnew start/inc alignment: " << DebugStr(NewAlign)151<< "\n");152return NewAlign;153}154155return Align(1);156}157158bool AlignmentFromAssumptionsPass::extractAlignmentInfo(CallInst *I,159unsigned Idx,160Value *&AAPtr,161const SCEV *&AlignSCEV,162const SCEV *&OffSCEV) {163Type *Int64Ty = Type::getInt64Ty(I->getContext());164OperandBundleUse AlignOB = I->getOperandBundleAt(Idx);165if (AlignOB.getTagName() != "align")166return false;167assert(AlignOB.Inputs.size() >= 2);168AAPtr = AlignOB.Inputs[0].get();169// TODO: Consider accumulating the offset to the base.170AAPtr = AAPtr->stripPointerCastsSameRepresentation();171AlignSCEV = SE->getSCEV(AlignOB.Inputs[1].get());172AlignSCEV = SE->getTruncateOrZeroExtend(AlignSCEV, Int64Ty);173if (!isa<SCEVConstant>(AlignSCEV))174// Added to suppress a crash because consumer doesn't expect non-constant175// alignments in the assume bundle. TODO: Consider generalizing caller.176return false;177if (!cast<SCEVConstant>(AlignSCEV)->getAPInt().isPowerOf2())178// Only power of two alignments are supported.179return false;180if (AlignOB.Inputs.size() == 3)181OffSCEV = SE->getSCEV(AlignOB.Inputs[2].get());182else183OffSCEV = SE->getZero(Int64Ty);184OffSCEV = SE->getTruncateOrZeroExtend(OffSCEV, Int64Ty);185return true;186}187188bool AlignmentFromAssumptionsPass::processAssumption(CallInst *ACall,189unsigned Idx) {190Value *AAPtr;191const SCEV *AlignSCEV, *OffSCEV;192if (!extractAlignmentInfo(ACall, Idx, AAPtr, AlignSCEV, OffSCEV))193return false;194195// Skip ConstantPointerNull and UndefValue. Assumptions on these shouldn't196// affect other users.197if (isa<ConstantData>(AAPtr))198return false;199200const SCEV *AASCEV = SE->getSCEV(AAPtr);201202// Apply the assumption to all other users of the specified pointer.203SmallPtrSet<Instruction *, 32> Visited;204SmallVector<Instruction*, 16> WorkList;205for (User *J : AAPtr->users()) {206if (J == ACall)207continue;208209if (Instruction *K = dyn_cast<Instruction>(J))210WorkList.push_back(K);211}212213while (!WorkList.empty()) {214Instruction *J = WorkList.pop_back_val();215if (LoadInst *LI = dyn_cast<LoadInst>(J)) {216if (!isValidAssumeForContext(ACall, J, DT))217continue;218Align NewAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV,219LI->getPointerOperand(), SE);220if (NewAlignment > LI->getAlign()) {221LI->setAlignment(NewAlignment);222++NumLoadAlignChanged;223}224} else if (StoreInst *SI = dyn_cast<StoreInst>(J)) {225if (!isValidAssumeForContext(ACall, J, DT))226continue;227Align NewAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV,228SI->getPointerOperand(), SE);229if (NewAlignment > SI->getAlign()) {230SI->setAlignment(NewAlignment);231++NumStoreAlignChanged;232}233} else if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(J)) {234if (!isValidAssumeForContext(ACall, J, DT))235continue;236Align NewDestAlignment =237getNewAlignment(AASCEV, AlignSCEV, OffSCEV, MI->getDest(), SE);238239LLVM_DEBUG(dbgs() << "\tmem inst: " << DebugStr(NewDestAlignment)240<< "\n";);241if (NewDestAlignment > *MI->getDestAlign()) {242MI->setDestAlignment(NewDestAlignment);243++NumMemIntAlignChanged;244}245246// For memory transfers, there is also a source alignment that247// can be set.248if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(MI)) {249Align NewSrcAlignment =250getNewAlignment(AASCEV, AlignSCEV, OffSCEV, MTI->getSource(), SE);251252LLVM_DEBUG(dbgs() << "\tmem trans: " << DebugStr(NewSrcAlignment)253<< "\n";);254255if (NewSrcAlignment > *MTI->getSourceAlign()) {256MTI->setSourceAlignment(NewSrcAlignment);257++NumMemIntAlignChanged;258}259}260}261262// Now that we've updated that use of the pointer, look for other uses of263// the pointer to update.264Visited.insert(J);265if (isa<GetElementPtrInst>(J) || isa<PHINode>(J))266for (auto &U : J->uses()) {267if (U->getType()->isPointerTy()) {268Instruction *K = cast<Instruction>(U.getUser());269StoreInst *SI = dyn_cast<StoreInst>(K);270if (SI && SI->getPointerOperandIndex() != U.getOperandNo())271continue;272if (!Visited.count(K))273WorkList.push_back(K);274}275}276}277278return true;279}280281bool AlignmentFromAssumptionsPass::runImpl(Function &F, AssumptionCache &AC,282ScalarEvolution *SE_,283DominatorTree *DT_) {284SE = SE_;285DT = DT_;286287bool Changed = false;288for (auto &AssumeVH : AC.assumptions())289if (AssumeVH) {290CallInst *Call = cast<CallInst>(AssumeVH);291for (unsigned Idx = 0; Idx < Call->getNumOperandBundles(); Idx++)292Changed |= processAssumption(Call, Idx);293}294295return Changed;296}297298PreservedAnalyses299AlignmentFromAssumptionsPass::run(Function &F, FunctionAnalysisManager &AM) {300301AssumptionCache &AC = AM.getResult<AssumptionAnalysis>(F);302ScalarEvolution &SE = AM.getResult<ScalarEvolutionAnalysis>(F);303DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F);304if (!runImpl(F, AC, &SE, &DT))305return PreservedAnalyses::all();306307PreservedAnalyses PA;308PA.preserveSet<CFGAnalyses>();309PA.preserve<ScalarEvolutionAnalysis>();310return PA;311}312313314