Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
freebsd
GitHub Repository: freebsd/freebsd-src
Path: blob/main/contrib/llvm-project/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
35271 views
1
//===- OpenMPIRBuilder.cpp - Builder for LLVM-IR for OpenMP directives ----===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
/// \file
9
///
10
/// This file implements the OpenMPIRBuilder class, which is used as a
11
/// convenient way to create LLVM instructions for OpenMP directives.
12
///
13
//===----------------------------------------------------------------------===//
14
15
#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
16
#include "llvm/ADT/SmallSet.h"
17
#include "llvm/ADT/StringExtras.h"
18
#include "llvm/ADT/StringRef.h"
19
#include "llvm/Analysis/AssumptionCache.h"
20
#include "llvm/Analysis/CodeMetrics.h"
21
#include "llvm/Analysis/LoopInfo.h"
22
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
23
#include "llvm/Analysis/ScalarEvolution.h"
24
#include "llvm/Analysis/TargetLibraryInfo.h"
25
#include "llvm/Bitcode/BitcodeReader.h"
26
#include "llvm/Frontend/Offloading/Utility.h"
27
#include "llvm/Frontend/OpenMP/OMPGridValues.h"
28
#include "llvm/IR/Attributes.h"
29
#include "llvm/IR/BasicBlock.h"
30
#include "llvm/IR/CFG.h"
31
#include "llvm/IR/CallingConv.h"
32
#include "llvm/IR/Constant.h"
33
#include "llvm/IR/Constants.h"
34
#include "llvm/IR/DebugInfoMetadata.h"
35
#include "llvm/IR/DerivedTypes.h"
36
#include "llvm/IR/Function.h"
37
#include "llvm/IR/GlobalVariable.h"
38
#include "llvm/IR/IRBuilder.h"
39
#include "llvm/IR/LLVMContext.h"
40
#include "llvm/IR/MDBuilder.h"
41
#include "llvm/IR/Metadata.h"
42
#include "llvm/IR/PassManager.h"
43
#include "llvm/IR/PassInstrumentation.h"
44
#include "llvm/IR/ReplaceConstant.h"
45
#include "llvm/IR/Value.h"
46
#include "llvm/MC/TargetRegistry.h"
47
#include "llvm/Support/CommandLine.h"
48
#include "llvm/Support/ErrorHandling.h"
49
#include "llvm/Support/FileSystem.h"
50
#include "llvm/Target/TargetMachine.h"
51
#include "llvm/Target/TargetOptions.h"
52
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
53
#include "llvm/Transforms/Utils/Cloning.h"
54
#include "llvm/Transforms/Utils/CodeExtractor.h"
55
#include "llvm/Transforms/Utils/LoopPeel.h"
56
#include "llvm/Transforms/Utils/UnrollLoop.h"
57
58
#include <cstdint>
59
#include <optional>
60
#include <stack>
61
62
#define DEBUG_TYPE "openmp-ir-builder"
63
64
using namespace llvm;
65
using namespace omp;
66
67
static cl::opt<bool>
68
OptimisticAttributes("openmp-ir-builder-optimistic-attributes", cl::Hidden,
69
cl::desc("Use optimistic attributes describing "
70
"'as-if' properties of runtime calls."),
71
cl::init(false));
72
73
static cl::opt<double> UnrollThresholdFactor(
74
"openmp-ir-builder-unroll-threshold-factor", cl::Hidden,
75
cl::desc("Factor for the unroll threshold to account for code "
76
"simplifications still taking place"),
77
cl::init(1.5));
78
79
#ifndef NDEBUG
80
/// Return whether IP1 and IP2 are ambiguous, i.e. that inserting instructions
81
/// at position IP1 may change the meaning of IP2 or vice-versa. This is because
82
/// an InsertPoint stores the instruction before something is inserted. For
83
/// instance, if both point to the same instruction, two IRBuilders alternating
84
/// creating instruction will cause the instructions to be interleaved.
85
static bool isConflictIP(IRBuilder<>::InsertPoint IP1,
86
IRBuilder<>::InsertPoint IP2) {
87
if (!IP1.isSet() || !IP2.isSet())
88
return false;
89
return IP1.getBlock() == IP2.getBlock() && IP1.getPoint() == IP2.getPoint();
90
}
91
92
static bool isValidWorkshareLoopScheduleType(OMPScheduleType SchedType) {
93
// Valid ordered/unordered and base algorithm combinations.
94
switch (SchedType & ~OMPScheduleType::MonotonicityMask) {
95
case OMPScheduleType::UnorderedStaticChunked:
96
case OMPScheduleType::UnorderedStatic:
97
case OMPScheduleType::UnorderedDynamicChunked:
98
case OMPScheduleType::UnorderedGuidedChunked:
99
case OMPScheduleType::UnorderedRuntime:
100
case OMPScheduleType::UnorderedAuto:
101
case OMPScheduleType::UnorderedTrapezoidal:
102
case OMPScheduleType::UnorderedGreedy:
103
case OMPScheduleType::UnorderedBalanced:
104
case OMPScheduleType::UnorderedGuidedIterativeChunked:
105
case OMPScheduleType::UnorderedGuidedAnalyticalChunked:
106
case OMPScheduleType::UnorderedSteal:
107
case OMPScheduleType::UnorderedStaticBalancedChunked:
108
case OMPScheduleType::UnorderedGuidedSimd:
109
case OMPScheduleType::UnorderedRuntimeSimd:
110
case OMPScheduleType::OrderedStaticChunked:
111
case OMPScheduleType::OrderedStatic:
112
case OMPScheduleType::OrderedDynamicChunked:
113
case OMPScheduleType::OrderedGuidedChunked:
114
case OMPScheduleType::OrderedRuntime:
115
case OMPScheduleType::OrderedAuto:
116
case OMPScheduleType::OrderdTrapezoidal:
117
case OMPScheduleType::NomergeUnorderedStaticChunked:
118
case OMPScheduleType::NomergeUnorderedStatic:
119
case OMPScheduleType::NomergeUnorderedDynamicChunked:
120
case OMPScheduleType::NomergeUnorderedGuidedChunked:
121
case OMPScheduleType::NomergeUnorderedRuntime:
122
case OMPScheduleType::NomergeUnorderedAuto:
123
case OMPScheduleType::NomergeUnorderedTrapezoidal:
124
case OMPScheduleType::NomergeUnorderedGreedy:
125
case OMPScheduleType::NomergeUnorderedBalanced:
126
case OMPScheduleType::NomergeUnorderedGuidedIterativeChunked:
127
case OMPScheduleType::NomergeUnorderedGuidedAnalyticalChunked:
128
case OMPScheduleType::NomergeUnorderedSteal:
129
case OMPScheduleType::NomergeOrderedStaticChunked:
130
case OMPScheduleType::NomergeOrderedStatic:
131
case OMPScheduleType::NomergeOrderedDynamicChunked:
132
case OMPScheduleType::NomergeOrderedGuidedChunked:
133
case OMPScheduleType::NomergeOrderedRuntime:
134
case OMPScheduleType::NomergeOrderedAuto:
135
case OMPScheduleType::NomergeOrderedTrapezoidal:
136
break;
137
default:
138
return false;
139
}
140
141
// Must not set both monotonicity modifiers at the same time.
142
OMPScheduleType MonotonicityFlags =
143
SchedType & OMPScheduleType::MonotonicityMask;
144
if (MonotonicityFlags == OMPScheduleType::MonotonicityMask)
145
return false;
146
147
return true;
148
}
149
#endif
150
151
static const omp::GV &getGridValue(const Triple &T, Function *Kernel) {
152
if (T.isAMDGPU()) {
153
StringRef Features =
154
Kernel->getFnAttribute("target-features").getValueAsString();
155
if (Features.count("+wavefrontsize64"))
156
return omp::getAMDGPUGridValues<64>();
157
return omp::getAMDGPUGridValues<32>();
158
}
159
if (T.isNVPTX())
160
return omp::NVPTXGridValues;
161
llvm_unreachable("No grid value available for this architecture!");
162
}
163
164
/// Determine which scheduling algorithm to use, determined from schedule clause
165
/// arguments.
166
static OMPScheduleType
167
getOpenMPBaseScheduleType(llvm::omp::ScheduleKind ClauseKind, bool HasChunks,
168
bool HasSimdModifier) {
169
// Currently, the default schedule it static.
170
switch (ClauseKind) {
171
case OMP_SCHEDULE_Default:
172
case OMP_SCHEDULE_Static:
173
return HasChunks ? OMPScheduleType::BaseStaticChunked
174
: OMPScheduleType::BaseStatic;
175
case OMP_SCHEDULE_Dynamic:
176
return OMPScheduleType::BaseDynamicChunked;
177
case OMP_SCHEDULE_Guided:
178
return HasSimdModifier ? OMPScheduleType::BaseGuidedSimd
179
: OMPScheduleType::BaseGuidedChunked;
180
case OMP_SCHEDULE_Auto:
181
return llvm::omp::OMPScheduleType::BaseAuto;
182
case OMP_SCHEDULE_Runtime:
183
return HasSimdModifier ? OMPScheduleType::BaseRuntimeSimd
184
: OMPScheduleType::BaseRuntime;
185
}
186
llvm_unreachable("unhandled schedule clause argument");
187
}
188
189
/// Adds ordering modifier flags to schedule type.
190
static OMPScheduleType
191
getOpenMPOrderingScheduleType(OMPScheduleType BaseScheduleType,
192
bool HasOrderedClause) {
193
assert((BaseScheduleType & OMPScheduleType::ModifierMask) ==
194
OMPScheduleType::None &&
195
"Must not have ordering nor monotonicity flags already set");
196
197
OMPScheduleType OrderingModifier = HasOrderedClause
198
? OMPScheduleType::ModifierOrdered
199
: OMPScheduleType::ModifierUnordered;
200
OMPScheduleType OrderingScheduleType = BaseScheduleType | OrderingModifier;
201
202
// Unsupported combinations
203
if (OrderingScheduleType ==
204
(OMPScheduleType::BaseGuidedSimd | OMPScheduleType::ModifierOrdered))
205
return OMPScheduleType::OrderedGuidedChunked;
206
else if (OrderingScheduleType == (OMPScheduleType::BaseRuntimeSimd |
207
OMPScheduleType::ModifierOrdered))
208
return OMPScheduleType::OrderedRuntime;
209
210
return OrderingScheduleType;
211
}
212
213
/// Adds monotonicity modifier flags to schedule type.
214
static OMPScheduleType
215
getOpenMPMonotonicityScheduleType(OMPScheduleType ScheduleType,
216
bool HasSimdModifier, bool HasMonotonic,
217
bool HasNonmonotonic, bool HasOrderedClause) {
218
assert((ScheduleType & OMPScheduleType::MonotonicityMask) ==
219
OMPScheduleType::None &&
220
"Must not have monotonicity flags already set");
221
assert((!HasMonotonic || !HasNonmonotonic) &&
222
"Monotonic and Nonmonotonic are contradicting each other");
223
224
if (HasMonotonic) {
225
return ScheduleType | OMPScheduleType::ModifierMonotonic;
226
} else if (HasNonmonotonic) {
227
return ScheduleType | OMPScheduleType::ModifierNonmonotonic;
228
} else {
229
// OpenMP 5.1, 2.11.4 Worksharing-Loop Construct, Description.
230
// If the static schedule kind is specified or if the ordered clause is
231
// specified, and if the nonmonotonic modifier is not specified, the
232
// effect is as if the monotonic modifier is specified. Otherwise, unless
233
// the monotonic modifier is specified, the effect is as if the
234
// nonmonotonic modifier is specified.
235
OMPScheduleType BaseScheduleType =
236
ScheduleType & ~OMPScheduleType::ModifierMask;
237
if ((BaseScheduleType == OMPScheduleType::BaseStatic) ||
238
(BaseScheduleType == OMPScheduleType::BaseStaticChunked) ||
239
HasOrderedClause) {
240
// The monotonic is used by default in openmp runtime library, so no need
241
// to set it.
242
return ScheduleType;
243
} else {
244
return ScheduleType | OMPScheduleType::ModifierNonmonotonic;
245
}
246
}
247
}
248
249
/// Determine the schedule type using schedule and ordering clause arguments.
250
static OMPScheduleType
251
computeOpenMPScheduleType(ScheduleKind ClauseKind, bool HasChunks,
252
bool HasSimdModifier, bool HasMonotonicModifier,
253
bool HasNonmonotonicModifier, bool HasOrderedClause) {
254
OMPScheduleType BaseSchedule =
255
getOpenMPBaseScheduleType(ClauseKind, HasChunks, HasSimdModifier);
256
OMPScheduleType OrderedSchedule =
257
getOpenMPOrderingScheduleType(BaseSchedule, HasOrderedClause);
258
OMPScheduleType Result = getOpenMPMonotonicityScheduleType(
259
OrderedSchedule, HasSimdModifier, HasMonotonicModifier,
260
HasNonmonotonicModifier, HasOrderedClause);
261
262
assert(isValidWorkshareLoopScheduleType(Result));
263
return Result;
264
}
265
266
/// Make \p Source branch to \p Target.
267
///
268
/// Handles two situations:
269
/// * \p Source already has an unconditional branch.
270
/// * \p Source is a degenerate block (no terminator because the BB is
271
/// the current head of the IR construction).
272
static void redirectTo(BasicBlock *Source, BasicBlock *Target, DebugLoc DL) {
273
if (Instruction *Term = Source->getTerminator()) {
274
auto *Br = cast<BranchInst>(Term);
275
assert(!Br->isConditional() &&
276
"BB's terminator must be an unconditional branch (or degenerate)");
277
BasicBlock *Succ = Br->getSuccessor(0);
278
Succ->removePredecessor(Source, /*KeepOneInputPHIs=*/true);
279
Br->setSuccessor(0, Target);
280
return;
281
}
282
283
auto *NewBr = BranchInst::Create(Target, Source);
284
NewBr->setDebugLoc(DL);
285
}
286
287
void llvm::spliceBB(IRBuilderBase::InsertPoint IP, BasicBlock *New,
288
bool CreateBranch) {
289
assert(New->getFirstInsertionPt() == New->begin() &&
290
"Target BB must not have PHI nodes");
291
292
// Move instructions to new block.
293
BasicBlock *Old = IP.getBlock();
294
New->splice(New->begin(), Old, IP.getPoint(), Old->end());
295
296
if (CreateBranch)
297
BranchInst::Create(New, Old);
298
}
299
300
void llvm::spliceBB(IRBuilder<> &Builder, BasicBlock *New, bool CreateBranch) {
301
DebugLoc DebugLoc = Builder.getCurrentDebugLocation();
302
BasicBlock *Old = Builder.GetInsertBlock();
303
304
spliceBB(Builder.saveIP(), New, CreateBranch);
305
if (CreateBranch)
306
Builder.SetInsertPoint(Old->getTerminator());
307
else
308
Builder.SetInsertPoint(Old);
309
310
// SetInsertPoint also updates the Builder's debug location, but we want to
311
// keep the one the Builder was configured to use.
312
Builder.SetCurrentDebugLocation(DebugLoc);
313
}
314
315
BasicBlock *llvm::splitBB(IRBuilderBase::InsertPoint IP, bool CreateBranch,
316
llvm::Twine Name) {
317
BasicBlock *Old = IP.getBlock();
318
BasicBlock *New = BasicBlock::Create(
319
Old->getContext(), Name.isTriviallyEmpty() ? Old->getName() : Name,
320
Old->getParent(), Old->getNextNode());
321
spliceBB(IP, New, CreateBranch);
322
New->replaceSuccessorsPhiUsesWith(Old, New);
323
return New;
324
}
325
326
BasicBlock *llvm::splitBB(IRBuilderBase &Builder, bool CreateBranch,
327
llvm::Twine Name) {
328
DebugLoc DebugLoc = Builder.getCurrentDebugLocation();
329
BasicBlock *New = splitBB(Builder.saveIP(), CreateBranch, Name);
330
if (CreateBranch)
331
Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator());
332
else
333
Builder.SetInsertPoint(Builder.GetInsertBlock());
334
// SetInsertPoint also updates the Builder's debug location, but we want to
335
// keep the one the Builder was configured to use.
336
Builder.SetCurrentDebugLocation(DebugLoc);
337
return New;
338
}
339
340
BasicBlock *llvm::splitBB(IRBuilder<> &Builder, bool CreateBranch,
341
llvm::Twine Name) {
342
DebugLoc DebugLoc = Builder.getCurrentDebugLocation();
343
BasicBlock *New = splitBB(Builder.saveIP(), CreateBranch, Name);
344
if (CreateBranch)
345
Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator());
346
else
347
Builder.SetInsertPoint(Builder.GetInsertBlock());
348
// SetInsertPoint also updates the Builder's debug location, but we want to
349
// keep the one the Builder was configured to use.
350
Builder.SetCurrentDebugLocation(DebugLoc);
351
return New;
352
}
353
354
BasicBlock *llvm::splitBBWithSuffix(IRBuilderBase &Builder, bool CreateBranch,
355
llvm::Twine Suffix) {
356
BasicBlock *Old = Builder.GetInsertBlock();
357
return splitBB(Builder, CreateBranch, Old->getName() + Suffix);
358
}
359
360
// This function creates a fake integer value and a fake use for the integer
361
// value. It returns the fake value created. This is useful in modeling the
362
// extra arguments to the outlined functions.
363
Value *createFakeIntVal(IRBuilderBase &Builder,
364
OpenMPIRBuilder::InsertPointTy OuterAllocaIP,
365
llvm::SmallVectorImpl<Instruction *> &ToBeDeleted,
366
OpenMPIRBuilder::InsertPointTy InnerAllocaIP,
367
const Twine &Name = "", bool AsPtr = true) {
368
Builder.restoreIP(OuterAllocaIP);
369
Instruction *FakeVal;
370
AllocaInst *FakeValAddr =
371
Builder.CreateAlloca(Builder.getInt32Ty(), nullptr, Name + ".addr");
372
ToBeDeleted.push_back(FakeValAddr);
373
374
if (AsPtr) {
375
FakeVal = FakeValAddr;
376
} else {
377
FakeVal =
378
Builder.CreateLoad(Builder.getInt32Ty(), FakeValAddr, Name + ".val");
379
ToBeDeleted.push_back(FakeVal);
380
}
381
382
// Generate a fake use of this value
383
Builder.restoreIP(InnerAllocaIP);
384
Instruction *UseFakeVal;
385
if (AsPtr) {
386
UseFakeVal =
387
Builder.CreateLoad(Builder.getInt32Ty(), FakeVal, Name + ".use");
388
} else {
389
UseFakeVal =
390
cast<BinaryOperator>(Builder.CreateAdd(FakeVal, Builder.getInt32(10)));
391
}
392
ToBeDeleted.push_back(UseFakeVal);
393
return FakeVal;
394
}
395
396
//===----------------------------------------------------------------------===//
397
// OpenMPIRBuilderConfig
398
//===----------------------------------------------------------------------===//
399
400
namespace {
401
LLVM_ENABLE_BITMASK_ENUMS_IN_NAMESPACE();
402
/// Values for bit flags for marking which requires clauses have been used.
403
enum OpenMPOffloadingRequiresDirFlags {
404
/// flag undefined.
405
OMP_REQ_UNDEFINED = 0x000,
406
/// no requires directive present.
407
OMP_REQ_NONE = 0x001,
408
/// reverse_offload clause.
409
OMP_REQ_REVERSE_OFFLOAD = 0x002,
410
/// unified_address clause.
411
OMP_REQ_UNIFIED_ADDRESS = 0x004,
412
/// unified_shared_memory clause.
413
OMP_REQ_UNIFIED_SHARED_MEMORY = 0x008,
414
/// dynamic_allocators clause.
415
OMP_REQ_DYNAMIC_ALLOCATORS = 0x010,
416
LLVM_MARK_AS_BITMASK_ENUM(/*LargestValue=*/OMP_REQ_DYNAMIC_ALLOCATORS)
417
};
418
419
} // anonymous namespace
420
421
OpenMPIRBuilderConfig::OpenMPIRBuilderConfig()
422
: RequiresFlags(OMP_REQ_UNDEFINED) {}
423
424
OpenMPIRBuilderConfig::OpenMPIRBuilderConfig(
425
bool IsTargetDevice, bool IsGPU, bool OpenMPOffloadMandatory,
426
bool HasRequiresReverseOffload, bool HasRequiresUnifiedAddress,
427
bool HasRequiresUnifiedSharedMemory, bool HasRequiresDynamicAllocators)
428
: IsTargetDevice(IsTargetDevice), IsGPU(IsGPU),
429
OpenMPOffloadMandatory(OpenMPOffloadMandatory),
430
RequiresFlags(OMP_REQ_UNDEFINED) {
431
if (HasRequiresReverseOffload)
432
RequiresFlags |= OMP_REQ_REVERSE_OFFLOAD;
433
if (HasRequiresUnifiedAddress)
434
RequiresFlags |= OMP_REQ_UNIFIED_ADDRESS;
435
if (HasRequiresUnifiedSharedMemory)
436
RequiresFlags |= OMP_REQ_UNIFIED_SHARED_MEMORY;
437
if (HasRequiresDynamicAllocators)
438
RequiresFlags |= OMP_REQ_DYNAMIC_ALLOCATORS;
439
}
440
441
bool OpenMPIRBuilderConfig::hasRequiresReverseOffload() const {
442
return RequiresFlags & OMP_REQ_REVERSE_OFFLOAD;
443
}
444
445
bool OpenMPIRBuilderConfig::hasRequiresUnifiedAddress() const {
446
return RequiresFlags & OMP_REQ_UNIFIED_ADDRESS;
447
}
448
449
bool OpenMPIRBuilderConfig::hasRequiresUnifiedSharedMemory() const {
450
return RequiresFlags & OMP_REQ_UNIFIED_SHARED_MEMORY;
451
}
452
453
bool OpenMPIRBuilderConfig::hasRequiresDynamicAllocators() const {
454
return RequiresFlags & OMP_REQ_DYNAMIC_ALLOCATORS;
455
}
456
457
int64_t OpenMPIRBuilderConfig::getRequiresFlags() const {
458
return hasRequiresFlags() ? RequiresFlags
459
: static_cast<int64_t>(OMP_REQ_NONE);
460
}
461
462
void OpenMPIRBuilderConfig::setHasRequiresReverseOffload(bool Value) {
463
if (Value)
464
RequiresFlags |= OMP_REQ_REVERSE_OFFLOAD;
465
else
466
RequiresFlags &= ~OMP_REQ_REVERSE_OFFLOAD;
467
}
468
469
void OpenMPIRBuilderConfig::setHasRequiresUnifiedAddress(bool Value) {
470
if (Value)
471
RequiresFlags |= OMP_REQ_UNIFIED_ADDRESS;
472
else
473
RequiresFlags &= ~OMP_REQ_UNIFIED_ADDRESS;
474
}
475
476
void OpenMPIRBuilderConfig::setHasRequiresUnifiedSharedMemory(bool Value) {
477
if (Value)
478
RequiresFlags |= OMP_REQ_UNIFIED_SHARED_MEMORY;
479
else
480
RequiresFlags &= ~OMP_REQ_UNIFIED_SHARED_MEMORY;
481
}
482
483
void OpenMPIRBuilderConfig::setHasRequiresDynamicAllocators(bool Value) {
484
if (Value)
485
RequiresFlags |= OMP_REQ_DYNAMIC_ALLOCATORS;
486
else
487
RequiresFlags &= ~OMP_REQ_DYNAMIC_ALLOCATORS;
488
}
489
490
//===----------------------------------------------------------------------===//
491
// OpenMPIRBuilder
492
//===----------------------------------------------------------------------===//
493
494
void OpenMPIRBuilder::getKernelArgsVector(TargetKernelArgs &KernelArgs,
495
IRBuilderBase &Builder,
496
SmallVector<Value *> &ArgsVector) {
497
Value *Version = Builder.getInt32(OMP_KERNEL_ARG_VERSION);
498
Value *PointerNum = Builder.getInt32(KernelArgs.NumTargetItems);
499
auto Int32Ty = Type::getInt32Ty(Builder.getContext());
500
Value *ZeroArray = Constant::getNullValue(ArrayType::get(Int32Ty, 3));
501
Value *Flags = Builder.getInt64(KernelArgs.HasNoWait);
502
503
Value *NumTeams3D =
504
Builder.CreateInsertValue(ZeroArray, KernelArgs.NumTeams, {0});
505
Value *NumThreads3D =
506
Builder.CreateInsertValue(ZeroArray, KernelArgs.NumThreads, {0});
507
508
ArgsVector = {Version,
509
PointerNum,
510
KernelArgs.RTArgs.BasePointersArray,
511
KernelArgs.RTArgs.PointersArray,
512
KernelArgs.RTArgs.SizesArray,
513
KernelArgs.RTArgs.MapTypesArray,
514
KernelArgs.RTArgs.MapNamesArray,
515
KernelArgs.RTArgs.MappersArray,
516
KernelArgs.NumIterations,
517
Flags,
518
NumTeams3D,
519
NumThreads3D,
520
KernelArgs.DynCGGroupMem};
521
}
522
523
void OpenMPIRBuilder::addAttributes(omp::RuntimeFunction FnID, Function &Fn) {
524
LLVMContext &Ctx = Fn.getContext();
525
526
// Get the function's current attributes.
527
auto Attrs = Fn.getAttributes();
528
auto FnAttrs = Attrs.getFnAttrs();
529
auto RetAttrs = Attrs.getRetAttrs();
530
SmallVector<AttributeSet, 4> ArgAttrs;
531
for (size_t ArgNo = 0; ArgNo < Fn.arg_size(); ++ArgNo)
532
ArgAttrs.emplace_back(Attrs.getParamAttrs(ArgNo));
533
534
// Add AS to FnAS while taking special care with integer extensions.
535
auto addAttrSet = [&](AttributeSet &FnAS, const AttributeSet &AS,
536
bool Param = true) -> void {
537
bool HasSignExt = AS.hasAttribute(Attribute::SExt);
538
bool HasZeroExt = AS.hasAttribute(Attribute::ZExt);
539
if (HasSignExt || HasZeroExt) {
540
assert(AS.getNumAttributes() == 1 &&
541
"Currently not handling extension attr combined with others.");
542
if (Param) {
543
if (auto AK = TargetLibraryInfo::getExtAttrForI32Param(T, HasSignExt))
544
FnAS = FnAS.addAttribute(Ctx, AK);
545
} else if (auto AK =
546
TargetLibraryInfo::getExtAttrForI32Return(T, HasSignExt))
547
FnAS = FnAS.addAttribute(Ctx, AK);
548
} else {
549
FnAS = FnAS.addAttributes(Ctx, AS);
550
}
551
};
552
553
#define OMP_ATTRS_SET(VarName, AttrSet) AttributeSet VarName = AttrSet;
554
#include "llvm/Frontend/OpenMP/OMPKinds.def"
555
556
// Add attributes to the function declaration.
557
switch (FnID) {
558
#define OMP_RTL_ATTRS(Enum, FnAttrSet, RetAttrSet, ArgAttrSets) \
559
case Enum: \
560
FnAttrs = FnAttrs.addAttributes(Ctx, FnAttrSet); \
561
addAttrSet(RetAttrs, RetAttrSet, /*Param*/ false); \
562
for (size_t ArgNo = 0; ArgNo < ArgAttrSets.size(); ++ArgNo) \
563
addAttrSet(ArgAttrs[ArgNo], ArgAttrSets[ArgNo]); \
564
Fn.setAttributes(AttributeList::get(Ctx, FnAttrs, RetAttrs, ArgAttrs)); \
565
break;
566
#include "llvm/Frontend/OpenMP/OMPKinds.def"
567
default:
568
// Attributes are optional.
569
break;
570
}
571
}
572
573
FunctionCallee
574
OpenMPIRBuilder::getOrCreateRuntimeFunction(Module &M, RuntimeFunction FnID) {
575
FunctionType *FnTy = nullptr;
576
Function *Fn = nullptr;
577
578
// Try to find the declation in the module first.
579
switch (FnID) {
580
#define OMP_RTL(Enum, Str, IsVarArg, ReturnType, ...) \
581
case Enum: \
582
FnTy = FunctionType::get(ReturnType, ArrayRef<Type *>{__VA_ARGS__}, \
583
IsVarArg); \
584
Fn = M.getFunction(Str); \
585
break;
586
#include "llvm/Frontend/OpenMP/OMPKinds.def"
587
}
588
589
if (!Fn) {
590
// Create a new declaration if we need one.
591
switch (FnID) {
592
#define OMP_RTL(Enum, Str, ...) \
593
case Enum: \
594
Fn = Function::Create(FnTy, GlobalValue::ExternalLinkage, Str, M); \
595
break;
596
#include "llvm/Frontend/OpenMP/OMPKinds.def"
597
}
598
599
// Add information if the runtime function takes a callback function
600
if (FnID == OMPRTL___kmpc_fork_call || FnID == OMPRTL___kmpc_fork_teams) {
601
if (!Fn->hasMetadata(LLVMContext::MD_callback)) {
602
LLVMContext &Ctx = Fn->getContext();
603
MDBuilder MDB(Ctx);
604
// Annotate the callback behavior of the runtime function:
605
// - The callback callee is argument number 2 (microtask).
606
// - The first two arguments of the callback callee are unknown (-1).
607
// - All variadic arguments to the runtime function are passed to the
608
// callback callee.
609
Fn->addMetadata(
610
LLVMContext::MD_callback,
611
*MDNode::get(Ctx, {MDB.createCallbackEncoding(
612
2, {-1, -1}, /* VarArgsArePassed */ true)}));
613
}
614
}
615
616
LLVM_DEBUG(dbgs() << "Created OpenMP runtime function " << Fn->getName()
617
<< " with type " << *Fn->getFunctionType() << "\n");
618
addAttributes(FnID, *Fn);
619
620
} else {
621
LLVM_DEBUG(dbgs() << "Found OpenMP runtime function " << Fn->getName()
622
<< " with type " << *Fn->getFunctionType() << "\n");
623
}
624
625
assert(Fn && "Failed to create OpenMP runtime function");
626
627
return {FnTy, Fn};
628
}
629
630
Function *OpenMPIRBuilder::getOrCreateRuntimeFunctionPtr(RuntimeFunction FnID) {
631
FunctionCallee RTLFn = getOrCreateRuntimeFunction(M, FnID);
632
auto *Fn = dyn_cast<llvm::Function>(RTLFn.getCallee());
633
assert(Fn && "Failed to create OpenMP runtime function pointer");
634
return Fn;
635
}
636
637
void OpenMPIRBuilder::initialize() { initializeTypes(M); }
638
639
static void raiseUserConstantDataAllocasToEntryBlock(IRBuilderBase &Builder,
640
Function *Function) {
641
BasicBlock &EntryBlock = Function->getEntryBlock();
642
Instruction *MoveLocInst = EntryBlock.getFirstNonPHI();
643
644
// Loop over blocks looking for constant allocas, skipping the entry block
645
// as any allocas there are already in the desired location.
646
for (auto Block = std::next(Function->begin(), 1); Block != Function->end();
647
Block++) {
648
for (auto Inst = Block->getReverseIterator()->begin();
649
Inst != Block->getReverseIterator()->end();) {
650
if (auto *AllocaInst = dyn_cast_if_present<llvm::AllocaInst>(Inst)) {
651
Inst++;
652
if (!isa<ConstantData>(AllocaInst->getArraySize()))
653
continue;
654
AllocaInst->moveBeforePreserving(MoveLocInst);
655
} else {
656
Inst++;
657
}
658
}
659
}
660
}
661
662
void OpenMPIRBuilder::finalize(Function *Fn) {
663
SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
664
SmallVector<BasicBlock *, 32> Blocks;
665
SmallVector<OutlineInfo, 16> DeferredOutlines;
666
for (OutlineInfo &OI : OutlineInfos) {
667
// Skip functions that have not finalized yet; may happen with nested
668
// function generation.
669
if (Fn && OI.getFunction() != Fn) {
670
DeferredOutlines.push_back(OI);
671
continue;
672
}
673
674
ParallelRegionBlockSet.clear();
675
Blocks.clear();
676
OI.collectBlocks(ParallelRegionBlockSet, Blocks);
677
678
Function *OuterFn = OI.getFunction();
679
CodeExtractorAnalysisCache CEAC(*OuterFn);
680
// If we generate code for the target device, we need to allocate
681
// struct for aggregate params in the device default alloca address space.
682
// OpenMP runtime requires that the params of the extracted functions are
683
// passed as zero address space pointers. This flag ensures that
684
// CodeExtractor generates correct code for extracted functions
685
// which are used by OpenMP runtime.
686
bool ArgsInZeroAddressSpace = Config.isTargetDevice();
687
CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr,
688
/* AggregateArgs */ true,
689
/* BlockFrequencyInfo */ nullptr,
690
/* BranchProbabilityInfo */ nullptr,
691
/* AssumptionCache */ nullptr,
692
/* AllowVarArgs */ true,
693
/* AllowAlloca */ true,
694
/* AllocaBlock*/ OI.OuterAllocaBB,
695
/* Suffix */ ".omp_par", ArgsInZeroAddressSpace);
696
697
LLVM_DEBUG(dbgs() << "Before outlining: " << *OuterFn << "\n");
698
LLVM_DEBUG(dbgs() << "Entry " << OI.EntryBB->getName()
699
<< " Exit: " << OI.ExitBB->getName() << "\n");
700
assert(Extractor.isEligible() &&
701
"Expected OpenMP outlining to be possible!");
702
703
for (auto *V : OI.ExcludeArgsFromAggregate)
704
Extractor.excludeArgFromAggregate(V);
705
706
Function *OutlinedFn = Extractor.extractCodeRegion(CEAC);
707
708
// Forward target-cpu, target-features attributes to the outlined function.
709
auto TargetCpuAttr = OuterFn->getFnAttribute("target-cpu");
710
if (TargetCpuAttr.isStringAttribute())
711
OutlinedFn->addFnAttr(TargetCpuAttr);
712
713
auto TargetFeaturesAttr = OuterFn->getFnAttribute("target-features");
714
if (TargetFeaturesAttr.isStringAttribute())
715
OutlinedFn->addFnAttr(TargetFeaturesAttr);
716
717
LLVM_DEBUG(dbgs() << "After outlining: " << *OuterFn << "\n");
718
LLVM_DEBUG(dbgs() << " Outlined function: " << *OutlinedFn << "\n");
719
assert(OutlinedFn->getReturnType()->isVoidTy() &&
720
"OpenMP outlined functions should not return a value!");
721
722
// For compability with the clang CG we move the outlined function after the
723
// one with the parallel region.
724
OutlinedFn->removeFromParent();
725
M.getFunctionList().insertAfter(OuterFn->getIterator(), OutlinedFn);
726
727
// Remove the artificial entry introduced by the extractor right away, we
728
// made our own entry block after all.
729
{
730
BasicBlock &ArtificialEntry = OutlinedFn->getEntryBlock();
731
assert(ArtificialEntry.getUniqueSuccessor() == OI.EntryBB);
732
assert(OI.EntryBB->getUniquePredecessor() == &ArtificialEntry);
733
// Move instructions from the to-be-deleted ArtificialEntry to the entry
734
// basic block of the parallel region. CodeExtractor generates
735
// instructions to unwrap the aggregate argument and may sink
736
// allocas/bitcasts for values that are solely used in the outlined region
737
// and do not escape.
738
assert(!ArtificialEntry.empty() &&
739
"Expected instructions to add in the outlined region entry");
740
for (BasicBlock::reverse_iterator It = ArtificialEntry.rbegin(),
741
End = ArtificialEntry.rend();
742
It != End;) {
743
Instruction &I = *It;
744
It++;
745
746
if (I.isTerminator())
747
continue;
748
749
I.moveBeforePreserving(*OI.EntryBB, OI.EntryBB->getFirstInsertionPt());
750
}
751
752
OI.EntryBB->moveBefore(&ArtificialEntry);
753
ArtificialEntry.eraseFromParent();
754
}
755
assert(&OutlinedFn->getEntryBlock() == OI.EntryBB);
756
assert(OutlinedFn && OutlinedFn->getNumUses() == 1);
757
758
// Run a user callback, e.g. to add attributes.
759
if (OI.PostOutlineCB)
760
OI.PostOutlineCB(*OutlinedFn);
761
}
762
763
// Remove work items that have been completed.
764
OutlineInfos = std::move(DeferredOutlines);
765
766
// The createTarget functions embeds user written code into
767
// the target region which may inject allocas which need to
768
// be moved to the entry block of our target or risk malformed
769
// optimisations by later passes, this is only relevant for
770
// the device pass which appears to be a little more delicate
771
// when it comes to optimisations (however, we do not block on
772
// that here, it's up to the inserter to the list to do so).
773
// This notbaly has to occur after the OutlinedInfo candidates
774
// have been extracted so we have an end product that will not
775
// be implicitly adversely affected by any raises unless
776
// intentionally appended to the list.
777
// NOTE: This only does so for ConstantData, it could be extended
778
// to ConstantExpr's with further effort, however, they should
779
// largely be folded when they get here. Extending it to runtime
780
// defined/read+writeable allocation sizes would be non-trivial
781
// (need to factor in movement of any stores to variables the
782
// allocation size depends on, as well as the usual loads,
783
// otherwise it'll yield the wrong result after movement) and
784
// likely be more suitable as an LLVM optimisation pass.
785
for (Function *F : ConstantAllocaRaiseCandidates)
786
raiseUserConstantDataAllocasToEntryBlock(Builder, F);
787
788
EmitMetadataErrorReportFunctionTy &&ErrorReportFn =
789
[](EmitMetadataErrorKind Kind,
790
const TargetRegionEntryInfo &EntryInfo) -> void {
791
errs() << "Error of kind: " << Kind
792
<< " when emitting offload entries and metadata during "
793
"OMPIRBuilder finalization \n";
794
};
795
796
if (!OffloadInfoManager.empty())
797
createOffloadEntriesAndInfoMetadata(ErrorReportFn);
798
799
if (Config.EmitLLVMUsedMetaInfo.value_or(false)) {
800
std::vector<WeakTrackingVH> LLVMCompilerUsed = {
801
M.getGlobalVariable("__openmp_nvptx_data_transfer_temporary_storage")};
802
emitUsed("llvm.compiler.used", LLVMCompilerUsed);
803
}
804
}
805
806
OpenMPIRBuilder::~OpenMPIRBuilder() {
807
assert(OutlineInfos.empty() && "There must be no outstanding outlinings");
808
}
809
810
GlobalValue *OpenMPIRBuilder::createGlobalFlag(unsigned Value, StringRef Name) {
811
IntegerType *I32Ty = Type::getInt32Ty(M.getContext());
812
auto *GV =
813
new GlobalVariable(M, I32Ty,
814
/* isConstant = */ true, GlobalValue::WeakODRLinkage,
815
ConstantInt::get(I32Ty, Value), Name);
816
GV->setVisibility(GlobalValue::HiddenVisibility);
817
818
return GV;
819
}
820
821
Constant *OpenMPIRBuilder::getOrCreateIdent(Constant *SrcLocStr,
822
uint32_t SrcLocStrSize,
823
IdentFlag LocFlags,
824
unsigned Reserve2Flags) {
825
// Enable "C-mode".
826
LocFlags |= OMP_IDENT_FLAG_KMPC;
827
828
Constant *&Ident =
829
IdentMap[{SrcLocStr, uint64_t(LocFlags) << 31 | Reserve2Flags}];
830
if (!Ident) {
831
Constant *I32Null = ConstantInt::getNullValue(Int32);
832
Constant *IdentData[] = {I32Null,
833
ConstantInt::get(Int32, uint32_t(LocFlags)),
834
ConstantInt::get(Int32, Reserve2Flags),
835
ConstantInt::get(Int32, SrcLocStrSize), SrcLocStr};
836
Constant *Initializer =
837
ConstantStruct::get(OpenMPIRBuilder::Ident, IdentData);
838
839
// Look for existing encoding of the location + flags, not needed but
840
// minimizes the difference to the existing solution while we transition.
841
for (GlobalVariable &GV : M.globals())
842
if (GV.getValueType() == OpenMPIRBuilder::Ident && GV.hasInitializer())
843
if (GV.getInitializer() == Initializer)
844
Ident = &GV;
845
846
if (!Ident) {
847
auto *GV = new GlobalVariable(
848
M, OpenMPIRBuilder::Ident,
849
/* isConstant = */ true, GlobalValue::PrivateLinkage, Initializer, "",
850
nullptr, GlobalValue::NotThreadLocal,
851
M.getDataLayout().getDefaultGlobalsAddressSpace());
852
GV->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
853
GV->setAlignment(Align(8));
854
Ident = GV;
855
}
856
}
857
858
return ConstantExpr::getPointerBitCastOrAddrSpaceCast(Ident, IdentPtr);
859
}
860
861
Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(StringRef LocStr,
862
uint32_t &SrcLocStrSize) {
863
SrcLocStrSize = LocStr.size();
864
Constant *&SrcLocStr = SrcLocStrMap[LocStr];
865
if (!SrcLocStr) {
866
Constant *Initializer =
867
ConstantDataArray::getString(M.getContext(), LocStr);
868
869
// Look for existing encoding of the location, not needed but minimizes the
870
// difference to the existing solution while we transition.
871
for (GlobalVariable &GV : M.globals())
872
if (GV.isConstant() && GV.hasInitializer() &&
873
GV.getInitializer() == Initializer)
874
return SrcLocStr = ConstantExpr::getPointerCast(&GV, Int8Ptr);
875
876
SrcLocStr = Builder.CreateGlobalStringPtr(LocStr, /* Name */ "",
877
/* AddressSpace */ 0, &M);
878
}
879
return SrcLocStr;
880
}
881
882
Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(StringRef FunctionName,
883
StringRef FileName,
884
unsigned Line, unsigned Column,
885
uint32_t &SrcLocStrSize) {
886
SmallString<128> Buffer;
887
Buffer.push_back(';');
888
Buffer.append(FileName);
889
Buffer.push_back(';');
890
Buffer.append(FunctionName);
891
Buffer.push_back(';');
892
Buffer.append(std::to_string(Line));
893
Buffer.push_back(';');
894
Buffer.append(std::to_string(Column));
895
Buffer.push_back(';');
896
Buffer.push_back(';');
897
return getOrCreateSrcLocStr(Buffer.str(), SrcLocStrSize);
898
}
899
900
Constant *
901
OpenMPIRBuilder::getOrCreateDefaultSrcLocStr(uint32_t &SrcLocStrSize) {
902
StringRef UnknownLoc = ";unknown;unknown;0;0;;";
903
return getOrCreateSrcLocStr(UnknownLoc, SrcLocStrSize);
904
}
905
906
Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(DebugLoc DL,
907
uint32_t &SrcLocStrSize,
908
Function *F) {
909
DILocation *DIL = DL.get();
910
if (!DIL)
911
return getOrCreateDefaultSrcLocStr(SrcLocStrSize);
912
StringRef FileName = M.getName();
913
if (DIFile *DIF = DIL->getFile())
914
if (std::optional<StringRef> Source = DIF->getSource())
915
FileName = *Source;
916
StringRef Function = DIL->getScope()->getSubprogram()->getName();
917
if (Function.empty() && F)
918
Function = F->getName();
919
return getOrCreateSrcLocStr(Function, FileName, DIL->getLine(),
920
DIL->getColumn(), SrcLocStrSize);
921
}
922
923
Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(const LocationDescription &Loc,
924
uint32_t &SrcLocStrSize) {
925
return getOrCreateSrcLocStr(Loc.DL, SrcLocStrSize,
926
Loc.IP.getBlock()->getParent());
927
}
928
929
Value *OpenMPIRBuilder::getOrCreateThreadID(Value *Ident) {
930
return Builder.CreateCall(
931
getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_global_thread_num), Ident,
932
"omp_global_thread_num");
933
}
934
935
OpenMPIRBuilder::InsertPointTy
936
OpenMPIRBuilder::createBarrier(const LocationDescription &Loc, Directive Kind,
937
bool ForceSimpleCall, bool CheckCancelFlag) {
938
if (!updateToLocation(Loc))
939
return Loc.IP;
940
941
// Build call __kmpc_cancel_barrier(loc, thread_id) or
942
// __kmpc_barrier(loc, thread_id);
943
944
IdentFlag BarrierLocFlags;
945
switch (Kind) {
946
case OMPD_for:
947
BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_FOR;
948
break;
949
case OMPD_sections:
950
BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_SECTIONS;
951
break;
952
case OMPD_single:
953
BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_SINGLE;
954
break;
955
case OMPD_barrier:
956
BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_EXPL;
957
break;
958
default:
959
BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL;
960
break;
961
}
962
963
uint32_t SrcLocStrSize;
964
Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
965
Value *Args[] = {
966
getOrCreateIdent(SrcLocStr, SrcLocStrSize, BarrierLocFlags),
967
getOrCreateThreadID(getOrCreateIdent(SrcLocStr, SrcLocStrSize))};
968
969
// If we are in a cancellable parallel region, barriers are cancellation
970
// points.
971
// TODO: Check why we would force simple calls or to ignore the cancel flag.
972
bool UseCancelBarrier =
973
!ForceSimpleCall && isLastFinalizationInfoCancellable(OMPD_parallel);
974
975
Value *Result =
976
Builder.CreateCall(getOrCreateRuntimeFunctionPtr(
977
UseCancelBarrier ? OMPRTL___kmpc_cancel_barrier
978
: OMPRTL___kmpc_barrier),
979
Args);
980
981
if (UseCancelBarrier && CheckCancelFlag)
982
emitCancelationCheckImpl(Result, OMPD_parallel);
983
984
return Builder.saveIP();
985
}
986
987
OpenMPIRBuilder::InsertPointTy
988
OpenMPIRBuilder::createCancel(const LocationDescription &Loc,
989
Value *IfCondition,
990
omp::Directive CanceledDirective) {
991
if (!updateToLocation(Loc))
992
return Loc.IP;
993
994
// LLVM utilities like blocks with terminators.
995
auto *UI = Builder.CreateUnreachable();
996
997
Instruction *ThenTI = UI, *ElseTI = nullptr;
998
if (IfCondition)
999
SplitBlockAndInsertIfThenElse(IfCondition, UI, &ThenTI, &ElseTI);
1000
Builder.SetInsertPoint(ThenTI);
1001
1002
Value *CancelKind = nullptr;
1003
switch (CanceledDirective) {
1004
#define OMP_CANCEL_KIND(Enum, Str, DirectiveEnum, Value) \
1005
case DirectiveEnum: \
1006
CancelKind = Builder.getInt32(Value); \
1007
break;
1008
#include "llvm/Frontend/OpenMP/OMPKinds.def"
1009
default:
1010
llvm_unreachable("Unknown cancel kind!");
1011
}
1012
1013
uint32_t SrcLocStrSize;
1014
Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1015
Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1016
Value *Args[] = {Ident, getOrCreateThreadID(Ident), CancelKind};
1017
Value *Result = Builder.CreateCall(
1018
getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_cancel), Args);
1019
auto ExitCB = [this, CanceledDirective, Loc](InsertPointTy IP) {
1020
if (CanceledDirective == OMPD_parallel) {
1021
IRBuilder<>::InsertPointGuard IPG(Builder);
1022
Builder.restoreIP(IP);
1023
createBarrier(LocationDescription(Builder.saveIP(), Loc.DL),
1024
omp::Directive::OMPD_unknown, /* ForceSimpleCall */ false,
1025
/* CheckCancelFlag */ false);
1026
}
1027
};
1028
1029
// The actual cancel logic is shared with others, e.g., cancel_barriers.
1030
emitCancelationCheckImpl(Result, CanceledDirective, ExitCB);
1031
1032
// Update the insertion point and remove the terminator we introduced.
1033
Builder.SetInsertPoint(UI->getParent());
1034
UI->eraseFromParent();
1035
1036
return Builder.saveIP();
1037
}
1038
1039
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetKernel(
1040
const LocationDescription &Loc, InsertPointTy AllocaIP, Value *&Return,
1041
Value *Ident, Value *DeviceID, Value *NumTeams, Value *NumThreads,
1042
Value *HostPtr, ArrayRef<Value *> KernelArgs) {
1043
if (!updateToLocation(Loc))
1044
return Loc.IP;
1045
1046
Builder.restoreIP(AllocaIP);
1047
auto *KernelArgsPtr =
1048
Builder.CreateAlloca(OpenMPIRBuilder::KernelArgs, nullptr, "kernel_args");
1049
Builder.restoreIP(Loc.IP);
1050
1051
for (unsigned I = 0, Size = KernelArgs.size(); I != Size; ++I) {
1052
llvm::Value *Arg =
1053
Builder.CreateStructGEP(OpenMPIRBuilder::KernelArgs, KernelArgsPtr, I);
1054
Builder.CreateAlignedStore(
1055
KernelArgs[I], Arg,
1056
M.getDataLayout().getPrefTypeAlign(KernelArgs[I]->getType()));
1057
}
1058
1059
SmallVector<Value *> OffloadingArgs{Ident, DeviceID, NumTeams,
1060
NumThreads, HostPtr, KernelArgsPtr};
1061
1062
Return = Builder.CreateCall(
1063
getOrCreateRuntimeFunction(M, OMPRTL___tgt_target_kernel),
1064
OffloadingArgs);
1065
1066
return Builder.saveIP();
1067
}
1068
1069
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitKernelLaunch(
1070
const LocationDescription &Loc, Function *OutlinedFn, Value *OutlinedFnID,
1071
EmitFallbackCallbackTy emitTargetCallFallbackCB, TargetKernelArgs &Args,
1072
Value *DeviceID, Value *RTLoc, InsertPointTy AllocaIP) {
1073
1074
if (!updateToLocation(Loc))
1075
return Loc.IP;
1076
1077
Builder.restoreIP(Loc.IP);
1078
// On top of the arrays that were filled up, the target offloading call
1079
// takes as arguments the device id as well as the host pointer. The host
1080
// pointer is used by the runtime library to identify the current target
1081
// region, so it only has to be unique and not necessarily point to
1082
// anything. It could be the pointer to the outlined function that
1083
// implements the target region, but we aren't using that so that the
1084
// compiler doesn't need to keep that, and could therefore inline the host
1085
// function if proven worthwhile during optimization.
1086
1087
// From this point on, we need to have an ID of the target region defined.
1088
assert(OutlinedFnID && "Invalid outlined function ID!");
1089
(void)OutlinedFnID;
1090
1091
// Return value of the runtime offloading call.
1092
Value *Return = nullptr;
1093
1094
// Arguments for the target kernel.
1095
SmallVector<Value *> ArgsVector;
1096
getKernelArgsVector(Args, Builder, ArgsVector);
1097
1098
// The target region is an outlined function launched by the runtime
1099
// via calls to __tgt_target_kernel().
1100
//
1101
// Note that on the host and CPU targets, the runtime implementation of
1102
// these calls simply call the outlined function without forking threads.
1103
// The outlined functions themselves have runtime calls to
1104
// __kmpc_fork_teams() and __kmpc_fork() for this purpose, codegen'd by
1105
// the compiler in emitTeamsCall() and emitParallelCall().
1106
//
1107
// In contrast, on the NVPTX target, the implementation of
1108
// __tgt_target_teams() launches a GPU kernel with the requested number
1109
// of teams and threads so no additional calls to the runtime are required.
1110
// Check the error code and execute the host version if required.
1111
Builder.restoreIP(emitTargetKernel(Builder, AllocaIP, Return, RTLoc, DeviceID,
1112
Args.NumTeams, Args.NumThreads,
1113
OutlinedFnID, ArgsVector));
1114
1115
BasicBlock *OffloadFailedBlock =
1116
BasicBlock::Create(Builder.getContext(), "omp_offload.failed");
1117
BasicBlock *OffloadContBlock =
1118
BasicBlock::Create(Builder.getContext(), "omp_offload.cont");
1119
Value *Failed = Builder.CreateIsNotNull(Return);
1120
Builder.CreateCondBr(Failed, OffloadFailedBlock, OffloadContBlock);
1121
1122
auto CurFn = Builder.GetInsertBlock()->getParent();
1123
emitBlock(OffloadFailedBlock, CurFn);
1124
Builder.restoreIP(emitTargetCallFallbackCB(Builder.saveIP()));
1125
emitBranch(OffloadContBlock);
1126
emitBlock(OffloadContBlock, CurFn, /*IsFinished=*/true);
1127
return Builder.saveIP();
1128
}
1129
1130
void OpenMPIRBuilder::emitCancelationCheckImpl(Value *CancelFlag,
1131
omp::Directive CanceledDirective,
1132
FinalizeCallbackTy ExitCB) {
1133
assert(isLastFinalizationInfoCancellable(CanceledDirective) &&
1134
"Unexpected cancellation!");
1135
1136
// For a cancel barrier we create two new blocks.
1137
BasicBlock *BB = Builder.GetInsertBlock();
1138
BasicBlock *NonCancellationBlock;
1139
if (Builder.GetInsertPoint() == BB->end()) {
1140
// TODO: This branch will not be needed once we moved to the
1141
// OpenMPIRBuilder codegen completely.
1142
NonCancellationBlock = BasicBlock::Create(
1143
BB->getContext(), BB->getName() + ".cont", BB->getParent());
1144
} else {
1145
NonCancellationBlock = SplitBlock(BB, &*Builder.GetInsertPoint());
1146
BB->getTerminator()->eraseFromParent();
1147
Builder.SetInsertPoint(BB);
1148
}
1149
BasicBlock *CancellationBlock = BasicBlock::Create(
1150
BB->getContext(), BB->getName() + ".cncl", BB->getParent());
1151
1152
// Jump to them based on the return value.
1153
Value *Cmp = Builder.CreateIsNull(CancelFlag);
1154
Builder.CreateCondBr(Cmp, NonCancellationBlock, CancellationBlock,
1155
/* TODO weight */ nullptr, nullptr);
1156
1157
// From the cancellation block we finalize all variables and go to the
1158
// post finalization block that is known to the FiniCB callback.
1159
Builder.SetInsertPoint(CancellationBlock);
1160
if (ExitCB)
1161
ExitCB(Builder.saveIP());
1162
auto &FI = FinalizationStack.back();
1163
FI.FiniCB(Builder.saveIP());
1164
1165
// The continuation block is where code generation continues.
1166
Builder.SetInsertPoint(NonCancellationBlock, NonCancellationBlock->begin());
1167
}
1168
1169
// Callback used to create OpenMP runtime calls to support
1170
// omp parallel clause for the device.
1171
// We need to use this callback to replace call to the OutlinedFn in OuterFn
1172
// by the call to the OpenMP DeviceRTL runtime function (kmpc_parallel_51)
1173
static void targetParallelCallback(
1174
OpenMPIRBuilder *OMPIRBuilder, Function &OutlinedFn, Function *OuterFn,
1175
BasicBlock *OuterAllocaBB, Value *Ident, Value *IfCondition,
1176
Value *NumThreads, Instruction *PrivTID, AllocaInst *PrivTIDAddr,
1177
Value *ThreadID, const SmallVector<Instruction *, 4> &ToBeDeleted) {
1178
// Add some known attributes.
1179
IRBuilder<> &Builder = OMPIRBuilder->Builder;
1180
OutlinedFn.addParamAttr(0, Attribute::NoAlias);
1181
OutlinedFn.addParamAttr(1, Attribute::NoAlias);
1182
OutlinedFn.addParamAttr(0, Attribute::NoUndef);
1183
OutlinedFn.addParamAttr(1, Attribute::NoUndef);
1184
OutlinedFn.addFnAttr(Attribute::NoUnwind);
1185
1186
assert(OutlinedFn.arg_size() >= 2 &&
1187
"Expected at least tid and bounded tid as arguments");
1188
unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
1189
1190
CallInst *CI = cast<CallInst>(OutlinedFn.user_back());
1191
assert(CI && "Expected call instruction to outlined function");
1192
CI->getParent()->setName("omp_parallel");
1193
1194
Builder.SetInsertPoint(CI);
1195
Type *PtrTy = OMPIRBuilder->VoidPtr;
1196
Value *NullPtrValue = Constant::getNullValue(PtrTy);
1197
1198
// Add alloca for kernel args
1199
OpenMPIRBuilder ::InsertPointTy CurrentIP = Builder.saveIP();
1200
Builder.SetInsertPoint(OuterAllocaBB, OuterAllocaBB->getFirstInsertionPt());
1201
AllocaInst *ArgsAlloca =
1202
Builder.CreateAlloca(ArrayType::get(PtrTy, NumCapturedVars));
1203
Value *Args = ArgsAlloca;
1204
// Add address space cast if array for storing arguments is not allocated
1205
// in address space 0
1206
if (ArgsAlloca->getAddressSpace())
1207
Args = Builder.CreatePointerCast(ArgsAlloca, PtrTy);
1208
Builder.restoreIP(CurrentIP);
1209
1210
// Store captured vars which are used by kmpc_parallel_51
1211
for (unsigned Idx = 0; Idx < NumCapturedVars; Idx++) {
1212
Value *V = *(CI->arg_begin() + 2 + Idx);
1213
Value *StoreAddress = Builder.CreateConstInBoundsGEP2_64(
1214
ArrayType::get(PtrTy, NumCapturedVars), Args, 0, Idx);
1215
Builder.CreateStore(V, StoreAddress);
1216
}
1217
1218
Value *Cond =
1219
IfCondition ? Builder.CreateSExtOrTrunc(IfCondition, OMPIRBuilder->Int32)
1220
: Builder.getInt32(1);
1221
1222
// Build kmpc_parallel_51 call
1223
Value *Parallel51CallArgs[] = {
1224
/* identifier*/ Ident,
1225
/* global thread num*/ ThreadID,
1226
/* if expression */ Cond,
1227
/* number of threads */ NumThreads ? NumThreads : Builder.getInt32(-1),
1228
/* Proc bind */ Builder.getInt32(-1),
1229
/* outlined function */
1230
Builder.CreateBitCast(&OutlinedFn, OMPIRBuilder->ParallelTaskPtr),
1231
/* wrapper function */ NullPtrValue,
1232
/* arguments of the outlined funciton*/ Args,
1233
/* number of arguments */ Builder.getInt64(NumCapturedVars)};
1234
1235
FunctionCallee RTLFn =
1236
OMPIRBuilder->getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_parallel_51);
1237
1238
Builder.CreateCall(RTLFn, Parallel51CallArgs);
1239
1240
LLVM_DEBUG(dbgs() << "With kmpc_parallel_51 placed: "
1241
<< *Builder.GetInsertBlock()->getParent() << "\n");
1242
1243
// Initialize the local TID stack location with the argument value.
1244
Builder.SetInsertPoint(PrivTID);
1245
Function::arg_iterator OutlinedAI = OutlinedFn.arg_begin();
1246
Builder.CreateStore(Builder.CreateLoad(OMPIRBuilder->Int32, OutlinedAI),
1247
PrivTIDAddr);
1248
1249
// Remove redundant call to the outlined function.
1250
CI->eraseFromParent();
1251
1252
for (Instruction *I : ToBeDeleted) {
1253
I->eraseFromParent();
1254
}
1255
}
1256
1257
// Callback used to create OpenMP runtime calls to support
1258
// omp parallel clause for the host.
1259
// We need to use this callback to replace call to the OutlinedFn in OuterFn
1260
// by the call to the OpenMP host runtime function ( __kmpc_fork_call[_if])
1261
static void
1262
hostParallelCallback(OpenMPIRBuilder *OMPIRBuilder, Function &OutlinedFn,
1263
Function *OuterFn, Value *Ident, Value *IfCondition,
1264
Instruction *PrivTID, AllocaInst *PrivTIDAddr,
1265
const SmallVector<Instruction *, 4> &ToBeDeleted) {
1266
IRBuilder<> &Builder = OMPIRBuilder->Builder;
1267
FunctionCallee RTLFn;
1268
if (IfCondition) {
1269
RTLFn =
1270
OMPIRBuilder->getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_call_if);
1271
} else {
1272
RTLFn =
1273
OMPIRBuilder->getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_call);
1274
}
1275
if (auto *F = dyn_cast<Function>(RTLFn.getCallee())) {
1276
if (!F->hasMetadata(LLVMContext::MD_callback)) {
1277
LLVMContext &Ctx = F->getContext();
1278
MDBuilder MDB(Ctx);
1279
// Annotate the callback behavior of the __kmpc_fork_call:
1280
// - The callback callee is argument number 2 (microtask).
1281
// - The first two arguments of the callback callee are unknown (-1).
1282
// - All variadic arguments to the __kmpc_fork_call are passed to the
1283
// callback callee.
1284
F->addMetadata(LLVMContext::MD_callback,
1285
*MDNode::get(Ctx, {MDB.createCallbackEncoding(
1286
2, {-1, -1},
1287
/* VarArgsArePassed */ true)}));
1288
}
1289
}
1290
// Add some known attributes.
1291
OutlinedFn.addParamAttr(0, Attribute::NoAlias);
1292
OutlinedFn.addParamAttr(1, Attribute::NoAlias);
1293
OutlinedFn.addFnAttr(Attribute::NoUnwind);
1294
1295
assert(OutlinedFn.arg_size() >= 2 &&
1296
"Expected at least tid and bounded tid as arguments");
1297
unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
1298
1299
CallInst *CI = cast<CallInst>(OutlinedFn.user_back());
1300
CI->getParent()->setName("omp_parallel");
1301
Builder.SetInsertPoint(CI);
1302
1303
// Build call __kmpc_fork_call[_if](Ident, n, microtask, var1, .., varn);
1304
Value *ForkCallArgs[] = {
1305
Ident, Builder.getInt32(NumCapturedVars),
1306
Builder.CreateBitCast(&OutlinedFn, OMPIRBuilder->ParallelTaskPtr)};
1307
1308
SmallVector<Value *, 16> RealArgs;
1309
RealArgs.append(std::begin(ForkCallArgs), std::end(ForkCallArgs));
1310
if (IfCondition) {
1311
Value *Cond = Builder.CreateSExtOrTrunc(IfCondition, OMPIRBuilder->Int32);
1312
RealArgs.push_back(Cond);
1313
}
1314
RealArgs.append(CI->arg_begin() + /* tid & bound tid */ 2, CI->arg_end());
1315
1316
// __kmpc_fork_call_if always expects a void ptr as the last argument
1317
// If there are no arguments, pass a null pointer.
1318
auto PtrTy = OMPIRBuilder->VoidPtr;
1319
if (IfCondition && NumCapturedVars == 0) {
1320
Value *NullPtrValue = Constant::getNullValue(PtrTy);
1321
RealArgs.push_back(NullPtrValue);
1322
}
1323
if (IfCondition && RealArgs.back()->getType() != PtrTy)
1324
RealArgs.back() = Builder.CreateBitCast(RealArgs.back(), PtrTy);
1325
1326
Builder.CreateCall(RTLFn, RealArgs);
1327
1328
LLVM_DEBUG(dbgs() << "With fork_call placed: "
1329
<< *Builder.GetInsertBlock()->getParent() << "\n");
1330
1331
// Initialize the local TID stack location with the argument value.
1332
Builder.SetInsertPoint(PrivTID);
1333
Function::arg_iterator OutlinedAI = OutlinedFn.arg_begin();
1334
Builder.CreateStore(Builder.CreateLoad(OMPIRBuilder->Int32, OutlinedAI),
1335
PrivTIDAddr);
1336
1337
// Remove redundant call to the outlined function.
1338
CI->eraseFromParent();
1339
1340
for (Instruction *I : ToBeDeleted) {
1341
I->eraseFromParent();
1342
}
1343
}
1344
1345
IRBuilder<>::InsertPoint OpenMPIRBuilder::createParallel(
1346
const LocationDescription &Loc, InsertPointTy OuterAllocaIP,
1347
BodyGenCallbackTy BodyGenCB, PrivatizeCallbackTy PrivCB,
1348
FinalizeCallbackTy FiniCB, Value *IfCondition, Value *NumThreads,
1349
omp::ProcBindKind ProcBind, bool IsCancellable) {
1350
assert(!isConflictIP(Loc.IP, OuterAllocaIP) && "IPs must not be ambiguous");
1351
1352
if (!updateToLocation(Loc))
1353
return Loc.IP;
1354
1355
uint32_t SrcLocStrSize;
1356
Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1357
Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1358
Value *ThreadID = getOrCreateThreadID(Ident);
1359
// If we generate code for the target device, we need to allocate
1360
// struct for aggregate params in the device default alloca address space.
1361
// OpenMP runtime requires that the params of the extracted functions are
1362
// passed as zero address space pointers. This flag ensures that extracted
1363
// function arguments are declared in zero address space
1364
bool ArgsInZeroAddressSpace = Config.isTargetDevice();
1365
1366
// Build call __kmpc_push_num_threads(&Ident, global_tid, num_threads)
1367
// only if we compile for host side.
1368
if (NumThreads && !Config.isTargetDevice()) {
1369
Value *Args[] = {
1370
Ident, ThreadID,
1371
Builder.CreateIntCast(NumThreads, Int32, /*isSigned*/ false)};
1372
Builder.CreateCall(
1373
getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_num_threads), Args);
1374
}
1375
1376
if (ProcBind != OMP_PROC_BIND_default) {
1377
// Build call __kmpc_push_proc_bind(&Ident, global_tid, proc_bind)
1378
Value *Args[] = {
1379
Ident, ThreadID,
1380
ConstantInt::get(Int32, unsigned(ProcBind), /*isSigned=*/true)};
1381
Builder.CreateCall(
1382
getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_proc_bind), Args);
1383
}
1384
1385
BasicBlock *InsertBB = Builder.GetInsertBlock();
1386
Function *OuterFn = InsertBB->getParent();
1387
1388
// Save the outer alloca block because the insertion iterator may get
1389
// invalidated and we still need this later.
1390
BasicBlock *OuterAllocaBlock = OuterAllocaIP.getBlock();
1391
1392
// Vector to remember instructions we used only during the modeling but which
1393
// we want to delete at the end.
1394
SmallVector<Instruction *, 4> ToBeDeleted;
1395
1396
// Change the location to the outer alloca insertion point to create and
1397
// initialize the allocas we pass into the parallel region.
1398
InsertPointTy NewOuter(OuterAllocaBlock, OuterAllocaBlock->begin());
1399
Builder.restoreIP(NewOuter);
1400
AllocaInst *TIDAddrAlloca = Builder.CreateAlloca(Int32, nullptr, "tid.addr");
1401
AllocaInst *ZeroAddrAlloca =
1402
Builder.CreateAlloca(Int32, nullptr, "zero.addr");
1403
Instruction *TIDAddr = TIDAddrAlloca;
1404
Instruction *ZeroAddr = ZeroAddrAlloca;
1405
if (ArgsInZeroAddressSpace && M.getDataLayout().getAllocaAddrSpace() != 0) {
1406
// Add additional casts to enforce pointers in zero address space
1407
TIDAddr = new AddrSpaceCastInst(
1408
TIDAddrAlloca, PointerType ::get(M.getContext(), 0), "tid.addr.ascast");
1409
TIDAddr->insertAfter(TIDAddrAlloca);
1410
ToBeDeleted.push_back(TIDAddr);
1411
ZeroAddr = new AddrSpaceCastInst(ZeroAddrAlloca,
1412
PointerType ::get(M.getContext(), 0),
1413
"zero.addr.ascast");
1414
ZeroAddr->insertAfter(ZeroAddrAlloca);
1415
ToBeDeleted.push_back(ZeroAddr);
1416
}
1417
1418
// We only need TIDAddr and ZeroAddr for modeling purposes to get the
1419
// associated arguments in the outlined function, so we delete them later.
1420
ToBeDeleted.push_back(TIDAddrAlloca);
1421
ToBeDeleted.push_back(ZeroAddrAlloca);
1422
1423
// Create an artificial insertion point that will also ensure the blocks we
1424
// are about to split are not degenerated.
1425
auto *UI = new UnreachableInst(Builder.getContext(), InsertBB);
1426
1427
BasicBlock *EntryBB = UI->getParent();
1428
BasicBlock *PRegEntryBB = EntryBB->splitBasicBlock(UI, "omp.par.entry");
1429
BasicBlock *PRegBodyBB = PRegEntryBB->splitBasicBlock(UI, "omp.par.region");
1430
BasicBlock *PRegPreFiniBB =
1431
PRegBodyBB->splitBasicBlock(UI, "omp.par.pre_finalize");
1432
BasicBlock *PRegExitBB = PRegPreFiniBB->splitBasicBlock(UI, "omp.par.exit");
1433
1434
auto FiniCBWrapper = [&](InsertPointTy IP) {
1435
// Hide "open-ended" blocks from the given FiniCB by setting the right jump
1436
// target to the region exit block.
1437
if (IP.getBlock()->end() == IP.getPoint()) {
1438
IRBuilder<>::InsertPointGuard IPG(Builder);
1439
Builder.restoreIP(IP);
1440
Instruction *I = Builder.CreateBr(PRegExitBB);
1441
IP = InsertPointTy(I->getParent(), I->getIterator());
1442
}
1443
assert(IP.getBlock()->getTerminator()->getNumSuccessors() == 1 &&
1444
IP.getBlock()->getTerminator()->getSuccessor(0) == PRegExitBB &&
1445
"Unexpected insertion point for finalization call!");
1446
return FiniCB(IP);
1447
};
1448
1449
FinalizationStack.push_back({FiniCBWrapper, OMPD_parallel, IsCancellable});
1450
1451
// Generate the privatization allocas in the block that will become the entry
1452
// of the outlined function.
1453
Builder.SetInsertPoint(PRegEntryBB->getTerminator());
1454
InsertPointTy InnerAllocaIP = Builder.saveIP();
1455
1456
AllocaInst *PrivTIDAddr =
1457
Builder.CreateAlloca(Int32, nullptr, "tid.addr.local");
1458
Instruction *PrivTID = Builder.CreateLoad(Int32, PrivTIDAddr, "tid");
1459
1460
// Add some fake uses for OpenMP provided arguments.
1461
ToBeDeleted.push_back(Builder.CreateLoad(Int32, TIDAddr, "tid.addr.use"));
1462
Instruction *ZeroAddrUse =
1463
Builder.CreateLoad(Int32, ZeroAddr, "zero.addr.use");
1464
ToBeDeleted.push_back(ZeroAddrUse);
1465
1466
// EntryBB
1467
// |
1468
// V
1469
// PRegionEntryBB <- Privatization allocas are placed here.
1470
// |
1471
// V
1472
// PRegionBodyBB <- BodeGen is invoked here.
1473
// |
1474
// V
1475
// PRegPreFiniBB <- The block we will start finalization from.
1476
// |
1477
// V
1478
// PRegionExitBB <- A common exit to simplify block collection.
1479
//
1480
1481
LLVM_DEBUG(dbgs() << "Before body codegen: " << *OuterFn << "\n");
1482
1483
// Let the caller create the body.
1484
assert(BodyGenCB && "Expected body generation callback!");
1485
InsertPointTy CodeGenIP(PRegBodyBB, PRegBodyBB->begin());
1486
BodyGenCB(InnerAllocaIP, CodeGenIP);
1487
1488
LLVM_DEBUG(dbgs() << "After body codegen: " << *OuterFn << "\n");
1489
1490
OutlineInfo OI;
1491
if (Config.isTargetDevice()) {
1492
// Generate OpenMP target specific runtime call
1493
OI.PostOutlineCB = [=, ToBeDeletedVec =
1494
std::move(ToBeDeleted)](Function &OutlinedFn) {
1495
targetParallelCallback(this, OutlinedFn, OuterFn, OuterAllocaBlock, Ident,
1496
IfCondition, NumThreads, PrivTID, PrivTIDAddr,
1497
ThreadID, ToBeDeletedVec);
1498
};
1499
} else {
1500
// Generate OpenMP host runtime call
1501
OI.PostOutlineCB = [=, ToBeDeletedVec =
1502
std::move(ToBeDeleted)](Function &OutlinedFn) {
1503
hostParallelCallback(this, OutlinedFn, OuterFn, Ident, IfCondition,
1504
PrivTID, PrivTIDAddr, ToBeDeletedVec);
1505
};
1506
}
1507
1508
OI.OuterAllocaBB = OuterAllocaBlock;
1509
OI.EntryBB = PRegEntryBB;
1510
OI.ExitBB = PRegExitBB;
1511
1512
SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
1513
SmallVector<BasicBlock *, 32> Blocks;
1514
OI.collectBlocks(ParallelRegionBlockSet, Blocks);
1515
1516
// Ensure a single exit node for the outlined region by creating one.
1517
// We might have multiple incoming edges to the exit now due to finalizations,
1518
// e.g., cancel calls that cause the control flow to leave the region.
1519
BasicBlock *PRegOutlinedExitBB = PRegExitBB;
1520
PRegExitBB = SplitBlock(PRegExitBB, &*PRegExitBB->getFirstInsertionPt());
1521
PRegOutlinedExitBB->setName("omp.par.outlined.exit");
1522
Blocks.push_back(PRegOutlinedExitBB);
1523
1524
CodeExtractorAnalysisCache CEAC(*OuterFn);
1525
CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr,
1526
/* AggregateArgs */ false,
1527
/* BlockFrequencyInfo */ nullptr,
1528
/* BranchProbabilityInfo */ nullptr,
1529
/* AssumptionCache */ nullptr,
1530
/* AllowVarArgs */ true,
1531
/* AllowAlloca */ true,
1532
/* AllocationBlock */ OuterAllocaBlock,
1533
/* Suffix */ ".omp_par", ArgsInZeroAddressSpace);
1534
1535
// Find inputs to, outputs from the code region.
1536
BasicBlock *CommonExit = nullptr;
1537
SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
1538
Extractor.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
1539
Extractor.findInputsOutputs(Inputs, Outputs, SinkingCands);
1540
1541
LLVM_DEBUG(dbgs() << "Before privatization: " << *OuterFn << "\n");
1542
1543
FunctionCallee TIDRTLFn =
1544
getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_global_thread_num);
1545
1546
auto PrivHelper = [&](Value &V) {
1547
if (&V == TIDAddr || &V == ZeroAddr) {
1548
OI.ExcludeArgsFromAggregate.push_back(&V);
1549
return;
1550
}
1551
1552
SetVector<Use *> Uses;
1553
for (Use &U : V.uses())
1554
if (auto *UserI = dyn_cast<Instruction>(U.getUser()))
1555
if (ParallelRegionBlockSet.count(UserI->getParent()))
1556
Uses.insert(&U);
1557
1558
// __kmpc_fork_call expects extra arguments as pointers. If the input
1559
// already has a pointer type, everything is fine. Otherwise, store the
1560
// value onto stack and load it back inside the to-be-outlined region. This
1561
// will ensure only the pointer will be passed to the function.
1562
// FIXME: if there are more than 15 trailing arguments, they must be
1563
// additionally packed in a struct.
1564
Value *Inner = &V;
1565
if (!V.getType()->isPointerTy()) {
1566
IRBuilder<>::InsertPointGuard Guard(Builder);
1567
LLVM_DEBUG(llvm::dbgs() << "Forwarding input as pointer: " << V << "\n");
1568
1569
Builder.restoreIP(OuterAllocaIP);
1570
Value *Ptr =
1571
Builder.CreateAlloca(V.getType(), nullptr, V.getName() + ".reloaded");
1572
1573
// Store to stack at end of the block that currently branches to the entry
1574
// block of the to-be-outlined region.
1575
Builder.SetInsertPoint(InsertBB,
1576
InsertBB->getTerminator()->getIterator());
1577
Builder.CreateStore(&V, Ptr);
1578
1579
// Load back next to allocations in the to-be-outlined region.
1580
Builder.restoreIP(InnerAllocaIP);
1581
Inner = Builder.CreateLoad(V.getType(), Ptr);
1582
}
1583
1584
Value *ReplacementValue = nullptr;
1585
CallInst *CI = dyn_cast<CallInst>(&V);
1586
if (CI && CI->getCalledFunction() == TIDRTLFn.getCallee()) {
1587
ReplacementValue = PrivTID;
1588
} else {
1589
Builder.restoreIP(
1590
PrivCB(InnerAllocaIP, Builder.saveIP(), V, *Inner, ReplacementValue));
1591
InnerAllocaIP = {
1592
InnerAllocaIP.getBlock(),
1593
InnerAllocaIP.getBlock()->getTerminator()->getIterator()};
1594
1595
assert(ReplacementValue &&
1596
"Expected copy/create callback to set replacement value!");
1597
if (ReplacementValue == &V)
1598
return;
1599
}
1600
1601
for (Use *UPtr : Uses)
1602
UPtr->set(ReplacementValue);
1603
};
1604
1605
// Reset the inner alloca insertion as it will be used for loading the values
1606
// wrapped into pointers before passing them into the to-be-outlined region.
1607
// Configure it to insert immediately after the fake use of zero address so
1608
// that they are available in the generated body and so that the
1609
// OpenMP-related values (thread ID and zero address pointers) remain leading
1610
// in the argument list.
1611
InnerAllocaIP = IRBuilder<>::InsertPoint(
1612
ZeroAddrUse->getParent(), ZeroAddrUse->getNextNode()->getIterator());
1613
1614
// Reset the outer alloca insertion point to the entry of the relevant block
1615
// in case it was invalidated.
1616
OuterAllocaIP = IRBuilder<>::InsertPoint(
1617
OuterAllocaBlock, OuterAllocaBlock->getFirstInsertionPt());
1618
1619
for (Value *Input : Inputs) {
1620
LLVM_DEBUG(dbgs() << "Captured input: " << *Input << "\n");
1621
PrivHelper(*Input);
1622
}
1623
LLVM_DEBUG({
1624
for (Value *Output : Outputs)
1625
LLVM_DEBUG(dbgs() << "Captured output: " << *Output << "\n");
1626
});
1627
assert(Outputs.empty() &&
1628
"OpenMP outlining should not produce live-out values!");
1629
1630
LLVM_DEBUG(dbgs() << "After privatization: " << *OuterFn << "\n");
1631
LLVM_DEBUG({
1632
for (auto *BB : Blocks)
1633
dbgs() << " PBR: " << BB->getName() << "\n";
1634
});
1635
1636
// Adjust the finalization stack, verify the adjustment, and call the
1637
// finalize function a last time to finalize values between the pre-fini
1638
// block and the exit block if we left the parallel "the normal way".
1639
auto FiniInfo = FinalizationStack.pop_back_val();
1640
(void)FiniInfo;
1641
assert(FiniInfo.DK == OMPD_parallel &&
1642
"Unexpected finalization stack state!");
1643
1644
Instruction *PRegPreFiniTI = PRegPreFiniBB->getTerminator();
1645
1646
InsertPointTy PreFiniIP(PRegPreFiniBB, PRegPreFiniTI->getIterator());
1647
FiniCB(PreFiniIP);
1648
1649
// Register the outlined info.
1650
addOutlineInfo(std::move(OI));
1651
1652
InsertPointTy AfterIP(UI->getParent(), UI->getParent()->end());
1653
UI->eraseFromParent();
1654
1655
return AfterIP;
1656
}
1657
1658
void OpenMPIRBuilder::emitFlush(const LocationDescription &Loc) {
1659
// Build call void __kmpc_flush(ident_t *loc)
1660
uint32_t SrcLocStrSize;
1661
Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1662
Value *Args[] = {getOrCreateIdent(SrcLocStr, SrcLocStrSize)};
1663
1664
Builder.CreateCall(getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_flush), Args);
1665
}
1666
1667
void OpenMPIRBuilder::createFlush(const LocationDescription &Loc) {
1668
if (!updateToLocation(Loc))
1669
return;
1670
emitFlush(Loc);
1671
}
1672
1673
void OpenMPIRBuilder::emitTaskwaitImpl(const LocationDescription &Loc) {
1674
// Build call kmp_int32 __kmpc_omp_taskwait(ident_t *loc, kmp_int32
1675
// global_tid);
1676
uint32_t SrcLocStrSize;
1677
Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1678
Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1679
Value *Args[] = {Ident, getOrCreateThreadID(Ident)};
1680
1681
// Ignore return result until untied tasks are supported.
1682
Builder.CreateCall(getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_taskwait),
1683
Args);
1684
}
1685
1686
void OpenMPIRBuilder::createTaskwait(const LocationDescription &Loc) {
1687
if (!updateToLocation(Loc))
1688
return;
1689
emitTaskwaitImpl(Loc);
1690
}
1691
1692
void OpenMPIRBuilder::emitTaskyieldImpl(const LocationDescription &Loc) {
1693
// Build call __kmpc_omp_taskyield(loc, thread_id, 0);
1694
uint32_t SrcLocStrSize;
1695
Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1696
Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1697
Constant *I32Null = ConstantInt::getNullValue(Int32);
1698
Value *Args[] = {Ident, getOrCreateThreadID(Ident), I32Null};
1699
1700
Builder.CreateCall(getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_taskyield),
1701
Args);
1702
}
1703
1704
void OpenMPIRBuilder::createTaskyield(const LocationDescription &Loc) {
1705
if (!updateToLocation(Loc))
1706
return;
1707
emitTaskyieldImpl(Loc);
1708
}
1709
1710
// Processes the dependencies in Dependencies and does the following
1711
// - Allocates space on the stack of an array of DependInfo objects
1712
// - Populates each DependInfo object with relevant information of
1713
// the corresponding dependence.
1714
// - All code is inserted in the entry block of the current function.
1715
static Value *emitTaskDependencies(
1716
OpenMPIRBuilder &OMPBuilder,
1717
SmallVectorImpl<OpenMPIRBuilder::DependData> &Dependencies) {
1718
// Early return if we have no dependencies to process
1719
if (Dependencies.empty())
1720
return nullptr;
1721
1722
// Given a vector of DependData objects, in this function we create an
1723
// array on the stack that holds kmp_dep_info objects corresponding
1724
// to each dependency. This is then passed to the OpenMP runtime.
1725
// For example, if there are 'n' dependencies then the following psedo
1726
// code is generated. Assume the first dependence is on a variable 'a'
1727
//
1728
// \code{c}
1729
// DepArray = alloc(n x sizeof(kmp_depend_info);
1730
// idx = 0;
1731
// DepArray[idx].base_addr = ptrtoint(&a);
1732
// DepArray[idx].len = 8;
1733
// DepArray[idx].flags = Dep.DepKind; /*(See OMPContants.h for DepKind)*/
1734
// ++idx;
1735
// DepArray[idx].base_addr = ...;
1736
// \endcode
1737
1738
IRBuilderBase &Builder = OMPBuilder.Builder;
1739
Type *DependInfo = OMPBuilder.DependInfo;
1740
Module &M = OMPBuilder.M;
1741
1742
Value *DepArray = nullptr;
1743
OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
1744
Builder.SetInsertPoint(
1745
OldIP.getBlock()->getParent()->getEntryBlock().getTerminator());
1746
1747
Type *DepArrayTy = ArrayType::get(DependInfo, Dependencies.size());
1748
DepArray = Builder.CreateAlloca(DepArrayTy, nullptr, ".dep.arr.addr");
1749
1750
for (const auto &[DepIdx, Dep] : enumerate(Dependencies)) {
1751
Value *Base =
1752
Builder.CreateConstInBoundsGEP2_64(DepArrayTy, DepArray, 0, DepIdx);
1753
// Store the pointer to the variable
1754
Value *Addr = Builder.CreateStructGEP(
1755
DependInfo, Base,
1756
static_cast<unsigned int>(RTLDependInfoFields::BaseAddr));
1757
Value *DepValPtr = Builder.CreatePtrToInt(Dep.DepVal, Builder.getInt64Ty());
1758
Builder.CreateStore(DepValPtr, Addr);
1759
// Store the size of the variable
1760
Value *Size = Builder.CreateStructGEP(
1761
DependInfo, Base, static_cast<unsigned int>(RTLDependInfoFields::Len));
1762
Builder.CreateStore(
1763
Builder.getInt64(M.getDataLayout().getTypeStoreSize(Dep.DepValueType)),
1764
Size);
1765
// Store the dependency kind
1766
Value *Flags = Builder.CreateStructGEP(
1767
DependInfo, Base,
1768
static_cast<unsigned int>(RTLDependInfoFields::Flags));
1769
Builder.CreateStore(
1770
ConstantInt::get(Builder.getInt8Ty(),
1771
static_cast<unsigned int>(Dep.DepKind)),
1772
Flags);
1773
}
1774
Builder.restoreIP(OldIP);
1775
return DepArray;
1776
}
1777
1778
OpenMPIRBuilder::InsertPointTy
1779
OpenMPIRBuilder::createTask(const LocationDescription &Loc,
1780
InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB,
1781
bool Tied, Value *Final, Value *IfCondition,
1782
SmallVector<DependData> Dependencies) {
1783
1784
if (!updateToLocation(Loc))
1785
return InsertPointTy();
1786
1787
uint32_t SrcLocStrSize;
1788
Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1789
Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1790
// The current basic block is split into four basic blocks. After outlining,
1791
// they will be mapped as follows:
1792
// ```
1793
// def current_fn() {
1794
// current_basic_block:
1795
// br label %task.exit
1796
// task.exit:
1797
// ; instructions after task
1798
// }
1799
// def outlined_fn() {
1800
// task.alloca:
1801
// br label %task.body
1802
// task.body:
1803
// ret void
1804
// }
1805
// ```
1806
BasicBlock *TaskExitBB = splitBB(Builder, /*CreateBranch=*/true, "task.exit");
1807
BasicBlock *TaskBodyBB = splitBB(Builder, /*CreateBranch=*/true, "task.body");
1808
BasicBlock *TaskAllocaBB =
1809
splitBB(Builder, /*CreateBranch=*/true, "task.alloca");
1810
1811
InsertPointTy TaskAllocaIP =
1812
InsertPointTy(TaskAllocaBB, TaskAllocaBB->begin());
1813
InsertPointTy TaskBodyIP = InsertPointTy(TaskBodyBB, TaskBodyBB->begin());
1814
BodyGenCB(TaskAllocaIP, TaskBodyIP);
1815
1816
OutlineInfo OI;
1817
OI.EntryBB = TaskAllocaBB;
1818
OI.OuterAllocaBB = AllocaIP.getBlock();
1819
OI.ExitBB = TaskExitBB;
1820
1821
// Add the thread ID argument.
1822
SmallVector<Instruction *, 4> ToBeDeleted;
1823
OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
1824
Builder, AllocaIP, ToBeDeleted, TaskAllocaIP, "global.tid", false));
1825
1826
OI.PostOutlineCB = [this, Ident, Tied, Final, IfCondition, Dependencies,
1827
TaskAllocaBB, ToBeDeleted](Function &OutlinedFn) mutable {
1828
// Replace the Stale CI by appropriate RTL function call.
1829
assert(OutlinedFn.getNumUses() == 1 &&
1830
"there must be a single user for the outlined function");
1831
CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back());
1832
1833
// HasShareds is true if any variables are captured in the outlined region,
1834
// false otherwise.
1835
bool HasShareds = StaleCI->arg_size() > 1;
1836
Builder.SetInsertPoint(StaleCI);
1837
1838
// Gather the arguments for emitting the runtime call for
1839
// @__kmpc_omp_task_alloc
1840
Function *TaskAllocFn =
1841
getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc);
1842
1843
// Arguments - `loc_ref` (Ident) and `gtid` (ThreadID)
1844
// call.
1845
Value *ThreadID = getOrCreateThreadID(Ident);
1846
1847
// Argument - `flags`
1848
// Task is tied iff (Flags & 1) == 1.
1849
// Task is untied iff (Flags & 1) == 0.
1850
// Task is final iff (Flags & 2) == 2.
1851
// Task is not final iff (Flags & 2) == 0.
1852
// TODO: Handle the other flags.
1853
Value *Flags = Builder.getInt32(Tied);
1854
if (Final) {
1855
Value *FinalFlag =
1856
Builder.CreateSelect(Final, Builder.getInt32(2), Builder.getInt32(0));
1857
Flags = Builder.CreateOr(FinalFlag, Flags);
1858
}
1859
1860
// Argument - `sizeof_kmp_task_t` (TaskSize)
1861
// Tasksize refers to the size in bytes of kmp_task_t data structure
1862
// including private vars accessed in task.
1863
// TODO: add kmp_task_t_with_privates (privates)
1864
Value *TaskSize = Builder.getInt64(
1865
divideCeil(M.getDataLayout().getTypeSizeInBits(Task), 8));
1866
1867
// Argument - `sizeof_shareds` (SharedsSize)
1868
// SharedsSize refers to the shareds array size in the kmp_task_t data
1869
// structure.
1870
Value *SharedsSize = Builder.getInt64(0);
1871
if (HasShareds) {
1872
AllocaInst *ArgStructAlloca =
1873
dyn_cast<AllocaInst>(StaleCI->getArgOperand(1));
1874
assert(ArgStructAlloca &&
1875
"Unable to find the alloca instruction corresponding to arguments "
1876
"for extracted function");
1877
StructType *ArgStructType =
1878
dyn_cast<StructType>(ArgStructAlloca->getAllocatedType());
1879
assert(ArgStructType && "Unable to find struct type corresponding to "
1880
"arguments for extracted function");
1881
SharedsSize =
1882
Builder.getInt64(M.getDataLayout().getTypeStoreSize(ArgStructType));
1883
}
1884
// Emit the @__kmpc_omp_task_alloc runtime call
1885
// The runtime call returns a pointer to an area where the task captured
1886
// variables must be copied before the task is run (TaskData)
1887
CallInst *TaskData = Builder.CreateCall(
1888
TaskAllocFn, {/*loc_ref=*/Ident, /*gtid=*/ThreadID, /*flags=*/Flags,
1889
/*sizeof_task=*/TaskSize, /*sizeof_shared=*/SharedsSize,
1890
/*task_func=*/&OutlinedFn});
1891
1892
// Copy the arguments for outlined function
1893
if (HasShareds) {
1894
Value *Shareds = StaleCI->getArgOperand(1);
1895
Align Alignment = TaskData->getPointerAlignment(M.getDataLayout());
1896
Value *TaskShareds = Builder.CreateLoad(VoidPtr, TaskData);
1897
Builder.CreateMemCpy(TaskShareds, Alignment, Shareds, Alignment,
1898
SharedsSize);
1899
}
1900
1901
Value *DepArray = nullptr;
1902
if (Dependencies.size()) {
1903
InsertPointTy OldIP = Builder.saveIP();
1904
Builder.SetInsertPoint(
1905
&OldIP.getBlock()->getParent()->getEntryBlock().back());
1906
1907
Type *DepArrayTy = ArrayType::get(DependInfo, Dependencies.size());
1908
DepArray = Builder.CreateAlloca(DepArrayTy, nullptr, ".dep.arr.addr");
1909
1910
unsigned P = 0;
1911
for (const DependData &Dep : Dependencies) {
1912
Value *Base =
1913
Builder.CreateConstInBoundsGEP2_64(DepArrayTy, DepArray, 0, P);
1914
// Store the pointer to the variable
1915
Value *Addr = Builder.CreateStructGEP(
1916
DependInfo, Base,
1917
static_cast<unsigned int>(RTLDependInfoFields::BaseAddr));
1918
Value *DepValPtr =
1919
Builder.CreatePtrToInt(Dep.DepVal, Builder.getInt64Ty());
1920
Builder.CreateStore(DepValPtr, Addr);
1921
// Store the size of the variable
1922
Value *Size = Builder.CreateStructGEP(
1923
DependInfo, Base,
1924
static_cast<unsigned int>(RTLDependInfoFields::Len));
1925
Builder.CreateStore(Builder.getInt64(M.getDataLayout().getTypeStoreSize(
1926
Dep.DepValueType)),
1927
Size);
1928
// Store the dependency kind
1929
Value *Flags = Builder.CreateStructGEP(
1930
DependInfo, Base,
1931
static_cast<unsigned int>(RTLDependInfoFields::Flags));
1932
Builder.CreateStore(
1933
ConstantInt::get(Builder.getInt8Ty(),
1934
static_cast<unsigned int>(Dep.DepKind)),
1935
Flags);
1936
++P;
1937
}
1938
1939
Builder.restoreIP(OldIP);
1940
}
1941
1942
// In the presence of the `if` clause, the following IR is generated:
1943
// ...
1944
// %data = call @__kmpc_omp_task_alloc(...)
1945
// br i1 %if_condition, label %then, label %else
1946
// then:
1947
// call @__kmpc_omp_task(...)
1948
// br label %exit
1949
// else:
1950
// ;; Wait for resolution of dependencies, if any, before
1951
// ;; beginning the task
1952
// call @__kmpc_omp_wait_deps(...)
1953
// call @__kmpc_omp_task_begin_if0(...)
1954
// call @outlined_fn(...)
1955
// call @__kmpc_omp_task_complete_if0(...)
1956
// br label %exit
1957
// exit:
1958
// ...
1959
if (IfCondition) {
1960
// `SplitBlockAndInsertIfThenElse` requires the block to have a
1961
// terminator.
1962
splitBB(Builder, /*CreateBranch=*/true, "if.end");
1963
Instruction *IfTerminator =
1964
Builder.GetInsertPoint()->getParent()->getTerminator();
1965
Instruction *ThenTI = IfTerminator, *ElseTI = nullptr;
1966
Builder.SetInsertPoint(IfTerminator);
1967
SplitBlockAndInsertIfThenElse(IfCondition, IfTerminator, &ThenTI,
1968
&ElseTI);
1969
Builder.SetInsertPoint(ElseTI);
1970
1971
if (Dependencies.size()) {
1972
Function *TaskWaitFn =
1973
getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_wait_deps);
1974
Builder.CreateCall(
1975
TaskWaitFn,
1976
{Ident, ThreadID, Builder.getInt32(Dependencies.size()), DepArray,
1977
ConstantInt::get(Builder.getInt32Ty(), 0),
1978
ConstantPointerNull::get(PointerType::getUnqual(M.getContext()))});
1979
}
1980
Function *TaskBeginFn =
1981
getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_begin_if0);
1982
Function *TaskCompleteFn =
1983
getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_complete_if0);
1984
Builder.CreateCall(TaskBeginFn, {Ident, ThreadID, TaskData});
1985
CallInst *CI = nullptr;
1986
if (HasShareds)
1987
CI = Builder.CreateCall(&OutlinedFn, {ThreadID, TaskData});
1988
else
1989
CI = Builder.CreateCall(&OutlinedFn, {ThreadID});
1990
CI->setDebugLoc(StaleCI->getDebugLoc());
1991
Builder.CreateCall(TaskCompleteFn, {Ident, ThreadID, TaskData});
1992
Builder.SetInsertPoint(ThenTI);
1993
}
1994
1995
if (Dependencies.size()) {
1996
Function *TaskFn =
1997
getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_with_deps);
1998
Builder.CreateCall(
1999
TaskFn,
2000
{Ident, ThreadID, TaskData, Builder.getInt32(Dependencies.size()),
2001
DepArray, ConstantInt::get(Builder.getInt32Ty(), 0),
2002
ConstantPointerNull::get(PointerType::getUnqual(M.getContext()))});
2003
2004
} else {
2005
// Emit the @__kmpc_omp_task runtime call to spawn the task
2006
Function *TaskFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task);
2007
Builder.CreateCall(TaskFn, {Ident, ThreadID, TaskData});
2008
}
2009
2010
StaleCI->eraseFromParent();
2011
2012
Builder.SetInsertPoint(TaskAllocaBB, TaskAllocaBB->begin());
2013
if (HasShareds) {
2014
LoadInst *Shareds = Builder.CreateLoad(VoidPtr, OutlinedFn.getArg(1));
2015
OutlinedFn.getArg(1)->replaceUsesWithIf(
2016
Shareds, [Shareds](Use &U) { return U.getUser() != Shareds; });
2017
}
2018
2019
llvm::for_each(llvm::reverse(ToBeDeleted),
2020
[](Instruction *I) { I->eraseFromParent(); });
2021
};
2022
2023
addOutlineInfo(std::move(OI));
2024
Builder.SetInsertPoint(TaskExitBB, TaskExitBB->begin());
2025
2026
return Builder.saveIP();
2027
}
2028
2029
OpenMPIRBuilder::InsertPointTy
2030
OpenMPIRBuilder::createTaskgroup(const LocationDescription &Loc,
2031
InsertPointTy AllocaIP,
2032
BodyGenCallbackTy BodyGenCB) {
2033
if (!updateToLocation(Loc))
2034
return InsertPointTy();
2035
2036
uint32_t SrcLocStrSize;
2037
Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
2038
Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
2039
Value *ThreadID = getOrCreateThreadID(Ident);
2040
2041
// Emit the @__kmpc_taskgroup runtime call to start the taskgroup
2042
Function *TaskgroupFn =
2043
getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_taskgroup);
2044
Builder.CreateCall(TaskgroupFn, {Ident, ThreadID});
2045
2046
BasicBlock *TaskgroupExitBB = splitBB(Builder, true, "taskgroup.exit");
2047
BodyGenCB(AllocaIP, Builder.saveIP());
2048
2049
Builder.SetInsertPoint(TaskgroupExitBB);
2050
// Emit the @__kmpc_end_taskgroup runtime call to end the taskgroup
2051
Function *EndTaskgroupFn =
2052
getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_taskgroup);
2053
Builder.CreateCall(EndTaskgroupFn, {Ident, ThreadID});
2054
2055
return Builder.saveIP();
2056
}
2057
2058
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createSections(
2059
const LocationDescription &Loc, InsertPointTy AllocaIP,
2060
ArrayRef<StorableBodyGenCallbackTy> SectionCBs, PrivatizeCallbackTy PrivCB,
2061
FinalizeCallbackTy FiniCB, bool IsCancellable, bool IsNowait) {
2062
assert(!isConflictIP(AllocaIP, Loc.IP) && "Dedicated IP allocas required");
2063
2064
if (!updateToLocation(Loc))
2065
return Loc.IP;
2066
2067
auto FiniCBWrapper = [&](InsertPointTy IP) {
2068
if (IP.getBlock()->end() != IP.getPoint())
2069
return FiniCB(IP);
2070
// This must be done otherwise any nested constructs using FinalizeOMPRegion
2071
// will fail because that function requires the Finalization Basic Block to
2072
// have a terminator, which is already removed by EmitOMPRegionBody.
2073
// IP is currently at cancelation block.
2074
// We need to backtrack to the condition block to fetch
2075
// the exit block and create a branch from cancelation
2076
// to exit block.
2077
IRBuilder<>::InsertPointGuard IPG(Builder);
2078
Builder.restoreIP(IP);
2079
auto *CaseBB = IP.getBlock()->getSinglePredecessor();
2080
auto *CondBB = CaseBB->getSinglePredecessor()->getSinglePredecessor();
2081
auto *ExitBB = CondBB->getTerminator()->getSuccessor(1);
2082
Instruction *I = Builder.CreateBr(ExitBB);
2083
IP = InsertPointTy(I->getParent(), I->getIterator());
2084
return FiniCB(IP);
2085
};
2086
2087
FinalizationStack.push_back({FiniCBWrapper, OMPD_sections, IsCancellable});
2088
2089
// Each section is emitted as a switch case
2090
// Each finalization callback is handled from clang.EmitOMPSectionDirective()
2091
// -> OMP.createSection() which generates the IR for each section
2092
// Iterate through all sections and emit a switch construct:
2093
// switch (IV) {
2094
// case 0:
2095
// <SectionStmt[0]>;
2096
// break;
2097
// ...
2098
// case <NumSection> - 1:
2099
// <SectionStmt[<NumSection> - 1]>;
2100
// break;
2101
// }
2102
// ...
2103
// section_loop.after:
2104
// <FiniCB>;
2105
auto LoopBodyGenCB = [&](InsertPointTy CodeGenIP, Value *IndVar) {
2106
Builder.restoreIP(CodeGenIP);
2107
BasicBlock *Continue =
2108
splitBBWithSuffix(Builder, /*CreateBranch=*/false, ".sections.after");
2109
Function *CurFn = Continue->getParent();
2110
SwitchInst *SwitchStmt = Builder.CreateSwitch(IndVar, Continue);
2111
2112
unsigned CaseNumber = 0;
2113
for (auto SectionCB : SectionCBs) {
2114
BasicBlock *CaseBB = BasicBlock::Create(
2115
M.getContext(), "omp_section_loop.body.case", CurFn, Continue);
2116
SwitchStmt->addCase(Builder.getInt32(CaseNumber), CaseBB);
2117
Builder.SetInsertPoint(CaseBB);
2118
BranchInst *CaseEndBr = Builder.CreateBr(Continue);
2119
SectionCB(InsertPointTy(),
2120
{CaseEndBr->getParent(), CaseEndBr->getIterator()});
2121
CaseNumber++;
2122
}
2123
// remove the existing terminator from body BB since there can be no
2124
// terminators after switch/case
2125
};
2126
// Loop body ends here
2127
// LowerBound, UpperBound, and STride for createCanonicalLoop
2128
Type *I32Ty = Type::getInt32Ty(M.getContext());
2129
Value *LB = ConstantInt::get(I32Ty, 0);
2130
Value *UB = ConstantInt::get(I32Ty, SectionCBs.size());
2131
Value *ST = ConstantInt::get(I32Ty, 1);
2132
llvm::CanonicalLoopInfo *LoopInfo = createCanonicalLoop(
2133
Loc, LoopBodyGenCB, LB, UB, ST, true, false, AllocaIP, "section_loop");
2134
InsertPointTy AfterIP =
2135
applyStaticWorkshareLoop(Loc.DL, LoopInfo, AllocaIP, !IsNowait);
2136
2137
// Apply the finalization callback in LoopAfterBB
2138
auto FiniInfo = FinalizationStack.pop_back_val();
2139
assert(FiniInfo.DK == OMPD_sections &&
2140
"Unexpected finalization stack state!");
2141
if (FinalizeCallbackTy &CB = FiniInfo.FiniCB) {
2142
Builder.restoreIP(AfterIP);
2143
BasicBlock *FiniBB =
2144
splitBBWithSuffix(Builder, /*CreateBranch=*/true, "sections.fini");
2145
CB(Builder.saveIP());
2146
AfterIP = {FiniBB, FiniBB->begin()};
2147
}
2148
2149
return AfterIP;
2150
}
2151
2152
OpenMPIRBuilder::InsertPointTy
2153
OpenMPIRBuilder::createSection(const LocationDescription &Loc,
2154
BodyGenCallbackTy BodyGenCB,
2155
FinalizeCallbackTy FiniCB) {
2156
if (!updateToLocation(Loc))
2157
return Loc.IP;
2158
2159
auto FiniCBWrapper = [&](InsertPointTy IP) {
2160
if (IP.getBlock()->end() != IP.getPoint())
2161
return FiniCB(IP);
2162
// This must be done otherwise any nested constructs using FinalizeOMPRegion
2163
// will fail because that function requires the Finalization Basic Block to
2164
// have a terminator, which is already removed by EmitOMPRegionBody.
2165
// IP is currently at cancelation block.
2166
// We need to backtrack to the condition block to fetch
2167
// the exit block and create a branch from cancelation
2168
// to exit block.
2169
IRBuilder<>::InsertPointGuard IPG(Builder);
2170
Builder.restoreIP(IP);
2171
auto *CaseBB = Loc.IP.getBlock();
2172
auto *CondBB = CaseBB->getSinglePredecessor()->getSinglePredecessor();
2173
auto *ExitBB = CondBB->getTerminator()->getSuccessor(1);
2174
Instruction *I = Builder.CreateBr(ExitBB);
2175
IP = InsertPointTy(I->getParent(), I->getIterator());
2176
return FiniCB(IP);
2177
};
2178
2179
Directive OMPD = Directive::OMPD_sections;
2180
// Since we are using Finalization Callback here, HasFinalize
2181
// and IsCancellable have to be true
2182
return EmitOMPInlinedRegion(OMPD, nullptr, nullptr, BodyGenCB, FiniCBWrapper,
2183
/*Conditional*/ false, /*hasFinalize*/ true,
2184
/*IsCancellable*/ true);
2185
}
2186
2187
static OpenMPIRBuilder::InsertPointTy getInsertPointAfterInstr(Instruction *I) {
2188
BasicBlock::iterator IT(I);
2189
IT++;
2190
return OpenMPIRBuilder::InsertPointTy(I->getParent(), IT);
2191
}
2192
2193
void OpenMPIRBuilder::emitUsed(StringRef Name,
2194
std::vector<WeakTrackingVH> &List) {
2195
if (List.empty())
2196
return;
2197
2198
// Convert List to what ConstantArray needs.
2199
SmallVector<Constant *, 8> UsedArray;
2200
UsedArray.resize(List.size());
2201
for (unsigned I = 0, E = List.size(); I != E; ++I)
2202
UsedArray[I] = ConstantExpr::getPointerBitCastOrAddrSpaceCast(
2203
cast<Constant>(&*List[I]), Builder.getPtrTy());
2204
2205
if (UsedArray.empty())
2206
return;
2207
ArrayType *ATy = ArrayType::get(Builder.getPtrTy(), UsedArray.size());
2208
2209
auto *GV = new GlobalVariable(M, ATy, false, GlobalValue::AppendingLinkage,
2210
ConstantArray::get(ATy, UsedArray), Name);
2211
2212
GV->setSection("llvm.metadata");
2213
}
2214
2215
Value *OpenMPIRBuilder::getGPUThreadID() {
2216
return Builder.CreateCall(
2217
getOrCreateRuntimeFunction(M,
2218
OMPRTL___kmpc_get_hardware_thread_id_in_block),
2219
{});
2220
}
2221
2222
Value *OpenMPIRBuilder::getGPUWarpSize() {
2223
return Builder.CreateCall(
2224
getOrCreateRuntimeFunction(M, OMPRTL___kmpc_get_warp_size), {});
2225
}
2226
2227
Value *OpenMPIRBuilder::getNVPTXWarpID() {
2228
unsigned LaneIDBits = Log2_32(Config.getGridValue().GV_Warp_Size);
2229
return Builder.CreateAShr(getGPUThreadID(), LaneIDBits, "nvptx_warp_id");
2230
}
2231
2232
Value *OpenMPIRBuilder::getNVPTXLaneID() {
2233
unsigned LaneIDBits = Log2_32(Config.getGridValue().GV_Warp_Size);
2234
assert(LaneIDBits < 32 && "Invalid LaneIDBits size in NVPTX device.");
2235
unsigned LaneIDMask = ~0u >> (32u - LaneIDBits);
2236
return Builder.CreateAnd(getGPUThreadID(), Builder.getInt32(LaneIDMask),
2237
"nvptx_lane_id");
2238
}
2239
2240
Value *OpenMPIRBuilder::castValueToType(InsertPointTy AllocaIP, Value *From,
2241
Type *ToType) {
2242
Type *FromType = From->getType();
2243
uint64_t FromSize = M.getDataLayout().getTypeStoreSize(FromType);
2244
uint64_t ToSize = M.getDataLayout().getTypeStoreSize(ToType);
2245
assert(FromSize > 0 && "From size must be greater than zero");
2246
assert(ToSize > 0 && "To size must be greater than zero");
2247
if (FromType == ToType)
2248
return From;
2249
if (FromSize == ToSize)
2250
return Builder.CreateBitCast(From, ToType);
2251
if (ToType->isIntegerTy() && FromType->isIntegerTy())
2252
return Builder.CreateIntCast(From, ToType, /*isSigned*/ true);
2253
InsertPointTy SaveIP = Builder.saveIP();
2254
Builder.restoreIP(AllocaIP);
2255
Value *CastItem = Builder.CreateAlloca(ToType);
2256
Builder.restoreIP(SaveIP);
2257
2258
Value *ValCastItem = Builder.CreatePointerBitCastOrAddrSpaceCast(
2259
CastItem, FromType->getPointerTo());
2260
Builder.CreateStore(From, ValCastItem);
2261
return Builder.CreateLoad(ToType, CastItem);
2262
}
2263
2264
Value *OpenMPIRBuilder::createRuntimeShuffleFunction(InsertPointTy AllocaIP,
2265
Value *Element,
2266
Type *ElementType,
2267
Value *Offset) {
2268
uint64_t Size = M.getDataLayout().getTypeStoreSize(ElementType);
2269
assert(Size <= 8 && "Unsupported bitwidth in shuffle instruction");
2270
2271
// Cast all types to 32- or 64-bit values before calling shuffle routines.
2272
Type *CastTy = Builder.getIntNTy(Size <= 4 ? 32 : 64);
2273
Value *ElemCast = castValueToType(AllocaIP, Element, CastTy);
2274
Value *WarpSize =
2275
Builder.CreateIntCast(getGPUWarpSize(), Builder.getInt16Ty(), true);
2276
Function *ShuffleFunc = getOrCreateRuntimeFunctionPtr(
2277
Size <= 4 ? RuntimeFunction::OMPRTL___kmpc_shuffle_int32
2278
: RuntimeFunction::OMPRTL___kmpc_shuffle_int64);
2279
Value *WarpSizeCast =
2280
Builder.CreateIntCast(WarpSize, Builder.getInt16Ty(), /*isSigned=*/true);
2281
Value *ShuffleCall =
2282
Builder.CreateCall(ShuffleFunc, {ElemCast, Offset, WarpSizeCast});
2283
return castValueToType(AllocaIP, ShuffleCall, CastTy);
2284
}
2285
2286
void OpenMPIRBuilder::shuffleAndStore(InsertPointTy AllocaIP, Value *SrcAddr,
2287
Value *DstAddr, Type *ElemType,
2288
Value *Offset, Type *ReductionArrayTy) {
2289
uint64_t Size = M.getDataLayout().getTypeStoreSize(ElemType);
2290
// Create the loop over the big sized data.
2291
// ptr = (void*)Elem;
2292
// ptrEnd = (void*) Elem + 1;
2293
// Step = 8;
2294
// while (ptr + Step < ptrEnd)
2295
// shuffle((int64_t)*ptr);
2296
// Step = 4;
2297
// while (ptr + Step < ptrEnd)
2298
// shuffle((int32_t)*ptr);
2299
// ...
2300
Type *IndexTy = Builder.getIndexTy(
2301
M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
2302
Value *ElemPtr = DstAddr;
2303
Value *Ptr = SrcAddr;
2304
for (unsigned IntSize = 8; IntSize >= 1; IntSize /= 2) {
2305
if (Size < IntSize)
2306
continue;
2307
Type *IntType = Builder.getIntNTy(IntSize * 8);
2308
Ptr = Builder.CreatePointerBitCastOrAddrSpaceCast(
2309
Ptr, IntType->getPointerTo(), Ptr->getName() + ".ascast");
2310
Value *SrcAddrGEP =
2311
Builder.CreateGEP(ElemType, SrcAddr, {ConstantInt::get(IndexTy, 1)});
2312
ElemPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
2313
ElemPtr, IntType->getPointerTo(), ElemPtr->getName() + ".ascast");
2314
2315
Function *CurFunc = Builder.GetInsertBlock()->getParent();
2316
if ((Size / IntSize) > 1) {
2317
Value *PtrEnd = Builder.CreatePointerBitCastOrAddrSpaceCast(
2318
SrcAddrGEP, Builder.getPtrTy());
2319
BasicBlock *PreCondBB =
2320
BasicBlock::Create(M.getContext(), ".shuffle.pre_cond");
2321
BasicBlock *ThenBB = BasicBlock::Create(M.getContext(), ".shuffle.then");
2322
BasicBlock *ExitBB = BasicBlock::Create(M.getContext(), ".shuffle.exit");
2323
BasicBlock *CurrentBB = Builder.GetInsertBlock();
2324
emitBlock(PreCondBB, CurFunc);
2325
PHINode *PhiSrc =
2326
Builder.CreatePHI(Ptr->getType(), /*NumReservedValues=*/2);
2327
PhiSrc->addIncoming(Ptr, CurrentBB);
2328
PHINode *PhiDest =
2329
Builder.CreatePHI(ElemPtr->getType(), /*NumReservedValues=*/2);
2330
PhiDest->addIncoming(ElemPtr, CurrentBB);
2331
Ptr = PhiSrc;
2332
ElemPtr = PhiDest;
2333
Value *PtrDiff = Builder.CreatePtrDiff(
2334
Builder.getInt8Ty(), PtrEnd,
2335
Builder.CreatePointerBitCastOrAddrSpaceCast(Ptr, Builder.getPtrTy()));
2336
Builder.CreateCondBr(
2337
Builder.CreateICmpSGT(PtrDiff, Builder.getInt64(IntSize - 1)), ThenBB,
2338
ExitBB);
2339
emitBlock(ThenBB, CurFunc);
2340
Value *Res = createRuntimeShuffleFunction(
2341
AllocaIP,
2342
Builder.CreateAlignedLoad(
2343
IntType, Ptr, M.getDataLayout().getPrefTypeAlign(ElemType)),
2344
IntType, Offset);
2345
Builder.CreateAlignedStore(Res, ElemPtr,
2346
M.getDataLayout().getPrefTypeAlign(ElemType));
2347
Value *LocalPtr =
2348
Builder.CreateGEP(IntType, Ptr, {ConstantInt::get(IndexTy, 1)});
2349
Value *LocalElemPtr =
2350
Builder.CreateGEP(IntType, ElemPtr, {ConstantInt::get(IndexTy, 1)});
2351
PhiSrc->addIncoming(LocalPtr, ThenBB);
2352
PhiDest->addIncoming(LocalElemPtr, ThenBB);
2353
emitBranch(PreCondBB);
2354
emitBlock(ExitBB, CurFunc);
2355
} else {
2356
Value *Res = createRuntimeShuffleFunction(
2357
AllocaIP, Builder.CreateLoad(IntType, Ptr), IntType, Offset);
2358
if (ElemType->isIntegerTy() && ElemType->getScalarSizeInBits() <
2359
Res->getType()->getScalarSizeInBits())
2360
Res = Builder.CreateTrunc(Res, ElemType);
2361
Builder.CreateStore(Res, ElemPtr);
2362
Ptr = Builder.CreateGEP(IntType, Ptr, {ConstantInt::get(IndexTy, 1)});
2363
ElemPtr =
2364
Builder.CreateGEP(IntType, ElemPtr, {ConstantInt::get(IndexTy, 1)});
2365
}
2366
Size = Size % IntSize;
2367
}
2368
}
2369
2370
void OpenMPIRBuilder::emitReductionListCopy(
2371
InsertPointTy AllocaIP, CopyAction Action, Type *ReductionArrayTy,
2372
ArrayRef<ReductionInfo> ReductionInfos, Value *SrcBase, Value *DestBase,
2373
CopyOptionsTy CopyOptions) {
2374
Type *IndexTy = Builder.getIndexTy(
2375
M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
2376
Value *RemoteLaneOffset = CopyOptions.RemoteLaneOffset;
2377
2378
// Iterates, element-by-element, through the source Reduce list and
2379
// make a copy.
2380
for (auto En : enumerate(ReductionInfos)) {
2381
const ReductionInfo &RI = En.value();
2382
Value *SrcElementAddr = nullptr;
2383
Value *DestElementAddr = nullptr;
2384
Value *DestElementPtrAddr = nullptr;
2385
// Should we shuffle in an element from a remote lane?
2386
bool ShuffleInElement = false;
2387
// Set to true to update the pointer in the dest Reduce list to a
2388
// newly created element.
2389
bool UpdateDestListPtr = false;
2390
2391
// Step 1.1: Get the address for the src element in the Reduce list.
2392
Value *SrcElementPtrAddr = Builder.CreateInBoundsGEP(
2393
ReductionArrayTy, SrcBase,
2394
{ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
2395
SrcElementAddr = Builder.CreateLoad(Builder.getPtrTy(), SrcElementPtrAddr);
2396
2397
// Step 1.2: Create a temporary to store the element in the destination
2398
// Reduce list.
2399
DestElementPtrAddr = Builder.CreateInBoundsGEP(
2400
ReductionArrayTy, DestBase,
2401
{ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
2402
switch (Action) {
2403
case CopyAction::RemoteLaneToThread: {
2404
InsertPointTy CurIP = Builder.saveIP();
2405
Builder.restoreIP(AllocaIP);
2406
AllocaInst *DestAlloca = Builder.CreateAlloca(RI.ElementType, nullptr,
2407
".omp.reduction.element");
2408
DestAlloca->setAlignment(
2409
M.getDataLayout().getPrefTypeAlign(RI.ElementType));
2410
DestElementAddr = DestAlloca;
2411
DestElementAddr =
2412
Builder.CreateAddrSpaceCast(DestElementAddr, Builder.getPtrTy(),
2413
DestElementAddr->getName() + ".ascast");
2414
Builder.restoreIP(CurIP);
2415
ShuffleInElement = true;
2416
UpdateDestListPtr = true;
2417
break;
2418
}
2419
case CopyAction::ThreadCopy: {
2420
DestElementAddr =
2421
Builder.CreateLoad(Builder.getPtrTy(), DestElementPtrAddr);
2422
break;
2423
}
2424
}
2425
2426
// Now that all active lanes have read the element in the
2427
// Reduce list, shuffle over the value from the remote lane.
2428
if (ShuffleInElement) {
2429
shuffleAndStore(AllocaIP, SrcElementAddr, DestElementAddr, RI.ElementType,
2430
RemoteLaneOffset, ReductionArrayTy);
2431
} else {
2432
switch (RI.EvaluationKind) {
2433
case EvalKind::Scalar: {
2434
Value *Elem = Builder.CreateLoad(RI.ElementType, SrcElementAddr);
2435
// Store the source element value to the dest element address.
2436
Builder.CreateStore(Elem, DestElementAddr);
2437
break;
2438
}
2439
case EvalKind::Complex: {
2440
Value *SrcRealPtr = Builder.CreateConstInBoundsGEP2_32(
2441
RI.ElementType, SrcElementAddr, 0, 0, ".realp");
2442
Value *SrcReal = Builder.CreateLoad(
2443
RI.ElementType->getStructElementType(0), SrcRealPtr, ".real");
2444
Value *SrcImgPtr = Builder.CreateConstInBoundsGEP2_32(
2445
RI.ElementType, SrcElementAddr, 0, 1, ".imagp");
2446
Value *SrcImg = Builder.CreateLoad(
2447
RI.ElementType->getStructElementType(1), SrcImgPtr, ".imag");
2448
2449
Value *DestRealPtr = Builder.CreateConstInBoundsGEP2_32(
2450
RI.ElementType, DestElementAddr, 0, 0, ".realp");
2451
Value *DestImgPtr = Builder.CreateConstInBoundsGEP2_32(
2452
RI.ElementType, DestElementAddr, 0, 1, ".imagp");
2453
Builder.CreateStore(SrcReal, DestRealPtr);
2454
Builder.CreateStore(SrcImg, DestImgPtr);
2455
break;
2456
}
2457
case EvalKind::Aggregate: {
2458
Value *SizeVal = Builder.getInt64(
2459
M.getDataLayout().getTypeStoreSize(RI.ElementType));
2460
Builder.CreateMemCpy(
2461
DestElementAddr, M.getDataLayout().getPrefTypeAlign(RI.ElementType),
2462
SrcElementAddr, M.getDataLayout().getPrefTypeAlign(RI.ElementType),
2463
SizeVal, false);
2464
break;
2465
}
2466
};
2467
}
2468
2469
// Step 3.1: Modify reference in dest Reduce list as needed.
2470
// Modifying the reference in Reduce list to point to the newly
2471
// created element. The element is live in the current function
2472
// scope and that of functions it invokes (i.e., reduce_function).
2473
// RemoteReduceData[i] = (void*)&RemoteElem
2474
if (UpdateDestListPtr) {
2475
Value *CastDestAddr = Builder.CreatePointerBitCastOrAddrSpaceCast(
2476
DestElementAddr, Builder.getPtrTy(),
2477
DestElementAddr->getName() + ".ascast");
2478
Builder.CreateStore(CastDestAddr, DestElementPtrAddr);
2479
}
2480
}
2481
}
2482
2483
Function *OpenMPIRBuilder::emitInterWarpCopyFunction(
2484
const LocationDescription &Loc, ArrayRef<ReductionInfo> ReductionInfos,
2485
AttributeList FuncAttrs) {
2486
InsertPointTy SavedIP = Builder.saveIP();
2487
LLVMContext &Ctx = M.getContext();
2488
FunctionType *FuncTy = FunctionType::get(
2489
Builder.getVoidTy(), {Builder.getPtrTy(), Builder.getInt32Ty()},
2490
/* IsVarArg */ false);
2491
Function *WcFunc =
2492
Function::Create(FuncTy, GlobalVariable::InternalLinkage,
2493
"_omp_reduction_inter_warp_copy_func", &M);
2494
WcFunc->setAttributes(FuncAttrs);
2495
WcFunc->addParamAttr(0, Attribute::NoUndef);
2496
WcFunc->addParamAttr(1, Attribute::NoUndef);
2497
BasicBlock *EntryBB = BasicBlock::Create(M.getContext(), "entry", WcFunc);
2498
Builder.SetInsertPoint(EntryBB);
2499
2500
// ReduceList: thread local Reduce list.
2501
// At the stage of the computation when this function is called, partially
2502
// aggregated values reside in the first lane of every active warp.
2503
Argument *ReduceListArg = WcFunc->getArg(0);
2504
// NumWarps: number of warps active in the parallel region. This could
2505
// be smaller than 32 (max warps in a CTA) for partial block reduction.
2506
Argument *NumWarpsArg = WcFunc->getArg(1);
2507
2508
// This array is used as a medium to transfer, one reduce element at a time,
2509
// the data from the first lane of every warp to lanes in the first warp
2510
// in order to perform the final step of a reduction in a parallel region
2511
// (reduction across warps). The array is placed in NVPTX __shared__ memory
2512
// for reduced latency, as well as to have a distinct copy for concurrently
2513
// executing target regions. The array is declared with common linkage so
2514
// as to be shared across compilation units.
2515
StringRef TransferMediumName =
2516
"__openmp_nvptx_data_transfer_temporary_storage";
2517
GlobalVariable *TransferMedium = M.getGlobalVariable(TransferMediumName);
2518
unsigned WarpSize = Config.getGridValue().GV_Warp_Size;
2519
ArrayType *ArrayTy = ArrayType::get(Builder.getInt32Ty(), WarpSize);
2520
if (!TransferMedium) {
2521
TransferMedium = new GlobalVariable(
2522
M, ArrayTy, /*isConstant=*/false, GlobalVariable::WeakAnyLinkage,
2523
UndefValue::get(ArrayTy), TransferMediumName,
2524
/*InsertBefore=*/nullptr, GlobalVariable::NotThreadLocal,
2525
/*AddressSpace=*/3);
2526
}
2527
2528
// Get the CUDA thread id of the current OpenMP thread on the GPU.
2529
Value *GPUThreadID = getGPUThreadID();
2530
// nvptx_lane_id = nvptx_id % warpsize
2531
Value *LaneID = getNVPTXLaneID();
2532
// nvptx_warp_id = nvptx_id / warpsize
2533
Value *WarpID = getNVPTXWarpID();
2534
2535
InsertPointTy AllocaIP =
2536
InsertPointTy(Builder.GetInsertBlock(),
2537
Builder.GetInsertBlock()->getFirstInsertionPt());
2538
Type *Arg0Type = ReduceListArg->getType();
2539
Type *Arg1Type = NumWarpsArg->getType();
2540
Builder.restoreIP(AllocaIP);
2541
AllocaInst *ReduceListAlloca = Builder.CreateAlloca(
2542
Arg0Type, nullptr, ReduceListArg->getName() + ".addr");
2543
AllocaInst *NumWarpsAlloca =
2544
Builder.CreateAlloca(Arg1Type, nullptr, NumWarpsArg->getName() + ".addr");
2545
Value *ReduceListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2546
ReduceListAlloca, Arg0Type, ReduceListAlloca->getName() + ".ascast");
2547
Value *NumWarpsAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2548
NumWarpsAlloca, Arg1Type->getPointerTo(),
2549
NumWarpsAlloca->getName() + ".ascast");
2550
Builder.CreateStore(ReduceListArg, ReduceListAddrCast);
2551
Builder.CreateStore(NumWarpsArg, NumWarpsAddrCast);
2552
AllocaIP = getInsertPointAfterInstr(NumWarpsAlloca);
2553
InsertPointTy CodeGenIP =
2554
getInsertPointAfterInstr(&Builder.GetInsertBlock()->back());
2555
Builder.restoreIP(CodeGenIP);
2556
2557
Value *ReduceList =
2558
Builder.CreateLoad(Builder.getPtrTy(), ReduceListAddrCast);
2559
2560
for (auto En : enumerate(ReductionInfos)) {
2561
//
2562
// Warp master copies reduce element to transfer medium in __shared__
2563
// memory.
2564
//
2565
const ReductionInfo &RI = En.value();
2566
unsigned RealTySize = M.getDataLayout().getTypeAllocSize(RI.ElementType);
2567
for (unsigned TySize = 4; TySize > 0 && RealTySize > 0; TySize /= 2) {
2568
Type *CType = Builder.getIntNTy(TySize * 8);
2569
2570
unsigned NumIters = RealTySize / TySize;
2571
if (NumIters == 0)
2572
continue;
2573
Value *Cnt = nullptr;
2574
Value *CntAddr = nullptr;
2575
BasicBlock *PrecondBB = nullptr;
2576
BasicBlock *ExitBB = nullptr;
2577
if (NumIters > 1) {
2578
CodeGenIP = Builder.saveIP();
2579
Builder.restoreIP(AllocaIP);
2580
CntAddr =
2581
Builder.CreateAlloca(Builder.getInt32Ty(), nullptr, ".cnt.addr");
2582
2583
CntAddr = Builder.CreateAddrSpaceCast(CntAddr, Builder.getPtrTy(),
2584
CntAddr->getName() + ".ascast");
2585
Builder.restoreIP(CodeGenIP);
2586
Builder.CreateStore(Constant::getNullValue(Builder.getInt32Ty()),
2587
CntAddr,
2588
/*Volatile=*/false);
2589
PrecondBB = BasicBlock::Create(Ctx, "precond");
2590
ExitBB = BasicBlock::Create(Ctx, "exit");
2591
BasicBlock *BodyBB = BasicBlock::Create(Ctx, "body");
2592
emitBlock(PrecondBB, Builder.GetInsertBlock()->getParent());
2593
Cnt = Builder.CreateLoad(Builder.getInt32Ty(), CntAddr,
2594
/*Volatile=*/false);
2595
Value *Cmp = Builder.CreateICmpULT(
2596
Cnt, ConstantInt::get(Builder.getInt32Ty(), NumIters));
2597
Builder.CreateCondBr(Cmp, BodyBB, ExitBB);
2598
emitBlock(BodyBB, Builder.GetInsertBlock()->getParent());
2599
}
2600
2601
// kmpc_barrier.
2602
createBarrier(LocationDescription(Builder.saveIP(), Loc.DL),
2603
omp::Directive::OMPD_unknown,
2604
/* ForceSimpleCall */ false,
2605
/* CheckCancelFlag */ true);
2606
BasicBlock *ThenBB = BasicBlock::Create(Ctx, "then");
2607
BasicBlock *ElseBB = BasicBlock::Create(Ctx, "else");
2608
BasicBlock *MergeBB = BasicBlock::Create(Ctx, "ifcont");
2609
2610
// if (lane_id == 0)
2611
Value *IsWarpMaster = Builder.CreateIsNull(LaneID, "warp_master");
2612
Builder.CreateCondBr(IsWarpMaster, ThenBB, ElseBB);
2613
emitBlock(ThenBB, Builder.GetInsertBlock()->getParent());
2614
2615
// Reduce element = LocalReduceList[i]
2616
auto *RedListArrayTy =
2617
ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
2618
Type *IndexTy = Builder.getIndexTy(
2619
M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
2620
Value *ElemPtrPtr =
2621
Builder.CreateInBoundsGEP(RedListArrayTy, ReduceList,
2622
{ConstantInt::get(IndexTy, 0),
2623
ConstantInt::get(IndexTy, En.index())});
2624
// elemptr = ((CopyType*)(elemptrptr)) + I
2625
Value *ElemPtr = Builder.CreateLoad(Builder.getPtrTy(), ElemPtrPtr);
2626
if (NumIters > 1)
2627
ElemPtr = Builder.CreateGEP(Builder.getInt32Ty(), ElemPtr, Cnt);
2628
2629
// Get pointer to location in transfer medium.
2630
// MediumPtr = &medium[warp_id]
2631
Value *MediumPtr = Builder.CreateInBoundsGEP(
2632
ArrayTy, TransferMedium, {Builder.getInt64(0), WarpID});
2633
// elem = *elemptr
2634
//*MediumPtr = elem
2635
Value *Elem = Builder.CreateLoad(CType, ElemPtr);
2636
// Store the source element value to the dest element address.
2637
Builder.CreateStore(Elem, MediumPtr,
2638
/*IsVolatile*/ true);
2639
Builder.CreateBr(MergeBB);
2640
2641
// else
2642
emitBlock(ElseBB, Builder.GetInsertBlock()->getParent());
2643
Builder.CreateBr(MergeBB);
2644
2645
// endif
2646
emitBlock(MergeBB, Builder.GetInsertBlock()->getParent());
2647
createBarrier(LocationDescription(Builder.saveIP(), Loc.DL),
2648
omp::Directive::OMPD_unknown,
2649
/* ForceSimpleCall */ false,
2650
/* CheckCancelFlag */ true);
2651
2652
// Warp 0 copies reduce element from transfer medium
2653
BasicBlock *W0ThenBB = BasicBlock::Create(Ctx, "then");
2654
BasicBlock *W0ElseBB = BasicBlock::Create(Ctx, "else");
2655
BasicBlock *W0MergeBB = BasicBlock::Create(Ctx, "ifcont");
2656
2657
Value *NumWarpsVal =
2658
Builder.CreateLoad(Builder.getInt32Ty(), NumWarpsAddrCast);
2659
// Up to 32 threads in warp 0 are active.
2660
Value *IsActiveThread =
2661
Builder.CreateICmpULT(GPUThreadID, NumWarpsVal, "is_active_thread");
2662
Builder.CreateCondBr(IsActiveThread, W0ThenBB, W0ElseBB);
2663
2664
emitBlock(W0ThenBB, Builder.GetInsertBlock()->getParent());
2665
2666
// SecMediumPtr = &medium[tid]
2667
// SrcMediumVal = *SrcMediumPtr
2668
Value *SrcMediumPtrVal = Builder.CreateInBoundsGEP(
2669
ArrayTy, TransferMedium, {Builder.getInt64(0), GPUThreadID});
2670
// TargetElemPtr = (CopyType*)(SrcDataAddr[i]) + I
2671
Value *TargetElemPtrPtr =
2672
Builder.CreateInBoundsGEP(RedListArrayTy, ReduceList,
2673
{ConstantInt::get(IndexTy, 0),
2674
ConstantInt::get(IndexTy, En.index())});
2675
Value *TargetElemPtrVal =
2676
Builder.CreateLoad(Builder.getPtrTy(), TargetElemPtrPtr);
2677
Value *TargetElemPtr = TargetElemPtrVal;
2678
if (NumIters > 1)
2679
TargetElemPtr =
2680
Builder.CreateGEP(Builder.getInt32Ty(), TargetElemPtr, Cnt);
2681
2682
// *TargetElemPtr = SrcMediumVal;
2683
Value *SrcMediumValue =
2684
Builder.CreateLoad(CType, SrcMediumPtrVal, /*IsVolatile*/ true);
2685
Builder.CreateStore(SrcMediumValue, TargetElemPtr);
2686
Builder.CreateBr(W0MergeBB);
2687
2688
emitBlock(W0ElseBB, Builder.GetInsertBlock()->getParent());
2689
Builder.CreateBr(W0MergeBB);
2690
2691
emitBlock(W0MergeBB, Builder.GetInsertBlock()->getParent());
2692
2693
if (NumIters > 1) {
2694
Cnt = Builder.CreateNSWAdd(
2695
Cnt, ConstantInt::get(Builder.getInt32Ty(), /*V=*/1));
2696
Builder.CreateStore(Cnt, CntAddr, /*Volatile=*/false);
2697
2698
auto *CurFn = Builder.GetInsertBlock()->getParent();
2699
emitBranch(PrecondBB);
2700
emitBlock(ExitBB, CurFn);
2701
}
2702
RealTySize %= TySize;
2703
}
2704
}
2705
2706
Builder.CreateRetVoid();
2707
Builder.restoreIP(SavedIP);
2708
2709
return WcFunc;
2710
}
2711
2712
Function *OpenMPIRBuilder::emitShuffleAndReduceFunction(
2713
ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
2714
AttributeList FuncAttrs) {
2715
LLVMContext &Ctx = M.getContext();
2716
FunctionType *FuncTy =
2717
FunctionType::get(Builder.getVoidTy(),
2718
{Builder.getPtrTy(), Builder.getInt16Ty(),
2719
Builder.getInt16Ty(), Builder.getInt16Ty()},
2720
/* IsVarArg */ false);
2721
Function *SarFunc =
2722
Function::Create(FuncTy, GlobalVariable::InternalLinkage,
2723
"_omp_reduction_shuffle_and_reduce_func", &M);
2724
SarFunc->setAttributes(FuncAttrs);
2725
SarFunc->addParamAttr(0, Attribute::NoUndef);
2726
SarFunc->addParamAttr(1, Attribute::NoUndef);
2727
SarFunc->addParamAttr(2, Attribute::NoUndef);
2728
SarFunc->addParamAttr(3, Attribute::NoUndef);
2729
SarFunc->addParamAttr(1, Attribute::SExt);
2730
SarFunc->addParamAttr(2, Attribute::SExt);
2731
SarFunc->addParamAttr(3, Attribute::SExt);
2732
BasicBlock *EntryBB = BasicBlock::Create(M.getContext(), "entry", SarFunc);
2733
Builder.SetInsertPoint(EntryBB);
2734
2735
// Thread local Reduce list used to host the values of data to be reduced.
2736
Argument *ReduceListArg = SarFunc->getArg(0);
2737
// Current lane id; could be logical.
2738
Argument *LaneIDArg = SarFunc->getArg(1);
2739
// Offset of the remote source lane relative to the current lane.
2740
Argument *RemoteLaneOffsetArg = SarFunc->getArg(2);
2741
// Algorithm version. This is expected to be known at compile time.
2742
Argument *AlgoVerArg = SarFunc->getArg(3);
2743
2744
Type *ReduceListArgType = ReduceListArg->getType();
2745
Type *LaneIDArgType = LaneIDArg->getType();
2746
Type *LaneIDArgPtrType = LaneIDArg->getType()->getPointerTo();
2747
Value *ReduceListAlloca = Builder.CreateAlloca(
2748
ReduceListArgType, nullptr, ReduceListArg->getName() + ".addr");
2749
Value *LaneIdAlloca = Builder.CreateAlloca(LaneIDArgType, nullptr,
2750
LaneIDArg->getName() + ".addr");
2751
Value *RemoteLaneOffsetAlloca = Builder.CreateAlloca(
2752
LaneIDArgType, nullptr, RemoteLaneOffsetArg->getName() + ".addr");
2753
Value *AlgoVerAlloca = Builder.CreateAlloca(LaneIDArgType, nullptr,
2754
AlgoVerArg->getName() + ".addr");
2755
ArrayType *RedListArrayTy =
2756
ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
2757
2758
// Create a local thread-private variable to host the Reduce list
2759
// from a remote lane.
2760
Instruction *RemoteReductionListAlloca = Builder.CreateAlloca(
2761
RedListArrayTy, nullptr, ".omp.reduction.remote_reduce_list");
2762
2763
Value *ReduceListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2764
ReduceListAlloca, ReduceListArgType,
2765
ReduceListAlloca->getName() + ".ascast");
2766
Value *LaneIdAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2767
LaneIdAlloca, LaneIDArgPtrType, LaneIdAlloca->getName() + ".ascast");
2768
Value *RemoteLaneOffsetAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2769
RemoteLaneOffsetAlloca, LaneIDArgPtrType,
2770
RemoteLaneOffsetAlloca->getName() + ".ascast");
2771
Value *AlgoVerAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2772
AlgoVerAlloca, LaneIDArgPtrType, AlgoVerAlloca->getName() + ".ascast");
2773
Value *RemoteListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2774
RemoteReductionListAlloca, Builder.getPtrTy(),
2775
RemoteReductionListAlloca->getName() + ".ascast");
2776
2777
Builder.CreateStore(ReduceListArg, ReduceListAddrCast);
2778
Builder.CreateStore(LaneIDArg, LaneIdAddrCast);
2779
Builder.CreateStore(RemoteLaneOffsetArg, RemoteLaneOffsetAddrCast);
2780
Builder.CreateStore(AlgoVerArg, AlgoVerAddrCast);
2781
2782
Value *ReduceList = Builder.CreateLoad(ReduceListArgType, ReduceListAddrCast);
2783
Value *LaneId = Builder.CreateLoad(LaneIDArgType, LaneIdAddrCast);
2784
Value *RemoteLaneOffset =
2785
Builder.CreateLoad(LaneIDArgType, RemoteLaneOffsetAddrCast);
2786
Value *AlgoVer = Builder.CreateLoad(LaneIDArgType, AlgoVerAddrCast);
2787
2788
InsertPointTy AllocaIP = getInsertPointAfterInstr(RemoteReductionListAlloca);
2789
2790
// This loop iterates through the list of reduce elements and copies,
2791
// element by element, from a remote lane in the warp to RemoteReduceList,
2792
// hosted on the thread's stack.
2793
emitReductionListCopy(
2794
AllocaIP, CopyAction::RemoteLaneToThread, RedListArrayTy, ReductionInfos,
2795
ReduceList, RemoteListAddrCast, {RemoteLaneOffset, nullptr, nullptr});
2796
2797
// The actions to be performed on the Remote Reduce list is dependent
2798
// on the algorithm version.
2799
//
2800
// if (AlgoVer==0) || (AlgoVer==1 && (LaneId < Offset)) || (AlgoVer==2 &&
2801
// LaneId % 2 == 0 && Offset > 0):
2802
// do the reduction value aggregation
2803
//
2804
// The thread local variable Reduce list is mutated in place to host the
2805
// reduced data, which is the aggregated value produced from local and
2806
// remote lanes.
2807
//
2808
// Note that AlgoVer is expected to be a constant integer known at compile
2809
// time.
2810
// When AlgoVer==0, the first conjunction evaluates to true, making
2811
// the entire predicate true during compile time.
2812
// When AlgoVer==1, the second conjunction has only the second part to be
2813
// evaluated during runtime. Other conjunctions evaluates to false
2814
// during compile time.
2815
// When AlgoVer==2, the third conjunction has only the second part to be
2816
// evaluated during runtime. Other conjunctions evaluates to false
2817
// during compile time.
2818
Value *CondAlgo0 = Builder.CreateIsNull(AlgoVer);
2819
Value *Algo1 = Builder.CreateICmpEQ(AlgoVer, Builder.getInt16(1));
2820
Value *LaneComp = Builder.CreateICmpULT(LaneId, RemoteLaneOffset);
2821
Value *CondAlgo1 = Builder.CreateAnd(Algo1, LaneComp);
2822
Value *Algo2 = Builder.CreateICmpEQ(AlgoVer, Builder.getInt16(2));
2823
Value *LaneIdAnd1 = Builder.CreateAnd(LaneId, Builder.getInt16(1));
2824
Value *LaneIdComp = Builder.CreateIsNull(LaneIdAnd1);
2825
Value *Algo2AndLaneIdComp = Builder.CreateAnd(Algo2, LaneIdComp);
2826
Value *RemoteOffsetComp =
2827
Builder.CreateICmpSGT(RemoteLaneOffset, Builder.getInt16(0));
2828
Value *CondAlgo2 = Builder.CreateAnd(Algo2AndLaneIdComp, RemoteOffsetComp);
2829
Value *CA0OrCA1 = Builder.CreateOr(CondAlgo0, CondAlgo1);
2830
Value *CondReduce = Builder.CreateOr(CA0OrCA1, CondAlgo2);
2831
2832
BasicBlock *ThenBB = BasicBlock::Create(Ctx, "then");
2833
BasicBlock *ElseBB = BasicBlock::Create(Ctx, "else");
2834
BasicBlock *MergeBB = BasicBlock::Create(Ctx, "ifcont");
2835
2836
Builder.CreateCondBr(CondReduce, ThenBB, ElseBB);
2837
emitBlock(ThenBB, Builder.GetInsertBlock()->getParent());
2838
Value *LocalReduceListPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
2839
ReduceList, Builder.getPtrTy());
2840
Value *RemoteReduceListPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
2841
RemoteListAddrCast, Builder.getPtrTy());
2842
Builder.CreateCall(ReduceFn, {LocalReduceListPtr, RemoteReduceListPtr})
2843
->addFnAttr(Attribute::NoUnwind);
2844
Builder.CreateBr(MergeBB);
2845
2846
emitBlock(ElseBB, Builder.GetInsertBlock()->getParent());
2847
Builder.CreateBr(MergeBB);
2848
2849
emitBlock(MergeBB, Builder.GetInsertBlock()->getParent());
2850
2851
// if (AlgoVer==1 && (LaneId >= Offset)) copy Remote Reduce list to local
2852
// Reduce list.
2853
Algo1 = Builder.CreateICmpEQ(AlgoVer, Builder.getInt16(1));
2854
Value *LaneIdGtOffset = Builder.CreateICmpUGE(LaneId, RemoteLaneOffset);
2855
Value *CondCopy = Builder.CreateAnd(Algo1, LaneIdGtOffset);
2856
2857
BasicBlock *CpyThenBB = BasicBlock::Create(Ctx, "then");
2858
BasicBlock *CpyElseBB = BasicBlock::Create(Ctx, "else");
2859
BasicBlock *CpyMergeBB = BasicBlock::Create(Ctx, "ifcont");
2860
Builder.CreateCondBr(CondCopy, CpyThenBB, CpyElseBB);
2861
2862
emitBlock(CpyThenBB, Builder.GetInsertBlock()->getParent());
2863
emitReductionListCopy(AllocaIP, CopyAction::ThreadCopy, RedListArrayTy,
2864
ReductionInfos, RemoteListAddrCast, ReduceList);
2865
Builder.CreateBr(CpyMergeBB);
2866
2867
emitBlock(CpyElseBB, Builder.GetInsertBlock()->getParent());
2868
Builder.CreateBr(CpyMergeBB);
2869
2870
emitBlock(CpyMergeBB, Builder.GetInsertBlock()->getParent());
2871
2872
Builder.CreateRetVoid();
2873
2874
return SarFunc;
2875
}
2876
2877
Function *OpenMPIRBuilder::emitListToGlobalCopyFunction(
2878
ArrayRef<ReductionInfo> ReductionInfos, Type *ReductionsBufferTy,
2879
AttributeList FuncAttrs) {
2880
OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
2881
LLVMContext &Ctx = M.getContext();
2882
FunctionType *FuncTy = FunctionType::get(
2883
Builder.getVoidTy(),
2884
{Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
2885
/* IsVarArg */ false);
2886
Function *LtGCFunc =
2887
Function::Create(FuncTy, GlobalVariable::InternalLinkage,
2888
"_omp_reduction_list_to_global_copy_func", &M);
2889
LtGCFunc->setAttributes(FuncAttrs);
2890
LtGCFunc->addParamAttr(0, Attribute::NoUndef);
2891
LtGCFunc->addParamAttr(1, Attribute::NoUndef);
2892
LtGCFunc->addParamAttr(2, Attribute::NoUndef);
2893
2894
BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", LtGCFunc);
2895
Builder.SetInsertPoint(EntryBlock);
2896
2897
// Buffer: global reduction buffer.
2898
Argument *BufferArg = LtGCFunc->getArg(0);
2899
// Idx: index of the buffer.
2900
Argument *IdxArg = LtGCFunc->getArg(1);
2901
// ReduceList: thread local Reduce list.
2902
Argument *ReduceListArg = LtGCFunc->getArg(2);
2903
2904
Value *BufferArgAlloca = Builder.CreateAlloca(Builder.getPtrTy(), nullptr,
2905
BufferArg->getName() + ".addr");
2906
Value *IdxArgAlloca = Builder.CreateAlloca(Builder.getInt32Ty(), nullptr,
2907
IdxArg->getName() + ".addr");
2908
Value *ReduceListArgAlloca = Builder.CreateAlloca(
2909
Builder.getPtrTy(), nullptr, ReduceListArg->getName() + ".addr");
2910
Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2911
BufferArgAlloca, Builder.getPtrTy(),
2912
BufferArgAlloca->getName() + ".ascast");
2913
Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2914
IdxArgAlloca, Builder.getPtrTy(), IdxArgAlloca->getName() + ".ascast");
2915
Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
2916
ReduceListArgAlloca, Builder.getPtrTy(),
2917
ReduceListArgAlloca->getName() + ".ascast");
2918
2919
Builder.CreateStore(BufferArg, BufferArgAddrCast);
2920
Builder.CreateStore(IdxArg, IdxArgAddrCast);
2921
Builder.CreateStore(ReduceListArg, ReduceListArgAddrCast);
2922
2923
Value *LocalReduceList =
2924
Builder.CreateLoad(Builder.getPtrTy(), ReduceListArgAddrCast);
2925
Value *BufferArgVal =
2926
Builder.CreateLoad(Builder.getPtrTy(), BufferArgAddrCast);
2927
Value *Idxs[] = {Builder.CreateLoad(Builder.getInt32Ty(), IdxArgAddrCast)};
2928
Type *IndexTy = Builder.getIndexTy(
2929
M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
2930
for (auto En : enumerate(ReductionInfos)) {
2931
const ReductionInfo &RI = En.value();
2932
auto *RedListArrayTy =
2933
ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
2934
// Reduce element = LocalReduceList[i]
2935
Value *ElemPtrPtr = Builder.CreateInBoundsGEP(
2936
RedListArrayTy, LocalReduceList,
2937
{ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
2938
// elemptr = ((CopyType*)(elemptrptr)) + I
2939
Value *ElemPtr = Builder.CreateLoad(Builder.getPtrTy(), ElemPtrPtr);
2940
2941
// Global = Buffer.VD[Idx];
2942
Value *BufferVD =
2943
Builder.CreateInBoundsGEP(ReductionsBufferTy, BufferArgVal, Idxs);
2944
Value *GlobVal = Builder.CreateConstInBoundsGEP2_32(
2945
ReductionsBufferTy, BufferVD, 0, En.index());
2946
2947
switch (RI.EvaluationKind) {
2948
case EvalKind::Scalar: {
2949
Value *TargetElement = Builder.CreateLoad(RI.ElementType, ElemPtr);
2950
Builder.CreateStore(TargetElement, GlobVal);
2951
break;
2952
}
2953
case EvalKind::Complex: {
2954
Value *SrcRealPtr = Builder.CreateConstInBoundsGEP2_32(
2955
RI.ElementType, ElemPtr, 0, 0, ".realp");
2956
Value *SrcReal = Builder.CreateLoad(
2957
RI.ElementType->getStructElementType(0), SrcRealPtr, ".real");
2958
Value *SrcImgPtr = Builder.CreateConstInBoundsGEP2_32(
2959
RI.ElementType, ElemPtr, 0, 1, ".imagp");
2960
Value *SrcImg = Builder.CreateLoad(
2961
RI.ElementType->getStructElementType(1), SrcImgPtr, ".imag");
2962
2963
Value *DestRealPtr = Builder.CreateConstInBoundsGEP2_32(
2964
RI.ElementType, GlobVal, 0, 0, ".realp");
2965
Value *DestImgPtr = Builder.CreateConstInBoundsGEP2_32(
2966
RI.ElementType, GlobVal, 0, 1, ".imagp");
2967
Builder.CreateStore(SrcReal, DestRealPtr);
2968
Builder.CreateStore(SrcImg, DestImgPtr);
2969
break;
2970
}
2971
case EvalKind::Aggregate: {
2972
Value *SizeVal =
2973
Builder.getInt64(M.getDataLayout().getTypeStoreSize(RI.ElementType));
2974
Builder.CreateMemCpy(
2975
GlobVal, M.getDataLayout().getPrefTypeAlign(RI.ElementType), ElemPtr,
2976
M.getDataLayout().getPrefTypeAlign(RI.ElementType), SizeVal, false);
2977
break;
2978
}
2979
}
2980
}
2981
2982
Builder.CreateRetVoid();
2983
Builder.restoreIP(OldIP);
2984
return LtGCFunc;
2985
}
2986
2987
Function *OpenMPIRBuilder::emitListToGlobalReduceFunction(
2988
ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
2989
Type *ReductionsBufferTy, AttributeList FuncAttrs) {
2990
OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
2991
LLVMContext &Ctx = M.getContext();
2992
FunctionType *FuncTy = FunctionType::get(
2993
Builder.getVoidTy(),
2994
{Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
2995
/* IsVarArg */ false);
2996
Function *LtGRFunc =
2997
Function::Create(FuncTy, GlobalVariable::InternalLinkage,
2998
"_omp_reduction_list_to_global_reduce_func", &M);
2999
LtGRFunc->setAttributes(FuncAttrs);
3000
LtGRFunc->addParamAttr(0, Attribute::NoUndef);
3001
LtGRFunc->addParamAttr(1, Attribute::NoUndef);
3002
LtGRFunc->addParamAttr(2, Attribute::NoUndef);
3003
3004
BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", LtGRFunc);
3005
Builder.SetInsertPoint(EntryBlock);
3006
3007
// Buffer: global reduction buffer.
3008
Argument *BufferArg = LtGRFunc->getArg(0);
3009
// Idx: index of the buffer.
3010
Argument *IdxArg = LtGRFunc->getArg(1);
3011
// ReduceList: thread local Reduce list.
3012
Argument *ReduceListArg = LtGRFunc->getArg(2);
3013
3014
Value *BufferArgAlloca = Builder.CreateAlloca(Builder.getPtrTy(), nullptr,
3015
BufferArg->getName() + ".addr");
3016
Value *IdxArgAlloca = Builder.CreateAlloca(Builder.getInt32Ty(), nullptr,
3017
IdxArg->getName() + ".addr");
3018
Value *ReduceListArgAlloca = Builder.CreateAlloca(
3019
Builder.getPtrTy(), nullptr, ReduceListArg->getName() + ".addr");
3020
auto *RedListArrayTy =
3021
ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
3022
3023
// 1. Build a list of reduction variables.
3024
// void *RedList[<n>] = {<ReductionVars>[0], ..., <ReductionVars>[<n>-1]};
3025
Value *LocalReduceList =
3026
Builder.CreateAlloca(RedListArrayTy, nullptr, ".omp.reduction.red_list");
3027
3028
Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3029
BufferArgAlloca, Builder.getPtrTy(),
3030
BufferArgAlloca->getName() + ".ascast");
3031
Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3032
IdxArgAlloca, Builder.getPtrTy(), IdxArgAlloca->getName() + ".ascast");
3033
Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3034
ReduceListArgAlloca, Builder.getPtrTy(),
3035
ReduceListArgAlloca->getName() + ".ascast");
3036
Value *LocalReduceListAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3037
LocalReduceList, Builder.getPtrTy(),
3038
LocalReduceList->getName() + ".ascast");
3039
3040
Builder.CreateStore(BufferArg, BufferArgAddrCast);
3041
Builder.CreateStore(IdxArg, IdxArgAddrCast);
3042
Builder.CreateStore(ReduceListArg, ReduceListArgAddrCast);
3043
3044
Value *BufferVal = Builder.CreateLoad(Builder.getPtrTy(), BufferArgAddrCast);
3045
Value *Idxs[] = {Builder.CreateLoad(Builder.getInt32Ty(), IdxArgAddrCast)};
3046
Type *IndexTy = Builder.getIndexTy(
3047
M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
3048
for (auto En : enumerate(ReductionInfos)) {
3049
Value *TargetElementPtrPtr = Builder.CreateInBoundsGEP(
3050
RedListArrayTy, LocalReduceListAddrCast,
3051
{ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
3052
Value *BufferVD =
3053
Builder.CreateInBoundsGEP(ReductionsBufferTy, BufferVal, Idxs);
3054
// Global = Buffer.VD[Idx];
3055
Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
3056
ReductionsBufferTy, BufferVD, 0, En.index());
3057
Builder.CreateStore(GlobValPtr, TargetElementPtrPtr);
3058
}
3059
3060
// Call reduce_function(GlobalReduceList, ReduceList)
3061
Value *ReduceList =
3062
Builder.CreateLoad(Builder.getPtrTy(), ReduceListArgAddrCast);
3063
Builder.CreateCall(ReduceFn, {LocalReduceListAddrCast, ReduceList})
3064
->addFnAttr(Attribute::NoUnwind);
3065
Builder.CreateRetVoid();
3066
Builder.restoreIP(OldIP);
3067
return LtGRFunc;
3068
}
3069
3070
Function *OpenMPIRBuilder::emitGlobalToListCopyFunction(
3071
ArrayRef<ReductionInfo> ReductionInfos, Type *ReductionsBufferTy,
3072
AttributeList FuncAttrs) {
3073
OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
3074
LLVMContext &Ctx = M.getContext();
3075
FunctionType *FuncTy = FunctionType::get(
3076
Builder.getVoidTy(),
3077
{Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
3078
/* IsVarArg */ false);
3079
Function *LtGCFunc =
3080
Function::Create(FuncTy, GlobalVariable::InternalLinkage,
3081
"_omp_reduction_global_to_list_copy_func", &M);
3082
LtGCFunc->setAttributes(FuncAttrs);
3083
LtGCFunc->addParamAttr(0, Attribute::NoUndef);
3084
LtGCFunc->addParamAttr(1, Attribute::NoUndef);
3085
LtGCFunc->addParamAttr(2, Attribute::NoUndef);
3086
3087
BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", LtGCFunc);
3088
Builder.SetInsertPoint(EntryBlock);
3089
3090
// Buffer: global reduction buffer.
3091
Argument *BufferArg = LtGCFunc->getArg(0);
3092
// Idx: index of the buffer.
3093
Argument *IdxArg = LtGCFunc->getArg(1);
3094
// ReduceList: thread local Reduce list.
3095
Argument *ReduceListArg = LtGCFunc->getArg(2);
3096
3097
Value *BufferArgAlloca = Builder.CreateAlloca(Builder.getPtrTy(), nullptr,
3098
BufferArg->getName() + ".addr");
3099
Value *IdxArgAlloca = Builder.CreateAlloca(Builder.getInt32Ty(), nullptr,
3100
IdxArg->getName() + ".addr");
3101
Value *ReduceListArgAlloca = Builder.CreateAlloca(
3102
Builder.getPtrTy(), nullptr, ReduceListArg->getName() + ".addr");
3103
Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3104
BufferArgAlloca, Builder.getPtrTy(),
3105
BufferArgAlloca->getName() + ".ascast");
3106
Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3107
IdxArgAlloca, Builder.getPtrTy(), IdxArgAlloca->getName() + ".ascast");
3108
Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3109
ReduceListArgAlloca, Builder.getPtrTy(),
3110
ReduceListArgAlloca->getName() + ".ascast");
3111
Builder.CreateStore(BufferArg, BufferArgAddrCast);
3112
Builder.CreateStore(IdxArg, IdxArgAddrCast);
3113
Builder.CreateStore(ReduceListArg, ReduceListArgAddrCast);
3114
3115
Value *LocalReduceList =
3116
Builder.CreateLoad(Builder.getPtrTy(), ReduceListArgAddrCast);
3117
Value *BufferVal = Builder.CreateLoad(Builder.getPtrTy(), BufferArgAddrCast);
3118
Value *Idxs[] = {Builder.CreateLoad(Builder.getInt32Ty(), IdxArgAddrCast)};
3119
Type *IndexTy = Builder.getIndexTy(
3120
M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
3121
for (auto En : enumerate(ReductionInfos)) {
3122
const OpenMPIRBuilder::ReductionInfo &RI = En.value();
3123
auto *RedListArrayTy =
3124
ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
3125
// Reduce element = LocalReduceList[i]
3126
Value *ElemPtrPtr = Builder.CreateInBoundsGEP(
3127
RedListArrayTy, LocalReduceList,
3128
{ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
3129
// elemptr = ((CopyType*)(elemptrptr)) + I
3130
Value *ElemPtr = Builder.CreateLoad(Builder.getPtrTy(), ElemPtrPtr);
3131
// Global = Buffer.VD[Idx];
3132
Value *BufferVD =
3133
Builder.CreateInBoundsGEP(ReductionsBufferTy, BufferVal, Idxs);
3134
Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
3135
ReductionsBufferTy, BufferVD, 0, En.index());
3136
3137
switch (RI.EvaluationKind) {
3138
case EvalKind::Scalar: {
3139
Value *TargetElement = Builder.CreateLoad(RI.ElementType, GlobValPtr);
3140
Builder.CreateStore(TargetElement, ElemPtr);
3141
break;
3142
}
3143
case EvalKind::Complex: {
3144
Value *SrcRealPtr = Builder.CreateConstInBoundsGEP2_32(
3145
RI.ElementType, GlobValPtr, 0, 0, ".realp");
3146
Value *SrcReal = Builder.CreateLoad(
3147
RI.ElementType->getStructElementType(0), SrcRealPtr, ".real");
3148
Value *SrcImgPtr = Builder.CreateConstInBoundsGEP2_32(
3149
RI.ElementType, GlobValPtr, 0, 1, ".imagp");
3150
Value *SrcImg = Builder.CreateLoad(
3151
RI.ElementType->getStructElementType(1), SrcImgPtr, ".imag");
3152
3153
Value *DestRealPtr = Builder.CreateConstInBoundsGEP2_32(
3154
RI.ElementType, ElemPtr, 0, 0, ".realp");
3155
Value *DestImgPtr = Builder.CreateConstInBoundsGEP2_32(
3156
RI.ElementType, ElemPtr, 0, 1, ".imagp");
3157
Builder.CreateStore(SrcReal, DestRealPtr);
3158
Builder.CreateStore(SrcImg, DestImgPtr);
3159
break;
3160
}
3161
case EvalKind::Aggregate: {
3162
Value *SizeVal =
3163
Builder.getInt64(M.getDataLayout().getTypeStoreSize(RI.ElementType));
3164
Builder.CreateMemCpy(
3165
ElemPtr, M.getDataLayout().getPrefTypeAlign(RI.ElementType),
3166
GlobValPtr, M.getDataLayout().getPrefTypeAlign(RI.ElementType),
3167
SizeVal, false);
3168
break;
3169
}
3170
}
3171
}
3172
3173
Builder.CreateRetVoid();
3174
Builder.restoreIP(OldIP);
3175
return LtGCFunc;
3176
}
3177
3178
Function *OpenMPIRBuilder::emitGlobalToListReduceFunction(
3179
ArrayRef<ReductionInfo> ReductionInfos, Function *ReduceFn,
3180
Type *ReductionsBufferTy, AttributeList FuncAttrs) {
3181
OpenMPIRBuilder::InsertPointTy OldIP = Builder.saveIP();
3182
LLVMContext &Ctx = M.getContext();
3183
auto *FuncTy = FunctionType::get(
3184
Builder.getVoidTy(),
3185
{Builder.getPtrTy(), Builder.getInt32Ty(), Builder.getPtrTy()},
3186
/* IsVarArg */ false);
3187
Function *LtGRFunc =
3188
Function::Create(FuncTy, GlobalVariable::InternalLinkage,
3189
"_omp_reduction_global_to_list_reduce_func", &M);
3190
LtGRFunc->setAttributes(FuncAttrs);
3191
LtGRFunc->addParamAttr(0, Attribute::NoUndef);
3192
LtGRFunc->addParamAttr(1, Attribute::NoUndef);
3193
LtGRFunc->addParamAttr(2, Attribute::NoUndef);
3194
3195
BasicBlock *EntryBlock = BasicBlock::Create(Ctx, "entry", LtGRFunc);
3196
Builder.SetInsertPoint(EntryBlock);
3197
3198
// Buffer: global reduction buffer.
3199
Argument *BufferArg = LtGRFunc->getArg(0);
3200
// Idx: index of the buffer.
3201
Argument *IdxArg = LtGRFunc->getArg(1);
3202
// ReduceList: thread local Reduce list.
3203
Argument *ReduceListArg = LtGRFunc->getArg(2);
3204
3205
Value *BufferArgAlloca = Builder.CreateAlloca(Builder.getPtrTy(), nullptr,
3206
BufferArg->getName() + ".addr");
3207
Value *IdxArgAlloca = Builder.CreateAlloca(Builder.getInt32Ty(), nullptr,
3208
IdxArg->getName() + ".addr");
3209
Value *ReduceListArgAlloca = Builder.CreateAlloca(
3210
Builder.getPtrTy(), nullptr, ReduceListArg->getName() + ".addr");
3211
ArrayType *RedListArrayTy =
3212
ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
3213
3214
// 1. Build a list of reduction variables.
3215
// void *RedList[<n>] = {<ReductionVars>[0], ..., <ReductionVars>[<n>-1]};
3216
Value *LocalReduceList =
3217
Builder.CreateAlloca(RedListArrayTy, nullptr, ".omp.reduction.red_list");
3218
3219
Value *BufferArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3220
BufferArgAlloca, Builder.getPtrTy(),
3221
BufferArgAlloca->getName() + ".ascast");
3222
Value *IdxArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3223
IdxArgAlloca, Builder.getPtrTy(), IdxArgAlloca->getName() + ".ascast");
3224
Value *ReduceListArgAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3225
ReduceListArgAlloca, Builder.getPtrTy(),
3226
ReduceListArgAlloca->getName() + ".ascast");
3227
Value *ReductionList = Builder.CreatePointerBitCastOrAddrSpaceCast(
3228
LocalReduceList, Builder.getPtrTy(),
3229
LocalReduceList->getName() + ".ascast");
3230
3231
Builder.CreateStore(BufferArg, BufferArgAddrCast);
3232
Builder.CreateStore(IdxArg, IdxArgAddrCast);
3233
Builder.CreateStore(ReduceListArg, ReduceListArgAddrCast);
3234
3235
Value *BufferVal = Builder.CreateLoad(Builder.getPtrTy(), BufferArgAddrCast);
3236
Value *Idxs[] = {Builder.CreateLoad(Builder.getInt32Ty(), IdxArgAddrCast)};
3237
Type *IndexTy = Builder.getIndexTy(
3238
M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
3239
for (auto En : enumerate(ReductionInfos)) {
3240
Value *TargetElementPtrPtr = Builder.CreateInBoundsGEP(
3241
RedListArrayTy, ReductionList,
3242
{ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
3243
// Global = Buffer.VD[Idx];
3244
Value *BufferVD =
3245
Builder.CreateInBoundsGEP(ReductionsBufferTy, BufferVal, Idxs);
3246
Value *GlobValPtr = Builder.CreateConstInBoundsGEP2_32(
3247
ReductionsBufferTy, BufferVD, 0, En.index());
3248
Builder.CreateStore(GlobValPtr, TargetElementPtrPtr);
3249
}
3250
3251
// Call reduce_function(ReduceList, GlobalReduceList)
3252
Value *ReduceList =
3253
Builder.CreateLoad(Builder.getPtrTy(), ReduceListArgAddrCast);
3254
Builder.CreateCall(ReduceFn, {ReduceList, ReductionList})
3255
->addFnAttr(Attribute::NoUnwind);
3256
Builder.CreateRetVoid();
3257
Builder.restoreIP(OldIP);
3258
return LtGRFunc;
3259
}
3260
3261
std::string OpenMPIRBuilder::getReductionFuncName(StringRef Name) const {
3262
std::string Suffix =
3263
createPlatformSpecificName({"omp", "reduction", "reduction_func"});
3264
return (Name + Suffix).str();
3265
}
3266
3267
Function *OpenMPIRBuilder::createReductionFunction(
3268
StringRef ReducerName, ArrayRef<ReductionInfo> ReductionInfos,
3269
ReductionGenCBKind ReductionGenCBKind, AttributeList FuncAttrs) {
3270
auto *FuncTy = FunctionType::get(Builder.getVoidTy(),
3271
{Builder.getPtrTy(), Builder.getPtrTy()},
3272
/* IsVarArg */ false);
3273
std::string Name = getReductionFuncName(ReducerName);
3274
Function *ReductionFunc =
3275
Function::Create(FuncTy, GlobalVariable::InternalLinkage, Name, &M);
3276
ReductionFunc->setAttributes(FuncAttrs);
3277
ReductionFunc->addParamAttr(0, Attribute::NoUndef);
3278
ReductionFunc->addParamAttr(1, Attribute::NoUndef);
3279
BasicBlock *EntryBB =
3280
BasicBlock::Create(M.getContext(), "entry", ReductionFunc);
3281
Builder.SetInsertPoint(EntryBB);
3282
3283
// Need to alloca memory here and deal with the pointers before getting
3284
// LHS/RHS pointers out
3285
Value *LHSArrayPtr = nullptr;
3286
Value *RHSArrayPtr = nullptr;
3287
Argument *Arg0 = ReductionFunc->getArg(0);
3288
Argument *Arg1 = ReductionFunc->getArg(1);
3289
Type *Arg0Type = Arg0->getType();
3290
Type *Arg1Type = Arg1->getType();
3291
3292
Value *LHSAlloca =
3293
Builder.CreateAlloca(Arg0Type, nullptr, Arg0->getName() + ".addr");
3294
Value *RHSAlloca =
3295
Builder.CreateAlloca(Arg1Type, nullptr, Arg1->getName() + ".addr");
3296
Value *LHSAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3297
LHSAlloca, Arg0Type, LHSAlloca->getName() + ".ascast");
3298
Value *RHSAddrCast = Builder.CreatePointerBitCastOrAddrSpaceCast(
3299
RHSAlloca, Arg1Type, RHSAlloca->getName() + ".ascast");
3300
Builder.CreateStore(Arg0, LHSAddrCast);
3301
Builder.CreateStore(Arg1, RHSAddrCast);
3302
LHSArrayPtr = Builder.CreateLoad(Arg0Type, LHSAddrCast);
3303
RHSArrayPtr = Builder.CreateLoad(Arg1Type, RHSAddrCast);
3304
3305
Type *RedArrayTy = ArrayType::get(Builder.getPtrTy(), ReductionInfos.size());
3306
Type *IndexTy = Builder.getIndexTy(
3307
M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
3308
SmallVector<Value *> LHSPtrs, RHSPtrs;
3309
for (auto En : enumerate(ReductionInfos)) {
3310
const ReductionInfo &RI = En.value();
3311
Value *RHSI8PtrPtr = Builder.CreateInBoundsGEP(
3312
RedArrayTy, RHSArrayPtr,
3313
{ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
3314
Value *RHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), RHSI8PtrPtr);
3315
Value *RHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
3316
RHSI8Ptr, RI.PrivateVariable->getType(),
3317
RHSI8Ptr->getName() + ".ascast");
3318
3319
Value *LHSI8PtrPtr = Builder.CreateInBoundsGEP(
3320
RedArrayTy, LHSArrayPtr,
3321
{ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
3322
Value *LHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), LHSI8PtrPtr);
3323
Value *LHSPtr = Builder.CreatePointerBitCastOrAddrSpaceCast(
3324
LHSI8Ptr, RI.Variable->getType(), LHSI8Ptr->getName() + ".ascast");
3325
3326
if (ReductionGenCBKind == ReductionGenCBKind::Clang) {
3327
LHSPtrs.emplace_back(LHSPtr);
3328
RHSPtrs.emplace_back(RHSPtr);
3329
} else {
3330
Value *LHS = Builder.CreateLoad(RI.ElementType, LHSPtr);
3331
Value *RHS = Builder.CreateLoad(RI.ElementType, RHSPtr);
3332
Value *Reduced;
3333
RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced);
3334
if (!Builder.GetInsertBlock())
3335
return ReductionFunc;
3336
Builder.CreateStore(Reduced, LHSPtr);
3337
}
3338
}
3339
3340
if (ReductionGenCBKind == ReductionGenCBKind::Clang)
3341
for (auto En : enumerate(ReductionInfos)) {
3342
unsigned Index = En.index();
3343
const ReductionInfo &RI = En.value();
3344
Value *LHSFixupPtr, *RHSFixupPtr;
3345
Builder.restoreIP(RI.ReductionGenClang(
3346
Builder.saveIP(), Index, &LHSFixupPtr, &RHSFixupPtr, ReductionFunc));
3347
3348
// Fix the CallBack code genereated to use the correct Values for the LHS
3349
// and RHS
3350
LHSFixupPtr->replaceUsesWithIf(
3351
LHSPtrs[Index], [ReductionFunc](const Use &U) {
3352
return cast<Instruction>(U.getUser())->getParent()->getParent() ==
3353
ReductionFunc;
3354
});
3355
RHSFixupPtr->replaceUsesWithIf(
3356
RHSPtrs[Index], [ReductionFunc](const Use &U) {
3357
return cast<Instruction>(U.getUser())->getParent()->getParent() ==
3358
ReductionFunc;
3359
});
3360
}
3361
3362
Builder.CreateRetVoid();
3363
return ReductionFunc;
3364
}
3365
3366
static void
3367
checkReductionInfos(ArrayRef<OpenMPIRBuilder::ReductionInfo> ReductionInfos,
3368
bool IsGPU) {
3369
for (const OpenMPIRBuilder::ReductionInfo &RI : ReductionInfos) {
3370
(void)RI;
3371
assert(RI.Variable && "expected non-null variable");
3372
assert(RI.PrivateVariable && "expected non-null private variable");
3373
assert((RI.ReductionGen || RI.ReductionGenClang) &&
3374
"expected non-null reduction generator callback");
3375
if (!IsGPU) {
3376
assert(
3377
RI.Variable->getType() == RI.PrivateVariable->getType() &&
3378
"expected variables and their private equivalents to have the same "
3379
"type");
3380
}
3381
assert(RI.Variable->getType()->isPointerTy() &&
3382
"expected variables to be pointers");
3383
}
3384
}
3385
3386
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createReductionsGPU(
3387
const LocationDescription &Loc, InsertPointTy AllocaIP,
3388
InsertPointTy CodeGenIP, ArrayRef<ReductionInfo> ReductionInfos,
3389
bool IsNoWait, bool IsTeamsReduction, bool HasDistribute,
3390
ReductionGenCBKind ReductionGenCBKind, std::optional<omp::GV> GridValue,
3391
unsigned ReductionBufNum, Value *SrcLocInfo) {
3392
if (!updateToLocation(Loc))
3393
return InsertPointTy();
3394
Builder.restoreIP(CodeGenIP);
3395
checkReductionInfos(ReductionInfos, /*IsGPU*/ true);
3396
LLVMContext &Ctx = M.getContext();
3397
3398
// Source location for the ident struct
3399
if (!SrcLocInfo) {
3400
uint32_t SrcLocStrSize;
3401
Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3402
SrcLocInfo = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3403
}
3404
3405
if (ReductionInfos.size() == 0)
3406
return Builder.saveIP();
3407
3408
Function *CurFunc = Builder.GetInsertBlock()->getParent();
3409
AttributeList FuncAttrs;
3410
AttrBuilder AttrBldr(Ctx);
3411
for (auto Attr : CurFunc->getAttributes().getFnAttrs())
3412
AttrBldr.addAttribute(Attr);
3413
AttrBldr.removeAttribute(Attribute::OptimizeNone);
3414
FuncAttrs = FuncAttrs.addFnAttributes(Ctx, AttrBldr);
3415
3416
Function *ReductionFunc = nullptr;
3417
CodeGenIP = Builder.saveIP();
3418
ReductionFunc =
3419
createReductionFunction(Builder.GetInsertBlock()->getParent()->getName(),
3420
ReductionInfos, ReductionGenCBKind, FuncAttrs);
3421
Builder.restoreIP(CodeGenIP);
3422
3423
// Set the grid value in the config needed for lowering later on
3424
if (GridValue.has_value())
3425
Config.setGridValue(GridValue.value());
3426
else
3427
Config.setGridValue(getGridValue(T, ReductionFunc));
3428
3429
uint32_t SrcLocStrSize;
3430
Constant *SrcLocStr = getOrCreateDefaultSrcLocStr(SrcLocStrSize);
3431
Value *RTLoc =
3432
getOrCreateIdent(SrcLocStr, SrcLocStrSize, omp::IdentFlag(0), 0);
3433
3434
// Build res = __kmpc_reduce{_nowait}(<gtid>, <n>, sizeof(RedList),
3435
// RedList, shuffle_reduce_func, interwarp_copy_func);
3436
// or
3437
// Build res = __kmpc_reduce_teams_nowait_simple(<loc>, <gtid>, <lck>);
3438
Value *Res;
3439
3440
// 1. Build a list of reduction variables.
3441
// void *RedList[<n>] = {<ReductionVars>[0], ..., <ReductionVars>[<n>-1]};
3442
auto Size = ReductionInfos.size();
3443
Type *PtrTy = PointerType::getUnqual(Ctx);
3444
Type *RedArrayTy = ArrayType::get(PtrTy, Size);
3445
CodeGenIP = Builder.saveIP();
3446
Builder.restoreIP(AllocaIP);
3447
Value *ReductionListAlloca =
3448
Builder.CreateAlloca(RedArrayTy, nullptr, ".omp.reduction.red_list");
3449
Value *ReductionList = Builder.CreatePointerBitCastOrAddrSpaceCast(
3450
ReductionListAlloca, PtrTy, ReductionListAlloca->getName() + ".ascast");
3451
Builder.restoreIP(CodeGenIP);
3452
Type *IndexTy = Builder.getIndexTy(
3453
M.getDataLayout(), M.getDataLayout().getDefaultGlobalsAddressSpace());
3454
for (auto En : enumerate(ReductionInfos)) {
3455
const ReductionInfo &RI = En.value();
3456
Value *ElemPtr = Builder.CreateInBoundsGEP(
3457
RedArrayTy, ReductionList,
3458
{ConstantInt::get(IndexTy, 0), ConstantInt::get(IndexTy, En.index())});
3459
Value *CastElem =
3460
Builder.CreatePointerBitCastOrAddrSpaceCast(RI.PrivateVariable, PtrTy);
3461
Builder.CreateStore(CastElem, ElemPtr);
3462
}
3463
CodeGenIP = Builder.saveIP();
3464
Function *SarFunc =
3465
emitShuffleAndReduceFunction(ReductionInfos, ReductionFunc, FuncAttrs);
3466
Function *WcFunc = emitInterWarpCopyFunction(Loc, ReductionInfos, FuncAttrs);
3467
Builder.restoreIP(CodeGenIP);
3468
3469
Value *RL = Builder.CreatePointerBitCastOrAddrSpaceCast(ReductionList, PtrTy);
3470
3471
unsigned MaxDataSize = 0;
3472
SmallVector<Type *> ReductionTypeArgs;
3473
for (auto En : enumerate(ReductionInfos)) {
3474
auto Size = M.getDataLayout().getTypeStoreSize(En.value().ElementType);
3475
if (Size > MaxDataSize)
3476
MaxDataSize = Size;
3477
ReductionTypeArgs.emplace_back(En.value().ElementType);
3478
}
3479
Value *ReductionDataSize =
3480
Builder.getInt64(MaxDataSize * ReductionInfos.size());
3481
if (!IsTeamsReduction) {
3482
Value *SarFuncCast =
3483
Builder.CreatePointerBitCastOrAddrSpaceCast(SarFunc, PtrTy);
3484
Value *WcFuncCast =
3485
Builder.CreatePointerBitCastOrAddrSpaceCast(WcFunc, PtrTy);
3486
Value *Args[] = {RTLoc, ReductionDataSize, RL, SarFuncCast, WcFuncCast};
3487
Function *Pv2Ptr = getOrCreateRuntimeFunctionPtr(
3488
RuntimeFunction::OMPRTL___kmpc_nvptx_parallel_reduce_nowait_v2);
3489
Res = Builder.CreateCall(Pv2Ptr, Args);
3490
} else {
3491
CodeGenIP = Builder.saveIP();
3492
StructType *ReductionsBufferTy = StructType::create(
3493
Ctx, ReductionTypeArgs, "struct._globalized_locals_ty");
3494
Function *RedFixedBuferFn = getOrCreateRuntimeFunctionPtr(
3495
RuntimeFunction::OMPRTL___kmpc_reduction_get_fixed_buffer);
3496
Function *LtGCFunc = emitListToGlobalCopyFunction(
3497
ReductionInfos, ReductionsBufferTy, FuncAttrs);
3498
Function *LtGRFunc = emitListToGlobalReduceFunction(
3499
ReductionInfos, ReductionFunc, ReductionsBufferTy, FuncAttrs);
3500
Function *GtLCFunc = emitGlobalToListCopyFunction(
3501
ReductionInfos, ReductionsBufferTy, FuncAttrs);
3502
Function *GtLRFunc = emitGlobalToListReduceFunction(
3503
ReductionInfos, ReductionFunc, ReductionsBufferTy, FuncAttrs);
3504
Builder.restoreIP(CodeGenIP);
3505
3506
Value *KernelTeamsReductionPtr = Builder.CreateCall(
3507
RedFixedBuferFn, {}, "_openmp_teams_reductions_buffer_$_$ptr");
3508
3509
Value *Args3[] = {RTLoc,
3510
KernelTeamsReductionPtr,
3511
Builder.getInt32(ReductionBufNum),
3512
ReductionDataSize,
3513
RL,
3514
SarFunc,
3515
WcFunc,
3516
LtGCFunc,
3517
LtGRFunc,
3518
GtLCFunc,
3519
GtLRFunc};
3520
3521
Function *TeamsReduceFn = getOrCreateRuntimeFunctionPtr(
3522
RuntimeFunction::OMPRTL___kmpc_nvptx_teams_reduce_nowait_v2);
3523
Res = Builder.CreateCall(TeamsReduceFn, Args3);
3524
}
3525
3526
// 5. Build if (res == 1)
3527
BasicBlock *ExitBB = BasicBlock::Create(Ctx, ".omp.reduction.done");
3528
BasicBlock *ThenBB = BasicBlock::Create(Ctx, ".omp.reduction.then");
3529
Value *Cond = Builder.CreateICmpEQ(Res, Builder.getInt32(1));
3530
Builder.CreateCondBr(Cond, ThenBB, ExitBB);
3531
3532
// 6. Build then branch: where we have reduced values in the master
3533
// thread in each team.
3534
// __kmpc_end_reduce{_nowait}(<gtid>);
3535
// break;
3536
emitBlock(ThenBB, CurFunc);
3537
3538
// Add emission of __kmpc_end_reduce{_nowait}(<gtid>);
3539
for (auto En : enumerate(ReductionInfos)) {
3540
const ReductionInfo &RI = En.value();
3541
Value *LHS = RI.Variable;
3542
Value *RHS =
3543
Builder.CreatePointerBitCastOrAddrSpaceCast(RI.PrivateVariable, PtrTy);
3544
3545
if (ReductionGenCBKind == ReductionGenCBKind::Clang) {
3546
Value *LHSPtr, *RHSPtr;
3547
Builder.restoreIP(RI.ReductionGenClang(Builder.saveIP(), En.index(),
3548
&LHSPtr, &RHSPtr, CurFunc));
3549
3550
// Fix the CallBack code genereated to use the correct Values for the LHS
3551
// and RHS
3552
LHSPtr->replaceUsesWithIf(LHS, [ReductionFunc](const Use &U) {
3553
return cast<Instruction>(U.getUser())->getParent()->getParent() ==
3554
ReductionFunc;
3555
});
3556
RHSPtr->replaceUsesWithIf(RHS, [ReductionFunc](const Use &U) {
3557
return cast<Instruction>(U.getUser())->getParent()->getParent() ==
3558
ReductionFunc;
3559
});
3560
} else {
3561
assert(false && "Unhandled ReductionGenCBKind");
3562
}
3563
}
3564
emitBlock(ExitBB, CurFunc);
3565
3566
Config.setEmitLLVMUsed();
3567
3568
return Builder.saveIP();
3569
}
3570
3571
static Function *getFreshReductionFunc(Module &M) {
3572
Type *VoidTy = Type::getVoidTy(M.getContext());
3573
Type *Int8PtrTy = PointerType::getUnqual(M.getContext());
3574
auto *FuncTy =
3575
FunctionType::get(VoidTy, {Int8PtrTy, Int8PtrTy}, /* IsVarArg */ false);
3576
return Function::Create(FuncTy, GlobalVariable::InternalLinkage,
3577
".omp.reduction.func", &M);
3578
}
3579
3580
OpenMPIRBuilder::InsertPointTy
3581
OpenMPIRBuilder::createReductions(const LocationDescription &Loc,
3582
InsertPointTy AllocaIP,
3583
ArrayRef<ReductionInfo> ReductionInfos,
3584
ArrayRef<bool> IsByRef, bool IsNoWait) {
3585
assert(ReductionInfos.size() == IsByRef.size());
3586
for (const ReductionInfo &RI : ReductionInfos) {
3587
(void)RI;
3588
assert(RI.Variable && "expected non-null variable");
3589
assert(RI.PrivateVariable && "expected non-null private variable");
3590
assert(RI.ReductionGen && "expected non-null reduction generator callback");
3591
assert(RI.Variable->getType() == RI.PrivateVariable->getType() &&
3592
"expected variables and their private equivalents to have the same "
3593
"type");
3594
assert(RI.Variable->getType()->isPointerTy() &&
3595
"expected variables to be pointers");
3596
}
3597
3598
if (!updateToLocation(Loc))
3599
return InsertPointTy();
3600
3601
BasicBlock *InsertBlock = Loc.IP.getBlock();
3602
BasicBlock *ContinuationBlock =
3603
InsertBlock->splitBasicBlock(Loc.IP.getPoint(), "reduce.finalize");
3604
InsertBlock->getTerminator()->eraseFromParent();
3605
3606
// Create and populate array of type-erased pointers to private reduction
3607
// values.
3608
unsigned NumReductions = ReductionInfos.size();
3609
Type *RedArrayTy = ArrayType::get(Builder.getPtrTy(), NumReductions);
3610
Builder.SetInsertPoint(AllocaIP.getBlock()->getTerminator());
3611
Value *RedArray = Builder.CreateAlloca(RedArrayTy, nullptr, "red.array");
3612
3613
Builder.SetInsertPoint(InsertBlock, InsertBlock->end());
3614
3615
for (auto En : enumerate(ReductionInfos)) {
3616
unsigned Index = En.index();
3617
const ReductionInfo &RI = En.value();
3618
Value *RedArrayElemPtr = Builder.CreateConstInBoundsGEP2_64(
3619
RedArrayTy, RedArray, 0, Index, "red.array.elem." + Twine(Index));
3620
Builder.CreateStore(RI.PrivateVariable, RedArrayElemPtr);
3621
}
3622
3623
// Emit a call to the runtime function that orchestrates the reduction.
3624
// Declare the reduction function in the process.
3625
Function *Func = Builder.GetInsertBlock()->getParent();
3626
Module *Module = Func->getParent();
3627
uint32_t SrcLocStrSize;
3628
Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3629
bool CanGenerateAtomic = all_of(ReductionInfos, [](const ReductionInfo &RI) {
3630
return RI.AtomicReductionGen;
3631
});
3632
Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize,
3633
CanGenerateAtomic
3634
? IdentFlag::OMP_IDENT_FLAG_ATOMIC_REDUCE
3635
: IdentFlag(0));
3636
Value *ThreadId = getOrCreateThreadID(Ident);
3637
Constant *NumVariables = Builder.getInt32(NumReductions);
3638
const DataLayout &DL = Module->getDataLayout();
3639
unsigned RedArrayByteSize = DL.getTypeStoreSize(RedArrayTy);
3640
Constant *RedArraySize = Builder.getInt64(RedArrayByteSize);
3641
Function *ReductionFunc = getFreshReductionFunc(*Module);
3642
Value *Lock = getOMPCriticalRegionLock(".reduction");
3643
Function *ReduceFunc = getOrCreateRuntimeFunctionPtr(
3644
IsNoWait ? RuntimeFunction::OMPRTL___kmpc_reduce_nowait
3645
: RuntimeFunction::OMPRTL___kmpc_reduce);
3646
CallInst *ReduceCall =
3647
Builder.CreateCall(ReduceFunc,
3648
{Ident, ThreadId, NumVariables, RedArraySize, RedArray,
3649
ReductionFunc, Lock},
3650
"reduce");
3651
3652
// Create final reduction entry blocks for the atomic and non-atomic case.
3653
// Emit IR that dispatches control flow to one of the blocks based on the
3654
// reduction supporting the atomic mode.
3655
BasicBlock *NonAtomicRedBlock =
3656
BasicBlock::Create(Module->getContext(), "reduce.switch.nonatomic", Func);
3657
BasicBlock *AtomicRedBlock =
3658
BasicBlock::Create(Module->getContext(), "reduce.switch.atomic", Func);
3659
SwitchInst *Switch =
3660
Builder.CreateSwitch(ReduceCall, ContinuationBlock, /* NumCases */ 2);
3661
Switch->addCase(Builder.getInt32(1), NonAtomicRedBlock);
3662
Switch->addCase(Builder.getInt32(2), AtomicRedBlock);
3663
3664
// Populate the non-atomic reduction using the elementwise reduction function.
3665
// This loads the elements from the global and private variables and reduces
3666
// them before storing back the result to the global variable.
3667
Builder.SetInsertPoint(NonAtomicRedBlock);
3668
for (auto En : enumerate(ReductionInfos)) {
3669
const ReductionInfo &RI = En.value();
3670
Type *ValueType = RI.ElementType;
3671
// We have one less load for by-ref case because that load is now inside of
3672
// the reduction region
3673
Value *RedValue = nullptr;
3674
if (!IsByRef[En.index()]) {
3675
RedValue = Builder.CreateLoad(ValueType, RI.Variable,
3676
"red.value." + Twine(En.index()));
3677
}
3678
Value *PrivateRedValue =
3679
Builder.CreateLoad(ValueType, RI.PrivateVariable,
3680
"red.private.value." + Twine(En.index()));
3681
Value *Reduced;
3682
if (IsByRef[En.index()]) {
3683
Builder.restoreIP(RI.ReductionGen(Builder.saveIP(), RI.Variable,
3684
PrivateRedValue, Reduced));
3685
} else {
3686
Builder.restoreIP(RI.ReductionGen(Builder.saveIP(), RedValue,
3687
PrivateRedValue, Reduced));
3688
}
3689
if (!Builder.GetInsertBlock())
3690
return InsertPointTy();
3691
// for by-ref case, the load is inside of the reduction region
3692
if (!IsByRef[En.index()])
3693
Builder.CreateStore(Reduced, RI.Variable);
3694
}
3695
Function *EndReduceFunc = getOrCreateRuntimeFunctionPtr(
3696
IsNoWait ? RuntimeFunction::OMPRTL___kmpc_end_reduce_nowait
3697
: RuntimeFunction::OMPRTL___kmpc_end_reduce);
3698
Builder.CreateCall(EndReduceFunc, {Ident, ThreadId, Lock});
3699
Builder.CreateBr(ContinuationBlock);
3700
3701
// Populate the atomic reduction using the atomic elementwise reduction
3702
// function. There are no loads/stores here because they will be happening
3703
// inside the atomic elementwise reduction.
3704
Builder.SetInsertPoint(AtomicRedBlock);
3705
if (CanGenerateAtomic && llvm::none_of(IsByRef, [](bool P) { return P; })) {
3706
for (const ReductionInfo &RI : ReductionInfos) {
3707
Builder.restoreIP(RI.AtomicReductionGen(Builder.saveIP(), RI.ElementType,
3708
RI.Variable, RI.PrivateVariable));
3709
if (!Builder.GetInsertBlock())
3710
return InsertPointTy();
3711
}
3712
Builder.CreateBr(ContinuationBlock);
3713
} else {
3714
Builder.CreateUnreachable();
3715
}
3716
3717
// Populate the outlined reduction function using the elementwise reduction
3718
// function. Partial values are extracted from the type-erased array of
3719
// pointers to private variables.
3720
BasicBlock *ReductionFuncBlock =
3721
BasicBlock::Create(Module->getContext(), "", ReductionFunc);
3722
Builder.SetInsertPoint(ReductionFuncBlock);
3723
Value *LHSArrayPtr = ReductionFunc->getArg(0);
3724
Value *RHSArrayPtr = ReductionFunc->getArg(1);
3725
3726
for (auto En : enumerate(ReductionInfos)) {
3727
const ReductionInfo &RI = En.value();
3728
Value *LHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
3729
RedArrayTy, LHSArrayPtr, 0, En.index());
3730
Value *LHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), LHSI8PtrPtr);
3731
Value *LHSPtr = Builder.CreateBitCast(LHSI8Ptr, RI.Variable->getType());
3732
Value *LHS = Builder.CreateLoad(RI.ElementType, LHSPtr);
3733
Value *RHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
3734
RedArrayTy, RHSArrayPtr, 0, En.index());
3735
Value *RHSI8Ptr = Builder.CreateLoad(Builder.getPtrTy(), RHSI8PtrPtr);
3736
Value *RHSPtr =
3737
Builder.CreateBitCast(RHSI8Ptr, RI.PrivateVariable->getType());
3738
Value *RHS = Builder.CreateLoad(RI.ElementType, RHSPtr);
3739
Value *Reduced;
3740
Builder.restoreIP(RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced));
3741
if (!Builder.GetInsertBlock())
3742
return InsertPointTy();
3743
// store is inside of the reduction region when using by-ref
3744
if (!IsByRef[En.index()])
3745
Builder.CreateStore(Reduced, LHSPtr);
3746
}
3747
Builder.CreateRetVoid();
3748
3749
Builder.SetInsertPoint(ContinuationBlock);
3750
return Builder.saveIP();
3751
}
3752
3753
OpenMPIRBuilder::InsertPointTy
3754
OpenMPIRBuilder::createMaster(const LocationDescription &Loc,
3755
BodyGenCallbackTy BodyGenCB,
3756
FinalizeCallbackTy FiniCB) {
3757
3758
if (!updateToLocation(Loc))
3759
return Loc.IP;
3760
3761
Directive OMPD = Directive::OMPD_master;
3762
uint32_t SrcLocStrSize;
3763
Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3764
Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3765
Value *ThreadId = getOrCreateThreadID(Ident);
3766
Value *Args[] = {Ident, ThreadId};
3767
3768
Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_master);
3769
Instruction *EntryCall = Builder.CreateCall(EntryRTLFn, Args);
3770
3771
Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_master);
3772
Instruction *ExitCall = Builder.CreateCall(ExitRTLFn, Args);
3773
3774
return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
3775
/*Conditional*/ true, /*hasFinalize*/ true);
3776
}
3777
3778
OpenMPIRBuilder::InsertPointTy
3779
OpenMPIRBuilder::createMasked(const LocationDescription &Loc,
3780
BodyGenCallbackTy BodyGenCB,
3781
FinalizeCallbackTy FiniCB, Value *Filter) {
3782
if (!updateToLocation(Loc))
3783
return Loc.IP;
3784
3785
Directive OMPD = Directive::OMPD_masked;
3786
uint32_t SrcLocStrSize;
3787
Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3788
Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3789
Value *ThreadId = getOrCreateThreadID(Ident);
3790
Value *Args[] = {Ident, ThreadId, Filter};
3791
Value *ArgsEnd[] = {Ident, ThreadId};
3792
3793
Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_masked);
3794
Instruction *EntryCall = Builder.CreateCall(EntryRTLFn, Args);
3795
3796
Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_masked);
3797
Instruction *ExitCall = Builder.CreateCall(ExitRTLFn, ArgsEnd);
3798
3799
return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
3800
/*Conditional*/ true, /*hasFinalize*/ true);
3801
}
3802
3803
CanonicalLoopInfo *OpenMPIRBuilder::createLoopSkeleton(
3804
DebugLoc DL, Value *TripCount, Function *F, BasicBlock *PreInsertBefore,
3805
BasicBlock *PostInsertBefore, const Twine &Name) {
3806
Module *M = F->getParent();
3807
LLVMContext &Ctx = M->getContext();
3808
Type *IndVarTy = TripCount->getType();
3809
3810
// Create the basic block structure.
3811
BasicBlock *Preheader =
3812
BasicBlock::Create(Ctx, "omp_" + Name + ".preheader", F, PreInsertBefore);
3813
BasicBlock *Header =
3814
BasicBlock::Create(Ctx, "omp_" + Name + ".header", F, PreInsertBefore);
3815
BasicBlock *Cond =
3816
BasicBlock::Create(Ctx, "omp_" + Name + ".cond", F, PreInsertBefore);
3817
BasicBlock *Body =
3818
BasicBlock::Create(Ctx, "omp_" + Name + ".body", F, PreInsertBefore);
3819
BasicBlock *Latch =
3820
BasicBlock::Create(Ctx, "omp_" + Name + ".inc", F, PostInsertBefore);
3821
BasicBlock *Exit =
3822
BasicBlock::Create(Ctx, "omp_" + Name + ".exit", F, PostInsertBefore);
3823
BasicBlock *After =
3824
BasicBlock::Create(Ctx, "omp_" + Name + ".after", F, PostInsertBefore);
3825
3826
// Use specified DebugLoc for new instructions.
3827
Builder.SetCurrentDebugLocation(DL);
3828
3829
Builder.SetInsertPoint(Preheader);
3830
Builder.CreateBr(Header);
3831
3832
Builder.SetInsertPoint(Header);
3833
PHINode *IndVarPHI = Builder.CreatePHI(IndVarTy, 2, "omp_" + Name + ".iv");
3834
IndVarPHI->addIncoming(ConstantInt::get(IndVarTy, 0), Preheader);
3835
Builder.CreateBr(Cond);
3836
3837
Builder.SetInsertPoint(Cond);
3838
Value *Cmp =
3839
Builder.CreateICmpULT(IndVarPHI, TripCount, "omp_" + Name + ".cmp");
3840
Builder.CreateCondBr(Cmp, Body, Exit);
3841
3842
Builder.SetInsertPoint(Body);
3843
Builder.CreateBr(Latch);
3844
3845
Builder.SetInsertPoint(Latch);
3846
Value *Next = Builder.CreateAdd(IndVarPHI, ConstantInt::get(IndVarTy, 1),
3847
"omp_" + Name + ".next", /*HasNUW=*/true);
3848
Builder.CreateBr(Header);
3849
IndVarPHI->addIncoming(Next, Latch);
3850
3851
Builder.SetInsertPoint(Exit);
3852
Builder.CreateBr(After);
3853
3854
// Remember and return the canonical control flow.
3855
LoopInfos.emplace_front();
3856
CanonicalLoopInfo *CL = &LoopInfos.front();
3857
3858
CL->Header = Header;
3859
CL->Cond = Cond;
3860
CL->Latch = Latch;
3861
CL->Exit = Exit;
3862
3863
#ifndef NDEBUG
3864
CL->assertOK();
3865
#endif
3866
return CL;
3867
}
3868
3869
CanonicalLoopInfo *
3870
OpenMPIRBuilder::createCanonicalLoop(const LocationDescription &Loc,
3871
LoopBodyGenCallbackTy BodyGenCB,
3872
Value *TripCount, const Twine &Name) {
3873
BasicBlock *BB = Loc.IP.getBlock();
3874
BasicBlock *NextBB = BB->getNextNode();
3875
3876
CanonicalLoopInfo *CL = createLoopSkeleton(Loc.DL, TripCount, BB->getParent(),
3877
NextBB, NextBB, Name);
3878
BasicBlock *After = CL->getAfter();
3879
3880
// If location is not set, don't connect the loop.
3881
if (updateToLocation(Loc)) {
3882
// Split the loop at the insertion point: Branch to the preheader and move
3883
// every following instruction to after the loop (the After BB). Also, the
3884
// new successor is the loop's after block.
3885
spliceBB(Builder, After, /*CreateBranch=*/false);
3886
Builder.CreateBr(CL->getPreheader());
3887
}
3888
3889
// Emit the body content. We do it after connecting the loop to the CFG to
3890
// avoid that the callback encounters degenerate BBs.
3891
BodyGenCB(CL->getBodyIP(), CL->getIndVar());
3892
3893
#ifndef NDEBUG
3894
CL->assertOK();
3895
#endif
3896
return CL;
3897
}
3898
3899
CanonicalLoopInfo *OpenMPIRBuilder::createCanonicalLoop(
3900
const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
3901
Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
3902
InsertPointTy ComputeIP, const Twine &Name) {
3903
3904
// Consider the following difficulties (assuming 8-bit signed integers):
3905
// * Adding \p Step to the loop counter which passes \p Stop may overflow:
3906
// DO I = 1, 100, 50
3907
/// * A \p Step of INT_MIN cannot not be normalized to a positive direction:
3908
// DO I = 100, 0, -128
3909
3910
// Start, Stop and Step must be of the same integer type.
3911
auto *IndVarTy = cast<IntegerType>(Start->getType());
3912
assert(IndVarTy == Stop->getType() && "Stop type mismatch");
3913
assert(IndVarTy == Step->getType() && "Step type mismatch");
3914
3915
LocationDescription ComputeLoc =
3916
ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc;
3917
updateToLocation(ComputeLoc);
3918
3919
ConstantInt *Zero = ConstantInt::get(IndVarTy, 0);
3920
ConstantInt *One = ConstantInt::get(IndVarTy, 1);
3921
3922
// Like Step, but always positive.
3923
Value *Incr = Step;
3924
3925
// Distance between Start and Stop; always positive.
3926
Value *Span;
3927
3928
// Condition whether there are no iterations are executed at all, e.g. because
3929
// UB < LB.
3930
Value *ZeroCmp;
3931
3932
if (IsSigned) {
3933
// Ensure that increment is positive. If not, negate and invert LB and UB.
3934
Value *IsNeg = Builder.CreateICmpSLT(Step, Zero);
3935
Incr = Builder.CreateSelect(IsNeg, Builder.CreateNeg(Step), Step);
3936
Value *LB = Builder.CreateSelect(IsNeg, Stop, Start);
3937
Value *UB = Builder.CreateSelect(IsNeg, Start, Stop);
3938
Span = Builder.CreateSub(UB, LB, "", false, true);
3939
ZeroCmp = Builder.CreateICmp(
3940
InclusiveStop ? CmpInst::ICMP_SLT : CmpInst::ICMP_SLE, UB, LB);
3941
} else {
3942
Span = Builder.CreateSub(Stop, Start, "", true);
3943
ZeroCmp = Builder.CreateICmp(
3944
InclusiveStop ? CmpInst::ICMP_ULT : CmpInst::ICMP_ULE, Stop, Start);
3945
}
3946
3947
Value *CountIfLooping;
3948
if (InclusiveStop) {
3949
CountIfLooping = Builder.CreateAdd(Builder.CreateUDiv(Span, Incr), One);
3950
} else {
3951
// Avoid incrementing past stop since it could overflow.
3952
Value *CountIfTwo = Builder.CreateAdd(
3953
Builder.CreateUDiv(Builder.CreateSub(Span, One), Incr), One);
3954
Value *OneCmp = Builder.CreateICmp(CmpInst::ICMP_ULE, Span, Incr);
3955
CountIfLooping = Builder.CreateSelect(OneCmp, One, CountIfTwo);
3956
}
3957
Value *TripCount = Builder.CreateSelect(ZeroCmp, Zero, CountIfLooping,
3958
"omp_" + Name + ".tripcount");
3959
3960
auto BodyGen = [=](InsertPointTy CodeGenIP, Value *IV) {
3961
Builder.restoreIP(CodeGenIP);
3962
Value *Span = Builder.CreateMul(IV, Step);
3963
Value *IndVar = Builder.CreateAdd(Span, Start);
3964
BodyGenCB(Builder.saveIP(), IndVar);
3965
};
3966
LocationDescription LoopLoc = ComputeIP.isSet() ? Loc.IP : Builder.saveIP();
3967
return createCanonicalLoop(LoopLoc, BodyGen, TripCount, Name);
3968
}
3969
3970
// Returns an LLVM function to call for initializing loop bounds using OpenMP
3971
// static scheduling depending on `type`. Only i32 and i64 are supported by the
3972
// runtime. Always interpret integers as unsigned similarly to
3973
// CanonicalLoopInfo.
3974
static FunctionCallee getKmpcForStaticInitForType(Type *Ty, Module &M,
3975
OpenMPIRBuilder &OMPBuilder) {
3976
unsigned Bitwidth = Ty->getIntegerBitWidth();
3977
if (Bitwidth == 32)
3978
return OMPBuilder.getOrCreateRuntimeFunction(
3979
M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_init_4u);
3980
if (Bitwidth == 64)
3981
return OMPBuilder.getOrCreateRuntimeFunction(
3982
M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_init_8u);
3983
llvm_unreachable("unknown OpenMP loop iterator bitwidth");
3984
}
3985
3986
OpenMPIRBuilder::InsertPointTy
3987
OpenMPIRBuilder::applyStaticWorkshareLoop(DebugLoc DL, CanonicalLoopInfo *CLI,
3988
InsertPointTy AllocaIP,
3989
bool NeedsBarrier) {
3990
assert(CLI->isValid() && "Requires a valid canonical loop");
3991
assert(!isConflictIP(AllocaIP, CLI->getPreheaderIP()) &&
3992
"Require dedicated allocate IP");
3993
3994
// Set up the source location value for OpenMP runtime.
3995
Builder.restoreIP(CLI->getPreheaderIP());
3996
Builder.SetCurrentDebugLocation(DL);
3997
3998
uint32_t SrcLocStrSize;
3999
Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
4000
Value *SrcLoc = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4001
4002
// Declare useful OpenMP runtime functions.
4003
Value *IV = CLI->getIndVar();
4004
Type *IVTy = IV->getType();
4005
FunctionCallee StaticInit = getKmpcForStaticInitForType(IVTy, M, *this);
4006
FunctionCallee StaticFini =
4007
getOrCreateRuntimeFunction(M, omp::OMPRTL___kmpc_for_static_fini);
4008
4009
// Allocate space for computed loop bounds as expected by the "init" function.
4010
Builder.SetInsertPoint(AllocaIP.getBlock()->getFirstNonPHIOrDbgOrAlloca());
4011
4012
Type *I32Type = Type::getInt32Ty(M.getContext());
4013
Value *PLastIter = Builder.CreateAlloca(I32Type, nullptr, "p.lastiter");
4014
Value *PLowerBound = Builder.CreateAlloca(IVTy, nullptr, "p.lowerbound");
4015
Value *PUpperBound = Builder.CreateAlloca(IVTy, nullptr, "p.upperbound");
4016
Value *PStride = Builder.CreateAlloca(IVTy, nullptr, "p.stride");
4017
4018
// At the end of the preheader, prepare for calling the "init" function by
4019
// storing the current loop bounds into the allocated space. A canonical loop
4020
// always iterates from 0 to trip-count with step 1. Note that "init" expects
4021
// and produces an inclusive upper bound.
4022
Builder.SetInsertPoint(CLI->getPreheader()->getTerminator());
4023
Constant *Zero = ConstantInt::get(IVTy, 0);
4024
Constant *One = ConstantInt::get(IVTy, 1);
4025
Builder.CreateStore(Zero, PLowerBound);
4026
Value *UpperBound = Builder.CreateSub(CLI->getTripCount(), One);
4027
Builder.CreateStore(UpperBound, PUpperBound);
4028
Builder.CreateStore(One, PStride);
4029
4030
Value *ThreadNum = getOrCreateThreadID(SrcLoc);
4031
4032
Constant *SchedulingType = ConstantInt::get(
4033
I32Type, static_cast<int>(OMPScheduleType::UnorderedStatic));
4034
4035
// Call the "init" function and update the trip count of the loop with the
4036
// value it produced.
4037
Builder.CreateCall(StaticInit,
4038
{SrcLoc, ThreadNum, SchedulingType, PLastIter, PLowerBound,
4039
PUpperBound, PStride, One, Zero});
4040
Value *LowerBound = Builder.CreateLoad(IVTy, PLowerBound);
4041
Value *InclusiveUpperBound = Builder.CreateLoad(IVTy, PUpperBound);
4042
Value *TripCountMinusOne = Builder.CreateSub(InclusiveUpperBound, LowerBound);
4043
Value *TripCount = Builder.CreateAdd(TripCountMinusOne, One);
4044
CLI->setTripCount(TripCount);
4045
4046
// Update all uses of the induction variable except the one in the condition
4047
// block that compares it with the actual upper bound, and the increment in
4048
// the latch block.
4049
4050
CLI->mapIndVar([&](Instruction *OldIV) -> Value * {
4051
Builder.SetInsertPoint(CLI->getBody(),
4052
CLI->getBody()->getFirstInsertionPt());
4053
Builder.SetCurrentDebugLocation(DL);
4054
return Builder.CreateAdd(OldIV, LowerBound);
4055
});
4056
4057
// In the "exit" block, call the "fini" function.
4058
Builder.SetInsertPoint(CLI->getExit(),
4059
CLI->getExit()->getTerminator()->getIterator());
4060
Builder.CreateCall(StaticFini, {SrcLoc, ThreadNum});
4061
4062
// Add the barrier if requested.
4063
if (NeedsBarrier)
4064
createBarrier(LocationDescription(Builder.saveIP(), DL),
4065
omp::Directive::OMPD_for, /* ForceSimpleCall */ false,
4066
/* CheckCancelFlag */ false);
4067
4068
InsertPointTy AfterIP = CLI->getAfterIP();
4069
CLI->invalidate();
4070
4071
return AfterIP;
4072
}
4073
4074
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(
4075
DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
4076
bool NeedsBarrier, Value *ChunkSize) {
4077
assert(CLI->isValid() && "Requires a valid canonical loop");
4078
assert(ChunkSize && "Chunk size is required");
4079
4080
LLVMContext &Ctx = CLI->getFunction()->getContext();
4081
Value *IV = CLI->getIndVar();
4082
Value *OrigTripCount = CLI->getTripCount();
4083
Type *IVTy = IV->getType();
4084
assert(IVTy->getIntegerBitWidth() <= 64 &&
4085
"Max supported tripcount bitwidth is 64 bits");
4086
Type *InternalIVTy = IVTy->getIntegerBitWidth() <= 32 ? Type::getInt32Ty(Ctx)
4087
: Type::getInt64Ty(Ctx);
4088
Type *I32Type = Type::getInt32Ty(M.getContext());
4089
Constant *Zero = ConstantInt::get(InternalIVTy, 0);
4090
Constant *One = ConstantInt::get(InternalIVTy, 1);
4091
4092
// Declare useful OpenMP runtime functions.
4093
FunctionCallee StaticInit =
4094
getKmpcForStaticInitForType(InternalIVTy, M, *this);
4095
FunctionCallee StaticFini =
4096
getOrCreateRuntimeFunction(M, omp::OMPRTL___kmpc_for_static_fini);
4097
4098
// Allocate space for computed loop bounds as expected by the "init" function.
4099
Builder.restoreIP(AllocaIP);
4100
Builder.SetCurrentDebugLocation(DL);
4101
Value *PLastIter = Builder.CreateAlloca(I32Type, nullptr, "p.lastiter");
4102
Value *PLowerBound =
4103
Builder.CreateAlloca(InternalIVTy, nullptr, "p.lowerbound");
4104
Value *PUpperBound =
4105
Builder.CreateAlloca(InternalIVTy, nullptr, "p.upperbound");
4106
Value *PStride = Builder.CreateAlloca(InternalIVTy, nullptr, "p.stride");
4107
4108
// Set up the source location value for the OpenMP runtime.
4109
Builder.restoreIP(CLI->getPreheaderIP());
4110
Builder.SetCurrentDebugLocation(DL);
4111
4112
// TODO: Detect overflow in ubsan or max-out with current tripcount.
4113
Value *CastedChunkSize =
4114
Builder.CreateZExtOrTrunc(ChunkSize, InternalIVTy, "chunksize");
4115
Value *CastedTripCount =
4116
Builder.CreateZExt(OrigTripCount, InternalIVTy, "tripcount");
4117
4118
Constant *SchedulingType = ConstantInt::get(
4119
I32Type, static_cast<int>(OMPScheduleType::UnorderedStaticChunked));
4120
Builder.CreateStore(Zero, PLowerBound);
4121
Value *OrigUpperBound = Builder.CreateSub(CastedTripCount, One);
4122
Builder.CreateStore(OrigUpperBound, PUpperBound);
4123
Builder.CreateStore(One, PStride);
4124
4125
// Call the "init" function and update the trip count of the loop with the
4126
// value it produced.
4127
uint32_t SrcLocStrSize;
4128
Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
4129
Value *SrcLoc = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4130
Value *ThreadNum = getOrCreateThreadID(SrcLoc);
4131
Builder.CreateCall(StaticInit,
4132
{/*loc=*/SrcLoc, /*global_tid=*/ThreadNum,
4133
/*schedtype=*/SchedulingType, /*plastiter=*/PLastIter,
4134
/*plower=*/PLowerBound, /*pupper=*/PUpperBound,
4135
/*pstride=*/PStride, /*incr=*/One,
4136
/*chunk=*/CastedChunkSize});
4137
4138
// Load values written by the "init" function.
4139
Value *FirstChunkStart =
4140
Builder.CreateLoad(InternalIVTy, PLowerBound, "omp_firstchunk.lb");
4141
Value *FirstChunkStop =
4142
Builder.CreateLoad(InternalIVTy, PUpperBound, "omp_firstchunk.ub");
4143
Value *FirstChunkEnd = Builder.CreateAdd(FirstChunkStop, One);
4144
Value *ChunkRange =
4145
Builder.CreateSub(FirstChunkEnd, FirstChunkStart, "omp_chunk.range");
4146
Value *NextChunkStride =
4147
Builder.CreateLoad(InternalIVTy, PStride, "omp_dispatch.stride");
4148
4149
// Create outer "dispatch" loop for enumerating the chunks.
4150
BasicBlock *DispatchEnter = splitBB(Builder, true);
4151
Value *DispatchCounter;
4152
CanonicalLoopInfo *DispatchCLI = createCanonicalLoop(
4153
{Builder.saveIP(), DL},
4154
[&](InsertPointTy BodyIP, Value *Counter) { DispatchCounter = Counter; },
4155
FirstChunkStart, CastedTripCount, NextChunkStride,
4156
/*IsSigned=*/false, /*InclusiveStop=*/false, /*ComputeIP=*/{},
4157
"dispatch");
4158
4159
// Remember the BasicBlocks of the dispatch loop we need, then invalidate to
4160
// not have to preserve the canonical invariant.
4161
BasicBlock *DispatchBody = DispatchCLI->getBody();
4162
BasicBlock *DispatchLatch = DispatchCLI->getLatch();
4163
BasicBlock *DispatchExit = DispatchCLI->getExit();
4164
BasicBlock *DispatchAfter = DispatchCLI->getAfter();
4165
DispatchCLI->invalidate();
4166
4167
// Rewire the original loop to become the chunk loop inside the dispatch loop.
4168
redirectTo(DispatchAfter, CLI->getAfter(), DL);
4169
redirectTo(CLI->getExit(), DispatchLatch, DL);
4170
redirectTo(DispatchBody, DispatchEnter, DL);
4171
4172
// Prepare the prolog of the chunk loop.
4173
Builder.restoreIP(CLI->getPreheaderIP());
4174
Builder.SetCurrentDebugLocation(DL);
4175
4176
// Compute the number of iterations of the chunk loop.
4177
Builder.SetInsertPoint(CLI->getPreheader()->getTerminator());
4178
Value *ChunkEnd = Builder.CreateAdd(DispatchCounter, ChunkRange);
4179
Value *IsLastChunk =
4180
Builder.CreateICmpUGE(ChunkEnd, CastedTripCount, "omp_chunk.is_last");
4181
Value *CountUntilOrigTripCount =
4182
Builder.CreateSub(CastedTripCount, DispatchCounter);
4183
Value *ChunkTripCount = Builder.CreateSelect(
4184
IsLastChunk, CountUntilOrigTripCount, ChunkRange, "omp_chunk.tripcount");
4185
Value *BackcastedChunkTC =
4186
Builder.CreateTrunc(ChunkTripCount, IVTy, "omp_chunk.tripcount.trunc");
4187
CLI->setTripCount(BackcastedChunkTC);
4188
4189
// Update all uses of the induction variable except the one in the condition
4190
// block that compares it with the actual upper bound, and the increment in
4191
// the latch block.
4192
Value *BackcastedDispatchCounter =
4193
Builder.CreateTrunc(DispatchCounter, IVTy, "omp_dispatch.iv.trunc");
4194
CLI->mapIndVar([&](Instruction *) -> Value * {
4195
Builder.restoreIP(CLI->getBodyIP());
4196
return Builder.CreateAdd(IV, BackcastedDispatchCounter);
4197
});
4198
4199
// In the "exit" block, call the "fini" function.
4200
Builder.SetInsertPoint(DispatchExit, DispatchExit->getFirstInsertionPt());
4201
Builder.CreateCall(StaticFini, {SrcLoc, ThreadNum});
4202
4203
// Add the barrier if requested.
4204
if (NeedsBarrier)
4205
createBarrier(LocationDescription(Builder.saveIP(), DL), OMPD_for,
4206
/*ForceSimpleCall=*/false, /*CheckCancelFlag=*/false);
4207
4208
#ifndef NDEBUG
4209
// Even though we currently do not support applying additional methods to it,
4210
// the chunk loop should remain a canonical loop.
4211
CLI->assertOK();
4212
#endif
4213
4214
return {DispatchAfter, DispatchAfter->getFirstInsertionPt()};
4215
}
4216
4217
// Returns an LLVM function to call for executing an OpenMP static worksharing
4218
// for loop depending on `type`. Only i32 and i64 are supported by the runtime.
4219
// Always interpret integers as unsigned similarly to CanonicalLoopInfo.
4220
static FunctionCallee
4221
getKmpcForStaticLoopForType(Type *Ty, OpenMPIRBuilder *OMPBuilder,
4222
WorksharingLoopType LoopType) {
4223
unsigned Bitwidth = Ty->getIntegerBitWidth();
4224
Module &M = OMPBuilder->M;
4225
switch (LoopType) {
4226
case WorksharingLoopType::ForStaticLoop:
4227
if (Bitwidth == 32)
4228
return OMPBuilder->getOrCreateRuntimeFunction(
4229
M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_loop_4u);
4230
if (Bitwidth == 64)
4231
return OMPBuilder->getOrCreateRuntimeFunction(
4232
M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_loop_8u);
4233
break;
4234
case WorksharingLoopType::DistributeStaticLoop:
4235
if (Bitwidth == 32)
4236
return OMPBuilder->getOrCreateRuntimeFunction(
4237
M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_static_loop_4u);
4238
if (Bitwidth == 64)
4239
return OMPBuilder->getOrCreateRuntimeFunction(
4240
M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_static_loop_8u);
4241
break;
4242
case WorksharingLoopType::DistributeForStaticLoop:
4243
if (Bitwidth == 32)
4244
return OMPBuilder->getOrCreateRuntimeFunction(
4245
M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_for_static_loop_4u);
4246
if (Bitwidth == 64)
4247
return OMPBuilder->getOrCreateRuntimeFunction(
4248
M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_for_static_loop_8u);
4249
break;
4250
}
4251
if (Bitwidth != 32 && Bitwidth != 64) {
4252
llvm_unreachable("Unknown OpenMP loop iterator bitwidth");
4253
}
4254
llvm_unreachable("Unknown type of OpenMP worksharing loop");
4255
}
4256
4257
// Inserts a call to proper OpenMP Device RTL function which handles
4258
// loop worksharing.
4259
static void createTargetLoopWorkshareCall(
4260
OpenMPIRBuilder *OMPBuilder, WorksharingLoopType LoopType,
4261
BasicBlock *InsertBlock, Value *Ident, Value *LoopBodyArg,
4262
Type *ParallelTaskPtr, Value *TripCount, Function &LoopBodyFn) {
4263
Type *TripCountTy = TripCount->getType();
4264
Module &M = OMPBuilder->M;
4265
IRBuilder<> &Builder = OMPBuilder->Builder;
4266
FunctionCallee RTLFn =
4267
getKmpcForStaticLoopForType(TripCountTy, OMPBuilder, LoopType);
4268
SmallVector<Value *, 8> RealArgs;
4269
RealArgs.push_back(Ident);
4270
RealArgs.push_back(Builder.CreateBitCast(&LoopBodyFn, ParallelTaskPtr));
4271
RealArgs.push_back(LoopBodyArg);
4272
RealArgs.push_back(TripCount);
4273
if (LoopType == WorksharingLoopType::DistributeStaticLoop) {
4274
RealArgs.push_back(ConstantInt::get(TripCountTy, 0));
4275
Builder.CreateCall(RTLFn, RealArgs);
4276
return;
4277
}
4278
FunctionCallee RTLNumThreads = OMPBuilder->getOrCreateRuntimeFunction(
4279
M, omp::RuntimeFunction::OMPRTL_omp_get_num_threads);
4280
Builder.restoreIP({InsertBlock, std::prev(InsertBlock->end())});
4281
Value *NumThreads = Builder.CreateCall(RTLNumThreads, {});
4282
4283
RealArgs.push_back(
4284
Builder.CreateZExtOrTrunc(NumThreads, TripCountTy, "num.threads.cast"));
4285
RealArgs.push_back(ConstantInt::get(TripCountTy, 0));
4286
if (LoopType == WorksharingLoopType::DistributeForStaticLoop) {
4287
RealArgs.push_back(ConstantInt::get(TripCountTy, 0));
4288
}
4289
4290
Builder.CreateCall(RTLFn, RealArgs);
4291
}
4292
4293
static void
4294
workshareLoopTargetCallback(OpenMPIRBuilder *OMPIRBuilder,
4295
CanonicalLoopInfo *CLI, Value *Ident,
4296
Function &OutlinedFn, Type *ParallelTaskPtr,
4297
const SmallVector<Instruction *, 4> &ToBeDeleted,
4298
WorksharingLoopType LoopType) {
4299
IRBuilder<> &Builder = OMPIRBuilder->Builder;
4300
BasicBlock *Preheader = CLI->getPreheader();
4301
Value *TripCount = CLI->getTripCount();
4302
4303
// After loop body outling, the loop body contains only set up
4304
// of loop body argument structure and the call to the outlined
4305
// loop body function. Firstly, we need to move setup of loop body args
4306
// into loop preheader.
4307
Preheader->splice(std::prev(Preheader->end()), CLI->getBody(),
4308
CLI->getBody()->begin(), std::prev(CLI->getBody()->end()));
4309
4310
// The next step is to remove the whole loop. We do not it need anymore.
4311
// That's why make an unconditional branch from loop preheader to loop
4312
// exit block
4313
Builder.restoreIP({Preheader, Preheader->end()});
4314
Preheader->getTerminator()->eraseFromParent();
4315
Builder.CreateBr(CLI->getExit());
4316
4317
// Delete dead loop blocks
4318
OpenMPIRBuilder::OutlineInfo CleanUpInfo;
4319
SmallPtrSet<BasicBlock *, 32> RegionBlockSet;
4320
SmallVector<BasicBlock *, 32> BlocksToBeRemoved;
4321
CleanUpInfo.EntryBB = CLI->getHeader();
4322
CleanUpInfo.ExitBB = CLI->getExit();
4323
CleanUpInfo.collectBlocks(RegionBlockSet, BlocksToBeRemoved);
4324
DeleteDeadBlocks(BlocksToBeRemoved);
4325
4326
// Find the instruction which corresponds to loop body argument structure
4327
// and remove the call to loop body function instruction.
4328
Value *LoopBodyArg;
4329
User *OutlinedFnUser = OutlinedFn.getUniqueUndroppableUser();
4330
assert(OutlinedFnUser &&
4331
"Expected unique undroppable user of outlined function");
4332
CallInst *OutlinedFnCallInstruction = dyn_cast<CallInst>(OutlinedFnUser);
4333
assert(OutlinedFnCallInstruction && "Expected outlined function call");
4334
assert((OutlinedFnCallInstruction->getParent() == Preheader) &&
4335
"Expected outlined function call to be located in loop preheader");
4336
// Check in case no argument structure has been passed.
4337
if (OutlinedFnCallInstruction->arg_size() > 1)
4338
LoopBodyArg = OutlinedFnCallInstruction->getArgOperand(1);
4339
else
4340
LoopBodyArg = Constant::getNullValue(Builder.getPtrTy());
4341
OutlinedFnCallInstruction->eraseFromParent();
4342
4343
createTargetLoopWorkshareCall(OMPIRBuilder, LoopType, Preheader, Ident,
4344
LoopBodyArg, ParallelTaskPtr, TripCount,
4345
OutlinedFn);
4346
4347
for (auto &ToBeDeletedItem : ToBeDeleted)
4348
ToBeDeletedItem->eraseFromParent();
4349
CLI->invalidate();
4350
}
4351
4352
OpenMPIRBuilder::InsertPointTy
4353
OpenMPIRBuilder::applyWorkshareLoopTarget(DebugLoc DL, CanonicalLoopInfo *CLI,
4354
InsertPointTy AllocaIP,
4355
WorksharingLoopType LoopType) {
4356
uint32_t SrcLocStrSize;
4357
Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
4358
Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4359
4360
OutlineInfo OI;
4361
OI.OuterAllocaBB = CLI->getPreheader();
4362
Function *OuterFn = CLI->getPreheader()->getParent();
4363
4364
// Instructions which need to be deleted at the end of code generation
4365
SmallVector<Instruction *, 4> ToBeDeleted;
4366
4367
OI.OuterAllocaBB = AllocaIP.getBlock();
4368
4369
// Mark the body loop as region which needs to be extracted
4370
OI.EntryBB = CLI->getBody();
4371
OI.ExitBB = CLI->getLatch()->splitBasicBlock(CLI->getLatch()->begin(),
4372
"omp.prelatch", true);
4373
4374
// Prepare loop body for extraction
4375
Builder.restoreIP({CLI->getPreheader(), CLI->getPreheader()->begin()});
4376
4377
// Insert new loop counter variable which will be used only in loop
4378
// body.
4379
AllocaInst *NewLoopCnt = Builder.CreateAlloca(CLI->getIndVarType(), 0, "");
4380
Instruction *NewLoopCntLoad =
4381
Builder.CreateLoad(CLI->getIndVarType(), NewLoopCnt);
4382
// New loop counter instructions are redundant in the loop preheader when
4383
// code generation for workshare loop is finshed. That's why mark them as
4384
// ready for deletion.
4385
ToBeDeleted.push_back(NewLoopCntLoad);
4386
ToBeDeleted.push_back(NewLoopCnt);
4387
4388
// Analyse loop body region. Find all input variables which are used inside
4389
// loop body region.
4390
SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
4391
SmallVector<BasicBlock *, 32> Blocks;
4392
OI.collectBlocks(ParallelRegionBlockSet, Blocks);
4393
SmallVector<BasicBlock *, 32> BlocksT(ParallelRegionBlockSet.begin(),
4394
ParallelRegionBlockSet.end());
4395
4396
CodeExtractorAnalysisCache CEAC(*OuterFn);
4397
CodeExtractor Extractor(Blocks,
4398
/* DominatorTree */ nullptr,
4399
/* AggregateArgs */ true,
4400
/* BlockFrequencyInfo */ nullptr,
4401
/* BranchProbabilityInfo */ nullptr,
4402
/* AssumptionCache */ nullptr,
4403
/* AllowVarArgs */ true,
4404
/* AllowAlloca */ true,
4405
/* AllocationBlock */ CLI->getPreheader(),
4406
/* Suffix */ ".omp_wsloop",
4407
/* AggrArgsIn0AddrSpace */ true);
4408
4409
BasicBlock *CommonExit = nullptr;
4410
SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
4411
4412
// Find allocas outside the loop body region which are used inside loop
4413
// body
4414
Extractor.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
4415
4416
// We need to model loop body region as the function f(cnt, loop_arg).
4417
// That's why we replace loop induction variable by the new counter
4418
// which will be one of loop body function argument
4419
SmallVector<User *> Users(CLI->getIndVar()->user_begin(),
4420
CLI->getIndVar()->user_end());
4421
for (auto Use : Users) {
4422
if (Instruction *Inst = dyn_cast<Instruction>(Use)) {
4423
if (ParallelRegionBlockSet.count(Inst->getParent())) {
4424
Inst->replaceUsesOfWith(CLI->getIndVar(), NewLoopCntLoad);
4425
}
4426
}
4427
}
4428
// Make sure that loop counter variable is not merged into loop body
4429
// function argument structure and it is passed as separate variable
4430
OI.ExcludeArgsFromAggregate.push_back(NewLoopCntLoad);
4431
4432
// PostOutline CB is invoked when loop body function is outlined and
4433
// loop body is replaced by call to outlined function. We need to add
4434
// call to OpenMP device rtl inside loop preheader. OpenMP device rtl
4435
// function will handle loop control logic.
4436
//
4437
OI.PostOutlineCB = [=, ToBeDeletedVec =
4438
std::move(ToBeDeleted)](Function &OutlinedFn) {
4439
workshareLoopTargetCallback(this, CLI, Ident, OutlinedFn, ParallelTaskPtr,
4440
ToBeDeletedVec, LoopType);
4441
};
4442
addOutlineInfo(std::move(OI));
4443
return CLI->getAfterIP();
4444
}
4445
4446
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyWorkshareLoop(
4447
DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
4448
bool NeedsBarrier, omp::ScheduleKind SchedKind, Value *ChunkSize,
4449
bool HasSimdModifier, bool HasMonotonicModifier,
4450
bool HasNonmonotonicModifier, bool HasOrderedClause,
4451
WorksharingLoopType LoopType) {
4452
if (Config.isTargetDevice())
4453
return applyWorkshareLoopTarget(DL, CLI, AllocaIP, LoopType);
4454
OMPScheduleType EffectiveScheduleType = computeOpenMPScheduleType(
4455
SchedKind, ChunkSize, HasSimdModifier, HasMonotonicModifier,
4456
HasNonmonotonicModifier, HasOrderedClause);
4457
4458
bool IsOrdered = (EffectiveScheduleType & OMPScheduleType::ModifierOrdered) ==
4459
OMPScheduleType::ModifierOrdered;
4460
switch (EffectiveScheduleType & ~OMPScheduleType::ModifierMask) {
4461
case OMPScheduleType::BaseStatic:
4462
assert(!ChunkSize && "No chunk size with static-chunked schedule");
4463
if (IsOrdered)
4464
return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, EffectiveScheduleType,
4465
NeedsBarrier, ChunkSize);
4466
// FIXME: Monotonicity ignored?
4467
return applyStaticWorkshareLoop(DL, CLI, AllocaIP, NeedsBarrier);
4468
4469
case OMPScheduleType::BaseStaticChunked:
4470
if (IsOrdered)
4471
return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, EffectiveScheduleType,
4472
NeedsBarrier, ChunkSize);
4473
// FIXME: Monotonicity ignored?
4474
return applyStaticChunkedWorkshareLoop(DL, CLI, AllocaIP, NeedsBarrier,
4475
ChunkSize);
4476
4477
case OMPScheduleType::BaseRuntime:
4478
case OMPScheduleType::BaseAuto:
4479
case OMPScheduleType::BaseGreedy:
4480
case OMPScheduleType::BaseBalanced:
4481
case OMPScheduleType::BaseSteal:
4482
case OMPScheduleType::BaseGuidedSimd:
4483
case OMPScheduleType::BaseRuntimeSimd:
4484
assert(!ChunkSize &&
4485
"schedule type does not support user-defined chunk sizes");
4486
[[fallthrough]];
4487
case OMPScheduleType::BaseDynamicChunked:
4488
case OMPScheduleType::BaseGuidedChunked:
4489
case OMPScheduleType::BaseGuidedIterativeChunked:
4490
case OMPScheduleType::BaseGuidedAnalyticalChunked:
4491
case OMPScheduleType::BaseStaticBalancedChunked:
4492
return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, EffectiveScheduleType,
4493
NeedsBarrier, ChunkSize);
4494
4495
default:
4496
llvm_unreachable("Unknown/unimplemented schedule kind");
4497
}
4498
}
4499
4500
/// Returns an LLVM function to call for initializing loop bounds using OpenMP
4501
/// dynamic scheduling depending on `type`. Only i32 and i64 are supported by
4502
/// the runtime. Always interpret integers as unsigned similarly to
4503
/// CanonicalLoopInfo.
4504
static FunctionCallee
4505
getKmpcForDynamicInitForType(Type *Ty, Module &M, OpenMPIRBuilder &OMPBuilder) {
4506
unsigned Bitwidth = Ty->getIntegerBitWidth();
4507
if (Bitwidth == 32)
4508
return OMPBuilder.getOrCreateRuntimeFunction(
4509
M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_init_4u);
4510
if (Bitwidth == 64)
4511
return OMPBuilder.getOrCreateRuntimeFunction(
4512
M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_init_8u);
4513
llvm_unreachable("unknown OpenMP loop iterator bitwidth");
4514
}
4515
4516
/// Returns an LLVM function to call for updating the next loop using OpenMP
4517
/// dynamic scheduling depending on `type`. Only i32 and i64 are supported by
4518
/// the runtime. Always interpret integers as unsigned similarly to
4519
/// CanonicalLoopInfo.
4520
static FunctionCallee
4521
getKmpcForDynamicNextForType(Type *Ty, Module &M, OpenMPIRBuilder &OMPBuilder) {
4522
unsigned Bitwidth = Ty->getIntegerBitWidth();
4523
if (Bitwidth == 32)
4524
return OMPBuilder.getOrCreateRuntimeFunction(
4525
M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_next_4u);
4526
if (Bitwidth == 64)
4527
return OMPBuilder.getOrCreateRuntimeFunction(
4528
M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_next_8u);
4529
llvm_unreachable("unknown OpenMP loop iterator bitwidth");
4530
}
4531
4532
/// Returns an LLVM function to call for finalizing the dynamic loop using
4533
/// depending on `type`. Only i32 and i64 are supported by the runtime. Always
4534
/// interpret integers as unsigned similarly to CanonicalLoopInfo.
4535
static FunctionCallee
4536
getKmpcForDynamicFiniForType(Type *Ty, Module &M, OpenMPIRBuilder &OMPBuilder) {
4537
unsigned Bitwidth = Ty->getIntegerBitWidth();
4538
if (Bitwidth == 32)
4539
return OMPBuilder.getOrCreateRuntimeFunction(
4540
M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_fini_4u);
4541
if (Bitwidth == 64)
4542
return OMPBuilder.getOrCreateRuntimeFunction(
4543
M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_fini_8u);
4544
llvm_unreachable("unknown OpenMP loop iterator bitwidth");
4545
}
4546
4547
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyDynamicWorkshareLoop(
4548
DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
4549
OMPScheduleType SchedType, bool NeedsBarrier, Value *Chunk) {
4550
assert(CLI->isValid() && "Requires a valid canonical loop");
4551
assert(!isConflictIP(AllocaIP, CLI->getPreheaderIP()) &&
4552
"Require dedicated allocate IP");
4553
assert(isValidWorkshareLoopScheduleType(SchedType) &&
4554
"Require valid schedule type");
4555
4556
bool Ordered = (SchedType & OMPScheduleType::ModifierOrdered) ==
4557
OMPScheduleType::ModifierOrdered;
4558
4559
// Set up the source location value for OpenMP runtime.
4560
Builder.SetCurrentDebugLocation(DL);
4561
4562
uint32_t SrcLocStrSize;
4563
Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
4564
Value *SrcLoc = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
4565
4566
// Declare useful OpenMP runtime functions.
4567
Value *IV = CLI->getIndVar();
4568
Type *IVTy = IV->getType();
4569
FunctionCallee DynamicInit = getKmpcForDynamicInitForType(IVTy, M, *this);
4570
FunctionCallee DynamicNext = getKmpcForDynamicNextForType(IVTy, M, *this);
4571
4572
// Allocate space for computed loop bounds as expected by the "init" function.
4573
Builder.SetInsertPoint(AllocaIP.getBlock()->getFirstNonPHIOrDbgOrAlloca());
4574
Type *I32Type = Type::getInt32Ty(M.getContext());
4575
Value *PLastIter = Builder.CreateAlloca(I32Type, nullptr, "p.lastiter");
4576
Value *PLowerBound = Builder.CreateAlloca(IVTy, nullptr, "p.lowerbound");
4577
Value *PUpperBound = Builder.CreateAlloca(IVTy, nullptr, "p.upperbound");
4578
Value *PStride = Builder.CreateAlloca(IVTy, nullptr, "p.stride");
4579
4580
// At the end of the preheader, prepare for calling the "init" function by
4581
// storing the current loop bounds into the allocated space. A canonical loop
4582
// always iterates from 0 to trip-count with step 1. Note that "init" expects
4583
// and produces an inclusive upper bound.
4584
BasicBlock *PreHeader = CLI->getPreheader();
4585
Builder.SetInsertPoint(PreHeader->getTerminator());
4586
Constant *One = ConstantInt::get(IVTy, 1);
4587
Builder.CreateStore(One, PLowerBound);
4588
Value *UpperBound = CLI->getTripCount();
4589
Builder.CreateStore(UpperBound, PUpperBound);
4590
Builder.CreateStore(One, PStride);
4591
4592
BasicBlock *Header = CLI->getHeader();
4593
BasicBlock *Exit = CLI->getExit();
4594
BasicBlock *Cond = CLI->getCond();
4595
BasicBlock *Latch = CLI->getLatch();
4596
InsertPointTy AfterIP = CLI->getAfterIP();
4597
4598
// The CLI will be "broken" in the code below, as the loop is no longer
4599
// a valid canonical loop.
4600
4601
if (!Chunk)
4602
Chunk = One;
4603
4604
Value *ThreadNum = getOrCreateThreadID(SrcLoc);
4605
4606
Constant *SchedulingType =
4607
ConstantInt::get(I32Type, static_cast<int>(SchedType));
4608
4609
// Call the "init" function.
4610
Builder.CreateCall(DynamicInit,
4611
{SrcLoc, ThreadNum, SchedulingType, /* LowerBound */ One,
4612
UpperBound, /* step */ One, Chunk});
4613
4614
// An outer loop around the existing one.
4615
BasicBlock *OuterCond = BasicBlock::Create(
4616
PreHeader->getContext(), Twine(PreHeader->getName()) + ".outer.cond",
4617
PreHeader->getParent());
4618
// This needs to be 32-bit always, so can't use the IVTy Zero above.
4619
Builder.SetInsertPoint(OuterCond, OuterCond->getFirstInsertionPt());
4620
Value *Res =
4621
Builder.CreateCall(DynamicNext, {SrcLoc, ThreadNum, PLastIter,
4622
PLowerBound, PUpperBound, PStride});
4623
Constant *Zero32 = ConstantInt::get(I32Type, 0);
4624
Value *MoreWork = Builder.CreateCmp(CmpInst::ICMP_NE, Res, Zero32);
4625
Value *LowerBound =
4626
Builder.CreateSub(Builder.CreateLoad(IVTy, PLowerBound), One, "lb");
4627
Builder.CreateCondBr(MoreWork, Header, Exit);
4628
4629
// Change PHI-node in loop header to use outer cond rather than preheader,
4630
// and set IV to the LowerBound.
4631
Instruction *Phi = &Header->front();
4632
auto *PI = cast<PHINode>(Phi);
4633
PI->setIncomingBlock(0, OuterCond);
4634
PI->setIncomingValue(0, LowerBound);
4635
4636
// Then set the pre-header to jump to the OuterCond
4637
Instruction *Term = PreHeader->getTerminator();
4638
auto *Br = cast<BranchInst>(Term);
4639
Br->setSuccessor(0, OuterCond);
4640
4641
// Modify the inner condition:
4642
// * Use the UpperBound returned from the DynamicNext call.
4643
// * jump to the loop outer loop when done with one of the inner loops.
4644
Builder.SetInsertPoint(Cond, Cond->getFirstInsertionPt());
4645
UpperBound = Builder.CreateLoad(IVTy, PUpperBound, "ub");
4646
Instruction *Comp = &*Builder.GetInsertPoint();
4647
auto *CI = cast<CmpInst>(Comp);
4648
CI->setOperand(1, UpperBound);
4649
// Redirect the inner exit to branch to outer condition.
4650
Instruction *Branch = &Cond->back();
4651
auto *BI = cast<BranchInst>(Branch);
4652
assert(BI->getSuccessor(1) == Exit);
4653
BI->setSuccessor(1, OuterCond);
4654
4655
// Call the "fini" function if "ordered" is present in wsloop directive.
4656
if (Ordered) {
4657
Builder.SetInsertPoint(&Latch->back());
4658
FunctionCallee DynamicFini = getKmpcForDynamicFiniForType(IVTy, M, *this);
4659
Builder.CreateCall(DynamicFini, {SrcLoc, ThreadNum});
4660
}
4661
4662
// Add the barrier if requested.
4663
if (NeedsBarrier) {
4664
Builder.SetInsertPoint(&Exit->back());
4665
createBarrier(LocationDescription(Builder.saveIP(), DL),
4666
omp::Directive::OMPD_for, /* ForceSimpleCall */ false,
4667
/* CheckCancelFlag */ false);
4668
}
4669
4670
CLI->invalidate();
4671
return AfterIP;
4672
}
4673
4674
/// Redirect all edges that branch to \p OldTarget to \p NewTarget. That is,
4675
/// after this \p OldTarget will be orphaned.
4676
static void redirectAllPredecessorsTo(BasicBlock *OldTarget,
4677
BasicBlock *NewTarget, DebugLoc DL) {
4678
for (BasicBlock *Pred : make_early_inc_range(predecessors(OldTarget)))
4679
redirectTo(Pred, NewTarget, DL);
4680
}
4681
4682
/// Determine which blocks in \p BBs are reachable from outside and remove the
4683
/// ones that are not reachable from the function.
4684
static void removeUnusedBlocksFromParent(ArrayRef<BasicBlock *> BBs) {
4685
SmallPtrSet<BasicBlock *, 6> BBsToErase{BBs.begin(), BBs.end()};
4686
auto HasRemainingUses = [&BBsToErase](BasicBlock *BB) {
4687
for (Use &U : BB->uses()) {
4688
auto *UseInst = dyn_cast<Instruction>(U.getUser());
4689
if (!UseInst)
4690
continue;
4691
if (BBsToErase.count(UseInst->getParent()))
4692
continue;
4693
return true;
4694
}
4695
return false;
4696
};
4697
4698
while (BBsToErase.remove_if(HasRemainingUses)) {
4699
// Try again if anything was removed.
4700
}
4701
4702
SmallVector<BasicBlock *, 7> BBVec(BBsToErase.begin(), BBsToErase.end());
4703
DeleteDeadBlocks(BBVec);
4704
}
4705
4706
CanonicalLoopInfo *
4707
OpenMPIRBuilder::collapseLoops(DebugLoc DL, ArrayRef<CanonicalLoopInfo *> Loops,
4708
InsertPointTy ComputeIP) {
4709
assert(Loops.size() >= 1 && "At least one loop required");
4710
size_t NumLoops = Loops.size();
4711
4712
// Nothing to do if there is already just one loop.
4713
if (NumLoops == 1)
4714
return Loops.front();
4715
4716
CanonicalLoopInfo *Outermost = Loops.front();
4717
CanonicalLoopInfo *Innermost = Loops.back();
4718
BasicBlock *OrigPreheader = Outermost->getPreheader();
4719
BasicBlock *OrigAfter = Outermost->getAfter();
4720
Function *F = OrigPreheader->getParent();
4721
4722
// Loop control blocks that may become orphaned later.
4723
SmallVector<BasicBlock *, 12> OldControlBBs;
4724
OldControlBBs.reserve(6 * Loops.size());
4725
for (CanonicalLoopInfo *Loop : Loops)
4726
Loop->collectControlBlocks(OldControlBBs);
4727
4728
// Setup the IRBuilder for inserting the trip count computation.
4729
Builder.SetCurrentDebugLocation(DL);
4730
if (ComputeIP.isSet())
4731
Builder.restoreIP(ComputeIP);
4732
else
4733
Builder.restoreIP(Outermost->getPreheaderIP());
4734
4735
// Derive the collapsed' loop trip count.
4736
// TODO: Find common/largest indvar type.
4737
Value *CollapsedTripCount = nullptr;
4738
for (CanonicalLoopInfo *L : Loops) {
4739
assert(L->isValid() &&
4740
"All loops to collapse must be valid canonical loops");
4741
Value *OrigTripCount = L->getTripCount();
4742
if (!CollapsedTripCount) {
4743
CollapsedTripCount = OrigTripCount;
4744
continue;
4745
}
4746
4747
// TODO: Enable UndefinedSanitizer to diagnose an overflow here.
4748
CollapsedTripCount = Builder.CreateMul(CollapsedTripCount, OrigTripCount,
4749
{}, /*HasNUW=*/true);
4750
}
4751
4752
// Create the collapsed loop control flow.
4753
CanonicalLoopInfo *Result =
4754
createLoopSkeleton(DL, CollapsedTripCount, F,
4755
OrigPreheader->getNextNode(), OrigAfter, "collapsed");
4756
4757
// Build the collapsed loop body code.
4758
// Start with deriving the input loop induction variables from the collapsed
4759
// one, using a divmod scheme. To preserve the original loops' order, the
4760
// innermost loop use the least significant bits.
4761
Builder.restoreIP(Result->getBodyIP());
4762
4763
Value *Leftover = Result->getIndVar();
4764
SmallVector<Value *> NewIndVars;
4765
NewIndVars.resize(NumLoops);
4766
for (int i = NumLoops - 1; i >= 1; --i) {
4767
Value *OrigTripCount = Loops[i]->getTripCount();
4768
4769
Value *NewIndVar = Builder.CreateURem(Leftover, OrigTripCount);
4770
NewIndVars[i] = NewIndVar;
4771
4772
Leftover = Builder.CreateUDiv(Leftover, OrigTripCount);
4773
}
4774
// Outermost loop gets all the remaining bits.
4775
NewIndVars[0] = Leftover;
4776
4777
// Construct the loop body control flow.
4778
// We progressively construct the branch structure following in direction of
4779
// the control flow, from the leading in-between code, the loop nest body, the
4780
// trailing in-between code, and rejoining the collapsed loop's latch.
4781
// ContinueBlock and ContinuePred keep track of the source(s) of next edge. If
4782
// the ContinueBlock is set, continue with that block. If ContinuePred, use
4783
// its predecessors as sources.
4784
BasicBlock *ContinueBlock = Result->getBody();
4785
BasicBlock *ContinuePred = nullptr;
4786
auto ContinueWith = [&ContinueBlock, &ContinuePred, DL](BasicBlock *Dest,
4787
BasicBlock *NextSrc) {
4788
if (ContinueBlock)
4789
redirectTo(ContinueBlock, Dest, DL);
4790
else
4791
redirectAllPredecessorsTo(ContinuePred, Dest, DL);
4792
4793
ContinueBlock = nullptr;
4794
ContinuePred = NextSrc;
4795
};
4796
4797
// The code before the nested loop of each level.
4798
// Because we are sinking it into the nest, it will be executed more often
4799
// that the original loop. More sophisticated schemes could keep track of what
4800
// the in-between code is and instantiate it only once per thread.
4801
for (size_t i = 0; i < NumLoops - 1; ++i)
4802
ContinueWith(Loops[i]->getBody(), Loops[i + 1]->getHeader());
4803
4804
// Connect the loop nest body.
4805
ContinueWith(Innermost->getBody(), Innermost->getLatch());
4806
4807
// The code after the nested loop at each level.
4808
for (size_t i = NumLoops - 1; i > 0; --i)
4809
ContinueWith(Loops[i]->getAfter(), Loops[i - 1]->getLatch());
4810
4811
// Connect the finished loop to the collapsed loop latch.
4812
ContinueWith(Result->getLatch(), nullptr);
4813
4814
// Replace the input loops with the new collapsed loop.
4815
redirectTo(Outermost->getPreheader(), Result->getPreheader(), DL);
4816
redirectTo(Result->getAfter(), Outermost->getAfter(), DL);
4817
4818
// Replace the input loop indvars with the derived ones.
4819
for (size_t i = 0; i < NumLoops; ++i)
4820
Loops[i]->getIndVar()->replaceAllUsesWith(NewIndVars[i]);
4821
4822
// Remove unused parts of the input loops.
4823
removeUnusedBlocksFromParent(OldControlBBs);
4824
4825
for (CanonicalLoopInfo *L : Loops)
4826
L->invalidate();
4827
4828
#ifndef NDEBUG
4829
Result->assertOK();
4830
#endif
4831
return Result;
4832
}
4833
4834
std::vector<CanonicalLoopInfo *>
4835
OpenMPIRBuilder::tileLoops(DebugLoc DL, ArrayRef<CanonicalLoopInfo *> Loops,
4836
ArrayRef<Value *> TileSizes) {
4837
assert(TileSizes.size() == Loops.size() &&
4838
"Must pass as many tile sizes as there are loops");
4839
int NumLoops = Loops.size();
4840
assert(NumLoops >= 1 && "At least one loop to tile required");
4841
4842
CanonicalLoopInfo *OutermostLoop = Loops.front();
4843
CanonicalLoopInfo *InnermostLoop = Loops.back();
4844
Function *F = OutermostLoop->getBody()->getParent();
4845
BasicBlock *InnerEnter = InnermostLoop->getBody();
4846
BasicBlock *InnerLatch = InnermostLoop->getLatch();
4847
4848
// Loop control blocks that may become orphaned later.
4849
SmallVector<BasicBlock *, 12> OldControlBBs;
4850
OldControlBBs.reserve(6 * Loops.size());
4851
for (CanonicalLoopInfo *Loop : Loops)
4852
Loop->collectControlBlocks(OldControlBBs);
4853
4854
// Collect original trip counts and induction variable to be accessible by
4855
// index. Also, the structure of the original loops is not preserved during
4856
// the construction of the tiled loops, so do it before we scavenge the BBs of
4857
// any original CanonicalLoopInfo.
4858
SmallVector<Value *, 4> OrigTripCounts, OrigIndVars;
4859
for (CanonicalLoopInfo *L : Loops) {
4860
assert(L->isValid() && "All input loops must be valid canonical loops");
4861
OrigTripCounts.push_back(L->getTripCount());
4862
OrigIndVars.push_back(L->getIndVar());
4863
}
4864
4865
// Collect the code between loop headers. These may contain SSA definitions
4866
// that are used in the loop nest body. To be usable with in the innermost
4867
// body, these BasicBlocks will be sunk into the loop nest body. That is,
4868
// these instructions may be executed more often than before the tiling.
4869
// TODO: It would be sufficient to only sink them into body of the
4870
// corresponding tile loop.
4871
SmallVector<std::pair<BasicBlock *, BasicBlock *>, 4> InbetweenCode;
4872
for (int i = 0; i < NumLoops - 1; ++i) {
4873
CanonicalLoopInfo *Surrounding = Loops[i];
4874
CanonicalLoopInfo *Nested = Loops[i + 1];
4875
4876
BasicBlock *EnterBB = Surrounding->getBody();
4877
BasicBlock *ExitBB = Nested->getHeader();
4878
InbetweenCode.emplace_back(EnterBB, ExitBB);
4879
}
4880
4881
// Compute the trip counts of the floor loops.
4882
Builder.SetCurrentDebugLocation(DL);
4883
Builder.restoreIP(OutermostLoop->getPreheaderIP());
4884
SmallVector<Value *, 4> FloorCount, FloorRems;
4885
for (int i = 0; i < NumLoops; ++i) {
4886
Value *TileSize = TileSizes[i];
4887
Value *OrigTripCount = OrigTripCounts[i];
4888
Type *IVType = OrigTripCount->getType();
4889
4890
Value *FloorTripCount = Builder.CreateUDiv(OrigTripCount, TileSize);
4891
Value *FloorTripRem = Builder.CreateURem(OrigTripCount, TileSize);
4892
4893
// 0 if tripcount divides the tilesize, 1 otherwise.
4894
// 1 means we need an additional iteration for a partial tile.
4895
//
4896
// Unfortunately we cannot just use the roundup-formula
4897
// (tripcount + tilesize - 1)/tilesize
4898
// because the summation might overflow. We do not want introduce undefined
4899
// behavior when the untiled loop nest did not.
4900
Value *FloorTripOverflow =
4901
Builder.CreateICmpNE(FloorTripRem, ConstantInt::get(IVType, 0));
4902
4903
FloorTripOverflow = Builder.CreateZExt(FloorTripOverflow, IVType);
4904
FloorTripCount =
4905
Builder.CreateAdd(FloorTripCount, FloorTripOverflow,
4906
"omp_floor" + Twine(i) + ".tripcount", true);
4907
4908
// Remember some values for later use.
4909
FloorCount.push_back(FloorTripCount);
4910
FloorRems.push_back(FloorTripRem);
4911
}
4912
4913
// Generate the new loop nest, from the outermost to the innermost.
4914
std::vector<CanonicalLoopInfo *> Result;
4915
Result.reserve(NumLoops * 2);
4916
4917
// The basic block of the surrounding loop that enters the nest generated
4918
// loop.
4919
BasicBlock *Enter = OutermostLoop->getPreheader();
4920
4921
// The basic block of the surrounding loop where the inner code should
4922
// continue.
4923
BasicBlock *Continue = OutermostLoop->getAfter();
4924
4925
// Where the next loop basic block should be inserted.
4926
BasicBlock *OutroInsertBefore = InnermostLoop->getExit();
4927
4928
auto EmbeddNewLoop =
4929
[this, DL, F, InnerEnter, &Enter, &Continue, &OutroInsertBefore](
4930
Value *TripCount, const Twine &Name) -> CanonicalLoopInfo * {
4931
CanonicalLoopInfo *EmbeddedLoop = createLoopSkeleton(
4932
DL, TripCount, F, InnerEnter, OutroInsertBefore, Name);
4933
redirectTo(Enter, EmbeddedLoop->getPreheader(), DL);
4934
redirectTo(EmbeddedLoop->getAfter(), Continue, DL);
4935
4936
// Setup the position where the next embedded loop connects to this loop.
4937
Enter = EmbeddedLoop->getBody();
4938
Continue = EmbeddedLoop->getLatch();
4939
OutroInsertBefore = EmbeddedLoop->getLatch();
4940
return EmbeddedLoop;
4941
};
4942
4943
auto EmbeddNewLoops = [&Result, &EmbeddNewLoop](ArrayRef<Value *> TripCounts,
4944
const Twine &NameBase) {
4945
for (auto P : enumerate(TripCounts)) {
4946
CanonicalLoopInfo *EmbeddedLoop =
4947
EmbeddNewLoop(P.value(), NameBase + Twine(P.index()));
4948
Result.push_back(EmbeddedLoop);
4949
}
4950
};
4951
4952
EmbeddNewLoops(FloorCount, "floor");
4953
4954
// Within the innermost floor loop, emit the code that computes the tile
4955
// sizes.
4956
Builder.SetInsertPoint(Enter->getTerminator());
4957
SmallVector<Value *, 4> TileCounts;
4958
for (int i = 0; i < NumLoops; ++i) {
4959
CanonicalLoopInfo *FloorLoop = Result[i];
4960
Value *TileSize = TileSizes[i];
4961
4962
Value *FloorIsEpilogue =
4963
Builder.CreateICmpEQ(FloorLoop->getIndVar(), FloorCount[i]);
4964
Value *TileTripCount =
4965
Builder.CreateSelect(FloorIsEpilogue, FloorRems[i], TileSize);
4966
4967
TileCounts.push_back(TileTripCount);
4968
}
4969
4970
// Create the tile loops.
4971
EmbeddNewLoops(TileCounts, "tile");
4972
4973
// Insert the inbetween code into the body.
4974
BasicBlock *BodyEnter = Enter;
4975
BasicBlock *BodyEntered = nullptr;
4976
for (std::pair<BasicBlock *, BasicBlock *> P : InbetweenCode) {
4977
BasicBlock *EnterBB = P.first;
4978
BasicBlock *ExitBB = P.second;
4979
4980
if (BodyEnter)
4981
redirectTo(BodyEnter, EnterBB, DL);
4982
else
4983
redirectAllPredecessorsTo(BodyEntered, EnterBB, DL);
4984
4985
BodyEnter = nullptr;
4986
BodyEntered = ExitBB;
4987
}
4988
4989
// Append the original loop nest body into the generated loop nest body.
4990
if (BodyEnter)
4991
redirectTo(BodyEnter, InnerEnter, DL);
4992
else
4993
redirectAllPredecessorsTo(BodyEntered, InnerEnter, DL);
4994
redirectAllPredecessorsTo(InnerLatch, Continue, DL);
4995
4996
// Replace the original induction variable with an induction variable computed
4997
// from the tile and floor induction variables.
4998
Builder.restoreIP(Result.back()->getBodyIP());
4999
for (int i = 0; i < NumLoops; ++i) {
5000
CanonicalLoopInfo *FloorLoop = Result[i];
5001
CanonicalLoopInfo *TileLoop = Result[NumLoops + i];
5002
Value *OrigIndVar = OrigIndVars[i];
5003
Value *Size = TileSizes[i];
5004
5005
Value *Scale =
5006
Builder.CreateMul(Size, FloorLoop->getIndVar(), {}, /*HasNUW=*/true);
5007
Value *Shift =
5008
Builder.CreateAdd(Scale, TileLoop->getIndVar(), {}, /*HasNUW=*/true);
5009
OrigIndVar->replaceAllUsesWith(Shift);
5010
}
5011
5012
// Remove unused parts of the original loops.
5013
removeUnusedBlocksFromParent(OldControlBBs);
5014
5015
for (CanonicalLoopInfo *L : Loops)
5016
L->invalidate();
5017
5018
#ifndef NDEBUG
5019
for (CanonicalLoopInfo *GenL : Result)
5020
GenL->assertOK();
5021
#endif
5022
return Result;
5023
}
5024
5025
/// Attach metadata \p Properties to the basic block described by \p BB. If the
5026
/// basic block already has metadata, the basic block properties are appended.
5027
static void addBasicBlockMetadata(BasicBlock *BB,
5028
ArrayRef<Metadata *> Properties) {
5029
// Nothing to do if no property to attach.
5030
if (Properties.empty())
5031
return;
5032
5033
LLVMContext &Ctx = BB->getContext();
5034
SmallVector<Metadata *> NewProperties;
5035
NewProperties.push_back(nullptr);
5036
5037
// If the basic block already has metadata, prepend it to the new metadata.
5038
MDNode *Existing = BB->getTerminator()->getMetadata(LLVMContext::MD_loop);
5039
if (Existing)
5040
append_range(NewProperties, drop_begin(Existing->operands(), 1));
5041
5042
append_range(NewProperties, Properties);
5043
MDNode *BasicBlockID = MDNode::getDistinct(Ctx, NewProperties);
5044
BasicBlockID->replaceOperandWith(0, BasicBlockID);
5045
5046
BB->getTerminator()->setMetadata(LLVMContext::MD_loop, BasicBlockID);
5047
}
5048
5049
/// Attach loop metadata \p Properties to the loop described by \p Loop. If the
5050
/// loop already has metadata, the loop properties are appended.
5051
static void addLoopMetadata(CanonicalLoopInfo *Loop,
5052
ArrayRef<Metadata *> Properties) {
5053
assert(Loop->isValid() && "Expecting a valid CanonicalLoopInfo");
5054
5055
// Attach metadata to the loop's latch
5056
BasicBlock *Latch = Loop->getLatch();
5057
assert(Latch && "A valid CanonicalLoopInfo must have a unique latch");
5058
addBasicBlockMetadata(Latch, Properties);
5059
}
5060
5061
/// Attach llvm.access.group metadata to the memref instructions of \p Block
5062
static void addSimdMetadata(BasicBlock *Block, MDNode *AccessGroup,
5063
LoopInfo &LI) {
5064
for (Instruction &I : *Block) {
5065
if (I.mayReadOrWriteMemory()) {
5066
// TODO: This instruction may already have access group from
5067
// other pragmas e.g. #pragma clang loop vectorize. Append
5068
// so that the existing metadata is not overwritten.
5069
I.setMetadata(LLVMContext::MD_access_group, AccessGroup);
5070
}
5071
}
5072
}
5073
5074
void OpenMPIRBuilder::unrollLoopFull(DebugLoc, CanonicalLoopInfo *Loop) {
5075
LLVMContext &Ctx = Builder.getContext();
5076
addLoopMetadata(
5077
Loop, {MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.enable")),
5078
MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.full"))});
5079
}
5080
5081
void OpenMPIRBuilder::unrollLoopHeuristic(DebugLoc, CanonicalLoopInfo *Loop) {
5082
LLVMContext &Ctx = Builder.getContext();
5083
addLoopMetadata(
5084
Loop, {
5085
MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.enable")),
5086
});
5087
}
5088
5089
void OpenMPIRBuilder::createIfVersion(CanonicalLoopInfo *CanonicalLoop,
5090
Value *IfCond, ValueToValueMapTy &VMap,
5091
const Twine &NamePrefix) {
5092
Function *F = CanonicalLoop->getFunction();
5093
5094
// Define where if branch should be inserted
5095
Instruction *SplitBefore;
5096
if (Instruction::classof(IfCond)) {
5097
SplitBefore = dyn_cast<Instruction>(IfCond);
5098
} else {
5099
SplitBefore = CanonicalLoop->getPreheader()->getTerminator();
5100
}
5101
5102
// TODO: We should not rely on pass manager. Currently we use pass manager
5103
// only for getting llvm::Loop which corresponds to given CanonicalLoopInfo
5104
// object. We should have a method which returns all blocks between
5105
// CanonicalLoopInfo::getHeader() and CanonicalLoopInfo::getAfter()
5106
FunctionAnalysisManager FAM;
5107
FAM.registerPass([]() { return DominatorTreeAnalysis(); });
5108
FAM.registerPass([]() { return LoopAnalysis(); });
5109
FAM.registerPass([]() { return PassInstrumentationAnalysis(); });
5110
5111
// Get the loop which needs to be cloned
5112
LoopAnalysis LIA;
5113
LoopInfo &&LI = LIA.run(*F, FAM);
5114
Loop *L = LI.getLoopFor(CanonicalLoop->getHeader());
5115
5116
// Create additional blocks for the if statement
5117
BasicBlock *Head = SplitBefore->getParent();
5118
Instruction *HeadOldTerm = Head->getTerminator();
5119
llvm::LLVMContext &C = Head->getContext();
5120
llvm::BasicBlock *ThenBlock = llvm::BasicBlock::Create(
5121
C, NamePrefix + ".if.then", Head->getParent(), Head->getNextNode());
5122
llvm::BasicBlock *ElseBlock = llvm::BasicBlock::Create(
5123
C, NamePrefix + ".if.else", Head->getParent(), CanonicalLoop->getExit());
5124
5125
// Create if condition branch.
5126
Builder.SetInsertPoint(HeadOldTerm);
5127
Instruction *BrInstr =
5128
Builder.CreateCondBr(IfCond, ThenBlock, /*ifFalse*/ ElseBlock);
5129
InsertPointTy IP{BrInstr->getParent(), ++BrInstr->getIterator()};
5130
// Then block contains branch to omp loop which needs to be vectorized
5131
spliceBB(IP, ThenBlock, false);
5132
ThenBlock->replaceSuccessorsPhiUsesWith(Head, ThenBlock);
5133
5134
Builder.SetInsertPoint(ElseBlock);
5135
5136
// Clone loop for the else branch
5137
SmallVector<BasicBlock *, 8> NewBlocks;
5138
5139
VMap[CanonicalLoop->getPreheader()] = ElseBlock;
5140
for (BasicBlock *Block : L->getBlocks()) {
5141
BasicBlock *NewBB = CloneBasicBlock(Block, VMap, "", F);
5142
NewBB->moveBefore(CanonicalLoop->getExit());
5143
VMap[Block] = NewBB;
5144
NewBlocks.push_back(NewBB);
5145
}
5146
remapInstructionsInBlocks(NewBlocks, VMap);
5147
Builder.CreateBr(NewBlocks.front());
5148
}
5149
5150
unsigned
5151
OpenMPIRBuilder::getOpenMPDefaultSimdAlign(const Triple &TargetTriple,
5152
const StringMap<bool> &Features) {
5153
if (TargetTriple.isX86()) {
5154
if (Features.lookup("avx512f"))
5155
return 512;
5156
else if (Features.lookup("avx"))
5157
return 256;
5158
return 128;
5159
}
5160
if (TargetTriple.isPPC())
5161
return 128;
5162
if (TargetTriple.isWasm())
5163
return 128;
5164
return 0;
5165
}
5166
5167
void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
5168
MapVector<Value *, Value *> AlignedVars,
5169
Value *IfCond, OrderKind Order,
5170
ConstantInt *Simdlen, ConstantInt *Safelen) {
5171
LLVMContext &Ctx = Builder.getContext();
5172
5173
Function *F = CanonicalLoop->getFunction();
5174
5175
// TODO: We should not rely on pass manager. Currently we use pass manager
5176
// only for getting llvm::Loop which corresponds to given CanonicalLoopInfo
5177
// object. We should have a method which returns all blocks between
5178
// CanonicalLoopInfo::getHeader() and CanonicalLoopInfo::getAfter()
5179
FunctionAnalysisManager FAM;
5180
FAM.registerPass([]() { return DominatorTreeAnalysis(); });
5181
FAM.registerPass([]() { return LoopAnalysis(); });
5182
FAM.registerPass([]() { return PassInstrumentationAnalysis(); });
5183
5184
LoopAnalysis LIA;
5185
LoopInfo &&LI = LIA.run(*F, FAM);
5186
5187
Loop *L = LI.getLoopFor(CanonicalLoop->getHeader());
5188
if (AlignedVars.size()) {
5189
InsertPointTy IP = Builder.saveIP();
5190
Builder.SetInsertPoint(CanonicalLoop->getPreheader()->getTerminator());
5191
for (auto &AlignedItem : AlignedVars) {
5192
Value *AlignedPtr = AlignedItem.first;
5193
Value *Alignment = AlignedItem.second;
5194
Builder.CreateAlignmentAssumption(F->getDataLayout(),
5195
AlignedPtr, Alignment);
5196
}
5197
Builder.restoreIP(IP);
5198
}
5199
5200
if (IfCond) {
5201
ValueToValueMapTy VMap;
5202
createIfVersion(CanonicalLoop, IfCond, VMap, "simd");
5203
// Add metadata to the cloned loop which disables vectorization
5204
Value *MappedLatch = VMap.lookup(CanonicalLoop->getLatch());
5205
assert(MappedLatch &&
5206
"Cannot find value which corresponds to original loop latch");
5207
assert(isa<BasicBlock>(MappedLatch) &&
5208
"Cannot cast mapped latch block value to BasicBlock");
5209
BasicBlock *NewLatchBlock = dyn_cast<BasicBlock>(MappedLatch);
5210
ConstantAsMetadata *BoolConst =
5211
ConstantAsMetadata::get(ConstantInt::getFalse(Type::getInt1Ty(Ctx)));
5212
addBasicBlockMetadata(
5213
NewLatchBlock,
5214
{MDNode::get(Ctx, {MDString::get(Ctx, "llvm.loop.vectorize.enable"),
5215
BoolConst})});
5216
}
5217
5218
SmallSet<BasicBlock *, 8> Reachable;
5219
5220
// Get the basic blocks from the loop in which memref instructions
5221
// can be found.
5222
// TODO: Generalize getting all blocks inside a CanonicalizeLoopInfo,
5223
// preferably without running any passes.
5224
for (BasicBlock *Block : L->getBlocks()) {
5225
if (Block == CanonicalLoop->getCond() ||
5226
Block == CanonicalLoop->getHeader())
5227
continue;
5228
Reachable.insert(Block);
5229
}
5230
5231
SmallVector<Metadata *> LoopMDList;
5232
5233
// In presence of finite 'safelen', it may be unsafe to mark all
5234
// the memory instructions parallel, because loop-carried
5235
// dependences of 'safelen' iterations are possible.
5236
// If clause order(concurrent) is specified then the memory instructions
5237
// are marked parallel even if 'safelen' is finite.
5238
if ((Safelen == nullptr) || (Order == OrderKind::OMP_ORDER_concurrent)) {
5239
// Add access group metadata to memory-access instructions.
5240
MDNode *AccessGroup = MDNode::getDistinct(Ctx, {});
5241
for (BasicBlock *BB : Reachable)
5242
addSimdMetadata(BB, AccessGroup, LI);
5243
// TODO: If the loop has existing parallel access metadata, have
5244
// to combine two lists.
5245
LoopMDList.push_back(MDNode::get(
5246
Ctx, {MDString::get(Ctx, "llvm.loop.parallel_accesses"), AccessGroup}));
5247
}
5248
5249
// Use the above access group metadata to create loop level
5250
// metadata, which should be distinct for each loop.
5251
ConstantAsMetadata *BoolConst =
5252
ConstantAsMetadata::get(ConstantInt::getTrue(Type::getInt1Ty(Ctx)));
5253
LoopMDList.push_back(MDNode::get(
5254
Ctx, {MDString::get(Ctx, "llvm.loop.vectorize.enable"), BoolConst}));
5255
5256
if (Simdlen || Safelen) {
5257
// If both simdlen and safelen clauses are specified, the value of the
5258
// simdlen parameter must be less than or equal to the value of the safelen
5259
// parameter. Therefore, use safelen only in the absence of simdlen.
5260
ConstantInt *VectorizeWidth = Simdlen == nullptr ? Safelen : Simdlen;
5261
LoopMDList.push_back(
5262
MDNode::get(Ctx, {MDString::get(Ctx, "llvm.loop.vectorize.width"),
5263
ConstantAsMetadata::get(VectorizeWidth)}));
5264
}
5265
5266
addLoopMetadata(CanonicalLoop, LoopMDList);
5267
}
5268
5269
/// Create the TargetMachine object to query the backend for optimization
5270
/// preferences.
5271
///
5272
/// Ideally, this would be passed from the front-end to the OpenMPBuilder, but
5273
/// e.g. Clang does not pass it to its CodeGen layer and creates it only when
5274
/// needed for the LLVM pass pipline. We use some default options to avoid
5275
/// having to pass too many settings from the frontend that probably do not
5276
/// matter.
5277
///
5278
/// Currently, TargetMachine is only used sometimes by the unrollLoopPartial
5279
/// method. If we are going to use TargetMachine for more purposes, especially
5280
/// those that are sensitive to TargetOptions, RelocModel and CodeModel, it
5281
/// might become be worth requiring front-ends to pass on their TargetMachine,
5282
/// or at least cache it between methods. Note that while fontends such as Clang
5283
/// have just a single main TargetMachine per translation unit, "target-cpu" and
5284
/// "target-features" that determine the TargetMachine are per-function and can
5285
/// be overrided using __attribute__((target("OPTIONS"))).
5286
static std::unique_ptr<TargetMachine>
5287
createTargetMachine(Function *F, CodeGenOptLevel OptLevel) {
5288
Module *M = F->getParent();
5289
5290
StringRef CPU = F->getFnAttribute("target-cpu").getValueAsString();
5291
StringRef Features = F->getFnAttribute("target-features").getValueAsString();
5292
const std::string &Triple = M->getTargetTriple();
5293
5294
std::string Error;
5295
const llvm::Target *TheTarget = TargetRegistry::lookupTarget(Triple, Error);
5296
if (!TheTarget)
5297
return {};
5298
5299
llvm::TargetOptions Options;
5300
return std::unique_ptr<TargetMachine>(TheTarget->createTargetMachine(
5301
Triple, CPU, Features, Options, /*RelocModel=*/std::nullopt,
5302
/*CodeModel=*/std::nullopt, OptLevel));
5303
}
5304
5305
/// Heuristically determine the best-performant unroll factor for \p CLI. This
5306
/// depends on the target processor. We are re-using the same heuristics as the
5307
/// LoopUnrollPass.
5308
static int32_t computeHeuristicUnrollFactor(CanonicalLoopInfo *CLI) {
5309
Function *F = CLI->getFunction();
5310
5311
// Assume the user requests the most aggressive unrolling, even if the rest of
5312
// the code is optimized using a lower setting.
5313
CodeGenOptLevel OptLevel = CodeGenOptLevel::Aggressive;
5314
std::unique_ptr<TargetMachine> TM = createTargetMachine(F, OptLevel);
5315
5316
FunctionAnalysisManager FAM;
5317
FAM.registerPass([]() { return TargetLibraryAnalysis(); });
5318
FAM.registerPass([]() { return AssumptionAnalysis(); });
5319
FAM.registerPass([]() { return DominatorTreeAnalysis(); });
5320
FAM.registerPass([]() { return LoopAnalysis(); });
5321
FAM.registerPass([]() { return ScalarEvolutionAnalysis(); });
5322
FAM.registerPass([]() { return PassInstrumentationAnalysis(); });
5323
TargetIRAnalysis TIRA;
5324
if (TM)
5325
TIRA = TargetIRAnalysis(
5326
[&](const Function &F) { return TM->getTargetTransformInfo(F); });
5327
FAM.registerPass([&]() { return TIRA; });
5328
5329
TargetIRAnalysis::Result &&TTI = TIRA.run(*F, FAM);
5330
ScalarEvolutionAnalysis SEA;
5331
ScalarEvolution &&SE = SEA.run(*F, FAM);
5332
DominatorTreeAnalysis DTA;
5333
DominatorTree &&DT = DTA.run(*F, FAM);
5334
LoopAnalysis LIA;
5335
LoopInfo &&LI = LIA.run(*F, FAM);
5336
AssumptionAnalysis ACT;
5337
AssumptionCache &&AC = ACT.run(*F, FAM);
5338
OptimizationRemarkEmitter ORE{F};
5339
5340
Loop *L = LI.getLoopFor(CLI->getHeader());
5341
assert(L && "Expecting CanonicalLoopInfo to be recognized as a loop");
5342
5343
TargetTransformInfo::UnrollingPreferences UP =
5344
gatherUnrollingPreferences(L, SE, TTI,
5345
/*BlockFrequencyInfo=*/nullptr,
5346
/*ProfileSummaryInfo=*/nullptr, ORE, static_cast<int>(OptLevel),
5347
/*UserThreshold=*/std::nullopt,
5348
/*UserCount=*/std::nullopt,
5349
/*UserAllowPartial=*/true,
5350
/*UserAllowRuntime=*/true,
5351
/*UserUpperBound=*/std::nullopt,
5352
/*UserFullUnrollMaxCount=*/std::nullopt);
5353
5354
UP.Force = true;
5355
5356
// Account for additional optimizations taking place before the LoopUnrollPass
5357
// would unroll the loop.
5358
UP.Threshold *= UnrollThresholdFactor;
5359
UP.PartialThreshold *= UnrollThresholdFactor;
5360
5361
// Use normal unroll factors even if the rest of the code is optimized for
5362
// size.
5363
UP.OptSizeThreshold = UP.Threshold;
5364
UP.PartialOptSizeThreshold = UP.PartialThreshold;
5365
5366
LLVM_DEBUG(dbgs() << "Unroll heuristic thresholds:\n"
5367
<< " Threshold=" << UP.Threshold << "\n"
5368
<< " PartialThreshold=" << UP.PartialThreshold << "\n"
5369
<< " OptSizeThreshold=" << UP.OptSizeThreshold << "\n"
5370
<< " PartialOptSizeThreshold="
5371
<< UP.PartialOptSizeThreshold << "\n");
5372
5373
// Disable peeling.
5374
TargetTransformInfo::PeelingPreferences PP =
5375
gatherPeelingPreferences(L, SE, TTI,
5376
/*UserAllowPeeling=*/false,
5377
/*UserAllowProfileBasedPeeling=*/false,
5378
/*UnrollingSpecficValues=*/false);
5379
5380
SmallPtrSet<const Value *, 32> EphValues;
5381
CodeMetrics::collectEphemeralValues(L, &AC, EphValues);
5382
5383
// Assume that reads and writes to stack variables can be eliminated by
5384
// Mem2Reg, SROA or LICM. That is, don't count them towards the loop body's
5385
// size.
5386
for (BasicBlock *BB : L->blocks()) {
5387
for (Instruction &I : *BB) {
5388
Value *Ptr;
5389
if (auto *Load = dyn_cast<LoadInst>(&I)) {
5390
Ptr = Load->getPointerOperand();
5391
} else if (auto *Store = dyn_cast<StoreInst>(&I)) {
5392
Ptr = Store->getPointerOperand();
5393
} else
5394
continue;
5395
5396
Ptr = Ptr->stripPointerCasts();
5397
5398
if (auto *Alloca = dyn_cast<AllocaInst>(Ptr)) {
5399
if (Alloca->getParent() == &F->getEntryBlock())
5400
EphValues.insert(&I);
5401
}
5402
}
5403
}
5404
5405
UnrollCostEstimator UCE(L, TTI, EphValues, UP.BEInsns);
5406
5407
// Loop is not unrollable if the loop contains certain instructions.
5408
if (!UCE.canUnroll()) {
5409
LLVM_DEBUG(dbgs() << "Loop not considered unrollable\n");
5410
return 1;
5411
}
5412
5413
LLVM_DEBUG(dbgs() << "Estimated loop size is " << UCE.getRolledLoopSize()
5414
<< "\n");
5415
5416
// TODO: Determine trip count of \p CLI if constant, computeUnrollCount might
5417
// be able to use it.
5418
int TripCount = 0;
5419
int MaxTripCount = 0;
5420
bool MaxOrZero = false;
5421
unsigned TripMultiple = 0;
5422
5423
bool UseUpperBound = false;
5424
computeUnrollCount(L, TTI, DT, &LI, &AC, SE, EphValues, &ORE, TripCount,
5425
MaxTripCount, MaxOrZero, TripMultiple, UCE, UP, PP,
5426
UseUpperBound);
5427
unsigned Factor = UP.Count;
5428
LLVM_DEBUG(dbgs() << "Suggesting unroll factor of " << Factor << "\n");
5429
5430
// This function returns 1 to signal to not unroll a loop.
5431
if (Factor == 0)
5432
return 1;
5433
return Factor;
5434
}
5435
5436
void OpenMPIRBuilder::unrollLoopPartial(DebugLoc DL, CanonicalLoopInfo *Loop,
5437
int32_t Factor,
5438
CanonicalLoopInfo **UnrolledCLI) {
5439
assert(Factor >= 0 && "Unroll factor must not be negative");
5440
5441
Function *F = Loop->getFunction();
5442
LLVMContext &Ctx = F->getContext();
5443
5444
// If the unrolled loop is not used for another loop-associated directive, it
5445
// is sufficient to add metadata for the LoopUnrollPass.
5446
if (!UnrolledCLI) {
5447
SmallVector<Metadata *, 2> LoopMetadata;
5448
LoopMetadata.push_back(
5449
MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.enable")));
5450
5451
if (Factor >= 1) {
5452
ConstantAsMetadata *FactorConst = ConstantAsMetadata::get(
5453
ConstantInt::get(Type::getInt32Ty(Ctx), APInt(32, Factor)));
5454
LoopMetadata.push_back(MDNode::get(
5455
Ctx, {MDString::get(Ctx, "llvm.loop.unroll.count"), FactorConst}));
5456
}
5457
5458
addLoopMetadata(Loop, LoopMetadata);
5459
return;
5460
}
5461
5462
// Heuristically determine the unroll factor.
5463
if (Factor == 0)
5464
Factor = computeHeuristicUnrollFactor(Loop);
5465
5466
// No change required with unroll factor 1.
5467
if (Factor == 1) {
5468
*UnrolledCLI = Loop;
5469
return;
5470
}
5471
5472
assert(Factor >= 2 &&
5473
"unrolling only makes sense with a factor of 2 or larger");
5474
5475
Type *IndVarTy = Loop->getIndVarType();
5476
5477
// Apply partial unrolling by tiling the loop by the unroll-factor, then fully
5478
// unroll the inner loop.
5479
Value *FactorVal =
5480
ConstantInt::get(IndVarTy, APInt(IndVarTy->getIntegerBitWidth(), Factor,
5481
/*isSigned=*/false));
5482
std::vector<CanonicalLoopInfo *> LoopNest =
5483
tileLoops(DL, {Loop}, {FactorVal});
5484
assert(LoopNest.size() == 2 && "Expect 2 loops after tiling");
5485
*UnrolledCLI = LoopNest[0];
5486
CanonicalLoopInfo *InnerLoop = LoopNest[1];
5487
5488
// LoopUnrollPass can only fully unroll loops with constant trip count.
5489
// Unroll by the unroll factor with a fallback epilog for the remainder
5490
// iterations if necessary.
5491
ConstantAsMetadata *FactorConst = ConstantAsMetadata::get(
5492
ConstantInt::get(Type::getInt32Ty(Ctx), APInt(32, Factor)));
5493
addLoopMetadata(
5494
InnerLoop,
5495
{MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.enable")),
5496
MDNode::get(
5497
Ctx, {MDString::get(Ctx, "llvm.loop.unroll.count"), FactorConst})});
5498
5499
#ifndef NDEBUG
5500
(*UnrolledCLI)->assertOK();
5501
#endif
5502
}
5503
5504
OpenMPIRBuilder::InsertPointTy
5505
OpenMPIRBuilder::createCopyPrivate(const LocationDescription &Loc,
5506
llvm::Value *BufSize, llvm::Value *CpyBuf,
5507
llvm::Value *CpyFn, llvm::Value *DidIt) {
5508
if (!updateToLocation(Loc))
5509
return Loc.IP;
5510
5511
uint32_t SrcLocStrSize;
5512
Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5513
Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5514
Value *ThreadId = getOrCreateThreadID(Ident);
5515
5516
llvm::Value *DidItLD = Builder.CreateLoad(Builder.getInt32Ty(), DidIt);
5517
5518
Value *Args[] = {Ident, ThreadId, BufSize, CpyBuf, CpyFn, DidItLD};
5519
5520
Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_copyprivate);
5521
Builder.CreateCall(Fn, Args);
5522
5523
return Builder.saveIP();
5524
}
5525
5526
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createSingle(
5527
const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
5528
FinalizeCallbackTy FiniCB, bool IsNowait, ArrayRef<llvm::Value *> CPVars,
5529
ArrayRef<llvm::Function *> CPFuncs) {
5530
5531
if (!updateToLocation(Loc))
5532
return Loc.IP;
5533
5534
// If needed allocate and initialize `DidIt` with 0.
5535
// DidIt: flag variable: 1=single thread; 0=not single thread.
5536
llvm::Value *DidIt = nullptr;
5537
if (!CPVars.empty()) {
5538
DidIt = Builder.CreateAlloca(llvm::Type::getInt32Ty(Builder.getContext()));
5539
Builder.CreateStore(Builder.getInt32(0), DidIt);
5540
}
5541
5542
Directive OMPD = Directive::OMPD_single;
5543
uint32_t SrcLocStrSize;
5544
Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5545
Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5546
Value *ThreadId = getOrCreateThreadID(Ident);
5547
Value *Args[] = {Ident, ThreadId};
5548
5549
Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_single);
5550
Instruction *EntryCall = Builder.CreateCall(EntryRTLFn, Args);
5551
5552
Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_single);
5553
Instruction *ExitCall = Builder.CreateCall(ExitRTLFn, Args);
5554
5555
auto FiniCBWrapper = [&](InsertPointTy IP) {
5556
FiniCB(IP);
5557
5558
// The thread that executes the single region must set `DidIt` to 1.
5559
// This is used by __kmpc_copyprivate, to know if the caller is the
5560
// single thread or not.
5561
if (DidIt)
5562
Builder.CreateStore(Builder.getInt32(1), DidIt);
5563
};
5564
5565
// generates the following:
5566
// if (__kmpc_single()) {
5567
// .... single region ...
5568
// __kmpc_end_single
5569
// }
5570
// __kmpc_copyprivate
5571
// __kmpc_barrier
5572
5573
EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCBWrapper,
5574
/*Conditional*/ true,
5575
/*hasFinalize*/ true);
5576
5577
if (DidIt) {
5578
for (size_t I = 0, E = CPVars.size(); I < E; ++I)
5579
// NOTE BufSize is currently unused, so just pass 0.
5580
createCopyPrivate(LocationDescription(Builder.saveIP(), Loc.DL),
5581
/*BufSize=*/ConstantInt::get(Int64, 0), CPVars[I],
5582
CPFuncs[I], DidIt);
5583
// NOTE __kmpc_copyprivate already inserts a barrier
5584
} else if (!IsNowait)
5585
createBarrier(LocationDescription(Builder.saveIP(), Loc.DL),
5586
omp::Directive::OMPD_unknown, /* ForceSimpleCall */ false,
5587
/* CheckCancelFlag */ false);
5588
return Builder.saveIP();
5589
}
5590
5591
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createCritical(
5592
const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
5593
FinalizeCallbackTy FiniCB, StringRef CriticalName, Value *HintInst) {
5594
5595
if (!updateToLocation(Loc))
5596
return Loc.IP;
5597
5598
Directive OMPD = Directive::OMPD_critical;
5599
uint32_t SrcLocStrSize;
5600
Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5601
Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5602
Value *ThreadId = getOrCreateThreadID(Ident);
5603
Value *LockVar = getOMPCriticalRegionLock(CriticalName);
5604
Value *Args[] = {Ident, ThreadId, LockVar};
5605
5606
SmallVector<llvm::Value *, 4> EnterArgs(std::begin(Args), std::end(Args));
5607
Function *RTFn = nullptr;
5608
if (HintInst) {
5609
// Add Hint to entry Args and create call
5610
EnterArgs.push_back(HintInst);
5611
RTFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_critical_with_hint);
5612
} else {
5613
RTFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_critical);
5614
}
5615
Instruction *EntryCall = Builder.CreateCall(RTFn, EnterArgs);
5616
5617
Function *ExitRTLFn =
5618
getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_critical);
5619
Instruction *ExitCall = Builder.CreateCall(ExitRTLFn, Args);
5620
5621
return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
5622
/*Conditional*/ false, /*hasFinalize*/ true);
5623
}
5624
5625
OpenMPIRBuilder::InsertPointTy
5626
OpenMPIRBuilder::createOrderedDepend(const LocationDescription &Loc,
5627
InsertPointTy AllocaIP, unsigned NumLoops,
5628
ArrayRef<llvm::Value *> StoreValues,
5629
const Twine &Name, bool IsDependSource) {
5630
assert(
5631
llvm::all_of(StoreValues,
5632
[](Value *SV) { return SV->getType()->isIntegerTy(64); }) &&
5633
"OpenMP runtime requires depend vec with i64 type");
5634
5635
if (!updateToLocation(Loc))
5636
return Loc.IP;
5637
5638
// Allocate space for vector and generate alloc instruction.
5639
auto *ArrI64Ty = ArrayType::get(Int64, NumLoops);
5640
Builder.restoreIP(AllocaIP);
5641
AllocaInst *ArgsBase = Builder.CreateAlloca(ArrI64Ty, nullptr, Name);
5642
ArgsBase->setAlignment(Align(8));
5643
Builder.restoreIP(Loc.IP);
5644
5645
// Store the index value with offset in depend vector.
5646
for (unsigned I = 0; I < NumLoops; ++I) {
5647
Value *DependAddrGEPIter = Builder.CreateInBoundsGEP(
5648
ArrI64Ty, ArgsBase, {Builder.getInt64(0), Builder.getInt64(I)});
5649
StoreInst *STInst = Builder.CreateStore(StoreValues[I], DependAddrGEPIter);
5650
STInst->setAlignment(Align(8));
5651
}
5652
5653
Value *DependBaseAddrGEP = Builder.CreateInBoundsGEP(
5654
ArrI64Ty, ArgsBase, {Builder.getInt64(0), Builder.getInt64(0)});
5655
5656
uint32_t SrcLocStrSize;
5657
Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5658
Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5659
Value *ThreadId = getOrCreateThreadID(Ident);
5660
Value *Args[] = {Ident, ThreadId, DependBaseAddrGEP};
5661
5662
Function *RTLFn = nullptr;
5663
if (IsDependSource)
5664
RTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_doacross_post);
5665
else
5666
RTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_doacross_wait);
5667
Builder.CreateCall(RTLFn, Args);
5668
5669
return Builder.saveIP();
5670
}
5671
5672
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createOrderedThreadsSimd(
5673
const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
5674
FinalizeCallbackTy FiniCB, bool IsThreads) {
5675
if (!updateToLocation(Loc))
5676
return Loc.IP;
5677
5678
Directive OMPD = Directive::OMPD_ordered;
5679
Instruction *EntryCall = nullptr;
5680
Instruction *ExitCall = nullptr;
5681
5682
if (IsThreads) {
5683
uint32_t SrcLocStrSize;
5684
Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5685
Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5686
Value *ThreadId = getOrCreateThreadID(Ident);
5687
Value *Args[] = {Ident, ThreadId};
5688
5689
Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_ordered);
5690
EntryCall = Builder.CreateCall(EntryRTLFn, Args);
5691
5692
Function *ExitRTLFn =
5693
getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_ordered);
5694
ExitCall = Builder.CreateCall(ExitRTLFn, Args);
5695
}
5696
5697
return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
5698
/*Conditional*/ false, /*hasFinalize*/ true);
5699
}
5700
5701
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::EmitOMPInlinedRegion(
5702
Directive OMPD, Instruction *EntryCall, Instruction *ExitCall,
5703
BodyGenCallbackTy BodyGenCB, FinalizeCallbackTy FiniCB, bool Conditional,
5704
bool HasFinalize, bool IsCancellable) {
5705
5706
if (HasFinalize)
5707
FinalizationStack.push_back({FiniCB, OMPD, IsCancellable});
5708
5709
// Create inlined region's entry and body blocks, in preparation
5710
// for conditional creation
5711
BasicBlock *EntryBB = Builder.GetInsertBlock();
5712
Instruction *SplitPos = EntryBB->getTerminator();
5713
if (!isa_and_nonnull<BranchInst>(SplitPos))
5714
SplitPos = new UnreachableInst(Builder.getContext(), EntryBB);
5715
BasicBlock *ExitBB = EntryBB->splitBasicBlock(SplitPos, "omp_region.end");
5716
BasicBlock *FiniBB =
5717
EntryBB->splitBasicBlock(EntryBB->getTerminator(), "omp_region.finalize");
5718
5719
Builder.SetInsertPoint(EntryBB->getTerminator());
5720
emitCommonDirectiveEntry(OMPD, EntryCall, ExitBB, Conditional);
5721
5722
// generate body
5723
BodyGenCB(/* AllocaIP */ InsertPointTy(),
5724
/* CodeGenIP */ Builder.saveIP());
5725
5726
// emit exit call and do any needed finalization.
5727
auto FinIP = InsertPointTy(FiniBB, FiniBB->getFirstInsertionPt());
5728
assert(FiniBB->getTerminator()->getNumSuccessors() == 1 &&
5729
FiniBB->getTerminator()->getSuccessor(0) == ExitBB &&
5730
"Unexpected control flow graph state!!");
5731
emitCommonDirectiveExit(OMPD, FinIP, ExitCall, HasFinalize);
5732
assert(FiniBB->getUniquePredecessor()->getUniqueSuccessor() == FiniBB &&
5733
"Unexpected Control Flow State!");
5734
MergeBlockIntoPredecessor(FiniBB);
5735
5736
// If we are skipping the region of a non conditional, remove the exit
5737
// block, and clear the builder's insertion point.
5738
assert(SplitPos->getParent() == ExitBB &&
5739
"Unexpected Insertion point location!");
5740
auto merged = MergeBlockIntoPredecessor(ExitBB);
5741
BasicBlock *ExitPredBB = SplitPos->getParent();
5742
auto InsertBB = merged ? ExitPredBB : ExitBB;
5743
if (!isa_and_nonnull<BranchInst>(SplitPos))
5744
SplitPos->eraseFromParent();
5745
Builder.SetInsertPoint(InsertBB);
5746
5747
return Builder.saveIP();
5748
}
5749
5750
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitCommonDirectiveEntry(
5751
Directive OMPD, Value *EntryCall, BasicBlock *ExitBB, bool Conditional) {
5752
// if nothing to do, Return current insertion point.
5753
if (!Conditional || !EntryCall)
5754
return Builder.saveIP();
5755
5756
BasicBlock *EntryBB = Builder.GetInsertBlock();
5757
Value *CallBool = Builder.CreateIsNotNull(EntryCall);
5758
auto *ThenBB = BasicBlock::Create(M.getContext(), "omp_region.body");
5759
auto *UI = new UnreachableInst(Builder.getContext(), ThenBB);
5760
5761
// Emit thenBB and set the Builder's insertion point there for
5762
// body generation next. Place the block after the current block.
5763
Function *CurFn = EntryBB->getParent();
5764
CurFn->insert(std::next(EntryBB->getIterator()), ThenBB);
5765
5766
// Move Entry branch to end of ThenBB, and replace with conditional
5767
// branch (If-stmt)
5768
Instruction *EntryBBTI = EntryBB->getTerminator();
5769
Builder.CreateCondBr(CallBool, ThenBB, ExitBB);
5770
EntryBBTI->removeFromParent();
5771
Builder.SetInsertPoint(UI);
5772
Builder.Insert(EntryBBTI);
5773
UI->eraseFromParent();
5774
Builder.SetInsertPoint(ThenBB->getTerminator());
5775
5776
// return an insertion point to ExitBB.
5777
return IRBuilder<>::InsertPoint(ExitBB, ExitBB->getFirstInsertionPt());
5778
}
5779
5780
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitCommonDirectiveExit(
5781
omp::Directive OMPD, InsertPointTy FinIP, Instruction *ExitCall,
5782
bool HasFinalize) {
5783
5784
Builder.restoreIP(FinIP);
5785
5786
// If there is finalization to do, emit it before the exit call
5787
if (HasFinalize) {
5788
assert(!FinalizationStack.empty() &&
5789
"Unexpected finalization stack state!");
5790
5791
FinalizationInfo Fi = FinalizationStack.pop_back_val();
5792
assert(Fi.DK == OMPD && "Unexpected Directive for Finalization call!");
5793
5794
Fi.FiniCB(FinIP);
5795
5796
BasicBlock *FiniBB = FinIP.getBlock();
5797
Instruction *FiniBBTI = FiniBB->getTerminator();
5798
5799
// set Builder IP for call creation
5800
Builder.SetInsertPoint(FiniBBTI);
5801
}
5802
5803
if (!ExitCall)
5804
return Builder.saveIP();
5805
5806
// place the Exitcall as last instruction before Finalization block terminator
5807
ExitCall->removeFromParent();
5808
Builder.Insert(ExitCall);
5809
5810
return IRBuilder<>::InsertPoint(ExitCall->getParent(),
5811
ExitCall->getIterator());
5812
}
5813
5814
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createCopyinClauseBlocks(
5815
InsertPointTy IP, Value *MasterAddr, Value *PrivateAddr,
5816
llvm::IntegerType *IntPtrTy, bool BranchtoEnd) {
5817
if (!IP.isSet())
5818
return IP;
5819
5820
IRBuilder<>::InsertPointGuard IPG(Builder);
5821
5822
// creates the following CFG structure
5823
// OMP_Entry : (MasterAddr != PrivateAddr)?
5824
// F T
5825
// | \
5826
// | copin.not.master
5827
// | /
5828
// v /
5829
// copyin.not.master.end
5830
// |
5831
// v
5832
// OMP.Entry.Next
5833
5834
BasicBlock *OMP_Entry = IP.getBlock();
5835
Function *CurFn = OMP_Entry->getParent();
5836
BasicBlock *CopyBegin =
5837
BasicBlock::Create(M.getContext(), "copyin.not.master", CurFn);
5838
BasicBlock *CopyEnd = nullptr;
5839
5840
// If entry block is terminated, split to preserve the branch to following
5841
// basic block (i.e. OMP.Entry.Next), otherwise, leave everything as is.
5842
if (isa_and_nonnull<BranchInst>(OMP_Entry->getTerminator())) {
5843
CopyEnd = OMP_Entry->splitBasicBlock(OMP_Entry->getTerminator(),
5844
"copyin.not.master.end");
5845
OMP_Entry->getTerminator()->eraseFromParent();
5846
} else {
5847
CopyEnd =
5848
BasicBlock::Create(M.getContext(), "copyin.not.master.end", CurFn);
5849
}
5850
5851
Builder.SetInsertPoint(OMP_Entry);
5852
Value *MasterPtr = Builder.CreatePtrToInt(MasterAddr, IntPtrTy);
5853
Value *PrivatePtr = Builder.CreatePtrToInt(PrivateAddr, IntPtrTy);
5854
Value *cmp = Builder.CreateICmpNE(MasterPtr, PrivatePtr);
5855
Builder.CreateCondBr(cmp, CopyBegin, CopyEnd);
5856
5857
Builder.SetInsertPoint(CopyBegin);
5858
if (BranchtoEnd)
5859
Builder.SetInsertPoint(Builder.CreateBr(CopyEnd));
5860
5861
return Builder.saveIP();
5862
}
5863
5864
CallInst *OpenMPIRBuilder::createOMPAlloc(const LocationDescription &Loc,
5865
Value *Size, Value *Allocator,
5866
std::string Name) {
5867
IRBuilder<>::InsertPointGuard IPG(Builder);
5868
updateToLocation(Loc);
5869
5870
uint32_t SrcLocStrSize;
5871
Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5872
Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5873
Value *ThreadId = getOrCreateThreadID(Ident);
5874
Value *Args[] = {ThreadId, Size, Allocator};
5875
5876
Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_alloc);
5877
5878
return Builder.CreateCall(Fn, Args, Name);
5879
}
5880
5881
CallInst *OpenMPIRBuilder::createOMPFree(const LocationDescription &Loc,
5882
Value *Addr, Value *Allocator,
5883
std::string Name) {
5884
IRBuilder<>::InsertPointGuard IPG(Builder);
5885
updateToLocation(Loc);
5886
5887
uint32_t SrcLocStrSize;
5888
Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5889
Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5890
Value *ThreadId = getOrCreateThreadID(Ident);
5891
Value *Args[] = {ThreadId, Addr, Allocator};
5892
Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_free);
5893
return Builder.CreateCall(Fn, Args, Name);
5894
}
5895
5896
CallInst *OpenMPIRBuilder::createOMPInteropInit(
5897
const LocationDescription &Loc, Value *InteropVar,
5898
omp::OMPInteropType InteropType, Value *Device, Value *NumDependences,
5899
Value *DependenceAddress, bool HaveNowaitClause) {
5900
IRBuilder<>::InsertPointGuard IPG(Builder);
5901
updateToLocation(Loc);
5902
5903
uint32_t SrcLocStrSize;
5904
Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5905
Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5906
Value *ThreadId = getOrCreateThreadID(Ident);
5907
if (Device == nullptr)
5908
Device = ConstantInt::get(Int32, -1);
5909
Constant *InteropTypeVal = ConstantInt::get(Int32, (int)InteropType);
5910
if (NumDependences == nullptr) {
5911
NumDependences = ConstantInt::get(Int32, 0);
5912
PointerType *PointerTypeVar = PointerType::getUnqual(M.getContext());
5913
DependenceAddress = ConstantPointerNull::get(PointerTypeVar);
5914
}
5915
Value *HaveNowaitClauseVal = ConstantInt::get(Int32, HaveNowaitClause);
5916
Value *Args[] = {
5917
Ident, ThreadId, InteropVar, InteropTypeVal,
5918
Device, NumDependences, DependenceAddress, HaveNowaitClauseVal};
5919
5920
Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___tgt_interop_init);
5921
5922
return Builder.CreateCall(Fn, Args);
5923
}
5924
5925
CallInst *OpenMPIRBuilder::createOMPInteropDestroy(
5926
const LocationDescription &Loc, Value *InteropVar, Value *Device,
5927
Value *NumDependences, Value *DependenceAddress, bool HaveNowaitClause) {
5928
IRBuilder<>::InsertPointGuard IPG(Builder);
5929
updateToLocation(Loc);
5930
5931
uint32_t SrcLocStrSize;
5932
Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5933
Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5934
Value *ThreadId = getOrCreateThreadID(Ident);
5935
if (Device == nullptr)
5936
Device = ConstantInt::get(Int32, -1);
5937
if (NumDependences == nullptr) {
5938
NumDependences = ConstantInt::get(Int32, 0);
5939
PointerType *PointerTypeVar = PointerType::getUnqual(M.getContext());
5940
DependenceAddress = ConstantPointerNull::get(PointerTypeVar);
5941
}
5942
Value *HaveNowaitClauseVal = ConstantInt::get(Int32, HaveNowaitClause);
5943
Value *Args[] = {
5944
Ident, ThreadId, InteropVar, Device,
5945
NumDependences, DependenceAddress, HaveNowaitClauseVal};
5946
5947
Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___tgt_interop_destroy);
5948
5949
return Builder.CreateCall(Fn, Args);
5950
}
5951
5952
CallInst *OpenMPIRBuilder::createOMPInteropUse(const LocationDescription &Loc,
5953
Value *InteropVar, Value *Device,
5954
Value *NumDependences,
5955
Value *DependenceAddress,
5956
bool HaveNowaitClause) {
5957
IRBuilder<>::InsertPointGuard IPG(Builder);
5958
updateToLocation(Loc);
5959
uint32_t SrcLocStrSize;
5960
Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5961
Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5962
Value *ThreadId = getOrCreateThreadID(Ident);
5963
if (Device == nullptr)
5964
Device = ConstantInt::get(Int32, -1);
5965
if (NumDependences == nullptr) {
5966
NumDependences = ConstantInt::get(Int32, 0);
5967
PointerType *PointerTypeVar = PointerType::getUnqual(M.getContext());
5968
DependenceAddress = ConstantPointerNull::get(PointerTypeVar);
5969
}
5970
Value *HaveNowaitClauseVal = ConstantInt::get(Int32, HaveNowaitClause);
5971
Value *Args[] = {
5972
Ident, ThreadId, InteropVar, Device,
5973
NumDependences, DependenceAddress, HaveNowaitClauseVal};
5974
5975
Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___tgt_interop_use);
5976
5977
return Builder.CreateCall(Fn, Args);
5978
}
5979
5980
CallInst *OpenMPIRBuilder::createCachedThreadPrivate(
5981
const LocationDescription &Loc, llvm::Value *Pointer,
5982
llvm::ConstantInt *Size, const llvm::Twine &Name) {
5983
IRBuilder<>::InsertPointGuard IPG(Builder);
5984
updateToLocation(Loc);
5985
5986
uint32_t SrcLocStrSize;
5987
Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
5988
Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
5989
Value *ThreadId = getOrCreateThreadID(Ident);
5990
Constant *ThreadPrivateCache =
5991
getOrCreateInternalVariable(Int8PtrPtr, Name.str());
5992
llvm::Value *Args[] = {Ident, ThreadId, Pointer, Size, ThreadPrivateCache};
5993
5994
Function *Fn =
5995
getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_threadprivate_cached);
5996
5997
return Builder.CreateCall(Fn, Args);
5998
}
5999
6000
OpenMPIRBuilder::InsertPointTy
6001
OpenMPIRBuilder::createTargetInit(const LocationDescription &Loc, bool IsSPMD,
6002
int32_t MinThreadsVal, int32_t MaxThreadsVal,
6003
int32_t MinTeamsVal, int32_t MaxTeamsVal) {
6004
if (!updateToLocation(Loc))
6005
return Loc.IP;
6006
6007
uint32_t SrcLocStrSize;
6008
Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
6009
Constant *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6010
Constant *IsSPMDVal = ConstantInt::getSigned(
6011
Int8, IsSPMD ? OMP_TGT_EXEC_MODE_SPMD : OMP_TGT_EXEC_MODE_GENERIC);
6012
Constant *UseGenericStateMachineVal = ConstantInt::getSigned(Int8, !IsSPMD);
6013
Constant *MayUseNestedParallelismVal = ConstantInt::getSigned(Int8, true);
6014
Constant *DebugIndentionLevelVal = ConstantInt::getSigned(Int16, 0);
6015
6016
Function *Kernel = Builder.GetInsertBlock()->getParent();
6017
6018
// Manifest the launch configuration in the metadata matching the kernel
6019
// environment.
6020
if (MinTeamsVal > 1 || MaxTeamsVal > 0)
6021
writeTeamsForKernel(T, *Kernel, MinTeamsVal, MaxTeamsVal);
6022
6023
// For max values, < 0 means unset, == 0 means set but unknown.
6024
if (MaxThreadsVal < 0)
6025
MaxThreadsVal = std::max(
6026
int32_t(getGridValue(T, Kernel).GV_Default_WG_Size), MinThreadsVal);
6027
6028
if (MaxThreadsVal > 0)
6029
writeThreadBoundsForKernel(T, *Kernel, MinThreadsVal, MaxThreadsVal);
6030
6031
Constant *MinThreads = ConstantInt::getSigned(Int32, MinThreadsVal);
6032
Constant *MaxThreads = ConstantInt::getSigned(Int32, MaxThreadsVal);
6033
Constant *MinTeams = ConstantInt::getSigned(Int32, MinTeamsVal);
6034
Constant *MaxTeams = ConstantInt::getSigned(Int32, MaxTeamsVal);
6035
Constant *ReductionDataSize = ConstantInt::getSigned(Int32, 0);
6036
Constant *ReductionBufferLength = ConstantInt::getSigned(Int32, 0);
6037
6038
// We need to strip the debug prefix to get the correct kernel name.
6039
StringRef KernelName = Kernel->getName();
6040
const std::string DebugPrefix = "_debug__";
6041
if (KernelName.ends_with(DebugPrefix))
6042
KernelName = KernelName.drop_back(DebugPrefix.length());
6043
6044
Function *Fn = getOrCreateRuntimeFunctionPtr(
6045
omp::RuntimeFunction::OMPRTL___kmpc_target_init);
6046
const DataLayout &DL = Fn->getDataLayout();
6047
6048
Twine DynamicEnvironmentName = KernelName + "_dynamic_environment";
6049
Constant *DynamicEnvironmentInitializer =
6050
ConstantStruct::get(DynamicEnvironment, {DebugIndentionLevelVal});
6051
GlobalVariable *DynamicEnvironmentGV = new GlobalVariable(
6052
M, DynamicEnvironment, /*IsConstant=*/false, GlobalValue::WeakODRLinkage,
6053
DynamicEnvironmentInitializer, DynamicEnvironmentName,
6054
/*InsertBefore=*/nullptr, GlobalValue::NotThreadLocal,
6055
DL.getDefaultGlobalsAddressSpace());
6056
DynamicEnvironmentGV->setVisibility(GlobalValue::ProtectedVisibility);
6057
6058
Constant *DynamicEnvironment =
6059
DynamicEnvironmentGV->getType() == DynamicEnvironmentPtr
6060
? DynamicEnvironmentGV
6061
: ConstantExpr::getAddrSpaceCast(DynamicEnvironmentGV,
6062
DynamicEnvironmentPtr);
6063
6064
Constant *ConfigurationEnvironmentInitializer = ConstantStruct::get(
6065
ConfigurationEnvironment, {
6066
UseGenericStateMachineVal,
6067
MayUseNestedParallelismVal,
6068
IsSPMDVal,
6069
MinThreads,
6070
MaxThreads,
6071
MinTeams,
6072
MaxTeams,
6073
ReductionDataSize,
6074
ReductionBufferLength,
6075
});
6076
Constant *KernelEnvironmentInitializer = ConstantStruct::get(
6077
KernelEnvironment, {
6078
ConfigurationEnvironmentInitializer,
6079
Ident,
6080
DynamicEnvironment,
6081
});
6082
std::string KernelEnvironmentName =
6083
(KernelName + "_kernel_environment").str();
6084
GlobalVariable *KernelEnvironmentGV = new GlobalVariable(
6085
M, KernelEnvironment, /*IsConstant=*/true, GlobalValue::WeakODRLinkage,
6086
KernelEnvironmentInitializer, KernelEnvironmentName,
6087
/*InsertBefore=*/nullptr, GlobalValue::NotThreadLocal,
6088
DL.getDefaultGlobalsAddressSpace());
6089
KernelEnvironmentGV->setVisibility(GlobalValue::ProtectedVisibility);
6090
6091
Constant *KernelEnvironment =
6092
KernelEnvironmentGV->getType() == KernelEnvironmentPtr
6093
? KernelEnvironmentGV
6094
: ConstantExpr::getAddrSpaceCast(KernelEnvironmentGV,
6095
KernelEnvironmentPtr);
6096
Value *KernelLaunchEnvironment = Kernel->getArg(0);
6097
CallInst *ThreadKind =
6098
Builder.CreateCall(Fn, {KernelEnvironment, KernelLaunchEnvironment});
6099
6100
Value *ExecUserCode = Builder.CreateICmpEQ(
6101
ThreadKind, ConstantInt::get(ThreadKind->getType(), -1),
6102
"exec_user_code");
6103
6104
// ThreadKind = __kmpc_target_init(...)
6105
// if (ThreadKind == -1)
6106
// user_code
6107
// else
6108
// return;
6109
6110
auto *UI = Builder.CreateUnreachable();
6111
BasicBlock *CheckBB = UI->getParent();
6112
BasicBlock *UserCodeEntryBB = CheckBB->splitBasicBlock(UI, "user_code.entry");
6113
6114
BasicBlock *WorkerExitBB = BasicBlock::Create(
6115
CheckBB->getContext(), "worker.exit", CheckBB->getParent());
6116
Builder.SetInsertPoint(WorkerExitBB);
6117
Builder.CreateRetVoid();
6118
6119
auto *CheckBBTI = CheckBB->getTerminator();
6120
Builder.SetInsertPoint(CheckBBTI);
6121
Builder.CreateCondBr(ExecUserCode, UI->getParent(), WorkerExitBB);
6122
6123
CheckBBTI->eraseFromParent();
6124
UI->eraseFromParent();
6125
6126
// Continue in the "user_code" block, see diagram above and in
6127
// openmp/libomptarget/deviceRTLs/common/include/target.h .
6128
return InsertPointTy(UserCodeEntryBB, UserCodeEntryBB->getFirstInsertionPt());
6129
}
6130
6131
void OpenMPIRBuilder::createTargetDeinit(const LocationDescription &Loc,
6132
int32_t TeamsReductionDataSize,
6133
int32_t TeamsReductionBufferLength) {
6134
if (!updateToLocation(Loc))
6135
return;
6136
6137
Function *Fn = getOrCreateRuntimeFunctionPtr(
6138
omp::RuntimeFunction::OMPRTL___kmpc_target_deinit);
6139
6140
Builder.CreateCall(Fn, {});
6141
6142
if (!TeamsReductionBufferLength || !TeamsReductionDataSize)
6143
return;
6144
6145
Function *Kernel = Builder.GetInsertBlock()->getParent();
6146
// We need to strip the debug prefix to get the correct kernel name.
6147
StringRef KernelName = Kernel->getName();
6148
const std::string DebugPrefix = "_debug__";
6149
if (KernelName.ends_with(DebugPrefix))
6150
KernelName = KernelName.drop_back(DebugPrefix.length());
6151
auto *KernelEnvironmentGV =
6152
M.getNamedGlobal((KernelName + "_kernel_environment").str());
6153
assert(KernelEnvironmentGV && "Expected kernel environment global\n");
6154
auto *KernelEnvironmentInitializer = KernelEnvironmentGV->getInitializer();
6155
auto *NewInitializer = ConstantFoldInsertValueInstruction(
6156
KernelEnvironmentInitializer,
6157
ConstantInt::get(Int32, TeamsReductionDataSize), {0, 7});
6158
NewInitializer = ConstantFoldInsertValueInstruction(
6159
NewInitializer, ConstantInt::get(Int32, TeamsReductionBufferLength),
6160
{0, 8});
6161
KernelEnvironmentGV->setInitializer(NewInitializer);
6162
}
6163
6164
static MDNode *getNVPTXMDNode(Function &Kernel, StringRef Name) {
6165
Module &M = *Kernel.getParent();
6166
NamedMDNode *MD = M.getOrInsertNamedMetadata("nvvm.annotations");
6167
for (auto *Op : MD->operands()) {
6168
if (Op->getNumOperands() != 3)
6169
continue;
6170
auto *KernelOp = dyn_cast<ConstantAsMetadata>(Op->getOperand(0));
6171
if (!KernelOp || KernelOp->getValue() != &Kernel)
6172
continue;
6173
auto *Prop = dyn_cast<MDString>(Op->getOperand(1));
6174
if (!Prop || Prop->getString() != Name)
6175
continue;
6176
return Op;
6177
}
6178
return nullptr;
6179
}
6180
6181
static void updateNVPTXMetadata(Function &Kernel, StringRef Name, int32_t Value,
6182
bool Min) {
6183
// Update the "maxntidx" metadata for NVIDIA, or add it.
6184
MDNode *ExistingOp = getNVPTXMDNode(Kernel, Name);
6185
if (ExistingOp) {
6186
auto *OldVal = cast<ConstantAsMetadata>(ExistingOp->getOperand(2));
6187
int32_t OldLimit = cast<ConstantInt>(OldVal->getValue())->getZExtValue();
6188
ExistingOp->replaceOperandWith(
6189
2, ConstantAsMetadata::get(ConstantInt::get(
6190
OldVal->getValue()->getType(),
6191
Min ? std::min(OldLimit, Value) : std::max(OldLimit, Value))));
6192
} else {
6193
LLVMContext &Ctx = Kernel.getContext();
6194
Metadata *MDVals[] = {ConstantAsMetadata::get(&Kernel),
6195
MDString::get(Ctx, Name),
6196
ConstantAsMetadata::get(
6197
ConstantInt::get(Type::getInt32Ty(Ctx), Value))};
6198
// Append metadata to nvvm.annotations
6199
Module &M = *Kernel.getParent();
6200
NamedMDNode *MD = M.getOrInsertNamedMetadata("nvvm.annotations");
6201
MD->addOperand(MDNode::get(Ctx, MDVals));
6202
}
6203
}
6204
6205
std::pair<int32_t, int32_t>
6206
OpenMPIRBuilder::readThreadBoundsForKernel(const Triple &T, Function &Kernel) {
6207
int32_t ThreadLimit =
6208
Kernel.getFnAttributeAsParsedInteger("omp_target_thread_limit");
6209
6210
if (T.isAMDGPU()) {
6211
const auto &Attr = Kernel.getFnAttribute("amdgpu-flat-work-group-size");
6212
if (!Attr.isValid() || !Attr.isStringAttribute())
6213
return {0, ThreadLimit};
6214
auto [LBStr, UBStr] = Attr.getValueAsString().split(',');
6215
int32_t LB, UB;
6216
if (!llvm::to_integer(UBStr, UB, 10))
6217
return {0, ThreadLimit};
6218
UB = ThreadLimit ? std::min(ThreadLimit, UB) : UB;
6219
if (!llvm::to_integer(LBStr, LB, 10))
6220
return {0, UB};
6221
return {LB, UB};
6222
}
6223
6224
if (MDNode *ExistingOp = getNVPTXMDNode(Kernel, "maxntidx")) {
6225
auto *OldVal = cast<ConstantAsMetadata>(ExistingOp->getOperand(2));
6226
int32_t UB = cast<ConstantInt>(OldVal->getValue())->getZExtValue();
6227
return {0, ThreadLimit ? std::min(ThreadLimit, UB) : UB};
6228
}
6229
return {0, ThreadLimit};
6230
}
6231
6232
void OpenMPIRBuilder::writeThreadBoundsForKernel(const Triple &T,
6233
Function &Kernel, int32_t LB,
6234
int32_t UB) {
6235
Kernel.addFnAttr("omp_target_thread_limit", std::to_string(UB));
6236
6237
if (T.isAMDGPU()) {
6238
Kernel.addFnAttr("amdgpu-flat-work-group-size",
6239
llvm::utostr(LB) + "," + llvm::utostr(UB));
6240
return;
6241
}
6242
6243
updateNVPTXMetadata(Kernel, "maxntidx", UB, true);
6244
}
6245
6246
std::pair<int32_t, int32_t>
6247
OpenMPIRBuilder::readTeamBoundsForKernel(const Triple &, Function &Kernel) {
6248
// TODO: Read from backend annotations if available.
6249
return {0, Kernel.getFnAttributeAsParsedInteger("omp_target_num_teams")};
6250
}
6251
6252
void OpenMPIRBuilder::writeTeamsForKernel(const Triple &T, Function &Kernel,
6253
int32_t LB, int32_t UB) {
6254
if (T.isNVPTX())
6255
if (UB > 0)
6256
updateNVPTXMetadata(Kernel, "maxclusterrank", UB, true);
6257
if (T.isAMDGPU())
6258
Kernel.addFnAttr("amdgpu-max-num-workgroups", llvm::utostr(LB) + ",1,1");
6259
6260
Kernel.addFnAttr("omp_target_num_teams", std::to_string(LB));
6261
}
6262
6263
void OpenMPIRBuilder::setOutlinedTargetRegionFunctionAttributes(
6264
Function *OutlinedFn) {
6265
if (Config.isTargetDevice()) {
6266
OutlinedFn->setLinkage(GlobalValue::WeakODRLinkage);
6267
// TODO: Determine if DSO local can be set to true.
6268
OutlinedFn->setDSOLocal(false);
6269
OutlinedFn->setVisibility(GlobalValue::ProtectedVisibility);
6270
if (T.isAMDGCN())
6271
OutlinedFn->setCallingConv(CallingConv::AMDGPU_KERNEL);
6272
}
6273
}
6274
6275
Constant *OpenMPIRBuilder::createOutlinedFunctionID(Function *OutlinedFn,
6276
StringRef EntryFnIDName) {
6277
if (Config.isTargetDevice()) {
6278
assert(OutlinedFn && "The outlined function must exist if embedded");
6279
return OutlinedFn;
6280
}
6281
6282
return new GlobalVariable(
6283
M, Builder.getInt8Ty(), /*isConstant=*/true, GlobalValue::WeakAnyLinkage,
6284
Constant::getNullValue(Builder.getInt8Ty()), EntryFnIDName);
6285
}
6286
6287
Constant *OpenMPIRBuilder::createTargetRegionEntryAddr(Function *OutlinedFn,
6288
StringRef EntryFnName) {
6289
if (OutlinedFn)
6290
return OutlinedFn;
6291
6292
assert(!M.getGlobalVariable(EntryFnName, true) &&
6293
"Named kernel already exists?");
6294
return new GlobalVariable(
6295
M, Builder.getInt8Ty(), /*isConstant=*/true, GlobalValue::InternalLinkage,
6296
Constant::getNullValue(Builder.getInt8Ty()), EntryFnName);
6297
}
6298
6299
void OpenMPIRBuilder::emitTargetRegionFunction(
6300
TargetRegionEntryInfo &EntryInfo,
6301
FunctionGenCallback &GenerateFunctionCallback, bool IsOffloadEntry,
6302
Function *&OutlinedFn, Constant *&OutlinedFnID) {
6303
6304
SmallString<64> EntryFnName;
6305
OffloadInfoManager.getTargetRegionEntryFnName(EntryFnName, EntryInfo);
6306
6307
OutlinedFn = Config.isTargetDevice() || !Config.openMPOffloadMandatory()
6308
? GenerateFunctionCallback(EntryFnName)
6309
: nullptr;
6310
6311
// If this target outline function is not an offload entry, we don't need to
6312
// register it. This may be in the case of a false if clause, or if there are
6313
// no OpenMP targets.
6314
if (!IsOffloadEntry)
6315
return;
6316
6317
std::string EntryFnIDName =
6318
Config.isTargetDevice()
6319
? std::string(EntryFnName)
6320
: createPlatformSpecificName({EntryFnName, "region_id"});
6321
6322
OutlinedFnID = registerTargetRegionFunction(EntryInfo, OutlinedFn,
6323
EntryFnName, EntryFnIDName);
6324
}
6325
6326
Constant *OpenMPIRBuilder::registerTargetRegionFunction(
6327
TargetRegionEntryInfo &EntryInfo, Function *OutlinedFn,
6328
StringRef EntryFnName, StringRef EntryFnIDName) {
6329
if (OutlinedFn)
6330
setOutlinedTargetRegionFunctionAttributes(OutlinedFn);
6331
auto OutlinedFnID = createOutlinedFunctionID(OutlinedFn, EntryFnIDName);
6332
auto EntryAddr = createTargetRegionEntryAddr(OutlinedFn, EntryFnName);
6333
OffloadInfoManager.registerTargetRegionEntryInfo(
6334
EntryInfo, EntryAddr, OutlinedFnID,
6335
OffloadEntriesInfoManager::OMPTargetRegionEntryTargetRegion);
6336
return OutlinedFnID;
6337
}
6338
6339
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetData(
6340
const LocationDescription &Loc, InsertPointTy AllocaIP,
6341
InsertPointTy CodeGenIP, Value *DeviceID, Value *IfCond,
6342
TargetDataInfo &Info, GenMapInfoCallbackTy GenMapInfoCB,
6343
omp::RuntimeFunction *MapperFunc,
6344
function_ref<InsertPointTy(InsertPointTy CodeGenIP, BodyGenTy BodyGenType)>
6345
BodyGenCB,
6346
function_ref<void(unsigned int, Value *)> DeviceAddrCB,
6347
function_ref<Value *(unsigned int)> CustomMapperCB, Value *SrcLocInfo) {
6348
if (!updateToLocation(Loc))
6349
return InsertPointTy();
6350
6351
// Disable TargetData CodeGen on Device pass.
6352
if (Config.IsTargetDevice.value_or(false)) {
6353
if (BodyGenCB)
6354
Builder.restoreIP(BodyGenCB(Builder.saveIP(), BodyGenTy::NoPriv));
6355
return Builder.saveIP();
6356
}
6357
6358
Builder.restoreIP(CodeGenIP);
6359
bool IsStandAlone = !BodyGenCB;
6360
MapInfosTy *MapInfo;
6361
// Generate the code for the opening of the data environment. Capture all the
6362
// arguments of the runtime call by reference because they are used in the
6363
// closing of the region.
6364
auto BeginThenGen = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
6365
MapInfo = &GenMapInfoCB(Builder.saveIP());
6366
emitOffloadingArrays(AllocaIP, Builder.saveIP(), *MapInfo, Info,
6367
/*IsNonContiguous=*/true, DeviceAddrCB,
6368
CustomMapperCB);
6369
6370
TargetDataRTArgs RTArgs;
6371
emitOffloadingArraysArgument(Builder, RTArgs, Info,
6372
!MapInfo->Names.empty());
6373
6374
// Emit the number of elements in the offloading arrays.
6375
Value *PointerNum = Builder.getInt32(Info.NumberOfPtrs);
6376
6377
// Source location for the ident struct
6378
if (!SrcLocInfo) {
6379
uint32_t SrcLocStrSize;
6380
Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
6381
SrcLocInfo = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6382
}
6383
6384
Value *OffloadingArgs[] = {SrcLocInfo, DeviceID,
6385
PointerNum, RTArgs.BasePointersArray,
6386
RTArgs.PointersArray, RTArgs.SizesArray,
6387
RTArgs.MapTypesArray, RTArgs.MapNamesArray,
6388
RTArgs.MappersArray};
6389
6390
if (IsStandAlone) {
6391
assert(MapperFunc && "MapperFunc missing for standalone target data");
6392
Builder.CreateCall(getOrCreateRuntimeFunctionPtr(*MapperFunc),
6393
OffloadingArgs);
6394
} else {
6395
Function *BeginMapperFunc = getOrCreateRuntimeFunctionPtr(
6396
omp::OMPRTL___tgt_target_data_begin_mapper);
6397
6398
Builder.CreateCall(BeginMapperFunc, OffloadingArgs);
6399
6400
for (auto DeviceMap : Info.DevicePtrInfoMap) {
6401
if (isa<AllocaInst>(DeviceMap.second.second)) {
6402
auto *LI =
6403
Builder.CreateLoad(Builder.getPtrTy(), DeviceMap.second.first);
6404
Builder.CreateStore(LI, DeviceMap.second.second);
6405
}
6406
}
6407
6408
// If device pointer privatization is required, emit the body of the
6409
// region here. It will have to be duplicated: with and without
6410
// privatization.
6411
Builder.restoreIP(BodyGenCB(Builder.saveIP(), BodyGenTy::Priv));
6412
}
6413
};
6414
6415
// If we need device pointer privatization, we need to emit the body of the
6416
// region with no privatization in the 'else' branch of the conditional.
6417
// Otherwise, we don't have to do anything.
6418
auto BeginElseGen = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
6419
Builder.restoreIP(BodyGenCB(Builder.saveIP(), BodyGenTy::DupNoPriv));
6420
};
6421
6422
// Generate code for the closing of the data region.
6423
auto EndThenGen = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
6424
TargetDataRTArgs RTArgs;
6425
emitOffloadingArraysArgument(Builder, RTArgs, Info, !MapInfo->Names.empty(),
6426
/*ForEndCall=*/true);
6427
6428
// Emit the number of elements in the offloading arrays.
6429
Value *PointerNum = Builder.getInt32(Info.NumberOfPtrs);
6430
6431
// Source location for the ident struct
6432
if (!SrcLocInfo) {
6433
uint32_t SrcLocStrSize;
6434
Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
6435
SrcLocInfo = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6436
}
6437
6438
Value *OffloadingArgs[] = {SrcLocInfo, DeviceID,
6439
PointerNum, RTArgs.BasePointersArray,
6440
RTArgs.PointersArray, RTArgs.SizesArray,
6441
RTArgs.MapTypesArray, RTArgs.MapNamesArray,
6442
RTArgs.MappersArray};
6443
Function *EndMapperFunc =
6444
getOrCreateRuntimeFunctionPtr(omp::OMPRTL___tgt_target_data_end_mapper);
6445
6446
Builder.CreateCall(EndMapperFunc, OffloadingArgs);
6447
};
6448
6449
// We don't have to do anything to close the region if the if clause evaluates
6450
// to false.
6451
auto EndElseGen = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {};
6452
6453
if (BodyGenCB) {
6454
if (IfCond) {
6455
emitIfClause(IfCond, BeginThenGen, BeginElseGen, AllocaIP);
6456
} else {
6457
BeginThenGen(AllocaIP, Builder.saveIP());
6458
}
6459
6460
// If we don't require privatization of device pointers, we emit the body in
6461
// between the runtime calls. This avoids duplicating the body code.
6462
Builder.restoreIP(BodyGenCB(Builder.saveIP(), BodyGenTy::NoPriv));
6463
6464
if (IfCond) {
6465
emitIfClause(IfCond, EndThenGen, EndElseGen, AllocaIP);
6466
} else {
6467
EndThenGen(AllocaIP, Builder.saveIP());
6468
}
6469
} else {
6470
if (IfCond) {
6471
emitIfClause(IfCond, BeginThenGen, EndElseGen, AllocaIP);
6472
} else {
6473
BeginThenGen(AllocaIP, Builder.saveIP());
6474
}
6475
}
6476
6477
return Builder.saveIP();
6478
}
6479
6480
FunctionCallee
6481
OpenMPIRBuilder::createForStaticInitFunction(unsigned IVSize, bool IVSigned,
6482
bool IsGPUDistribute) {
6483
assert((IVSize == 32 || IVSize == 64) &&
6484
"IV size is not compatible with the omp runtime");
6485
RuntimeFunction Name;
6486
if (IsGPUDistribute)
6487
Name = IVSize == 32
6488
? (IVSigned ? omp::OMPRTL___kmpc_distribute_static_init_4
6489
: omp::OMPRTL___kmpc_distribute_static_init_4u)
6490
: (IVSigned ? omp::OMPRTL___kmpc_distribute_static_init_8
6491
: omp::OMPRTL___kmpc_distribute_static_init_8u);
6492
else
6493
Name = IVSize == 32 ? (IVSigned ? omp::OMPRTL___kmpc_for_static_init_4
6494
: omp::OMPRTL___kmpc_for_static_init_4u)
6495
: (IVSigned ? omp::OMPRTL___kmpc_for_static_init_8
6496
: omp::OMPRTL___kmpc_for_static_init_8u);
6497
6498
return getOrCreateRuntimeFunction(M, Name);
6499
}
6500
6501
FunctionCallee OpenMPIRBuilder::createDispatchInitFunction(unsigned IVSize,
6502
bool IVSigned) {
6503
assert((IVSize == 32 || IVSize == 64) &&
6504
"IV size is not compatible with the omp runtime");
6505
RuntimeFunction Name = IVSize == 32
6506
? (IVSigned ? omp::OMPRTL___kmpc_dispatch_init_4
6507
: omp::OMPRTL___kmpc_dispatch_init_4u)
6508
: (IVSigned ? omp::OMPRTL___kmpc_dispatch_init_8
6509
: omp::OMPRTL___kmpc_dispatch_init_8u);
6510
6511
return getOrCreateRuntimeFunction(M, Name);
6512
}
6513
6514
FunctionCallee OpenMPIRBuilder::createDispatchNextFunction(unsigned IVSize,
6515
bool IVSigned) {
6516
assert((IVSize == 32 || IVSize == 64) &&
6517
"IV size is not compatible with the omp runtime");
6518
RuntimeFunction Name = IVSize == 32
6519
? (IVSigned ? omp::OMPRTL___kmpc_dispatch_next_4
6520
: omp::OMPRTL___kmpc_dispatch_next_4u)
6521
: (IVSigned ? omp::OMPRTL___kmpc_dispatch_next_8
6522
: omp::OMPRTL___kmpc_dispatch_next_8u);
6523
6524
return getOrCreateRuntimeFunction(M, Name);
6525
}
6526
6527
FunctionCallee OpenMPIRBuilder::createDispatchFiniFunction(unsigned IVSize,
6528
bool IVSigned) {
6529
assert((IVSize == 32 || IVSize == 64) &&
6530
"IV size is not compatible with the omp runtime");
6531
RuntimeFunction Name = IVSize == 32
6532
? (IVSigned ? omp::OMPRTL___kmpc_dispatch_fini_4
6533
: omp::OMPRTL___kmpc_dispatch_fini_4u)
6534
: (IVSigned ? omp::OMPRTL___kmpc_dispatch_fini_8
6535
: omp::OMPRTL___kmpc_dispatch_fini_8u);
6536
6537
return getOrCreateRuntimeFunction(M, Name);
6538
}
6539
6540
FunctionCallee OpenMPIRBuilder::createDispatchDeinitFunction() {
6541
return getOrCreateRuntimeFunction(M, omp::OMPRTL___kmpc_dispatch_deinit);
6542
}
6543
6544
static Function *createOutlinedFunction(
6545
OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, StringRef FuncName,
6546
SmallVectorImpl<Value *> &Inputs,
6547
OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
6548
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
6549
SmallVector<Type *> ParameterTypes;
6550
if (OMPBuilder.Config.isTargetDevice()) {
6551
// Add the "implicit" runtime argument we use to provide launch specific
6552
// information for target devices.
6553
auto *Int8PtrTy = PointerType::getUnqual(Builder.getContext());
6554
ParameterTypes.push_back(Int8PtrTy);
6555
6556
// All parameters to target devices are passed as pointers
6557
// or i64. This assumes 64-bit address spaces/pointers.
6558
for (auto &Arg : Inputs)
6559
ParameterTypes.push_back(Arg->getType()->isPointerTy()
6560
? Arg->getType()
6561
: Type::getInt64Ty(Builder.getContext()));
6562
} else {
6563
for (auto &Arg : Inputs)
6564
ParameterTypes.push_back(Arg->getType());
6565
}
6566
6567
auto FuncType = FunctionType::get(Builder.getVoidTy(), ParameterTypes,
6568
/*isVarArg*/ false);
6569
auto Func = Function::Create(FuncType, GlobalValue::InternalLinkage, FuncName,
6570
Builder.GetInsertBlock()->getModule());
6571
6572
// Save insert point.
6573
auto OldInsertPoint = Builder.saveIP();
6574
6575
// Generate the region into the function.
6576
BasicBlock *EntryBB = BasicBlock::Create(Builder.getContext(), "entry", Func);
6577
Builder.SetInsertPoint(EntryBB);
6578
6579
// Insert target init call in the device compilation pass.
6580
if (OMPBuilder.Config.isTargetDevice())
6581
Builder.restoreIP(OMPBuilder.createTargetInit(Builder, /*IsSPMD*/ false));
6582
6583
BasicBlock *UserCodeEntryBB = Builder.GetInsertBlock();
6584
6585
// As we embed the user code in the middle of our target region after we
6586
// generate entry code, we must move what allocas we can into the entry
6587
// block to avoid possible breaking optimisations for device
6588
if (OMPBuilder.Config.isTargetDevice())
6589
OMPBuilder.ConstantAllocaRaiseCandidates.emplace_back(Func);
6590
6591
// Insert target deinit call in the device compilation pass.
6592
Builder.restoreIP(CBFunc(Builder.saveIP(), Builder.saveIP()));
6593
if (OMPBuilder.Config.isTargetDevice())
6594
OMPBuilder.createTargetDeinit(Builder);
6595
6596
// Insert return instruction.
6597
Builder.CreateRetVoid();
6598
6599
// New Alloca IP at entry point of created device function.
6600
Builder.SetInsertPoint(EntryBB->getFirstNonPHI());
6601
auto AllocaIP = Builder.saveIP();
6602
6603
Builder.SetInsertPoint(UserCodeEntryBB->getFirstNonPHIOrDbg());
6604
6605
// Skip the artificial dyn_ptr on the device.
6606
const auto &ArgRange =
6607
OMPBuilder.Config.isTargetDevice()
6608
? make_range(Func->arg_begin() + 1, Func->arg_end())
6609
: Func->args();
6610
6611
auto ReplaceValue = [](Value *Input, Value *InputCopy, Function *Func) {
6612
// Things like GEP's can come in the form of Constants. Constants and
6613
// ConstantExpr's do not have access to the knowledge of what they're
6614
// contained in, so we must dig a little to find an instruction so we
6615
// can tell if they're used inside of the function we're outlining. We
6616
// also replace the original constant expression with a new instruction
6617
// equivalent; an instruction as it allows easy modification in the
6618
// following loop, as we can now know the constant (instruction) is
6619
// owned by our target function and replaceUsesOfWith can now be invoked
6620
// on it (cannot do this with constants it seems). A brand new one also
6621
// allows us to be cautious as it is perhaps possible the old expression
6622
// was used inside of the function but exists and is used externally
6623
// (unlikely by the nature of a Constant, but still).
6624
// NOTE: We cannot remove dead constants that have been rewritten to
6625
// instructions at this stage, we run the risk of breaking later lowering
6626
// by doing so as we could still be in the process of lowering the module
6627
// from MLIR to LLVM-IR and the MLIR lowering may still require the original
6628
// constants we have created rewritten versions of.
6629
if (auto *Const = dyn_cast<Constant>(Input))
6630
convertUsersOfConstantsToInstructions(Const, Func, false);
6631
6632
// Collect all the instructions
6633
for (User *User : make_early_inc_range(Input->users()))
6634
if (auto *Instr = dyn_cast<Instruction>(User))
6635
if (Instr->getFunction() == Func)
6636
Instr->replaceUsesOfWith(Input, InputCopy);
6637
};
6638
6639
SmallVector<std::pair<Value *, Value *>> DeferredReplacement;
6640
6641
// Rewrite uses of input valus to parameters.
6642
for (auto InArg : zip(Inputs, ArgRange)) {
6643
Value *Input = std::get<0>(InArg);
6644
Argument &Arg = std::get<1>(InArg);
6645
Value *InputCopy = nullptr;
6646
6647
Builder.restoreIP(
6648
ArgAccessorFuncCB(Arg, Input, InputCopy, AllocaIP, Builder.saveIP()));
6649
6650
// In certain cases a Global may be set up for replacement, however, this
6651
// Global may be used in multiple arguments to the kernel, just segmented
6652
// apart, for example, if we have a global array, that is sectioned into
6653
// multiple mappings (technically not legal in OpenMP, but there is a case
6654
// in Fortran for Common Blocks where this is neccesary), we will end up
6655
// with GEP's into this array inside the kernel, that refer to the Global
6656
// but are technically seperate arguments to the kernel for all intents and
6657
// purposes. If we have mapped a segment that requires a GEP into the 0-th
6658
// index, it will fold into an referal to the Global, if we then encounter
6659
// this folded GEP during replacement all of the references to the
6660
// Global in the kernel will be replaced with the argument we have generated
6661
// that corresponds to it, including any other GEP's that refer to the
6662
// Global that may be other arguments. This will invalidate all of the other
6663
// preceding mapped arguments that refer to the same global that may be
6664
// seperate segments. To prevent this, we defer global processing until all
6665
// other processing has been performed.
6666
if (llvm::isa<llvm::GlobalValue>(std::get<0>(InArg)) ||
6667
llvm::isa<llvm::GlobalObject>(std::get<0>(InArg)) ||
6668
llvm::isa<llvm::GlobalVariable>(std::get<0>(InArg))) {
6669
DeferredReplacement.push_back(std::make_pair(Input, InputCopy));
6670
continue;
6671
}
6672
6673
ReplaceValue(Input, InputCopy, Func);
6674
}
6675
6676
// Replace all of our deferred Input values, currently just Globals.
6677
for (auto Deferred : DeferredReplacement)
6678
ReplaceValue(std::get<0>(Deferred), std::get<1>(Deferred), Func);
6679
6680
// Restore insert point.
6681
Builder.restoreIP(OldInsertPoint);
6682
6683
return Func;
6684
}
6685
6686
/// Create an entry point for a target task with the following.
6687
/// It'll have the following signature
6688
/// void @.omp_target_task_proxy_func(i32 %thread.id, ptr %task)
6689
/// This function is called from emitTargetTask once the
6690
/// code to launch the target kernel has been outlined already.
6691
static Function *emitTargetTaskProxyFunction(OpenMPIRBuilder &OMPBuilder,
6692
IRBuilderBase &Builder,
6693
CallInst *StaleCI) {
6694
Module &M = OMPBuilder.M;
6695
// KernelLaunchFunction is the target launch function, i.e.
6696
// the function that sets up kernel arguments and calls
6697
// __tgt_target_kernel to launch the kernel on the device.
6698
//
6699
Function *KernelLaunchFunction = StaleCI->getCalledFunction();
6700
6701
// StaleCI is the CallInst which is the call to the outlined
6702
// target kernel launch function. If there are values that the
6703
// outlined function uses then these are aggregated into a structure
6704
// which is passed as the second argument. If not, then there's
6705
// only one argument, the threadID. So, StaleCI can be
6706
//
6707
// %structArg = alloca { ptr, ptr }, align 8
6708
// %gep_ = getelementptr { ptr, ptr }, ptr %structArg, i32 0, i32 0
6709
// store ptr %20, ptr %gep_, align 8
6710
// %gep_8 = getelementptr { ptr, ptr }, ptr %structArg, i32 0, i32 1
6711
// store ptr %21, ptr %gep_8, align 8
6712
// call void @_QQmain..omp_par.1(i32 %global.tid.val6, ptr %structArg)
6713
//
6714
// OR
6715
//
6716
// call void @_QQmain..omp_par.1(i32 %global.tid.val6)
6717
OpenMPIRBuilder::InsertPointTy IP(StaleCI->getParent(),
6718
StaleCI->getIterator());
6719
LLVMContext &Ctx = StaleCI->getParent()->getContext();
6720
Type *ThreadIDTy = Type::getInt32Ty(Ctx);
6721
Type *TaskPtrTy = OMPBuilder.TaskPtr;
6722
Type *TaskTy = OMPBuilder.Task;
6723
auto ProxyFnTy =
6724
FunctionType::get(Builder.getVoidTy(), {ThreadIDTy, TaskPtrTy},
6725
/* isVarArg */ false);
6726
auto ProxyFn = Function::Create(ProxyFnTy, GlobalValue::InternalLinkage,
6727
".omp_target_task_proxy_func",
6728
Builder.GetInsertBlock()->getModule());
6729
ProxyFn->getArg(0)->setName("thread.id");
6730
ProxyFn->getArg(1)->setName("task");
6731
6732
BasicBlock *EntryBB =
6733
BasicBlock::Create(Builder.getContext(), "entry", ProxyFn);
6734
Builder.SetInsertPoint(EntryBB);
6735
6736
bool HasShareds = StaleCI->arg_size() > 1;
6737
// TODO: This is a temporary assert to prove to ourselves that
6738
// the outlined target launch function is always going to have
6739
// atmost two arguments if there is any data shared between
6740
// host and device.
6741
assert((!HasShareds || (StaleCI->arg_size() == 2)) &&
6742
"StaleCI with shareds should have exactly two arguments.");
6743
if (HasShareds) {
6744
auto *ArgStructAlloca = dyn_cast<AllocaInst>(StaleCI->getArgOperand(1));
6745
assert(ArgStructAlloca &&
6746
"Unable to find the alloca instruction corresponding to arguments "
6747
"for extracted function");
6748
auto *ArgStructType =
6749
dyn_cast<StructType>(ArgStructAlloca->getAllocatedType());
6750
6751
AllocaInst *NewArgStructAlloca =
6752
Builder.CreateAlloca(ArgStructType, nullptr, "structArg");
6753
Value *TaskT = ProxyFn->getArg(1);
6754
Value *ThreadId = ProxyFn->getArg(0);
6755
Value *SharedsSize =
6756
Builder.getInt64(M.getDataLayout().getTypeStoreSize(ArgStructType));
6757
6758
Value *Shareds = Builder.CreateStructGEP(TaskTy, TaskT, 0);
6759
LoadInst *LoadShared =
6760
Builder.CreateLoad(PointerType::getUnqual(Ctx), Shareds);
6761
6762
Builder.CreateMemCpy(
6763
NewArgStructAlloca, NewArgStructAlloca->getAlign(), LoadShared,
6764
LoadShared->getPointerAlignment(M.getDataLayout()), SharedsSize);
6765
6766
Builder.CreateCall(KernelLaunchFunction, {ThreadId, NewArgStructAlloca});
6767
}
6768
Builder.CreateRetVoid();
6769
return ProxyFn;
6770
}
6771
static void emitTargetOutlinedFunction(
6772
OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
6773
TargetRegionEntryInfo &EntryInfo, Function *&OutlinedFn,
6774
Constant *&OutlinedFnID, SmallVectorImpl<Value *> &Inputs,
6775
OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
6776
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
6777
6778
OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction =
6779
[&OMPBuilder, &Builder, &Inputs, &CBFunc,
6780
&ArgAccessorFuncCB](StringRef EntryFnName) {
6781
return createOutlinedFunction(OMPBuilder, Builder, EntryFnName, Inputs,
6782
CBFunc, ArgAccessorFuncCB);
6783
};
6784
6785
OMPBuilder.emitTargetRegionFunction(EntryInfo, GenerateOutlinedFunction, true,
6786
OutlinedFn, OutlinedFnID);
6787
}
6788
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetTask(
6789
Function *OutlinedFn, Value *OutlinedFnID,
6790
EmitFallbackCallbackTy EmitTargetCallFallbackCB, TargetKernelArgs &Args,
6791
Value *DeviceID, Value *RTLoc, OpenMPIRBuilder::InsertPointTy AllocaIP,
6792
SmallVector<llvm::OpenMPIRBuilder::DependData> &Dependencies,
6793
bool HasNoWait) {
6794
6795
// When we arrive at this function, the target region itself has been
6796
// outlined into the function OutlinedFn.
6797
// So at ths point, for
6798
// --------------------------------------------------
6799
// void user_code_that_offloads(...) {
6800
// omp target depend(..) map(from:a) map(to:b, c)
6801
// a = b + c
6802
// }
6803
//
6804
// --------------------------------------------------
6805
//
6806
// we have
6807
//
6808
// --------------------------------------------------
6809
//
6810
// void user_code_that_offloads(...) {
6811
// %.offload_baseptrs = alloca [3 x ptr], align 8
6812
// %.offload_ptrs = alloca [3 x ptr], align 8
6813
// %.offload_mappers = alloca [3 x ptr], align 8
6814
// ;; target region has been outlined and now we need to
6815
// ;; offload to it via a target task.
6816
// }
6817
// void outlined_device_function(ptr a, ptr b, ptr c) {
6818
// *a = *b + *c
6819
// }
6820
//
6821
// We have to now do the following
6822
// (i) Make an offloading call to outlined_device_function using the OpenMP
6823
// RTL. See 'kernel_launch_function' in the pseudo code below. This is
6824
// emitted by emitKernelLaunch
6825
// (ii) Create a task entry point function that calls kernel_launch_function
6826
// and is the entry point for the target task. See
6827
// '@.omp_target_task_proxy_func in the pseudocode below.
6828
// (iii) Create a task with the task entry point created in (ii)
6829
//
6830
// That is we create the following
6831
//
6832
// void user_code_that_offloads(...) {
6833
// %.offload_baseptrs = alloca [3 x ptr], align 8
6834
// %.offload_ptrs = alloca [3 x ptr], align 8
6835
// %.offload_mappers = alloca [3 x ptr], align 8
6836
//
6837
// %structArg = alloca { ptr, ptr, ptr }, align 8
6838
// %strucArg[0] = %.offload_baseptrs
6839
// %strucArg[1] = %.offload_ptrs
6840
// %strucArg[2] = %.offload_mappers
6841
// proxy_target_task = @__kmpc_omp_task_alloc(...,
6842
// @.omp_target_task_proxy_func)
6843
// memcpy(proxy_target_task->shareds, %structArg, sizeof(structArg))
6844
// dependencies_array = ...
6845
// ;; if nowait not present
6846
// call @__kmpc_omp_wait_deps(..., dependencies_array)
6847
// call @__kmpc_omp_task_begin_if0(...)
6848
// call @ @.omp_target_task_proxy_func(i32 thread_id, ptr
6849
// %proxy_target_task) call @__kmpc_omp_task_complete_if0(...)
6850
// }
6851
//
6852
// define internal void @.omp_target_task_proxy_func(i32 %thread.id,
6853
// ptr %task) {
6854
// %structArg = alloca {ptr, ptr, ptr}
6855
// %shared_data = load (getelementptr %task, 0, 0)
6856
// mempcy(%structArg, %shared_data, sizeof(structArg))
6857
// kernel_launch_function(%thread.id, %structArg)
6858
// }
6859
//
6860
// We need the proxy function because the signature of the task entry point
6861
// expected by kmpc_omp_task is always the same and will be different from
6862
// that of the kernel_launch function.
6863
//
6864
// kernel_launch_function is generated by emitKernelLaunch and has the
6865
// always_inline attribute.
6866
// void kernel_launch_function(thread_id,
6867
// structArg) alwaysinline {
6868
// %kernel_args = alloca %struct.__tgt_kernel_arguments, align 8
6869
// offload_baseptrs = load(getelementptr structArg, 0, 0)
6870
// offload_ptrs = load(getelementptr structArg, 0, 1)
6871
// offload_mappers = load(getelementptr structArg, 0, 2)
6872
// ; setup kernel_args using offload_baseptrs, offload_ptrs and
6873
// ; offload_mappers
6874
// call i32 @__tgt_target_kernel(...,
6875
// outlined_device_function,
6876
// ptr %kernel_args)
6877
// }
6878
// void outlined_device_function(ptr a, ptr b, ptr c) {
6879
// *a = *b + *c
6880
// }
6881
//
6882
BasicBlock *TargetTaskBodyBB =
6883
splitBB(Builder, /*CreateBranch=*/true, "target.task.body");
6884
BasicBlock *TargetTaskAllocaBB =
6885
splitBB(Builder, /*CreateBranch=*/true, "target.task.alloca");
6886
6887
InsertPointTy TargetTaskAllocaIP(TargetTaskAllocaBB,
6888
TargetTaskAllocaBB->begin());
6889
InsertPointTy TargetTaskBodyIP(TargetTaskBodyBB, TargetTaskBodyBB->begin());
6890
6891
OutlineInfo OI;
6892
OI.EntryBB = TargetTaskAllocaBB;
6893
OI.OuterAllocaBB = AllocaIP.getBlock();
6894
6895
// Add the thread ID argument.
6896
SmallVector<Instruction *, 4> ToBeDeleted;
6897
OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
6898
Builder, AllocaIP, ToBeDeleted, TargetTaskAllocaIP, "global.tid", false));
6899
6900
Builder.restoreIP(TargetTaskBodyIP);
6901
6902
// emitKernelLaunch makes the necessary runtime call to offload the kernel.
6903
// We then outline all that code into a separate function
6904
// ('kernel_launch_function' in the pseudo code above). This function is then
6905
// called by the target task proxy function (see
6906
// '@.omp_target_task_proxy_func' in the pseudo code above)
6907
// "@.omp_target_task_proxy_func' is generated by emitTargetTaskProxyFunction
6908
Builder.restoreIP(emitKernelLaunch(Builder, OutlinedFn, OutlinedFnID,
6909
EmitTargetCallFallbackCB, Args, DeviceID,
6910
RTLoc, TargetTaskAllocaIP));
6911
6912
OI.ExitBB = Builder.saveIP().getBlock();
6913
OI.PostOutlineCB = [this, ToBeDeleted, Dependencies,
6914
HasNoWait](Function &OutlinedFn) mutable {
6915
assert(OutlinedFn.getNumUses() == 1 &&
6916
"there must be a single user for the outlined function");
6917
6918
CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back());
6919
bool HasShareds = StaleCI->arg_size() > 1;
6920
6921
Function *ProxyFn = emitTargetTaskProxyFunction(*this, Builder, StaleCI);
6922
6923
LLVM_DEBUG(dbgs() << "Proxy task entry function created: " << *ProxyFn
6924
<< "\n");
6925
6926
Builder.SetInsertPoint(StaleCI);
6927
6928
// Gather the arguments for emitting the runtime call.
6929
uint32_t SrcLocStrSize;
6930
Constant *SrcLocStr =
6931
getOrCreateSrcLocStr(LocationDescription(Builder), SrcLocStrSize);
6932
Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
6933
6934
// @__kmpc_omp_task_alloc
6935
Function *TaskAllocFn =
6936
getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc);
6937
6938
// Arguments - `loc_ref` (Ident) and `gtid` (ThreadID)
6939
// call.
6940
Value *ThreadID = getOrCreateThreadID(Ident);
6941
6942
// Argument - `sizeof_kmp_task_t` (TaskSize)
6943
// Tasksize refers to the size in bytes of kmp_task_t data structure
6944
// including private vars accessed in task.
6945
// TODO: add kmp_task_t_with_privates (privates)
6946
Value *TaskSize =
6947
Builder.getInt64(M.getDataLayout().getTypeStoreSize(Task));
6948
6949
// Argument - `sizeof_shareds` (SharedsSize)
6950
// SharedsSize refers to the shareds array size in the kmp_task_t data
6951
// structure.
6952
Value *SharedsSize = Builder.getInt64(0);
6953
if (HasShareds) {
6954
auto *ArgStructAlloca = dyn_cast<AllocaInst>(StaleCI->getArgOperand(1));
6955
assert(ArgStructAlloca &&
6956
"Unable to find the alloca instruction corresponding to arguments "
6957
"for extracted function");
6958
auto *ArgStructType =
6959
dyn_cast<StructType>(ArgStructAlloca->getAllocatedType());
6960
assert(ArgStructType && "Unable to find struct type corresponding to "
6961
"arguments for extracted function");
6962
SharedsSize =
6963
Builder.getInt64(M.getDataLayout().getTypeStoreSize(ArgStructType));
6964
}
6965
6966
// Argument - `flags`
6967
// Task is tied iff (Flags & 1) == 1.
6968
// Task is untied iff (Flags & 1) == 0.
6969
// Task is final iff (Flags & 2) == 2.
6970
// Task is not final iff (Flags & 2) == 0.
6971
// A target task is not final and is untied.
6972
Value *Flags = Builder.getInt32(0);
6973
6974
// Emit the @__kmpc_omp_task_alloc runtime call
6975
// The runtime call returns a pointer to an area where the task captured
6976
// variables must be copied before the task is run (TaskData)
6977
CallInst *TaskData = Builder.CreateCall(
6978
TaskAllocFn, {/*loc_ref=*/Ident, /*gtid=*/ThreadID, /*flags=*/Flags,
6979
/*sizeof_task=*/TaskSize, /*sizeof_shared=*/SharedsSize,
6980
/*task_func=*/ProxyFn});
6981
6982
if (HasShareds) {
6983
Value *Shareds = StaleCI->getArgOperand(1);
6984
Align Alignment = TaskData->getPointerAlignment(M.getDataLayout());
6985
Value *TaskShareds = Builder.CreateLoad(VoidPtr, TaskData);
6986
Builder.CreateMemCpy(TaskShareds, Alignment, Shareds, Alignment,
6987
SharedsSize);
6988
}
6989
6990
Value *DepArray = emitTaskDependencies(*this, Dependencies);
6991
6992
// ---------------------------------------------------------------
6993
// V5.2 13.8 target construct
6994
// If the nowait clause is present, execution of the target task
6995
// may be deferred. If the nowait clause is not present, the target task is
6996
// an included task.
6997
// ---------------------------------------------------------------
6998
// The above means that the lack of a nowait on the target construct
6999
// translates to '#pragma omp task if(0)'
7000
if (!HasNoWait) {
7001
if (DepArray) {
7002
Function *TaskWaitFn =
7003
getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_wait_deps);
7004
Builder.CreateCall(
7005
TaskWaitFn,
7006
{/*loc_ref=*/Ident, /*gtid=*/ThreadID,
7007
/*ndeps=*/Builder.getInt32(Dependencies.size()),
7008
/*dep_list=*/DepArray,
7009
/*ndeps_noalias=*/ConstantInt::get(Builder.getInt32Ty(), 0),
7010
/*noalias_dep_list=*/
7011
ConstantPointerNull::get(PointerType::getUnqual(M.getContext()))});
7012
}
7013
// Included task.
7014
Function *TaskBeginFn =
7015
getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_begin_if0);
7016
Function *TaskCompleteFn =
7017
getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_complete_if0);
7018
Builder.CreateCall(TaskBeginFn, {Ident, ThreadID, TaskData});
7019
CallInst *CI = nullptr;
7020
if (HasShareds)
7021
CI = Builder.CreateCall(ProxyFn, {ThreadID, TaskData});
7022
else
7023
CI = Builder.CreateCall(ProxyFn, {ThreadID});
7024
CI->setDebugLoc(StaleCI->getDebugLoc());
7025
Builder.CreateCall(TaskCompleteFn, {Ident, ThreadID, TaskData});
7026
} else if (DepArray) {
7027
// HasNoWait - meaning the task may be deferred. Call
7028
// __kmpc_omp_task_with_deps if there are dependencies,
7029
// else call __kmpc_omp_task
7030
Function *TaskFn =
7031
getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_with_deps);
7032
Builder.CreateCall(
7033
TaskFn,
7034
{Ident, ThreadID, TaskData, Builder.getInt32(Dependencies.size()),
7035
DepArray, ConstantInt::get(Builder.getInt32Ty(), 0),
7036
ConstantPointerNull::get(PointerType::getUnqual(M.getContext()))});
7037
} else {
7038
// Emit the @__kmpc_omp_task runtime call to spawn the task
7039
Function *TaskFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task);
7040
Builder.CreateCall(TaskFn, {Ident, ThreadID, TaskData});
7041
}
7042
7043
StaleCI->eraseFromParent();
7044
llvm::for_each(llvm::reverse(ToBeDeleted),
7045
[](Instruction *I) { I->eraseFromParent(); });
7046
};
7047
addOutlineInfo(std::move(OI));
7048
7049
LLVM_DEBUG(dbgs() << "Insert block after emitKernelLaunch = \n"
7050
<< *(Builder.GetInsertBlock()) << "\n");
7051
LLVM_DEBUG(dbgs() << "Module after emitKernelLaunch = \n"
7052
<< *(Builder.GetInsertBlock()->getParent()->getParent())
7053
<< "\n");
7054
return Builder.saveIP();
7055
}
7056
static void emitTargetCall(
7057
OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7058
OpenMPIRBuilder::InsertPointTy AllocaIP, Function *OutlinedFn,
7059
Constant *OutlinedFnID, int32_t NumTeams, int32_t NumThreads,
7060
SmallVectorImpl<Value *> &Args,
7061
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
7062
SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {}) {
7063
7064
OpenMPIRBuilder::TargetDataInfo Info(
7065
/*RequiresDevicePointerInfo=*/false,
7066
/*SeparateBeginEndCalls=*/true);
7067
7068
OpenMPIRBuilder::MapInfosTy &MapInfo = GenMapInfoCB(Builder.saveIP());
7069
OMPBuilder.emitOffloadingArrays(AllocaIP, Builder.saveIP(), MapInfo, Info,
7070
/*IsNonContiguous=*/true);
7071
7072
OpenMPIRBuilder::TargetDataRTArgs RTArgs;
7073
OMPBuilder.emitOffloadingArraysArgument(Builder, RTArgs, Info,
7074
!MapInfo.Names.empty());
7075
7076
// emitKernelLaunch
7077
auto &&EmitTargetCallFallbackCB =
7078
[&](OpenMPIRBuilder::InsertPointTy IP) -> OpenMPIRBuilder::InsertPointTy {
7079
Builder.restoreIP(IP);
7080
Builder.CreateCall(OutlinedFn, Args);
7081
return Builder.saveIP();
7082
};
7083
7084
unsigned NumTargetItems = MapInfo.BasePointers.size();
7085
// TODO: Use correct device ID
7086
Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF);
7087
Value *NumTeamsVal = Builder.getInt32(NumTeams);
7088
Value *NumThreadsVal = Builder.getInt32(NumThreads);
7089
uint32_t SrcLocStrSize;
7090
Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
7091
Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
7092
llvm::omp::IdentFlag(0), 0);
7093
// TODO: Use correct NumIterations
7094
Value *NumIterations = Builder.getInt64(0);
7095
// TODO: Use correct DynCGGroupMem
7096
Value *DynCGGroupMem = Builder.getInt32(0);
7097
7098
bool HasNoWait = false;
7099
bool HasDependencies = Dependencies.size() > 0;
7100
bool RequiresOuterTargetTask = HasNoWait || HasDependencies;
7101
7102
OpenMPIRBuilder::TargetKernelArgs KArgs(NumTargetItems, RTArgs, NumIterations,
7103
NumTeamsVal, NumThreadsVal,
7104
DynCGGroupMem, HasNoWait);
7105
7106
// The presence of certain clauses on the target directive require the
7107
// explicit generation of the target task.
7108
if (RequiresOuterTargetTask) {
7109
Builder.restoreIP(OMPBuilder.emitTargetTask(
7110
OutlinedFn, OutlinedFnID, EmitTargetCallFallbackCB, KArgs, DeviceID,
7111
RTLoc, AllocaIP, Dependencies, HasNoWait));
7112
} else {
7113
Builder.restoreIP(OMPBuilder.emitKernelLaunch(
7114
Builder, OutlinedFn, OutlinedFnID, EmitTargetCallFallbackCB, KArgs,
7115
DeviceID, RTLoc, AllocaIP));
7116
}
7117
}
7118
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTarget(
7119
const LocationDescription &Loc, InsertPointTy AllocaIP,
7120
InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo, int32_t NumTeams,
7121
int32_t NumThreads, SmallVectorImpl<Value *> &Args,
7122
GenMapInfoCallbackTy GenMapInfoCB,
7123
OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
7124
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
7125
SmallVector<DependData> Dependencies) {
7126
7127
if (!updateToLocation(Loc))
7128
return InsertPointTy();
7129
7130
Builder.restoreIP(CodeGenIP);
7131
7132
Function *OutlinedFn;
7133
Constant *OutlinedFnID;
7134
// The target region is outlined into its own function. The LLVM IR for
7135
// the target region itself is generated using the callbacks CBFunc
7136
// and ArgAccessorFuncCB
7137
emitTargetOutlinedFunction(*this, Builder, EntryInfo, OutlinedFn,
7138
OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB);
7139
7140
// If we are not on the target device, then we need to generate code
7141
// to make a remote call (offload) to the previously outlined function
7142
// that represents the target region. Do that now.
7143
if (!Config.isTargetDevice())
7144
emitTargetCall(*this, Builder, AllocaIP, OutlinedFn, OutlinedFnID, NumTeams,
7145
NumThreads, Args, GenMapInfoCB, Dependencies);
7146
return Builder.saveIP();
7147
}
7148
7149
std::string OpenMPIRBuilder::getNameWithSeparators(ArrayRef<StringRef> Parts,
7150
StringRef FirstSeparator,
7151
StringRef Separator) {
7152
SmallString<128> Buffer;
7153
llvm::raw_svector_ostream OS(Buffer);
7154
StringRef Sep = FirstSeparator;
7155
for (StringRef Part : Parts) {
7156
OS << Sep << Part;
7157
Sep = Separator;
7158
}
7159
return OS.str().str();
7160
}
7161
7162
std::string
7163
OpenMPIRBuilder::createPlatformSpecificName(ArrayRef<StringRef> Parts) const {
7164
return OpenMPIRBuilder::getNameWithSeparators(Parts, Config.firstSeparator(),
7165
Config.separator());
7166
}
7167
7168
GlobalVariable *
7169
OpenMPIRBuilder::getOrCreateInternalVariable(Type *Ty, const StringRef &Name,
7170
unsigned AddressSpace) {
7171
auto &Elem = *InternalVars.try_emplace(Name, nullptr).first;
7172
if (Elem.second) {
7173
assert(Elem.second->getValueType() == Ty &&
7174
"OMP internal variable has different type than requested");
7175
} else {
7176
// TODO: investigate the appropriate linkage type used for the global
7177
// variable for possibly changing that to internal or private, or maybe
7178
// create different versions of the function for different OMP internal
7179
// variables.
7180
auto Linkage = this->M.getTargetTriple().rfind("wasm32") == 0
7181
? GlobalValue::ExternalLinkage
7182
: GlobalValue::CommonLinkage;
7183
auto *GV = new GlobalVariable(M, Ty, /*IsConstant=*/false, Linkage,
7184
Constant::getNullValue(Ty), Elem.first(),
7185
/*InsertBefore=*/nullptr,
7186
GlobalValue::NotThreadLocal, AddressSpace);
7187
const DataLayout &DL = M.getDataLayout();
7188
const llvm::Align TypeAlign = DL.getABITypeAlign(Ty);
7189
const llvm::Align PtrAlign = DL.getPointerABIAlignment(AddressSpace);
7190
GV->setAlignment(std::max(TypeAlign, PtrAlign));
7191
Elem.second = GV;
7192
}
7193
7194
return Elem.second;
7195
}
7196
7197
Value *OpenMPIRBuilder::getOMPCriticalRegionLock(StringRef CriticalName) {
7198
std::string Prefix = Twine("gomp_critical_user_", CriticalName).str();
7199
std::string Name = getNameWithSeparators({Prefix, "var"}, ".", ".");
7200
return getOrCreateInternalVariable(KmpCriticalNameTy, Name);
7201
}
7202
7203
Value *OpenMPIRBuilder::getSizeInBytes(Value *BasePtr) {
7204
LLVMContext &Ctx = Builder.getContext();
7205
Value *Null =
7206
Constant::getNullValue(PointerType::getUnqual(BasePtr->getContext()));
7207
Value *SizeGep =
7208
Builder.CreateGEP(BasePtr->getType(), Null, Builder.getInt32(1));
7209
Value *SizePtrToInt = Builder.CreatePtrToInt(SizeGep, Type::getInt64Ty(Ctx));
7210
return SizePtrToInt;
7211
}
7212
7213
GlobalVariable *
7214
OpenMPIRBuilder::createOffloadMaptypes(SmallVectorImpl<uint64_t> &Mappings,
7215
std::string VarName) {
7216
llvm::Constant *MaptypesArrayInit =
7217
llvm::ConstantDataArray::get(M.getContext(), Mappings);
7218
auto *MaptypesArrayGlobal = new llvm::GlobalVariable(
7219
M, MaptypesArrayInit->getType(),
7220
/*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, MaptypesArrayInit,
7221
VarName);
7222
MaptypesArrayGlobal->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global);
7223
return MaptypesArrayGlobal;
7224
}
7225
7226
void OpenMPIRBuilder::createMapperAllocas(const LocationDescription &Loc,
7227
InsertPointTy AllocaIP,
7228
unsigned NumOperands,
7229
struct MapperAllocas &MapperAllocas) {
7230
if (!updateToLocation(Loc))
7231
return;
7232
7233
auto *ArrI8PtrTy = ArrayType::get(Int8Ptr, NumOperands);
7234
auto *ArrI64Ty = ArrayType::get(Int64, NumOperands);
7235
Builder.restoreIP(AllocaIP);
7236
AllocaInst *ArgsBase = Builder.CreateAlloca(
7237
ArrI8PtrTy, /* ArraySize = */ nullptr, ".offload_baseptrs");
7238
AllocaInst *Args = Builder.CreateAlloca(ArrI8PtrTy, /* ArraySize = */ nullptr,
7239
".offload_ptrs");
7240
AllocaInst *ArgSizes = Builder.CreateAlloca(
7241
ArrI64Ty, /* ArraySize = */ nullptr, ".offload_sizes");
7242
Builder.restoreIP(Loc.IP);
7243
MapperAllocas.ArgsBase = ArgsBase;
7244
MapperAllocas.Args = Args;
7245
MapperAllocas.ArgSizes = ArgSizes;
7246
}
7247
7248
void OpenMPIRBuilder::emitMapperCall(const LocationDescription &Loc,
7249
Function *MapperFunc, Value *SrcLocInfo,
7250
Value *MaptypesArg, Value *MapnamesArg,
7251
struct MapperAllocas &MapperAllocas,
7252
int64_t DeviceID, unsigned NumOperands) {
7253
if (!updateToLocation(Loc))
7254
return;
7255
7256
auto *ArrI8PtrTy = ArrayType::get(Int8Ptr, NumOperands);
7257
auto *ArrI64Ty = ArrayType::get(Int64, NumOperands);
7258
Value *ArgsBaseGEP =
7259
Builder.CreateInBoundsGEP(ArrI8PtrTy, MapperAllocas.ArgsBase,
7260
{Builder.getInt32(0), Builder.getInt32(0)});
7261
Value *ArgsGEP =
7262
Builder.CreateInBoundsGEP(ArrI8PtrTy, MapperAllocas.Args,
7263
{Builder.getInt32(0), Builder.getInt32(0)});
7264
Value *ArgSizesGEP =
7265
Builder.CreateInBoundsGEP(ArrI64Ty, MapperAllocas.ArgSizes,
7266
{Builder.getInt32(0), Builder.getInt32(0)});
7267
Value *NullPtr =
7268
Constant::getNullValue(PointerType::getUnqual(Int8Ptr->getContext()));
7269
Builder.CreateCall(MapperFunc,
7270
{SrcLocInfo, Builder.getInt64(DeviceID),
7271
Builder.getInt32(NumOperands), ArgsBaseGEP, ArgsGEP,
7272
ArgSizesGEP, MaptypesArg, MapnamesArg, NullPtr});
7273
}
7274
7275
void OpenMPIRBuilder::emitOffloadingArraysArgument(IRBuilderBase &Builder,
7276
TargetDataRTArgs &RTArgs,
7277
TargetDataInfo &Info,
7278
bool EmitDebug,
7279
bool ForEndCall) {
7280
assert((!ForEndCall || Info.separateBeginEndCalls()) &&
7281
"expected region end call to runtime only when end call is separate");
7282
auto UnqualPtrTy = PointerType::getUnqual(M.getContext());
7283
auto VoidPtrTy = UnqualPtrTy;
7284
auto VoidPtrPtrTy = UnqualPtrTy;
7285
auto Int64Ty = Type::getInt64Ty(M.getContext());
7286
auto Int64PtrTy = UnqualPtrTy;
7287
7288
if (!Info.NumberOfPtrs) {
7289
RTArgs.BasePointersArray = ConstantPointerNull::get(VoidPtrPtrTy);
7290
RTArgs.PointersArray = ConstantPointerNull::get(VoidPtrPtrTy);
7291
RTArgs.SizesArray = ConstantPointerNull::get(Int64PtrTy);
7292
RTArgs.MapTypesArray = ConstantPointerNull::get(Int64PtrTy);
7293
RTArgs.MapNamesArray = ConstantPointerNull::get(VoidPtrPtrTy);
7294
RTArgs.MappersArray = ConstantPointerNull::get(VoidPtrPtrTy);
7295
return;
7296
}
7297
7298
RTArgs.BasePointersArray = Builder.CreateConstInBoundsGEP2_32(
7299
ArrayType::get(VoidPtrTy, Info.NumberOfPtrs),
7300
Info.RTArgs.BasePointersArray,
7301
/*Idx0=*/0, /*Idx1=*/0);
7302
RTArgs.PointersArray = Builder.CreateConstInBoundsGEP2_32(
7303
ArrayType::get(VoidPtrTy, Info.NumberOfPtrs), Info.RTArgs.PointersArray,
7304
/*Idx0=*/0,
7305
/*Idx1=*/0);
7306
RTArgs.SizesArray = Builder.CreateConstInBoundsGEP2_32(
7307
ArrayType::get(Int64Ty, Info.NumberOfPtrs), Info.RTArgs.SizesArray,
7308
/*Idx0=*/0, /*Idx1=*/0);
7309
RTArgs.MapTypesArray = Builder.CreateConstInBoundsGEP2_32(
7310
ArrayType::get(Int64Ty, Info.NumberOfPtrs),
7311
ForEndCall && Info.RTArgs.MapTypesArrayEnd ? Info.RTArgs.MapTypesArrayEnd
7312
: Info.RTArgs.MapTypesArray,
7313
/*Idx0=*/0,
7314
/*Idx1=*/0);
7315
7316
// Only emit the mapper information arrays if debug information is
7317
// requested.
7318
if (!EmitDebug)
7319
RTArgs.MapNamesArray = ConstantPointerNull::get(VoidPtrPtrTy);
7320
else
7321
RTArgs.MapNamesArray = Builder.CreateConstInBoundsGEP2_32(
7322
ArrayType::get(VoidPtrTy, Info.NumberOfPtrs), Info.RTArgs.MapNamesArray,
7323
/*Idx0=*/0,
7324
/*Idx1=*/0);
7325
// If there is no user-defined mapper, set the mapper array to nullptr to
7326
// avoid an unnecessary data privatization
7327
if (!Info.HasMapper)
7328
RTArgs.MappersArray = ConstantPointerNull::get(VoidPtrPtrTy);
7329
else
7330
RTArgs.MappersArray =
7331
Builder.CreatePointerCast(Info.RTArgs.MappersArray, VoidPtrPtrTy);
7332
}
7333
7334
void OpenMPIRBuilder::emitNonContiguousDescriptor(InsertPointTy AllocaIP,
7335
InsertPointTy CodeGenIP,
7336
MapInfosTy &CombinedInfo,
7337
TargetDataInfo &Info) {
7338
MapInfosTy::StructNonContiguousInfo &NonContigInfo =
7339
CombinedInfo.NonContigInfo;
7340
7341
// Build an array of struct descriptor_dim and then assign it to
7342
// offload_args.
7343
//
7344
// struct descriptor_dim {
7345
// uint64_t offset;
7346
// uint64_t count;
7347
// uint64_t stride
7348
// };
7349
Type *Int64Ty = Builder.getInt64Ty();
7350
StructType *DimTy = StructType::create(
7351
M.getContext(), ArrayRef<Type *>({Int64Ty, Int64Ty, Int64Ty}),
7352
"struct.descriptor_dim");
7353
7354
enum { OffsetFD = 0, CountFD, StrideFD };
7355
// We need two index variable here since the size of "Dims" is the same as
7356
// the size of Components, however, the size of offset, count, and stride is
7357
// equal to the size of base declaration that is non-contiguous.
7358
for (unsigned I = 0, L = 0, E = NonContigInfo.Dims.size(); I < E; ++I) {
7359
// Skip emitting ir if dimension size is 1 since it cannot be
7360
// non-contiguous.
7361
if (NonContigInfo.Dims[I] == 1)
7362
continue;
7363
Builder.restoreIP(AllocaIP);
7364
ArrayType *ArrayTy = ArrayType::get(DimTy, NonContigInfo.Dims[I]);
7365
AllocaInst *DimsAddr =
7366
Builder.CreateAlloca(ArrayTy, /* ArraySize = */ nullptr, "dims");
7367
Builder.restoreIP(CodeGenIP);
7368
for (unsigned II = 0, EE = NonContigInfo.Dims[I]; II < EE; ++II) {
7369
unsigned RevIdx = EE - II - 1;
7370
Value *DimsLVal = Builder.CreateInBoundsGEP(
7371
DimsAddr->getAllocatedType(), DimsAddr,
7372
{Builder.getInt64(0), Builder.getInt64(II)});
7373
// Offset
7374
Value *OffsetLVal = Builder.CreateStructGEP(DimTy, DimsLVal, OffsetFD);
7375
Builder.CreateAlignedStore(
7376
NonContigInfo.Offsets[L][RevIdx], OffsetLVal,
7377
M.getDataLayout().getPrefTypeAlign(OffsetLVal->getType()));
7378
// Count
7379
Value *CountLVal = Builder.CreateStructGEP(DimTy, DimsLVal, CountFD);
7380
Builder.CreateAlignedStore(
7381
NonContigInfo.Counts[L][RevIdx], CountLVal,
7382
M.getDataLayout().getPrefTypeAlign(CountLVal->getType()));
7383
// Stride
7384
Value *StrideLVal = Builder.CreateStructGEP(DimTy, DimsLVal, StrideFD);
7385
Builder.CreateAlignedStore(
7386
NonContigInfo.Strides[L][RevIdx], StrideLVal,
7387
M.getDataLayout().getPrefTypeAlign(CountLVal->getType()));
7388
}
7389
// args[I] = &dims
7390
Builder.restoreIP(CodeGenIP);
7391
Value *DAddr = Builder.CreatePointerBitCastOrAddrSpaceCast(
7392
DimsAddr, Builder.getPtrTy());
7393
Value *P = Builder.CreateConstInBoundsGEP2_32(
7394
ArrayType::get(Builder.getPtrTy(), Info.NumberOfPtrs),
7395
Info.RTArgs.PointersArray, 0, I);
7396
Builder.CreateAlignedStore(
7397
DAddr, P, M.getDataLayout().getPrefTypeAlign(Builder.getPtrTy()));
7398
++L;
7399
}
7400
}
7401
7402
void OpenMPIRBuilder::emitOffloadingArrays(
7403
InsertPointTy AllocaIP, InsertPointTy CodeGenIP, MapInfosTy &CombinedInfo,
7404
TargetDataInfo &Info, bool IsNonContiguous,
7405
function_ref<void(unsigned int, Value *)> DeviceAddrCB,
7406
function_ref<Value *(unsigned int)> CustomMapperCB) {
7407
7408
// Reset the array information.
7409
Info.clearArrayInfo();
7410
Info.NumberOfPtrs = CombinedInfo.BasePointers.size();
7411
7412
if (Info.NumberOfPtrs == 0)
7413
return;
7414
7415
Builder.restoreIP(AllocaIP);
7416
// Detect if we have any capture size requiring runtime evaluation of the
7417
// size so that a constant array could be eventually used.
7418
ArrayType *PointerArrayType =
7419
ArrayType::get(Builder.getPtrTy(), Info.NumberOfPtrs);
7420
7421
Info.RTArgs.BasePointersArray = Builder.CreateAlloca(
7422
PointerArrayType, /* ArraySize = */ nullptr, ".offload_baseptrs");
7423
7424
Info.RTArgs.PointersArray = Builder.CreateAlloca(
7425
PointerArrayType, /* ArraySize = */ nullptr, ".offload_ptrs");
7426
AllocaInst *MappersArray = Builder.CreateAlloca(
7427
PointerArrayType, /* ArraySize = */ nullptr, ".offload_mappers");
7428
Info.RTArgs.MappersArray = MappersArray;
7429
7430
// If we don't have any VLA types or other types that require runtime
7431
// evaluation, we can use a constant array for the map sizes, otherwise we
7432
// need to fill up the arrays as we do for the pointers.
7433
Type *Int64Ty = Builder.getInt64Ty();
7434
SmallVector<Constant *> ConstSizes(CombinedInfo.Sizes.size(),
7435
ConstantInt::get(Int64Ty, 0));
7436
SmallBitVector RuntimeSizes(CombinedInfo.Sizes.size());
7437
for (unsigned I = 0, E = CombinedInfo.Sizes.size(); I < E; ++I) {
7438
if (auto *CI = dyn_cast<Constant>(CombinedInfo.Sizes[I])) {
7439
if (!isa<ConstantExpr>(CI) && !isa<GlobalValue>(CI)) {
7440
if (IsNonContiguous &&
7441
static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
7442
CombinedInfo.Types[I] &
7443
OpenMPOffloadMappingFlags::OMP_MAP_NON_CONTIG))
7444
ConstSizes[I] =
7445
ConstantInt::get(Int64Ty, CombinedInfo.NonContigInfo.Dims[I]);
7446
else
7447
ConstSizes[I] = CI;
7448
continue;
7449
}
7450
}
7451
RuntimeSizes.set(I);
7452
}
7453
7454
if (RuntimeSizes.all()) {
7455
ArrayType *SizeArrayType = ArrayType::get(Int64Ty, Info.NumberOfPtrs);
7456
Info.RTArgs.SizesArray = Builder.CreateAlloca(
7457
SizeArrayType, /* ArraySize = */ nullptr, ".offload_sizes");
7458
Builder.restoreIP(CodeGenIP);
7459
} else {
7460
auto *SizesArrayInit = ConstantArray::get(
7461
ArrayType::get(Int64Ty, ConstSizes.size()), ConstSizes);
7462
std::string Name = createPlatformSpecificName({"offload_sizes"});
7463
auto *SizesArrayGbl =
7464
new GlobalVariable(M, SizesArrayInit->getType(), /*isConstant=*/true,
7465
GlobalValue::PrivateLinkage, SizesArrayInit, Name);
7466
SizesArrayGbl->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
7467
7468
if (!RuntimeSizes.any()) {
7469
Info.RTArgs.SizesArray = SizesArrayGbl;
7470
} else {
7471
unsigned IndexSize = M.getDataLayout().getIndexSizeInBits(0);
7472
Align OffloadSizeAlign = M.getDataLayout().getABIIntegerTypeAlignment(64);
7473
ArrayType *SizeArrayType = ArrayType::get(Int64Ty, Info.NumberOfPtrs);
7474
AllocaInst *Buffer = Builder.CreateAlloca(
7475
SizeArrayType, /* ArraySize = */ nullptr, ".offload_sizes");
7476
Buffer->setAlignment(OffloadSizeAlign);
7477
Builder.restoreIP(CodeGenIP);
7478
Builder.CreateMemCpy(
7479
Buffer, M.getDataLayout().getPrefTypeAlign(Buffer->getType()),
7480
SizesArrayGbl, OffloadSizeAlign,
7481
Builder.getIntN(
7482
IndexSize,
7483
Buffer->getAllocationSize(M.getDataLayout())->getFixedValue()));
7484
7485
Info.RTArgs.SizesArray = Buffer;
7486
}
7487
Builder.restoreIP(CodeGenIP);
7488
}
7489
7490
// The map types are always constant so we don't need to generate code to
7491
// fill arrays. Instead, we create an array constant.
7492
SmallVector<uint64_t, 4> Mapping;
7493
for (auto mapFlag : CombinedInfo.Types)
7494
Mapping.push_back(
7495
static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
7496
mapFlag));
7497
std::string MaptypesName = createPlatformSpecificName({"offload_maptypes"});
7498
auto *MapTypesArrayGbl = createOffloadMaptypes(Mapping, MaptypesName);
7499
Info.RTArgs.MapTypesArray = MapTypesArrayGbl;
7500
7501
// The information types are only built if provided.
7502
if (!CombinedInfo.Names.empty()) {
7503
std::string MapnamesName = createPlatformSpecificName({"offload_mapnames"});
7504
auto *MapNamesArrayGbl =
7505
createOffloadMapnames(CombinedInfo.Names, MapnamesName);
7506
Info.RTArgs.MapNamesArray = MapNamesArrayGbl;
7507
} else {
7508
Info.RTArgs.MapNamesArray =
7509
Constant::getNullValue(PointerType::getUnqual(Builder.getContext()));
7510
}
7511
7512
// If there's a present map type modifier, it must not be applied to the end
7513
// of a region, so generate a separate map type array in that case.
7514
if (Info.separateBeginEndCalls()) {
7515
bool EndMapTypesDiffer = false;
7516
for (uint64_t &Type : Mapping) {
7517
if (Type & static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
7518
OpenMPOffloadMappingFlags::OMP_MAP_PRESENT)) {
7519
Type &= ~static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
7520
OpenMPOffloadMappingFlags::OMP_MAP_PRESENT);
7521
EndMapTypesDiffer = true;
7522
}
7523
}
7524
if (EndMapTypesDiffer) {
7525
MapTypesArrayGbl = createOffloadMaptypes(Mapping, MaptypesName);
7526
Info.RTArgs.MapTypesArrayEnd = MapTypesArrayGbl;
7527
}
7528
}
7529
7530
PointerType *PtrTy = Builder.getPtrTy();
7531
for (unsigned I = 0; I < Info.NumberOfPtrs; ++I) {
7532
Value *BPVal = CombinedInfo.BasePointers[I];
7533
Value *BP = Builder.CreateConstInBoundsGEP2_32(
7534
ArrayType::get(PtrTy, Info.NumberOfPtrs), Info.RTArgs.BasePointersArray,
7535
0, I);
7536
Builder.CreateAlignedStore(BPVal, BP,
7537
M.getDataLayout().getPrefTypeAlign(PtrTy));
7538
7539
if (Info.requiresDevicePointerInfo()) {
7540
if (CombinedInfo.DevicePointers[I] == DeviceInfoTy::Pointer) {
7541
CodeGenIP = Builder.saveIP();
7542
Builder.restoreIP(AllocaIP);
7543
Info.DevicePtrInfoMap[BPVal] = {BP, Builder.CreateAlloca(PtrTy)};
7544
Builder.restoreIP(CodeGenIP);
7545
if (DeviceAddrCB)
7546
DeviceAddrCB(I, Info.DevicePtrInfoMap[BPVal].second);
7547
} else if (CombinedInfo.DevicePointers[I] == DeviceInfoTy::Address) {
7548
Info.DevicePtrInfoMap[BPVal] = {BP, BP};
7549
if (DeviceAddrCB)
7550
DeviceAddrCB(I, BP);
7551
}
7552
}
7553
7554
Value *PVal = CombinedInfo.Pointers[I];
7555
Value *P = Builder.CreateConstInBoundsGEP2_32(
7556
ArrayType::get(PtrTy, Info.NumberOfPtrs), Info.RTArgs.PointersArray, 0,
7557
I);
7558
// TODO: Check alignment correct.
7559
Builder.CreateAlignedStore(PVal, P,
7560
M.getDataLayout().getPrefTypeAlign(PtrTy));
7561
7562
if (RuntimeSizes.test(I)) {
7563
Value *S = Builder.CreateConstInBoundsGEP2_32(
7564
ArrayType::get(Int64Ty, Info.NumberOfPtrs), Info.RTArgs.SizesArray,
7565
/*Idx0=*/0,
7566
/*Idx1=*/I);
7567
Builder.CreateAlignedStore(Builder.CreateIntCast(CombinedInfo.Sizes[I],
7568
Int64Ty,
7569
/*isSigned=*/true),
7570
S, M.getDataLayout().getPrefTypeAlign(PtrTy));
7571
}
7572
// Fill up the mapper array.
7573
unsigned IndexSize = M.getDataLayout().getIndexSizeInBits(0);
7574
Value *MFunc = ConstantPointerNull::get(PtrTy);
7575
if (CustomMapperCB)
7576
if (Value *CustomMFunc = CustomMapperCB(I))
7577
MFunc = Builder.CreatePointerCast(CustomMFunc, PtrTy);
7578
Value *MAddr = Builder.CreateInBoundsGEP(
7579
MappersArray->getAllocatedType(), MappersArray,
7580
{Builder.getIntN(IndexSize, 0), Builder.getIntN(IndexSize, I)});
7581
Builder.CreateAlignedStore(
7582
MFunc, MAddr, M.getDataLayout().getPrefTypeAlign(MAddr->getType()));
7583
}
7584
7585
if (!IsNonContiguous || CombinedInfo.NonContigInfo.Offsets.empty() ||
7586
Info.NumberOfPtrs == 0)
7587
return;
7588
emitNonContiguousDescriptor(AllocaIP, CodeGenIP, CombinedInfo, Info);
7589
}
7590
7591
void OpenMPIRBuilder::emitBranch(BasicBlock *Target) {
7592
BasicBlock *CurBB = Builder.GetInsertBlock();
7593
7594
if (!CurBB || CurBB->getTerminator()) {
7595
// If there is no insert point or the previous block is already
7596
// terminated, don't touch it.
7597
} else {
7598
// Otherwise, create a fall-through branch.
7599
Builder.CreateBr(Target);
7600
}
7601
7602
Builder.ClearInsertionPoint();
7603
}
7604
7605
void OpenMPIRBuilder::emitBlock(BasicBlock *BB, Function *CurFn,
7606
bool IsFinished) {
7607
BasicBlock *CurBB = Builder.GetInsertBlock();
7608
7609
// Fall out of the current block (if necessary).
7610
emitBranch(BB);
7611
7612
if (IsFinished && BB->use_empty()) {
7613
BB->eraseFromParent();
7614
return;
7615
}
7616
7617
// Place the block after the current block, if possible, or else at
7618
// the end of the function.
7619
if (CurBB && CurBB->getParent())
7620
CurFn->insert(std::next(CurBB->getIterator()), BB);
7621
else
7622
CurFn->insert(CurFn->end(), BB);
7623
Builder.SetInsertPoint(BB);
7624
}
7625
7626
void OpenMPIRBuilder::emitIfClause(Value *Cond, BodyGenCallbackTy ThenGen,
7627
BodyGenCallbackTy ElseGen,
7628
InsertPointTy AllocaIP) {
7629
// If the condition constant folds and can be elided, try to avoid emitting
7630
// the condition and the dead arm of the if/else.
7631
if (auto *CI = dyn_cast<ConstantInt>(Cond)) {
7632
auto CondConstant = CI->getSExtValue();
7633
if (CondConstant)
7634
ThenGen(AllocaIP, Builder.saveIP());
7635
else
7636
ElseGen(AllocaIP, Builder.saveIP());
7637
return;
7638
}
7639
7640
Function *CurFn = Builder.GetInsertBlock()->getParent();
7641
7642
// Otherwise, the condition did not fold, or we couldn't elide it. Just
7643
// emit the conditional branch.
7644
BasicBlock *ThenBlock = BasicBlock::Create(M.getContext(), "omp_if.then");
7645
BasicBlock *ElseBlock = BasicBlock::Create(M.getContext(), "omp_if.else");
7646
BasicBlock *ContBlock = BasicBlock::Create(M.getContext(), "omp_if.end");
7647
Builder.CreateCondBr(Cond, ThenBlock, ElseBlock);
7648
// Emit the 'then' code.
7649
emitBlock(ThenBlock, CurFn);
7650
ThenGen(AllocaIP, Builder.saveIP());
7651
emitBranch(ContBlock);
7652
// Emit the 'else' code if present.
7653
// There is no need to emit line number for unconditional branch.
7654
emitBlock(ElseBlock, CurFn);
7655
ElseGen(AllocaIP, Builder.saveIP());
7656
// There is no need to emit line number for unconditional branch.
7657
emitBranch(ContBlock);
7658
// Emit the continuation block for code after the if.
7659
emitBlock(ContBlock, CurFn, /*IsFinished=*/true);
7660
}
7661
7662
bool OpenMPIRBuilder::checkAndEmitFlushAfterAtomic(
7663
const LocationDescription &Loc, llvm::AtomicOrdering AO, AtomicKind AK) {
7664
assert(!(AO == AtomicOrdering::NotAtomic ||
7665
AO == llvm::AtomicOrdering::Unordered) &&
7666
"Unexpected Atomic Ordering.");
7667
7668
bool Flush = false;
7669
llvm::AtomicOrdering FlushAO = AtomicOrdering::Monotonic;
7670
7671
switch (AK) {
7672
case Read:
7673
if (AO == AtomicOrdering::Acquire || AO == AtomicOrdering::AcquireRelease ||
7674
AO == AtomicOrdering::SequentiallyConsistent) {
7675
FlushAO = AtomicOrdering::Acquire;
7676
Flush = true;
7677
}
7678
break;
7679
case Write:
7680
case Compare:
7681
case Update:
7682
if (AO == AtomicOrdering::Release || AO == AtomicOrdering::AcquireRelease ||
7683
AO == AtomicOrdering::SequentiallyConsistent) {
7684
FlushAO = AtomicOrdering::Release;
7685
Flush = true;
7686
}
7687
break;
7688
case Capture:
7689
switch (AO) {
7690
case AtomicOrdering::Acquire:
7691
FlushAO = AtomicOrdering::Acquire;
7692
Flush = true;
7693
break;
7694
case AtomicOrdering::Release:
7695
FlushAO = AtomicOrdering::Release;
7696
Flush = true;
7697
break;
7698
case AtomicOrdering::AcquireRelease:
7699
case AtomicOrdering::SequentiallyConsistent:
7700
FlushAO = AtomicOrdering::AcquireRelease;
7701
Flush = true;
7702
break;
7703
default:
7704
// do nothing - leave silently.
7705
break;
7706
}
7707
}
7708
7709
if (Flush) {
7710
// Currently Flush RT call still doesn't take memory_ordering, so for when
7711
// that happens, this tries to do the resolution of which atomic ordering
7712
// to use with but issue the flush call
7713
// TODO: pass `FlushAO` after memory ordering support is added
7714
(void)FlushAO;
7715
emitFlush(Loc);
7716
}
7717
7718
// for AO == AtomicOrdering::Monotonic and all other case combinations
7719
// do nothing
7720
return Flush;
7721
}
7722
7723
OpenMPIRBuilder::InsertPointTy
7724
OpenMPIRBuilder::createAtomicRead(const LocationDescription &Loc,
7725
AtomicOpValue &X, AtomicOpValue &V,
7726
AtomicOrdering AO) {
7727
if (!updateToLocation(Loc))
7728
return Loc.IP;
7729
7730
assert(X.Var->getType()->isPointerTy() &&
7731
"OMP Atomic expects a pointer to target memory");
7732
Type *XElemTy = X.ElemTy;
7733
assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
7734
XElemTy->isPointerTy()) &&
7735
"OMP atomic read expected a scalar type");
7736
7737
Value *XRead = nullptr;
7738
7739
if (XElemTy->isIntegerTy()) {
7740
LoadInst *XLD =
7741
Builder.CreateLoad(XElemTy, X.Var, X.IsVolatile, "omp.atomic.read");
7742
XLD->setAtomic(AO);
7743
XRead = cast<Value>(XLD);
7744
} else {
7745
// We need to perform atomic op as integer
7746
IntegerType *IntCastTy =
7747
IntegerType::get(M.getContext(), XElemTy->getScalarSizeInBits());
7748
LoadInst *XLoad =
7749
Builder.CreateLoad(IntCastTy, X.Var, X.IsVolatile, "omp.atomic.load");
7750
XLoad->setAtomic(AO);
7751
if (XElemTy->isFloatingPointTy()) {
7752
XRead = Builder.CreateBitCast(XLoad, XElemTy, "atomic.flt.cast");
7753
} else {
7754
XRead = Builder.CreateIntToPtr(XLoad, XElemTy, "atomic.ptr.cast");
7755
}
7756
}
7757
checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Read);
7758
Builder.CreateStore(XRead, V.Var, V.IsVolatile);
7759
return Builder.saveIP();
7760
}
7761
7762
OpenMPIRBuilder::InsertPointTy
7763
OpenMPIRBuilder::createAtomicWrite(const LocationDescription &Loc,
7764
AtomicOpValue &X, Value *Expr,
7765
AtomicOrdering AO) {
7766
if (!updateToLocation(Loc))
7767
return Loc.IP;
7768
7769
assert(X.Var->getType()->isPointerTy() &&
7770
"OMP Atomic expects a pointer to target memory");
7771
Type *XElemTy = X.ElemTy;
7772
assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
7773
XElemTy->isPointerTy()) &&
7774
"OMP atomic write expected a scalar type");
7775
7776
if (XElemTy->isIntegerTy()) {
7777
StoreInst *XSt = Builder.CreateStore(Expr, X.Var, X.IsVolatile);
7778
XSt->setAtomic(AO);
7779
} else {
7780
// We need to bitcast and perform atomic op as integers
7781
IntegerType *IntCastTy =
7782
IntegerType::get(M.getContext(), XElemTy->getScalarSizeInBits());
7783
Value *ExprCast =
7784
Builder.CreateBitCast(Expr, IntCastTy, "atomic.src.int.cast");
7785
StoreInst *XSt = Builder.CreateStore(ExprCast, X.Var, X.IsVolatile);
7786
XSt->setAtomic(AO);
7787
}
7788
7789
checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Write);
7790
return Builder.saveIP();
7791
}
7792
7793
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicUpdate(
7794
const LocationDescription &Loc, InsertPointTy AllocaIP, AtomicOpValue &X,
7795
Value *Expr, AtomicOrdering AO, AtomicRMWInst::BinOp RMWOp,
7796
AtomicUpdateCallbackTy &UpdateOp, bool IsXBinopExpr) {
7797
assert(!isConflictIP(Loc.IP, AllocaIP) && "IPs must not be ambiguous");
7798
if (!updateToLocation(Loc))
7799
return Loc.IP;
7800
7801
LLVM_DEBUG({
7802
Type *XTy = X.Var->getType();
7803
assert(XTy->isPointerTy() &&
7804
"OMP Atomic expects a pointer to target memory");
7805
Type *XElemTy = X.ElemTy;
7806
assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
7807
XElemTy->isPointerTy()) &&
7808
"OMP atomic update expected a scalar type");
7809
assert((RMWOp != AtomicRMWInst::Max) && (RMWOp != AtomicRMWInst::Min) &&
7810
(RMWOp != AtomicRMWInst::UMax) && (RMWOp != AtomicRMWInst::UMin) &&
7811
"OpenMP atomic does not support LT or GT operations");
7812
});
7813
7814
emitAtomicUpdate(AllocaIP, X.Var, X.ElemTy, Expr, AO, RMWOp, UpdateOp,
7815
X.IsVolatile, IsXBinopExpr);
7816
checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Update);
7817
return Builder.saveIP();
7818
}
7819
7820
// FIXME: Duplicating AtomicExpand
7821
Value *OpenMPIRBuilder::emitRMWOpAsInstruction(Value *Src1, Value *Src2,
7822
AtomicRMWInst::BinOp RMWOp) {
7823
switch (RMWOp) {
7824
case AtomicRMWInst::Add:
7825
return Builder.CreateAdd(Src1, Src2);
7826
case AtomicRMWInst::Sub:
7827
return Builder.CreateSub(Src1, Src2);
7828
case AtomicRMWInst::And:
7829
return Builder.CreateAnd(Src1, Src2);
7830
case AtomicRMWInst::Nand:
7831
return Builder.CreateNeg(Builder.CreateAnd(Src1, Src2));
7832
case AtomicRMWInst::Or:
7833
return Builder.CreateOr(Src1, Src2);
7834
case AtomicRMWInst::Xor:
7835
return Builder.CreateXor(Src1, Src2);
7836
case AtomicRMWInst::Xchg:
7837
case AtomicRMWInst::FAdd:
7838
case AtomicRMWInst::FSub:
7839
case AtomicRMWInst::BAD_BINOP:
7840
case AtomicRMWInst::Max:
7841
case AtomicRMWInst::Min:
7842
case AtomicRMWInst::UMax:
7843
case AtomicRMWInst::UMin:
7844
case AtomicRMWInst::FMax:
7845
case AtomicRMWInst::FMin:
7846
case AtomicRMWInst::UIncWrap:
7847
case AtomicRMWInst::UDecWrap:
7848
llvm_unreachable("Unsupported atomic update operation");
7849
}
7850
llvm_unreachable("Unsupported atomic update operation");
7851
}
7852
7853
std::pair<Value *, Value *> OpenMPIRBuilder::emitAtomicUpdate(
7854
InsertPointTy AllocaIP, Value *X, Type *XElemTy, Value *Expr,
7855
AtomicOrdering AO, AtomicRMWInst::BinOp RMWOp,
7856
AtomicUpdateCallbackTy &UpdateOp, bool VolatileX, bool IsXBinopExpr) {
7857
// TODO: handle the case where XElemTy is not byte-sized or not a power of 2
7858
// or a complex datatype.
7859
bool emitRMWOp = false;
7860
switch (RMWOp) {
7861
case AtomicRMWInst::Add:
7862
case AtomicRMWInst::And:
7863
case AtomicRMWInst::Nand:
7864
case AtomicRMWInst::Or:
7865
case AtomicRMWInst::Xor:
7866
case AtomicRMWInst::Xchg:
7867
emitRMWOp = XElemTy;
7868
break;
7869
case AtomicRMWInst::Sub:
7870
emitRMWOp = (IsXBinopExpr && XElemTy);
7871
break;
7872
default:
7873
emitRMWOp = false;
7874
}
7875
emitRMWOp &= XElemTy->isIntegerTy();
7876
7877
std::pair<Value *, Value *> Res;
7878
if (emitRMWOp) {
7879
Res.first = Builder.CreateAtomicRMW(RMWOp, X, Expr, llvm::MaybeAlign(), AO);
7880
// not needed except in case of postfix captures. Generate anyway for
7881
// consistency with the else part. Will be removed with any DCE pass.
7882
// AtomicRMWInst::Xchg does not have a coressponding instruction.
7883
if (RMWOp == AtomicRMWInst::Xchg)
7884
Res.second = Res.first;
7885
else
7886
Res.second = emitRMWOpAsInstruction(Res.first, Expr, RMWOp);
7887
} else {
7888
IntegerType *IntCastTy =
7889
IntegerType::get(M.getContext(), XElemTy->getScalarSizeInBits());
7890
LoadInst *OldVal =
7891
Builder.CreateLoad(IntCastTy, X, X->getName() + ".atomic.load");
7892
OldVal->setAtomic(AO);
7893
// CurBB
7894
// | /---\
7895
// ContBB |
7896
// | \---/
7897
// ExitBB
7898
BasicBlock *CurBB = Builder.GetInsertBlock();
7899
Instruction *CurBBTI = CurBB->getTerminator();
7900
CurBBTI = CurBBTI ? CurBBTI : Builder.CreateUnreachable();
7901
BasicBlock *ExitBB =
7902
CurBB->splitBasicBlock(CurBBTI, X->getName() + ".atomic.exit");
7903
BasicBlock *ContBB = CurBB->splitBasicBlock(CurBB->getTerminator(),
7904
X->getName() + ".atomic.cont");
7905
ContBB->getTerminator()->eraseFromParent();
7906
Builder.restoreIP(AllocaIP);
7907
AllocaInst *NewAtomicAddr = Builder.CreateAlloca(XElemTy);
7908
NewAtomicAddr->setName(X->getName() + "x.new.val");
7909
Builder.SetInsertPoint(ContBB);
7910
llvm::PHINode *PHI = Builder.CreatePHI(OldVal->getType(), 2);
7911
PHI->addIncoming(OldVal, CurBB);
7912
bool IsIntTy = XElemTy->isIntegerTy();
7913
Value *OldExprVal = PHI;
7914
if (!IsIntTy) {
7915
if (XElemTy->isFloatingPointTy()) {
7916
OldExprVal = Builder.CreateBitCast(PHI, XElemTy,
7917
X->getName() + ".atomic.fltCast");
7918
} else {
7919
OldExprVal = Builder.CreateIntToPtr(PHI, XElemTy,
7920
X->getName() + ".atomic.ptrCast");
7921
}
7922
}
7923
7924
Value *Upd = UpdateOp(OldExprVal, Builder);
7925
Builder.CreateStore(Upd, NewAtomicAddr);
7926
LoadInst *DesiredVal = Builder.CreateLoad(IntCastTy, NewAtomicAddr);
7927
AtomicOrdering Failure =
7928
llvm::AtomicCmpXchgInst::getStrongestFailureOrdering(AO);
7929
AtomicCmpXchgInst *Result = Builder.CreateAtomicCmpXchg(
7930
X, PHI, DesiredVal, llvm::MaybeAlign(), AO, Failure);
7931
Result->setVolatile(VolatileX);
7932
Value *PreviousVal = Builder.CreateExtractValue(Result, /*Idxs=*/0);
7933
Value *SuccessFailureVal = Builder.CreateExtractValue(Result, /*Idxs=*/1);
7934
PHI->addIncoming(PreviousVal, Builder.GetInsertBlock());
7935
Builder.CreateCondBr(SuccessFailureVal, ExitBB, ContBB);
7936
7937
Res.first = OldExprVal;
7938
Res.second = Upd;
7939
7940
// set Insertion point in exit block
7941
if (UnreachableInst *ExitTI =
7942
dyn_cast<UnreachableInst>(ExitBB->getTerminator())) {
7943
CurBBTI->eraseFromParent();
7944
Builder.SetInsertPoint(ExitBB);
7945
} else {
7946
Builder.SetInsertPoint(ExitTI);
7947
}
7948
}
7949
7950
return Res;
7951
}
7952
7953
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCapture(
7954
const LocationDescription &Loc, InsertPointTy AllocaIP, AtomicOpValue &X,
7955
AtomicOpValue &V, Value *Expr, AtomicOrdering AO,
7956
AtomicRMWInst::BinOp RMWOp, AtomicUpdateCallbackTy &UpdateOp,
7957
bool UpdateExpr, bool IsPostfixUpdate, bool IsXBinopExpr) {
7958
if (!updateToLocation(Loc))
7959
return Loc.IP;
7960
7961
LLVM_DEBUG({
7962
Type *XTy = X.Var->getType();
7963
assert(XTy->isPointerTy() &&
7964
"OMP Atomic expects a pointer to target memory");
7965
Type *XElemTy = X.ElemTy;
7966
assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
7967
XElemTy->isPointerTy()) &&
7968
"OMP atomic capture expected a scalar type");
7969
assert((RMWOp != AtomicRMWInst::Max) && (RMWOp != AtomicRMWInst::Min) &&
7970
"OpenMP atomic does not support LT or GT operations");
7971
});
7972
7973
// If UpdateExpr is 'x' updated with some `expr` not based on 'x',
7974
// 'x' is simply atomically rewritten with 'expr'.
7975
AtomicRMWInst::BinOp AtomicOp = (UpdateExpr ? RMWOp : AtomicRMWInst::Xchg);
7976
std::pair<Value *, Value *> Result =
7977
emitAtomicUpdate(AllocaIP, X.Var, X.ElemTy, Expr, AO, AtomicOp, UpdateOp,
7978
X.IsVolatile, IsXBinopExpr);
7979
7980
Value *CapturedVal = (IsPostfixUpdate ? Result.first : Result.second);
7981
Builder.CreateStore(CapturedVal, V.Var, V.IsVolatile);
7982
7983
checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Capture);
7984
return Builder.saveIP();
7985
}
7986
7987
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
7988
const LocationDescription &Loc, AtomicOpValue &X, AtomicOpValue &V,
7989
AtomicOpValue &R, Value *E, Value *D, AtomicOrdering AO,
7990
omp::OMPAtomicCompareOp Op, bool IsXBinopExpr, bool IsPostfixUpdate,
7991
bool IsFailOnly) {
7992
7993
AtomicOrdering Failure = AtomicCmpXchgInst::getStrongestFailureOrdering(AO);
7994
return createAtomicCompare(Loc, X, V, R, E, D, AO, Op, IsXBinopExpr,
7995
IsPostfixUpdate, IsFailOnly, Failure);
7996
}
7997
7998
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
7999
const LocationDescription &Loc, AtomicOpValue &X, AtomicOpValue &V,
8000
AtomicOpValue &R, Value *E, Value *D, AtomicOrdering AO,
8001
omp::OMPAtomicCompareOp Op, bool IsXBinopExpr, bool IsPostfixUpdate,
8002
bool IsFailOnly, AtomicOrdering Failure) {
8003
8004
if (!updateToLocation(Loc))
8005
return Loc.IP;
8006
8007
assert(X.Var->getType()->isPointerTy() &&
8008
"OMP atomic expects a pointer to target memory");
8009
// compare capture
8010
if (V.Var) {
8011
assert(V.Var->getType()->isPointerTy() && "v.var must be of pointer type");
8012
assert(V.ElemTy == X.ElemTy && "x and v must be of same type");
8013
}
8014
8015
bool IsInteger = E->getType()->isIntegerTy();
8016
8017
if (Op == OMPAtomicCompareOp::EQ) {
8018
AtomicCmpXchgInst *Result = nullptr;
8019
if (!IsInteger) {
8020
IntegerType *IntCastTy =
8021
IntegerType::get(M.getContext(), X.ElemTy->getScalarSizeInBits());
8022
Value *EBCast = Builder.CreateBitCast(E, IntCastTy);
8023
Value *DBCast = Builder.CreateBitCast(D, IntCastTy);
8024
Result = Builder.CreateAtomicCmpXchg(X.Var, EBCast, DBCast, MaybeAlign(),
8025
AO, Failure);
8026
} else {
8027
Result =
8028
Builder.CreateAtomicCmpXchg(X.Var, E, D, MaybeAlign(), AO, Failure);
8029
}
8030
8031
if (V.Var) {
8032
Value *OldValue = Builder.CreateExtractValue(Result, /*Idxs=*/0);
8033
if (!IsInteger)
8034
OldValue = Builder.CreateBitCast(OldValue, X.ElemTy);
8035
assert(OldValue->getType() == V.ElemTy &&
8036
"OldValue and V must be of same type");
8037
if (IsPostfixUpdate) {
8038
Builder.CreateStore(OldValue, V.Var, V.IsVolatile);
8039
} else {
8040
Value *SuccessOrFail = Builder.CreateExtractValue(Result, /*Idxs=*/1);
8041
if (IsFailOnly) {
8042
// CurBB----
8043
// | |
8044
// v |
8045
// ContBB |
8046
// | |
8047
// v |
8048
// ExitBB <-
8049
//
8050
// where ContBB only contains the store of old value to 'v'.
8051
BasicBlock *CurBB = Builder.GetInsertBlock();
8052
Instruction *CurBBTI = CurBB->getTerminator();
8053
CurBBTI = CurBBTI ? CurBBTI : Builder.CreateUnreachable();
8054
BasicBlock *ExitBB = CurBB->splitBasicBlock(
8055
CurBBTI, X.Var->getName() + ".atomic.exit");
8056
BasicBlock *ContBB = CurBB->splitBasicBlock(
8057
CurBB->getTerminator(), X.Var->getName() + ".atomic.cont");
8058
ContBB->getTerminator()->eraseFromParent();
8059
CurBB->getTerminator()->eraseFromParent();
8060
8061
Builder.CreateCondBr(SuccessOrFail, ExitBB, ContBB);
8062
8063
Builder.SetInsertPoint(ContBB);
8064
Builder.CreateStore(OldValue, V.Var);
8065
Builder.CreateBr(ExitBB);
8066
8067
if (UnreachableInst *ExitTI =
8068
dyn_cast<UnreachableInst>(ExitBB->getTerminator())) {
8069
CurBBTI->eraseFromParent();
8070
Builder.SetInsertPoint(ExitBB);
8071
} else {
8072
Builder.SetInsertPoint(ExitTI);
8073
}
8074
} else {
8075
Value *CapturedValue =
8076
Builder.CreateSelect(SuccessOrFail, E, OldValue);
8077
Builder.CreateStore(CapturedValue, V.Var, V.IsVolatile);
8078
}
8079
}
8080
}
8081
// The comparison result has to be stored.
8082
if (R.Var) {
8083
assert(R.Var->getType()->isPointerTy() &&
8084
"r.var must be of pointer type");
8085
assert(R.ElemTy->isIntegerTy() && "r must be of integral type");
8086
8087
Value *SuccessFailureVal = Builder.CreateExtractValue(Result, /*Idxs=*/1);
8088
Value *ResultCast = R.IsSigned
8089
? Builder.CreateSExt(SuccessFailureVal, R.ElemTy)
8090
: Builder.CreateZExt(SuccessFailureVal, R.ElemTy);
8091
Builder.CreateStore(ResultCast, R.Var, R.IsVolatile);
8092
}
8093
} else {
8094
assert((Op == OMPAtomicCompareOp::MAX || Op == OMPAtomicCompareOp::MIN) &&
8095
"Op should be either max or min at this point");
8096
assert(!IsFailOnly && "IsFailOnly is only valid when the comparison is ==");
8097
8098
// Reverse the ordop as the OpenMP forms are different from LLVM forms.
8099
// Let's take max as example.
8100
// OpenMP form:
8101
// x = x > expr ? expr : x;
8102
// LLVM form:
8103
// *ptr = *ptr > val ? *ptr : val;
8104
// We need to transform to LLVM form.
8105
// x = x <= expr ? x : expr;
8106
AtomicRMWInst::BinOp NewOp;
8107
if (IsXBinopExpr) {
8108
if (IsInteger) {
8109
if (X.IsSigned)
8110
NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::Min
8111
: AtomicRMWInst::Max;
8112
else
8113
NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::UMin
8114
: AtomicRMWInst::UMax;
8115
} else {
8116
NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::FMin
8117
: AtomicRMWInst::FMax;
8118
}
8119
} else {
8120
if (IsInteger) {
8121
if (X.IsSigned)
8122
NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::Max
8123
: AtomicRMWInst::Min;
8124
else
8125
NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::UMax
8126
: AtomicRMWInst::UMin;
8127
} else {
8128
NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::FMax
8129
: AtomicRMWInst::FMin;
8130
}
8131
}
8132
8133
AtomicRMWInst *OldValue =
8134
Builder.CreateAtomicRMW(NewOp, X.Var, E, MaybeAlign(), AO);
8135
if (V.Var) {
8136
Value *CapturedValue = nullptr;
8137
if (IsPostfixUpdate) {
8138
CapturedValue = OldValue;
8139
} else {
8140
CmpInst::Predicate Pred;
8141
switch (NewOp) {
8142
case AtomicRMWInst::Max:
8143
Pred = CmpInst::ICMP_SGT;
8144
break;
8145
case AtomicRMWInst::UMax:
8146
Pred = CmpInst::ICMP_UGT;
8147
break;
8148
case AtomicRMWInst::FMax:
8149
Pred = CmpInst::FCMP_OGT;
8150
break;
8151
case AtomicRMWInst::Min:
8152
Pred = CmpInst::ICMP_SLT;
8153
break;
8154
case AtomicRMWInst::UMin:
8155
Pred = CmpInst::ICMP_ULT;
8156
break;
8157
case AtomicRMWInst::FMin:
8158
Pred = CmpInst::FCMP_OLT;
8159
break;
8160
default:
8161
llvm_unreachable("unexpected comparison op");
8162
}
8163
Value *NonAtomicCmp = Builder.CreateCmp(Pred, OldValue, E);
8164
CapturedValue = Builder.CreateSelect(NonAtomicCmp, E, OldValue);
8165
}
8166
Builder.CreateStore(CapturedValue, V.Var, V.IsVolatile);
8167
}
8168
}
8169
8170
checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Compare);
8171
8172
return Builder.saveIP();
8173
}
8174
8175
OpenMPIRBuilder::InsertPointTy
8176
OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
8177
BodyGenCallbackTy BodyGenCB, Value *NumTeamsLower,
8178
Value *NumTeamsUpper, Value *ThreadLimit,
8179
Value *IfExpr) {
8180
if (!updateToLocation(Loc))
8181
return InsertPointTy();
8182
8183
uint32_t SrcLocStrSize;
8184
Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
8185
Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
8186
Function *CurrentFunction = Builder.GetInsertBlock()->getParent();
8187
8188
// Outer allocation basicblock is the entry block of the current function.
8189
BasicBlock &OuterAllocaBB = CurrentFunction->getEntryBlock();
8190
if (&OuterAllocaBB == Builder.GetInsertBlock()) {
8191
BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "teams.entry");
8192
Builder.SetInsertPoint(BodyBB, BodyBB->begin());
8193
}
8194
8195
// The current basic block is split into four basic blocks. After outlining,
8196
// they will be mapped as follows:
8197
// ```
8198
// def current_fn() {
8199
// current_basic_block:
8200
// br label %teams.exit
8201
// teams.exit:
8202
// ; instructions after teams
8203
// }
8204
//
8205
// def outlined_fn() {
8206
// teams.alloca:
8207
// br label %teams.body
8208
// teams.body:
8209
// ; instructions within teams body
8210
// }
8211
// ```
8212
BasicBlock *ExitBB = splitBB(Builder, /*CreateBranch=*/true, "teams.exit");
8213
BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "teams.body");
8214
BasicBlock *AllocaBB =
8215
splitBB(Builder, /*CreateBranch=*/true, "teams.alloca");
8216
8217
bool SubClausesPresent =
8218
(NumTeamsLower || NumTeamsUpper || ThreadLimit || IfExpr);
8219
// Push num_teams
8220
if (!Config.isTargetDevice() && SubClausesPresent) {
8221
assert((NumTeamsLower == nullptr || NumTeamsUpper != nullptr) &&
8222
"if lowerbound is non-null, then upperbound must also be non-null "
8223
"for bounds on num_teams");
8224
8225
if (NumTeamsUpper == nullptr)
8226
NumTeamsUpper = Builder.getInt32(0);
8227
8228
if (NumTeamsLower == nullptr)
8229
NumTeamsLower = NumTeamsUpper;
8230
8231
if (IfExpr) {
8232
assert(IfExpr->getType()->isIntegerTy() &&
8233
"argument to if clause must be an integer value");
8234
8235
// upper = ifexpr ? upper : 1
8236
if (IfExpr->getType() != Int1)
8237
IfExpr = Builder.CreateICmpNE(IfExpr,
8238
ConstantInt::get(IfExpr->getType(), 0));
8239
NumTeamsUpper = Builder.CreateSelect(
8240
IfExpr, NumTeamsUpper, Builder.getInt32(1), "numTeamsUpper");
8241
8242
// lower = ifexpr ? lower : 1
8243
NumTeamsLower = Builder.CreateSelect(
8244
IfExpr, NumTeamsLower, Builder.getInt32(1), "numTeamsLower");
8245
}
8246
8247
if (ThreadLimit == nullptr)
8248
ThreadLimit = Builder.getInt32(0);
8249
8250
Value *ThreadNum = getOrCreateThreadID(Ident);
8251
Builder.CreateCall(
8252
getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_num_teams_51),
8253
{Ident, ThreadNum, NumTeamsLower, NumTeamsUpper, ThreadLimit});
8254
}
8255
// Generate the body of teams.
8256
InsertPointTy AllocaIP(AllocaBB, AllocaBB->begin());
8257
InsertPointTy CodeGenIP(BodyBB, BodyBB->begin());
8258
BodyGenCB(AllocaIP, CodeGenIP);
8259
8260
OutlineInfo OI;
8261
OI.EntryBB = AllocaBB;
8262
OI.ExitBB = ExitBB;
8263
OI.OuterAllocaBB = &OuterAllocaBB;
8264
8265
// Insert fake values for global tid and bound tid.
8266
SmallVector<Instruction *, 8> ToBeDeleted;
8267
InsertPointTy OuterAllocaIP(&OuterAllocaBB, OuterAllocaBB.begin());
8268
OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
8269
Builder, OuterAllocaIP, ToBeDeleted, AllocaIP, "gid", true));
8270
OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
8271
Builder, OuterAllocaIP, ToBeDeleted, AllocaIP, "tid", true));
8272
8273
auto HostPostOutlineCB = [this, Ident,
8274
ToBeDeleted](Function &OutlinedFn) mutable {
8275
// The stale call instruction will be replaced with a new call instruction
8276
// for runtime call with the outlined function.
8277
8278
assert(OutlinedFn.getNumUses() == 1 &&
8279
"there must be a single user for the outlined function");
8280
CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back());
8281
ToBeDeleted.push_back(StaleCI);
8282
8283
assert((OutlinedFn.arg_size() == 2 || OutlinedFn.arg_size() == 3) &&
8284
"Outlined function must have two or three arguments only");
8285
8286
bool HasShared = OutlinedFn.arg_size() == 3;
8287
8288
OutlinedFn.getArg(0)->setName("global.tid.ptr");
8289
OutlinedFn.getArg(1)->setName("bound.tid.ptr");
8290
if (HasShared)
8291
OutlinedFn.getArg(2)->setName("data");
8292
8293
// Call to the runtime function for teams in the current function.
8294
assert(StaleCI && "Error while outlining - no CallInst user found for the "
8295
"outlined function.");
8296
Builder.SetInsertPoint(StaleCI);
8297
SmallVector<Value *> Args = {
8298
Ident, Builder.getInt32(StaleCI->arg_size() - 2), &OutlinedFn};
8299
if (HasShared)
8300
Args.push_back(StaleCI->getArgOperand(2));
8301
Builder.CreateCall(getOrCreateRuntimeFunctionPtr(
8302
omp::RuntimeFunction::OMPRTL___kmpc_fork_teams),
8303
Args);
8304
8305
llvm::for_each(llvm::reverse(ToBeDeleted),
8306
[](Instruction *I) { I->eraseFromParent(); });
8307
8308
};
8309
8310
if (!Config.isTargetDevice())
8311
OI.PostOutlineCB = HostPostOutlineCB;
8312
8313
addOutlineInfo(std::move(OI));
8314
8315
Builder.SetInsertPoint(ExitBB, ExitBB->begin());
8316
8317
return Builder.saveIP();
8318
}
8319
8320
GlobalVariable *
8321
OpenMPIRBuilder::createOffloadMapnames(SmallVectorImpl<llvm::Constant *> &Names,
8322
std::string VarName) {
8323
llvm::Constant *MapNamesArrayInit = llvm::ConstantArray::get(
8324
llvm::ArrayType::get(llvm::PointerType::getUnqual(M.getContext()),
8325
Names.size()),
8326
Names);
8327
auto *MapNamesArrayGlobal = new llvm::GlobalVariable(
8328
M, MapNamesArrayInit->getType(),
8329
/*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, MapNamesArrayInit,
8330
VarName);
8331
return MapNamesArrayGlobal;
8332
}
8333
8334
// Create all simple and struct types exposed by the runtime and remember
8335
// the llvm::PointerTypes of them for easy access later.
8336
void OpenMPIRBuilder::initializeTypes(Module &M) {
8337
LLVMContext &Ctx = M.getContext();
8338
StructType *T;
8339
#define OMP_TYPE(VarName, InitValue) VarName = InitValue;
8340
#define OMP_ARRAY_TYPE(VarName, ElemTy, ArraySize) \
8341
VarName##Ty = ArrayType::get(ElemTy, ArraySize); \
8342
VarName##PtrTy = PointerType::getUnqual(VarName##Ty);
8343
#define OMP_FUNCTION_TYPE(VarName, IsVarArg, ReturnType, ...) \
8344
VarName = FunctionType::get(ReturnType, {__VA_ARGS__}, IsVarArg); \
8345
VarName##Ptr = PointerType::getUnqual(VarName);
8346
#define OMP_STRUCT_TYPE(VarName, StructName, Packed, ...) \
8347
T = StructType::getTypeByName(Ctx, StructName); \
8348
if (!T) \
8349
T = StructType::create(Ctx, {__VA_ARGS__}, StructName, Packed); \
8350
VarName = T; \
8351
VarName##Ptr = PointerType::getUnqual(T);
8352
#include "llvm/Frontend/OpenMP/OMPKinds.def"
8353
}
8354
8355
void OpenMPIRBuilder::OutlineInfo::collectBlocks(
8356
SmallPtrSetImpl<BasicBlock *> &BlockSet,
8357
SmallVectorImpl<BasicBlock *> &BlockVector) {
8358
SmallVector<BasicBlock *, 32> Worklist;
8359
BlockSet.insert(EntryBB);
8360
BlockSet.insert(ExitBB);
8361
8362
Worklist.push_back(EntryBB);
8363
while (!Worklist.empty()) {
8364
BasicBlock *BB = Worklist.pop_back_val();
8365
BlockVector.push_back(BB);
8366
for (BasicBlock *SuccBB : successors(BB))
8367
if (BlockSet.insert(SuccBB).second)
8368
Worklist.push_back(SuccBB);
8369
}
8370
}
8371
8372
void OpenMPIRBuilder::createOffloadEntry(Constant *ID, Constant *Addr,
8373
uint64_t Size, int32_t Flags,
8374
GlobalValue::LinkageTypes,
8375
StringRef Name) {
8376
if (!Config.isGPU()) {
8377
llvm::offloading::emitOffloadingEntry(
8378
M, ID, Name.empty() ? Addr->getName() : Name, Size, Flags, /*Data=*/0,
8379
"omp_offloading_entries");
8380
return;
8381
}
8382
// TODO: Add support for global variables on the device after declare target
8383
// support.
8384
Function *Fn = dyn_cast<Function>(Addr);
8385
if (!Fn)
8386
return;
8387
8388
Module &M = *(Fn->getParent());
8389
LLVMContext &Ctx = M.getContext();
8390
8391
// Get "nvvm.annotations" metadata node.
8392
NamedMDNode *MD = M.getOrInsertNamedMetadata("nvvm.annotations");
8393
8394
Metadata *MDVals[] = {
8395
ConstantAsMetadata::get(Fn), MDString::get(Ctx, "kernel"),
8396
ConstantAsMetadata::get(ConstantInt::get(Type::getInt32Ty(Ctx), 1))};
8397
// Append metadata to nvvm.annotations.
8398
MD->addOperand(MDNode::get(Ctx, MDVals));
8399
8400
// Add a function attribute for the kernel.
8401
Fn->addFnAttr(Attribute::get(Ctx, "kernel"));
8402
if (T.isAMDGCN())
8403
Fn->addFnAttr("uniform-work-group-size", "true");
8404
Fn->addFnAttr(Attribute::MustProgress);
8405
}
8406
8407
// We only generate metadata for function that contain target regions.
8408
void OpenMPIRBuilder::createOffloadEntriesAndInfoMetadata(
8409
EmitMetadataErrorReportFunctionTy &ErrorFn) {
8410
8411
// If there are no entries, we don't need to do anything.
8412
if (OffloadInfoManager.empty())
8413
return;
8414
8415
LLVMContext &C = M.getContext();
8416
SmallVector<std::pair<const OffloadEntriesInfoManager::OffloadEntryInfo *,
8417
TargetRegionEntryInfo>,
8418
16>
8419
OrderedEntries(OffloadInfoManager.size());
8420
8421
// Auxiliary methods to create metadata values and strings.
8422
auto &&GetMDInt = [this](unsigned V) {
8423
return ConstantAsMetadata::get(ConstantInt::get(Builder.getInt32Ty(), V));
8424
};
8425
8426
auto &&GetMDString = [&C](StringRef V) { return MDString::get(C, V); };
8427
8428
// Create the offloading info metadata node.
8429
NamedMDNode *MD = M.getOrInsertNamedMetadata("omp_offload.info");
8430
auto &&TargetRegionMetadataEmitter =
8431
[&C, MD, &OrderedEntries, &GetMDInt, &GetMDString](
8432
const TargetRegionEntryInfo &EntryInfo,
8433
const OffloadEntriesInfoManager::OffloadEntryInfoTargetRegion &E) {
8434
// Generate metadata for target regions. Each entry of this metadata
8435
// contains:
8436
// - Entry 0 -> Kind of this type of metadata (0).
8437
// - Entry 1 -> Device ID of the file where the entry was identified.
8438
// - Entry 2 -> File ID of the file where the entry was identified.
8439
// - Entry 3 -> Mangled name of the function where the entry was
8440
// identified.
8441
// - Entry 4 -> Line in the file where the entry was identified.
8442
// - Entry 5 -> Count of regions at this DeviceID/FilesID/Line.
8443
// - Entry 6 -> Order the entry was created.
8444
// The first element of the metadata node is the kind.
8445
Metadata *Ops[] = {
8446
GetMDInt(E.getKind()), GetMDInt(EntryInfo.DeviceID),
8447
GetMDInt(EntryInfo.FileID), GetMDString(EntryInfo.ParentName),
8448
GetMDInt(EntryInfo.Line), GetMDInt(EntryInfo.Count),
8449
GetMDInt(E.getOrder())};
8450
8451
// Save this entry in the right position of the ordered entries array.
8452
OrderedEntries[E.getOrder()] = std::make_pair(&E, EntryInfo);
8453
8454
// Add metadata to the named metadata node.
8455
MD->addOperand(MDNode::get(C, Ops));
8456
};
8457
8458
OffloadInfoManager.actOnTargetRegionEntriesInfo(TargetRegionMetadataEmitter);
8459
8460
// Create function that emits metadata for each device global variable entry;
8461
auto &&DeviceGlobalVarMetadataEmitter =
8462
[&C, &OrderedEntries, &GetMDInt, &GetMDString, MD](
8463
StringRef MangledName,
8464
const OffloadEntriesInfoManager::OffloadEntryInfoDeviceGlobalVar &E) {
8465
// Generate metadata for global variables. Each entry of this metadata
8466
// contains:
8467
// - Entry 0 -> Kind of this type of metadata (1).
8468
// - Entry 1 -> Mangled name of the variable.
8469
// - Entry 2 -> Declare target kind.
8470
// - Entry 3 -> Order the entry was created.
8471
// The first element of the metadata node is the kind.
8472
Metadata *Ops[] = {GetMDInt(E.getKind()), GetMDString(MangledName),
8473
GetMDInt(E.getFlags()), GetMDInt(E.getOrder())};
8474
8475
// Save this entry in the right position of the ordered entries array.
8476
TargetRegionEntryInfo varInfo(MangledName, 0, 0, 0);
8477
OrderedEntries[E.getOrder()] = std::make_pair(&E, varInfo);
8478
8479
// Add metadata to the named metadata node.
8480
MD->addOperand(MDNode::get(C, Ops));
8481
};
8482
8483
OffloadInfoManager.actOnDeviceGlobalVarEntriesInfo(
8484
DeviceGlobalVarMetadataEmitter);
8485
8486
for (const auto &E : OrderedEntries) {
8487
assert(E.first && "All ordered entries must exist!");
8488
if (const auto *CE =
8489
dyn_cast<OffloadEntriesInfoManager::OffloadEntryInfoTargetRegion>(
8490
E.first)) {
8491
if (!CE->getID() || !CE->getAddress()) {
8492
// Do not blame the entry if the parent funtion is not emitted.
8493
TargetRegionEntryInfo EntryInfo = E.second;
8494
StringRef FnName = EntryInfo.ParentName;
8495
if (!M.getNamedValue(FnName))
8496
continue;
8497
ErrorFn(EMIT_MD_TARGET_REGION_ERROR, EntryInfo);
8498
continue;
8499
}
8500
createOffloadEntry(CE->getID(), CE->getAddress(),
8501
/*Size=*/0, CE->getFlags(),
8502
GlobalValue::WeakAnyLinkage);
8503
} else if (const auto *CE = dyn_cast<
8504
OffloadEntriesInfoManager::OffloadEntryInfoDeviceGlobalVar>(
8505
E.first)) {
8506
OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind Flags =
8507
static_cast<OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind>(
8508
CE->getFlags());
8509
switch (Flags) {
8510
case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter:
8511
case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo:
8512
if (Config.isTargetDevice() && Config.hasRequiresUnifiedSharedMemory())
8513
continue;
8514
if (!CE->getAddress()) {
8515
ErrorFn(EMIT_MD_DECLARE_TARGET_ERROR, E.second);
8516
continue;
8517
}
8518
// The vaiable has no definition - no need to add the entry.
8519
if (CE->getVarSize() == 0)
8520
continue;
8521
break;
8522
case OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink:
8523
assert(((Config.isTargetDevice() && !CE->getAddress()) ||
8524
(!Config.isTargetDevice() && CE->getAddress())) &&
8525
"Declaret target link address is set.");
8526
if (Config.isTargetDevice())
8527
continue;
8528
if (!CE->getAddress()) {
8529
ErrorFn(EMIT_MD_GLOBAL_VAR_LINK_ERROR, TargetRegionEntryInfo());
8530
continue;
8531
}
8532
break;
8533
default:
8534
break;
8535
}
8536
8537
// Hidden or internal symbols on the device are not externally visible.
8538
// We should not attempt to register them by creating an offloading
8539
// entry. Indirect variables are handled separately on the device.
8540
if (auto *GV = dyn_cast<GlobalValue>(CE->getAddress()))
8541
if ((GV->hasLocalLinkage() || GV->hasHiddenVisibility()) &&
8542
Flags != OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirect)
8543
continue;
8544
8545
// Indirect globals need to use a special name that doesn't match the name
8546
// of the associated host global.
8547
if (Flags == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirect)
8548
createOffloadEntry(CE->getAddress(), CE->getAddress(), CE->getVarSize(),
8549
Flags, CE->getLinkage(), CE->getVarName());
8550
else
8551
createOffloadEntry(CE->getAddress(), CE->getAddress(), CE->getVarSize(),
8552
Flags, CE->getLinkage());
8553
8554
} else {
8555
llvm_unreachable("Unsupported entry kind.");
8556
}
8557
}
8558
8559
// Emit requires directive globals to a special entry so the runtime can
8560
// register them when the device image is loaded.
8561
// TODO: This reduces the offloading entries to a 32-bit integer. Offloading
8562
// entries should be redesigned to better suit this use-case.
8563
if (Config.hasRequiresFlags() && !Config.isTargetDevice())
8564
offloading::emitOffloadingEntry(
8565
M, Constant::getNullValue(PointerType::getUnqual(M.getContext())),
8566
/*Name=*/"",
8567
/*Size=*/0, OffloadEntriesInfoManager::OMPTargetGlobalRegisterRequires,
8568
Config.getRequiresFlags(), "omp_offloading_entries");
8569
}
8570
8571
void TargetRegionEntryInfo::getTargetRegionEntryFnName(
8572
SmallVectorImpl<char> &Name, StringRef ParentName, unsigned DeviceID,
8573
unsigned FileID, unsigned Line, unsigned Count) {
8574
raw_svector_ostream OS(Name);
8575
OS << "__omp_offloading" << llvm::format("_%x", DeviceID)
8576
<< llvm::format("_%x_", FileID) << ParentName << "_l" << Line;
8577
if (Count)
8578
OS << "_" << Count;
8579
}
8580
8581
void OffloadEntriesInfoManager::getTargetRegionEntryFnName(
8582
SmallVectorImpl<char> &Name, const TargetRegionEntryInfo &EntryInfo) {
8583
unsigned NewCount = getTargetRegionEntryInfoCount(EntryInfo);
8584
TargetRegionEntryInfo::getTargetRegionEntryFnName(
8585
Name, EntryInfo.ParentName, EntryInfo.DeviceID, EntryInfo.FileID,
8586
EntryInfo.Line, NewCount);
8587
}
8588
8589
TargetRegionEntryInfo
8590
OpenMPIRBuilder::getTargetEntryUniqueInfo(FileIdentifierInfoCallbackTy CallBack,
8591
StringRef ParentName) {
8592
sys::fs::UniqueID ID;
8593
auto FileIDInfo = CallBack();
8594
if (auto EC = sys::fs::getUniqueID(std::get<0>(FileIDInfo), ID)) {
8595
report_fatal_error(("Unable to get unique ID for file, during "
8596
"getTargetEntryUniqueInfo, error message: " +
8597
EC.message())
8598
.c_str());
8599
}
8600
8601
return TargetRegionEntryInfo(ParentName, ID.getDevice(), ID.getFile(),
8602
std::get<1>(FileIDInfo));
8603
}
8604
8605
unsigned OpenMPIRBuilder::getFlagMemberOffset() {
8606
unsigned Offset = 0;
8607
for (uint64_t Remain =
8608
static_cast<std::underlying_type_t<omp::OpenMPOffloadMappingFlags>>(
8609
omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF);
8610
!(Remain & 1); Remain = Remain >> 1)
8611
Offset++;
8612
return Offset;
8613
}
8614
8615
omp::OpenMPOffloadMappingFlags
8616
OpenMPIRBuilder::getMemberOfFlag(unsigned Position) {
8617
// Rotate by getFlagMemberOffset() bits.
8618
return static_cast<omp::OpenMPOffloadMappingFlags>(((uint64_t)Position + 1)
8619
<< getFlagMemberOffset());
8620
}
8621
8622
void OpenMPIRBuilder::setCorrectMemberOfFlag(
8623
omp::OpenMPOffloadMappingFlags &Flags,
8624
omp::OpenMPOffloadMappingFlags MemberOfFlag) {
8625
// If the entry is PTR_AND_OBJ but has not been marked with the special
8626
// placeholder value 0xFFFF in the MEMBER_OF field, then it should not be
8627
// marked as MEMBER_OF.
8628
if (static_cast<std::underlying_type_t<omp::OpenMPOffloadMappingFlags>>(
8629
Flags & omp::OpenMPOffloadMappingFlags::OMP_MAP_PTR_AND_OBJ) &&
8630
static_cast<std::underlying_type_t<omp::OpenMPOffloadMappingFlags>>(
8631
(Flags & omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF) !=
8632
omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF))
8633
return;
8634
8635
// Reset the placeholder value to prepare the flag for the assignment of the
8636
// proper MEMBER_OF value.
8637
Flags &= ~omp::OpenMPOffloadMappingFlags::OMP_MAP_MEMBER_OF;
8638
Flags |= MemberOfFlag;
8639
}
8640
8641
Constant *OpenMPIRBuilder::getAddrOfDeclareTargetVar(
8642
OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind CaptureClause,
8643
OffloadEntriesInfoManager::OMPTargetDeviceClauseKind DeviceClause,
8644
bool IsDeclaration, bool IsExternallyVisible,
8645
TargetRegionEntryInfo EntryInfo, StringRef MangledName,
8646
std::vector<GlobalVariable *> &GeneratedRefs, bool OpenMPSIMD,
8647
std::vector<Triple> TargetTriple, Type *LlvmPtrTy,
8648
std::function<Constant *()> GlobalInitializer,
8649
std::function<GlobalValue::LinkageTypes()> VariableLinkage) {
8650
// TODO: convert this to utilise the IRBuilder Config rather than
8651
// a passed down argument.
8652
if (OpenMPSIMD)
8653
return nullptr;
8654
8655
if (CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink ||
8656
((CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo ||
8657
CaptureClause ==
8658
OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter) &&
8659
Config.hasRequiresUnifiedSharedMemory())) {
8660
SmallString<64> PtrName;
8661
{
8662
raw_svector_ostream OS(PtrName);
8663
OS << MangledName;
8664
if (!IsExternallyVisible)
8665
OS << format("_%x", EntryInfo.FileID);
8666
OS << "_decl_tgt_ref_ptr";
8667
}
8668
8669
Value *Ptr = M.getNamedValue(PtrName);
8670
8671
if (!Ptr) {
8672
GlobalValue *GlobalValue = M.getNamedValue(MangledName);
8673
Ptr = getOrCreateInternalVariable(LlvmPtrTy, PtrName);
8674
8675
auto *GV = cast<GlobalVariable>(Ptr);
8676
GV->setLinkage(GlobalValue::WeakAnyLinkage);
8677
8678
if (!Config.isTargetDevice()) {
8679
if (GlobalInitializer)
8680
GV->setInitializer(GlobalInitializer());
8681
else
8682
GV->setInitializer(GlobalValue);
8683
}
8684
8685
registerTargetGlobalVariable(
8686
CaptureClause, DeviceClause, IsDeclaration, IsExternallyVisible,
8687
EntryInfo, MangledName, GeneratedRefs, OpenMPSIMD, TargetTriple,
8688
GlobalInitializer, VariableLinkage, LlvmPtrTy, cast<Constant>(Ptr));
8689
}
8690
8691
return cast<Constant>(Ptr);
8692
}
8693
8694
return nullptr;
8695
}
8696
8697
void OpenMPIRBuilder::registerTargetGlobalVariable(
8698
OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind CaptureClause,
8699
OffloadEntriesInfoManager::OMPTargetDeviceClauseKind DeviceClause,
8700
bool IsDeclaration, bool IsExternallyVisible,
8701
TargetRegionEntryInfo EntryInfo, StringRef MangledName,
8702
std::vector<GlobalVariable *> &GeneratedRefs, bool OpenMPSIMD,
8703
std::vector<Triple> TargetTriple,
8704
std::function<Constant *()> GlobalInitializer,
8705
std::function<GlobalValue::LinkageTypes()> VariableLinkage, Type *LlvmPtrTy,
8706
Constant *Addr) {
8707
if (DeviceClause != OffloadEntriesInfoManager::OMPTargetDeviceClauseAny ||
8708
(TargetTriple.empty() && !Config.isTargetDevice()))
8709
return;
8710
8711
OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind Flags;
8712
StringRef VarName;
8713
int64_t VarSize;
8714
GlobalValue::LinkageTypes Linkage;
8715
8716
if ((CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo ||
8717
CaptureClause ==
8718
OffloadEntriesInfoManager::OMPTargetGlobalVarEntryEnter) &&
8719
!Config.hasRequiresUnifiedSharedMemory()) {
8720
Flags = OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
8721
VarName = MangledName;
8722
GlobalValue *LlvmVal = M.getNamedValue(VarName);
8723
8724
if (!IsDeclaration)
8725
VarSize = divideCeil(
8726
M.getDataLayout().getTypeSizeInBits(LlvmVal->getValueType()), 8);
8727
else
8728
VarSize = 0;
8729
Linkage = (VariableLinkage) ? VariableLinkage() : LlvmVal->getLinkage();
8730
8731
// This is a workaround carried over from Clang which prevents undesired
8732
// optimisation of internal variables.
8733
if (Config.isTargetDevice() &&
8734
(!IsExternallyVisible || Linkage == GlobalValue::LinkOnceODRLinkage)) {
8735
// Do not create a "ref-variable" if the original is not also available
8736
// on the host.
8737
if (!OffloadInfoManager.hasDeviceGlobalVarEntryInfo(VarName))
8738
return;
8739
8740
std::string RefName = createPlatformSpecificName({VarName, "ref"});
8741
8742
if (!M.getNamedValue(RefName)) {
8743
Constant *AddrRef =
8744
getOrCreateInternalVariable(Addr->getType(), RefName);
8745
auto *GvAddrRef = cast<GlobalVariable>(AddrRef);
8746
GvAddrRef->setConstant(true);
8747
GvAddrRef->setLinkage(GlobalValue::InternalLinkage);
8748
GvAddrRef->setInitializer(Addr);
8749
GeneratedRefs.push_back(GvAddrRef);
8750
}
8751
}
8752
} else {
8753
if (CaptureClause == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink)
8754
Flags = OffloadEntriesInfoManager::OMPTargetGlobalVarEntryLink;
8755
else
8756
Flags = OffloadEntriesInfoManager::OMPTargetGlobalVarEntryTo;
8757
8758
if (Config.isTargetDevice()) {
8759
VarName = (Addr) ? Addr->getName() : "";
8760
Addr = nullptr;
8761
} else {
8762
Addr = getAddrOfDeclareTargetVar(
8763
CaptureClause, DeviceClause, IsDeclaration, IsExternallyVisible,
8764
EntryInfo, MangledName, GeneratedRefs, OpenMPSIMD, TargetTriple,
8765
LlvmPtrTy, GlobalInitializer, VariableLinkage);
8766
VarName = (Addr) ? Addr->getName() : "";
8767
}
8768
VarSize = M.getDataLayout().getPointerSize();
8769
Linkage = GlobalValue::WeakAnyLinkage;
8770
}
8771
8772
OffloadInfoManager.registerDeviceGlobalVarEntryInfo(VarName, Addr, VarSize,
8773
Flags, Linkage);
8774
}
8775
8776
/// Loads all the offload entries information from the host IR
8777
/// metadata.
8778
void OpenMPIRBuilder::loadOffloadInfoMetadata(Module &M) {
8779
// If we are in target mode, load the metadata from the host IR. This code has
8780
// to match the metadata creation in createOffloadEntriesAndInfoMetadata().
8781
8782
NamedMDNode *MD = M.getNamedMetadata(ompOffloadInfoName);
8783
if (!MD)
8784
return;
8785
8786
for (MDNode *MN : MD->operands()) {
8787
auto &&GetMDInt = [MN](unsigned Idx) {
8788
auto *V = cast<ConstantAsMetadata>(MN->getOperand(Idx));
8789
return cast<ConstantInt>(V->getValue())->getZExtValue();
8790
};
8791
8792
auto &&GetMDString = [MN](unsigned Idx) {
8793
auto *V = cast<MDString>(MN->getOperand(Idx));
8794
return V->getString();
8795
};
8796
8797
switch (GetMDInt(0)) {
8798
default:
8799
llvm_unreachable("Unexpected metadata!");
8800
break;
8801
case OffloadEntriesInfoManager::OffloadEntryInfo::
8802
OffloadingEntryInfoTargetRegion: {
8803
TargetRegionEntryInfo EntryInfo(/*ParentName=*/GetMDString(3),
8804
/*DeviceID=*/GetMDInt(1),
8805
/*FileID=*/GetMDInt(2),
8806
/*Line=*/GetMDInt(4),
8807
/*Count=*/GetMDInt(5));
8808
OffloadInfoManager.initializeTargetRegionEntryInfo(EntryInfo,
8809
/*Order=*/GetMDInt(6));
8810
break;
8811
}
8812
case OffloadEntriesInfoManager::OffloadEntryInfo::
8813
OffloadingEntryInfoDeviceGlobalVar:
8814
OffloadInfoManager.initializeDeviceGlobalVarEntryInfo(
8815
/*MangledName=*/GetMDString(1),
8816
static_cast<OffloadEntriesInfoManager::OMPTargetGlobalVarEntryKind>(
8817
/*Flags=*/GetMDInt(2)),
8818
/*Order=*/GetMDInt(3));
8819
break;
8820
}
8821
}
8822
}
8823
8824
void OpenMPIRBuilder::loadOffloadInfoMetadata(StringRef HostFilePath) {
8825
if (HostFilePath.empty())
8826
return;
8827
8828
auto Buf = MemoryBuffer::getFile(HostFilePath);
8829
if (std::error_code Err = Buf.getError()) {
8830
report_fatal_error(("error opening host file from host file path inside of "
8831
"OpenMPIRBuilder: " +
8832
Err.message())
8833
.c_str());
8834
}
8835
8836
LLVMContext Ctx;
8837
auto M = expectedToErrorOrAndEmitErrors(
8838
Ctx, parseBitcodeFile(Buf.get()->getMemBufferRef(), Ctx));
8839
if (std::error_code Err = M.getError()) {
8840
report_fatal_error(
8841
("error parsing host file inside of OpenMPIRBuilder: " + Err.message())
8842
.c_str());
8843
}
8844
8845
loadOffloadInfoMetadata(*M.get());
8846
}
8847
8848
//===----------------------------------------------------------------------===//
8849
// OffloadEntriesInfoManager
8850
//===----------------------------------------------------------------------===//
8851
8852
bool OffloadEntriesInfoManager::empty() const {
8853
return OffloadEntriesTargetRegion.empty() &&
8854
OffloadEntriesDeviceGlobalVar.empty();
8855
}
8856
8857
unsigned OffloadEntriesInfoManager::getTargetRegionEntryInfoCount(
8858
const TargetRegionEntryInfo &EntryInfo) const {
8859
auto It = OffloadEntriesTargetRegionCount.find(
8860
getTargetRegionEntryCountKey(EntryInfo));
8861
if (It == OffloadEntriesTargetRegionCount.end())
8862
return 0;
8863
return It->second;
8864
}
8865
8866
void OffloadEntriesInfoManager::incrementTargetRegionEntryInfoCount(
8867
const TargetRegionEntryInfo &EntryInfo) {
8868
OffloadEntriesTargetRegionCount[getTargetRegionEntryCountKey(EntryInfo)] =
8869
EntryInfo.Count + 1;
8870
}
8871
8872
/// Initialize target region entry.
8873
void OffloadEntriesInfoManager::initializeTargetRegionEntryInfo(
8874
const TargetRegionEntryInfo &EntryInfo, unsigned Order) {
8875
OffloadEntriesTargetRegion[EntryInfo] =
8876
OffloadEntryInfoTargetRegion(Order, /*Addr=*/nullptr, /*ID=*/nullptr,
8877
OMPTargetRegionEntryTargetRegion);
8878
++OffloadingEntriesNum;
8879
}
8880
8881
void OffloadEntriesInfoManager::registerTargetRegionEntryInfo(
8882
TargetRegionEntryInfo EntryInfo, Constant *Addr, Constant *ID,
8883
OMPTargetRegionEntryKind Flags) {
8884
assert(EntryInfo.Count == 0 && "expected default EntryInfo");
8885
8886
// Update the EntryInfo with the next available count for this location.
8887
EntryInfo.Count = getTargetRegionEntryInfoCount(EntryInfo);
8888
8889
// If we are emitting code for a target, the entry is already initialized,
8890
// only has to be registered.
8891
if (OMPBuilder->Config.isTargetDevice()) {
8892
// This could happen if the device compilation is invoked standalone.
8893
if (!hasTargetRegionEntryInfo(EntryInfo)) {
8894
return;
8895
}
8896
auto &Entry = OffloadEntriesTargetRegion[EntryInfo];
8897
Entry.setAddress(Addr);
8898
Entry.setID(ID);
8899
Entry.setFlags(Flags);
8900
} else {
8901
if (Flags == OffloadEntriesInfoManager::OMPTargetRegionEntryTargetRegion &&
8902
hasTargetRegionEntryInfo(EntryInfo, /*IgnoreAddressId*/ true))
8903
return;
8904
assert(!hasTargetRegionEntryInfo(EntryInfo) &&
8905
"Target region entry already registered!");
8906
OffloadEntryInfoTargetRegion Entry(OffloadingEntriesNum, Addr, ID, Flags);
8907
OffloadEntriesTargetRegion[EntryInfo] = Entry;
8908
++OffloadingEntriesNum;
8909
}
8910
incrementTargetRegionEntryInfoCount(EntryInfo);
8911
}
8912
8913
bool OffloadEntriesInfoManager::hasTargetRegionEntryInfo(
8914
TargetRegionEntryInfo EntryInfo, bool IgnoreAddressId) const {
8915
8916
// Update the EntryInfo with the next available count for this location.
8917
EntryInfo.Count = getTargetRegionEntryInfoCount(EntryInfo);
8918
8919
auto It = OffloadEntriesTargetRegion.find(EntryInfo);
8920
if (It == OffloadEntriesTargetRegion.end()) {
8921
return false;
8922
}
8923
// Fail if this entry is already registered.
8924
if (!IgnoreAddressId && (It->second.getAddress() || It->second.getID()))
8925
return false;
8926
return true;
8927
}
8928
8929
void OffloadEntriesInfoManager::actOnTargetRegionEntriesInfo(
8930
const OffloadTargetRegionEntryInfoActTy &Action) {
8931
// Scan all target region entries and perform the provided action.
8932
for (const auto &It : OffloadEntriesTargetRegion) {
8933
Action(It.first, It.second);
8934
}
8935
}
8936
8937
void OffloadEntriesInfoManager::initializeDeviceGlobalVarEntryInfo(
8938
StringRef Name, OMPTargetGlobalVarEntryKind Flags, unsigned Order) {
8939
OffloadEntriesDeviceGlobalVar.try_emplace(Name, Order, Flags);
8940
++OffloadingEntriesNum;
8941
}
8942
8943
void OffloadEntriesInfoManager::registerDeviceGlobalVarEntryInfo(
8944
StringRef VarName, Constant *Addr, int64_t VarSize,
8945
OMPTargetGlobalVarEntryKind Flags, GlobalValue::LinkageTypes Linkage) {
8946
if (OMPBuilder->Config.isTargetDevice()) {
8947
// This could happen if the device compilation is invoked standalone.
8948
if (!hasDeviceGlobalVarEntryInfo(VarName))
8949
return;
8950
auto &Entry = OffloadEntriesDeviceGlobalVar[VarName];
8951
if (Entry.getAddress() && hasDeviceGlobalVarEntryInfo(VarName)) {
8952
if (Entry.getVarSize() == 0) {
8953
Entry.setVarSize(VarSize);
8954
Entry.setLinkage(Linkage);
8955
}
8956
return;
8957
}
8958
Entry.setVarSize(VarSize);
8959
Entry.setLinkage(Linkage);
8960
Entry.setAddress(Addr);
8961
} else {
8962
if (hasDeviceGlobalVarEntryInfo(VarName)) {
8963
auto &Entry = OffloadEntriesDeviceGlobalVar[VarName];
8964
assert(Entry.isValid() && Entry.getFlags() == Flags &&
8965
"Entry not initialized!");
8966
if (Entry.getVarSize() == 0) {
8967
Entry.setVarSize(VarSize);
8968
Entry.setLinkage(Linkage);
8969
}
8970
return;
8971
}
8972
if (Flags == OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirect)
8973
OffloadEntriesDeviceGlobalVar.try_emplace(VarName, OffloadingEntriesNum,
8974
Addr, VarSize, Flags, Linkage,
8975
VarName.str());
8976
else
8977
OffloadEntriesDeviceGlobalVar.try_emplace(
8978
VarName, OffloadingEntriesNum, Addr, VarSize, Flags, Linkage, "");
8979
++OffloadingEntriesNum;
8980
}
8981
}
8982
8983
void OffloadEntriesInfoManager::actOnDeviceGlobalVarEntriesInfo(
8984
const OffloadDeviceGlobalVarEntryInfoActTy &Action) {
8985
// Scan all target region entries and perform the provided action.
8986
for (const auto &E : OffloadEntriesDeviceGlobalVar)
8987
Action(E.getKey(), E.getValue());
8988
}
8989
8990
//===----------------------------------------------------------------------===//
8991
// CanonicalLoopInfo
8992
//===----------------------------------------------------------------------===//
8993
8994
void CanonicalLoopInfo::collectControlBlocks(
8995
SmallVectorImpl<BasicBlock *> &BBs) {
8996
// We only count those BBs as control block for which we do not need to
8997
// reverse the CFG, i.e. not the loop body which can contain arbitrary control
8998
// flow. For consistency, this also means we do not add the Body block, which
8999
// is just the entry to the body code.
9000
BBs.reserve(BBs.size() + 6);
9001
BBs.append({getPreheader(), Header, Cond, Latch, Exit, getAfter()});
9002
}
9003
9004
BasicBlock *CanonicalLoopInfo::getPreheader() const {
9005
assert(isValid() && "Requires a valid canonical loop");
9006
for (BasicBlock *Pred : predecessors(Header)) {
9007
if (Pred != Latch)
9008
return Pred;
9009
}
9010
llvm_unreachable("Missing preheader");
9011
}
9012
9013
void CanonicalLoopInfo::setTripCount(Value *TripCount) {
9014
assert(isValid() && "Requires a valid canonical loop");
9015
9016
Instruction *CmpI = &getCond()->front();
9017
assert(isa<CmpInst>(CmpI) && "First inst must compare IV with TripCount");
9018
CmpI->setOperand(1, TripCount);
9019
9020
#ifndef NDEBUG
9021
assertOK();
9022
#endif
9023
}
9024
9025
void CanonicalLoopInfo::mapIndVar(
9026
llvm::function_ref<Value *(Instruction *)> Updater) {
9027
assert(isValid() && "Requires a valid canonical loop");
9028
9029
Instruction *OldIV = getIndVar();
9030
9031
// Record all uses excluding those introduced by the updater. Uses by the
9032
// CanonicalLoopInfo itself to keep track of the number of iterations are
9033
// excluded.
9034
SmallVector<Use *> ReplacableUses;
9035
for (Use &U : OldIV->uses()) {
9036
auto *User = dyn_cast<Instruction>(U.getUser());
9037
if (!User)
9038
continue;
9039
if (User->getParent() == getCond())
9040
continue;
9041
if (User->getParent() == getLatch())
9042
continue;
9043
ReplacableUses.push_back(&U);
9044
}
9045
9046
// Run the updater that may introduce new uses
9047
Value *NewIV = Updater(OldIV);
9048
9049
// Replace the old uses with the value returned by the updater.
9050
for (Use *U : ReplacableUses)
9051
U->set(NewIV);
9052
9053
#ifndef NDEBUG
9054
assertOK();
9055
#endif
9056
}
9057
9058
void CanonicalLoopInfo::assertOK() const {
9059
#ifndef NDEBUG
9060
// No constraints if this object currently does not describe a loop.
9061
if (!isValid())
9062
return;
9063
9064
BasicBlock *Preheader = getPreheader();
9065
BasicBlock *Body = getBody();
9066
BasicBlock *After = getAfter();
9067
9068
// Verify standard control-flow we use for OpenMP loops.
9069
assert(Preheader);
9070
assert(isa<BranchInst>(Preheader->getTerminator()) &&
9071
"Preheader must terminate with unconditional branch");
9072
assert(Preheader->getSingleSuccessor() == Header &&
9073
"Preheader must jump to header");
9074
9075
assert(Header);
9076
assert(isa<BranchInst>(Header->getTerminator()) &&
9077
"Header must terminate with unconditional branch");
9078
assert(Header->getSingleSuccessor() == Cond &&
9079
"Header must jump to exiting block");
9080
9081
assert(Cond);
9082
assert(Cond->getSinglePredecessor() == Header &&
9083
"Exiting block only reachable from header");
9084
9085
assert(isa<BranchInst>(Cond->getTerminator()) &&
9086
"Exiting block must terminate with conditional branch");
9087
assert(size(successors(Cond)) == 2 &&
9088
"Exiting block must have two successors");
9089
assert(cast<BranchInst>(Cond->getTerminator())->getSuccessor(0) == Body &&
9090
"Exiting block's first successor jump to the body");
9091
assert(cast<BranchInst>(Cond->getTerminator())->getSuccessor(1) == Exit &&
9092
"Exiting block's second successor must exit the loop");
9093
9094
assert(Body);
9095
assert(Body->getSinglePredecessor() == Cond &&
9096
"Body only reachable from exiting block");
9097
assert(!isa<PHINode>(Body->front()));
9098
9099
assert(Latch);
9100
assert(isa<BranchInst>(Latch->getTerminator()) &&
9101
"Latch must terminate with unconditional branch");
9102
assert(Latch->getSingleSuccessor() == Header && "Latch must jump to header");
9103
// TODO: To support simple redirecting of the end of the body code that has
9104
// multiple; introduce another auxiliary basic block like preheader and after.
9105
assert(Latch->getSinglePredecessor() != nullptr);
9106
assert(!isa<PHINode>(Latch->front()));
9107
9108
assert(Exit);
9109
assert(isa<BranchInst>(Exit->getTerminator()) &&
9110
"Exit block must terminate with unconditional branch");
9111
assert(Exit->getSingleSuccessor() == After &&
9112
"Exit block must jump to after block");
9113
9114
assert(After);
9115
assert(After->getSinglePredecessor() == Exit &&
9116
"After block only reachable from exit block");
9117
assert(After->empty() || !isa<PHINode>(After->front()));
9118
9119
Instruction *IndVar = getIndVar();
9120
assert(IndVar && "Canonical induction variable not found?");
9121
assert(isa<IntegerType>(IndVar->getType()) &&
9122
"Induction variable must be an integer");
9123
assert(cast<PHINode>(IndVar)->getParent() == Header &&
9124
"Induction variable must be a PHI in the loop header");
9125
assert(cast<PHINode>(IndVar)->getIncomingBlock(0) == Preheader);
9126
assert(
9127
cast<ConstantInt>(cast<PHINode>(IndVar)->getIncomingValue(0))->isZero());
9128
assert(cast<PHINode>(IndVar)->getIncomingBlock(1) == Latch);
9129
9130
auto *NextIndVar = cast<PHINode>(IndVar)->getIncomingValue(1);
9131
assert(cast<Instruction>(NextIndVar)->getParent() == Latch);
9132
assert(cast<BinaryOperator>(NextIndVar)->getOpcode() == BinaryOperator::Add);
9133
assert(cast<BinaryOperator>(NextIndVar)->getOperand(0) == IndVar);
9134
assert(cast<ConstantInt>(cast<BinaryOperator>(NextIndVar)->getOperand(1))
9135
->isOne());
9136
9137
Value *TripCount = getTripCount();
9138
assert(TripCount && "Loop trip count not found?");
9139
assert(IndVar->getType() == TripCount->getType() &&
9140
"Trip count and induction variable must have the same type");
9141
9142
auto *CmpI = cast<CmpInst>(&Cond->front());
9143
assert(CmpI->getPredicate() == CmpInst::ICMP_ULT &&
9144
"Exit condition must be a signed less-than comparison");
9145
assert(CmpI->getOperand(0) == IndVar &&
9146
"Exit condition must compare the induction variable");
9147
assert(CmpI->getOperand(1) == TripCount &&
9148
"Exit condition must compare with the trip count");
9149
#endif
9150
}
9151
9152
void CanonicalLoopInfo::invalidate() {
9153
Header = nullptr;
9154
Cond = nullptr;
9155
Latch = nullptr;
9156
Exit = nullptr;
9157
}
9158
9159