Path: blob/main/contrib/llvm-project/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
35271 views
//===-- NVPTXAsmPrinter.cpp - NVPTX LLVM assembly writer ------------------===//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.3// See https://llvm.org/LICENSE.txt for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===----------------------------------------------------------------------===//7//8// This file contains a printer that converts from our internal representation9// of machine-dependent LLVM code to NVPTX assembly language.10//11//===----------------------------------------------------------------------===//1213#include "NVPTXAsmPrinter.h"14#include "MCTargetDesc/NVPTXBaseInfo.h"15#include "MCTargetDesc/NVPTXInstPrinter.h"16#include "MCTargetDesc/NVPTXMCAsmInfo.h"17#include "MCTargetDesc/NVPTXTargetStreamer.h"18#include "NVPTX.h"19#include "NVPTXMCExpr.h"20#include "NVPTXMachineFunctionInfo.h"21#include "NVPTXRegisterInfo.h"22#include "NVPTXSubtarget.h"23#include "NVPTXTargetMachine.h"24#include "NVPTXUtilities.h"25#include "TargetInfo/NVPTXTargetInfo.h"26#include "cl_common_defines.h"27#include "llvm/ADT/APFloat.h"28#include "llvm/ADT/APInt.h"29#include "llvm/ADT/DenseMap.h"30#include "llvm/ADT/DenseSet.h"31#include "llvm/ADT/SmallString.h"32#include "llvm/ADT/SmallVector.h"33#include "llvm/ADT/StringExtras.h"34#include "llvm/ADT/StringRef.h"35#include "llvm/ADT/Twine.h"36#include "llvm/Analysis/ConstantFolding.h"37#include "llvm/CodeGen/Analysis.h"38#include "llvm/CodeGen/MachineBasicBlock.h"39#include "llvm/CodeGen/MachineFrameInfo.h"40#include "llvm/CodeGen/MachineFunction.h"41#include "llvm/CodeGen/MachineInstr.h"42#include "llvm/CodeGen/MachineLoopInfo.h"43#include "llvm/CodeGen/MachineModuleInfo.h"44#include "llvm/CodeGen/MachineOperand.h"45#include "llvm/CodeGen/MachineRegisterInfo.h"46#include "llvm/CodeGen/TargetRegisterInfo.h"47#include "llvm/CodeGen/ValueTypes.h"48#include "llvm/CodeGenTypes/MachineValueType.h"49#include "llvm/IR/Attributes.h"50#include "llvm/IR/BasicBlock.h"51#include "llvm/IR/Constant.h"52#include "llvm/IR/Constants.h"53#include "llvm/IR/DataLayout.h"54#include "llvm/IR/DebugInfo.h"55#include "llvm/IR/DebugInfoMetadata.h"56#include "llvm/IR/DebugLoc.h"57#include "llvm/IR/DerivedTypes.h"58#include "llvm/IR/Function.h"59#include "llvm/IR/GlobalAlias.h"60#include "llvm/IR/GlobalValue.h"61#include "llvm/IR/GlobalVariable.h"62#include "llvm/IR/Instruction.h"63#include "llvm/IR/LLVMContext.h"64#include "llvm/IR/Module.h"65#include "llvm/IR/Operator.h"66#include "llvm/IR/Type.h"67#include "llvm/IR/User.h"68#include "llvm/MC/MCExpr.h"69#include "llvm/MC/MCInst.h"70#include "llvm/MC/MCInstrDesc.h"71#include "llvm/MC/MCStreamer.h"72#include "llvm/MC/MCSymbol.h"73#include "llvm/MC/TargetRegistry.h"74#include "llvm/Support/Alignment.h"75#include "llvm/Support/Casting.h"76#include "llvm/Support/CommandLine.h"77#include "llvm/Support/Endian.h"78#include "llvm/Support/ErrorHandling.h"79#include "llvm/Support/NativeFormatting.h"80#include "llvm/Support/Path.h"81#include "llvm/Support/raw_ostream.h"82#include "llvm/Target/TargetLoweringObjectFile.h"83#include "llvm/Target/TargetMachine.h"84#include "llvm/TargetParser/Triple.h"85#include "llvm/Transforms/Utils/UnrollLoop.h"86#include <cassert>87#include <cstdint>88#include <cstring>89#include <new>90#include <string>91#include <utility>92#include <vector>9394using namespace llvm;9596static cl::opt<bool>97LowerCtorDtor("nvptx-lower-global-ctor-dtor",98cl::desc("Lower GPU ctor / dtors to globals on the device."),99cl::init(false), cl::Hidden);100101#define DEPOTNAME "__local_depot"102103/// DiscoverDependentGlobals - Return a set of GlobalVariables on which \p V104/// depends.105static void106DiscoverDependentGlobals(const Value *V,107DenseSet<const GlobalVariable *> &Globals) {108if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(V))109Globals.insert(GV);110else {111if (const User *U = dyn_cast<User>(V)) {112for (unsigned i = 0, e = U->getNumOperands(); i != e; ++i) {113DiscoverDependentGlobals(U->getOperand(i), Globals);114}115}116}117}118119/// VisitGlobalVariableForEmission - Add \p GV to the list of GlobalVariable120/// instances to be emitted, but only after any dependents have been added121/// first.s122static void123VisitGlobalVariableForEmission(const GlobalVariable *GV,124SmallVectorImpl<const GlobalVariable *> &Order,125DenseSet<const GlobalVariable *> &Visited,126DenseSet<const GlobalVariable *> &Visiting) {127// Have we already visited this one?128if (Visited.count(GV))129return;130131// Do we have a circular dependency?132if (!Visiting.insert(GV).second)133report_fatal_error("Circular dependency found in global variable set");134135// Make sure we visit all dependents first136DenseSet<const GlobalVariable *> Others;137for (unsigned i = 0, e = GV->getNumOperands(); i != e; ++i)138DiscoverDependentGlobals(GV->getOperand(i), Others);139140for (const GlobalVariable *GV : Others)141VisitGlobalVariableForEmission(GV, Order, Visited, Visiting);142143// Now we can visit ourself144Order.push_back(GV);145Visited.insert(GV);146Visiting.erase(GV);147}148149void NVPTXAsmPrinter::emitInstruction(const MachineInstr *MI) {150NVPTX_MC::verifyInstructionPredicates(MI->getOpcode(),151getSubtargetInfo().getFeatureBits());152153MCInst Inst;154lowerToMCInst(MI, Inst);155EmitToStreamer(*OutStreamer, Inst);156}157158// Handle symbol backtracking for targets that do not support image handles159bool NVPTXAsmPrinter::lowerImageHandleOperand(const MachineInstr *MI,160unsigned OpNo, MCOperand &MCOp) {161const MachineOperand &MO = MI->getOperand(OpNo);162const MCInstrDesc &MCID = MI->getDesc();163164if (MCID.TSFlags & NVPTXII::IsTexFlag) {165// This is a texture fetch, so operand 4 is a texref and operand 5 is166// a samplerref167if (OpNo == 4 && MO.isImm()) {168lowerImageHandleSymbol(MO.getImm(), MCOp);169return true;170}171if (OpNo == 5 && MO.isImm() && !(MCID.TSFlags & NVPTXII::IsTexModeUnifiedFlag)) {172lowerImageHandleSymbol(MO.getImm(), MCOp);173return true;174}175176return false;177} else if (MCID.TSFlags & NVPTXII::IsSuldMask) {178unsigned VecSize =1791 << (((MCID.TSFlags & NVPTXII::IsSuldMask) >> NVPTXII::IsSuldShift) - 1);180181// For a surface load of vector size N, the Nth operand will be the surfref182if (OpNo == VecSize && MO.isImm()) {183lowerImageHandleSymbol(MO.getImm(), MCOp);184return true;185}186187return false;188} else if (MCID.TSFlags & NVPTXII::IsSustFlag) {189// This is a surface store, so operand 0 is a surfref190if (OpNo == 0 && MO.isImm()) {191lowerImageHandleSymbol(MO.getImm(), MCOp);192return true;193}194195return false;196} else if (MCID.TSFlags & NVPTXII::IsSurfTexQueryFlag) {197// This is a query, so operand 1 is a surfref/texref198if (OpNo == 1 && MO.isImm()) {199lowerImageHandleSymbol(MO.getImm(), MCOp);200return true;201}202203return false;204}205206return false;207}208209void NVPTXAsmPrinter::lowerImageHandleSymbol(unsigned Index, MCOperand &MCOp) {210// Ewwww211LLVMTargetMachine &TM = const_cast<LLVMTargetMachine&>(MF->getTarget());212NVPTXTargetMachine &nvTM = static_cast<NVPTXTargetMachine&>(TM);213const NVPTXMachineFunctionInfo *MFI = MF->getInfo<NVPTXMachineFunctionInfo>();214const char *Sym = MFI->getImageHandleSymbol(Index);215StringRef SymName = nvTM.getStrPool().save(Sym);216MCOp = GetSymbolRef(OutContext.getOrCreateSymbol(SymName));217}218219void NVPTXAsmPrinter::lowerToMCInst(const MachineInstr *MI, MCInst &OutMI) {220OutMI.setOpcode(MI->getOpcode());221// Special: Do not mangle symbol operand of CALL_PROTOTYPE222if (MI->getOpcode() == NVPTX::CALL_PROTOTYPE) {223const MachineOperand &MO = MI->getOperand(0);224OutMI.addOperand(GetSymbolRef(225OutContext.getOrCreateSymbol(Twine(MO.getSymbolName()))));226return;227}228229const NVPTXSubtarget &STI = MI->getMF()->getSubtarget<NVPTXSubtarget>();230for (unsigned i = 0, e = MI->getNumOperands(); i != e; ++i) {231const MachineOperand &MO = MI->getOperand(i);232233MCOperand MCOp;234if (!STI.hasImageHandles()) {235if (lowerImageHandleOperand(MI, i, MCOp)) {236OutMI.addOperand(MCOp);237continue;238}239}240241if (lowerOperand(MO, MCOp))242OutMI.addOperand(MCOp);243}244}245246bool NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO,247MCOperand &MCOp) {248switch (MO.getType()) {249default: llvm_unreachable("unknown operand type");250case MachineOperand::MO_Register:251MCOp = MCOperand::createReg(encodeVirtualRegister(MO.getReg()));252break;253case MachineOperand::MO_Immediate:254MCOp = MCOperand::createImm(MO.getImm());255break;256case MachineOperand::MO_MachineBasicBlock:257MCOp = MCOperand::createExpr(MCSymbolRefExpr::create(258MO.getMBB()->getSymbol(), OutContext));259break;260case MachineOperand::MO_ExternalSymbol:261MCOp = GetSymbolRef(GetExternalSymbolSymbol(MO.getSymbolName()));262break;263case MachineOperand::MO_GlobalAddress:264MCOp = GetSymbolRef(getSymbol(MO.getGlobal()));265break;266case MachineOperand::MO_FPImmediate: {267const ConstantFP *Cnt = MO.getFPImm();268const APFloat &Val = Cnt->getValueAPF();269270switch (Cnt->getType()->getTypeID()) {271default: report_fatal_error("Unsupported FP type"); break;272case Type::HalfTyID:273MCOp = MCOperand::createExpr(274NVPTXFloatMCExpr::createConstantFPHalf(Val, OutContext));275break;276case Type::BFloatTyID:277MCOp = MCOperand::createExpr(278NVPTXFloatMCExpr::createConstantBFPHalf(Val, OutContext));279break;280case Type::FloatTyID:281MCOp = MCOperand::createExpr(282NVPTXFloatMCExpr::createConstantFPSingle(Val, OutContext));283break;284case Type::DoubleTyID:285MCOp = MCOperand::createExpr(286NVPTXFloatMCExpr::createConstantFPDouble(Val, OutContext));287break;288}289break;290}291}292return true;293}294295unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg) {296if (Register::isVirtualRegister(Reg)) {297const TargetRegisterClass *RC = MRI->getRegClass(Reg);298299DenseMap<unsigned, unsigned> &RegMap = VRegMapping[RC];300unsigned RegNum = RegMap[Reg];301302// Encode the register class in the upper 4 bits303// Must be kept in sync with NVPTXInstPrinter::printRegName304unsigned Ret = 0;305if (RC == &NVPTX::Int1RegsRegClass) {306Ret = (1 << 28);307} else if (RC == &NVPTX::Int16RegsRegClass) {308Ret = (2 << 28);309} else if (RC == &NVPTX::Int32RegsRegClass) {310Ret = (3 << 28);311} else if (RC == &NVPTX::Int64RegsRegClass) {312Ret = (4 << 28);313} else if (RC == &NVPTX::Float32RegsRegClass) {314Ret = (5 << 28);315} else if (RC == &NVPTX::Float64RegsRegClass) {316Ret = (6 << 28);317} else if (RC == &NVPTX::Int128RegsRegClass) {318Ret = (7 << 28);319} else {320report_fatal_error("Bad register class");321}322323// Insert the vreg number324Ret |= (RegNum & 0x0FFFFFFF);325return Ret;326} else {327// Some special-use registers are actually physical registers.328// Encode this as the register class ID of 0 and the real register ID.329return Reg & 0x0FFFFFFF;330}331}332333MCOperand NVPTXAsmPrinter::GetSymbolRef(const MCSymbol *Symbol) {334const MCExpr *Expr;335Expr = MCSymbolRefExpr::create(Symbol, MCSymbolRefExpr::VK_None,336OutContext);337return MCOperand::createExpr(Expr);338}339340static bool ShouldPassAsArray(Type *Ty) {341return Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128) ||342Ty->isHalfTy() || Ty->isBFloatTy();343}344345void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {346const DataLayout &DL = getDataLayout();347const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F);348const auto *TLI = cast<NVPTXTargetLowering>(STI.getTargetLowering());349350Type *Ty = F->getReturnType();351352bool isABI = (STI.getSmVersion() >= 20);353354if (Ty->getTypeID() == Type::VoidTyID)355return;356O << " (";357358if (isABI) {359if ((Ty->isFloatingPointTy() || Ty->isIntegerTy()) &&360!ShouldPassAsArray(Ty)) {361unsigned size = 0;362if (auto *ITy = dyn_cast<IntegerType>(Ty)) {363size = ITy->getBitWidth();364} else {365assert(Ty->isFloatingPointTy() && "Floating point type expected here");366size = Ty->getPrimitiveSizeInBits();367}368size = promoteScalarArgumentSize(size);369O << ".param .b" << size << " func_retval0";370} else if (isa<PointerType>(Ty)) {371O << ".param .b" << TLI->getPointerTy(DL).getSizeInBits()372<< " func_retval0";373} else if (ShouldPassAsArray(Ty)) {374unsigned totalsz = DL.getTypeAllocSize(Ty);375Align RetAlignment = TLI->getFunctionArgumentAlignment(376F, Ty, AttributeList::ReturnIndex, DL);377O << ".param .align " << RetAlignment.value() << " .b8 func_retval0["378<< totalsz << "]";379} else380llvm_unreachable("Unknown return type");381} else {382SmallVector<EVT, 16> vtparts;383ComputeValueVTs(*TLI, DL, Ty, vtparts);384unsigned idx = 0;385for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {386unsigned elems = 1;387EVT elemtype = vtparts[i];388if (vtparts[i].isVector()) {389elems = vtparts[i].getVectorNumElements();390elemtype = vtparts[i].getVectorElementType();391}392393for (unsigned j = 0, je = elems; j != je; ++j) {394unsigned sz = elemtype.getSizeInBits();395if (elemtype.isInteger())396sz = promoteScalarArgumentSize(sz);397O << ".reg .b" << sz << " func_retval" << idx;398if (j < je - 1)399O << ", ";400++idx;401}402if (i < e - 1)403O << ", ";404}405}406O << ") ";407}408409void NVPTXAsmPrinter::printReturnValStr(const MachineFunction &MF,410raw_ostream &O) {411const Function &F = MF.getFunction();412printReturnValStr(&F, O);413}414415// Return true if MBB is the header of a loop marked with416// llvm.loop.unroll.disable or llvm.loop.unroll.count=1.417bool NVPTXAsmPrinter::isLoopHeaderOfNoUnroll(418const MachineBasicBlock &MBB) const {419MachineLoopInfo &LI = getAnalysis<MachineLoopInfoWrapperPass>().getLI();420// We insert .pragma "nounroll" only to the loop header.421if (!LI.isLoopHeader(&MBB))422return false;423424// llvm.loop.unroll.disable is marked on the back edges of a loop. Therefore,425// we iterate through each back edge of the loop with header MBB, and check426// whether its metadata contains llvm.loop.unroll.disable.427for (const MachineBasicBlock *PMBB : MBB.predecessors()) {428if (LI.getLoopFor(PMBB) != LI.getLoopFor(&MBB)) {429// Edges from other loops to MBB are not back edges.430continue;431}432if (const BasicBlock *PBB = PMBB->getBasicBlock()) {433if (MDNode *LoopID =434PBB->getTerminator()->getMetadata(LLVMContext::MD_loop)) {435if (GetUnrollMetadata(LoopID, "llvm.loop.unroll.disable"))436return true;437if (MDNode *UnrollCountMD =438GetUnrollMetadata(LoopID, "llvm.loop.unroll.count")) {439if (mdconst::extract<ConstantInt>(UnrollCountMD->getOperand(1))440->isOne())441return true;442}443}444}445}446return false;447}448449void NVPTXAsmPrinter::emitBasicBlockStart(const MachineBasicBlock &MBB) {450AsmPrinter::emitBasicBlockStart(MBB);451if (isLoopHeaderOfNoUnroll(MBB))452OutStreamer->emitRawText(StringRef("\t.pragma \"nounroll\";\n"));453}454455void NVPTXAsmPrinter::emitFunctionEntryLabel() {456SmallString<128> Str;457raw_svector_ostream O(Str);458459if (!GlobalsEmitted) {460emitGlobals(*MF->getFunction().getParent());461GlobalsEmitted = true;462}463464// Set up465MRI = &MF->getRegInfo();466F = &MF->getFunction();467emitLinkageDirective(F, O);468if (isKernelFunction(*F))469O << ".entry ";470else {471O << ".func ";472printReturnValStr(*MF, O);473}474475CurrentFnSym->print(O, MAI);476477emitFunctionParamList(F, O);478O << "\n";479480if (isKernelFunction(*F))481emitKernelFunctionDirectives(*F, O);482483if (shouldEmitPTXNoReturn(F, TM))484O << ".noreturn";485486OutStreamer->emitRawText(O.str());487488VRegMapping.clear();489// Emit open brace for function body.490OutStreamer->emitRawText(StringRef("{\n"));491setAndEmitFunctionVirtualRegisters(*MF);492// Emit initial .loc debug directive for correct relocation symbol data.493if (const DISubprogram *SP = MF->getFunction().getSubprogram()) {494assert(SP->getUnit());495if (!SP->getUnit()->isDebugDirectivesOnly() && MMI && MMI->hasDebugInfo())496emitInitialRawDwarfLocDirective(*MF);497}498}499500bool NVPTXAsmPrinter::runOnMachineFunction(MachineFunction &F) {501bool Result = AsmPrinter::runOnMachineFunction(F);502// Emit closing brace for the body of function F.503// The closing brace must be emitted here because we need to emit additional504// debug labels/data after the last basic block.505// We need to emit the closing brace here because we don't have function that506// finished emission of the function body.507OutStreamer->emitRawText(StringRef("}\n"));508return Result;509}510511void NVPTXAsmPrinter::emitFunctionBodyStart() {512SmallString<128> Str;513raw_svector_ostream O(Str);514emitDemotedVars(&MF->getFunction(), O);515OutStreamer->emitRawText(O.str());516}517518void NVPTXAsmPrinter::emitFunctionBodyEnd() {519VRegMapping.clear();520}521522const MCSymbol *NVPTXAsmPrinter::getFunctionFrameSymbol() const {523SmallString<128> Str;524raw_svector_ostream(Str) << DEPOTNAME << getFunctionNumber();525return OutContext.getOrCreateSymbol(Str);526}527528void NVPTXAsmPrinter::emitImplicitDef(const MachineInstr *MI) const {529Register RegNo = MI->getOperand(0).getReg();530if (RegNo.isVirtual()) {531OutStreamer->AddComment(Twine("implicit-def: ") +532getVirtualRegisterName(RegNo));533} else {534const NVPTXSubtarget &STI = MI->getMF()->getSubtarget<NVPTXSubtarget>();535OutStreamer->AddComment(Twine("implicit-def: ") +536STI.getRegisterInfo()->getName(RegNo));537}538OutStreamer->addBlankLine();539}540541void NVPTXAsmPrinter::emitKernelFunctionDirectives(const Function &F,542raw_ostream &O) const {543// If the NVVM IR has some of reqntid* specified, then output544// the reqntid directive, and set the unspecified ones to 1.545// If none of Reqntid* is specified, don't output reqntid directive.546std::optional<unsigned> Reqntidx = getReqNTIDx(F);547std::optional<unsigned> Reqntidy = getReqNTIDy(F);548std::optional<unsigned> Reqntidz = getReqNTIDz(F);549550if (Reqntidx || Reqntidy || Reqntidz)551O << ".reqntid " << Reqntidx.value_or(1) << ", " << Reqntidy.value_or(1)552<< ", " << Reqntidz.value_or(1) << "\n";553554// If the NVVM IR has some of maxntid* specified, then output555// the maxntid directive, and set the unspecified ones to 1.556// If none of maxntid* is specified, don't output maxntid directive.557std::optional<unsigned> Maxntidx = getMaxNTIDx(F);558std::optional<unsigned> Maxntidy = getMaxNTIDy(F);559std::optional<unsigned> Maxntidz = getMaxNTIDz(F);560561if (Maxntidx || Maxntidy || Maxntidz)562O << ".maxntid " << Maxntidx.value_or(1) << ", " << Maxntidy.value_or(1)563<< ", " << Maxntidz.value_or(1) << "\n";564565unsigned Mincta = 0;566if (getMinCTASm(F, Mincta))567O << ".minnctapersm " << Mincta << "\n";568569unsigned Maxnreg = 0;570if (getMaxNReg(F, Maxnreg))571O << ".maxnreg " << Maxnreg << "\n";572573// .maxclusterrank directive requires SM_90 or higher, make sure that we574// filter it out for lower SM versions, as it causes a hard ptxas crash.575const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);576const auto *STI = static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());577unsigned Maxclusterrank = 0;578if (getMaxClusterRank(F, Maxclusterrank) && STI->getSmVersion() >= 90)579O << ".maxclusterrank " << Maxclusterrank << "\n";580}581582std::string NVPTXAsmPrinter::getVirtualRegisterName(unsigned Reg) const {583const TargetRegisterClass *RC = MRI->getRegClass(Reg);584585std::string Name;586raw_string_ostream NameStr(Name);587588VRegRCMap::const_iterator I = VRegMapping.find(RC);589assert(I != VRegMapping.end() && "Bad register class");590const DenseMap<unsigned, unsigned> &RegMap = I->second;591592VRegMap::const_iterator VI = RegMap.find(Reg);593assert(VI != RegMap.end() && "Bad virtual register");594unsigned MappedVR = VI->second;595596NameStr << getNVPTXRegClassStr(RC) << MappedVR;597598NameStr.flush();599return Name;600}601602void NVPTXAsmPrinter::emitVirtualRegister(unsigned int vr,603raw_ostream &O) {604O << getVirtualRegisterName(vr);605}606607void NVPTXAsmPrinter::emitAliasDeclaration(const GlobalAlias *GA,608raw_ostream &O) {609const Function *F = dyn_cast_or_null<Function>(GA->getAliaseeObject());610if (!F || isKernelFunction(*F) || F->isDeclaration())611report_fatal_error(612"NVPTX aliasee must be a non-kernel function definition");613614if (GA->hasLinkOnceLinkage() || GA->hasWeakLinkage() ||615GA->hasAvailableExternallyLinkage() || GA->hasCommonLinkage())616report_fatal_error("NVPTX aliasee must not be '.weak'");617618emitDeclarationWithName(F, getSymbol(GA), O);619}620621void NVPTXAsmPrinter::emitDeclaration(const Function *F, raw_ostream &O) {622emitDeclarationWithName(F, getSymbol(F), O);623}624625void NVPTXAsmPrinter::emitDeclarationWithName(const Function *F, MCSymbol *S,626raw_ostream &O) {627emitLinkageDirective(F, O);628if (isKernelFunction(*F))629O << ".entry ";630else631O << ".func ";632printReturnValStr(F, O);633S->print(O, MAI);634O << "\n";635emitFunctionParamList(F, O);636O << "\n";637if (shouldEmitPTXNoReturn(F, TM))638O << ".noreturn";639O << ";\n";640}641642static bool usedInGlobalVarDef(const Constant *C) {643if (!C)644return false;645646if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(C)) {647return GV->getName() != "llvm.used";648}649650for (const User *U : C->users())651if (const Constant *C = dyn_cast<Constant>(U))652if (usedInGlobalVarDef(C))653return true;654655return false;656}657658static bool usedInOneFunc(const User *U, Function const *&oneFunc) {659if (const GlobalVariable *othergv = dyn_cast<GlobalVariable>(U)) {660if (othergv->getName() == "llvm.used")661return true;662}663664if (const Instruction *instr = dyn_cast<Instruction>(U)) {665if (instr->getParent() && instr->getParent()->getParent()) {666const Function *curFunc = instr->getParent()->getParent();667if (oneFunc && (curFunc != oneFunc))668return false;669oneFunc = curFunc;670return true;671} else672return false;673}674675for (const User *UU : U->users())676if (!usedInOneFunc(UU, oneFunc))677return false;678679return true;680}681682/* Find out if a global variable can be demoted to local scope.683* Currently, this is valid for CUDA shared variables, which have local684* scope and global lifetime. So the conditions to check are :685* 1. Is the global variable in shared address space?686* 2. Does it have local linkage?687* 3. Is the global variable referenced only in one function?688*/689static bool canDemoteGlobalVar(const GlobalVariable *gv, Function const *&f) {690if (!gv->hasLocalLinkage())691return false;692PointerType *Pty = gv->getType();693if (Pty->getAddressSpace() != ADDRESS_SPACE_SHARED)694return false;695696const Function *oneFunc = nullptr;697698bool flag = usedInOneFunc(gv, oneFunc);699if (!flag)700return false;701if (!oneFunc)702return false;703f = oneFunc;704return true;705}706707static bool useFuncSeen(const Constant *C,708DenseMap<const Function *, bool> &seenMap) {709for (const User *U : C->users()) {710if (const Constant *cu = dyn_cast<Constant>(U)) {711if (useFuncSeen(cu, seenMap))712return true;713} else if (const Instruction *I = dyn_cast<Instruction>(U)) {714const BasicBlock *bb = I->getParent();715if (!bb)716continue;717const Function *caller = bb->getParent();718if (!caller)719continue;720if (seenMap.contains(caller))721return true;722}723}724return false;725}726727void NVPTXAsmPrinter::emitDeclarations(const Module &M, raw_ostream &O) {728DenseMap<const Function *, bool> seenMap;729for (const Function &F : M) {730if (F.getAttributes().hasFnAttr("nvptx-libcall-callee")) {731emitDeclaration(&F, O);732continue;733}734735if (F.isDeclaration()) {736if (F.use_empty())737continue;738if (F.getIntrinsicID())739continue;740emitDeclaration(&F, O);741continue;742}743for (const User *U : F.users()) {744if (const Constant *C = dyn_cast<Constant>(U)) {745if (usedInGlobalVarDef(C)) {746// The use is in the initialization of a global variable747// that is a function pointer, so print a declaration748// for the original function749emitDeclaration(&F, O);750break;751}752// Emit a declaration of this function if the function that753// uses this constant expr has already been seen.754if (useFuncSeen(C, seenMap)) {755emitDeclaration(&F, O);756break;757}758}759760if (!isa<Instruction>(U))761continue;762const Instruction *instr = cast<Instruction>(U);763const BasicBlock *bb = instr->getParent();764if (!bb)765continue;766const Function *caller = bb->getParent();767if (!caller)768continue;769770// If a caller has already been seen, then the caller is771// appearing in the module before the callee. so print out772// a declaration for the callee.773if (seenMap.contains(caller)) {774emitDeclaration(&F, O);775break;776}777}778seenMap[&F] = true;779}780for (const GlobalAlias &GA : M.aliases())781emitAliasDeclaration(&GA, O);782}783784static bool isEmptyXXStructor(GlobalVariable *GV) {785if (!GV) return true;786const ConstantArray *InitList = dyn_cast<ConstantArray>(GV->getInitializer());787if (!InitList) return true; // Not an array; we don't know how to parse.788return InitList->getNumOperands() == 0;789}790791void NVPTXAsmPrinter::emitStartOfAsmFile(Module &M) {792// Construct a default subtarget off of the TargetMachine defaults. The793// rest of NVPTX isn't friendly to change subtargets per function and794// so the default TargetMachine will have all of the options.795const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);796const auto* STI = static_cast<const NVPTXSubtarget*>(NTM.getSubtargetImpl());797SmallString<128> Str1;798raw_svector_ostream OS1(Str1);799800// Emit header before any dwarf directives are emitted below.801emitHeader(M, OS1, *STI);802OutStreamer->emitRawText(OS1.str());803}804805bool NVPTXAsmPrinter::doInitialization(Module &M) {806const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);807const NVPTXSubtarget &STI =808*static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());809if (M.alias_size() && (STI.getPTXVersion() < 63 || STI.getSmVersion() < 30))810report_fatal_error(".alias requires PTX version >= 6.3 and sm_30");811812// OpenMP supports NVPTX global constructors and destructors.813bool IsOpenMP = M.getModuleFlag("openmp") != nullptr;814815if (!isEmptyXXStructor(M.getNamedGlobal("llvm.global_ctors")) &&816!LowerCtorDtor && !IsOpenMP) {817report_fatal_error(818"Module has a nontrivial global ctor, which NVPTX does not support.");819return true; // error820}821if (!isEmptyXXStructor(M.getNamedGlobal("llvm.global_dtors")) &&822!LowerCtorDtor && !IsOpenMP) {823report_fatal_error(824"Module has a nontrivial global dtor, which NVPTX does not support.");825return true; // error826}827828// We need to call the parent's one explicitly.829bool Result = AsmPrinter::doInitialization(M);830831GlobalsEmitted = false;832833return Result;834}835836void NVPTXAsmPrinter::emitGlobals(const Module &M) {837SmallString<128> Str2;838raw_svector_ostream OS2(Str2);839840emitDeclarations(M, OS2);841842// As ptxas does not support forward references of globals, we need to first843// sort the list of module-level globals in def-use order. We visit each844// global variable in order, and ensure that we emit it *after* its dependent845// globals. We use a little extra memory maintaining both a set and a list to846// have fast searches while maintaining a strict ordering.847SmallVector<const GlobalVariable *, 8> Globals;848DenseSet<const GlobalVariable *> GVVisited;849DenseSet<const GlobalVariable *> GVVisiting;850851// Visit each global variable, in order852for (const GlobalVariable &I : M.globals())853VisitGlobalVariableForEmission(&I, Globals, GVVisited, GVVisiting);854855assert(GVVisited.size() == M.global_size() && "Missed a global variable");856assert(GVVisiting.size() == 0 && "Did not fully process a global variable");857858const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);859const NVPTXSubtarget &STI =860*static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());861862// Print out module-level global variables in proper order863for (const GlobalVariable *GV : Globals)864printModuleLevelGV(GV, OS2, /*processDemoted=*/false, STI);865866OS2 << '\n';867868OutStreamer->emitRawText(OS2.str());869}870871void NVPTXAsmPrinter::emitGlobalAlias(const Module &M, const GlobalAlias &GA) {872SmallString<128> Str;873raw_svector_ostream OS(Str);874875MCSymbol *Name = getSymbol(&GA);876877OS << ".alias " << Name->getName() << ", " << GA.getAliaseeObject()->getName()878<< ";\n";879880OutStreamer->emitRawText(OS.str());881}882883void NVPTXAsmPrinter::emitHeader(Module &M, raw_ostream &O,884const NVPTXSubtarget &STI) {885O << "//\n";886O << "// Generated by LLVM NVPTX Back-End\n";887O << "//\n";888O << "\n";889890unsigned PTXVersion = STI.getPTXVersion();891O << ".version " << (PTXVersion / 10) << "." << (PTXVersion % 10) << "\n";892893O << ".target ";894O << STI.getTargetName();895896const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);897if (NTM.getDrvInterface() == NVPTX::NVCL)898O << ", texmode_independent";899900bool HasFullDebugInfo = false;901for (DICompileUnit *CU : M.debug_compile_units()) {902switch(CU->getEmissionKind()) {903case DICompileUnit::NoDebug:904case DICompileUnit::DebugDirectivesOnly:905break;906case DICompileUnit::LineTablesOnly:907case DICompileUnit::FullDebug:908HasFullDebugInfo = true;909break;910}911if (HasFullDebugInfo)912break;913}914if (MMI && MMI->hasDebugInfo() && HasFullDebugInfo)915O << ", debug";916917O << "\n";918919O << ".address_size ";920if (NTM.is64Bit())921O << "64";922else923O << "32";924O << "\n";925926O << "\n";927}928929bool NVPTXAsmPrinter::doFinalization(Module &M) {930bool HasDebugInfo = MMI && MMI->hasDebugInfo();931932// If we did not emit any functions, then the global declarations have not933// yet been emitted.934if (!GlobalsEmitted) {935emitGlobals(M);936GlobalsEmitted = true;937}938939// call doFinalization940bool ret = AsmPrinter::doFinalization(M);941942clearAnnotationCache(&M);943944auto *TS =945static_cast<NVPTXTargetStreamer *>(OutStreamer->getTargetStreamer());946// Close the last emitted section947if (HasDebugInfo) {948TS->closeLastSection();949// Emit empty .debug_loc section for better support of the empty files.950OutStreamer->emitRawText("\t.section\t.debug_loc\t{\t}");951}952953// Output last DWARF .file directives, if any.954TS->outputDwarfFileDirectives();955956return ret;957}958959// This function emits appropriate linkage directives for960// functions and global variables.961//962// extern function declaration -> .extern963// extern function definition -> .visible964// external global variable with init -> .visible965// external without init -> .extern966// appending -> not allowed, assert.967// for any linkage other than968// internal, private, linker_private,969// linker_private_weak, linker_private_weak_def_auto,970// we emit -> .weak.971972void NVPTXAsmPrinter::emitLinkageDirective(const GlobalValue *V,973raw_ostream &O) {974if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() == NVPTX::CUDA) {975if (V->hasExternalLinkage()) {976if (isa<GlobalVariable>(V)) {977const GlobalVariable *GVar = cast<GlobalVariable>(V);978if (GVar) {979if (GVar->hasInitializer())980O << ".visible ";981else982O << ".extern ";983}984} else if (V->isDeclaration())985O << ".extern ";986else987O << ".visible ";988} else if (V->hasAppendingLinkage()) {989std::string msg;990msg.append("Error: ");991msg.append("Symbol ");992if (V->hasName())993msg.append(std::string(V->getName()));994msg.append("has unsupported appending linkage type");995llvm_unreachable(msg.c_str());996} else if (!V->hasInternalLinkage() &&997!V->hasPrivateLinkage()) {998O << ".weak ";999}1000}1001}10021003void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,1004raw_ostream &O, bool processDemoted,1005const NVPTXSubtarget &STI) {1006// Skip meta data1007if (GVar->hasSection()) {1008if (GVar->getSection() == "llvm.metadata")1009return;1010}10111012// Skip LLVM intrinsic global variables1013if (GVar->getName().starts_with("llvm.") ||1014GVar->getName().starts_with("nvvm."))1015return;10161017const DataLayout &DL = getDataLayout();10181019// GlobalVariables are always constant pointers themselves.1020Type *ETy = GVar->getValueType();10211022if (GVar->hasExternalLinkage()) {1023if (GVar->hasInitializer())1024O << ".visible ";1025else1026O << ".extern ";1027} else if (STI.getPTXVersion() >= 50 && GVar->hasCommonLinkage() &&1028GVar->getAddressSpace() == ADDRESS_SPACE_GLOBAL) {1029O << ".common ";1030} else if (GVar->hasLinkOnceLinkage() || GVar->hasWeakLinkage() ||1031GVar->hasAvailableExternallyLinkage() ||1032GVar->hasCommonLinkage()) {1033O << ".weak ";1034}10351036if (isTexture(*GVar)) {1037O << ".global .texref " << getTextureName(*GVar) << ";\n";1038return;1039}10401041if (isSurface(*GVar)) {1042O << ".global .surfref " << getSurfaceName(*GVar) << ";\n";1043return;1044}10451046if (GVar->isDeclaration()) {1047// (extern) declarations, no definition or initializer1048// Currently the only known declaration is for an automatic __local1049// (.shared) promoted to global.1050emitPTXGlobalVariable(GVar, O, STI);1051O << ";\n";1052return;1053}10541055if (isSampler(*GVar)) {1056O << ".global .samplerref " << getSamplerName(*GVar);10571058const Constant *Initializer = nullptr;1059if (GVar->hasInitializer())1060Initializer = GVar->getInitializer();1061const ConstantInt *CI = nullptr;1062if (Initializer)1063CI = dyn_cast<ConstantInt>(Initializer);1064if (CI) {1065unsigned sample = CI->getZExtValue();10661067O << " = { ";10681069for (int i = 0,1070addr = ((sample & __CLK_ADDRESS_MASK) >> __CLK_ADDRESS_BASE);1071i < 3; i++) {1072O << "addr_mode_" << i << " = ";1073switch (addr) {1074case 0:1075O << "wrap";1076break;1077case 1:1078O << "clamp_to_border";1079break;1080case 2:1081O << "clamp_to_edge";1082break;1083case 3:1084O << "wrap";1085break;1086case 4:1087O << "mirror";1088break;1089}1090O << ", ";1091}1092O << "filter_mode = ";1093switch ((sample & __CLK_FILTER_MASK) >> __CLK_FILTER_BASE) {1094case 0:1095O << "nearest";1096break;1097case 1:1098O << "linear";1099break;1100case 2:1101llvm_unreachable("Anisotropic filtering is not supported");1102default:1103O << "nearest";1104break;1105}1106if (!((sample & __CLK_NORMALIZED_MASK) >> __CLK_NORMALIZED_BASE)) {1107O << ", force_unnormalized_coords = 1";1108}1109O << " }";1110}11111112O << ";\n";1113return;1114}11151116if (GVar->hasPrivateLinkage()) {1117if (strncmp(GVar->getName().data(), "unrollpragma", 12) == 0)1118return;11191120// FIXME - need better way (e.g. Metadata) to avoid generating this global1121if (strncmp(GVar->getName().data(), "filename", 8) == 0)1122return;1123if (GVar->use_empty())1124return;1125}11261127const Function *demotedFunc = nullptr;1128if (!processDemoted && canDemoteGlobalVar(GVar, demotedFunc)) {1129O << "// " << GVar->getName() << " has been demoted\n";1130if (localDecls.find(demotedFunc) != localDecls.end())1131localDecls[demotedFunc].push_back(GVar);1132else {1133std::vector<const GlobalVariable *> temp;1134temp.push_back(GVar);1135localDecls[demotedFunc] = temp;1136}1137return;1138}11391140O << ".";1141emitPTXAddressSpace(GVar->getAddressSpace(), O);11421143if (isManaged(*GVar)) {1144if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) {1145report_fatal_error(1146".attribute(.managed) requires PTX version >= 4.0 and sm_30");1147}1148O << " .attribute(.managed)";1149}11501151if (MaybeAlign A = GVar->getAlign())1152O << " .align " << A->value();1153else1154O << " .align " << (int)DL.getPrefTypeAlign(ETy).value();11551156if (ETy->isFloatingPointTy() || ETy->isPointerTy() ||1157(ETy->isIntegerTy() && ETy->getScalarSizeInBits() <= 64)) {1158O << " .";1159// Special case: ABI requires that we use .u8 for predicates1160if (ETy->isIntegerTy(1))1161O << "u8";1162else1163O << getPTXFundamentalTypeStr(ETy, false);1164O << " ";1165getSymbol(GVar)->print(O, MAI);11661167// Ptx allows variable initilization only for constant and global state1168// spaces.1169if (GVar->hasInitializer()) {1170if ((GVar->getAddressSpace() == ADDRESS_SPACE_GLOBAL) ||1171(GVar->getAddressSpace() == ADDRESS_SPACE_CONST)) {1172const Constant *Initializer = GVar->getInitializer();1173// 'undef' is treated as there is no value specified.1174if (!Initializer->isNullValue() && !isa<UndefValue>(Initializer)) {1175O << " = ";1176printScalarConstant(Initializer, O);1177}1178} else {1179// The frontend adds zero-initializer to device and constant variables1180// that don't have an initial value, and UndefValue to shared1181// variables, so skip warning for this case.1182if (!GVar->getInitializer()->isNullValue() &&1183!isa<UndefValue>(GVar->getInitializer())) {1184report_fatal_error("initial value of '" + GVar->getName() +1185"' is not allowed in addrspace(" +1186Twine(GVar->getAddressSpace()) + ")");1187}1188}1189}1190} else {1191uint64_t ElementSize = 0;11921193// Although PTX has direct support for struct type and array type and1194// LLVM IR is very similar to PTX, the LLVM CodeGen does not support for1195// targets that support these high level field accesses. Structs, arrays1196// and vectors are lowered into arrays of bytes.1197switch (ETy->getTypeID()) {1198case Type::IntegerTyID: // Integers larger than 64 bits1199case Type::StructTyID:1200case Type::ArrayTyID:1201case Type::FixedVectorTyID:1202ElementSize = DL.getTypeStoreSize(ETy);1203// Ptx allows variable initilization only for constant and1204// global state spaces.1205if (((GVar->getAddressSpace() == ADDRESS_SPACE_GLOBAL) ||1206(GVar->getAddressSpace() == ADDRESS_SPACE_CONST)) &&1207GVar->hasInitializer()) {1208const Constant *Initializer = GVar->getInitializer();1209if (!isa<UndefValue>(Initializer) && !Initializer->isNullValue()) {1210AggBuffer aggBuffer(ElementSize, *this);1211bufferAggregateConstant(Initializer, &aggBuffer);1212if (aggBuffer.numSymbols()) {1213unsigned int ptrSize = MAI->getCodePointerSize();1214if (ElementSize % ptrSize ||1215!aggBuffer.allSymbolsAligned(ptrSize)) {1216// Print in bytes and use the mask() operator for pointers.1217if (!STI.hasMaskOperator())1218report_fatal_error(1219"initialized packed aggregate with pointers '" +1220GVar->getName() +1221"' requires at least PTX ISA version 7.1");1222O << " .u8 ";1223getSymbol(GVar)->print(O, MAI);1224O << "[" << ElementSize << "] = {";1225aggBuffer.printBytes(O);1226O << "}";1227} else {1228O << " .u" << ptrSize * 8 << " ";1229getSymbol(GVar)->print(O, MAI);1230O << "[" << ElementSize / ptrSize << "] = {";1231aggBuffer.printWords(O);1232O << "}";1233}1234} else {1235O << " .b8 ";1236getSymbol(GVar)->print(O, MAI);1237O << "[" << ElementSize << "] = {";1238aggBuffer.printBytes(O);1239O << "}";1240}1241} else {1242O << " .b8 ";1243getSymbol(GVar)->print(O, MAI);1244if (ElementSize) {1245O << "[";1246O << ElementSize;1247O << "]";1248}1249}1250} else {1251O << " .b8 ";1252getSymbol(GVar)->print(O, MAI);1253if (ElementSize) {1254O << "[";1255O << ElementSize;1256O << "]";1257}1258}1259break;1260default:1261llvm_unreachable("type not supported yet");1262}1263}1264O << ";\n";1265}12661267void NVPTXAsmPrinter::AggBuffer::printSymbol(unsigned nSym, raw_ostream &os) {1268const Value *v = Symbols[nSym];1269const Value *v0 = SymbolsBeforeStripping[nSym];1270if (const GlobalValue *GVar = dyn_cast<GlobalValue>(v)) {1271MCSymbol *Name = AP.getSymbol(GVar);1272PointerType *PTy = dyn_cast<PointerType>(v0->getType());1273// Is v0 a generic pointer?1274bool isGenericPointer = PTy && PTy->getAddressSpace() == 0;1275if (EmitGeneric && isGenericPointer && !isa<Function>(v)) {1276os << "generic(";1277Name->print(os, AP.MAI);1278os << ")";1279} else {1280Name->print(os, AP.MAI);1281}1282} else if (const ConstantExpr *CExpr = dyn_cast<ConstantExpr>(v0)) {1283const MCExpr *Expr = AP.lowerConstantForGV(cast<Constant>(CExpr), false);1284AP.printMCExpr(*Expr, os);1285} else1286llvm_unreachable("symbol type unknown");1287}12881289void NVPTXAsmPrinter::AggBuffer::printBytes(raw_ostream &os) {1290unsigned int ptrSize = AP.MAI->getCodePointerSize();1291// Do not emit trailing zero initializers. They will be zero-initialized by1292// ptxas. This saves on both space requirements for the generated PTX and on1293// memory use by ptxas. (See:1294// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#global-state-space)1295unsigned int InitializerCount = size;1296// TODO: symbols make this harder, but it would still be good to trim trailing1297// 0s for aggs with symbols as well.1298if (numSymbols() == 0)1299while (InitializerCount >= 1 && !buffer[InitializerCount - 1])1300InitializerCount--;13011302symbolPosInBuffer.push_back(InitializerCount);1303unsigned int nSym = 0;1304unsigned int nextSymbolPos = symbolPosInBuffer[nSym];1305for (unsigned int pos = 0; pos < InitializerCount;) {1306if (pos)1307os << ", ";1308if (pos != nextSymbolPos) {1309os << (unsigned int)buffer[pos];1310++pos;1311continue;1312}1313// Generate a per-byte mask() operator for the symbol, which looks like:1314// .global .u8 addr[] = {0xFF(foo), 0xFF00(foo), 0xFF0000(foo), ...};1315// See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#initializers1316std::string symText;1317llvm::raw_string_ostream oss(symText);1318printSymbol(nSym, oss);1319for (unsigned i = 0; i < ptrSize; ++i) {1320if (i)1321os << ", ";1322llvm::write_hex(os, 0xFFULL << i * 8, HexPrintStyle::PrefixUpper);1323os << "(" << symText << ")";1324}1325pos += ptrSize;1326nextSymbolPos = symbolPosInBuffer[++nSym];1327assert(nextSymbolPos >= pos);1328}1329}13301331void NVPTXAsmPrinter::AggBuffer::printWords(raw_ostream &os) {1332unsigned int ptrSize = AP.MAI->getCodePointerSize();1333symbolPosInBuffer.push_back(size);1334unsigned int nSym = 0;1335unsigned int nextSymbolPos = symbolPosInBuffer[nSym];1336assert(nextSymbolPos % ptrSize == 0);1337for (unsigned int pos = 0; pos < size; pos += ptrSize) {1338if (pos)1339os << ", ";1340if (pos == nextSymbolPos) {1341printSymbol(nSym, os);1342nextSymbolPos = symbolPosInBuffer[++nSym];1343assert(nextSymbolPos % ptrSize == 0);1344assert(nextSymbolPos >= pos + ptrSize);1345} else if (ptrSize == 4)1346os << support::endian::read32le(&buffer[pos]);1347else1348os << support::endian::read64le(&buffer[pos]);1349}1350}13511352void NVPTXAsmPrinter::emitDemotedVars(const Function *f, raw_ostream &O) {1353if (localDecls.find(f) == localDecls.end())1354return;13551356std::vector<const GlobalVariable *> &gvars = localDecls[f];13571358const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);1359const NVPTXSubtarget &STI =1360*static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());13611362for (const GlobalVariable *GV : gvars) {1363O << "\t// demoted variable\n\t";1364printModuleLevelGV(GV, O, /*processDemoted=*/true, STI);1365}1366}13671368void NVPTXAsmPrinter::emitPTXAddressSpace(unsigned int AddressSpace,1369raw_ostream &O) const {1370switch (AddressSpace) {1371case ADDRESS_SPACE_LOCAL:1372O << "local";1373break;1374case ADDRESS_SPACE_GLOBAL:1375O << "global";1376break;1377case ADDRESS_SPACE_CONST:1378O << "const";1379break;1380case ADDRESS_SPACE_SHARED:1381O << "shared";1382break;1383default:1384report_fatal_error("Bad address space found while emitting PTX: " +1385llvm::Twine(AddressSpace));1386break;1387}1388}13891390std::string1391NVPTXAsmPrinter::getPTXFundamentalTypeStr(Type *Ty, bool useB4PTR) const {1392switch (Ty->getTypeID()) {1393case Type::IntegerTyID: {1394unsigned NumBits = cast<IntegerType>(Ty)->getBitWidth();1395if (NumBits == 1)1396return "pred";1397else if (NumBits <= 64) {1398std::string name = "u";1399return name + utostr(NumBits);1400} else {1401llvm_unreachable("Integer too large");1402break;1403}1404break;1405}1406case Type::BFloatTyID:1407case Type::HalfTyID:1408// fp16 and bf16 are stored as .b16 for compatibility with pre-sm_531409// PTX assembly.1410return "b16";1411case Type::FloatTyID:1412return "f32";1413case Type::DoubleTyID:1414return "f64";1415case Type::PointerTyID: {1416unsigned PtrSize = TM.getPointerSizeInBits(Ty->getPointerAddressSpace());1417assert((PtrSize == 64 || PtrSize == 32) && "Unexpected pointer size");14181419if (PtrSize == 64)1420if (useB4PTR)1421return "b64";1422else1423return "u64";1424else if (useB4PTR)1425return "b32";1426else1427return "u32";1428}1429default:1430break;1431}1432llvm_unreachable("unexpected type");1433}14341435void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar,1436raw_ostream &O,1437const NVPTXSubtarget &STI) {1438const DataLayout &DL = getDataLayout();14391440// GlobalVariables are always constant pointers themselves.1441Type *ETy = GVar->getValueType();14421443O << ".";1444emitPTXAddressSpace(GVar->getType()->getAddressSpace(), O);1445if (isManaged(*GVar)) {1446if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) {1447report_fatal_error(1448".attribute(.managed) requires PTX version >= 4.0 and sm_30");1449}1450O << " .attribute(.managed)";1451}1452if (MaybeAlign A = GVar->getAlign())1453O << " .align " << A->value();1454else1455O << " .align " << (int)DL.getPrefTypeAlign(ETy).value();14561457// Special case for i1281458if (ETy->isIntegerTy(128)) {1459O << " .b8 ";1460getSymbol(GVar)->print(O, MAI);1461O << "[16]";1462return;1463}14641465if (ETy->isFloatingPointTy() || ETy->isIntOrPtrTy()) {1466O << " .";1467O << getPTXFundamentalTypeStr(ETy);1468O << " ";1469getSymbol(GVar)->print(O, MAI);1470return;1471}14721473int64_t ElementSize = 0;14741475// Although PTX has direct support for struct type and array type and LLVM IR1476// is very similar to PTX, the LLVM CodeGen does not support for targets that1477// support these high level field accesses. Structs and arrays are lowered1478// into arrays of bytes.1479switch (ETy->getTypeID()) {1480case Type::StructTyID:1481case Type::ArrayTyID:1482case Type::FixedVectorTyID:1483ElementSize = DL.getTypeStoreSize(ETy);1484O << " .b8 ";1485getSymbol(GVar)->print(O, MAI);1486O << "[";1487if (ElementSize) {1488O << ElementSize;1489}1490O << "]";1491break;1492default:1493llvm_unreachable("type not supported yet");1494}1495}14961497void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {1498const DataLayout &DL = getDataLayout();1499const AttributeList &PAL = F->getAttributes();1500const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F);1501const auto *TLI = cast<NVPTXTargetLowering>(STI.getTargetLowering());15021503Function::const_arg_iterator I, E;1504unsigned paramIndex = 0;1505bool first = true;1506bool isKernelFunc = isKernelFunction(*F);1507bool isABI = (STI.getSmVersion() >= 20);1508bool hasImageHandles = STI.hasImageHandles();15091510if (F->arg_empty() && !F->isVarArg()) {1511O << "()";1512return;1513}15141515O << "(\n";15161517for (I = F->arg_begin(), E = F->arg_end(); I != E; ++I, paramIndex++) {1518Type *Ty = I->getType();15191520if (!first)1521O << ",\n";15221523first = false;15241525// Handle image/sampler parameters1526if (isKernelFunction(*F)) {1527if (isSampler(*I) || isImage(*I)) {1528if (isImage(*I)) {1529if (isImageWriteOnly(*I) || isImageReadWrite(*I)) {1530if (hasImageHandles)1531O << "\t.param .u64 .ptr .surfref ";1532else1533O << "\t.param .surfref ";1534O << TLI->getParamName(F, paramIndex);1535}1536else { // Default image is read_only1537if (hasImageHandles)1538O << "\t.param .u64 .ptr .texref ";1539else1540O << "\t.param .texref ";1541O << TLI->getParamName(F, paramIndex);1542}1543} else {1544if (hasImageHandles)1545O << "\t.param .u64 .ptr .samplerref ";1546else1547O << "\t.param .samplerref ";1548O << TLI->getParamName(F, paramIndex);1549}1550continue;1551}1552}15531554auto getOptimalAlignForParam = [TLI, &DL, &PAL, F,1555paramIndex](Type *Ty) -> Align {1556if (MaybeAlign StackAlign =1557getAlign(*F, paramIndex + AttributeList::FirstArgIndex))1558return StackAlign.value();15591560Align TypeAlign = TLI->getFunctionParamOptimizedAlign(F, Ty, DL);1561MaybeAlign ParamAlign = PAL.getParamAlignment(paramIndex);1562return std::max(TypeAlign, ParamAlign.valueOrOne());1563};15641565if (!PAL.hasParamAttr(paramIndex, Attribute::ByVal)) {1566if (ShouldPassAsArray(Ty)) {1567// Just print .param .align <a> .b8 .param[size];1568// <a> = optimal alignment for the element type; always multiple of1569// PAL.getParamAlignment1570// size = typeallocsize of element type1571Align OptimalAlign = getOptimalAlignForParam(Ty);15721573O << "\t.param .align " << OptimalAlign.value() << " .b8 ";1574O << TLI->getParamName(F, paramIndex);1575O << "[" << DL.getTypeAllocSize(Ty) << "]";15761577continue;1578}1579// Just a scalar1580auto *PTy = dyn_cast<PointerType>(Ty);1581unsigned PTySizeInBits = 0;1582if (PTy) {1583PTySizeInBits =1584TLI->getPointerTy(DL, PTy->getAddressSpace()).getSizeInBits();1585assert(PTySizeInBits && "Invalid pointer size");1586}15871588if (isKernelFunc) {1589if (PTy) {1590// Special handling for pointer arguments to kernel1591O << "\t.param .u" << PTySizeInBits << " ";15921593if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() !=1594NVPTX::CUDA) {1595int addrSpace = PTy->getAddressSpace();1596switch (addrSpace) {1597default:1598O << ".ptr ";1599break;1600case ADDRESS_SPACE_CONST:1601O << ".ptr .const ";1602break;1603case ADDRESS_SPACE_SHARED:1604O << ".ptr .shared ";1605break;1606case ADDRESS_SPACE_GLOBAL:1607O << ".ptr .global ";1608break;1609}1610Align ParamAlign = I->getParamAlign().valueOrOne();1611O << ".align " << ParamAlign.value() << " ";1612}1613O << TLI->getParamName(F, paramIndex);1614continue;1615}16161617// non-pointer scalar to kernel func1618O << "\t.param .";1619// Special case: predicate operands become .u8 types1620if (Ty->isIntegerTy(1))1621O << "u8";1622else1623O << getPTXFundamentalTypeStr(Ty);1624O << " ";1625O << TLI->getParamName(F, paramIndex);1626continue;1627}1628// Non-kernel function, just print .param .b<size> for ABI1629// and .reg .b<size> for non-ABI1630unsigned sz = 0;1631if (isa<IntegerType>(Ty)) {1632sz = cast<IntegerType>(Ty)->getBitWidth();1633sz = promoteScalarArgumentSize(sz);1634} else if (PTy) {1635assert(PTySizeInBits && "Invalid pointer size");1636sz = PTySizeInBits;1637} else1638sz = Ty->getPrimitiveSizeInBits();1639if (isABI)1640O << "\t.param .b" << sz << " ";1641else1642O << "\t.reg .b" << sz << " ";1643O << TLI->getParamName(F, paramIndex);1644continue;1645}16461647// param has byVal attribute.1648Type *ETy = PAL.getParamByValType(paramIndex);1649assert(ETy && "Param should have byval type");16501651if (isABI || isKernelFunc) {1652// Just print .param .align <a> .b8 .param[size];1653// <a> = optimal alignment for the element type; always multiple of1654// PAL.getParamAlignment1655// size = typeallocsize of element type1656Align OptimalAlign =1657isKernelFunc1658? getOptimalAlignForParam(ETy)1659: TLI->getFunctionByValParamAlign(1660F, ETy, PAL.getParamAlignment(paramIndex).valueOrOne(), DL);16611662unsigned sz = DL.getTypeAllocSize(ETy);1663O << "\t.param .align " << OptimalAlign.value() << " .b8 ";1664O << TLI->getParamName(F, paramIndex);1665O << "[" << sz << "]";1666continue;1667} else {1668// Split the ETy into constituent parts and1669// print .param .b<size> <name> for each part.1670// Further, if a part is vector, print the above for1671// each vector element.1672SmallVector<EVT, 16> vtparts;1673ComputeValueVTs(*TLI, DL, ETy, vtparts);1674for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {1675unsigned elems = 1;1676EVT elemtype = vtparts[i];1677if (vtparts[i].isVector()) {1678elems = vtparts[i].getVectorNumElements();1679elemtype = vtparts[i].getVectorElementType();1680}16811682for (unsigned j = 0, je = elems; j != je; ++j) {1683unsigned sz = elemtype.getSizeInBits();1684if (elemtype.isInteger())1685sz = promoteScalarArgumentSize(sz);1686O << "\t.reg .b" << sz << " ";1687O << TLI->getParamName(F, paramIndex);1688if (j < je - 1)1689O << ",\n";1690++paramIndex;1691}1692if (i < e - 1)1693O << ",\n";1694}1695--paramIndex;1696continue;1697}1698}16991700if (F->isVarArg()) {1701if (!first)1702O << ",\n";1703O << "\t.param .align " << STI.getMaxRequiredAlignment();1704O << " .b8 ";1705O << TLI->getParamName(F, /* vararg */ -1) << "[]";1706}17071708O << "\n)";1709}17101711void NVPTXAsmPrinter::setAndEmitFunctionVirtualRegisters(1712const MachineFunction &MF) {1713SmallString<128> Str;1714raw_svector_ostream O(Str);17151716// Map the global virtual register number to a register class specific1717// virtual register number starting from 1 with that class.1718const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo();1719//unsigned numRegClasses = TRI->getNumRegClasses();17201721// Emit the Fake Stack Object1722const MachineFrameInfo &MFI = MF.getFrameInfo();1723int64_t NumBytes = MFI.getStackSize();1724if (NumBytes) {1725O << "\t.local .align " << MFI.getMaxAlign().value() << " .b8 \t"1726<< DEPOTNAME << getFunctionNumber() << "[" << NumBytes << "];\n";1727if (static_cast<const NVPTXTargetMachine &>(MF.getTarget()).is64Bit()) {1728O << "\t.reg .b64 \t%SP;\n";1729O << "\t.reg .b64 \t%SPL;\n";1730} else {1731O << "\t.reg .b32 \t%SP;\n";1732O << "\t.reg .b32 \t%SPL;\n";1733}1734}17351736// Go through all virtual registers to establish the mapping between the1737// global virtual1738// register number and the per class virtual register number.1739// We use the per class virtual register number in the ptx output.1740unsigned int numVRs = MRI->getNumVirtRegs();1741for (unsigned i = 0; i < numVRs; i++) {1742Register vr = Register::index2VirtReg(i);1743const TargetRegisterClass *RC = MRI->getRegClass(vr);1744DenseMap<unsigned, unsigned> ®map = VRegMapping[RC];1745int n = regmap.size();1746regmap.insert(std::make_pair(vr, n + 1));1747}17481749// Emit register declarations1750// @TODO: Extract out the real register usage1751// O << "\t.reg .pred %p<" << NVPTXNumRegisters << ">;\n";1752// O << "\t.reg .s16 %rc<" << NVPTXNumRegisters << ">;\n";1753// O << "\t.reg .s16 %rs<" << NVPTXNumRegisters << ">;\n";1754// O << "\t.reg .s32 %r<" << NVPTXNumRegisters << ">;\n";1755// O << "\t.reg .s64 %rd<" << NVPTXNumRegisters << ">;\n";1756// O << "\t.reg .f32 %f<" << NVPTXNumRegisters << ">;\n";1757// O << "\t.reg .f64 %fd<" << NVPTXNumRegisters << ">;\n";17581759// Emit declaration of the virtual registers or 'physical' registers for1760// each register class1761for (unsigned i=0; i< TRI->getNumRegClasses(); i++) {1762const TargetRegisterClass *RC = TRI->getRegClass(i);1763DenseMap<unsigned, unsigned> ®map = VRegMapping[RC];1764std::string rcname = getNVPTXRegClassName(RC);1765std::string rcStr = getNVPTXRegClassStr(RC);1766int n = regmap.size();17671768// Only declare those registers that may be used.1769if (n) {1770O << "\t.reg " << rcname << " \t" << rcStr << "<" << (n+1)1771<< ">;\n";1772}1773}17741775OutStreamer->emitRawText(O.str());1776}17771778void NVPTXAsmPrinter::printFPConstant(const ConstantFP *Fp, raw_ostream &O) {1779APFloat APF = APFloat(Fp->getValueAPF()); // make a copy1780bool ignored;1781unsigned int numHex;1782const char *lead;17831784if (Fp->getType()->getTypeID() == Type::FloatTyID) {1785numHex = 8;1786lead = "0f";1787APF.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven, &ignored);1788} else if (Fp->getType()->getTypeID() == Type::DoubleTyID) {1789numHex = 16;1790lead = "0d";1791APF.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, &ignored);1792} else1793llvm_unreachable("unsupported fp type");17941795APInt API = APF.bitcastToAPInt();1796O << lead << format_hex_no_prefix(API.getZExtValue(), numHex, /*Upper=*/true);1797}17981799void NVPTXAsmPrinter::printScalarConstant(const Constant *CPV, raw_ostream &O) {1800if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) {1801O << CI->getValue();1802return;1803}1804if (const ConstantFP *CFP = dyn_cast<ConstantFP>(CPV)) {1805printFPConstant(CFP, O);1806return;1807}1808if (isa<ConstantPointerNull>(CPV)) {1809O << "0";1810return;1811}1812if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) {1813bool IsNonGenericPointer = false;1814if (GVar->getType()->getAddressSpace() != 0) {1815IsNonGenericPointer = true;1816}1817if (EmitGeneric && !isa<Function>(CPV) && !IsNonGenericPointer) {1818O << "generic(";1819getSymbol(GVar)->print(O, MAI);1820O << ")";1821} else {1822getSymbol(GVar)->print(O, MAI);1823}1824return;1825}1826if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) {1827const MCExpr *E = lowerConstantForGV(cast<Constant>(Cexpr), false);1828printMCExpr(*E, O);1829return;1830}1831llvm_unreachable("Not scalar type found in printScalarConstant()");1832}18331834void NVPTXAsmPrinter::bufferLEByte(const Constant *CPV, int Bytes,1835AggBuffer *AggBuffer) {1836const DataLayout &DL = getDataLayout();1837int AllocSize = DL.getTypeAllocSize(CPV->getType());1838if (isa<UndefValue>(CPV) || CPV->isNullValue()) {1839// Non-zero Bytes indicates that we need to zero-fill everything. Otherwise,1840// only the space allocated by CPV.1841AggBuffer->addZeros(Bytes ? Bytes : AllocSize);1842return;1843}18441845// Helper for filling AggBuffer with APInts.1846auto AddIntToBuffer = [AggBuffer, Bytes](const APInt &Val) {1847size_t NumBytes = (Val.getBitWidth() + 7) / 8;1848SmallVector<unsigned char, 16> Buf(NumBytes);1849// `extractBitsAsZExtValue` does not allow the extraction of bits beyond the1850// input's bit width, and i1 arrays may not have a length that is a multuple1851// of 8. We handle the last byte separately, so we never request out of1852// bounds bits.1853for (unsigned I = 0; I < NumBytes - 1; ++I) {1854Buf[I] = Val.extractBitsAsZExtValue(8, I * 8);1855}1856size_t LastBytePosition = (NumBytes - 1) * 8;1857size_t LastByteBits = Val.getBitWidth() - LastBytePosition;1858Buf[NumBytes - 1] =1859Val.extractBitsAsZExtValue(LastByteBits, LastBytePosition);1860AggBuffer->addBytes(Buf.data(), NumBytes, Bytes);1861};18621863switch (CPV->getType()->getTypeID()) {1864case Type::IntegerTyID:1865if (const auto CI = dyn_cast<ConstantInt>(CPV)) {1866AddIntToBuffer(CI->getValue());1867break;1868}1869if (const auto *Cexpr = dyn_cast<ConstantExpr>(CPV)) {1870if (const auto *CI =1871dyn_cast<ConstantInt>(ConstantFoldConstant(Cexpr, DL))) {1872AddIntToBuffer(CI->getValue());1873break;1874}1875if (Cexpr->getOpcode() == Instruction::PtrToInt) {1876Value *V = Cexpr->getOperand(0)->stripPointerCasts();1877AggBuffer->addSymbol(V, Cexpr->getOperand(0));1878AggBuffer->addZeros(AllocSize);1879break;1880}1881}1882llvm_unreachable("unsupported integer const type");1883break;18841885case Type::HalfTyID:1886case Type::BFloatTyID:1887case Type::FloatTyID:1888case Type::DoubleTyID:1889AddIntToBuffer(cast<ConstantFP>(CPV)->getValueAPF().bitcastToAPInt());1890break;18911892case Type::PointerTyID: {1893if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) {1894AggBuffer->addSymbol(GVar, GVar);1895} else if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) {1896const Value *v = Cexpr->stripPointerCasts();1897AggBuffer->addSymbol(v, Cexpr);1898}1899AggBuffer->addZeros(AllocSize);1900break;1901}19021903case Type::ArrayTyID:1904case Type::FixedVectorTyID:1905case Type::StructTyID: {1906if (isa<ConstantAggregate>(CPV) || isa<ConstantDataSequential>(CPV)) {1907bufferAggregateConstant(CPV, AggBuffer);1908if (Bytes > AllocSize)1909AggBuffer->addZeros(Bytes - AllocSize);1910} else if (isa<ConstantAggregateZero>(CPV))1911AggBuffer->addZeros(Bytes);1912else1913llvm_unreachable("Unexpected Constant type");1914break;1915}19161917default:1918llvm_unreachable("unsupported type");1919}1920}19211922void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV,1923AggBuffer *aggBuffer) {1924const DataLayout &DL = getDataLayout();1925int Bytes;19261927// Integers of arbitrary width1928if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) {1929APInt Val = CI->getValue();1930for (unsigned I = 0, E = DL.getTypeAllocSize(CPV->getType()); I < E; ++I) {1931uint8_t Byte = Val.getLoBits(8).getZExtValue();1932aggBuffer->addBytes(&Byte, 1, 1);1933Val.lshrInPlace(8);1934}1935return;1936}19371938// Old constants1939if (isa<ConstantArray>(CPV) || isa<ConstantVector>(CPV)) {1940if (CPV->getNumOperands())1941for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i)1942bufferLEByte(cast<Constant>(CPV->getOperand(i)), 0, aggBuffer);1943return;1944}19451946if (const ConstantDataSequential *CDS =1947dyn_cast<ConstantDataSequential>(CPV)) {1948if (CDS->getNumElements())1949for (unsigned i = 0; i < CDS->getNumElements(); ++i)1950bufferLEByte(cast<Constant>(CDS->getElementAsConstant(i)), 0,1951aggBuffer);1952return;1953}19541955if (isa<ConstantStruct>(CPV)) {1956if (CPV->getNumOperands()) {1957StructType *ST = cast<StructType>(CPV->getType());1958for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i) {1959if (i == (e - 1))1960Bytes = DL.getStructLayout(ST)->getElementOffset(0) +1961DL.getTypeAllocSize(ST) -1962DL.getStructLayout(ST)->getElementOffset(i);1963else1964Bytes = DL.getStructLayout(ST)->getElementOffset(i + 1) -1965DL.getStructLayout(ST)->getElementOffset(i);1966bufferLEByte(cast<Constant>(CPV->getOperand(i)), Bytes, aggBuffer);1967}1968}1969return;1970}1971llvm_unreachable("unsupported constant type in printAggregateConstant()");1972}19731974/// lowerConstantForGV - Return an MCExpr for the given Constant. This is mostly1975/// a copy from AsmPrinter::lowerConstant, except customized to only handle1976/// expressions that are representable in PTX and create1977/// NVPTXGenericMCSymbolRefExpr nodes for addrspacecast instructions.1978const MCExpr *1979NVPTXAsmPrinter::lowerConstantForGV(const Constant *CV, bool ProcessingGeneric) {1980MCContext &Ctx = OutContext;19811982if (CV->isNullValue() || isa<UndefValue>(CV))1983return MCConstantExpr::create(0, Ctx);19841985if (const ConstantInt *CI = dyn_cast<ConstantInt>(CV))1986return MCConstantExpr::create(CI->getZExtValue(), Ctx);19871988if (const GlobalValue *GV = dyn_cast<GlobalValue>(CV)) {1989const MCSymbolRefExpr *Expr =1990MCSymbolRefExpr::create(getSymbol(GV), Ctx);1991if (ProcessingGeneric) {1992return NVPTXGenericMCSymbolRefExpr::create(Expr, Ctx);1993} else {1994return Expr;1995}1996}19971998const ConstantExpr *CE = dyn_cast<ConstantExpr>(CV);1999if (!CE) {2000llvm_unreachable("Unknown constant value to lower!");2001}20022003switch (CE->getOpcode()) {2004default:2005break; // Error20062007case Instruction::AddrSpaceCast: {2008// Strip the addrspacecast and pass along the operand2009PointerType *DstTy = cast<PointerType>(CE->getType());2010if (DstTy->getAddressSpace() == 0)2011return lowerConstantForGV(cast<const Constant>(CE->getOperand(0)), true);20122013break; // Error2014}20152016case Instruction::GetElementPtr: {2017const DataLayout &DL = getDataLayout();20182019// Generate a symbolic expression for the byte address2020APInt OffsetAI(DL.getPointerTypeSizeInBits(CE->getType()), 0);2021cast<GEPOperator>(CE)->accumulateConstantOffset(DL, OffsetAI);20222023const MCExpr *Base = lowerConstantForGV(CE->getOperand(0),2024ProcessingGeneric);2025if (!OffsetAI)2026return Base;20272028int64_t Offset = OffsetAI.getSExtValue();2029return MCBinaryExpr::createAdd(Base, MCConstantExpr::create(Offset, Ctx),2030Ctx);2031}20322033case Instruction::Trunc:2034// We emit the value and depend on the assembler to truncate the generated2035// expression properly. This is important for differences between2036// blockaddress labels. Since the two labels are in the same function, it2037// is reasonable to treat their delta as a 32-bit value.2038[[fallthrough]];2039case Instruction::BitCast:2040return lowerConstantForGV(CE->getOperand(0), ProcessingGeneric);20412042case Instruction::IntToPtr: {2043const DataLayout &DL = getDataLayout();20442045// Handle casts to pointers by changing them into casts to the appropriate2046// integer type. This promotes constant folding and simplifies this code.2047Constant *Op = CE->getOperand(0);2048Op = ConstantFoldIntegerCast(Op, DL.getIntPtrType(CV->getType()),2049/*IsSigned*/ false, DL);2050if (Op)2051return lowerConstantForGV(Op, ProcessingGeneric);20522053break; // Error2054}20552056case Instruction::PtrToInt: {2057const DataLayout &DL = getDataLayout();20582059// Support only foldable casts to/from pointers that can be eliminated by2060// changing the pointer to the appropriately sized integer type.2061Constant *Op = CE->getOperand(0);2062Type *Ty = CE->getType();20632064const MCExpr *OpExpr = lowerConstantForGV(Op, ProcessingGeneric);20652066// We can emit the pointer value into this slot if the slot is an2067// integer slot equal to the size of the pointer.2068if (DL.getTypeAllocSize(Ty) == DL.getTypeAllocSize(Op->getType()))2069return OpExpr;20702071// Otherwise the pointer is smaller than the resultant integer, mask off2072// the high bits so we are sure to get a proper truncation if the input is2073// a constant expr.2074unsigned InBits = DL.getTypeAllocSizeInBits(Op->getType());2075const MCExpr *MaskExpr = MCConstantExpr::create(~0ULL >> (64-InBits), Ctx);2076return MCBinaryExpr::createAnd(OpExpr, MaskExpr, Ctx);2077}20782079// The MC library also has a right-shift operator, but it isn't consistently2080// signed or unsigned between different targets.2081case Instruction::Add: {2082const MCExpr *LHS = lowerConstantForGV(CE->getOperand(0), ProcessingGeneric);2083const MCExpr *RHS = lowerConstantForGV(CE->getOperand(1), ProcessingGeneric);2084switch (CE->getOpcode()) {2085default: llvm_unreachable("Unknown binary operator constant cast expr");2086case Instruction::Add: return MCBinaryExpr::createAdd(LHS, RHS, Ctx);2087}2088}2089}20902091// If the code isn't optimized, there may be outstanding folding2092// opportunities. Attempt to fold the expression using DataLayout as a2093// last resort before giving up.2094Constant *C = ConstantFoldConstant(CE, getDataLayout());2095if (C != CE)2096return lowerConstantForGV(C, ProcessingGeneric);20972098// Otherwise report the problem to the user.2099std::string S;2100raw_string_ostream OS(S);2101OS << "Unsupported expression in static initializer: ";2102CE->printAsOperand(OS, /*PrintType=*/false,2103!MF ? nullptr : MF->getFunction().getParent());2104report_fatal_error(Twine(OS.str()));2105}21062107// Copy of MCExpr::print customized for NVPTX2108void NVPTXAsmPrinter::printMCExpr(const MCExpr &Expr, raw_ostream &OS) {2109switch (Expr.getKind()) {2110case MCExpr::Target:2111return cast<MCTargetExpr>(&Expr)->printImpl(OS, MAI);2112case MCExpr::Constant:2113OS << cast<MCConstantExpr>(Expr).getValue();2114return;21152116case MCExpr::SymbolRef: {2117const MCSymbolRefExpr &SRE = cast<MCSymbolRefExpr>(Expr);2118const MCSymbol &Sym = SRE.getSymbol();2119Sym.print(OS, MAI);2120return;2121}21222123case MCExpr::Unary: {2124const MCUnaryExpr &UE = cast<MCUnaryExpr>(Expr);2125switch (UE.getOpcode()) {2126case MCUnaryExpr::LNot: OS << '!'; break;2127case MCUnaryExpr::Minus: OS << '-'; break;2128case MCUnaryExpr::Not: OS << '~'; break;2129case MCUnaryExpr::Plus: OS << '+'; break;2130}2131printMCExpr(*UE.getSubExpr(), OS);2132return;2133}21342135case MCExpr::Binary: {2136const MCBinaryExpr &BE = cast<MCBinaryExpr>(Expr);21372138// Only print parens around the LHS if it is non-trivial.2139if (isa<MCConstantExpr>(BE.getLHS()) || isa<MCSymbolRefExpr>(BE.getLHS()) ||2140isa<NVPTXGenericMCSymbolRefExpr>(BE.getLHS())) {2141printMCExpr(*BE.getLHS(), OS);2142} else {2143OS << '(';2144printMCExpr(*BE.getLHS(), OS);2145OS<< ')';2146}21472148switch (BE.getOpcode()) {2149case MCBinaryExpr::Add:2150// Print "X-42" instead of "X+-42".2151if (const MCConstantExpr *RHSC = dyn_cast<MCConstantExpr>(BE.getRHS())) {2152if (RHSC->getValue() < 0) {2153OS << RHSC->getValue();2154return;2155}2156}21572158OS << '+';2159break;2160default: llvm_unreachable("Unhandled binary operator");2161}21622163// Only print parens around the LHS if it is non-trivial.2164if (isa<MCConstantExpr>(BE.getRHS()) || isa<MCSymbolRefExpr>(BE.getRHS())) {2165printMCExpr(*BE.getRHS(), OS);2166} else {2167OS << '(';2168printMCExpr(*BE.getRHS(), OS);2169OS << ')';2170}2171return;2172}2173}21742175llvm_unreachable("Invalid expression kind!");2176}21772178/// PrintAsmOperand - Print out an operand for an inline asm expression.2179///2180bool NVPTXAsmPrinter::PrintAsmOperand(const MachineInstr *MI, unsigned OpNo,2181const char *ExtraCode, raw_ostream &O) {2182if (ExtraCode && ExtraCode[0]) {2183if (ExtraCode[1] != 0)2184return true; // Unknown modifier.21852186switch (ExtraCode[0]) {2187default:2188// See if this is a generic print operand2189return AsmPrinter::PrintAsmOperand(MI, OpNo, ExtraCode, O);2190case 'r':2191break;2192}2193}21942195printOperand(MI, OpNo, O);21962197return false;2198}21992200bool NVPTXAsmPrinter::PrintAsmMemoryOperand(const MachineInstr *MI,2201unsigned OpNo,2202const char *ExtraCode,2203raw_ostream &O) {2204if (ExtraCode && ExtraCode[0])2205return true; // Unknown modifier22062207O << '[';2208printMemOperand(MI, OpNo, O);2209O << ']';22102211return false;2212}22132214void NVPTXAsmPrinter::printOperand(const MachineInstr *MI, unsigned OpNum,2215raw_ostream &O) {2216const MachineOperand &MO = MI->getOperand(OpNum);2217switch (MO.getType()) {2218case MachineOperand::MO_Register:2219if (MO.getReg().isPhysical()) {2220if (MO.getReg() == NVPTX::VRDepot)2221O << DEPOTNAME << getFunctionNumber();2222else2223O << NVPTXInstPrinter::getRegisterName(MO.getReg());2224} else {2225emitVirtualRegister(MO.getReg(), O);2226}2227break;22282229case MachineOperand::MO_Immediate:2230O << MO.getImm();2231break;22322233case MachineOperand::MO_FPImmediate:2234printFPConstant(MO.getFPImm(), O);2235break;22362237case MachineOperand::MO_GlobalAddress:2238PrintSymbolOperand(MO, O);2239break;22402241case MachineOperand::MO_MachineBasicBlock:2242MO.getMBB()->getSymbol()->print(O, MAI);2243break;22442245default:2246llvm_unreachable("Operand type not supported.");2247}2248}22492250void NVPTXAsmPrinter::printMemOperand(const MachineInstr *MI, unsigned OpNum,2251raw_ostream &O, const char *Modifier) {2252printOperand(MI, OpNum, O);22532254if (Modifier && strcmp(Modifier, "add") == 0) {2255O << ", ";2256printOperand(MI, OpNum + 1, O);2257} else {2258if (MI->getOperand(OpNum + 1).isImm() &&2259MI->getOperand(OpNum + 1).getImm() == 0)2260return; // don't print ',0' or '+0'2261O << "+";2262printOperand(MI, OpNum + 1, O);2263}2264}22652266// Force static initialization.2267extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeNVPTXAsmPrinter() {2268RegisterAsmPrinter<NVPTXAsmPrinter> X(getTheNVPTXTarget32());2269RegisterAsmPrinter<NVPTXAsmPrinter> Y(getTheNVPTXTarget64());2270}227122722273