Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
freebsd
GitHub Repository: freebsd/freebsd-src
Path: blob/main/contrib/llvm-project/llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp
35271 views
1
//===-- NVPTXLowerArgs.cpp - Lower arguments ------------------------------===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
//
9
//
10
// Arguments to kernel and device functions are passed via param space,
11
// which imposes certain restrictions:
12
// http://docs.nvidia.com/cuda/parallel-thread-execution/#state-spaces
13
//
14
// Kernel parameters are read-only and accessible only via ld.param
15
// instruction, directly or via a pointer.
16
//
17
// Device function parameters are directly accessible via
18
// ld.param/st.param, but taking the address of one returns a pointer
19
// to a copy created in local space which *can't* be used with
20
// ld.param/st.param.
21
//
22
// Copying a byval struct into local memory in IR allows us to enforce
23
// the param space restrictions, gives the rest of IR a pointer w/o
24
// param space restrictions, and gives us an opportunity to eliminate
25
// the copy.
26
//
27
// Pointer arguments to kernel functions need more work to be lowered:
28
//
29
// 1. Convert non-byval pointer arguments of CUDA kernels to pointers in the
30
// global address space. This allows later optimizations to emit
31
// ld.global.*/st.global.* for accessing these pointer arguments. For
32
// example,
33
//
34
// define void @foo(float* %input) {
35
// %v = load float, float* %input, align 4
36
// ...
37
// }
38
//
39
// becomes
40
//
41
// define void @foo(float* %input) {
42
// %input2 = addrspacecast float* %input to float addrspace(1)*
43
// %input3 = addrspacecast float addrspace(1)* %input2 to float*
44
// %v = load float, float* %input3, align 4
45
// ...
46
// }
47
//
48
// Later, NVPTXInferAddressSpaces will optimize it to
49
//
50
// define void @foo(float* %input) {
51
// %input2 = addrspacecast float* %input to float addrspace(1)*
52
// %v = load float, float addrspace(1)* %input2, align 4
53
// ...
54
// }
55
//
56
// 2. Convert byval kernel parameters to pointers in the param address space
57
// (so that NVPTX emits ld/st.param). Convert pointers *within* a byval
58
// kernel parameter to pointers in the global address space. This allows
59
// NVPTX to emit ld/st.global.
60
//
61
// struct S {
62
// int *x;
63
// int *y;
64
// };
65
// __global__ void foo(S s) {
66
// int *b = s.y;
67
// // use b
68
// }
69
//
70
// "b" points to the global address space. In the IR level,
71
//
72
// define void @foo(ptr byval %input) {
73
// %b_ptr = getelementptr {ptr, ptr}, ptr %input, i64 0, i32 1
74
// %b = load ptr, ptr %b_ptr
75
// ; use %b
76
// }
77
//
78
// becomes
79
//
80
// define void @foo({i32*, i32*}* byval %input) {
81
// %b_param = addrspacecat ptr %input to ptr addrspace(101)
82
// %b_ptr = getelementptr {ptr, ptr}, ptr addrspace(101) %b_param, i64 0, i32 1
83
// %b = load ptr, ptr addrspace(101) %b_ptr
84
// %b_global = addrspacecast ptr %b to ptr addrspace(1)
85
// ; use %b_generic
86
// }
87
//
88
// Create a local copy of kernel byval parameters used in a way that *might* mutate
89
// the parameter, by storing it in an alloca. Mutations to "grid_constant" parameters
90
// are undefined behaviour, and don't require local copies.
91
//
92
// define void @foo(ptr byval(%struct.s) align 4 %input) {
93
// store i32 42, ptr %input
94
// ret void
95
// }
96
//
97
// becomes
98
//
99
// define void @foo(ptr byval(%struct.s) align 4 %input) #1 {
100
// %input1 = alloca %struct.s, align 4
101
// %input2 = addrspacecast ptr %input to ptr addrspace(101)
102
// %input3 = load %struct.s, ptr addrspace(101) %input2, align 4
103
// store %struct.s %input3, ptr %input1, align 4
104
// store i32 42, ptr %input1, align 4
105
// ret void
106
// }
107
//
108
// If %input were passed to a device function, or written to memory,
109
// conservatively assume that %input gets mutated, and create a local copy.
110
//
111
// Convert param pointers to grid_constant byval kernel parameters that are
112
// passed into calls (device functions, intrinsics, inline asm), or otherwise
113
// "escape" (into stores/ptrtoints) to the generic address space, using the
114
// `nvvm.ptr.param.to.gen` intrinsic, so that NVPTX emits cvta.param
115
// (available for sm70+)
116
//
117
// define void @foo(ptr byval(%struct.s) %input) {
118
// ; %input is a grid_constant
119
// %call = call i32 @escape(ptr %input)
120
// ret void
121
// }
122
//
123
// becomes
124
//
125
// define void @foo(ptr byval(%struct.s) %input) {
126
// %input1 = addrspacecast ptr %input to ptr addrspace(101)
127
// ; the following intrinsic converts pointer to generic. We don't use an addrspacecast
128
// ; to prevent generic -> param -> generic from getting cancelled out
129
// %input1.gen = call ptr @llvm.nvvm.ptr.param.to.gen.p0.p101(ptr addrspace(101) %input1)
130
// %call = call i32 @escape(ptr %input1.gen)
131
// ret void
132
// }
133
//
134
// TODO: merge this pass with NVPTXInferAddressSpaces so that other passes don't
135
// cancel the addrspacecast pair this pass emits.
136
//===----------------------------------------------------------------------===//
137
138
#include "MCTargetDesc/NVPTXBaseInfo.h"
139
#include "NVPTX.h"
140
#include "NVPTXTargetMachine.h"
141
#include "NVPTXUtilities.h"
142
#include "llvm/Analysis/ValueTracking.h"
143
#include "llvm/CodeGen/TargetPassConfig.h"
144
#include "llvm/IR/Function.h"
145
#include "llvm/IR/IRBuilder.h"
146
#include "llvm/IR/Instructions.h"
147
#include "llvm/IR/IntrinsicsNVPTX.h"
148
#include "llvm/IR/Module.h"
149
#include "llvm/IR/Type.h"
150
#include "llvm/InitializePasses.h"
151
#include "llvm/Pass.h"
152
#include <numeric>
153
#include <queue>
154
155
#define DEBUG_TYPE "nvptx-lower-args"
156
157
using namespace llvm;
158
159
namespace llvm {
160
void initializeNVPTXLowerArgsPass(PassRegistry &);
161
}
162
163
namespace {
164
class NVPTXLowerArgs : public FunctionPass {
165
bool runOnFunction(Function &F) override;
166
167
bool runOnKernelFunction(const NVPTXTargetMachine &TM, Function &F);
168
bool runOnDeviceFunction(const NVPTXTargetMachine &TM, Function &F);
169
170
// handle byval parameters
171
void handleByValParam(const NVPTXTargetMachine &TM, Argument *Arg);
172
// Knowing Ptr must point to the global address space, this function
173
// addrspacecasts Ptr to global and then back to generic. This allows
174
// NVPTXInferAddressSpaces to fold the global-to-generic cast into
175
// loads/stores that appear later.
176
void markPointerAsGlobal(Value *Ptr);
177
178
public:
179
static char ID; // Pass identification, replacement for typeid
180
NVPTXLowerArgs() : FunctionPass(ID) {}
181
StringRef getPassName() const override {
182
return "Lower pointer arguments of CUDA kernels";
183
}
184
void getAnalysisUsage(AnalysisUsage &AU) const override {
185
AU.addRequired<TargetPassConfig>();
186
}
187
};
188
} // namespace
189
190
char NVPTXLowerArgs::ID = 1;
191
192
INITIALIZE_PASS_BEGIN(NVPTXLowerArgs, "nvptx-lower-args",
193
"Lower arguments (NVPTX)", false, false)
194
INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
195
INITIALIZE_PASS_END(NVPTXLowerArgs, "nvptx-lower-args",
196
"Lower arguments (NVPTX)", false, false)
197
198
// =============================================================================
199
// If the function had a byval struct ptr arg, say foo(%struct.x* byval %d),
200
// and we can't guarantee that the only accesses are loads,
201
// then add the following instructions to the first basic block:
202
//
203
// %temp = alloca %struct.x, align 8
204
// %tempd = addrspacecast %struct.x* %d to %struct.x addrspace(101)*
205
// %tv = load %struct.x addrspace(101)* %tempd
206
// store %struct.x %tv, %struct.x* %temp, align 8
207
//
208
// The above code allocates some space in the stack and copies the incoming
209
// struct from param space to local space.
210
// Then replace all occurrences of %d by %temp.
211
//
212
// In case we know that all users are GEPs or Loads, replace them with the same
213
// ones in parameter AS, so we can access them using ld.param.
214
// =============================================================================
215
216
// For Loads, replaces the \p OldUse of the pointer with a Use of the same
217
// pointer in parameter AS.
218
// For "escapes" (to memory, a function call, or a ptrtoint), cast the OldUse to
219
// generic using cvta.param.
220
static void convertToParamAS(Use *OldUse, Value *Param, bool GridConstant) {
221
Instruction *I = dyn_cast<Instruction>(OldUse->getUser());
222
assert(I && "OldUse must be in an instruction");
223
struct IP {
224
Use *OldUse;
225
Instruction *OldInstruction;
226
Value *NewParam;
227
};
228
SmallVector<IP> ItemsToConvert = {{OldUse, I, Param}};
229
SmallVector<Instruction *> InstructionsToDelete;
230
231
auto CloneInstInParamAS = [GridConstant](const IP &I) -> Value * {
232
if (auto *LI = dyn_cast<LoadInst>(I.OldInstruction)) {
233
LI->setOperand(0, I.NewParam);
234
return LI;
235
}
236
if (auto *GEP = dyn_cast<GetElementPtrInst>(I.OldInstruction)) {
237
SmallVector<Value *, 4> Indices(GEP->indices());
238
auto *NewGEP = GetElementPtrInst::Create(
239
GEP->getSourceElementType(), I.NewParam, Indices, GEP->getName(),
240
GEP->getIterator());
241
NewGEP->setIsInBounds(GEP->isInBounds());
242
return NewGEP;
243
}
244
if (auto *BC = dyn_cast<BitCastInst>(I.OldInstruction)) {
245
auto *NewBCType = PointerType::get(BC->getContext(), ADDRESS_SPACE_PARAM);
246
return BitCastInst::Create(BC->getOpcode(), I.NewParam, NewBCType,
247
BC->getName(), BC->getIterator());
248
}
249
if (auto *ASC = dyn_cast<AddrSpaceCastInst>(I.OldInstruction)) {
250
assert(ASC->getDestAddressSpace() == ADDRESS_SPACE_PARAM);
251
(void)ASC;
252
// Just pass through the argument, the old ASC is no longer needed.
253
return I.NewParam;
254
}
255
256
if (GridConstant) {
257
auto GetParamAddrCastToGeneric =
258
[](Value *Addr, Instruction *OriginalUser) -> Value * {
259
PointerType *ReturnTy =
260
PointerType::get(OriginalUser->getContext(), ADDRESS_SPACE_GENERIC);
261
Function *CvtToGen = Intrinsic::getDeclaration(
262
OriginalUser->getModule(), Intrinsic::nvvm_ptr_param_to_gen,
263
{ReturnTy, PointerType::get(OriginalUser->getContext(),
264
ADDRESS_SPACE_PARAM)});
265
266
// Cast param address to generic address space
267
Value *CvtToGenCall =
268
CallInst::Create(CvtToGen, Addr, Addr->getName() + ".gen",
269
OriginalUser->getIterator());
270
return CvtToGenCall;
271
};
272
273
if (auto *CI = dyn_cast<CallInst>(I.OldInstruction)) {
274
I.OldUse->set(GetParamAddrCastToGeneric(I.NewParam, CI));
275
return CI;
276
}
277
if (auto *SI = dyn_cast<StoreInst>(I.OldInstruction)) {
278
// byval address is being stored, cast it to generic
279
if (SI->getValueOperand() == I.OldUse->get())
280
SI->setOperand(0, GetParamAddrCastToGeneric(I.NewParam, SI));
281
return SI;
282
}
283
if (auto *PI = dyn_cast<PtrToIntInst>(I.OldInstruction)) {
284
if (PI->getPointerOperand() == I.OldUse->get())
285
PI->setOperand(0, GetParamAddrCastToGeneric(I.NewParam, PI));
286
return PI;
287
}
288
llvm_unreachable(
289
"Instruction unsupported even for grid_constant argument");
290
}
291
292
llvm_unreachable("Unsupported instruction");
293
};
294
295
while (!ItemsToConvert.empty()) {
296
IP I = ItemsToConvert.pop_back_val();
297
Value *NewInst = CloneInstInParamAS(I);
298
299
if (NewInst && NewInst != I.OldInstruction) {
300
// We've created a new instruction. Queue users of the old instruction to
301
// be converted and the instruction itself to be deleted. We can't delete
302
// the old instruction yet, because it's still in use by a load somewhere.
303
for (Use &U : I.OldInstruction->uses())
304
ItemsToConvert.push_back({&U, cast<Instruction>(U.getUser()), NewInst});
305
306
InstructionsToDelete.push_back(I.OldInstruction);
307
}
308
}
309
310
// Now we know that all argument loads are using addresses in parameter space
311
// and we can finally remove the old instructions in generic AS. Instructions
312
// scheduled for removal should be processed in reverse order so the ones
313
// closest to the load are deleted first. Otherwise they may still be in use.
314
// E.g if we have Value = Load(BitCast(GEP(arg))), InstructionsToDelete will
315
// have {GEP,BitCast}. GEP can't be deleted first, because it's still used by
316
// the BitCast.
317
for (Instruction *I : llvm::reverse(InstructionsToDelete))
318
I->eraseFromParent();
319
}
320
321
// Adjust alignment of arguments passed byval in .param address space. We can
322
// increase alignment of such arguments in a way that ensures that we can
323
// effectively vectorize their loads. We should also traverse all loads from
324
// byval pointer and adjust their alignment, if those were using known offset.
325
// Such alignment changes must be conformed with parameter store and load in
326
// NVPTXTargetLowering::LowerCall.
327
static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS,
328
const NVPTXTargetLowering *TLI) {
329
Function *Func = Arg->getParent();
330
Type *StructType = Arg->getParamByValType();
331
const DataLayout DL(Func->getParent());
332
333
uint64_t NewArgAlign =
334
TLI->getFunctionParamOptimizedAlign(Func, StructType, DL).value();
335
uint64_t CurArgAlign =
336
Arg->getAttribute(Attribute::Alignment).getValueAsInt();
337
338
if (CurArgAlign >= NewArgAlign)
339
return;
340
341
LLVM_DEBUG(dbgs() << "Try to use alignment " << NewArgAlign << " instead of "
342
<< CurArgAlign << " for " << *Arg << '\n');
343
344
auto NewAlignAttr =
345
Attribute::get(Func->getContext(), Attribute::Alignment, NewArgAlign);
346
Arg->removeAttr(Attribute::Alignment);
347
Arg->addAttr(NewAlignAttr);
348
349
struct Load {
350
LoadInst *Inst;
351
uint64_t Offset;
352
};
353
354
struct LoadContext {
355
Value *InitialVal;
356
uint64_t Offset;
357
};
358
359
SmallVector<Load> Loads;
360
std::queue<LoadContext> Worklist;
361
Worklist.push({ArgInParamAS, 0});
362
bool IsGridConstant = isParamGridConstant(*Arg);
363
364
while (!Worklist.empty()) {
365
LoadContext Ctx = Worklist.front();
366
Worklist.pop();
367
368
for (User *CurUser : Ctx.InitialVal->users()) {
369
if (auto *I = dyn_cast<LoadInst>(CurUser)) {
370
Loads.push_back({I, Ctx.Offset});
371
continue;
372
}
373
374
if (auto *I = dyn_cast<BitCastInst>(CurUser)) {
375
Worklist.push({I, Ctx.Offset});
376
continue;
377
}
378
379
if (auto *I = dyn_cast<GetElementPtrInst>(CurUser)) {
380
APInt OffsetAccumulated =
381
APInt::getZero(DL.getIndexSizeInBits(ADDRESS_SPACE_PARAM));
382
383
if (!I->accumulateConstantOffset(DL, OffsetAccumulated))
384
continue;
385
386
uint64_t OffsetLimit = -1;
387
uint64_t Offset = OffsetAccumulated.getLimitedValue(OffsetLimit);
388
assert(Offset != OffsetLimit && "Expect Offset less than UINT64_MAX");
389
390
Worklist.push({I, Ctx.Offset + Offset});
391
continue;
392
}
393
394
// supported for grid_constant
395
if (IsGridConstant &&
396
(isa<CallInst>(CurUser) || isa<StoreInst>(CurUser) ||
397
isa<PtrToIntInst>(CurUser)))
398
continue;
399
400
llvm_unreachable("All users must be one of: load, "
401
"bitcast, getelementptr, call, store, ptrtoint");
402
}
403
}
404
405
for (Load &CurLoad : Loads) {
406
Align NewLoadAlign(std::gcd(NewArgAlign, CurLoad.Offset));
407
Align CurLoadAlign(CurLoad.Inst->getAlign());
408
CurLoad.Inst->setAlignment(std::max(NewLoadAlign, CurLoadAlign));
409
}
410
}
411
412
void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
413
Argument *Arg) {
414
bool IsGridConstant = isParamGridConstant(*Arg);
415
Function *Func = Arg->getParent();
416
BasicBlock::iterator FirstInst = Func->getEntryBlock().begin();
417
Type *StructType = Arg->getParamByValType();
418
assert(StructType && "Missing byval type");
419
420
auto AreSupportedUsers = [&](Value *Start) {
421
SmallVector<Value *, 16> ValuesToCheck = {Start};
422
auto IsSupportedUse = [IsGridConstant](Value *V) -> bool {
423
if (isa<GetElementPtrInst>(V) || isa<BitCastInst>(V) || isa<LoadInst>(V))
424
return true;
425
// ASC to param space are OK, too -- we'll just strip them.
426
if (auto *ASC = dyn_cast<AddrSpaceCastInst>(V)) {
427
if (ASC->getDestAddressSpace() == ADDRESS_SPACE_PARAM)
428
return true;
429
}
430
// Simple calls and stores are supported for grid_constants
431
// writes to these pointers are undefined behaviour
432
if (IsGridConstant &&
433
(isa<CallInst>(V) || isa<StoreInst>(V) || isa<PtrToIntInst>(V)))
434
return true;
435
return false;
436
};
437
438
while (!ValuesToCheck.empty()) {
439
Value *V = ValuesToCheck.pop_back_val();
440
if (!IsSupportedUse(V)) {
441
LLVM_DEBUG(dbgs() << "Need a "
442
<< (isParamGridConstant(*Arg) ? "cast " : "copy ")
443
<< "of " << *Arg << " because of " << *V << "\n");
444
(void)Arg;
445
return false;
446
}
447
if (!isa<LoadInst>(V) && !isa<CallInst>(V) && !isa<StoreInst>(V) &&
448
!isa<PtrToIntInst>(V))
449
llvm::append_range(ValuesToCheck, V->users());
450
}
451
return true;
452
};
453
454
if (llvm::all_of(Arg->users(), AreSupportedUsers)) {
455
// Convert all loads and intermediate operations to use parameter AS and
456
// skip creation of a local copy of the argument.
457
SmallVector<Use *, 16> UsesToUpdate;
458
for (Use &U : Arg->uses())
459
UsesToUpdate.push_back(&U);
460
461
Value *ArgInParamAS = new AddrSpaceCastInst(
462
Arg, PointerType::get(StructType, ADDRESS_SPACE_PARAM), Arg->getName(),
463
FirstInst);
464
for (Use *U : UsesToUpdate)
465
convertToParamAS(U, ArgInParamAS, IsGridConstant);
466
LLVM_DEBUG(dbgs() << "No need to copy or cast " << *Arg << "\n");
467
468
const auto *TLI =
469
cast<NVPTXTargetLowering>(TM.getSubtargetImpl()->getTargetLowering());
470
471
adjustByValArgAlignment(Arg, ArgInParamAS, TLI);
472
473
return;
474
}
475
476
const DataLayout &DL = Func->getDataLayout();
477
unsigned AS = DL.getAllocaAddrSpace();
478
if (isParamGridConstant(*Arg)) {
479
// Writes to a grid constant are undefined behaviour. We do not need a
480
// temporary copy. When a pointer might have escaped, conservatively replace
481
// all of its uses (which might include a device function call) with a cast
482
// to the generic address space.
483
IRBuilder<> IRB(&Func->getEntryBlock().front());
484
485
// Cast argument to param address space
486
auto *CastToParam = cast<AddrSpaceCastInst>(IRB.CreateAddrSpaceCast(
487
Arg, IRB.getPtrTy(ADDRESS_SPACE_PARAM), Arg->getName() + ".param"));
488
489
// Cast param address to generic address space. We do not use an
490
// addrspacecast to generic here, because, LLVM considers `Arg` to be in the
491
// generic address space, and a `generic -> param` cast followed by a `param
492
// -> generic` cast will be folded away. The `param -> generic` intrinsic
493
// will be correctly lowered to `cvta.param`.
494
Value *CvtToGenCall = IRB.CreateIntrinsic(
495
IRB.getPtrTy(ADDRESS_SPACE_GENERIC), Intrinsic::nvvm_ptr_param_to_gen,
496
CastToParam, nullptr, CastToParam->getName() + ".gen");
497
498
Arg->replaceAllUsesWith(CvtToGenCall);
499
500
// Do not replace Arg in the cast to param space
501
CastToParam->setOperand(0, Arg);
502
} else {
503
// Otherwise we have to create a temporary copy.
504
AllocaInst *AllocA =
505
new AllocaInst(StructType, AS, Arg->getName(), FirstInst);
506
// Set the alignment to alignment of the byval parameter. This is because,
507
// later load/stores assume that alignment, and we are going to replace
508
// the use of the byval parameter with this alloca instruction.
509
AllocA->setAlignment(Func->getParamAlign(Arg->getArgNo())
510
.value_or(DL.getPrefTypeAlign(StructType)));
511
Arg->replaceAllUsesWith(AllocA);
512
513
Value *ArgInParam = new AddrSpaceCastInst(
514
Arg, PointerType::get(Arg->getContext(), ADDRESS_SPACE_PARAM),
515
Arg->getName(), FirstInst);
516
// Be sure to propagate alignment to this load; LLVM doesn't know that NVPTX
517
// addrspacecast preserves alignment. Since params are constant, this load
518
// is definitely not volatile.
519
LoadInst *LI =
520
new LoadInst(StructType, ArgInParam, Arg->getName(),
521
/*isVolatile=*/false, AllocA->getAlign(), FirstInst);
522
new StoreInst(LI, AllocA, FirstInst);
523
}
524
}
525
526
void NVPTXLowerArgs::markPointerAsGlobal(Value *Ptr) {
527
if (Ptr->getType()->getPointerAddressSpace() != ADDRESS_SPACE_GENERIC)
528
return;
529
530
// Deciding where to emit the addrspacecast pair.
531
BasicBlock::iterator InsertPt;
532
if (Argument *Arg = dyn_cast<Argument>(Ptr)) {
533
// Insert at the functon entry if Ptr is an argument.
534
InsertPt = Arg->getParent()->getEntryBlock().begin();
535
} else {
536
// Insert right after Ptr if Ptr is an instruction.
537
InsertPt = ++cast<Instruction>(Ptr)->getIterator();
538
assert(InsertPt != InsertPt->getParent()->end() &&
539
"We don't call this function with Ptr being a terminator.");
540
}
541
542
Instruction *PtrInGlobal = new AddrSpaceCastInst(
543
Ptr, PointerType::get(Ptr->getContext(), ADDRESS_SPACE_GLOBAL),
544
Ptr->getName(), InsertPt);
545
Value *PtrInGeneric = new AddrSpaceCastInst(PtrInGlobal, Ptr->getType(),
546
Ptr->getName(), InsertPt);
547
// Replace with PtrInGeneric all uses of Ptr except PtrInGlobal.
548
Ptr->replaceAllUsesWith(PtrInGeneric);
549
PtrInGlobal->setOperand(0, Ptr);
550
}
551
552
// =============================================================================
553
// Main function for this pass.
554
// =============================================================================
555
bool NVPTXLowerArgs::runOnKernelFunction(const NVPTXTargetMachine &TM,
556
Function &F) {
557
// Copying of byval aggregates + SROA may result in pointers being loaded as
558
// integers, followed by intotoptr. We may want to mark those as global, too,
559
// but only if the loaded integer is used exclusively for conversion to a
560
// pointer with inttoptr.
561
auto HandleIntToPtr = [this](Value &V) {
562
if (llvm::all_of(V.users(), [](User *U) { return isa<IntToPtrInst>(U); })) {
563
SmallVector<User *, 16> UsersToUpdate(V.users());
564
for (User *U : UsersToUpdate)
565
markPointerAsGlobal(U);
566
}
567
};
568
if (TM.getDrvInterface() == NVPTX::CUDA) {
569
// Mark pointers in byval structs as global.
570
for (auto &B : F) {
571
for (auto &I : B) {
572
if (LoadInst *LI = dyn_cast<LoadInst>(&I)) {
573
if (LI->getType()->isPointerTy() || LI->getType()->isIntegerTy()) {
574
Value *UO = getUnderlyingObject(LI->getPointerOperand());
575
if (Argument *Arg = dyn_cast<Argument>(UO)) {
576
if (Arg->hasByValAttr()) {
577
// LI is a load from a pointer within a byval kernel parameter.
578
if (LI->getType()->isPointerTy())
579
markPointerAsGlobal(LI);
580
else
581
HandleIntToPtr(*LI);
582
}
583
}
584
}
585
}
586
}
587
}
588
}
589
590
LLVM_DEBUG(dbgs() << "Lowering kernel args of " << F.getName() << "\n");
591
for (Argument &Arg : F.args()) {
592
if (Arg.getType()->isPointerTy()) {
593
if (Arg.hasByValAttr())
594
handleByValParam(TM, &Arg);
595
else if (TM.getDrvInterface() == NVPTX::CUDA)
596
markPointerAsGlobal(&Arg);
597
} else if (Arg.getType()->isIntegerTy() &&
598
TM.getDrvInterface() == NVPTX::CUDA) {
599
HandleIntToPtr(Arg);
600
}
601
}
602
return true;
603
}
604
605
// Device functions only need to copy byval args into local memory.
606
bool NVPTXLowerArgs::runOnDeviceFunction(const NVPTXTargetMachine &TM,
607
Function &F) {
608
LLVM_DEBUG(dbgs() << "Lowering function args of " << F.getName() << "\n");
609
for (Argument &Arg : F.args())
610
if (Arg.getType()->isPointerTy() && Arg.hasByValAttr())
611
handleByValParam(TM, &Arg);
612
return true;
613
}
614
615
bool NVPTXLowerArgs::runOnFunction(Function &F) {
616
auto &TM = getAnalysis<TargetPassConfig>().getTM<NVPTXTargetMachine>();
617
618
return isKernelFunction(F) ? runOnKernelFunction(TM, F)
619
: runOnDeviceFunction(TM, F);
620
}
621
622
FunctionPass *llvm::createNVPTXLowerArgsPass() { return new NVPTXLowerArgs(); }
623
624