Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
freebsd
GitHub Repository: freebsd/freebsd-src
Path: blob/main/contrib/llvm-project/llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
35271 views
1
//===-- NVPTXAsmPrinter.cpp - NVPTX LLVM assembly writer ------------------===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
//
9
// This file contains a printer that converts from our internal representation
10
// of machine-dependent LLVM code to NVPTX assembly language.
11
//
12
//===----------------------------------------------------------------------===//
13
14
#include "NVPTXAsmPrinter.h"
15
#include "MCTargetDesc/NVPTXBaseInfo.h"
16
#include "MCTargetDesc/NVPTXInstPrinter.h"
17
#include "MCTargetDesc/NVPTXMCAsmInfo.h"
18
#include "MCTargetDesc/NVPTXTargetStreamer.h"
19
#include "NVPTX.h"
20
#include "NVPTXMCExpr.h"
21
#include "NVPTXMachineFunctionInfo.h"
22
#include "NVPTXRegisterInfo.h"
23
#include "NVPTXSubtarget.h"
24
#include "NVPTXTargetMachine.h"
25
#include "NVPTXUtilities.h"
26
#include "TargetInfo/NVPTXTargetInfo.h"
27
#include "cl_common_defines.h"
28
#include "llvm/ADT/APFloat.h"
29
#include "llvm/ADT/APInt.h"
30
#include "llvm/ADT/DenseMap.h"
31
#include "llvm/ADT/DenseSet.h"
32
#include "llvm/ADT/SmallString.h"
33
#include "llvm/ADT/SmallVector.h"
34
#include "llvm/ADT/StringExtras.h"
35
#include "llvm/ADT/StringRef.h"
36
#include "llvm/ADT/Twine.h"
37
#include "llvm/Analysis/ConstantFolding.h"
38
#include "llvm/CodeGen/Analysis.h"
39
#include "llvm/CodeGen/MachineBasicBlock.h"
40
#include "llvm/CodeGen/MachineFrameInfo.h"
41
#include "llvm/CodeGen/MachineFunction.h"
42
#include "llvm/CodeGen/MachineInstr.h"
43
#include "llvm/CodeGen/MachineLoopInfo.h"
44
#include "llvm/CodeGen/MachineModuleInfo.h"
45
#include "llvm/CodeGen/MachineOperand.h"
46
#include "llvm/CodeGen/MachineRegisterInfo.h"
47
#include "llvm/CodeGen/TargetRegisterInfo.h"
48
#include "llvm/CodeGen/ValueTypes.h"
49
#include "llvm/CodeGenTypes/MachineValueType.h"
50
#include "llvm/IR/Attributes.h"
51
#include "llvm/IR/BasicBlock.h"
52
#include "llvm/IR/Constant.h"
53
#include "llvm/IR/Constants.h"
54
#include "llvm/IR/DataLayout.h"
55
#include "llvm/IR/DebugInfo.h"
56
#include "llvm/IR/DebugInfoMetadata.h"
57
#include "llvm/IR/DebugLoc.h"
58
#include "llvm/IR/DerivedTypes.h"
59
#include "llvm/IR/Function.h"
60
#include "llvm/IR/GlobalAlias.h"
61
#include "llvm/IR/GlobalValue.h"
62
#include "llvm/IR/GlobalVariable.h"
63
#include "llvm/IR/Instruction.h"
64
#include "llvm/IR/LLVMContext.h"
65
#include "llvm/IR/Module.h"
66
#include "llvm/IR/Operator.h"
67
#include "llvm/IR/Type.h"
68
#include "llvm/IR/User.h"
69
#include "llvm/MC/MCExpr.h"
70
#include "llvm/MC/MCInst.h"
71
#include "llvm/MC/MCInstrDesc.h"
72
#include "llvm/MC/MCStreamer.h"
73
#include "llvm/MC/MCSymbol.h"
74
#include "llvm/MC/TargetRegistry.h"
75
#include "llvm/Support/Alignment.h"
76
#include "llvm/Support/Casting.h"
77
#include "llvm/Support/CommandLine.h"
78
#include "llvm/Support/Endian.h"
79
#include "llvm/Support/ErrorHandling.h"
80
#include "llvm/Support/NativeFormatting.h"
81
#include "llvm/Support/Path.h"
82
#include "llvm/Support/raw_ostream.h"
83
#include "llvm/Target/TargetLoweringObjectFile.h"
84
#include "llvm/Target/TargetMachine.h"
85
#include "llvm/TargetParser/Triple.h"
86
#include "llvm/Transforms/Utils/UnrollLoop.h"
87
#include <cassert>
88
#include <cstdint>
89
#include <cstring>
90
#include <new>
91
#include <string>
92
#include <utility>
93
#include <vector>
94
95
using namespace llvm;
96
97
static cl::opt<bool>
98
LowerCtorDtor("nvptx-lower-global-ctor-dtor",
99
cl::desc("Lower GPU ctor / dtors to globals on the device."),
100
cl::init(false), cl::Hidden);
101
102
#define DEPOTNAME "__local_depot"
103
104
/// DiscoverDependentGlobals - Return a set of GlobalVariables on which \p V
105
/// depends.
106
static void
107
DiscoverDependentGlobals(const Value *V,
108
DenseSet<const GlobalVariable *> &Globals) {
109
if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(V))
110
Globals.insert(GV);
111
else {
112
if (const User *U = dyn_cast<User>(V)) {
113
for (unsigned i = 0, e = U->getNumOperands(); i != e; ++i) {
114
DiscoverDependentGlobals(U->getOperand(i), Globals);
115
}
116
}
117
}
118
}
119
120
/// VisitGlobalVariableForEmission - Add \p GV to the list of GlobalVariable
121
/// instances to be emitted, but only after any dependents have been added
122
/// first.s
123
static void
124
VisitGlobalVariableForEmission(const GlobalVariable *GV,
125
SmallVectorImpl<const GlobalVariable *> &Order,
126
DenseSet<const GlobalVariable *> &Visited,
127
DenseSet<const GlobalVariable *> &Visiting) {
128
// Have we already visited this one?
129
if (Visited.count(GV))
130
return;
131
132
// Do we have a circular dependency?
133
if (!Visiting.insert(GV).second)
134
report_fatal_error("Circular dependency found in global variable set");
135
136
// Make sure we visit all dependents first
137
DenseSet<const GlobalVariable *> Others;
138
for (unsigned i = 0, e = GV->getNumOperands(); i != e; ++i)
139
DiscoverDependentGlobals(GV->getOperand(i), Others);
140
141
for (const GlobalVariable *GV : Others)
142
VisitGlobalVariableForEmission(GV, Order, Visited, Visiting);
143
144
// Now we can visit ourself
145
Order.push_back(GV);
146
Visited.insert(GV);
147
Visiting.erase(GV);
148
}
149
150
void NVPTXAsmPrinter::emitInstruction(const MachineInstr *MI) {
151
NVPTX_MC::verifyInstructionPredicates(MI->getOpcode(),
152
getSubtargetInfo().getFeatureBits());
153
154
MCInst Inst;
155
lowerToMCInst(MI, Inst);
156
EmitToStreamer(*OutStreamer, Inst);
157
}
158
159
// Handle symbol backtracking for targets that do not support image handles
160
bool NVPTXAsmPrinter::lowerImageHandleOperand(const MachineInstr *MI,
161
unsigned OpNo, MCOperand &MCOp) {
162
const MachineOperand &MO = MI->getOperand(OpNo);
163
const MCInstrDesc &MCID = MI->getDesc();
164
165
if (MCID.TSFlags & NVPTXII::IsTexFlag) {
166
// This is a texture fetch, so operand 4 is a texref and operand 5 is
167
// a samplerref
168
if (OpNo == 4 && MO.isImm()) {
169
lowerImageHandleSymbol(MO.getImm(), MCOp);
170
return true;
171
}
172
if (OpNo == 5 && MO.isImm() && !(MCID.TSFlags & NVPTXII::IsTexModeUnifiedFlag)) {
173
lowerImageHandleSymbol(MO.getImm(), MCOp);
174
return true;
175
}
176
177
return false;
178
} else if (MCID.TSFlags & NVPTXII::IsSuldMask) {
179
unsigned VecSize =
180
1 << (((MCID.TSFlags & NVPTXII::IsSuldMask) >> NVPTXII::IsSuldShift) - 1);
181
182
// For a surface load of vector size N, the Nth operand will be the surfref
183
if (OpNo == VecSize && MO.isImm()) {
184
lowerImageHandleSymbol(MO.getImm(), MCOp);
185
return true;
186
}
187
188
return false;
189
} else if (MCID.TSFlags & NVPTXII::IsSustFlag) {
190
// This is a surface store, so operand 0 is a surfref
191
if (OpNo == 0 && MO.isImm()) {
192
lowerImageHandleSymbol(MO.getImm(), MCOp);
193
return true;
194
}
195
196
return false;
197
} else if (MCID.TSFlags & NVPTXII::IsSurfTexQueryFlag) {
198
// This is a query, so operand 1 is a surfref/texref
199
if (OpNo == 1 && MO.isImm()) {
200
lowerImageHandleSymbol(MO.getImm(), MCOp);
201
return true;
202
}
203
204
return false;
205
}
206
207
return false;
208
}
209
210
void NVPTXAsmPrinter::lowerImageHandleSymbol(unsigned Index, MCOperand &MCOp) {
211
// Ewwww
212
LLVMTargetMachine &TM = const_cast<LLVMTargetMachine&>(MF->getTarget());
213
NVPTXTargetMachine &nvTM = static_cast<NVPTXTargetMachine&>(TM);
214
const NVPTXMachineFunctionInfo *MFI = MF->getInfo<NVPTXMachineFunctionInfo>();
215
const char *Sym = MFI->getImageHandleSymbol(Index);
216
StringRef SymName = nvTM.getStrPool().save(Sym);
217
MCOp = GetSymbolRef(OutContext.getOrCreateSymbol(SymName));
218
}
219
220
void NVPTXAsmPrinter::lowerToMCInst(const MachineInstr *MI, MCInst &OutMI) {
221
OutMI.setOpcode(MI->getOpcode());
222
// Special: Do not mangle symbol operand of CALL_PROTOTYPE
223
if (MI->getOpcode() == NVPTX::CALL_PROTOTYPE) {
224
const MachineOperand &MO = MI->getOperand(0);
225
OutMI.addOperand(GetSymbolRef(
226
OutContext.getOrCreateSymbol(Twine(MO.getSymbolName()))));
227
return;
228
}
229
230
const NVPTXSubtarget &STI = MI->getMF()->getSubtarget<NVPTXSubtarget>();
231
for (unsigned i = 0, e = MI->getNumOperands(); i != e; ++i) {
232
const MachineOperand &MO = MI->getOperand(i);
233
234
MCOperand MCOp;
235
if (!STI.hasImageHandles()) {
236
if (lowerImageHandleOperand(MI, i, MCOp)) {
237
OutMI.addOperand(MCOp);
238
continue;
239
}
240
}
241
242
if (lowerOperand(MO, MCOp))
243
OutMI.addOperand(MCOp);
244
}
245
}
246
247
bool NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO,
248
MCOperand &MCOp) {
249
switch (MO.getType()) {
250
default: llvm_unreachable("unknown operand type");
251
case MachineOperand::MO_Register:
252
MCOp = MCOperand::createReg(encodeVirtualRegister(MO.getReg()));
253
break;
254
case MachineOperand::MO_Immediate:
255
MCOp = MCOperand::createImm(MO.getImm());
256
break;
257
case MachineOperand::MO_MachineBasicBlock:
258
MCOp = MCOperand::createExpr(MCSymbolRefExpr::create(
259
MO.getMBB()->getSymbol(), OutContext));
260
break;
261
case MachineOperand::MO_ExternalSymbol:
262
MCOp = GetSymbolRef(GetExternalSymbolSymbol(MO.getSymbolName()));
263
break;
264
case MachineOperand::MO_GlobalAddress:
265
MCOp = GetSymbolRef(getSymbol(MO.getGlobal()));
266
break;
267
case MachineOperand::MO_FPImmediate: {
268
const ConstantFP *Cnt = MO.getFPImm();
269
const APFloat &Val = Cnt->getValueAPF();
270
271
switch (Cnt->getType()->getTypeID()) {
272
default: report_fatal_error("Unsupported FP type"); break;
273
case Type::HalfTyID:
274
MCOp = MCOperand::createExpr(
275
NVPTXFloatMCExpr::createConstantFPHalf(Val, OutContext));
276
break;
277
case Type::BFloatTyID:
278
MCOp = MCOperand::createExpr(
279
NVPTXFloatMCExpr::createConstantBFPHalf(Val, OutContext));
280
break;
281
case Type::FloatTyID:
282
MCOp = MCOperand::createExpr(
283
NVPTXFloatMCExpr::createConstantFPSingle(Val, OutContext));
284
break;
285
case Type::DoubleTyID:
286
MCOp = MCOperand::createExpr(
287
NVPTXFloatMCExpr::createConstantFPDouble(Val, OutContext));
288
break;
289
}
290
break;
291
}
292
}
293
return true;
294
}
295
296
unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg) {
297
if (Register::isVirtualRegister(Reg)) {
298
const TargetRegisterClass *RC = MRI->getRegClass(Reg);
299
300
DenseMap<unsigned, unsigned> &RegMap = VRegMapping[RC];
301
unsigned RegNum = RegMap[Reg];
302
303
// Encode the register class in the upper 4 bits
304
// Must be kept in sync with NVPTXInstPrinter::printRegName
305
unsigned Ret = 0;
306
if (RC == &NVPTX::Int1RegsRegClass) {
307
Ret = (1 << 28);
308
} else if (RC == &NVPTX::Int16RegsRegClass) {
309
Ret = (2 << 28);
310
} else if (RC == &NVPTX::Int32RegsRegClass) {
311
Ret = (3 << 28);
312
} else if (RC == &NVPTX::Int64RegsRegClass) {
313
Ret = (4 << 28);
314
} else if (RC == &NVPTX::Float32RegsRegClass) {
315
Ret = (5 << 28);
316
} else if (RC == &NVPTX::Float64RegsRegClass) {
317
Ret = (6 << 28);
318
} else if (RC == &NVPTX::Int128RegsRegClass) {
319
Ret = (7 << 28);
320
} else {
321
report_fatal_error("Bad register class");
322
}
323
324
// Insert the vreg number
325
Ret |= (RegNum & 0x0FFFFFFF);
326
return Ret;
327
} else {
328
// Some special-use registers are actually physical registers.
329
// Encode this as the register class ID of 0 and the real register ID.
330
return Reg & 0x0FFFFFFF;
331
}
332
}
333
334
MCOperand NVPTXAsmPrinter::GetSymbolRef(const MCSymbol *Symbol) {
335
const MCExpr *Expr;
336
Expr = MCSymbolRefExpr::create(Symbol, MCSymbolRefExpr::VK_None,
337
OutContext);
338
return MCOperand::createExpr(Expr);
339
}
340
341
static bool ShouldPassAsArray(Type *Ty) {
342
return Ty->isAggregateType() || Ty->isVectorTy() || Ty->isIntegerTy(128) ||
343
Ty->isHalfTy() || Ty->isBFloatTy();
344
}
345
346
void NVPTXAsmPrinter::printReturnValStr(const Function *F, raw_ostream &O) {
347
const DataLayout &DL = getDataLayout();
348
const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F);
349
const auto *TLI = cast<NVPTXTargetLowering>(STI.getTargetLowering());
350
351
Type *Ty = F->getReturnType();
352
353
bool isABI = (STI.getSmVersion() >= 20);
354
355
if (Ty->getTypeID() == Type::VoidTyID)
356
return;
357
O << " (";
358
359
if (isABI) {
360
if ((Ty->isFloatingPointTy() || Ty->isIntegerTy()) &&
361
!ShouldPassAsArray(Ty)) {
362
unsigned size = 0;
363
if (auto *ITy = dyn_cast<IntegerType>(Ty)) {
364
size = ITy->getBitWidth();
365
} else {
366
assert(Ty->isFloatingPointTy() && "Floating point type expected here");
367
size = Ty->getPrimitiveSizeInBits();
368
}
369
size = promoteScalarArgumentSize(size);
370
O << ".param .b" << size << " func_retval0";
371
} else if (isa<PointerType>(Ty)) {
372
O << ".param .b" << TLI->getPointerTy(DL).getSizeInBits()
373
<< " func_retval0";
374
} else if (ShouldPassAsArray(Ty)) {
375
unsigned totalsz = DL.getTypeAllocSize(Ty);
376
Align RetAlignment = TLI->getFunctionArgumentAlignment(
377
F, Ty, AttributeList::ReturnIndex, DL);
378
O << ".param .align " << RetAlignment.value() << " .b8 func_retval0["
379
<< totalsz << "]";
380
} else
381
llvm_unreachable("Unknown return type");
382
} else {
383
SmallVector<EVT, 16> vtparts;
384
ComputeValueVTs(*TLI, DL, Ty, vtparts);
385
unsigned idx = 0;
386
for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
387
unsigned elems = 1;
388
EVT elemtype = vtparts[i];
389
if (vtparts[i].isVector()) {
390
elems = vtparts[i].getVectorNumElements();
391
elemtype = vtparts[i].getVectorElementType();
392
}
393
394
for (unsigned j = 0, je = elems; j != je; ++j) {
395
unsigned sz = elemtype.getSizeInBits();
396
if (elemtype.isInteger())
397
sz = promoteScalarArgumentSize(sz);
398
O << ".reg .b" << sz << " func_retval" << idx;
399
if (j < je - 1)
400
O << ", ";
401
++idx;
402
}
403
if (i < e - 1)
404
O << ", ";
405
}
406
}
407
O << ") ";
408
}
409
410
void NVPTXAsmPrinter::printReturnValStr(const MachineFunction &MF,
411
raw_ostream &O) {
412
const Function &F = MF.getFunction();
413
printReturnValStr(&F, O);
414
}
415
416
// Return true if MBB is the header of a loop marked with
417
// llvm.loop.unroll.disable or llvm.loop.unroll.count=1.
418
bool NVPTXAsmPrinter::isLoopHeaderOfNoUnroll(
419
const MachineBasicBlock &MBB) const {
420
MachineLoopInfo &LI = getAnalysis<MachineLoopInfoWrapperPass>().getLI();
421
// We insert .pragma "nounroll" only to the loop header.
422
if (!LI.isLoopHeader(&MBB))
423
return false;
424
425
// llvm.loop.unroll.disable is marked on the back edges of a loop. Therefore,
426
// we iterate through each back edge of the loop with header MBB, and check
427
// whether its metadata contains llvm.loop.unroll.disable.
428
for (const MachineBasicBlock *PMBB : MBB.predecessors()) {
429
if (LI.getLoopFor(PMBB) != LI.getLoopFor(&MBB)) {
430
// Edges from other loops to MBB are not back edges.
431
continue;
432
}
433
if (const BasicBlock *PBB = PMBB->getBasicBlock()) {
434
if (MDNode *LoopID =
435
PBB->getTerminator()->getMetadata(LLVMContext::MD_loop)) {
436
if (GetUnrollMetadata(LoopID, "llvm.loop.unroll.disable"))
437
return true;
438
if (MDNode *UnrollCountMD =
439
GetUnrollMetadata(LoopID, "llvm.loop.unroll.count")) {
440
if (mdconst::extract<ConstantInt>(UnrollCountMD->getOperand(1))
441
->isOne())
442
return true;
443
}
444
}
445
}
446
}
447
return false;
448
}
449
450
void NVPTXAsmPrinter::emitBasicBlockStart(const MachineBasicBlock &MBB) {
451
AsmPrinter::emitBasicBlockStart(MBB);
452
if (isLoopHeaderOfNoUnroll(MBB))
453
OutStreamer->emitRawText(StringRef("\t.pragma \"nounroll\";\n"));
454
}
455
456
void NVPTXAsmPrinter::emitFunctionEntryLabel() {
457
SmallString<128> Str;
458
raw_svector_ostream O(Str);
459
460
if (!GlobalsEmitted) {
461
emitGlobals(*MF->getFunction().getParent());
462
GlobalsEmitted = true;
463
}
464
465
// Set up
466
MRI = &MF->getRegInfo();
467
F = &MF->getFunction();
468
emitLinkageDirective(F, O);
469
if (isKernelFunction(*F))
470
O << ".entry ";
471
else {
472
O << ".func ";
473
printReturnValStr(*MF, O);
474
}
475
476
CurrentFnSym->print(O, MAI);
477
478
emitFunctionParamList(F, O);
479
O << "\n";
480
481
if (isKernelFunction(*F))
482
emitKernelFunctionDirectives(*F, O);
483
484
if (shouldEmitPTXNoReturn(F, TM))
485
O << ".noreturn";
486
487
OutStreamer->emitRawText(O.str());
488
489
VRegMapping.clear();
490
// Emit open brace for function body.
491
OutStreamer->emitRawText(StringRef("{\n"));
492
setAndEmitFunctionVirtualRegisters(*MF);
493
// Emit initial .loc debug directive for correct relocation symbol data.
494
if (const DISubprogram *SP = MF->getFunction().getSubprogram()) {
495
assert(SP->getUnit());
496
if (!SP->getUnit()->isDebugDirectivesOnly() && MMI && MMI->hasDebugInfo())
497
emitInitialRawDwarfLocDirective(*MF);
498
}
499
}
500
501
bool NVPTXAsmPrinter::runOnMachineFunction(MachineFunction &F) {
502
bool Result = AsmPrinter::runOnMachineFunction(F);
503
// Emit closing brace for the body of function F.
504
// The closing brace must be emitted here because we need to emit additional
505
// debug labels/data after the last basic block.
506
// We need to emit the closing brace here because we don't have function that
507
// finished emission of the function body.
508
OutStreamer->emitRawText(StringRef("}\n"));
509
return Result;
510
}
511
512
void NVPTXAsmPrinter::emitFunctionBodyStart() {
513
SmallString<128> Str;
514
raw_svector_ostream O(Str);
515
emitDemotedVars(&MF->getFunction(), O);
516
OutStreamer->emitRawText(O.str());
517
}
518
519
void NVPTXAsmPrinter::emitFunctionBodyEnd() {
520
VRegMapping.clear();
521
}
522
523
const MCSymbol *NVPTXAsmPrinter::getFunctionFrameSymbol() const {
524
SmallString<128> Str;
525
raw_svector_ostream(Str) << DEPOTNAME << getFunctionNumber();
526
return OutContext.getOrCreateSymbol(Str);
527
}
528
529
void NVPTXAsmPrinter::emitImplicitDef(const MachineInstr *MI) const {
530
Register RegNo = MI->getOperand(0).getReg();
531
if (RegNo.isVirtual()) {
532
OutStreamer->AddComment(Twine("implicit-def: ") +
533
getVirtualRegisterName(RegNo));
534
} else {
535
const NVPTXSubtarget &STI = MI->getMF()->getSubtarget<NVPTXSubtarget>();
536
OutStreamer->AddComment(Twine("implicit-def: ") +
537
STI.getRegisterInfo()->getName(RegNo));
538
}
539
OutStreamer->addBlankLine();
540
}
541
542
void NVPTXAsmPrinter::emitKernelFunctionDirectives(const Function &F,
543
raw_ostream &O) const {
544
// If the NVVM IR has some of reqntid* specified, then output
545
// the reqntid directive, and set the unspecified ones to 1.
546
// If none of Reqntid* is specified, don't output reqntid directive.
547
std::optional<unsigned> Reqntidx = getReqNTIDx(F);
548
std::optional<unsigned> Reqntidy = getReqNTIDy(F);
549
std::optional<unsigned> Reqntidz = getReqNTIDz(F);
550
551
if (Reqntidx || Reqntidy || Reqntidz)
552
O << ".reqntid " << Reqntidx.value_or(1) << ", " << Reqntidy.value_or(1)
553
<< ", " << Reqntidz.value_or(1) << "\n";
554
555
// If the NVVM IR has some of maxntid* specified, then output
556
// the maxntid directive, and set the unspecified ones to 1.
557
// If none of maxntid* is specified, don't output maxntid directive.
558
std::optional<unsigned> Maxntidx = getMaxNTIDx(F);
559
std::optional<unsigned> Maxntidy = getMaxNTIDy(F);
560
std::optional<unsigned> Maxntidz = getMaxNTIDz(F);
561
562
if (Maxntidx || Maxntidy || Maxntidz)
563
O << ".maxntid " << Maxntidx.value_or(1) << ", " << Maxntidy.value_or(1)
564
<< ", " << Maxntidz.value_or(1) << "\n";
565
566
unsigned Mincta = 0;
567
if (getMinCTASm(F, Mincta))
568
O << ".minnctapersm " << Mincta << "\n";
569
570
unsigned Maxnreg = 0;
571
if (getMaxNReg(F, Maxnreg))
572
O << ".maxnreg " << Maxnreg << "\n";
573
574
// .maxclusterrank directive requires SM_90 or higher, make sure that we
575
// filter it out for lower SM versions, as it causes a hard ptxas crash.
576
const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
577
const auto *STI = static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
578
unsigned Maxclusterrank = 0;
579
if (getMaxClusterRank(F, Maxclusterrank) && STI->getSmVersion() >= 90)
580
O << ".maxclusterrank " << Maxclusterrank << "\n";
581
}
582
583
std::string NVPTXAsmPrinter::getVirtualRegisterName(unsigned Reg) const {
584
const TargetRegisterClass *RC = MRI->getRegClass(Reg);
585
586
std::string Name;
587
raw_string_ostream NameStr(Name);
588
589
VRegRCMap::const_iterator I = VRegMapping.find(RC);
590
assert(I != VRegMapping.end() && "Bad register class");
591
const DenseMap<unsigned, unsigned> &RegMap = I->second;
592
593
VRegMap::const_iterator VI = RegMap.find(Reg);
594
assert(VI != RegMap.end() && "Bad virtual register");
595
unsigned MappedVR = VI->second;
596
597
NameStr << getNVPTXRegClassStr(RC) << MappedVR;
598
599
NameStr.flush();
600
return Name;
601
}
602
603
void NVPTXAsmPrinter::emitVirtualRegister(unsigned int vr,
604
raw_ostream &O) {
605
O << getVirtualRegisterName(vr);
606
}
607
608
void NVPTXAsmPrinter::emitAliasDeclaration(const GlobalAlias *GA,
609
raw_ostream &O) {
610
const Function *F = dyn_cast_or_null<Function>(GA->getAliaseeObject());
611
if (!F || isKernelFunction(*F) || F->isDeclaration())
612
report_fatal_error(
613
"NVPTX aliasee must be a non-kernel function definition");
614
615
if (GA->hasLinkOnceLinkage() || GA->hasWeakLinkage() ||
616
GA->hasAvailableExternallyLinkage() || GA->hasCommonLinkage())
617
report_fatal_error("NVPTX aliasee must not be '.weak'");
618
619
emitDeclarationWithName(F, getSymbol(GA), O);
620
}
621
622
void NVPTXAsmPrinter::emitDeclaration(const Function *F, raw_ostream &O) {
623
emitDeclarationWithName(F, getSymbol(F), O);
624
}
625
626
void NVPTXAsmPrinter::emitDeclarationWithName(const Function *F, MCSymbol *S,
627
raw_ostream &O) {
628
emitLinkageDirective(F, O);
629
if (isKernelFunction(*F))
630
O << ".entry ";
631
else
632
O << ".func ";
633
printReturnValStr(F, O);
634
S->print(O, MAI);
635
O << "\n";
636
emitFunctionParamList(F, O);
637
O << "\n";
638
if (shouldEmitPTXNoReturn(F, TM))
639
O << ".noreturn";
640
O << ";\n";
641
}
642
643
static bool usedInGlobalVarDef(const Constant *C) {
644
if (!C)
645
return false;
646
647
if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(C)) {
648
return GV->getName() != "llvm.used";
649
}
650
651
for (const User *U : C->users())
652
if (const Constant *C = dyn_cast<Constant>(U))
653
if (usedInGlobalVarDef(C))
654
return true;
655
656
return false;
657
}
658
659
static bool usedInOneFunc(const User *U, Function const *&oneFunc) {
660
if (const GlobalVariable *othergv = dyn_cast<GlobalVariable>(U)) {
661
if (othergv->getName() == "llvm.used")
662
return true;
663
}
664
665
if (const Instruction *instr = dyn_cast<Instruction>(U)) {
666
if (instr->getParent() && instr->getParent()->getParent()) {
667
const Function *curFunc = instr->getParent()->getParent();
668
if (oneFunc && (curFunc != oneFunc))
669
return false;
670
oneFunc = curFunc;
671
return true;
672
} else
673
return false;
674
}
675
676
for (const User *UU : U->users())
677
if (!usedInOneFunc(UU, oneFunc))
678
return false;
679
680
return true;
681
}
682
683
/* Find out if a global variable can be demoted to local scope.
684
* Currently, this is valid for CUDA shared variables, which have local
685
* scope and global lifetime. So the conditions to check are :
686
* 1. Is the global variable in shared address space?
687
* 2. Does it have local linkage?
688
* 3. Is the global variable referenced only in one function?
689
*/
690
static bool canDemoteGlobalVar(const GlobalVariable *gv, Function const *&f) {
691
if (!gv->hasLocalLinkage())
692
return false;
693
PointerType *Pty = gv->getType();
694
if (Pty->getAddressSpace() != ADDRESS_SPACE_SHARED)
695
return false;
696
697
const Function *oneFunc = nullptr;
698
699
bool flag = usedInOneFunc(gv, oneFunc);
700
if (!flag)
701
return false;
702
if (!oneFunc)
703
return false;
704
f = oneFunc;
705
return true;
706
}
707
708
static bool useFuncSeen(const Constant *C,
709
DenseMap<const Function *, bool> &seenMap) {
710
for (const User *U : C->users()) {
711
if (const Constant *cu = dyn_cast<Constant>(U)) {
712
if (useFuncSeen(cu, seenMap))
713
return true;
714
} else if (const Instruction *I = dyn_cast<Instruction>(U)) {
715
const BasicBlock *bb = I->getParent();
716
if (!bb)
717
continue;
718
const Function *caller = bb->getParent();
719
if (!caller)
720
continue;
721
if (seenMap.contains(caller))
722
return true;
723
}
724
}
725
return false;
726
}
727
728
void NVPTXAsmPrinter::emitDeclarations(const Module &M, raw_ostream &O) {
729
DenseMap<const Function *, bool> seenMap;
730
for (const Function &F : M) {
731
if (F.getAttributes().hasFnAttr("nvptx-libcall-callee")) {
732
emitDeclaration(&F, O);
733
continue;
734
}
735
736
if (F.isDeclaration()) {
737
if (F.use_empty())
738
continue;
739
if (F.getIntrinsicID())
740
continue;
741
emitDeclaration(&F, O);
742
continue;
743
}
744
for (const User *U : F.users()) {
745
if (const Constant *C = dyn_cast<Constant>(U)) {
746
if (usedInGlobalVarDef(C)) {
747
// The use is in the initialization of a global variable
748
// that is a function pointer, so print a declaration
749
// for the original function
750
emitDeclaration(&F, O);
751
break;
752
}
753
// Emit a declaration of this function if the function that
754
// uses this constant expr has already been seen.
755
if (useFuncSeen(C, seenMap)) {
756
emitDeclaration(&F, O);
757
break;
758
}
759
}
760
761
if (!isa<Instruction>(U))
762
continue;
763
const Instruction *instr = cast<Instruction>(U);
764
const BasicBlock *bb = instr->getParent();
765
if (!bb)
766
continue;
767
const Function *caller = bb->getParent();
768
if (!caller)
769
continue;
770
771
// If a caller has already been seen, then the caller is
772
// appearing in the module before the callee. so print out
773
// a declaration for the callee.
774
if (seenMap.contains(caller)) {
775
emitDeclaration(&F, O);
776
break;
777
}
778
}
779
seenMap[&F] = true;
780
}
781
for (const GlobalAlias &GA : M.aliases())
782
emitAliasDeclaration(&GA, O);
783
}
784
785
static bool isEmptyXXStructor(GlobalVariable *GV) {
786
if (!GV) return true;
787
const ConstantArray *InitList = dyn_cast<ConstantArray>(GV->getInitializer());
788
if (!InitList) return true; // Not an array; we don't know how to parse.
789
return InitList->getNumOperands() == 0;
790
}
791
792
void NVPTXAsmPrinter::emitStartOfAsmFile(Module &M) {
793
// Construct a default subtarget off of the TargetMachine defaults. The
794
// rest of NVPTX isn't friendly to change subtargets per function and
795
// so the default TargetMachine will have all of the options.
796
const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
797
const auto* STI = static_cast<const NVPTXSubtarget*>(NTM.getSubtargetImpl());
798
SmallString<128> Str1;
799
raw_svector_ostream OS1(Str1);
800
801
// Emit header before any dwarf directives are emitted below.
802
emitHeader(M, OS1, *STI);
803
OutStreamer->emitRawText(OS1.str());
804
}
805
806
bool NVPTXAsmPrinter::doInitialization(Module &M) {
807
const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
808
const NVPTXSubtarget &STI =
809
*static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
810
if (M.alias_size() && (STI.getPTXVersion() < 63 || STI.getSmVersion() < 30))
811
report_fatal_error(".alias requires PTX version >= 6.3 and sm_30");
812
813
// OpenMP supports NVPTX global constructors and destructors.
814
bool IsOpenMP = M.getModuleFlag("openmp") != nullptr;
815
816
if (!isEmptyXXStructor(M.getNamedGlobal("llvm.global_ctors")) &&
817
!LowerCtorDtor && !IsOpenMP) {
818
report_fatal_error(
819
"Module has a nontrivial global ctor, which NVPTX does not support.");
820
return true; // error
821
}
822
if (!isEmptyXXStructor(M.getNamedGlobal("llvm.global_dtors")) &&
823
!LowerCtorDtor && !IsOpenMP) {
824
report_fatal_error(
825
"Module has a nontrivial global dtor, which NVPTX does not support.");
826
return true; // error
827
}
828
829
// We need to call the parent's one explicitly.
830
bool Result = AsmPrinter::doInitialization(M);
831
832
GlobalsEmitted = false;
833
834
return Result;
835
}
836
837
void NVPTXAsmPrinter::emitGlobals(const Module &M) {
838
SmallString<128> Str2;
839
raw_svector_ostream OS2(Str2);
840
841
emitDeclarations(M, OS2);
842
843
// As ptxas does not support forward references of globals, we need to first
844
// sort the list of module-level globals in def-use order. We visit each
845
// global variable in order, and ensure that we emit it *after* its dependent
846
// globals. We use a little extra memory maintaining both a set and a list to
847
// have fast searches while maintaining a strict ordering.
848
SmallVector<const GlobalVariable *, 8> Globals;
849
DenseSet<const GlobalVariable *> GVVisited;
850
DenseSet<const GlobalVariable *> GVVisiting;
851
852
// Visit each global variable, in order
853
for (const GlobalVariable &I : M.globals())
854
VisitGlobalVariableForEmission(&I, Globals, GVVisited, GVVisiting);
855
856
assert(GVVisited.size() == M.global_size() && "Missed a global variable");
857
assert(GVVisiting.size() == 0 && "Did not fully process a global variable");
858
859
const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
860
const NVPTXSubtarget &STI =
861
*static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
862
863
// Print out module-level global variables in proper order
864
for (const GlobalVariable *GV : Globals)
865
printModuleLevelGV(GV, OS2, /*processDemoted=*/false, STI);
866
867
OS2 << '\n';
868
869
OutStreamer->emitRawText(OS2.str());
870
}
871
872
void NVPTXAsmPrinter::emitGlobalAlias(const Module &M, const GlobalAlias &GA) {
873
SmallString<128> Str;
874
raw_svector_ostream OS(Str);
875
876
MCSymbol *Name = getSymbol(&GA);
877
878
OS << ".alias " << Name->getName() << ", " << GA.getAliaseeObject()->getName()
879
<< ";\n";
880
881
OutStreamer->emitRawText(OS.str());
882
}
883
884
void NVPTXAsmPrinter::emitHeader(Module &M, raw_ostream &O,
885
const NVPTXSubtarget &STI) {
886
O << "//\n";
887
O << "// Generated by LLVM NVPTX Back-End\n";
888
O << "//\n";
889
O << "\n";
890
891
unsigned PTXVersion = STI.getPTXVersion();
892
O << ".version " << (PTXVersion / 10) << "." << (PTXVersion % 10) << "\n";
893
894
O << ".target ";
895
O << STI.getTargetName();
896
897
const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
898
if (NTM.getDrvInterface() == NVPTX::NVCL)
899
O << ", texmode_independent";
900
901
bool HasFullDebugInfo = false;
902
for (DICompileUnit *CU : M.debug_compile_units()) {
903
switch(CU->getEmissionKind()) {
904
case DICompileUnit::NoDebug:
905
case DICompileUnit::DebugDirectivesOnly:
906
break;
907
case DICompileUnit::LineTablesOnly:
908
case DICompileUnit::FullDebug:
909
HasFullDebugInfo = true;
910
break;
911
}
912
if (HasFullDebugInfo)
913
break;
914
}
915
if (MMI && MMI->hasDebugInfo() && HasFullDebugInfo)
916
O << ", debug";
917
918
O << "\n";
919
920
O << ".address_size ";
921
if (NTM.is64Bit())
922
O << "64";
923
else
924
O << "32";
925
O << "\n";
926
927
O << "\n";
928
}
929
930
bool NVPTXAsmPrinter::doFinalization(Module &M) {
931
bool HasDebugInfo = MMI && MMI->hasDebugInfo();
932
933
// If we did not emit any functions, then the global declarations have not
934
// yet been emitted.
935
if (!GlobalsEmitted) {
936
emitGlobals(M);
937
GlobalsEmitted = true;
938
}
939
940
// call doFinalization
941
bool ret = AsmPrinter::doFinalization(M);
942
943
clearAnnotationCache(&M);
944
945
auto *TS =
946
static_cast<NVPTXTargetStreamer *>(OutStreamer->getTargetStreamer());
947
// Close the last emitted section
948
if (HasDebugInfo) {
949
TS->closeLastSection();
950
// Emit empty .debug_loc section for better support of the empty files.
951
OutStreamer->emitRawText("\t.section\t.debug_loc\t{\t}");
952
}
953
954
// Output last DWARF .file directives, if any.
955
TS->outputDwarfFileDirectives();
956
957
return ret;
958
}
959
960
// This function emits appropriate linkage directives for
961
// functions and global variables.
962
//
963
// extern function declaration -> .extern
964
// extern function definition -> .visible
965
// external global variable with init -> .visible
966
// external without init -> .extern
967
// appending -> not allowed, assert.
968
// for any linkage other than
969
// internal, private, linker_private,
970
// linker_private_weak, linker_private_weak_def_auto,
971
// we emit -> .weak.
972
973
void NVPTXAsmPrinter::emitLinkageDirective(const GlobalValue *V,
974
raw_ostream &O) {
975
if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() == NVPTX::CUDA) {
976
if (V->hasExternalLinkage()) {
977
if (isa<GlobalVariable>(V)) {
978
const GlobalVariable *GVar = cast<GlobalVariable>(V);
979
if (GVar) {
980
if (GVar->hasInitializer())
981
O << ".visible ";
982
else
983
O << ".extern ";
984
}
985
} else if (V->isDeclaration())
986
O << ".extern ";
987
else
988
O << ".visible ";
989
} else if (V->hasAppendingLinkage()) {
990
std::string msg;
991
msg.append("Error: ");
992
msg.append("Symbol ");
993
if (V->hasName())
994
msg.append(std::string(V->getName()));
995
msg.append("has unsupported appending linkage type");
996
llvm_unreachable(msg.c_str());
997
} else if (!V->hasInternalLinkage() &&
998
!V->hasPrivateLinkage()) {
999
O << ".weak ";
1000
}
1001
}
1002
}
1003
1004
void NVPTXAsmPrinter::printModuleLevelGV(const GlobalVariable *GVar,
1005
raw_ostream &O, bool processDemoted,
1006
const NVPTXSubtarget &STI) {
1007
// Skip meta data
1008
if (GVar->hasSection()) {
1009
if (GVar->getSection() == "llvm.metadata")
1010
return;
1011
}
1012
1013
// Skip LLVM intrinsic global variables
1014
if (GVar->getName().starts_with("llvm.") ||
1015
GVar->getName().starts_with("nvvm."))
1016
return;
1017
1018
const DataLayout &DL = getDataLayout();
1019
1020
// GlobalVariables are always constant pointers themselves.
1021
Type *ETy = GVar->getValueType();
1022
1023
if (GVar->hasExternalLinkage()) {
1024
if (GVar->hasInitializer())
1025
O << ".visible ";
1026
else
1027
O << ".extern ";
1028
} else if (STI.getPTXVersion() >= 50 && GVar->hasCommonLinkage() &&
1029
GVar->getAddressSpace() == ADDRESS_SPACE_GLOBAL) {
1030
O << ".common ";
1031
} else if (GVar->hasLinkOnceLinkage() || GVar->hasWeakLinkage() ||
1032
GVar->hasAvailableExternallyLinkage() ||
1033
GVar->hasCommonLinkage()) {
1034
O << ".weak ";
1035
}
1036
1037
if (isTexture(*GVar)) {
1038
O << ".global .texref " << getTextureName(*GVar) << ";\n";
1039
return;
1040
}
1041
1042
if (isSurface(*GVar)) {
1043
O << ".global .surfref " << getSurfaceName(*GVar) << ";\n";
1044
return;
1045
}
1046
1047
if (GVar->isDeclaration()) {
1048
// (extern) declarations, no definition or initializer
1049
// Currently the only known declaration is for an automatic __local
1050
// (.shared) promoted to global.
1051
emitPTXGlobalVariable(GVar, O, STI);
1052
O << ";\n";
1053
return;
1054
}
1055
1056
if (isSampler(*GVar)) {
1057
O << ".global .samplerref " << getSamplerName(*GVar);
1058
1059
const Constant *Initializer = nullptr;
1060
if (GVar->hasInitializer())
1061
Initializer = GVar->getInitializer();
1062
const ConstantInt *CI = nullptr;
1063
if (Initializer)
1064
CI = dyn_cast<ConstantInt>(Initializer);
1065
if (CI) {
1066
unsigned sample = CI->getZExtValue();
1067
1068
O << " = { ";
1069
1070
for (int i = 0,
1071
addr = ((sample & __CLK_ADDRESS_MASK) >> __CLK_ADDRESS_BASE);
1072
i < 3; i++) {
1073
O << "addr_mode_" << i << " = ";
1074
switch (addr) {
1075
case 0:
1076
O << "wrap";
1077
break;
1078
case 1:
1079
O << "clamp_to_border";
1080
break;
1081
case 2:
1082
O << "clamp_to_edge";
1083
break;
1084
case 3:
1085
O << "wrap";
1086
break;
1087
case 4:
1088
O << "mirror";
1089
break;
1090
}
1091
O << ", ";
1092
}
1093
O << "filter_mode = ";
1094
switch ((sample & __CLK_FILTER_MASK) >> __CLK_FILTER_BASE) {
1095
case 0:
1096
O << "nearest";
1097
break;
1098
case 1:
1099
O << "linear";
1100
break;
1101
case 2:
1102
llvm_unreachable("Anisotropic filtering is not supported");
1103
default:
1104
O << "nearest";
1105
break;
1106
}
1107
if (!((sample & __CLK_NORMALIZED_MASK) >> __CLK_NORMALIZED_BASE)) {
1108
O << ", force_unnormalized_coords = 1";
1109
}
1110
O << " }";
1111
}
1112
1113
O << ";\n";
1114
return;
1115
}
1116
1117
if (GVar->hasPrivateLinkage()) {
1118
if (strncmp(GVar->getName().data(), "unrollpragma", 12) == 0)
1119
return;
1120
1121
// FIXME - need better way (e.g. Metadata) to avoid generating this global
1122
if (strncmp(GVar->getName().data(), "filename", 8) == 0)
1123
return;
1124
if (GVar->use_empty())
1125
return;
1126
}
1127
1128
const Function *demotedFunc = nullptr;
1129
if (!processDemoted && canDemoteGlobalVar(GVar, demotedFunc)) {
1130
O << "// " << GVar->getName() << " has been demoted\n";
1131
if (localDecls.find(demotedFunc) != localDecls.end())
1132
localDecls[demotedFunc].push_back(GVar);
1133
else {
1134
std::vector<const GlobalVariable *> temp;
1135
temp.push_back(GVar);
1136
localDecls[demotedFunc] = temp;
1137
}
1138
return;
1139
}
1140
1141
O << ".";
1142
emitPTXAddressSpace(GVar->getAddressSpace(), O);
1143
1144
if (isManaged(*GVar)) {
1145
if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) {
1146
report_fatal_error(
1147
".attribute(.managed) requires PTX version >= 4.0 and sm_30");
1148
}
1149
O << " .attribute(.managed)";
1150
}
1151
1152
if (MaybeAlign A = GVar->getAlign())
1153
O << " .align " << A->value();
1154
else
1155
O << " .align " << (int)DL.getPrefTypeAlign(ETy).value();
1156
1157
if (ETy->isFloatingPointTy() || ETy->isPointerTy() ||
1158
(ETy->isIntegerTy() && ETy->getScalarSizeInBits() <= 64)) {
1159
O << " .";
1160
// Special case: ABI requires that we use .u8 for predicates
1161
if (ETy->isIntegerTy(1))
1162
O << "u8";
1163
else
1164
O << getPTXFundamentalTypeStr(ETy, false);
1165
O << " ";
1166
getSymbol(GVar)->print(O, MAI);
1167
1168
// Ptx allows variable initilization only for constant and global state
1169
// spaces.
1170
if (GVar->hasInitializer()) {
1171
if ((GVar->getAddressSpace() == ADDRESS_SPACE_GLOBAL) ||
1172
(GVar->getAddressSpace() == ADDRESS_SPACE_CONST)) {
1173
const Constant *Initializer = GVar->getInitializer();
1174
// 'undef' is treated as there is no value specified.
1175
if (!Initializer->isNullValue() && !isa<UndefValue>(Initializer)) {
1176
O << " = ";
1177
printScalarConstant(Initializer, O);
1178
}
1179
} else {
1180
// The frontend adds zero-initializer to device and constant variables
1181
// that don't have an initial value, and UndefValue to shared
1182
// variables, so skip warning for this case.
1183
if (!GVar->getInitializer()->isNullValue() &&
1184
!isa<UndefValue>(GVar->getInitializer())) {
1185
report_fatal_error("initial value of '" + GVar->getName() +
1186
"' is not allowed in addrspace(" +
1187
Twine(GVar->getAddressSpace()) + ")");
1188
}
1189
}
1190
}
1191
} else {
1192
uint64_t ElementSize = 0;
1193
1194
// Although PTX has direct support for struct type and array type and
1195
// LLVM IR is very similar to PTX, the LLVM CodeGen does not support for
1196
// targets that support these high level field accesses. Structs, arrays
1197
// and vectors are lowered into arrays of bytes.
1198
switch (ETy->getTypeID()) {
1199
case Type::IntegerTyID: // Integers larger than 64 bits
1200
case Type::StructTyID:
1201
case Type::ArrayTyID:
1202
case Type::FixedVectorTyID:
1203
ElementSize = DL.getTypeStoreSize(ETy);
1204
// Ptx allows variable initilization only for constant and
1205
// global state spaces.
1206
if (((GVar->getAddressSpace() == ADDRESS_SPACE_GLOBAL) ||
1207
(GVar->getAddressSpace() == ADDRESS_SPACE_CONST)) &&
1208
GVar->hasInitializer()) {
1209
const Constant *Initializer = GVar->getInitializer();
1210
if (!isa<UndefValue>(Initializer) && !Initializer->isNullValue()) {
1211
AggBuffer aggBuffer(ElementSize, *this);
1212
bufferAggregateConstant(Initializer, &aggBuffer);
1213
if (aggBuffer.numSymbols()) {
1214
unsigned int ptrSize = MAI->getCodePointerSize();
1215
if (ElementSize % ptrSize ||
1216
!aggBuffer.allSymbolsAligned(ptrSize)) {
1217
// Print in bytes and use the mask() operator for pointers.
1218
if (!STI.hasMaskOperator())
1219
report_fatal_error(
1220
"initialized packed aggregate with pointers '" +
1221
GVar->getName() +
1222
"' requires at least PTX ISA version 7.1");
1223
O << " .u8 ";
1224
getSymbol(GVar)->print(O, MAI);
1225
O << "[" << ElementSize << "] = {";
1226
aggBuffer.printBytes(O);
1227
O << "}";
1228
} else {
1229
O << " .u" << ptrSize * 8 << " ";
1230
getSymbol(GVar)->print(O, MAI);
1231
O << "[" << ElementSize / ptrSize << "] = {";
1232
aggBuffer.printWords(O);
1233
O << "}";
1234
}
1235
} else {
1236
O << " .b8 ";
1237
getSymbol(GVar)->print(O, MAI);
1238
O << "[" << ElementSize << "] = {";
1239
aggBuffer.printBytes(O);
1240
O << "}";
1241
}
1242
} else {
1243
O << " .b8 ";
1244
getSymbol(GVar)->print(O, MAI);
1245
if (ElementSize) {
1246
O << "[";
1247
O << ElementSize;
1248
O << "]";
1249
}
1250
}
1251
} else {
1252
O << " .b8 ";
1253
getSymbol(GVar)->print(O, MAI);
1254
if (ElementSize) {
1255
O << "[";
1256
O << ElementSize;
1257
O << "]";
1258
}
1259
}
1260
break;
1261
default:
1262
llvm_unreachable("type not supported yet");
1263
}
1264
}
1265
O << ";\n";
1266
}
1267
1268
void NVPTXAsmPrinter::AggBuffer::printSymbol(unsigned nSym, raw_ostream &os) {
1269
const Value *v = Symbols[nSym];
1270
const Value *v0 = SymbolsBeforeStripping[nSym];
1271
if (const GlobalValue *GVar = dyn_cast<GlobalValue>(v)) {
1272
MCSymbol *Name = AP.getSymbol(GVar);
1273
PointerType *PTy = dyn_cast<PointerType>(v0->getType());
1274
// Is v0 a generic pointer?
1275
bool isGenericPointer = PTy && PTy->getAddressSpace() == 0;
1276
if (EmitGeneric && isGenericPointer && !isa<Function>(v)) {
1277
os << "generic(";
1278
Name->print(os, AP.MAI);
1279
os << ")";
1280
} else {
1281
Name->print(os, AP.MAI);
1282
}
1283
} else if (const ConstantExpr *CExpr = dyn_cast<ConstantExpr>(v0)) {
1284
const MCExpr *Expr = AP.lowerConstantForGV(cast<Constant>(CExpr), false);
1285
AP.printMCExpr(*Expr, os);
1286
} else
1287
llvm_unreachable("symbol type unknown");
1288
}
1289
1290
void NVPTXAsmPrinter::AggBuffer::printBytes(raw_ostream &os) {
1291
unsigned int ptrSize = AP.MAI->getCodePointerSize();
1292
// Do not emit trailing zero initializers. They will be zero-initialized by
1293
// ptxas. This saves on both space requirements for the generated PTX and on
1294
// memory use by ptxas. (See:
1295
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#global-state-space)
1296
unsigned int InitializerCount = size;
1297
// TODO: symbols make this harder, but it would still be good to trim trailing
1298
// 0s for aggs with symbols as well.
1299
if (numSymbols() == 0)
1300
while (InitializerCount >= 1 && !buffer[InitializerCount - 1])
1301
InitializerCount--;
1302
1303
symbolPosInBuffer.push_back(InitializerCount);
1304
unsigned int nSym = 0;
1305
unsigned int nextSymbolPos = symbolPosInBuffer[nSym];
1306
for (unsigned int pos = 0; pos < InitializerCount;) {
1307
if (pos)
1308
os << ", ";
1309
if (pos != nextSymbolPos) {
1310
os << (unsigned int)buffer[pos];
1311
++pos;
1312
continue;
1313
}
1314
// Generate a per-byte mask() operator for the symbol, which looks like:
1315
// .global .u8 addr[] = {0xFF(foo), 0xFF00(foo), 0xFF0000(foo), ...};
1316
// See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#initializers
1317
std::string symText;
1318
llvm::raw_string_ostream oss(symText);
1319
printSymbol(nSym, oss);
1320
for (unsigned i = 0; i < ptrSize; ++i) {
1321
if (i)
1322
os << ", ";
1323
llvm::write_hex(os, 0xFFULL << i * 8, HexPrintStyle::PrefixUpper);
1324
os << "(" << symText << ")";
1325
}
1326
pos += ptrSize;
1327
nextSymbolPos = symbolPosInBuffer[++nSym];
1328
assert(nextSymbolPos >= pos);
1329
}
1330
}
1331
1332
void NVPTXAsmPrinter::AggBuffer::printWords(raw_ostream &os) {
1333
unsigned int ptrSize = AP.MAI->getCodePointerSize();
1334
symbolPosInBuffer.push_back(size);
1335
unsigned int nSym = 0;
1336
unsigned int nextSymbolPos = symbolPosInBuffer[nSym];
1337
assert(nextSymbolPos % ptrSize == 0);
1338
for (unsigned int pos = 0; pos < size; pos += ptrSize) {
1339
if (pos)
1340
os << ", ";
1341
if (pos == nextSymbolPos) {
1342
printSymbol(nSym, os);
1343
nextSymbolPos = symbolPosInBuffer[++nSym];
1344
assert(nextSymbolPos % ptrSize == 0);
1345
assert(nextSymbolPos >= pos + ptrSize);
1346
} else if (ptrSize == 4)
1347
os << support::endian::read32le(&buffer[pos]);
1348
else
1349
os << support::endian::read64le(&buffer[pos]);
1350
}
1351
}
1352
1353
void NVPTXAsmPrinter::emitDemotedVars(const Function *f, raw_ostream &O) {
1354
if (localDecls.find(f) == localDecls.end())
1355
return;
1356
1357
std::vector<const GlobalVariable *> &gvars = localDecls[f];
1358
1359
const NVPTXTargetMachine &NTM = static_cast<const NVPTXTargetMachine &>(TM);
1360
const NVPTXSubtarget &STI =
1361
*static_cast<const NVPTXSubtarget *>(NTM.getSubtargetImpl());
1362
1363
for (const GlobalVariable *GV : gvars) {
1364
O << "\t// demoted variable\n\t";
1365
printModuleLevelGV(GV, O, /*processDemoted=*/true, STI);
1366
}
1367
}
1368
1369
void NVPTXAsmPrinter::emitPTXAddressSpace(unsigned int AddressSpace,
1370
raw_ostream &O) const {
1371
switch (AddressSpace) {
1372
case ADDRESS_SPACE_LOCAL:
1373
O << "local";
1374
break;
1375
case ADDRESS_SPACE_GLOBAL:
1376
O << "global";
1377
break;
1378
case ADDRESS_SPACE_CONST:
1379
O << "const";
1380
break;
1381
case ADDRESS_SPACE_SHARED:
1382
O << "shared";
1383
break;
1384
default:
1385
report_fatal_error("Bad address space found while emitting PTX: " +
1386
llvm::Twine(AddressSpace));
1387
break;
1388
}
1389
}
1390
1391
std::string
1392
NVPTXAsmPrinter::getPTXFundamentalTypeStr(Type *Ty, bool useB4PTR) const {
1393
switch (Ty->getTypeID()) {
1394
case Type::IntegerTyID: {
1395
unsigned NumBits = cast<IntegerType>(Ty)->getBitWidth();
1396
if (NumBits == 1)
1397
return "pred";
1398
else if (NumBits <= 64) {
1399
std::string name = "u";
1400
return name + utostr(NumBits);
1401
} else {
1402
llvm_unreachable("Integer too large");
1403
break;
1404
}
1405
break;
1406
}
1407
case Type::BFloatTyID:
1408
case Type::HalfTyID:
1409
// fp16 and bf16 are stored as .b16 for compatibility with pre-sm_53
1410
// PTX assembly.
1411
return "b16";
1412
case Type::FloatTyID:
1413
return "f32";
1414
case Type::DoubleTyID:
1415
return "f64";
1416
case Type::PointerTyID: {
1417
unsigned PtrSize = TM.getPointerSizeInBits(Ty->getPointerAddressSpace());
1418
assert((PtrSize == 64 || PtrSize == 32) && "Unexpected pointer size");
1419
1420
if (PtrSize == 64)
1421
if (useB4PTR)
1422
return "b64";
1423
else
1424
return "u64";
1425
else if (useB4PTR)
1426
return "b32";
1427
else
1428
return "u32";
1429
}
1430
default:
1431
break;
1432
}
1433
llvm_unreachable("unexpected type");
1434
}
1435
1436
void NVPTXAsmPrinter::emitPTXGlobalVariable(const GlobalVariable *GVar,
1437
raw_ostream &O,
1438
const NVPTXSubtarget &STI) {
1439
const DataLayout &DL = getDataLayout();
1440
1441
// GlobalVariables are always constant pointers themselves.
1442
Type *ETy = GVar->getValueType();
1443
1444
O << ".";
1445
emitPTXAddressSpace(GVar->getType()->getAddressSpace(), O);
1446
if (isManaged(*GVar)) {
1447
if (STI.getPTXVersion() < 40 || STI.getSmVersion() < 30) {
1448
report_fatal_error(
1449
".attribute(.managed) requires PTX version >= 4.0 and sm_30");
1450
}
1451
O << " .attribute(.managed)";
1452
}
1453
if (MaybeAlign A = GVar->getAlign())
1454
O << " .align " << A->value();
1455
else
1456
O << " .align " << (int)DL.getPrefTypeAlign(ETy).value();
1457
1458
// Special case for i128
1459
if (ETy->isIntegerTy(128)) {
1460
O << " .b8 ";
1461
getSymbol(GVar)->print(O, MAI);
1462
O << "[16]";
1463
return;
1464
}
1465
1466
if (ETy->isFloatingPointTy() || ETy->isIntOrPtrTy()) {
1467
O << " .";
1468
O << getPTXFundamentalTypeStr(ETy);
1469
O << " ";
1470
getSymbol(GVar)->print(O, MAI);
1471
return;
1472
}
1473
1474
int64_t ElementSize = 0;
1475
1476
// Although PTX has direct support for struct type and array type and LLVM IR
1477
// is very similar to PTX, the LLVM CodeGen does not support for targets that
1478
// support these high level field accesses. Structs and arrays are lowered
1479
// into arrays of bytes.
1480
switch (ETy->getTypeID()) {
1481
case Type::StructTyID:
1482
case Type::ArrayTyID:
1483
case Type::FixedVectorTyID:
1484
ElementSize = DL.getTypeStoreSize(ETy);
1485
O << " .b8 ";
1486
getSymbol(GVar)->print(O, MAI);
1487
O << "[";
1488
if (ElementSize) {
1489
O << ElementSize;
1490
}
1491
O << "]";
1492
break;
1493
default:
1494
llvm_unreachable("type not supported yet");
1495
}
1496
}
1497
1498
void NVPTXAsmPrinter::emitFunctionParamList(const Function *F, raw_ostream &O) {
1499
const DataLayout &DL = getDataLayout();
1500
const AttributeList &PAL = F->getAttributes();
1501
const NVPTXSubtarget &STI = TM.getSubtarget<NVPTXSubtarget>(*F);
1502
const auto *TLI = cast<NVPTXTargetLowering>(STI.getTargetLowering());
1503
1504
Function::const_arg_iterator I, E;
1505
unsigned paramIndex = 0;
1506
bool first = true;
1507
bool isKernelFunc = isKernelFunction(*F);
1508
bool isABI = (STI.getSmVersion() >= 20);
1509
bool hasImageHandles = STI.hasImageHandles();
1510
1511
if (F->arg_empty() && !F->isVarArg()) {
1512
O << "()";
1513
return;
1514
}
1515
1516
O << "(\n";
1517
1518
for (I = F->arg_begin(), E = F->arg_end(); I != E; ++I, paramIndex++) {
1519
Type *Ty = I->getType();
1520
1521
if (!first)
1522
O << ",\n";
1523
1524
first = false;
1525
1526
// Handle image/sampler parameters
1527
if (isKernelFunction(*F)) {
1528
if (isSampler(*I) || isImage(*I)) {
1529
if (isImage(*I)) {
1530
if (isImageWriteOnly(*I) || isImageReadWrite(*I)) {
1531
if (hasImageHandles)
1532
O << "\t.param .u64 .ptr .surfref ";
1533
else
1534
O << "\t.param .surfref ";
1535
O << TLI->getParamName(F, paramIndex);
1536
}
1537
else { // Default image is read_only
1538
if (hasImageHandles)
1539
O << "\t.param .u64 .ptr .texref ";
1540
else
1541
O << "\t.param .texref ";
1542
O << TLI->getParamName(F, paramIndex);
1543
}
1544
} else {
1545
if (hasImageHandles)
1546
O << "\t.param .u64 .ptr .samplerref ";
1547
else
1548
O << "\t.param .samplerref ";
1549
O << TLI->getParamName(F, paramIndex);
1550
}
1551
continue;
1552
}
1553
}
1554
1555
auto getOptimalAlignForParam = [TLI, &DL, &PAL, F,
1556
paramIndex](Type *Ty) -> Align {
1557
if (MaybeAlign StackAlign =
1558
getAlign(*F, paramIndex + AttributeList::FirstArgIndex))
1559
return StackAlign.value();
1560
1561
Align TypeAlign = TLI->getFunctionParamOptimizedAlign(F, Ty, DL);
1562
MaybeAlign ParamAlign = PAL.getParamAlignment(paramIndex);
1563
return std::max(TypeAlign, ParamAlign.valueOrOne());
1564
};
1565
1566
if (!PAL.hasParamAttr(paramIndex, Attribute::ByVal)) {
1567
if (ShouldPassAsArray(Ty)) {
1568
// Just print .param .align <a> .b8 .param[size];
1569
// <a> = optimal alignment for the element type; always multiple of
1570
// PAL.getParamAlignment
1571
// size = typeallocsize of element type
1572
Align OptimalAlign = getOptimalAlignForParam(Ty);
1573
1574
O << "\t.param .align " << OptimalAlign.value() << " .b8 ";
1575
O << TLI->getParamName(F, paramIndex);
1576
O << "[" << DL.getTypeAllocSize(Ty) << "]";
1577
1578
continue;
1579
}
1580
// Just a scalar
1581
auto *PTy = dyn_cast<PointerType>(Ty);
1582
unsigned PTySizeInBits = 0;
1583
if (PTy) {
1584
PTySizeInBits =
1585
TLI->getPointerTy(DL, PTy->getAddressSpace()).getSizeInBits();
1586
assert(PTySizeInBits && "Invalid pointer size");
1587
}
1588
1589
if (isKernelFunc) {
1590
if (PTy) {
1591
// Special handling for pointer arguments to kernel
1592
O << "\t.param .u" << PTySizeInBits << " ";
1593
1594
if (static_cast<NVPTXTargetMachine &>(TM).getDrvInterface() !=
1595
NVPTX::CUDA) {
1596
int addrSpace = PTy->getAddressSpace();
1597
switch (addrSpace) {
1598
default:
1599
O << ".ptr ";
1600
break;
1601
case ADDRESS_SPACE_CONST:
1602
O << ".ptr .const ";
1603
break;
1604
case ADDRESS_SPACE_SHARED:
1605
O << ".ptr .shared ";
1606
break;
1607
case ADDRESS_SPACE_GLOBAL:
1608
O << ".ptr .global ";
1609
break;
1610
}
1611
Align ParamAlign = I->getParamAlign().valueOrOne();
1612
O << ".align " << ParamAlign.value() << " ";
1613
}
1614
O << TLI->getParamName(F, paramIndex);
1615
continue;
1616
}
1617
1618
// non-pointer scalar to kernel func
1619
O << "\t.param .";
1620
// Special case: predicate operands become .u8 types
1621
if (Ty->isIntegerTy(1))
1622
O << "u8";
1623
else
1624
O << getPTXFundamentalTypeStr(Ty);
1625
O << " ";
1626
O << TLI->getParamName(F, paramIndex);
1627
continue;
1628
}
1629
// Non-kernel function, just print .param .b<size> for ABI
1630
// and .reg .b<size> for non-ABI
1631
unsigned sz = 0;
1632
if (isa<IntegerType>(Ty)) {
1633
sz = cast<IntegerType>(Ty)->getBitWidth();
1634
sz = promoteScalarArgumentSize(sz);
1635
} else if (PTy) {
1636
assert(PTySizeInBits && "Invalid pointer size");
1637
sz = PTySizeInBits;
1638
} else
1639
sz = Ty->getPrimitiveSizeInBits();
1640
if (isABI)
1641
O << "\t.param .b" << sz << " ";
1642
else
1643
O << "\t.reg .b" << sz << " ";
1644
O << TLI->getParamName(F, paramIndex);
1645
continue;
1646
}
1647
1648
// param has byVal attribute.
1649
Type *ETy = PAL.getParamByValType(paramIndex);
1650
assert(ETy && "Param should have byval type");
1651
1652
if (isABI || isKernelFunc) {
1653
// Just print .param .align <a> .b8 .param[size];
1654
// <a> = optimal alignment for the element type; always multiple of
1655
// PAL.getParamAlignment
1656
// size = typeallocsize of element type
1657
Align OptimalAlign =
1658
isKernelFunc
1659
? getOptimalAlignForParam(ETy)
1660
: TLI->getFunctionByValParamAlign(
1661
F, ETy, PAL.getParamAlignment(paramIndex).valueOrOne(), DL);
1662
1663
unsigned sz = DL.getTypeAllocSize(ETy);
1664
O << "\t.param .align " << OptimalAlign.value() << " .b8 ";
1665
O << TLI->getParamName(F, paramIndex);
1666
O << "[" << sz << "]";
1667
continue;
1668
} else {
1669
// Split the ETy into constituent parts and
1670
// print .param .b<size> <name> for each part.
1671
// Further, if a part is vector, print the above for
1672
// each vector element.
1673
SmallVector<EVT, 16> vtparts;
1674
ComputeValueVTs(*TLI, DL, ETy, vtparts);
1675
for (unsigned i = 0, e = vtparts.size(); i != e; ++i) {
1676
unsigned elems = 1;
1677
EVT elemtype = vtparts[i];
1678
if (vtparts[i].isVector()) {
1679
elems = vtparts[i].getVectorNumElements();
1680
elemtype = vtparts[i].getVectorElementType();
1681
}
1682
1683
for (unsigned j = 0, je = elems; j != je; ++j) {
1684
unsigned sz = elemtype.getSizeInBits();
1685
if (elemtype.isInteger())
1686
sz = promoteScalarArgumentSize(sz);
1687
O << "\t.reg .b" << sz << " ";
1688
O << TLI->getParamName(F, paramIndex);
1689
if (j < je - 1)
1690
O << ",\n";
1691
++paramIndex;
1692
}
1693
if (i < e - 1)
1694
O << ",\n";
1695
}
1696
--paramIndex;
1697
continue;
1698
}
1699
}
1700
1701
if (F->isVarArg()) {
1702
if (!first)
1703
O << ",\n";
1704
O << "\t.param .align " << STI.getMaxRequiredAlignment();
1705
O << " .b8 ";
1706
O << TLI->getParamName(F, /* vararg */ -1) << "[]";
1707
}
1708
1709
O << "\n)";
1710
}
1711
1712
void NVPTXAsmPrinter::setAndEmitFunctionVirtualRegisters(
1713
const MachineFunction &MF) {
1714
SmallString<128> Str;
1715
raw_svector_ostream O(Str);
1716
1717
// Map the global virtual register number to a register class specific
1718
// virtual register number starting from 1 with that class.
1719
const TargetRegisterInfo *TRI = MF.getSubtarget().getRegisterInfo();
1720
//unsigned numRegClasses = TRI->getNumRegClasses();
1721
1722
// Emit the Fake Stack Object
1723
const MachineFrameInfo &MFI = MF.getFrameInfo();
1724
int64_t NumBytes = MFI.getStackSize();
1725
if (NumBytes) {
1726
O << "\t.local .align " << MFI.getMaxAlign().value() << " .b8 \t"
1727
<< DEPOTNAME << getFunctionNumber() << "[" << NumBytes << "];\n";
1728
if (static_cast<const NVPTXTargetMachine &>(MF.getTarget()).is64Bit()) {
1729
O << "\t.reg .b64 \t%SP;\n";
1730
O << "\t.reg .b64 \t%SPL;\n";
1731
} else {
1732
O << "\t.reg .b32 \t%SP;\n";
1733
O << "\t.reg .b32 \t%SPL;\n";
1734
}
1735
}
1736
1737
// Go through all virtual registers to establish the mapping between the
1738
// global virtual
1739
// register number and the per class virtual register number.
1740
// We use the per class virtual register number in the ptx output.
1741
unsigned int numVRs = MRI->getNumVirtRegs();
1742
for (unsigned i = 0; i < numVRs; i++) {
1743
Register vr = Register::index2VirtReg(i);
1744
const TargetRegisterClass *RC = MRI->getRegClass(vr);
1745
DenseMap<unsigned, unsigned> &regmap = VRegMapping[RC];
1746
int n = regmap.size();
1747
regmap.insert(std::make_pair(vr, n + 1));
1748
}
1749
1750
// Emit register declarations
1751
// @TODO: Extract out the real register usage
1752
// O << "\t.reg .pred %p<" << NVPTXNumRegisters << ">;\n";
1753
// O << "\t.reg .s16 %rc<" << NVPTXNumRegisters << ">;\n";
1754
// O << "\t.reg .s16 %rs<" << NVPTXNumRegisters << ">;\n";
1755
// O << "\t.reg .s32 %r<" << NVPTXNumRegisters << ">;\n";
1756
// O << "\t.reg .s64 %rd<" << NVPTXNumRegisters << ">;\n";
1757
// O << "\t.reg .f32 %f<" << NVPTXNumRegisters << ">;\n";
1758
// O << "\t.reg .f64 %fd<" << NVPTXNumRegisters << ">;\n";
1759
1760
// Emit declaration of the virtual registers or 'physical' registers for
1761
// each register class
1762
for (unsigned i=0; i< TRI->getNumRegClasses(); i++) {
1763
const TargetRegisterClass *RC = TRI->getRegClass(i);
1764
DenseMap<unsigned, unsigned> &regmap = VRegMapping[RC];
1765
std::string rcname = getNVPTXRegClassName(RC);
1766
std::string rcStr = getNVPTXRegClassStr(RC);
1767
int n = regmap.size();
1768
1769
// Only declare those registers that may be used.
1770
if (n) {
1771
O << "\t.reg " << rcname << " \t" << rcStr << "<" << (n+1)
1772
<< ">;\n";
1773
}
1774
}
1775
1776
OutStreamer->emitRawText(O.str());
1777
}
1778
1779
void NVPTXAsmPrinter::printFPConstant(const ConstantFP *Fp, raw_ostream &O) {
1780
APFloat APF = APFloat(Fp->getValueAPF()); // make a copy
1781
bool ignored;
1782
unsigned int numHex;
1783
const char *lead;
1784
1785
if (Fp->getType()->getTypeID() == Type::FloatTyID) {
1786
numHex = 8;
1787
lead = "0f";
1788
APF.convert(APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven, &ignored);
1789
} else if (Fp->getType()->getTypeID() == Type::DoubleTyID) {
1790
numHex = 16;
1791
lead = "0d";
1792
APF.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, &ignored);
1793
} else
1794
llvm_unreachable("unsupported fp type");
1795
1796
APInt API = APF.bitcastToAPInt();
1797
O << lead << format_hex_no_prefix(API.getZExtValue(), numHex, /*Upper=*/true);
1798
}
1799
1800
void NVPTXAsmPrinter::printScalarConstant(const Constant *CPV, raw_ostream &O) {
1801
if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) {
1802
O << CI->getValue();
1803
return;
1804
}
1805
if (const ConstantFP *CFP = dyn_cast<ConstantFP>(CPV)) {
1806
printFPConstant(CFP, O);
1807
return;
1808
}
1809
if (isa<ConstantPointerNull>(CPV)) {
1810
O << "0";
1811
return;
1812
}
1813
if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) {
1814
bool IsNonGenericPointer = false;
1815
if (GVar->getType()->getAddressSpace() != 0) {
1816
IsNonGenericPointer = true;
1817
}
1818
if (EmitGeneric && !isa<Function>(CPV) && !IsNonGenericPointer) {
1819
O << "generic(";
1820
getSymbol(GVar)->print(O, MAI);
1821
O << ")";
1822
} else {
1823
getSymbol(GVar)->print(O, MAI);
1824
}
1825
return;
1826
}
1827
if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) {
1828
const MCExpr *E = lowerConstantForGV(cast<Constant>(Cexpr), false);
1829
printMCExpr(*E, O);
1830
return;
1831
}
1832
llvm_unreachable("Not scalar type found in printScalarConstant()");
1833
}
1834
1835
void NVPTXAsmPrinter::bufferLEByte(const Constant *CPV, int Bytes,
1836
AggBuffer *AggBuffer) {
1837
const DataLayout &DL = getDataLayout();
1838
int AllocSize = DL.getTypeAllocSize(CPV->getType());
1839
if (isa<UndefValue>(CPV) || CPV->isNullValue()) {
1840
// Non-zero Bytes indicates that we need to zero-fill everything. Otherwise,
1841
// only the space allocated by CPV.
1842
AggBuffer->addZeros(Bytes ? Bytes : AllocSize);
1843
return;
1844
}
1845
1846
// Helper for filling AggBuffer with APInts.
1847
auto AddIntToBuffer = [AggBuffer, Bytes](const APInt &Val) {
1848
size_t NumBytes = (Val.getBitWidth() + 7) / 8;
1849
SmallVector<unsigned char, 16> Buf(NumBytes);
1850
// `extractBitsAsZExtValue` does not allow the extraction of bits beyond the
1851
// input's bit width, and i1 arrays may not have a length that is a multuple
1852
// of 8. We handle the last byte separately, so we never request out of
1853
// bounds bits.
1854
for (unsigned I = 0; I < NumBytes - 1; ++I) {
1855
Buf[I] = Val.extractBitsAsZExtValue(8, I * 8);
1856
}
1857
size_t LastBytePosition = (NumBytes - 1) * 8;
1858
size_t LastByteBits = Val.getBitWidth() - LastBytePosition;
1859
Buf[NumBytes - 1] =
1860
Val.extractBitsAsZExtValue(LastByteBits, LastBytePosition);
1861
AggBuffer->addBytes(Buf.data(), NumBytes, Bytes);
1862
};
1863
1864
switch (CPV->getType()->getTypeID()) {
1865
case Type::IntegerTyID:
1866
if (const auto CI = dyn_cast<ConstantInt>(CPV)) {
1867
AddIntToBuffer(CI->getValue());
1868
break;
1869
}
1870
if (const auto *Cexpr = dyn_cast<ConstantExpr>(CPV)) {
1871
if (const auto *CI =
1872
dyn_cast<ConstantInt>(ConstantFoldConstant(Cexpr, DL))) {
1873
AddIntToBuffer(CI->getValue());
1874
break;
1875
}
1876
if (Cexpr->getOpcode() == Instruction::PtrToInt) {
1877
Value *V = Cexpr->getOperand(0)->stripPointerCasts();
1878
AggBuffer->addSymbol(V, Cexpr->getOperand(0));
1879
AggBuffer->addZeros(AllocSize);
1880
break;
1881
}
1882
}
1883
llvm_unreachable("unsupported integer const type");
1884
break;
1885
1886
case Type::HalfTyID:
1887
case Type::BFloatTyID:
1888
case Type::FloatTyID:
1889
case Type::DoubleTyID:
1890
AddIntToBuffer(cast<ConstantFP>(CPV)->getValueAPF().bitcastToAPInt());
1891
break;
1892
1893
case Type::PointerTyID: {
1894
if (const GlobalValue *GVar = dyn_cast<GlobalValue>(CPV)) {
1895
AggBuffer->addSymbol(GVar, GVar);
1896
} else if (const ConstantExpr *Cexpr = dyn_cast<ConstantExpr>(CPV)) {
1897
const Value *v = Cexpr->stripPointerCasts();
1898
AggBuffer->addSymbol(v, Cexpr);
1899
}
1900
AggBuffer->addZeros(AllocSize);
1901
break;
1902
}
1903
1904
case Type::ArrayTyID:
1905
case Type::FixedVectorTyID:
1906
case Type::StructTyID: {
1907
if (isa<ConstantAggregate>(CPV) || isa<ConstantDataSequential>(CPV)) {
1908
bufferAggregateConstant(CPV, AggBuffer);
1909
if (Bytes > AllocSize)
1910
AggBuffer->addZeros(Bytes - AllocSize);
1911
} else if (isa<ConstantAggregateZero>(CPV))
1912
AggBuffer->addZeros(Bytes);
1913
else
1914
llvm_unreachable("Unexpected Constant type");
1915
break;
1916
}
1917
1918
default:
1919
llvm_unreachable("unsupported type");
1920
}
1921
}
1922
1923
void NVPTXAsmPrinter::bufferAggregateConstant(const Constant *CPV,
1924
AggBuffer *aggBuffer) {
1925
const DataLayout &DL = getDataLayout();
1926
int Bytes;
1927
1928
// Integers of arbitrary width
1929
if (const ConstantInt *CI = dyn_cast<ConstantInt>(CPV)) {
1930
APInt Val = CI->getValue();
1931
for (unsigned I = 0, E = DL.getTypeAllocSize(CPV->getType()); I < E; ++I) {
1932
uint8_t Byte = Val.getLoBits(8).getZExtValue();
1933
aggBuffer->addBytes(&Byte, 1, 1);
1934
Val.lshrInPlace(8);
1935
}
1936
return;
1937
}
1938
1939
// Old constants
1940
if (isa<ConstantArray>(CPV) || isa<ConstantVector>(CPV)) {
1941
if (CPV->getNumOperands())
1942
for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i)
1943
bufferLEByte(cast<Constant>(CPV->getOperand(i)), 0, aggBuffer);
1944
return;
1945
}
1946
1947
if (const ConstantDataSequential *CDS =
1948
dyn_cast<ConstantDataSequential>(CPV)) {
1949
if (CDS->getNumElements())
1950
for (unsigned i = 0; i < CDS->getNumElements(); ++i)
1951
bufferLEByte(cast<Constant>(CDS->getElementAsConstant(i)), 0,
1952
aggBuffer);
1953
return;
1954
}
1955
1956
if (isa<ConstantStruct>(CPV)) {
1957
if (CPV->getNumOperands()) {
1958
StructType *ST = cast<StructType>(CPV->getType());
1959
for (unsigned i = 0, e = CPV->getNumOperands(); i != e; ++i) {
1960
if (i == (e - 1))
1961
Bytes = DL.getStructLayout(ST)->getElementOffset(0) +
1962
DL.getTypeAllocSize(ST) -
1963
DL.getStructLayout(ST)->getElementOffset(i);
1964
else
1965
Bytes = DL.getStructLayout(ST)->getElementOffset(i + 1) -
1966
DL.getStructLayout(ST)->getElementOffset(i);
1967
bufferLEByte(cast<Constant>(CPV->getOperand(i)), Bytes, aggBuffer);
1968
}
1969
}
1970
return;
1971
}
1972
llvm_unreachable("unsupported constant type in printAggregateConstant()");
1973
}
1974
1975
/// lowerConstantForGV - Return an MCExpr for the given Constant. This is mostly
1976
/// a copy from AsmPrinter::lowerConstant, except customized to only handle
1977
/// expressions that are representable in PTX and create
1978
/// NVPTXGenericMCSymbolRefExpr nodes for addrspacecast instructions.
1979
const MCExpr *
1980
NVPTXAsmPrinter::lowerConstantForGV(const Constant *CV, bool ProcessingGeneric) {
1981
MCContext &Ctx = OutContext;
1982
1983
if (CV->isNullValue() || isa<UndefValue>(CV))
1984
return MCConstantExpr::create(0, Ctx);
1985
1986
if (const ConstantInt *CI = dyn_cast<ConstantInt>(CV))
1987
return MCConstantExpr::create(CI->getZExtValue(), Ctx);
1988
1989
if (const GlobalValue *GV = dyn_cast<GlobalValue>(CV)) {
1990
const MCSymbolRefExpr *Expr =
1991
MCSymbolRefExpr::create(getSymbol(GV), Ctx);
1992
if (ProcessingGeneric) {
1993
return NVPTXGenericMCSymbolRefExpr::create(Expr, Ctx);
1994
} else {
1995
return Expr;
1996
}
1997
}
1998
1999
const ConstantExpr *CE = dyn_cast<ConstantExpr>(CV);
2000
if (!CE) {
2001
llvm_unreachable("Unknown constant value to lower!");
2002
}
2003
2004
switch (CE->getOpcode()) {
2005
default:
2006
break; // Error
2007
2008
case Instruction::AddrSpaceCast: {
2009
// Strip the addrspacecast and pass along the operand
2010
PointerType *DstTy = cast<PointerType>(CE->getType());
2011
if (DstTy->getAddressSpace() == 0)
2012
return lowerConstantForGV(cast<const Constant>(CE->getOperand(0)), true);
2013
2014
break; // Error
2015
}
2016
2017
case Instruction::GetElementPtr: {
2018
const DataLayout &DL = getDataLayout();
2019
2020
// Generate a symbolic expression for the byte address
2021
APInt OffsetAI(DL.getPointerTypeSizeInBits(CE->getType()), 0);
2022
cast<GEPOperator>(CE)->accumulateConstantOffset(DL, OffsetAI);
2023
2024
const MCExpr *Base = lowerConstantForGV(CE->getOperand(0),
2025
ProcessingGeneric);
2026
if (!OffsetAI)
2027
return Base;
2028
2029
int64_t Offset = OffsetAI.getSExtValue();
2030
return MCBinaryExpr::createAdd(Base, MCConstantExpr::create(Offset, Ctx),
2031
Ctx);
2032
}
2033
2034
case Instruction::Trunc:
2035
// We emit the value and depend on the assembler to truncate the generated
2036
// expression properly. This is important for differences between
2037
// blockaddress labels. Since the two labels are in the same function, it
2038
// is reasonable to treat their delta as a 32-bit value.
2039
[[fallthrough]];
2040
case Instruction::BitCast:
2041
return lowerConstantForGV(CE->getOperand(0), ProcessingGeneric);
2042
2043
case Instruction::IntToPtr: {
2044
const DataLayout &DL = getDataLayout();
2045
2046
// Handle casts to pointers by changing them into casts to the appropriate
2047
// integer type. This promotes constant folding and simplifies this code.
2048
Constant *Op = CE->getOperand(0);
2049
Op = ConstantFoldIntegerCast(Op, DL.getIntPtrType(CV->getType()),
2050
/*IsSigned*/ false, DL);
2051
if (Op)
2052
return lowerConstantForGV(Op, ProcessingGeneric);
2053
2054
break; // Error
2055
}
2056
2057
case Instruction::PtrToInt: {
2058
const DataLayout &DL = getDataLayout();
2059
2060
// Support only foldable casts to/from pointers that can be eliminated by
2061
// changing the pointer to the appropriately sized integer type.
2062
Constant *Op = CE->getOperand(0);
2063
Type *Ty = CE->getType();
2064
2065
const MCExpr *OpExpr = lowerConstantForGV(Op, ProcessingGeneric);
2066
2067
// We can emit the pointer value into this slot if the slot is an
2068
// integer slot equal to the size of the pointer.
2069
if (DL.getTypeAllocSize(Ty) == DL.getTypeAllocSize(Op->getType()))
2070
return OpExpr;
2071
2072
// Otherwise the pointer is smaller than the resultant integer, mask off
2073
// the high bits so we are sure to get a proper truncation if the input is
2074
// a constant expr.
2075
unsigned InBits = DL.getTypeAllocSizeInBits(Op->getType());
2076
const MCExpr *MaskExpr = MCConstantExpr::create(~0ULL >> (64-InBits), Ctx);
2077
return MCBinaryExpr::createAnd(OpExpr, MaskExpr, Ctx);
2078
}
2079
2080
// The MC library also has a right-shift operator, but it isn't consistently
2081
// signed or unsigned between different targets.
2082
case Instruction::Add: {
2083
const MCExpr *LHS = lowerConstantForGV(CE->getOperand(0), ProcessingGeneric);
2084
const MCExpr *RHS = lowerConstantForGV(CE->getOperand(1), ProcessingGeneric);
2085
switch (CE->getOpcode()) {
2086
default: llvm_unreachable("Unknown binary operator constant cast expr");
2087
case Instruction::Add: return MCBinaryExpr::createAdd(LHS, RHS, Ctx);
2088
}
2089
}
2090
}
2091
2092
// If the code isn't optimized, there may be outstanding folding
2093
// opportunities. Attempt to fold the expression using DataLayout as a
2094
// last resort before giving up.
2095
Constant *C = ConstantFoldConstant(CE, getDataLayout());
2096
if (C != CE)
2097
return lowerConstantForGV(C, ProcessingGeneric);
2098
2099
// Otherwise report the problem to the user.
2100
std::string S;
2101
raw_string_ostream OS(S);
2102
OS << "Unsupported expression in static initializer: ";
2103
CE->printAsOperand(OS, /*PrintType=*/false,
2104
!MF ? nullptr : MF->getFunction().getParent());
2105
report_fatal_error(Twine(OS.str()));
2106
}
2107
2108
// Copy of MCExpr::print customized for NVPTX
2109
void NVPTXAsmPrinter::printMCExpr(const MCExpr &Expr, raw_ostream &OS) {
2110
switch (Expr.getKind()) {
2111
case MCExpr::Target:
2112
return cast<MCTargetExpr>(&Expr)->printImpl(OS, MAI);
2113
case MCExpr::Constant:
2114
OS << cast<MCConstantExpr>(Expr).getValue();
2115
return;
2116
2117
case MCExpr::SymbolRef: {
2118
const MCSymbolRefExpr &SRE = cast<MCSymbolRefExpr>(Expr);
2119
const MCSymbol &Sym = SRE.getSymbol();
2120
Sym.print(OS, MAI);
2121
return;
2122
}
2123
2124
case MCExpr::Unary: {
2125
const MCUnaryExpr &UE = cast<MCUnaryExpr>(Expr);
2126
switch (UE.getOpcode()) {
2127
case MCUnaryExpr::LNot: OS << '!'; break;
2128
case MCUnaryExpr::Minus: OS << '-'; break;
2129
case MCUnaryExpr::Not: OS << '~'; break;
2130
case MCUnaryExpr::Plus: OS << '+'; break;
2131
}
2132
printMCExpr(*UE.getSubExpr(), OS);
2133
return;
2134
}
2135
2136
case MCExpr::Binary: {
2137
const MCBinaryExpr &BE = cast<MCBinaryExpr>(Expr);
2138
2139
// Only print parens around the LHS if it is non-trivial.
2140
if (isa<MCConstantExpr>(BE.getLHS()) || isa<MCSymbolRefExpr>(BE.getLHS()) ||
2141
isa<NVPTXGenericMCSymbolRefExpr>(BE.getLHS())) {
2142
printMCExpr(*BE.getLHS(), OS);
2143
} else {
2144
OS << '(';
2145
printMCExpr(*BE.getLHS(), OS);
2146
OS<< ')';
2147
}
2148
2149
switch (BE.getOpcode()) {
2150
case MCBinaryExpr::Add:
2151
// Print "X-42" instead of "X+-42".
2152
if (const MCConstantExpr *RHSC = dyn_cast<MCConstantExpr>(BE.getRHS())) {
2153
if (RHSC->getValue() < 0) {
2154
OS << RHSC->getValue();
2155
return;
2156
}
2157
}
2158
2159
OS << '+';
2160
break;
2161
default: llvm_unreachable("Unhandled binary operator");
2162
}
2163
2164
// Only print parens around the LHS if it is non-trivial.
2165
if (isa<MCConstantExpr>(BE.getRHS()) || isa<MCSymbolRefExpr>(BE.getRHS())) {
2166
printMCExpr(*BE.getRHS(), OS);
2167
} else {
2168
OS << '(';
2169
printMCExpr(*BE.getRHS(), OS);
2170
OS << ')';
2171
}
2172
return;
2173
}
2174
}
2175
2176
llvm_unreachable("Invalid expression kind!");
2177
}
2178
2179
/// PrintAsmOperand - Print out an operand for an inline asm expression.
2180
///
2181
bool NVPTXAsmPrinter::PrintAsmOperand(const MachineInstr *MI, unsigned OpNo,
2182
const char *ExtraCode, raw_ostream &O) {
2183
if (ExtraCode && ExtraCode[0]) {
2184
if (ExtraCode[1] != 0)
2185
return true; // Unknown modifier.
2186
2187
switch (ExtraCode[0]) {
2188
default:
2189
// See if this is a generic print operand
2190
return AsmPrinter::PrintAsmOperand(MI, OpNo, ExtraCode, O);
2191
case 'r':
2192
break;
2193
}
2194
}
2195
2196
printOperand(MI, OpNo, O);
2197
2198
return false;
2199
}
2200
2201
bool NVPTXAsmPrinter::PrintAsmMemoryOperand(const MachineInstr *MI,
2202
unsigned OpNo,
2203
const char *ExtraCode,
2204
raw_ostream &O) {
2205
if (ExtraCode && ExtraCode[0])
2206
return true; // Unknown modifier
2207
2208
O << '[';
2209
printMemOperand(MI, OpNo, O);
2210
O << ']';
2211
2212
return false;
2213
}
2214
2215
void NVPTXAsmPrinter::printOperand(const MachineInstr *MI, unsigned OpNum,
2216
raw_ostream &O) {
2217
const MachineOperand &MO = MI->getOperand(OpNum);
2218
switch (MO.getType()) {
2219
case MachineOperand::MO_Register:
2220
if (MO.getReg().isPhysical()) {
2221
if (MO.getReg() == NVPTX::VRDepot)
2222
O << DEPOTNAME << getFunctionNumber();
2223
else
2224
O << NVPTXInstPrinter::getRegisterName(MO.getReg());
2225
} else {
2226
emitVirtualRegister(MO.getReg(), O);
2227
}
2228
break;
2229
2230
case MachineOperand::MO_Immediate:
2231
O << MO.getImm();
2232
break;
2233
2234
case MachineOperand::MO_FPImmediate:
2235
printFPConstant(MO.getFPImm(), O);
2236
break;
2237
2238
case MachineOperand::MO_GlobalAddress:
2239
PrintSymbolOperand(MO, O);
2240
break;
2241
2242
case MachineOperand::MO_MachineBasicBlock:
2243
MO.getMBB()->getSymbol()->print(O, MAI);
2244
break;
2245
2246
default:
2247
llvm_unreachable("Operand type not supported.");
2248
}
2249
}
2250
2251
void NVPTXAsmPrinter::printMemOperand(const MachineInstr *MI, unsigned OpNum,
2252
raw_ostream &O, const char *Modifier) {
2253
printOperand(MI, OpNum, O);
2254
2255
if (Modifier && strcmp(Modifier, "add") == 0) {
2256
O << ", ";
2257
printOperand(MI, OpNum + 1, O);
2258
} else {
2259
if (MI->getOperand(OpNum + 1).isImm() &&
2260
MI->getOperand(OpNum + 1).getImm() == 0)
2261
return; // don't print ',0' or '+0'
2262
O << "+";
2263
printOperand(MI, OpNum + 1, O);
2264
}
2265
}
2266
2267
// Force static initialization.
2268
extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeNVPTXAsmPrinter() {
2269
RegisterAsmPrinter<NVPTXAsmPrinter> X(getTheNVPTXTarget32());
2270
RegisterAsmPrinter<NVPTXAsmPrinter> Y(getTheNVPTXTarget64());
2271
}
2272
2273