Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
freebsd
GitHub Repository: freebsd/freebsd-src
Path: blob/main/contrib/llvm-project/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
35234 views
1
//===- ComplexDeinterleavingPass.cpp --------------------------------------===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
//
9
// Identification:
10
// This step is responsible for finding the patterns that can be lowered to
11
// complex instructions, and building a graph to represent the complex
12
// structures. Starting from the "Converging Shuffle" (a shuffle that
13
// reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the
14
// operands are evaluated and identified as "Composite Nodes" (collections of
15
// instructions that can potentially be lowered to a single complex
16
// instruction). This is performed by checking the real and imaginary components
17
// and tracking the data flow for each component while following the operand
18
// pairs. Validity of each node is expected to be done upon creation, and any
19
// validation errors should halt traversal and prevent further graph
20
// construction.
21
// Instead of relying on Shuffle operations, vector interleaving and
22
// deinterleaving can be represented by vector.interleave2 and
23
// vector.deinterleave2 intrinsics. Scalable vectors can be represented only by
24
// these intrinsics, whereas, fixed-width vectors are recognized for both
25
// shufflevector instruction and intrinsics.
26
//
27
// Replacement:
28
// This step traverses the graph built up by identification, delegating to the
29
// target to validate and generate the correct intrinsics, and plumbs them
30
// together connecting each end of the new intrinsics graph to the existing
31
// use-def chain. This step is assumed to finish successfully, as all
32
// information is expected to be correct by this point.
33
//
34
//
35
// Internal data structure:
36
// ComplexDeinterleavingGraph:
37
// Keeps references to all the valid CompositeNodes formed as part of the
38
// transformation, and every Instruction contained within said nodes. It also
39
// holds onto a reference to the root Instruction, and the root node that should
40
// replace it.
41
//
42
// ComplexDeinterleavingCompositeNode:
43
// A CompositeNode represents a single transformation point; each node should
44
// transform into a single complex instruction (ignoring vector splitting, which
45
// would generate more instructions per node). They are identified in a
46
// depth-first manner, traversing and identifying the operands of each
47
// instruction in the order they appear in the IR.
48
// Each node maintains a reference to its Real and Imaginary instructions,
49
// as well as any additional instructions that make up the identified operation
50
// (Internal instructions should only have uses within their containing node).
51
// A Node also contains the rotation and operation type that it represents.
52
// Operands contains pointers to other CompositeNodes, acting as the edges in
53
// the graph. ReplacementValue is the transformed Value* that has been emitted
54
// to the IR.
55
//
56
// Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and
57
// ReplacementValue fields of that Node are relevant, where the ReplacementValue
58
// should be pre-populated.
59
//
60
//===----------------------------------------------------------------------===//
61
62
#include "llvm/CodeGen/ComplexDeinterleavingPass.h"
63
#include "llvm/ADT/MapVector.h"
64
#include "llvm/ADT/Statistic.h"
65
#include "llvm/Analysis/TargetLibraryInfo.h"
66
#include "llvm/Analysis/TargetTransformInfo.h"
67
#include "llvm/CodeGen/TargetLowering.h"
68
#include "llvm/CodeGen/TargetPassConfig.h"
69
#include "llvm/CodeGen/TargetSubtargetInfo.h"
70
#include "llvm/IR/IRBuilder.h"
71
#include "llvm/IR/PatternMatch.h"
72
#include "llvm/InitializePasses.h"
73
#include "llvm/Target/TargetMachine.h"
74
#include "llvm/Transforms/Utils/Local.h"
75
#include <algorithm>
76
77
using namespace llvm;
78
using namespace PatternMatch;
79
80
#define DEBUG_TYPE "complex-deinterleaving"
81
82
STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed");
83
84
static cl::opt<bool> ComplexDeinterleavingEnabled(
85
"enable-complex-deinterleaving",
86
cl::desc("Enable generation of complex instructions"), cl::init(true),
87
cl::Hidden);
88
89
/// Checks the given mask, and determines whether said mask is interleaving.
90
///
91
/// To be interleaving, a mask must alternate between `i` and `i + (Length /
92
/// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a
93
/// 4x vector interleaving mask would be <0, 2, 1, 3>).
94
static bool isInterleavingMask(ArrayRef<int> Mask);
95
96
/// Checks the given mask, and determines whether said mask is deinterleaving.
97
///
98
/// To be deinterleaving, a mask must increment in steps of 2, and either start
99
/// with 0 or 1.
100
/// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or
101
/// <1, 3, 5, 7>).
102
static bool isDeinterleavingMask(ArrayRef<int> Mask);
103
104
/// Returns true if the operation is a negation of V, and it works for both
105
/// integers and floats.
106
static bool isNeg(Value *V);
107
108
/// Returns the operand for negation operation.
109
static Value *getNegOperand(Value *V);
110
111
namespace {
112
113
class ComplexDeinterleavingLegacyPass : public FunctionPass {
114
public:
115
static char ID;
116
117
ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr)
118
: FunctionPass(ID), TM(TM) {
119
initializeComplexDeinterleavingLegacyPassPass(
120
*PassRegistry::getPassRegistry());
121
}
122
123
StringRef getPassName() const override {
124
return "Complex Deinterleaving Pass";
125
}
126
127
bool runOnFunction(Function &F) override;
128
void getAnalysisUsage(AnalysisUsage &AU) const override {
129
AU.addRequired<TargetLibraryInfoWrapperPass>();
130
AU.setPreservesCFG();
131
}
132
133
private:
134
const TargetMachine *TM;
135
};
136
137
class ComplexDeinterleavingGraph;
138
struct ComplexDeinterleavingCompositeNode {
139
140
ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,
141
Value *R, Value *I)
142
: Operation(Op), Real(R), Imag(I) {}
143
144
private:
145
friend class ComplexDeinterleavingGraph;
146
using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>;
147
using RawNodePtr = ComplexDeinterleavingCompositeNode *;
148
149
public:
150
ComplexDeinterleavingOperation Operation;
151
Value *Real;
152
Value *Imag;
153
154
// This two members are required exclusively for generating
155
// ComplexDeinterleavingOperation::Symmetric operations.
156
unsigned Opcode;
157
std::optional<FastMathFlags> Flags;
158
159
ComplexDeinterleavingRotation Rotation =
160
ComplexDeinterleavingRotation::Rotation_0;
161
SmallVector<RawNodePtr> Operands;
162
Value *ReplacementNode = nullptr;
163
164
void addOperand(NodePtr Node) { Operands.push_back(Node.get()); }
165
166
void dump() { dump(dbgs()); }
167
void dump(raw_ostream &OS) {
168
auto PrintValue = [&](Value *V) {
169
if (V) {
170
OS << "\"";
171
V->print(OS, true);
172
OS << "\"\n";
173
} else
174
OS << "nullptr\n";
175
};
176
auto PrintNodeRef = [&](RawNodePtr Ptr) {
177
if (Ptr)
178
OS << Ptr << "\n";
179
else
180
OS << "nullptr\n";
181
};
182
183
OS << "- CompositeNode: " << this << "\n";
184
OS << " Real: ";
185
PrintValue(Real);
186
OS << " Imag: ";
187
PrintValue(Imag);
188
OS << " ReplacementNode: ";
189
PrintValue(ReplacementNode);
190
OS << " Operation: " << (int)Operation << "\n";
191
OS << " Rotation: " << ((int)Rotation * 90) << "\n";
192
OS << " Operands: \n";
193
for (const auto &Op : Operands) {
194
OS << " - ";
195
PrintNodeRef(Op);
196
}
197
}
198
};
199
200
class ComplexDeinterleavingGraph {
201
public:
202
struct Product {
203
Value *Multiplier;
204
Value *Multiplicand;
205
bool IsPositive;
206
};
207
208
using Addend = std::pair<Value *, bool>;
209
using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr;
210
using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr;
211
212
// Helper struct for holding info about potential partial multiplication
213
// candidates
214
struct PartialMulCandidate {
215
Value *Common;
216
NodePtr Node;
217
unsigned RealIdx;
218
unsigned ImagIdx;
219
bool IsNodeInverted;
220
};
221
222
explicit ComplexDeinterleavingGraph(const TargetLowering *TL,
223
const TargetLibraryInfo *TLI)
224
: TL(TL), TLI(TLI) {}
225
226
private:
227
const TargetLowering *TL = nullptr;
228
const TargetLibraryInfo *TLI = nullptr;
229
SmallVector<NodePtr> CompositeNodes;
230
DenseMap<std::pair<Value *, Value *>, NodePtr> CachedResult;
231
232
SmallPtrSet<Instruction *, 16> FinalInstructions;
233
234
/// Root instructions are instructions from which complex computation starts
235
std::map<Instruction *, NodePtr> RootToNode;
236
237
/// Topologically sorted root instructions
238
SmallVector<Instruction *, 1> OrderedRoots;
239
240
/// When examining a basic block for complex deinterleaving, if it is a simple
241
/// one-block loop, then the only incoming block is 'Incoming' and the
242
/// 'BackEdge' block is the block itself."
243
BasicBlock *BackEdge = nullptr;
244
BasicBlock *Incoming = nullptr;
245
246
/// ReductionInfo maps from %ReductionOp to %PHInode and Instruction
247
/// %OutsideUser as it is shown in the IR:
248
///
249
/// vector.body:
250
/// %PHInode = phi <vector type> [ zeroinitializer, %entry ],
251
/// [ %ReductionOp, %vector.body ]
252
/// ...
253
/// %ReductionOp = fadd i64 ...
254
/// ...
255
/// br i1 %condition, label %vector.body, %middle.block
256
///
257
/// middle.block:
258
/// %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp)
259
///
260
/// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding
261
/// `llvm.vector.reduce.fadd` when unroll factor isn't one.
262
MapVector<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo;
263
264
/// In the process of detecting a reduction, we consider a pair of
265
/// %ReductionOP, which we refer to as real and imag (or vice versa), and
266
/// traverse the use-tree to detect complex operations. As this is a reduction
267
/// operation, it will eventually reach RealPHI and ImagPHI, which corresponds
268
/// to the %ReductionOPs that we suspect to be complex.
269
/// RealPHI and ImagPHI are used by the identifyPHINode method.
270
PHINode *RealPHI = nullptr;
271
PHINode *ImagPHI = nullptr;
272
273
/// Set this flag to true if RealPHI and ImagPHI were reached during reduction
274
/// detection.
275
bool PHIsFound = false;
276
277
/// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode.
278
/// The new PHINode corresponds to a vector of deinterleaved complex numbers.
279
/// This mapping is populated during
280
/// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then
281
/// used in the ComplexDeinterleavingOperation::ReductionOperation node
282
/// replacement process.
283
std::map<PHINode *, PHINode *> OldToNewPHI;
284
285
NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation,
286
Value *R, Value *I) {
287
assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI &&
288
Operation != ComplexDeinterleavingOperation::ReductionOperation) ||
289
(R && I)) &&
290
"Reduction related nodes must have Real and Imaginary parts");
291
return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation, R,
292
I);
293
}
294
295
NodePtr submitCompositeNode(NodePtr Node) {
296
CompositeNodes.push_back(Node);
297
if (Node->Real && Node->Imag)
298
CachedResult[{Node->Real, Node->Imag}] = Node;
299
return Node;
300
}
301
302
/// Identifies a complex partial multiply pattern and its rotation, based on
303
/// the following patterns
304
///
305
/// 0: r: cr + ar * br
306
/// i: ci + ar * bi
307
/// 90: r: cr - ai * bi
308
/// i: ci + ai * br
309
/// 180: r: cr - ar * br
310
/// i: ci - ar * bi
311
/// 270: r: cr + ai * bi
312
/// i: ci - ai * br
313
NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag);
314
315
/// Identify the other branch of a Partial Mul, taking the CommonOperandI that
316
/// is partially known from identifyPartialMul, filling in the other half of
317
/// the complex pair.
318
NodePtr
319
identifyNodeWithImplicitAdd(Instruction *I, Instruction *J,
320
std::pair<Value *, Value *> &CommonOperandI);
321
322
/// Identifies a complex add pattern and its rotation, based on the following
323
/// patterns.
324
///
325
/// 90: r: ar - bi
326
/// i: ai + br
327
/// 270: r: ar + bi
328
/// i: ai - br
329
NodePtr identifyAdd(Instruction *Real, Instruction *Imag);
330
NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag);
331
332
NodePtr identifyNode(Value *R, Value *I);
333
334
/// Determine if a sum of complex numbers can be formed from \p RealAddends
335
/// and \p ImagAddens. If \p Accumulator is not null, add the result to it.
336
/// Return nullptr if it is not possible to construct a complex number.
337
/// \p Flags are needed to generate symmetric Add and Sub operations.
338
NodePtr identifyAdditions(std::list<Addend> &RealAddends,
339
std::list<Addend> &ImagAddends,
340
std::optional<FastMathFlags> Flags,
341
NodePtr Accumulator);
342
343
/// Extract one addend that have both real and imaginary parts positive.
344
NodePtr extractPositiveAddend(std::list<Addend> &RealAddends,
345
std::list<Addend> &ImagAddends);
346
347
/// Determine if sum of multiplications of complex numbers can be formed from
348
/// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result
349
/// to it. Return nullptr if it is not possible to construct a complex number.
350
NodePtr identifyMultiplications(std::vector<Product> &RealMuls,
351
std::vector<Product> &ImagMuls,
352
NodePtr Accumulator);
353
354
/// Go through pairs of multiplication (one Real and one Imag) and find all
355
/// possible candidates for partial multiplication and put them into \p
356
/// Candidates. Returns true if all Product has pair with common operand
357
bool collectPartialMuls(const std::vector<Product> &RealMuls,
358
const std::vector<Product> &ImagMuls,
359
std::vector<PartialMulCandidate> &Candidates);
360
361
/// If the code is compiled with -Ofast or expressions have `reassoc` flag,
362
/// the order of complex computation operations may be significantly altered,
363
/// and the real and imaginary parts may not be executed in parallel. This
364
/// function takes this into consideration and employs a more general approach
365
/// to identify complex computations. Initially, it gathers all the addends
366
/// and multiplicands and then constructs a complex expression from them.
367
NodePtr identifyReassocNodes(Instruction *I, Instruction *J);
368
369
NodePtr identifyRoot(Instruction *I);
370
371
/// Identifies the Deinterleave operation applied to a vector containing
372
/// complex numbers. There are two ways to represent the Deinterleave
373
/// operation:
374
/// * Using two shufflevectors with even indices for /pReal instruction and
375
/// odd indices for /pImag instructions (only for fixed-width vectors)
376
/// * Using two extractvalue instructions applied to `vector.deinterleave2`
377
/// intrinsic (for both fixed and scalable vectors)
378
NodePtr identifyDeinterleave(Instruction *Real, Instruction *Imag);
379
380
/// identifying the operation that represents a complex number repeated in a
381
/// Splat vector. There are two possible types of splats: ConstantExpr with
382
/// the opcode ShuffleVector and ShuffleVectorInstr. Both should have an
383
/// initialization mask with all values set to zero.
384
NodePtr identifySplat(Value *Real, Value *Imag);
385
386
NodePtr identifyPHINode(Instruction *Real, Instruction *Imag);
387
388
/// Identifies SelectInsts in a loop that has reduction with predication masks
389
/// and/or predicated tail folding
390
NodePtr identifySelectNode(Instruction *Real, Instruction *Imag);
391
392
Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node);
393
394
/// Complete IR modifications after producing new reduction operation:
395
/// * Populate the PHINode generated for
396
/// ComplexDeinterleavingOperation::ReductionPHI
397
/// * Deinterleave the final value outside of the loop and repurpose original
398
/// reduction users
399
void processReductionOperation(Value *OperationReplacement, RawNodePtr Node);
400
401
public:
402
void dump() { dump(dbgs()); }
403
void dump(raw_ostream &OS) {
404
for (const auto &Node : CompositeNodes)
405
Node->dump(OS);
406
}
407
408
/// Returns false if the deinterleaving operation should be cancelled for the
409
/// current graph.
410
bool identifyNodes(Instruction *RootI);
411
412
/// In case \pB is one-block loop, this function seeks potential reductions
413
/// and populates ReductionInfo. Returns true if any reductions were
414
/// identified.
415
bool collectPotentialReductions(BasicBlock *B);
416
417
void identifyReductionNodes();
418
419
/// Check that every instruction, from the roots to the leaves, has internal
420
/// uses.
421
bool checkNodes();
422
423
/// Perform the actual replacement of the underlying instruction graph.
424
void replaceNodes();
425
};
426
427
class ComplexDeinterleaving {
428
public:
429
ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli)
430
: TL(tl), TLI(tli) {}
431
bool runOnFunction(Function &F);
432
433
private:
434
bool evaluateBasicBlock(BasicBlock *B);
435
436
const TargetLowering *TL = nullptr;
437
const TargetLibraryInfo *TLI = nullptr;
438
};
439
440
} // namespace
441
442
char ComplexDeinterleavingLegacyPass::ID = 0;
443
444
INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
445
"Complex Deinterleaving", false, false)
446
INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,
447
"Complex Deinterleaving", false, false)
448
449
PreservedAnalyses ComplexDeinterleavingPass::run(Function &F,
450
FunctionAnalysisManager &AM) {
451
const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering();
452
auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F);
453
if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F))
454
return PreservedAnalyses::all();
455
456
PreservedAnalyses PA;
457
PA.preserve<FunctionAnalysisManagerModuleProxy>();
458
return PA;
459
}
460
461
FunctionPass *llvm::createComplexDeinterleavingPass(const TargetMachine *TM) {
462
return new ComplexDeinterleavingLegacyPass(TM);
463
}
464
465
bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) {
466
const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering();
467
auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
468
return ComplexDeinterleaving(TL, &TLI).runOnFunction(F);
469
}
470
471
bool ComplexDeinterleaving::runOnFunction(Function &F) {
472
if (!ComplexDeinterleavingEnabled) {
473
LLVM_DEBUG(
474
dbgs() << "Complex deinterleaving has been explicitly disabled.\n");
475
return false;
476
}
477
478
if (!TL->isComplexDeinterleavingSupported()) {
479
LLVM_DEBUG(
480
dbgs() << "Complex deinterleaving has been disabled, target does "
481
"not support lowering of complex number operations.\n");
482
return false;
483
}
484
485
bool Changed = false;
486
for (auto &B : F)
487
Changed |= evaluateBasicBlock(&B);
488
489
return Changed;
490
}
491
492
static bool isInterleavingMask(ArrayRef<int> Mask) {
493
// If the size is not even, it's not an interleaving mask
494
if ((Mask.size() & 1))
495
return false;
496
497
int HalfNumElements = Mask.size() / 2;
498
for (int Idx = 0; Idx < HalfNumElements; ++Idx) {
499
int MaskIdx = Idx * 2;
500
if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements))
501
return false;
502
}
503
504
return true;
505
}
506
507
static bool isDeinterleavingMask(ArrayRef<int> Mask) {
508
int Offset = Mask[0];
509
int HalfNumElements = Mask.size() / 2;
510
511
for (int Idx = 1; Idx < HalfNumElements; ++Idx) {
512
if (Mask[Idx] != (Idx * 2) + Offset)
513
return false;
514
}
515
516
return true;
517
}
518
519
bool isNeg(Value *V) {
520
return match(V, m_FNeg(m_Value())) || match(V, m_Neg(m_Value()));
521
}
522
523
Value *getNegOperand(Value *V) {
524
assert(isNeg(V));
525
auto *I = cast<Instruction>(V);
526
if (I->getOpcode() == Instruction::FNeg)
527
return I->getOperand(0);
528
529
return I->getOperand(1);
530
}
531
532
bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) {
533
ComplexDeinterleavingGraph Graph(TL, TLI);
534
if (Graph.collectPotentialReductions(B))
535
Graph.identifyReductionNodes();
536
537
for (auto &I : *B)
538
Graph.identifyNodes(&I);
539
540
if (Graph.checkNodes()) {
541
Graph.replaceNodes();
542
return true;
543
}
544
545
return false;
546
}
547
548
ComplexDeinterleavingGraph::NodePtr
549
ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(
550
Instruction *Real, Instruction *Imag,
551
std::pair<Value *, Value *> &PartialMatch) {
552
LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag
553
<< "\n");
554
555
if (!Real->hasOneUse() || !Imag->hasOneUse()) {
556
LLVM_DEBUG(dbgs() << " - Mul operand has multiple uses.\n");
557
return nullptr;
558
}
559
560
if ((Real->getOpcode() != Instruction::FMul &&
561
Real->getOpcode() != Instruction::Mul) ||
562
(Imag->getOpcode() != Instruction::FMul &&
563
Imag->getOpcode() != Instruction::Mul)) {
564
LLVM_DEBUG(
565
dbgs() << " - Real or imaginary instruction is not fmul or mul\n");
566
return nullptr;
567
}
568
569
Value *R0 = Real->getOperand(0);
570
Value *R1 = Real->getOperand(1);
571
Value *I0 = Imag->getOperand(0);
572
Value *I1 = Imag->getOperand(1);
573
574
// A +/+ has a rotation of 0. If any of the operands are fneg, we flip the
575
// rotations and use the operand.
576
unsigned Negs = 0;
577
Value *Op;
578
if (match(R0, m_Neg(m_Value(Op)))) {
579
Negs |= 1;
580
R0 = Op;
581
} else if (match(R1, m_Neg(m_Value(Op)))) {
582
Negs |= 1;
583
R1 = Op;
584
}
585
586
if (isNeg(I0)) {
587
Negs |= 2;
588
Negs ^= 1;
589
I0 = Op;
590
} else if (match(I1, m_Neg(m_Value(Op)))) {
591
Negs |= 2;
592
Negs ^= 1;
593
I1 = Op;
594
}
595
596
ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs;
597
598
Value *CommonOperand;
599
Value *UncommonRealOp;
600
Value *UncommonImagOp;
601
602
if (R0 == I0 || R0 == I1) {
603
CommonOperand = R0;
604
UncommonRealOp = R1;
605
} else if (R1 == I0 || R1 == I1) {
606
CommonOperand = R1;
607
UncommonRealOp = R0;
608
} else {
609
LLVM_DEBUG(dbgs() << " - No equal operand\n");
610
return nullptr;
611
}
612
613
UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
614
if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
615
Rotation == ComplexDeinterleavingRotation::Rotation_270)
616
std::swap(UncommonRealOp, UncommonImagOp);
617
618
// Between identifyPartialMul and here we need to have found a complete valid
619
// pair from the CommonOperand of each part.
620
if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
621
Rotation == ComplexDeinterleavingRotation::Rotation_180)
622
PartialMatch.first = CommonOperand;
623
else
624
PartialMatch.second = CommonOperand;
625
626
if (!PartialMatch.first || !PartialMatch.second) {
627
LLVM_DEBUG(dbgs() << " - Incomplete partial match\n");
628
return nullptr;
629
}
630
631
NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second);
632
if (!CommonNode) {
633
LLVM_DEBUG(dbgs() << " - No CommonNode identified\n");
634
return nullptr;
635
}
636
637
NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);
638
if (!UncommonNode) {
639
LLVM_DEBUG(dbgs() << " - No UncommonNode identified\n");
640
return nullptr;
641
}
642
643
NodePtr Node = prepareCompositeNode(
644
ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
645
Node->Rotation = Rotation;
646
Node->addOperand(CommonNode);
647
Node->addOperand(UncommonNode);
648
return submitCompositeNode(Node);
649
}
650
651
ComplexDeinterleavingGraph::NodePtr
652
ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,
653
Instruction *Imag) {
654
LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag
655
<< "\n");
656
// Determine rotation
657
auto IsAdd = [](unsigned Op) {
658
return Op == Instruction::FAdd || Op == Instruction::Add;
659
};
660
auto IsSub = [](unsigned Op) {
661
return Op == Instruction::FSub || Op == Instruction::Sub;
662
};
663
ComplexDeinterleavingRotation Rotation;
664
if (IsAdd(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
665
Rotation = ComplexDeinterleavingRotation::Rotation_0;
666
else if (IsSub(Real->getOpcode()) && IsAdd(Imag->getOpcode()))
667
Rotation = ComplexDeinterleavingRotation::Rotation_90;
668
else if (IsSub(Real->getOpcode()) && IsSub(Imag->getOpcode()))
669
Rotation = ComplexDeinterleavingRotation::Rotation_180;
670
else if (IsAdd(Real->getOpcode()) && IsSub(Imag->getOpcode()))
671
Rotation = ComplexDeinterleavingRotation::Rotation_270;
672
else {
673
LLVM_DEBUG(dbgs() << " - Unhandled rotation.\n");
674
return nullptr;
675
}
676
677
if (isa<FPMathOperator>(Real) &&
678
(!Real->getFastMathFlags().allowContract() ||
679
!Imag->getFastMathFlags().allowContract())) {
680
LLVM_DEBUG(dbgs() << " - Contract is missing from the FastMath flags.\n");
681
return nullptr;
682
}
683
684
Value *CR = Real->getOperand(0);
685
Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1));
686
if (!RealMulI)
687
return nullptr;
688
Value *CI = Imag->getOperand(0);
689
Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1));
690
if (!ImagMulI)
691
return nullptr;
692
693
if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) {
694
LLVM_DEBUG(dbgs() << " - Mul instruction has multiple uses\n");
695
return nullptr;
696
}
697
698
Value *R0 = RealMulI->getOperand(0);
699
Value *R1 = RealMulI->getOperand(1);
700
Value *I0 = ImagMulI->getOperand(0);
701
Value *I1 = ImagMulI->getOperand(1);
702
703
Value *CommonOperand;
704
Value *UncommonRealOp;
705
Value *UncommonImagOp;
706
707
if (R0 == I0 || R0 == I1) {
708
CommonOperand = R0;
709
UncommonRealOp = R1;
710
} else if (R1 == I0 || R1 == I1) {
711
CommonOperand = R1;
712
UncommonRealOp = R0;
713
} else {
714
LLVM_DEBUG(dbgs() << " - No equal operand\n");
715
return nullptr;
716
}
717
718
UncommonImagOp = (CommonOperand == I0) ? I1 : I0;
719
if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
720
Rotation == ComplexDeinterleavingRotation::Rotation_270)
721
std::swap(UncommonRealOp, UncommonImagOp);
722
723
std::pair<Value *, Value *> PartialMatch(
724
(Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
725
Rotation == ComplexDeinterleavingRotation::Rotation_180)
726
? CommonOperand
727
: nullptr,
728
(Rotation == ComplexDeinterleavingRotation::Rotation_90 ||
729
Rotation == ComplexDeinterleavingRotation::Rotation_270)
730
? CommonOperand
731
: nullptr);
732
733
auto *CRInst = dyn_cast<Instruction>(CR);
734
auto *CIInst = dyn_cast<Instruction>(CI);
735
736
if (!CRInst || !CIInst) {
737
LLVM_DEBUG(dbgs() << " - Common operands are not instructions.\n");
738
return nullptr;
739
}
740
741
NodePtr CNode = identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch);
742
if (!CNode) {
743
LLVM_DEBUG(dbgs() << " - No cnode identified\n");
744
return nullptr;
745
}
746
747
NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);
748
if (!UncommonRes) {
749
LLVM_DEBUG(dbgs() << " - No UncommonRes identified\n");
750
return nullptr;
751
}
752
753
assert(PartialMatch.first && PartialMatch.second);
754
NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second);
755
if (!CommonRes) {
756
LLVM_DEBUG(dbgs() << " - No CommonRes identified\n");
757
return nullptr;
758
}
759
760
NodePtr Node = prepareCompositeNode(
761
ComplexDeinterleavingOperation::CMulPartial, Real, Imag);
762
Node->Rotation = Rotation;
763
Node->addOperand(CommonRes);
764
Node->addOperand(UncommonRes);
765
Node->addOperand(CNode);
766
return submitCompositeNode(Node);
767
}
768
769
ComplexDeinterleavingGraph::NodePtr
770
ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {
771
LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n");
772
773
// Determine rotation
774
ComplexDeinterleavingRotation Rotation;
775
if ((Real->getOpcode() == Instruction::FSub &&
776
Imag->getOpcode() == Instruction::FAdd) ||
777
(Real->getOpcode() == Instruction::Sub &&
778
Imag->getOpcode() == Instruction::Add))
779
Rotation = ComplexDeinterleavingRotation::Rotation_90;
780
else if ((Real->getOpcode() == Instruction::FAdd &&
781
Imag->getOpcode() == Instruction::FSub) ||
782
(Real->getOpcode() == Instruction::Add &&
783
Imag->getOpcode() == Instruction::Sub))
784
Rotation = ComplexDeinterleavingRotation::Rotation_270;
785
else {
786
LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n");
787
return nullptr;
788
}
789
790
auto *AR = dyn_cast<Instruction>(Real->getOperand(0));
791
auto *BI = dyn_cast<Instruction>(Real->getOperand(1));
792
auto *AI = dyn_cast<Instruction>(Imag->getOperand(0));
793
auto *BR = dyn_cast<Instruction>(Imag->getOperand(1));
794
795
if (!AR || !AI || !BR || !BI) {
796
LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n");
797
return nullptr;
798
}
799
800
NodePtr ResA = identifyNode(AR, AI);
801
if (!ResA) {
802
LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n");
803
return nullptr;
804
}
805
NodePtr ResB = identifyNode(BR, BI);
806
if (!ResB) {
807
LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n");
808
return nullptr;
809
}
810
811
NodePtr Node =
812
prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag);
813
Node->Rotation = Rotation;
814
Node->addOperand(ResA);
815
Node->addOperand(ResB);
816
return submitCompositeNode(Node);
817
}
818
819
static bool isInstructionPairAdd(Instruction *A, Instruction *B) {
820
unsigned OpcA = A->getOpcode();
821
unsigned OpcB = B->getOpcode();
822
823
return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||
824
(OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||
825
(OpcA == Instruction::Sub && OpcB == Instruction::Add) ||
826
(OpcA == Instruction::Add && OpcB == Instruction::Sub);
827
}
828
829
static bool isInstructionPairMul(Instruction *A, Instruction *B) {
830
auto Pattern =
831
m_BinOp(m_FMul(m_Value(), m_Value()), m_FMul(m_Value(), m_Value()));
832
833
return match(A, Pattern) && match(B, Pattern);
834
}
835
836
static bool isInstructionPotentiallySymmetric(Instruction *I) {
837
switch (I->getOpcode()) {
838
case Instruction::FAdd:
839
case Instruction::FSub:
840
case Instruction::FMul:
841
case Instruction::FNeg:
842
case Instruction::Add:
843
case Instruction::Sub:
844
case Instruction::Mul:
845
return true;
846
default:
847
return false;
848
}
849
}
850
851
ComplexDeinterleavingGraph::NodePtr
852
ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real,
853
Instruction *Imag) {
854
if (Real->getOpcode() != Imag->getOpcode())
855
return nullptr;
856
857
if (!isInstructionPotentiallySymmetric(Real) ||
858
!isInstructionPotentiallySymmetric(Imag))
859
return nullptr;
860
861
auto *R0 = Real->getOperand(0);
862
auto *I0 = Imag->getOperand(0);
863
864
NodePtr Op0 = identifyNode(R0, I0);
865
NodePtr Op1 = nullptr;
866
if (Op0 == nullptr)
867
return nullptr;
868
869
if (Real->isBinaryOp()) {
870
auto *R1 = Real->getOperand(1);
871
auto *I1 = Imag->getOperand(1);
872
Op1 = identifyNode(R1, I1);
873
if (Op1 == nullptr)
874
return nullptr;
875
}
876
877
if (isa<FPMathOperator>(Real) &&
878
Real->getFastMathFlags() != Imag->getFastMathFlags())
879
return nullptr;
880
881
auto Node = prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric,
882
Real, Imag);
883
Node->Opcode = Real->getOpcode();
884
if (isa<FPMathOperator>(Real))
885
Node->Flags = Real->getFastMathFlags();
886
887
Node->addOperand(Op0);
888
if (Real->isBinaryOp())
889
Node->addOperand(Op1);
890
891
return submitCompositeNode(Node);
892
}
893
894
ComplexDeinterleavingGraph::NodePtr
895
ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) {
896
LLVM_DEBUG(dbgs() << "identifyNode on " << *R << " / " << *I << "\n");
897
assert(R->getType() == I->getType() &&
898
"Real and imaginary parts should not have different types");
899
900
auto It = CachedResult.find({R, I});
901
if (It != CachedResult.end()) {
902
LLVM_DEBUG(dbgs() << " - Folding to existing node\n");
903
return It->second;
904
}
905
906
if (NodePtr CN = identifySplat(R, I))
907
return CN;
908
909
auto *Real = dyn_cast<Instruction>(R);
910
auto *Imag = dyn_cast<Instruction>(I);
911
if (!Real || !Imag)
912
return nullptr;
913
914
if (NodePtr CN = identifyDeinterleave(Real, Imag))
915
return CN;
916
917
if (NodePtr CN = identifyPHINode(Real, Imag))
918
return CN;
919
920
if (NodePtr CN = identifySelectNode(Real, Imag))
921
return CN;
922
923
auto *VTy = cast<VectorType>(Real->getType());
924
auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
925
926
bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported(
927
ComplexDeinterleavingOperation::CMulPartial, NewVTy);
928
bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported(
929
ComplexDeinterleavingOperation::CAdd, NewVTy);
930
931
if (HasCMulSupport && isInstructionPairMul(Real, Imag)) {
932
if (NodePtr CN = identifyPartialMul(Real, Imag))
933
return CN;
934
}
935
936
if (HasCAddSupport && isInstructionPairAdd(Real, Imag)) {
937
if (NodePtr CN = identifyAdd(Real, Imag))
938
return CN;
939
}
940
941
if (HasCMulSupport && HasCAddSupport) {
942
if (NodePtr CN = identifyReassocNodes(Real, Imag))
943
return CN;
944
}
945
946
if (NodePtr CN = identifySymmetricOperation(Real, Imag))
947
return CN;
948
949
LLVM_DEBUG(dbgs() << " - Not recognised as a valid pattern.\n");
950
CachedResult[{R, I}] = nullptr;
951
return nullptr;
952
}
953
954
ComplexDeinterleavingGraph::NodePtr
955
ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,
956
Instruction *Imag) {
957
auto IsOperationSupported = [](unsigned Opcode) -> bool {
958
return Opcode == Instruction::FAdd || Opcode == Instruction::FSub ||
959
Opcode == Instruction::FNeg || Opcode == Instruction::Add ||
960
Opcode == Instruction::Sub;
961
};
962
963
if (!IsOperationSupported(Real->getOpcode()) ||
964
!IsOperationSupported(Imag->getOpcode()))
965
return nullptr;
966
967
std::optional<FastMathFlags> Flags;
968
if (isa<FPMathOperator>(Real)) {
969
if (Real->getFastMathFlags() != Imag->getFastMathFlags()) {
970
LLVM_DEBUG(dbgs() << "The flags in Real and Imaginary instructions are "
971
"not identical\n");
972
return nullptr;
973
}
974
975
Flags = Real->getFastMathFlags();
976
if (!Flags->allowReassoc()) {
977
LLVM_DEBUG(
978
dbgs()
979
<< "the 'Reassoc' attribute is missing in the FastMath flags\n");
980
return nullptr;
981
}
982
}
983
984
// Collect multiplications and addend instructions from the given instruction
985
// while traversing it operands. Additionally, verify that all instructions
986
// have the same fast math flags.
987
auto Collect = [&Flags](Instruction *Insn, std::vector<Product> &Muls,
988
std::list<Addend> &Addends) -> bool {
989
SmallVector<PointerIntPair<Value *, 1, bool>> Worklist = {{Insn, true}};
990
SmallPtrSet<Value *, 8> Visited;
991
while (!Worklist.empty()) {
992
auto [V, IsPositive] = Worklist.back();
993
Worklist.pop_back();
994
if (!Visited.insert(V).second)
995
continue;
996
997
Instruction *I = dyn_cast<Instruction>(V);
998
if (!I) {
999
Addends.emplace_back(V, IsPositive);
1000
continue;
1001
}
1002
1003
// If an instruction has more than one user, it indicates that it either
1004
// has an external user, which will be later checked by the checkNodes
1005
// function, or it is a subexpression utilized by multiple expressions. In
1006
// the latter case, we will attempt to separately identify the complex
1007
// operation from here in order to create a shared
1008
// ComplexDeinterleavingCompositeNode.
1009
if (I != Insn && I->getNumUses() > 1) {
1010
LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I << "\n");
1011
Addends.emplace_back(I, IsPositive);
1012
continue;
1013
}
1014
switch (I->getOpcode()) {
1015
case Instruction::FAdd:
1016
case Instruction::Add:
1017
Worklist.emplace_back(I->getOperand(1), IsPositive);
1018
Worklist.emplace_back(I->getOperand(0), IsPositive);
1019
break;
1020
case Instruction::FSub:
1021
Worklist.emplace_back(I->getOperand(1), !IsPositive);
1022
Worklist.emplace_back(I->getOperand(0), IsPositive);
1023
break;
1024
case Instruction::Sub:
1025
if (isNeg(I)) {
1026
Worklist.emplace_back(getNegOperand(I), !IsPositive);
1027
} else {
1028
Worklist.emplace_back(I->getOperand(1), !IsPositive);
1029
Worklist.emplace_back(I->getOperand(0), IsPositive);
1030
}
1031
break;
1032
case Instruction::FMul:
1033
case Instruction::Mul: {
1034
Value *A, *B;
1035
if (isNeg(I->getOperand(0))) {
1036
A = getNegOperand(I->getOperand(0));
1037
IsPositive = !IsPositive;
1038
} else {
1039
A = I->getOperand(0);
1040
}
1041
1042
if (isNeg(I->getOperand(1))) {
1043
B = getNegOperand(I->getOperand(1));
1044
IsPositive = !IsPositive;
1045
} else {
1046
B = I->getOperand(1);
1047
}
1048
Muls.push_back(Product{A, B, IsPositive});
1049
break;
1050
}
1051
case Instruction::FNeg:
1052
Worklist.emplace_back(I->getOperand(0), !IsPositive);
1053
break;
1054
default:
1055
Addends.emplace_back(I, IsPositive);
1056
continue;
1057
}
1058
1059
if (Flags && I->getFastMathFlags() != *Flags) {
1060
LLVM_DEBUG(dbgs() << "The instruction's fast math flags are "
1061
"inconsistent with the root instructions' flags: "
1062
<< *I << "\n");
1063
return false;
1064
}
1065
}
1066
return true;
1067
};
1068
1069
std::vector<Product> RealMuls, ImagMuls;
1070
std::list<Addend> RealAddends, ImagAddends;
1071
if (!Collect(Real, RealMuls, RealAddends) ||
1072
!Collect(Imag, ImagMuls, ImagAddends))
1073
return nullptr;
1074
1075
if (RealAddends.size() != ImagAddends.size())
1076
return nullptr;
1077
1078
NodePtr FinalNode;
1079
if (!RealMuls.empty() || !ImagMuls.empty()) {
1080
// If there are multiplicands, extract positive addend and use it as an
1081
// accumulator
1082
FinalNode = extractPositiveAddend(RealAddends, ImagAddends);
1083
FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode);
1084
if (!FinalNode)
1085
return nullptr;
1086
}
1087
1088
// Identify and process remaining additions
1089
if (!RealAddends.empty() || !ImagAddends.empty()) {
1090
FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode);
1091
if (!FinalNode)
1092
return nullptr;
1093
}
1094
assert(FinalNode && "FinalNode can not be nullptr here");
1095
// Set the Real and Imag fields of the final node and submit it
1096
FinalNode->Real = Real;
1097
FinalNode->Imag = Imag;
1098
submitCompositeNode(FinalNode);
1099
return FinalNode;
1100
}
1101
1102
bool ComplexDeinterleavingGraph::collectPartialMuls(
1103
const std::vector<Product> &RealMuls, const std::vector<Product> &ImagMuls,
1104
std::vector<PartialMulCandidate> &PartialMulCandidates) {
1105
// Helper function to extract a common operand from two products
1106
auto FindCommonInstruction = [](const Product &Real,
1107
const Product &Imag) -> Value * {
1108
if (Real.Multiplicand == Imag.Multiplicand ||
1109
Real.Multiplicand == Imag.Multiplier)
1110
return Real.Multiplicand;
1111
1112
if (Real.Multiplier == Imag.Multiplicand ||
1113
Real.Multiplier == Imag.Multiplier)
1114
return Real.Multiplier;
1115
1116
return nullptr;
1117
};
1118
1119
// Iterating over real and imaginary multiplications to find common operands
1120
// If a common operand is found, a partial multiplication candidate is created
1121
// and added to the candidates vector The function returns false if no common
1122
// operands are found for any product
1123
for (unsigned i = 0; i < RealMuls.size(); ++i) {
1124
bool FoundCommon = false;
1125
for (unsigned j = 0; j < ImagMuls.size(); ++j) {
1126
auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]);
1127
if (!Common)
1128
continue;
1129
1130
auto *A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier
1131
: RealMuls[i].Multiplicand;
1132
auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier
1133
: ImagMuls[j].Multiplicand;
1134
1135
auto Node = identifyNode(A, B);
1136
if (Node) {
1137
FoundCommon = true;
1138
PartialMulCandidates.push_back({Common, Node, i, j, false});
1139
}
1140
1141
Node = identifyNode(B, A);
1142
if (Node) {
1143
FoundCommon = true;
1144
PartialMulCandidates.push_back({Common, Node, i, j, true});
1145
}
1146
}
1147
if (!FoundCommon)
1148
return false;
1149
}
1150
return true;
1151
}
1152
1153
ComplexDeinterleavingGraph::NodePtr
1154
ComplexDeinterleavingGraph::identifyMultiplications(
1155
std::vector<Product> &RealMuls, std::vector<Product> &ImagMuls,
1156
NodePtr Accumulator = nullptr) {
1157
if (RealMuls.size() != ImagMuls.size())
1158
return nullptr;
1159
1160
std::vector<PartialMulCandidate> Info;
1161
if (!collectPartialMuls(RealMuls, ImagMuls, Info))
1162
return nullptr;
1163
1164
// Map to store common instruction to node pointers
1165
std::map<Value *, NodePtr> CommonToNode;
1166
std::vector<bool> Processed(Info.size(), false);
1167
for (unsigned I = 0; I < Info.size(); ++I) {
1168
if (Processed[I])
1169
continue;
1170
1171
PartialMulCandidate &InfoA = Info[I];
1172
for (unsigned J = I + 1; J < Info.size(); ++J) {
1173
if (Processed[J])
1174
continue;
1175
1176
PartialMulCandidate &InfoB = Info[J];
1177
auto *InfoReal = &InfoA;
1178
auto *InfoImag = &InfoB;
1179
1180
auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1181
if (!NodeFromCommon) {
1182
std::swap(InfoReal, InfoImag);
1183
NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);
1184
}
1185
if (!NodeFromCommon)
1186
continue;
1187
1188
CommonToNode[InfoReal->Common] = NodeFromCommon;
1189
CommonToNode[InfoImag->Common] = NodeFromCommon;
1190
Processed[I] = true;
1191
Processed[J] = true;
1192
}
1193
}
1194
1195
std::vector<bool> ProcessedReal(RealMuls.size(), false);
1196
std::vector<bool> ProcessedImag(ImagMuls.size(), false);
1197
NodePtr Result = Accumulator;
1198
for (auto &PMI : Info) {
1199
if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx])
1200
continue;
1201
1202
auto It = CommonToNode.find(PMI.Common);
1203
// TODO: Process independent complex multiplications. Cases like this:
1204
// A.real() * B where both A and B are complex numbers.
1205
if (It == CommonToNode.end()) {
1206
LLVM_DEBUG({
1207
dbgs() << "Unprocessed independent partial multiplication:\n";
1208
for (auto *Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]})
1209
dbgs().indent(4) << (Mul->IsPositive ? "+" : "-") << *Mul->Multiplier
1210
<< " multiplied by " << *Mul->Multiplicand << "\n";
1211
});
1212
return nullptr;
1213
}
1214
1215
auto &RealMul = RealMuls[PMI.RealIdx];
1216
auto &ImagMul = ImagMuls[PMI.ImagIdx];
1217
1218
auto NodeA = It->second;
1219
auto NodeB = PMI.Node;
1220
auto IsMultiplicandReal = PMI.Common == NodeA->Real;
1221
// The following table illustrates the relationship between multiplications
1222
// and rotations. If we consider the multiplication (X + iY) * (U + iV), we
1223
// can see:
1224
//
1225
// Rotation | Real | Imag |
1226
// ---------+--------+--------+
1227
// 0 | x * u | x * v |
1228
// 90 | -y * v | y * u |
1229
// 180 | -x * u | -x * v |
1230
// 270 | y * v | -y * u |
1231
//
1232
// Check if the candidate can indeed be represented by partial
1233
// multiplication
1234
// TODO: Add support for multiplication by complex one
1235
if ((IsMultiplicandReal && PMI.IsNodeInverted) ||
1236
(!IsMultiplicandReal && !PMI.IsNodeInverted))
1237
continue;
1238
1239
// Determine the rotation based on the multiplications
1240
ComplexDeinterleavingRotation Rotation;
1241
if (IsMultiplicandReal) {
1242
// Detect 0 and 180 degrees rotation
1243
if (RealMul.IsPositive && ImagMul.IsPositive)
1244
Rotation = llvm::ComplexDeinterleavingRotation::Rotation_0;
1245
else if (!RealMul.IsPositive && !ImagMul.IsPositive)
1246
Rotation = llvm::ComplexDeinterleavingRotation::Rotation_180;
1247
else
1248
continue;
1249
1250
} else {
1251
// Detect 90 and 270 degrees rotation
1252
if (!RealMul.IsPositive && ImagMul.IsPositive)
1253
Rotation = llvm::ComplexDeinterleavingRotation::Rotation_90;
1254
else if (RealMul.IsPositive && !ImagMul.IsPositive)
1255
Rotation = llvm::ComplexDeinterleavingRotation::Rotation_270;
1256
else
1257
continue;
1258
}
1259
1260
LLVM_DEBUG({
1261
dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n";
1262
dbgs().indent(4) << "X: " << *NodeA->Real << "\n";
1263
dbgs().indent(4) << "Y: " << *NodeA->Imag << "\n";
1264
dbgs().indent(4) << "U: " << *NodeB->Real << "\n";
1265
dbgs().indent(4) << "V: " << *NodeB->Imag << "\n";
1266
dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1267
});
1268
1269
NodePtr NodeMul = prepareCompositeNode(
1270
ComplexDeinterleavingOperation::CMulPartial, nullptr, nullptr);
1271
NodeMul->Rotation = Rotation;
1272
NodeMul->addOperand(NodeA);
1273
NodeMul->addOperand(NodeB);
1274
if (Result)
1275
NodeMul->addOperand(Result);
1276
submitCompositeNode(NodeMul);
1277
Result = NodeMul;
1278
ProcessedReal[PMI.RealIdx] = true;
1279
ProcessedImag[PMI.ImagIdx] = true;
1280
}
1281
1282
// Ensure all products have been processed, if not return nullptr.
1283
if (!all_of(ProcessedReal, [](bool V) { return V; }) ||
1284
!all_of(ProcessedImag, [](bool V) { return V; })) {
1285
1286
// Dump debug information about which partial multiplications are not
1287
// processed.
1288
LLVM_DEBUG({
1289
dbgs() << "Unprocessed products (Real):\n";
1290
for (size_t i = 0; i < ProcessedReal.size(); ++i) {
1291
if (!ProcessedReal[i])
1292
dbgs().indent(4) << (RealMuls[i].IsPositive ? "+" : "-")
1293
<< *RealMuls[i].Multiplier << " multiplied by "
1294
<< *RealMuls[i].Multiplicand << "\n";
1295
}
1296
dbgs() << "Unprocessed products (Imag):\n";
1297
for (size_t i = 0; i < ProcessedImag.size(); ++i) {
1298
if (!ProcessedImag[i])
1299
dbgs().indent(4) << (ImagMuls[i].IsPositive ? "+" : "-")
1300
<< *ImagMuls[i].Multiplier << " multiplied by "
1301
<< *ImagMuls[i].Multiplicand << "\n";
1302
}
1303
});
1304
return nullptr;
1305
}
1306
1307
return Result;
1308
}
1309
1310
ComplexDeinterleavingGraph::NodePtr
1311
ComplexDeinterleavingGraph::identifyAdditions(
1312
std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends,
1313
std::optional<FastMathFlags> Flags, NodePtr Accumulator = nullptr) {
1314
if (RealAddends.size() != ImagAddends.size())
1315
return nullptr;
1316
1317
NodePtr Result;
1318
// If we have accumulator use it as first addend
1319
if (Accumulator)
1320
Result = Accumulator;
1321
// Otherwise find an element with both positive real and imaginary parts.
1322
else
1323
Result = extractPositiveAddend(RealAddends, ImagAddends);
1324
1325
if (!Result)
1326
return nullptr;
1327
1328
while (!RealAddends.empty()) {
1329
auto ItR = RealAddends.begin();
1330
auto [R, IsPositiveR] = *ItR;
1331
1332
bool FoundImag = false;
1333
for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1334
auto [I, IsPositiveI] = *ItI;
1335
ComplexDeinterleavingRotation Rotation;
1336
if (IsPositiveR && IsPositiveI)
1337
Rotation = ComplexDeinterleavingRotation::Rotation_0;
1338
else if (!IsPositiveR && IsPositiveI)
1339
Rotation = ComplexDeinterleavingRotation::Rotation_90;
1340
else if (!IsPositiveR && !IsPositiveI)
1341
Rotation = ComplexDeinterleavingRotation::Rotation_180;
1342
else
1343
Rotation = ComplexDeinterleavingRotation::Rotation_270;
1344
1345
NodePtr AddNode;
1346
if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||
1347
Rotation == ComplexDeinterleavingRotation::Rotation_180) {
1348
AddNode = identifyNode(R, I);
1349
} else {
1350
AddNode = identifyNode(I, R);
1351
}
1352
if (AddNode) {
1353
LLVM_DEBUG({
1354
dbgs() << "Identified addition:\n";
1355
dbgs().indent(4) << "X: " << *R << "\n";
1356
dbgs().indent(4) << "Y: " << *I << "\n";
1357
dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";
1358
});
1359
1360
NodePtr TmpNode;
1361
if (Rotation == llvm::ComplexDeinterleavingRotation::Rotation_0) {
1362
TmpNode = prepareCompositeNode(
1363
ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1364
if (Flags) {
1365
TmpNode->Opcode = Instruction::FAdd;
1366
TmpNode->Flags = *Flags;
1367
} else {
1368
TmpNode->Opcode = Instruction::Add;
1369
}
1370
} else if (Rotation ==
1371
llvm::ComplexDeinterleavingRotation::Rotation_180) {
1372
TmpNode = prepareCompositeNode(
1373
ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);
1374
if (Flags) {
1375
TmpNode->Opcode = Instruction::FSub;
1376
TmpNode->Flags = *Flags;
1377
} else {
1378
TmpNode->Opcode = Instruction::Sub;
1379
}
1380
} else {
1381
TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd,
1382
nullptr, nullptr);
1383
TmpNode->Rotation = Rotation;
1384
}
1385
1386
TmpNode->addOperand(Result);
1387
TmpNode->addOperand(AddNode);
1388
submitCompositeNode(TmpNode);
1389
Result = TmpNode;
1390
RealAddends.erase(ItR);
1391
ImagAddends.erase(ItI);
1392
FoundImag = true;
1393
break;
1394
}
1395
}
1396
if (!FoundImag)
1397
return nullptr;
1398
}
1399
return Result;
1400
}
1401
1402
ComplexDeinterleavingGraph::NodePtr
1403
ComplexDeinterleavingGraph::extractPositiveAddend(
1404
std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends) {
1405
for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) {
1406
for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {
1407
auto [R, IsPositiveR] = *ItR;
1408
auto [I, IsPositiveI] = *ItI;
1409
if (IsPositiveR && IsPositiveI) {
1410
auto Result = identifyNode(R, I);
1411
if (Result) {
1412
RealAddends.erase(ItR);
1413
ImagAddends.erase(ItI);
1414
return Result;
1415
}
1416
}
1417
}
1418
}
1419
return nullptr;
1420
}
1421
1422
bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {
1423
// This potential root instruction might already have been recognized as
1424
// reduction. Because RootToNode maps both Real and Imaginary parts to
1425
// CompositeNode we should choose only one either Real or Imag instruction to
1426
// use as an anchor for generating complex instruction.
1427
auto It = RootToNode.find(RootI);
1428
if (It != RootToNode.end()) {
1429
auto RootNode = It->second;
1430
assert(RootNode->Operation ==
1431
ComplexDeinterleavingOperation::ReductionOperation);
1432
// Find out which part, Real or Imag, comes later, and only if we come to
1433
// the latest part, add it to OrderedRoots.
1434
auto *R = cast<Instruction>(RootNode->Real);
1435
auto *I = cast<Instruction>(RootNode->Imag);
1436
auto *ReplacementAnchor = R->comesBefore(I) ? I : R;
1437
if (ReplacementAnchor != RootI)
1438
return false;
1439
OrderedRoots.push_back(RootI);
1440
return true;
1441
}
1442
1443
auto RootNode = identifyRoot(RootI);
1444
if (!RootNode)
1445
return false;
1446
1447
LLVM_DEBUG({
1448
Function *F = RootI->getFunction();
1449
BasicBlock *B = RootI->getParent();
1450
dbgs() << "Complex deinterleaving graph for " << F->getName()
1451
<< "::" << B->getName() << ".\n";
1452
dump(dbgs());
1453
dbgs() << "\n";
1454
});
1455
RootToNode[RootI] = RootNode;
1456
OrderedRoots.push_back(RootI);
1457
return true;
1458
}
1459
1460
bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) {
1461
bool FoundPotentialReduction = false;
1462
1463
auto *Br = dyn_cast<BranchInst>(B->getTerminator());
1464
if (!Br || Br->getNumSuccessors() != 2)
1465
return false;
1466
1467
// Identify simple one-block loop
1468
if (Br->getSuccessor(0) != B && Br->getSuccessor(1) != B)
1469
return false;
1470
1471
SmallVector<PHINode *> PHIs;
1472
for (auto &PHI : B->phis()) {
1473
if (PHI.getNumIncomingValues() != 2)
1474
continue;
1475
1476
if (!PHI.getType()->isVectorTy())
1477
continue;
1478
1479
auto *ReductionOp = dyn_cast<Instruction>(PHI.getIncomingValueForBlock(B));
1480
if (!ReductionOp)
1481
continue;
1482
1483
// Check if final instruction is reduced outside of current block
1484
Instruction *FinalReduction = nullptr;
1485
auto NumUsers = 0u;
1486
for (auto *U : ReductionOp->users()) {
1487
++NumUsers;
1488
if (U == &PHI)
1489
continue;
1490
FinalReduction = dyn_cast<Instruction>(U);
1491
}
1492
1493
if (NumUsers != 2 || !FinalReduction || FinalReduction->getParent() == B ||
1494
isa<PHINode>(FinalReduction))
1495
continue;
1496
1497
ReductionInfo[ReductionOp] = {&PHI, FinalReduction};
1498
BackEdge = B;
1499
auto BackEdgeIdx = PHI.getBasicBlockIndex(B);
1500
auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0;
1501
Incoming = PHI.getIncomingBlock(IncomingIdx);
1502
FoundPotentialReduction = true;
1503
1504
// If the initial value of PHINode is an Instruction, consider it a leaf
1505
// value of a complex deinterleaving graph.
1506
if (auto *InitPHI =
1507
dyn_cast<Instruction>(PHI.getIncomingValueForBlock(Incoming)))
1508
FinalInstructions.insert(InitPHI);
1509
}
1510
return FoundPotentialReduction;
1511
}
1512
1513
void ComplexDeinterleavingGraph::identifyReductionNodes() {
1514
SmallVector<bool> Processed(ReductionInfo.size(), false);
1515
SmallVector<Instruction *> OperationInstruction;
1516
for (auto &P : ReductionInfo)
1517
OperationInstruction.push_back(P.first);
1518
1519
// Identify a complex computation by evaluating two reduction operations that
1520
// potentially could be involved
1521
for (size_t i = 0; i < OperationInstruction.size(); ++i) {
1522
if (Processed[i])
1523
continue;
1524
for (size_t j = i + 1; j < OperationInstruction.size(); ++j) {
1525
if (Processed[j])
1526
continue;
1527
1528
auto *Real = OperationInstruction[i];
1529
auto *Imag = OperationInstruction[j];
1530
if (Real->getType() != Imag->getType())
1531
continue;
1532
1533
RealPHI = ReductionInfo[Real].first;
1534
ImagPHI = ReductionInfo[Imag].first;
1535
PHIsFound = false;
1536
auto Node = identifyNode(Real, Imag);
1537
if (!Node) {
1538
std::swap(Real, Imag);
1539
std::swap(RealPHI, ImagPHI);
1540
Node = identifyNode(Real, Imag);
1541
}
1542
1543
// If a node is identified and reduction PHINode is used in the chain of
1544
// operations, mark its operation instructions as used to prevent
1545
// re-identification and attach the node to the real part
1546
if (Node && PHIsFound) {
1547
LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: "
1548
<< *Real << " / " << *Imag << "\n");
1549
Processed[i] = true;
1550
Processed[j] = true;
1551
auto RootNode = prepareCompositeNode(
1552
ComplexDeinterleavingOperation::ReductionOperation, Real, Imag);
1553
RootNode->addOperand(Node);
1554
RootToNode[Real] = RootNode;
1555
RootToNode[Imag] = RootNode;
1556
submitCompositeNode(RootNode);
1557
break;
1558
}
1559
}
1560
}
1561
1562
RealPHI = nullptr;
1563
ImagPHI = nullptr;
1564
}
1565
1566
bool ComplexDeinterleavingGraph::checkNodes() {
1567
// Collect all instructions from roots to leaves
1568
SmallPtrSet<Instruction *, 16> AllInstructions;
1569
SmallVector<Instruction *, 8> Worklist;
1570
for (auto &Pair : RootToNode)
1571
Worklist.push_back(Pair.first);
1572
1573
// Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG
1574
// chains
1575
while (!Worklist.empty()) {
1576
auto *I = Worklist.back();
1577
Worklist.pop_back();
1578
1579
if (!AllInstructions.insert(I).second)
1580
continue;
1581
1582
for (Value *Op : I->operands()) {
1583
if (auto *OpI = dyn_cast<Instruction>(Op)) {
1584
if (!FinalInstructions.count(I))
1585
Worklist.emplace_back(OpI);
1586
}
1587
}
1588
}
1589
1590
// Find instructions that have users outside of chain
1591
SmallVector<Instruction *, 2> OuterInstructions;
1592
for (auto *I : AllInstructions) {
1593
// Skip root nodes
1594
if (RootToNode.count(I))
1595
continue;
1596
1597
for (User *U : I->users()) {
1598
if (AllInstructions.count(cast<Instruction>(U)))
1599
continue;
1600
1601
// Found an instruction that is not used by XCMLA/XCADD chain
1602
Worklist.emplace_back(I);
1603
break;
1604
}
1605
}
1606
1607
// If any instructions are found to be used outside, find and remove roots
1608
// that somehow connect to those instructions.
1609
SmallPtrSet<Instruction *, 16> Visited;
1610
while (!Worklist.empty()) {
1611
auto *I = Worklist.back();
1612
Worklist.pop_back();
1613
if (!Visited.insert(I).second)
1614
continue;
1615
1616
// Found an impacted root node. Removing it from the nodes to be
1617
// deinterleaved
1618
if (RootToNode.count(I)) {
1619
LLVM_DEBUG(dbgs() << "Instruction " << *I
1620
<< " could be deinterleaved but its chain of complex "
1621
"operations have an outside user\n");
1622
RootToNode.erase(I);
1623
}
1624
1625
if (!AllInstructions.count(I) || FinalInstructions.count(I))
1626
continue;
1627
1628
for (User *U : I->users())
1629
Worklist.emplace_back(cast<Instruction>(U));
1630
1631
for (Value *Op : I->operands()) {
1632
if (auto *OpI = dyn_cast<Instruction>(Op))
1633
Worklist.emplace_back(OpI);
1634
}
1635
}
1636
return !RootToNode.empty();
1637
}
1638
1639
ComplexDeinterleavingGraph::NodePtr
1640
ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) {
1641
if (auto *Intrinsic = dyn_cast<IntrinsicInst>(RootI)) {
1642
if (Intrinsic->getIntrinsicID() != Intrinsic::vector_interleave2)
1643
return nullptr;
1644
1645
auto *Real = dyn_cast<Instruction>(Intrinsic->getOperand(0));
1646
auto *Imag = dyn_cast<Instruction>(Intrinsic->getOperand(1));
1647
if (!Real || !Imag)
1648
return nullptr;
1649
1650
return identifyNode(Real, Imag);
1651
}
1652
1653
auto *SVI = dyn_cast<ShuffleVectorInst>(RootI);
1654
if (!SVI)
1655
return nullptr;
1656
1657
// Look for a shufflevector that takes separate vectors of the real and
1658
// imaginary components and recombines them into a single vector.
1659
if (!isInterleavingMask(SVI->getShuffleMask()))
1660
return nullptr;
1661
1662
Instruction *Real;
1663
Instruction *Imag;
1664
if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag))))
1665
return nullptr;
1666
1667
return identifyNode(Real, Imag);
1668
}
1669
1670
ComplexDeinterleavingGraph::NodePtr
1671
ComplexDeinterleavingGraph::identifyDeinterleave(Instruction *Real,
1672
Instruction *Imag) {
1673
Instruction *I = nullptr;
1674
Value *FinalValue = nullptr;
1675
if (match(Real, m_ExtractValue<0>(m_Instruction(I))) &&
1676
match(Imag, m_ExtractValue<1>(m_Specific(I))) &&
1677
match(I, m_Intrinsic<Intrinsic::vector_deinterleave2>(
1678
m_Value(FinalValue)))) {
1679
NodePtr PlaceholderNode = prepareCompositeNode(
1680
llvm::ComplexDeinterleavingOperation::Deinterleave, Real, Imag);
1681
PlaceholderNode->ReplacementNode = FinalValue;
1682
FinalInstructions.insert(Real);
1683
FinalInstructions.insert(Imag);
1684
return submitCompositeNode(PlaceholderNode);
1685
}
1686
1687
auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real);
1688
auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag);
1689
if (!RealShuffle || !ImagShuffle) {
1690
if (RealShuffle || ImagShuffle)
1691
LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n");
1692
return nullptr;
1693
}
1694
1695
Value *RealOp1 = RealShuffle->getOperand(1);
1696
if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) {
1697
LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n");
1698
return nullptr;
1699
}
1700
Value *ImagOp1 = ImagShuffle->getOperand(1);
1701
if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) {
1702
LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n");
1703
return nullptr;
1704
}
1705
1706
Value *RealOp0 = RealShuffle->getOperand(0);
1707
Value *ImagOp0 = ImagShuffle->getOperand(0);
1708
1709
if (RealOp0 != ImagOp0) {
1710
LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n");
1711
return nullptr;
1712
}
1713
1714
ArrayRef<int> RealMask = RealShuffle->getShuffleMask();
1715
ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask();
1716
if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) {
1717
LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n");
1718
return nullptr;
1719
}
1720
1721
if (RealMask[0] != 0 || ImagMask[0] != 1) {
1722
LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n");
1723
return nullptr;
1724
}
1725
1726
// Type checking, the shuffle type should be a vector type of the same
1727
// scalar type, but half the size
1728
auto CheckType = [&](ShuffleVectorInst *Shuffle) {
1729
Value *Op = Shuffle->getOperand(0);
1730
auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType());
1731
auto *OpTy = cast<FixedVectorType>(Op->getType());
1732
1733
if (OpTy->getScalarType() != ShuffleTy->getScalarType())
1734
return false;
1735
if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())
1736
return false;
1737
1738
return true;
1739
};
1740
1741
auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool {
1742
if (!CheckType(Shuffle))
1743
return false;
1744
1745
ArrayRef<int> Mask = Shuffle->getShuffleMask();
1746
int Last = *Mask.rbegin();
1747
1748
Value *Op = Shuffle->getOperand(0);
1749
auto *OpTy = cast<FixedVectorType>(Op->getType());
1750
int NumElements = OpTy->getNumElements();
1751
1752
// Ensure that the deinterleaving shuffle only pulls from the first
1753
// shuffle operand.
1754
return Last < NumElements;
1755
};
1756
1757
if (RealShuffle->getType() != ImagShuffle->getType()) {
1758
LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n");
1759
return nullptr;
1760
}
1761
if (!CheckDeinterleavingShuffle(RealShuffle)) {
1762
LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n");
1763
return nullptr;
1764
}
1765
if (!CheckDeinterleavingShuffle(ImagShuffle)) {
1766
LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n");
1767
return nullptr;
1768
}
1769
1770
NodePtr PlaceholderNode =
1771
prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Deinterleave,
1772
RealShuffle, ImagShuffle);
1773
PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);
1774
FinalInstructions.insert(RealShuffle);
1775
FinalInstructions.insert(ImagShuffle);
1776
return submitCompositeNode(PlaceholderNode);
1777
}
1778
1779
ComplexDeinterleavingGraph::NodePtr
1780
ComplexDeinterleavingGraph::identifySplat(Value *R, Value *I) {
1781
auto IsSplat = [](Value *V) -> bool {
1782
// Fixed-width vector with constants
1783
if (isa<ConstantDataVector>(V))
1784
return true;
1785
1786
VectorType *VTy;
1787
ArrayRef<int> Mask;
1788
// Splats are represented differently depending on whether the repeated
1789
// value is a constant or an Instruction
1790
if (auto *Const = dyn_cast<ConstantExpr>(V)) {
1791
if (Const->getOpcode() != Instruction::ShuffleVector)
1792
return false;
1793
VTy = cast<VectorType>(Const->getType());
1794
Mask = Const->getShuffleMask();
1795
} else if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V)) {
1796
VTy = Shuf->getType();
1797
Mask = Shuf->getShuffleMask();
1798
} else {
1799
return false;
1800
}
1801
1802
// When the data type is <1 x Type>, it's not possible to differentiate
1803
// between the ComplexDeinterleaving::Deinterleave and
1804
// ComplexDeinterleaving::Splat operations.
1805
if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1)
1806
return false;
1807
1808
return all_equal(Mask) && Mask[0] == 0;
1809
};
1810
1811
if (!IsSplat(R) || !IsSplat(I))
1812
return nullptr;
1813
1814
auto *Real = dyn_cast<Instruction>(R);
1815
auto *Imag = dyn_cast<Instruction>(I);
1816
if ((!Real && Imag) || (Real && !Imag))
1817
return nullptr;
1818
1819
if (Real && Imag) {
1820
// Non-constant splats should be in the same basic block
1821
if (Real->getParent() != Imag->getParent())
1822
return nullptr;
1823
1824
FinalInstructions.insert(Real);
1825
FinalInstructions.insert(Imag);
1826
}
1827
NodePtr PlaceholderNode =
1828
prepareCompositeNode(ComplexDeinterleavingOperation::Splat, R, I);
1829
return submitCompositeNode(PlaceholderNode);
1830
}
1831
1832
ComplexDeinterleavingGraph::NodePtr
1833
ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,
1834
Instruction *Imag) {
1835
if (Real != RealPHI || Imag != ImagPHI)
1836
return nullptr;
1837
1838
PHIsFound = true;
1839
NodePtr PlaceholderNode = prepareCompositeNode(
1840
ComplexDeinterleavingOperation::ReductionPHI, Real, Imag);
1841
return submitCompositeNode(PlaceholderNode);
1842
}
1843
1844
ComplexDeinterleavingGraph::NodePtr
1845
ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real,
1846
Instruction *Imag) {
1847
auto *SelectReal = dyn_cast<SelectInst>(Real);
1848
auto *SelectImag = dyn_cast<SelectInst>(Imag);
1849
if (!SelectReal || !SelectImag)
1850
return nullptr;
1851
1852
Instruction *MaskA, *MaskB;
1853
Instruction *AR, *AI, *RA, *BI;
1854
if (!match(Real, m_Select(m_Instruction(MaskA), m_Instruction(AR),
1855
m_Instruction(RA))) ||
1856
!match(Imag, m_Select(m_Instruction(MaskB), m_Instruction(AI),
1857
m_Instruction(BI))))
1858
return nullptr;
1859
1860
if (MaskA != MaskB && !MaskA->isIdenticalTo(MaskB))
1861
return nullptr;
1862
1863
if (!MaskA->getType()->isVectorTy())
1864
return nullptr;
1865
1866
auto NodeA = identifyNode(AR, AI);
1867
if (!NodeA)
1868
return nullptr;
1869
1870
auto NodeB = identifyNode(RA, BI);
1871
if (!NodeB)
1872
return nullptr;
1873
1874
NodePtr PlaceholderNode = prepareCompositeNode(
1875
ComplexDeinterleavingOperation::ReductionSelect, Real, Imag);
1876
PlaceholderNode->addOperand(NodeA);
1877
PlaceholderNode->addOperand(NodeB);
1878
FinalInstructions.insert(MaskA);
1879
FinalInstructions.insert(MaskB);
1880
return submitCompositeNode(PlaceholderNode);
1881
}
1882
1883
static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode,
1884
std::optional<FastMathFlags> Flags,
1885
Value *InputA, Value *InputB) {
1886
Value *I;
1887
switch (Opcode) {
1888
case Instruction::FNeg:
1889
I = B.CreateFNeg(InputA);
1890
break;
1891
case Instruction::FAdd:
1892
I = B.CreateFAdd(InputA, InputB);
1893
break;
1894
case Instruction::Add:
1895
I = B.CreateAdd(InputA, InputB);
1896
break;
1897
case Instruction::FSub:
1898
I = B.CreateFSub(InputA, InputB);
1899
break;
1900
case Instruction::Sub:
1901
I = B.CreateSub(InputA, InputB);
1902
break;
1903
case Instruction::FMul:
1904
I = B.CreateFMul(InputA, InputB);
1905
break;
1906
case Instruction::Mul:
1907
I = B.CreateMul(InputA, InputB);
1908
break;
1909
default:
1910
llvm_unreachable("Incorrect symmetric opcode");
1911
}
1912
if (Flags)
1913
cast<Instruction>(I)->setFastMathFlags(*Flags);
1914
return I;
1915
}
1916
1917
Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,
1918
RawNodePtr Node) {
1919
if (Node->ReplacementNode)
1920
return Node->ReplacementNode;
1921
1922
auto ReplaceOperandIfExist = [&](RawNodePtr &Node, unsigned Idx) -> Value * {
1923
return Node->Operands.size() > Idx
1924
? replaceNode(Builder, Node->Operands[Idx])
1925
: nullptr;
1926
};
1927
1928
Value *ReplacementNode;
1929
switch (Node->Operation) {
1930
case ComplexDeinterleavingOperation::CAdd:
1931
case ComplexDeinterleavingOperation::CMulPartial:
1932
case ComplexDeinterleavingOperation::Symmetric: {
1933
Value *Input0 = ReplaceOperandIfExist(Node, 0);
1934
Value *Input1 = ReplaceOperandIfExist(Node, 1);
1935
Value *Accumulator = ReplaceOperandIfExist(Node, 2);
1936
assert(!Input1 || (Input0->getType() == Input1->getType() &&
1937
"Node inputs need to be of the same type"));
1938
assert(!Accumulator ||
1939
(Input0->getType() == Accumulator->getType() &&
1940
"Accumulator and input need to be of the same type"));
1941
if (Node->Operation == ComplexDeinterleavingOperation::Symmetric)
1942
ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode, Node->Flags,
1943
Input0, Input1);
1944
else
1945
ReplacementNode = TL->createComplexDeinterleavingIR(
1946
Builder, Node->Operation, Node->Rotation, Input0, Input1,
1947
Accumulator);
1948
break;
1949
}
1950
case ComplexDeinterleavingOperation::Deinterleave:
1951
llvm_unreachable("Deinterleave node should already have ReplacementNode");
1952
break;
1953
case ComplexDeinterleavingOperation::Splat: {
1954
auto *NewTy = VectorType::getDoubleElementsVectorType(
1955
cast<VectorType>(Node->Real->getType()));
1956
auto *R = dyn_cast<Instruction>(Node->Real);
1957
auto *I = dyn_cast<Instruction>(Node->Imag);
1958
if (R && I) {
1959
// Splats that are not constant are interleaved where they are located
1960
Instruction *InsertPoint = (I->comesBefore(R) ? R : I)->getNextNode();
1961
IRBuilder<> IRB(InsertPoint);
1962
ReplacementNode = IRB.CreateIntrinsic(Intrinsic::vector_interleave2,
1963
NewTy, {Node->Real, Node->Imag});
1964
} else {
1965
ReplacementNode = Builder.CreateIntrinsic(
1966
Intrinsic::vector_interleave2, NewTy, {Node->Real, Node->Imag});
1967
}
1968
break;
1969
}
1970
case ComplexDeinterleavingOperation::ReductionPHI: {
1971
// If Operation is ReductionPHI, a new empty PHINode is created.
1972
// It is filled later when the ReductionOperation is processed.
1973
auto *VTy = cast<VectorType>(Node->Real->getType());
1974
auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
1975
auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHIIt());
1976
OldToNewPHI[dyn_cast<PHINode>(Node->Real)] = NewPHI;
1977
ReplacementNode = NewPHI;
1978
break;
1979
}
1980
case ComplexDeinterleavingOperation::ReductionOperation:
1981
ReplacementNode = replaceNode(Builder, Node->Operands[0]);
1982
processReductionOperation(ReplacementNode, Node);
1983
break;
1984
case ComplexDeinterleavingOperation::ReductionSelect: {
1985
auto *MaskReal = cast<Instruction>(Node->Real)->getOperand(0);
1986
auto *MaskImag = cast<Instruction>(Node->Imag)->getOperand(0);
1987
auto *A = replaceNode(Builder, Node->Operands[0]);
1988
auto *B = replaceNode(Builder, Node->Operands[1]);
1989
auto *NewMaskTy = VectorType::getDoubleElementsVectorType(
1990
cast<VectorType>(MaskReal->getType()));
1991
auto *NewMask = Builder.CreateIntrinsic(Intrinsic::vector_interleave2,
1992
NewMaskTy, {MaskReal, MaskImag});
1993
ReplacementNode = Builder.CreateSelect(NewMask, A, B);
1994
break;
1995
}
1996
}
1997
1998
assert(ReplacementNode && "Target failed to create Intrinsic call.");
1999
NumComplexTransformations += 1;
2000
Node->ReplacementNode = ReplacementNode;
2001
return ReplacementNode;
2002
}
2003
2004
void ComplexDeinterleavingGraph::processReductionOperation(
2005
Value *OperationReplacement, RawNodePtr Node) {
2006
auto *Real = cast<Instruction>(Node->Real);
2007
auto *Imag = cast<Instruction>(Node->Imag);
2008
auto *OldPHIReal = ReductionInfo[Real].first;
2009
auto *OldPHIImag = ReductionInfo[Imag].first;
2010
auto *NewPHI = OldToNewPHI[OldPHIReal];
2011
2012
auto *VTy = cast<VectorType>(Real->getType());
2013
auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);
2014
2015
// We have to interleave initial origin values coming from IncomingBlock
2016
Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming);
2017
Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming);
2018
2019
IRBuilder<> Builder(Incoming->getTerminator());
2020
auto *NewInit = Builder.CreateIntrinsic(Intrinsic::vector_interleave2, NewVTy,
2021
{InitReal, InitImag});
2022
2023
NewPHI->addIncoming(NewInit, Incoming);
2024
NewPHI->addIncoming(OperationReplacement, BackEdge);
2025
2026
// Deinterleave complex vector outside of loop so that it can be finally
2027
// reduced
2028
auto *FinalReductionReal = ReductionInfo[Real].second;
2029
auto *FinalReductionImag = ReductionInfo[Imag].second;
2030
2031
Builder.SetInsertPoint(
2032
&*FinalReductionReal->getParent()->getFirstInsertionPt());
2033
auto *Deinterleave = Builder.CreateIntrinsic(Intrinsic::vector_deinterleave2,
2034
OperationReplacement->getType(),
2035
OperationReplacement);
2036
2037
auto *NewReal = Builder.CreateExtractValue(Deinterleave, (uint64_t)0);
2038
FinalReductionReal->replaceUsesOfWith(Real, NewReal);
2039
2040
Builder.SetInsertPoint(FinalReductionImag);
2041
auto *NewImag = Builder.CreateExtractValue(Deinterleave, 1);
2042
FinalReductionImag->replaceUsesOfWith(Imag, NewImag);
2043
}
2044
2045
void ComplexDeinterleavingGraph::replaceNodes() {
2046
SmallVector<Instruction *, 16> DeadInstrRoots;
2047
for (auto *RootInstruction : OrderedRoots) {
2048
// Check if this potential root went through check process and we can
2049
// deinterleave it
2050
if (!RootToNode.count(RootInstruction))
2051
continue;
2052
2053
IRBuilder<> Builder(RootInstruction);
2054
auto RootNode = RootToNode[RootInstruction];
2055
Value *R = replaceNode(Builder, RootNode.get());
2056
2057
if (RootNode->Operation ==
2058
ComplexDeinterleavingOperation::ReductionOperation) {
2059
auto *RootReal = cast<Instruction>(RootNode->Real);
2060
auto *RootImag = cast<Instruction>(RootNode->Imag);
2061
ReductionInfo[RootReal].first->removeIncomingValue(BackEdge);
2062
ReductionInfo[RootImag].first->removeIncomingValue(BackEdge);
2063
DeadInstrRoots.push_back(cast<Instruction>(RootReal));
2064
DeadInstrRoots.push_back(cast<Instruction>(RootImag));
2065
} else {
2066
assert(R && "Unable to find replacement for RootInstruction");
2067
DeadInstrRoots.push_back(RootInstruction);
2068
RootInstruction->replaceAllUsesWith(R);
2069
}
2070
}
2071
2072
for (auto *I : DeadInstrRoots)
2073
RecursivelyDeleteTriviallyDeadInstructions(I, TLI);
2074
}
2075
2076