Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
freebsd
GitHub Repository: freebsd/freebsd-src
Path: blob/main/contrib/llvm-project/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp
35269 views
1
//===- RISCVGatherScatterLowering.cpp - Gather/Scatter lowering -----------===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
//
9
// This pass custom lowers llvm.gather and llvm.scatter instructions to
10
// RISC-V intrinsics.
11
//
12
//===----------------------------------------------------------------------===//
13
14
#include "RISCV.h"
15
#include "RISCVTargetMachine.h"
16
#include "llvm/Analysis/InstSimplifyFolder.h"
17
#include "llvm/Analysis/LoopInfo.h"
18
#include "llvm/Analysis/ValueTracking.h"
19
#include "llvm/Analysis/VectorUtils.h"
20
#include "llvm/CodeGen/TargetPassConfig.h"
21
#include "llvm/IR/GetElementPtrTypeIterator.h"
22
#include "llvm/IR/IRBuilder.h"
23
#include "llvm/IR/IntrinsicInst.h"
24
#include "llvm/IR/IntrinsicsRISCV.h"
25
#include "llvm/IR/PatternMatch.h"
26
#include "llvm/Transforms/Utils/Local.h"
27
#include <optional>
28
29
using namespace llvm;
30
using namespace PatternMatch;
31
32
#define DEBUG_TYPE "riscv-gather-scatter-lowering"
33
34
namespace {
35
36
class RISCVGatherScatterLowering : public FunctionPass {
37
const RISCVSubtarget *ST = nullptr;
38
const RISCVTargetLowering *TLI = nullptr;
39
LoopInfo *LI = nullptr;
40
const DataLayout *DL = nullptr;
41
42
SmallVector<WeakTrackingVH> MaybeDeadPHIs;
43
44
// Cache of the BasePtr and Stride determined from this GEP. When a GEP is
45
// used by multiple gathers/scatters, this allow us to reuse the scalar
46
// instructions we created for the first gather/scatter for the others.
47
DenseMap<GetElementPtrInst *, std::pair<Value *, Value *>> StridedAddrs;
48
49
public:
50
static char ID; // Pass identification, replacement for typeid
51
52
RISCVGatherScatterLowering() : FunctionPass(ID) {}
53
54
bool runOnFunction(Function &F) override;
55
56
void getAnalysisUsage(AnalysisUsage &AU) const override {
57
AU.setPreservesCFG();
58
AU.addRequired<TargetPassConfig>();
59
AU.addRequired<LoopInfoWrapperPass>();
60
}
61
62
StringRef getPassName() const override {
63
return "RISC-V gather/scatter lowering";
64
}
65
66
private:
67
bool tryCreateStridedLoadStore(IntrinsicInst *II, Type *DataType, Value *Ptr,
68
Value *AlignOp);
69
70
std::pair<Value *, Value *> determineBaseAndStride(Instruction *Ptr,
71
IRBuilderBase &Builder);
72
73
bool matchStridedRecurrence(Value *Index, Loop *L, Value *&Stride,
74
PHINode *&BasePtr, BinaryOperator *&Inc,
75
IRBuilderBase &Builder);
76
};
77
78
} // end anonymous namespace
79
80
char RISCVGatherScatterLowering::ID = 0;
81
82
INITIALIZE_PASS(RISCVGatherScatterLowering, DEBUG_TYPE,
83
"RISC-V gather/scatter lowering pass", false, false)
84
85
FunctionPass *llvm::createRISCVGatherScatterLoweringPass() {
86
return new RISCVGatherScatterLowering();
87
}
88
89
// TODO: Should we consider the mask when looking for a stride?
90
static std::pair<Value *, Value *> matchStridedConstant(Constant *StartC) {
91
if (!isa<FixedVectorType>(StartC->getType()))
92
return std::make_pair(nullptr, nullptr);
93
94
unsigned NumElts = cast<FixedVectorType>(StartC->getType())->getNumElements();
95
96
// Check that the start value is a strided constant.
97
auto *StartVal =
98
dyn_cast_or_null<ConstantInt>(StartC->getAggregateElement((unsigned)0));
99
if (!StartVal)
100
return std::make_pair(nullptr, nullptr);
101
APInt StrideVal(StartVal->getValue().getBitWidth(), 0);
102
ConstantInt *Prev = StartVal;
103
for (unsigned i = 1; i != NumElts; ++i) {
104
auto *C = dyn_cast_or_null<ConstantInt>(StartC->getAggregateElement(i));
105
if (!C)
106
return std::make_pair(nullptr, nullptr);
107
108
APInt LocalStride = C->getValue() - Prev->getValue();
109
if (i == 1)
110
StrideVal = LocalStride;
111
else if (StrideVal != LocalStride)
112
return std::make_pair(nullptr, nullptr);
113
114
Prev = C;
115
}
116
117
Value *Stride = ConstantInt::get(StartVal->getType(), StrideVal);
118
119
return std::make_pair(StartVal, Stride);
120
}
121
122
static std::pair<Value *, Value *> matchStridedStart(Value *Start,
123
IRBuilderBase &Builder) {
124
// Base case, start is a strided constant.
125
auto *StartC = dyn_cast<Constant>(Start);
126
if (StartC)
127
return matchStridedConstant(StartC);
128
129
// Base case, start is a stepvector
130
if (match(Start, m_Intrinsic<Intrinsic::experimental_stepvector>())) {
131
auto *Ty = Start->getType()->getScalarType();
132
return std::make_pair(ConstantInt::get(Ty, 0), ConstantInt::get(Ty, 1));
133
}
134
135
// Not a constant, maybe it's a strided constant with a splat added or
136
// multipled.
137
auto *BO = dyn_cast<BinaryOperator>(Start);
138
if (!BO || (BO->getOpcode() != Instruction::Add &&
139
BO->getOpcode() != Instruction::Or &&
140
BO->getOpcode() != Instruction::Shl &&
141
BO->getOpcode() != Instruction::Mul))
142
return std::make_pair(nullptr, nullptr);
143
144
if (BO->getOpcode() == Instruction::Or &&
145
!cast<PossiblyDisjointInst>(BO)->isDisjoint())
146
return std::make_pair(nullptr, nullptr);
147
148
// Look for an operand that is splatted.
149
unsigned OtherIndex = 0;
150
Value *Splat = getSplatValue(BO->getOperand(1));
151
if (!Splat && Instruction::isCommutative(BO->getOpcode())) {
152
Splat = getSplatValue(BO->getOperand(0));
153
OtherIndex = 1;
154
}
155
if (!Splat)
156
return std::make_pair(nullptr, nullptr);
157
158
Value *Stride;
159
std::tie(Start, Stride) = matchStridedStart(BO->getOperand(OtherIndex),
160
Builder);
161
if (!Start)
162
return std::make_pair(nullptr, nullptr);
163
164
Builder.SetInsertPoint(BO);
165
Builder.SetCurrentDebugLocation(DebugLoc());
166
// Add the splat value to the start or multiply the start and stride by the
167
// splat.
168
switch (BO->getOpcode()) {
169
default:
170
llvm_unreachable("Unexpected opcode");
171
case Instruction::Or:
172
// TODO: We'd be better off creating disjoint or here, but we don't yet
173
// have an IRBuilder API for that.
174
[[fallthrough]];
175
case Instruction::Add:
176
Start = Builder.CreateAdd(Start, Splat);
177
break;
178
case Instruction::Mul:
179
Start = Builder.CreateMul(Start, Splat);
180
Stride = Builder.CreateMul(Stride, Splat);
181
break;
182
case Instruction::Shl:
183
Start = Builder.CreateShl(Start, Splat);
184
Stride = Builder.CreateShl(Stride, Splat);
185
break;
186
}
187
188
return std::make_pair(Start, Stride);
189
}
190
191
// Recursively, walk about the use-def chain until we find a Phi with a strided
192
// start value. Build and update a scalar recurrence as we unwind the recursion.
193
// We also update the Stride as we unwind. Our goal is to move all of the
194
// arithmetic out of the loop.
195
bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L,
196
Value *&Stride,
197
PHINode *&BasePtr,
198
BinaryOperator *&Inc,
199
IRBuilderBase &Builder) {
200
// Our base case is a Phi.
201
if (auto *Phi = dyn_cast<PHINode>(Index)) {
202
// A phi node we want to perform this function on should be from the
203
// loop header.
204
if (Phi->getParent() != L->getHeader())
205
return false;
206
207
Value *Step, *Start;
208
if (!matchSimpleRecurrence(Phi, Inc, Start, Step) ||
209
Inc->getOpcode() != Instruction::Add)
210
return false;
211
assert(Phi->getNumIncomingValues() == 2 && "Expected 2 operand phi.");
212
unsigned IncrementingBlock = Phi->getIncomingValue(0) == Inc ? 0 : 1;
213
assert(Phi->getIncomingValue(IncrementingBlock) == Inc &&
214
"Expected one operand of phi to be Inc");
215
216
// Only proceed if the step is loop invariant.
217
if (!L->isLoopInvariant(Step))
218
return false;
219
220
// Step should be a splat.
221
Step = getSplatValue(Step);
222
if (!Step)
223
return false;
224
225
std::tie(Start, Stride) = matchStridedStart(Start, Builder);
226
if (!Start)
227
return false;
228
assert(Stride != nullptr);
229
230
// Build scalar phi and increment.
231
BasePtr =
232
PHINode::Create(Start->getType(), 2, Phi->getName() + ".scalar", Phi->getIterator());
233
Inc = BinaryOperator::CreateAdd(BasePtr, Step, Inc->getName() + ".scalar",
234
Inc->getIterator());
235
BasePtr->addIncoming(Start, Phi->getIncomingBlock(1 - IncrementingBlock));
236
BasePtr->addIncoming(Inc, Phi->getIncomingBlock(IncrementingBlock));
237
238
// Note that this Phi might be eligible for removal.
239
MaybeDeadPHIs.push_back(Phi);
240
return true;
241
}
242
243
// Otherwise look for binary operator.
244
auto *BO = dyn_cast<BinaryOperator>(Index);
245
if (!BO)
246
return false;
247
248
switch (BO->getOpcode()) {
249
default:
250
return false;
251
case Instruction::Or:
252
// We need to be able to treat Or as Add.
253
if (!cast<PossiblyDisjointInst>(BO)->isDisjoint())
254
return false;
255
break;
256
case Instruction::Add:
257
break;
258
case Instruction::Shl:
259
break;
260
case Instruction::Mul:
261
break;
262
}
263
264
// We should have one operand in the loop and one splat.
265
Value *OtherOp;
266
if (isa<Instruction>(BO->getOperand(0)) &&
267
L->contains(cast<Instruction>(BO->getOperand(0)))) {
268
Index = cast<Instruction>(BO->getOperand(0));
269
OtherOp = BO->getOperand(1);
270
} else if (isa<Instruction>(BO->getOperand(1)) &&
271
L->contains(cast<Instruction>(BO->getOperand(1))) &&
272
Instruction::isCommutative(BO->getOpcode())) {
273
Index = cast<Instruction>(BO->getOperand(1));
274
OtherOp = BO->getOperand(0);
275
} else {
276
return false;
277
}
278
279
// Make sure other op is loop invariant.
280
if (!L->isLoopInvariant(OtherOp))
281
return false;
282
283
// Make sure we have a splat.
284
Value *SplatOp = getSplatValue(OtherOp);
285
if (!SplatOp)
286
return false;
287
288
// Recurse up the use-def chain.
289
if (!matchStridedRecurrence(Index, L, Stride, BasePtr, Inc, Builder))
290
return false;
291
292
// Locate the Step and Start values from the recurrence.
293
unsigned StepIndex = Inc->getOperand(0) == BasePtr ? 1 : 0;
294
unsigned StartBlock = BasePtr->getOperand(0) == Inc ? 1 : 0;
295
Value *Step = Inc->getOperand(StepIndex);
296
Value *Start = BasePtr->getOperand(StartBlock);
297
298
// We need to adjust the start value in the preheader.
299
Builder.SetInsertPoint(
300
BasePtr->getIncomingBlock(StartBlock)->getTerminator());
301
Builder.SetCurrentDebugLocation(DebugLoc());
302
303
switch (BO->getOpcode()) {
304
default:
305
llvm_unreachable("Unexpected opcode!");
306
case Instruction::Add:
307
case Instruction::Or: {
308
// An add only affects the start value. It's ok to do this for Or because
309
// we already checked that there are no common set bits.
310
Start = Builder.CreateAdd(Start, SplatOp, "start");
311
break;
312
}
313
case Instruction::Mul: {
314
Start = Builder.CreateMul(Start, SplatOp, "start");
315
Step = Builder.CreateMul(Step, SplatOp, "step");
316
Stride = Builder.CreateMul(Stride, SplatOp, "stride");
317
break;
318
}
319
case Instruction::Shl: {
320
Start = Builder.CreateShl(Start, SplatOp, "start");
321
Step = Builder.CreateShl(Step, SplatOp, "step");
322
Stride = Builder.CreateShl(Stride, SplatOp, "stride");
323
break;
324
}
325
}
326
327
Inc->setOperand(StepIndex, Step);
328
BasePtr->setIncomingValue(StartBlock, Start);
329
return true;
330
}
331
332
std::pair<Value *, Value *>
333
RISCVGatherScatterLowering::determineBaseAndStride(Instruction *Ptr,
334
IRBuilderBase &Builder) {
335
336
// A gather/scatter of a splat is a zero strided load/store.
337
if (auto *BasePtr = getSplatValue(Ptr)) {
338
Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType());
339
return std::make_pair(BasePtr, ConstantInt::get(IntPtrTy, 0));
340
}
341
342
auto *GEP = dyn_cast<GetElementPtrInst>(Ptr);
343
if (!GEP)
344
return std::make_pair(nullptr, nullptr);
345
346
auto I = StridedAddrs.find(GEP);
347
if (I != StridedAddrs.end())
348
return I->second;
349
350
SmallVector<Value *, 2> Ops(GEP->operands());
351
352
// If the base pointer is a vector, check if it's strided.
353
Value *Base = GEP->getPointerOperand();
354
if (auto *BaseInst = dyn_cast<Instruction>(Base);
355
BaseInst && BaseInst->getType()->isVectorTy()) {
356
// If GEP's offset is scalar then we can add it to the base pointer's base.
357
auto IsScalar = [](Value *Idx) { return !Idx->getType()->isVectorTy(); };
358
if (all_of(GEP->indices(), IsScalar)) {
359
auto [BaseBase, Stride] = determineBaseAndStride(BaseInst, Builder);
360
if (BaseBase) {
361
Builder.SetInsertPoint(GEP);
362
SmallVector<Value *> Indices(GEP->indices());
363
Value *OffsetBase =
364
Builder.CreateGEP(GEP->getSourceElementType(), BaseBase, Indices,
365
GEP->getName() + "offset", GEP->isInBounds());
366
return {OffsetBase, Stride};
367
}
368
}
369
}
370
371
// Base pointer needs to be a scalar.
372
Value *ScalarBase = Base;
373
if (ScalarBase->getType()->isVectorTy()) {
374
ScalarBase = getSplatValue(ScalarBase);
375
if (!ScalarBase)
376
return std::make_pair(nullptr, nullptr);
377
}
378
379
std::optional<unsigned> VecOperand;
380
unsigned TypeScale = 0;
381
382
// Look for a vector operand and scale.
383
gep_type_iterator GTI = gep_type_begin(GEP);
384
for (unsigned i = 1, e = GEP->getNumOperands(); i != e; ++i, ++GTI) {
385
if (!Ops[i]->getType()->isVectorTy())
386
continue;
387
388
if (VecOperand)
389
return std::make_pair(nullptr, nullptr);
390
391
VecOperand = i;
392
393
TypeSize TS = GTI.getSequentialElementStride(*DL);
394
if (TS.isScalable())
395
return std::make_pair(nullptr, nullptr);
396
397
TypeScale = TS.getFixedValue();
398
}
399
400
// We need to find a vector index to simplify.
401
if (!VecOperand)
402
return std::make_pair(nullptr, nullptr);
403
404
// We can't extract the stride if the arithmetic is done at a different size
405
// than the pointer type. Adding the stride later may not wrap correctly.
406
// Technically we could handle wider indices, but I don't expect that in
407
// practice. Handle one special case here - constants. This simplifies
408
// writing test cases.
409
Value *VecIndex = Ops[*VecOperand];
410
Type *VecIntPtrTy = DL->getIntPtrType(GEP->getType());
411
if (VecIndex->getType() != VecIntPtrTy) {
412
auto *VecIndexC = dyn_cast<Constant>(VecIndex);
413
if (!VecIndexC)
414
return std::make_pair(nullptr, nullptr);
415
if (VecIndex->getType()->getScalarSizeInBits() > VecIntPtrTy->getScalarSizeInBits())
416
VecIndex = ConstantFoldCastInstruction(Instruction::Trunc, VecIndexC, VecIntPtrTy);
417
else
418
VecIndex = ConstantFoldCastInstruction(Instruction::SExt, VecIndexC, VecIntPtrTy);
419
}
420
421
// Handle the non-recursive case. This is what we see if the vectorizer
422
// decides to use a scalar IV + vid on demand instead of a vector IV.
423
auto [Start, Stride] = matchStridedStart(VecIndex, Builder);
424
if (Start) {
425
assert(Stride);
426
Builder.SetInsertPoint(GEP);
427
428
// Replace the vector index with the scalar start and build a scalar GEP.
429
Ops[*VecOperand] = Start;
430
Type *SourceTy = GEP->getSourceElementType();
431
Value *BasePtr =
432
Builder.CreateGEP(SourceTy, ScalarBase, ArrayRef(Ops).drop_front());
433
434
// Convert stride to pointer size if needed.
435
Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType());
436
assert(Stride->getType() == IntPtrTy && "Unexpected type");
437
438
// Scale the stride by the size of the indexed type.
439
if (TypeScale != 1)
440
Stride = Builder.CreateMul(Stride, ConstantInt::get(IntPtrTy, TypeScale));
441
442
auto P = std::make_pair(BasePtr, Stride);
443
StridedAddrs[GEP] = P;
444
return P;
445
}
446
447
// Make sure we're in a loop and that has a pre-header and a single latch.
448
Loop *L = LI->getLoopFor(GEP->getParent());
449
if (!L || !L->getLoopPreheader() || !L->getLoopLatch())
450
return std::make_pair(nullptr, nullptr);
451
452
BinaryOperator *Inc;
453
PHINode *BasePhi;
454
if (!matchStridedRecurrence(VecIndex, L, Stride, BasePhi, Inc, Builder))
455
return std::make_pair(nullptr, nullptr);
456
457
assert(BasePhi->getNumIncomingValues() == 2 && "Expected 2 operand phi.");
458
unsigned IncrementingBlock = BasePhi->getOperand(0) == Inc ? 0 : 1;
459
assert(BasePhi->getIncomingValue(IncrementingBlock) == Inc &&
460
"Expected one operand of phi to be Inc");
461
462
Builder.SetInsertPoint(GEP);
463
464
// Replace the vector index with the scalar phi and build a scalar GEP.
465
Ops[*VecOperand] = BasePhi;
466
Type *SourceTy = GEP->getSourceElementType();
467
Value *BasePtr =
468
Builder.CreateGEP(SourceTy, ScalarBase, ArrayRef(Ops).drop_front());
469
470
// Final adjustments to stride should go in the start block.
471
Builder.SetInsertPoint(
472
BasePhi->getIncomingBlock(1 - IncrementingBlock)->getTerminator());
473
474
// Convert stride to pointer size if needed.
475
Type *IntPtrTy = DL->getIntPtrType(BasePtr->getType());
476
assert(Stride->getType() == IntPtrTy && "Unexpected type");
477
478
// Scale the stride by the size of the indexed type.
479
if (TypeScale != 1)
480
Stride = Builder.CreateMul(Stride, ConstantInt::get(IntPtrTy, TypeScale));
481
482
auto P = std::make_pair(BasePtr, Stride);
483
StridedAddrs[GEP] = P;
484
return P;
485
}
486
487
bool RISCVGatherScatterLowering::tryCreateStridedLoadStore(IntrinsicInst *II,
488
Type *DataType,
489
Value *Ptr,
490
Value *AlignOp) {
491
// Make sure the operation will be supported by the backend.
492
MaybeAlign MA = cast<ConstantInt>(AlignOp)->getMaybeAlignValue();
493
EVT DataTypeVT = TLI->getValueType(*DL, DataType);
494
if (!MA || !TLI->isLegalStridedLoadStore(DataTypeVT, *MA))
495
return false;
496
497
// FIXME: Let the backend type legalize by splitting/widening?
498
if (!TLI->isTypeLegal(DataTypeVT))
499
return false;
500
501
// Pointer should be an instruction.
502
auto *PtrI = dyn_cast<Instruction>(Ptr);
503
if (!PtrI)
504
return false;
505
506
LLVMContext &Ctx = PtrI->getContext();
507
IRBuilder<InstSimplifyFolder> Builder(Ctx, *DL);
508
Builder.SetInsertPoint(PtrI);
509
510
Value *BasePtr, *Stride;
511
std::tie(BasePtr, Stride) = determineBaseAndStride(PtrI, Builder);
512
if (!BasePtr)
513
return false;
514
assert(Stride != nullptr);
515
516
Builder.SetInsertPoint(II);
517
518
CallInst *Call;
519
if (II->getIntrinsicID() == Intrinsic::masked_gather)
520
Call = Builder.CreateIntrinsic(
521
Intrinsic::riscv_masked_strided_load,
522
{DataType, BasePtr->getType(), Stride->getType()},
523
{II->getArgOperand(3), BasePtr, Stride, II->getArgOperand(2)});
524
else
525
Call = Builder.CreateIntrinsic(
526
Intrinsic::riscv_masked_strided_store,
527
{DataType, BasePtr->getType(), Stride->getType()},
528
{II->getArgOperand(0), BasePtr, Stride, II->getArgOperand(3)});
529
530
Call->takeName(II);
531
II->replaceAllUsesWith(Call);
532
II->eraseFromParent();
533
534
if (PtrI->use_empty())
535
RecursivelyDeleteTriviallyDeadInstructions(PtrI);
536
537
return true;
538
}
539
540
bool RISCVGatherScatterLowering::runOnFunction(Function &F) {
541
if (skipFunction(F))
542
return false;
543
544
auto &TPC = getAnalysis<TargetPassConfig>();
545
auto &TM = TPC.getTM<RISCVTargetMachine>();
546
ST = &TM.getSubtarget<RISCVSubtarget>(F);
547
if (!ST->hasVInstructions() || !ST->useRVVForFixedLengthVectors())
548
return false;
549
550
TLI = ST->getTargetLowering();
551
DL = &F.getDataLayout();
552
LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
553
554
StridedAddrs.clear();
555
556
SmallVector<IntrinsicInst *, 4> Gathers;
557
SmallVector<IntrinsicInst *, 4> Scatters;
558
559
bool Changed = false;
560
561
for (BasicBlock &BB : F) {
562
for (Instruction &I : BB) {
563
IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I);
564
if (II && II->getIntrinsicID() == Intrinsic::masked_gather) {
565
Gathers.push_back(II);
566
} else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter) {
567
Scatters.push_back(II);
568
}
569
}
570
}
571
572
// Rewrite gather/scatter to form strided load/store if possible.
573
for (auto *II : Gathers)
574
Changed |= tryCreateStridedLoadStore(
575
II, II->getType(), II->getArgOperand(0), II->getArgOperand(1));
576
for (auto *II : Scatters)
577
Changed |=
578
tryCreateStridedLoadStore(II, II->getArgOperand(0)->getType(),
579
II->getArgOperand(1), II->getArgOperand(2));
580
581
// Remove any dead phis.
582
while (!MaybeDeadPHIs.empty()) {
583
if (auto *Phi = dyn_cast_or_null<PHINode>(MaybeDeadPHIs.pop_back_val()))
584
RecursivelyDeleteDeadPHINode(Phi);
585
}
586
587
return Changed;
588
}
589
590