Path: blob/main/contrib/llvm-project/llvm/lib/CodeGen/ComplexDeinterleavingPass.cpp
35234 views
//===- ComplexDeinterleavingPass.cpp --------------------------------------===//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.3// See https://llvm.org/LICENSE.txt for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===----------------------------------------------------------------------===//7//8// Identification:9// This step is responsible for finding the patterns that can be lowered to10// complex instructions, and building a graph to represent the complex11// structures. Starting from the "Converging Shuffle" (a shuffle that12// reinterleaves the complex components, with a mask of <0, 2, 1, 3>), the13// operands are evaluated and identified as "Composite Nodes" (collections of14// instructions that can potentially be lowered to a single complex15// instruction). This is performed by checking the real and imaginary components16// and tracking the data flow for each component while following the operand17// pairs. Validity of each node is expected to be done upon creation, and any18// validation errors should halt traversal and prevent further graph19// construction.20// Instead of relying on Shuffle operations, vector interleaving and21// deinterleaving can be represented by vector.interleave2 and22// vector.deinterleave2 intrinsics. Scalable vectors can be represented only by23// these intrinsics, whereas, fixed-width vectors are recognized for both24// shufflevector instruction and intrinsics.25//26// Replacement:27// This step traverses the graph built up by identification, delegating to the28// target to validate and generate the correct intrinsics, and plumbs them29// together connecting each end of the new intrinsics graph to the existing30// use-def chain. This step is assumed to finish successfully, as all31// information is expected to be correct by this point.32//33//34// Internal data structure:35// ComplexDeinterleavingGraph:36// Keeps references to all the valid CompositeNodes formed as part of the37// transformation, and every Instruction contained within said nodes. It also38// holds onto a reference to the root Instruction, and the root node that should39// replace it.40//41// ComplexDeinterleavingCompositeNode:42// A CompositeNode represents a single transformation point; each node should43// transform into a single complex instruction (ignoring vector splitting, which44// would generate more instructions per node). They are identified in a45// depth-first manner, traversing and identifying the operands of each46// instruction in the order they appear in the IR.47// Each node maintains a reference to its Real and Imaginary instructions,48// as well as any additional instructions that make up the identified operation49// (Internal instructions should only have uses within their containing node).50// A Node also contains the rotation and operation type that it represents.51// Operands contains pointers to other CompositeNodes, acting as the edges in52// the graph. ReplacementValue is the transformed Value* that has been emitted53// to the IR.54//55// Note: If the operation of a Node is Shuffle, only the Real, Imaginary, and56// ReplacementValue fields of that Node are relevant, where the ReplacementValue57// should be pre-populated.58//59//===----------------------------------------------------------------------===//6061#include "llvm/CodeGen/ComplexDeinterleavingPass.h"62#include "llvm/ADT/MapVector.h"63#include "llvm/ADT/Statistic.h"64#include "llvm/Analysis/TargetLibraryInfo.h"65#include "llvm/Analysis/TargetTransformInfo.h"66#include "llvm/CodeGen/TargetLowering.h"67#include "llvm/CodeGen/TargetPassConfig.h"68#include "llvm/CodeGen/TargetSubtargetInfo.h"69#include "llvm/IR/IRBuilder.h"70#include "llvm/IR/PatternMatch.h"71#include "llvm/InitializePasses.h"72#include "llvm/Target/TargetMachine.h"73#include "llvm/Transforms/Utils/Local.h"74#include <algorithm>7576using namespace llvm;77using namespace PatternMatch;7879#define DEBUG_TYPE "complex-deinterleaving"8081STATISTIC(NumComplexTransformations, "Amount of complex patterns transformed");8283static cl::opt<bool> ComplexDeinterleavingEnabled(84"enable-complex-deinterleaving",85cl::desc("Enable generation of complex instructions"), cl::init(true),86cl::Hidden);8788/// Checks the given mask, and determines whether said mask is interleaving.89///90/// To be interleaving, a mask must alternate between `i` and `i + (Length /91/// 2)`, and must contain all numbers within the range of `[0..Length)` (e.g. a92/// 4x vector interleaving mask would be <0, 2, 1, 3>).93static bool isInterleavingMask(ArrayRef<int> Mask);9495/// Checks the given mask, and determines whether said mask is deinterleaving.96///97/// To be deinterleaving, a mask must increment in steps of 2, and either start98/// with 0 or 1.99/// (e.g. an 8x vector deinterleaving mask would be either <0, 2, 4, 6> or100/// <1, 3, 5, 7>).101static bool isDeinterleavingMask(ArrayRef<int> Mask);102103/// Returns true if the operation is a negation of V, and it works for both104/// integers and floats.105static bool isNeg(Value *V);106107/// Returns the operand for negation operation.108static Value *getNegOperand(Value *V);109110namespace {111112class ComplexDeinterleavingLegacyPass : public FunctionPass {113public:114static char ID;115116ComplexDeinterleavingLegacyPass(const TargetMachine *TM = nullptr)117: FunctionPass(ID), TM(TM) {118initializeComplexDeinterleavingLegacyPassPass(119*PassRegistry::getPassRegistry());120}121122StringRef getPassName() const override {123return "Complex Deinterleaving Pass";124}125126bool runOnFunction(Function &F) override;127void getAnalysisUsage(AnalysisUsage &AU) const override {128AU.addRequired<TargetLibraryInfoWrapperPass>();129AU.setPreservesCFG();130}131132private:133const TargetMachine *TM;134};135136class ComplexDeinterleavingGraph;137struct ComplexDeinterleavingCompositeNode {138139ComplexDeinterleavingCompositeNode(ComplexDeinterleavingOperation Op,140Value *R, Value *I)141: Operation(Op), Real(R), Imag(I) {}142143private:144friend class ComplexDeinterleavingGraph;145using NodePtr = std::shared_ptr<ComplexDeinterleavingCompositeNode>;146using RawNodePtr = ComplexDeinterleavingCompositeNode *;147148public:149ComplexDeinterleavingOperation Operation;150Value *Real;151Value *Imag;152153// This two members are required exclusively for generating154// ComplexDeinterleavingOperation::Symmetric operations.155unsigned Opcode;156std::optional<FastMathFlags> Flags;157158ComplexDeinterleavingRotation Rotation =159ComplexDeinterleavingRotation::Rotation_0;160SmallVector<RawNodePtr> Operands;161Value *ReplacementNode = nullptr;162163void addOperand(NodePtr Node) { Operands.push_back(Node.get()); }164165void dump() { dump(dbgs()); }166void dump(raw_ostream &OS) {167auto PrintValue = [&](Value *V) {168if (V) {169OS << "\"";170V->print(OS, true);171OS << "\"\n";172} else173OS << "nullptr\n";174};175auto PrintNodeRef = [&](RawNodePtr Ptr) {176if (Ptr)177OS << Ptr << "\n";178else179OS << "nullptr\n";180};181182OS << "- CompositeNode: " << this << "\n";183OS << " Real: ";184PrintValue(Real);185OS << " Imag: ";186PrintValue(Imag);187OS << " ReplacementNode: ";188PrintValue(ReplacementNode);189OS << " Operation: " << (int)Operation << "\n";190OS << " Rotation: " << ((int)Rotation * 90) << "\n";191OS << " Operands: \n";192for (const auto &Op : Operands) {193OS << " - ";194PrintNodeRef(Op);195}196}197};198199class ComplexDeinterleavingGraph {200public:201struct Product {202Value *Multiplier;203Value *Multiplicand;204bool IsPositive;205};206207using Addend = std::pair<Value *, bool>;208using NodePtr = ComplexDeinterleavingCompositeNode::NodePtr;209using RawNodePtr = ComplexDeinterleavingCompositeNode::RawNodePtr;210211// Helper struct for holding info about potential partial multiplication212// candidates213struct PartialMulCandidate {214Value *Common;215NodePtr Node;216unsigned RealIdx;217unsigned ImagIdx;218bool IsNodeInverted;219};220221explicit ComplexDeinterleavingGraph(const TargetLowering *TL,222const TargetLibraryInfo *TLI)223: TL(TL), TLI(TLI) {}224225private:226const TargetLowering *TL = nullptr;227const TargetLibraryInfo *TLI = nullptr;228SmallVector<NodePtr> CompositeNodes;229DenseMap<std::pair<Value *, Value *>, NodePtr> CachedResult;230231SmallPtrSet<Instruction *, 16> FinalInstructions;232233/// Root instructions are instructions from which complex computation starts234std::map<Instruction *, NodePtr> RootToNode;235236/// Topologically sorted root instructions237SmallVector<Instruction *, 1> OrderedRoots;238239/// When examining a basic block for complex deinterleaving, if it is a simple240/// one-block loop, then the only incoming block is 'Incoming' and the241/// 'BackEdge' block is the block itself."242BasicBlock *BackEdge = nullptr;243BasicBlock *Incoming = nullptr;244245/// ReductionInfo maps from %ReductionOp to %PHInode and Instruction246/// %OutsideUser as it is shown in the IR:247///248/// vector.body:249/// %PHInode = phi <vector type> [ zeroinitializer, %entry ],250/// [ %ReductionOp, %vector.body ]251/// ...252/// %ReductionOp = fadd i64 ...253/// ...254/// br i1 %condition, label %vector.body, %middle.block255///256/// middle.block:257/// %OutsideUser = llvm.vector.reduce.fadd(..., %ReductionOp)258///259/// %OutsideUser can be `llvm.vector.reduce.fadd` or `fadd` preceding260/// `llvm.vector.reduce.fadd` when unroll factor isn't one.261MapVector<Instruction *, std::pair<PHINode *, Instruction *>> ReductionInfo;262263/// In the process of detecting a reduction, we consider a pair of264/// %ReductionOP, which we refer to as real and imag (or vice versa), and265/// traverse the use-tree to detect complex operations. As this is a reduction266/// operation, it will eventually reach RealPHI and ImagPHI, which corresponds267/// to the %ReductionOPs that we suspect to be complex.268/// RealPHI and ImagPHI are used by the identifyPHINode method.269PHINode *RealPHI = nullptr;270PHINode *ImagPHI = nullptr;271272/// Set this flag to true if RealPHI and ImagPHI were reached during reduction273/// detection.274bool PHIsFound = false;275276/// OldToNewPHI maps the original real PHINode to a new, double-sized PHINode.277/// The new PHINode corresponds to a vector of deinterleaved complex numbers.278/// This mapping is populated during279/// ComplexDeinterleavingOperation::ReductionPHI node replacement. It is then280/// used in the ComplexDeinterleavingOperation::ReductionOperation node281/// replacement process.282std::map<PHINode *, PHINode *> OldToNewPHI;283284NodePtr prepareCompositeNode(ComplexDeinterleavingOperation Operation,285Value *R, Value *I) {286assert(((Operation != ComplexDeinterleavingOperation::ReductionPHI &&287Operation != ComplexDeinterleavingOperation::ReductionOperation) ||288(R && I)) &&289"Reduction related nodes must have Real and Imaginary parts");290return std::make_shared<ComplexDeinterleavingCompositeNode>(Operation, R,291I);292}293294NodePtr submitCompositeNode(NodePtr Node) {295CompositeNodes.push_back(Node);296if (Node->Real && Node->Imag)297CachedResult[{Node->Real, Node->Imag}] = Node;298return Node;299}300301/// Identifies a complex partial multiply pattern and its rotation, based on302/// the following patterns303///304/// 0: r: cr + ar * br305/// i: ci + ar * bi306/// 90: r: cr - ai * bi307/// i: ci + ai * br308/// 180: r: cr - ar * br309/// i: ci - ar * bi310/// 270: r: cr + ai * bi311/// i: ci - ai * br312NodePtr identifyPartialMul(Instruction *Real, Instruction *Imag);313314/// Identify the other branch of a Partial Mul, taking the CommonOperandI that315/// is partially known from identifyPartialMul, filling in the other half of316/// the complex pair.317NodePtr318identifyNodeWithImplicitAdd(Instruction *I, Instruction *J,319std::pair<Value *, Value *> &CommonOperandI);320321/// Identifies a complex add pattern and its rotation, based on the following322/// patterns.323///324/// 90: r: ar - bi325/// i: ai + br326/// 270: r: ar + bi327/// i: ai - br328NodePtr identifyAdd(Instruction *Real, Instruction *Imag);329NodePtr identifySymmetricOperation(Instruction *Real, Instruction *Imag);330331NodePtr identifyNode(Value *R, Value *I);332333/// Determine if a sum of complex numbers can be formed from \p RealAddends334/// and \p ImagAddens. If \p Accumulator is not null, add the result to it.335/// Return nullptr if it is not possible to construct a complex number.336/// \p Flags are needed to generate symmetric Add and Sub operations.337NodePtr identifyAdditions(std::list<Addend> &RealAddends,338std::list<Addend> &ImagAddends,339std::optional<FastMathFlags> Flags,340NodePtr Accumulator);341342/// Extract one addend that have both real and imaginary parts positive.343NodePtr extractPositiveAddend(std::list<Addend> &RealAddends,344std::list<Addend> &ImagAddends);345346/// Determine if sum of multiplications of complex numbers can be formed from347/// \p RealMuls and \p ImagMuls. If \p Accumulator is not null, add the result348/// to it. Return nullptr if it is not possible to construct a complex number.349NodePtr identifyMultiplications(std::vector<Product> &RealMuls,350std::vector<Product> &ImagMuls,351NodePtr Accumulator);352353/// Go through pairs of multiplication (one Real and one Imag) and find all354/// possible candidates for partial multiplication and put them into \p355/// Candidates. Returns true if all Product has pair with common operand356bool collectPartialMuls(const std::vector<Product> &RealMuls,357const std::vector<Product> &ImagMuls,358std::vector<PartialMulCandidate> &Candidates);359360/// If the code is compiled with -Ofast or expressions have `reassoc` flag,361/// the order of complex computation operations may be significantly altered,362/// and the real and imaginary parts may not be executed in parallel. This363/// function takes this into consideration and employs a more general approach364/// to identify complex computations. Initially, it gathers all the addends365/// and multiplicands and then constructs a complex expression from them.366NodePtr identifyReassocNodes(Instruction *I, Instruction *J);367368NodePtr identifyRoot(Instruction *I);369370/// Identifies the Deinterleave operation applied to a vector containing371/// complex numbers. There are two ways to represent the Deinterleave372/// operation:373/// * Using two shufflevectors with even indices for /pReal instruction and374/// odd indices for /pImag instructions (only for fixed-width vectors)375/// * Using two extractvalue instructions applied to `vector.deinterleave2`376/// intrinsic (for both fixed and scalable vectors)377NodePtr identifyDeinterleave(Instruction *Real, Instruction *Imag);378379/// identifying the operation that represents a complex number repeated in a380/// Splat vector. There are two possible types of splats: ConstantExpr with381/// the opcode ShuffleVector and ShuffleVectorInstr. Both should have an382/// initialization mask with all values set to zero.383NodePtr identifySplat(Value *Real, Value *Imag);384385NodePtr identifyPHINode(Instruction *Real, Instruction *Imag);386387/// Identifies SelectInsts in a loop that has reduction with predication masks388/// and/or predicated tail folding389NodePtr identifySelectNode(Instruction *Real, Instruction *Imag);390391Value *replaceNode(IRBuilderBase &Builder, RawNodePtr Node);392393/// Complete IR modifications after producing new reduction operation:394/// * Populate the PHINode generated for395/// ComplexDeinterleavingOperation::ReductionPHI396/// * Deinterleave the final value outside of the loop and repurpose original397/// reduction users398void processReductionOperation(Value *OperationReplacement, RawNodePtr Node);399400public:401void dump() { dump(dbgs()); }402void dump(raw_ostream &OS) {403for (const auto &Node : CompositeNodes)404Node->dump(OS);405}406407/// Returns false if the deinterleaving operation should be cancelled for the408/// current graph.409bool identifyNodes(Instruction *RootI);410411/// In case \pB is one-block loop, this function seeks potential reductions412/// and populates ReductionInfo. Returns true if any reductions were413/// identified.414bool collectPotentialReductions(BasicBlock *B);415416void identifyReductionNodes();417418/// Check that every instruction, from the roots to the leaves, has internal419/// uses.420bool checkNodes();421422/// Perform the actual replacement of the underlying instruction graph.423void replaceNodes();424};425426class ComplexDeinterleaving {427public:428ComplexDeinterleaving(const TargetLowering *tl, const TargetLibraryInfo *tli)429: TL(tl), TLI(tli) {}430bool runOnFunction(Function &F);431432private:433bool evaluateBasicBlock(BasicBlock *B);434435const TargetLowering *TL = nullptr;436const TargetLibraryInfo *TLI = nullptr;437};438439} // namespace440441char ComplexDeinterleavingLegacyPass::ID = 0;442443INITIALIZE_PASS_BEGIN(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,444"Complex Deinterleaving", false, false)445INITIALIZE_PASS_END(ComplexDeinterleavingLegacyPass, DEBUG_TYPE,446"Complex Deinterleaving", false, false)447448PreservedAnalyses ComplexDeinterleavingPass::run(Function &F,449FunctionAnalysisManager &AM) {450const TargetLowering *TL = TM->getSubtargetImpl(F)->getTargetLowering();451auto &TLI = AM.getResult<llvm::TargetLibraryAnalysis>(F);452if (!ComplexDeinterleaving(TL, &TLI).runOnFunction(F))453return PreservedAnalyses::all();454455PreservedAnalyses PA;456PA.preserve<FunctionAnalysisManagerModuleProxy>();457return PA;458}459460FunctionPass *llvm::createComplexDeinterleavingPass(const TargetMachine *TM) {461return new ComplexDeinterleavingLegacyPass(TM);462}463464bool ComplexDeinterleavingLegacyPass::runOnFunction(Function &F) {465const auto *TL = TM->getSubtargetImpl(F)->getTargetLowering();466auto TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);467return ComplexDeinterleaving(TL, &TLI).runOnFunction(F);468}469470bool ComplexDeinterleaving::runOnFunction(Function &F) {471if (!ComplexDeinterleavingEnabled) {472LLVM_DEBUG(473dbgs() << "Complex deinterleaving has been explicitly disabled.\n");474return false;475}476477if (!TL->isComplexDeinterleavingSupported()) {478LLVM_DEBUG(479dbgs() << "Complex deinterleaving has been disabled, target does "480"not support lowering of complex number operations.\n");481return false;482}483484bool Changed = false;485for (auto &B : F)486Changed |= evaluateBasicBlock(&B);487488return Changed;489}490491static bool isInterleavingMask(ArrayRef<int> Mask) {492// If the size is not even, it's not an interleaving mask493if ((Mask.size() & 1))494return false;495496int HalfNumElements = Mask.size() / 2;497for (int Idx = 0; Idx < HalfNumElements; ++Idx) {498int MaskIdx = Idx * 2;499if (Mask[MaskIdx] != Idx || Mask[MaskIdx + 1] != (Idx + HalfNumElements))500return false;501}502503return true;504}505506static bool isDeinterleavingMask(ArrayRef<int> Mask) {507int Offset = Mask[0];508int HalfNumElements = Mask.size() / 2;509510for (int Idx = 1; Idx < HalfNumElements; ++Idx) {511if (Mask[Idx] != (Idx * 2) + Offset)512return false;513}514515return true;516}517518bool isNeg(Value *V) {519return match(V, m_FNeg(m_Value())) || match(V, m_Neg(m_Value()));520}521522Value *getNegOperand(Value *V) {523assert(isNeg(V));524auto *I = cast<Instruction>(V);525if (I->getOpcode() == Instruction::FNeg)526return I->getOperand(0);527528return I->getOperand(1);529}530531bool ComplexDeinterleaving::evaluateBasicBlock(BasicBlock *B) {532ComplexDeinterleavingGraph Graph(TL, TLI);533if (Graph.collectPotentialReductions(B))534Graph.identifyReductionNodes();535536for (auto &I : *B)537Graph.identifyNodes(&I);538539if (Graph.checkNodes()) {540Graph.replaceNodes();541return true;542}543544return false;545}546547ComplexDeinterleavingGraph::NodePtr548ComplexDeinterleavingGraph::identifyNodeWithImplicitAdd(549Instruction *Real, Instruction *Imag,550std::pair<Value *, Value *> &PartialMatch) {551LLVM_DEBUG(dbgs() << "identifyNodeWithImplicitAdd " << *Real << " / " << *Imag552<< "\n");553554if (!Real->hasOneUse() || !Imag->hasOneUse()) {555LLVM_DEBUG(dbgs() << " - Mul operand has multiple uses.\n");556return nullptr;557}558559if ((Real->getOpcode() != Instruction::FMul &&560Real->getOpcode() != Instruction::Mul) ||561(Imag->getOpcode() != Instruction::FMul &&562Imag->getOpcode() != Instruction::Mul)) {563LLVM_DEBUG(564dbgs() << " - Real or imaginary instruction is not fmul or mul\n");565return nullptr;566}567568Value *R0 = Real->getOperand(0);569Value *R1 = Real->getOperand(1);570Value *I0 = Imag->getOperand(0);571Value *I1 = Imag->getOperand(1);572573// A +/+ has a rotation of 0. If any of the operands are fneg, we flip the574// rotations and use the operand.575unsigned Negs = 0;576Value *Op;577if (match(R0, m_Neg(m_Value(Op)))) {578Negs |= 1;579R0 = Op;580} else if (match(R1, m_Neg(m_Value(Op)))) {581Negs |= 1;582R1 = Op;583}584585if (isNeg(I0)) {586Negs |= 2;587Negs ^= 1;588I0 = Op;589} else if (match(I1, m_Neg(m_Value(Op)))) {590Negs |= 2;591Negs ^= 1;592I1 = Op;593}594595ComplexDeinterleavingRotation Rotation = (ComplexDeinterleavingRotation)Negs;596597Value *CommonOperand;598Value *UncommonRealOp;599Value *UncommonImagOp;600601if (R0 == I0 || R0 == I1) {602CommonOperand = R0;603UncommonRealOp = R1;604} else if (R1 == I0 || R1 == I1) {605CommonOperand = R1;606UncommonRealOp = R0;607} else {608LLVM_DEBUG(dbgs() << " - No equal operand\n");609return nullptr;610}611612UncommonImagOp = (CommonOperand == I0) ? I1 : I0;613if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||614Rotation == ComplexDeinterleavingRotation::Rotation_270)615std::swap(UncommonRealOp, UncommonImagOp);616617// Between identifyPartialMul and here we need to have found a complete valid618// pair from the CommonOperand of each part.619if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||620Rotation == ComplexDeinterleavingRotation::Rotation_180)621PartialMatch.first = CommonOperand;622else623PartialMatch.second = CommonOperand;624625if (!PartialMatch.first || !PartialMatch.second) {626LLVM_DEBUG(dbgs() << " - Incomplete partial match\n");627return nullptr;628}629630NodePtr CommonNode = identifyNode(PartialMatch.first, PartialMatch.second);631if (!CommonNode) {632LLVM_DEBUG(dbgs() << " - No CommonNode identified\n");633return nullptr;634}635636NodePtr UncommonNode = identifyNode(UncommonRealOp, UncommonImagOp);637if (!UncommonNode) {638LLVM_DEBUG(dbgs() << " - No UncommonNode identified\n");639return nullptr;640}641642NodePtr Node = prepareCompositeNode(643ComplexDeinterleavingOperation::CMulPartial, Real, Imag);644Node->Rotation = Rotation;645Node->addOperand(CommonNode);646Node->addOperand(UncommonNode);647return submitCompositeNode(Node);648}649650ComplexDeinterleavingGraph::NodePtr651ComplexDeinterleavingGraph::identifyPartialMul(Instruction *Real,652Instruction *Imag) {653LLVM_DEBUG(dbgs() << "identifyPartialMul " << *Real << " / " << *Imag654<< "\n");655// Determine rotation656auto IsAdd = [](unsigned Op) {657return Op == Instruction::FAdd || Op == Instruction::Add;658};659auto IsSub = [](unsigned Op) {660return Op == Instruction::FSub || Op == Instruction::Sub;661};662ComplexDeinterleavingRotation Rotation;663if (IsAdd(Real->getOpcode()) && IsAdd(Imag->getOpcode()))664Rotation = ComplexDeinterleavingRotation::Rotation_0;665else if (IsSub(Real->getOpcode()) && IsAdd(Imag->getOpcode()))666Rotation = ComplexDeinterleavingRotation::Rotation_90;667else if (IsSub(Real->getOpcode()) && IsSub(Imag->getOpcode()))668Rotation = ComplexDeinterleavingRotation::Rotation_180;669else if (IsAdd(Real->getOpcode()) && IsSub(Imag->getOpcode()))670Rotation = ComplexDeinterleavingRotation::Rotation_270;671else {672LLVM_DEBUG(dbgs() << " - Unhandled rotation.\n");673return nullptr;674}675676if (isa<FPMathOperator>(Real) &&677(!Real->getFastMathFlags().allowContract() ||678!Imag->getFastMathFlags().allowContract())) {679LLVM_DEBUG(dbgs() << " - Contract is missing from the FastMath flags.\n");680return nullptr;681}682683Value *CR = Real->getOperand(0);684Instruction *RealMulI = dyn_cast<Instruction>(Real->getOperand(1));685if (!RealMulI)686return nullptr;687Value *CI = Imag->getOperand(0);688Instruction *ImagMulI = dyn_cast<Instruction>(Imag->getOperand(1));689if (!ImagMulI)690return nullptr;691692if (!RealMulI->hasOneUse() || !ImagMulI->hasOneUse()) {693LLVM_DEBUG(dbgs() << " - Mul instruction has multiple uses\n");694return nullptr;695}696697Value *R0 = RealMulI->getOperand(0);698Value *R1 = RealMulI->getOperand(1);699Value *I0 = ImagMulI->getOperand(0);700Value *I1 = ImagMulI->getOperand(1);701702Value *CommonOperand;703Value *UncommonRealOp;704Value *UncommonImagOp;705706if (R0 == I0 || R0 == I1) {707CommonOperand = R0;708UncommonRealOp = R1;709} else if (R1 == I0 || R1 == I1) {710CommonOperand = R1;711UncommonRealOp = R0;712} else {713LLVM_DEBUG(dbgs() << " - No equal operand\n");714return nullptr;715}716717UncommonImagOp = (CommonOperand == I0) ? I1 : I0;718if (Rotation == ComplexDeinterleavingRotation::Rotation_90 ||719Rotation == ComplexDeinterleavingRotation::Rotation_270)720std::swap(UncommonRealOp, UncommonImagOp);721722std::pair<Value *, Value *> PartialMatch(723(Rotation == ComplexDeinterleavingRotation::Rotation_0 ||724Rotation == ComplexDeinterleavingRotation::Rotation_180)725? CommonOperand726: nullptr,727(Rotation == ComplexDeinterleavingRotation::Rotation_90 ||728Rotation == ComplexDeinterleavingRotation::Rotation_270)729? CommonOperand730: nullptr);731732auto *CRInst = dyn_cast<Instruction>(CR);733auto *CIInst = dyn_cast<Instruction>(CI);734735if (!CRInst || !CIInst) {736LLVM_DEBUG(dbgs() << " - Common operands are not instructions.\n");737return nullptr;738}739740NodePtr CNode = identifyNodeWithImplicitAdd(CRInst, CIInst, PartialMatch);741if (!CNode) {742LLVM_DEBUG(dbgs() << " - No cnode identified\n");743return nullptr;744}745746NodePtr UncommonRes = identifyNode(UncommonRealOp, UncommonImagOp);747if (!UncommonRes) {748LLVM_DEBUG(dbgs() << " - No UncommonRes identified\n");749return nullptr;750}751752assert(PartialMatch.first && PartialMatch.second);753NodePtr CommonRes = identifyNode(PartialMatch.first, PartialMatch.second);754if (!CommonRes) {755LLVM_DEBUG(dbgs() << " - No CommonRes identified\n");756return nullptr;757}758759NodePtr Node = prepareCompositeNode(760ComplexDeinterleavingOperation::CMulPartial, Real, Imag);761Node->Rotation = Rotation;762Node->addOperand(CommonRes);763Node->addOperand(UncommonRes);764Node->addOperand(CNode);765return submitCompositeNode(Node);766}767768ComplexDeinterleavingGraph::NodePtr769ComplexDeinterleavingGraph::identifyAdd(Instruction *Real, Instruction *Imag) {770LLVM_DEBUG(dbgs() << "identifyAdd " << *Real << " / " << *Imag << "\n");771772// Determine rotation773ComplexDeinterleavingRotation Rotation;774if ((Real->getOpcode() == Instruction::FSub &&775Imag->getOpcode() == Instruction::FAdd) ||776(Real->getOpcode() == Instruction::Sub &&777Imag->getOpcode() == Instruction::Add))778Rotation = ComplexDeinterleavingRotation::Rotation_90;779else if ((Real->getOpcode() == Instruction::FAdd &&780Imag->getOpcode() == Instruction::FSub) ||781(Real->getOpcode() == Instruction::Add &&782Imag->getOpcode() == Instruction::Sub))783Rotation = ComplexDeinterleavingRotation::Rotation_270;784else {785LLVM_DEBUG(dbgs() << " - Unhandled case, rotation is not assigned.\n");786return nullptr;787}788789auto *AR = dyn_cast<Instruction>(Real->getOperand(0));790auto *BI = dyn_cast<Instruction>(Real->getOperand(1));791auto *AI = dyn_cast<Instruction>(Imag->getOperand(0));792auto *BR = dyn_cast<Instruction>(Imag->getOperand(1));793794if (!AR || !AI || !BR || !BI) {795LLVM_DEBUG(dbgs() << " - Not all operands are instructions.\n");796return nullptr;797}798799NodePtr ResA = identifyNode(AR, AI);800if (!ResA) {801LLVM_DEBUG(dbgs() << " - AR/AI is not identified as a composite node.\n");802return nullptr;803}804NodePtr ResB = identifyNode(BR, BI);805if (!ResB) {806LLVM_DEBUG(dbgs() << " - BR/BI is not identified as a composite node.\n");807return nullptr;808}809810NodePtr Node =811prepareCompositeNode(ComplexDeinterleavingOperation::CAdd, Real, Imag);812Node->Rotation = Rotation;813Node->addOperand(ResA);814Node->addOperand(ResB);815return submitCompositeNode(Node);816}817818static bool isInstructionPairAdd(Instruction *A, Instruction *B) {819unsigned OpcA = A->getOpcode();820unsigned OpcB = B->getOpcode();821822return (OpcA == Instruction::FSub && OpcB == Instruction::FAdd) ||823(OpcA == Instruction::FAdd && OpcB == Instruction::FSub) ||824(OpcA == Instruction::Sub && OpcB == Instruction::Add) ||825(OpcA == Instruction::Add && OpcB == Instruction::Sub);826}827828static bool isInstructionPairMul(Instruction *A, Instruction *B) {829auto Pattern =830m_BinOp(m_FMul(m_Value(), m_Value()), m_FMul(m_Value(), m_Value()));831832return match(A, Pattern) && match(B, Pattern);833}834835static bool isInstructionPotentiallySymmetric(Instruction *I) {836switch (I->getOpcode()) {837case Instruction::FAdd:838case Instruction::FSub:839case Instruction::FMul:840case Instruction::FNeg:841case Instruction::Add:842case Instruction::Sub:843case Instruction::Mul:844return true;845default:846return false;847}848}849850ComplexDeinterleavingGraph::NodePtr851ComplexDeinterleavingGraph::identifySymmetricOperation(Instruction *Real,852Instruction *Imag) {853if (Real->getOpcode() != Imag->getOpcode())854return nullptr;855856if (!isInstructionPotentiallySymmetric(Real) ||857!isInstructionPotentiallySymmetric(Imag))858return nullptr;859860auto *R0 = Real->getOperand(0);861auto *I0 = Imag->getOperand(0);862863NodePtr Op0 = identifyNode(R0, I0);864NodePtr Op1 = nullptr;865if (Op0 == nullptr)866return nullptr;867868if (Real->isBinaryOp()) {869auto *R1 = Real->getOperand(1);870auto *I1 = Imag->getOperand(1);871Op1 = identifyNode(R1, I1);872if (Op1 == nullptr)873return nullptr;874}875876if (isa<FPMathOperator>(Real) &&877Real->getFastMathFlags() != Imag->getFastMathFlags())878return nullptr;879880auto Node = prepareCompositeNode(ComplexDeinterleavingOperation::Symmetric,881Real, Imag);882Node->Opcode = Real->getOpcode();883if (isa<FPMathOperator>(Real))884Node->Flags = Real->getFastMathFlags();885886Node->addOperand(Op0);887if (Real->isBinaryOp())888Node->addOperand(Op1);889890return submitCompositeNode(Node);891}892893ComplexDeinterleavingGraph::NodePtr894ComplexDeinterleavingGraph::identifyNode(Value *R, Value *I) {895LLVM_DEBUG(dbgs() << "identifyNode on " << *R << " / " << *I << "\n");896assert(R->getType() == I->getType() &&897"Real and imaginary parts should not have different types");898899auto It = CachedResult.find({R, I});900if (It != CachedResult.end()) {901LLVM_DEBUG(dbgs() << " - Folding to existing node\n");902return It->second;903}904905if (NodePtr CN = identifySplat(R, I))906return CN;907908auto *Real = dyn_cast<Instruction>(R);909auto *Imag = dyn_cast<Instruction>(I);910if (!Real || !Imag)911return nullptr;912913if (NodePtr CN = identifyDeinterleave(Real, Imag))914return CN;915916if (NodePtr CN = identifyPHINode(Real, Imag))917return CN;918919if (NodePtr CN = identifySelectNode(Real, Imag))920return CN;921922auto *VTy = cast<VectorType>(Real->getType());923auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);924925bool HasCMulSupport = TL->isComplexDeinterleavingOperationSupported(926ComplexDeinterleavingOperation::CMulPartial, NewVTy);927bool HasCAddSupport = TL->isComplexDeinterleavingOperationSupported(928ComplexDeinterleavingOperation::CAdd, NewVTy);929930if (HasCMulSupport && isInstructionPairMul(Real, Imag)) {931if (NodePtr CN = identifyPartialMul(Real, Imag))932return CN;933}934935if (HasCAddSupport && isInstructionPairAdd(Real, Imag)) {936if (NodePtr CN = identifyAdd(Real, Imag))937return CN;938}939940if (HasCMulSupport && HasCAddSupport) {941if (NodePtr CN = identifyReassocNodes(Real, Imag))942return CN;943}944945if (NodePtr CN = identifySymmetricOperation(Real, Imag))946return CN;947948LLVM_DEBUG(dbgs() << " - Not recognised as a valid pattern.\n");949CachedResult[{R, I}] = nullptr;950return nullptr;951}952953ComplexDeinterleavingGraph::NodePtr954ComplexDeinterleavingGraph::identifyReassocNodes(Instruction *Real,955Instruction *Imag) {956auto IsOperationSupported = [](unsigned Opcode) -> bool {957return Opcode == Instruction::FAdd || Opcode == Instruction::FSub ||958Opcode == Instruction::FNeg || Opcode == Instruction::Add ||959Opcode == Instruction::Sub;960};961962if (!IsOperationSupported(Real->getOpcode()) ||963!IsOperationSupported(Imag->getOpcode()))964return nullptr;965966std::optional<FastMathFlags> Flags;967if (isa<FPMathOperator>(Real)) {968if (Real->getFastMathFlags() != Imag->getFastMathFlags()) {969LLVM_DEBUG(dbgs() << "The flags in Real and Imaginary instructions are "970"not identical\n");971return nullptr;972}973974Flags = Real->getFastMathFlags();975if (!Flags->allowReassoc()) {976LLVM_DEBUG(977dbgs()978<< "the 'Reassoc' attribute is missing in the FastMath flags\n");979return nullptr;980}981}982983// Collect multiplications and addend instructions from the given instruction984// while traversing it operands. Additionally, verify that all instructions985// have the same fast math flags.986auto Collect = [&Flags](Instruction *Insn, std::vector<Product> &Muls,987std::list<Addend> &Addends) -> bool {988SmallVector<PointerIntPair<Value *, 1, bool>> Worklist = {{Insn, true}};989SmallPtrSet<Value *, 8> Visited;990while (!Worklist.empty()) {991auto [V, IsPositive] = Worklist.back();992Worklist.pop_back();993if (!Visited.insert(V).second)994continue;995996Instruction *I = dyn_cast<Instruction>(V);997if (!I) {998Addends.emplace_back(V, IsPositive);999continue;1000}10011002// If an instruction has more than one user, it indicates that it either1003// has an external user, which will be later checked by the checkNodes1004// function, or it is a subexpression utilized by multiple expressions. In1005// the latter case, we will attempt to separately identify the complex1006// operation from here in order to create a shared1007// ComplexDeinterleavingCompositeNode.1008if (I != Insn && I->getNumUses() > 1) {1009LLVM_DEBUG(dbgs() << "Found potential sub-expression: " << *I << "\n");1010Addends.emplace_back(I, IsPositive);1011continue;1012}1013switch (I->getOpcode()) {1014case Instruction::FAdd:1015case Instruction::Add:1016Worklist.emplace_back(I->getOperand(1), IsPositive);1017Worklist.emplace_back(I->getOperand(0), IsPositive);1018break;1019case Instruction::FSub:1020Worklist.emplace_back(I->getOperand(1), !IsPositive);1021Worklist.emplace_back(I->getOperand(0), IsPositive);1022break;1023case Instruction::Sub:1024if (isNeg(I)) {1025Worklist.emplace_back(getNegOperand(I), !IsPositive);1026} else {1027Worklist.emplace_back(I->getOperand(1), !IsPositive);1028Worklist.emplace_back(I->getOperand(0), IsPositive);1029}1030break;1031case Instruction::FMul:1032case Instruction::Mul: {1033Value *A, *B;1034if (isNeg(I->getOperand(0))) {1035A = getNegOperand(I->getOperand(0));1036IsPositive = !IsPositive;1037} else {1038A = I->getOperand(0);1039}10401041if (isNeg(I->getOperand(1))) {1042B = getNegOperand(I->getOperand(1));1043IsPositive = !IsPositive;1044} else {1045B = I->getOperand(1);1046}1047Muls.push_back(Product{A, B, IsPositive});1048break;1049}1050case Instruction::FNeg:1051Worklist.emplace_back(I->getOperand(0), !IsPositive);1052break;1053default:1054Addends.emplace_back(I, IsPositive);1055continue;1056}10571058if (Flags && I->getFastMathFlags() != *Flags) {1059LLVM_DEBUG(dbgs() << "The instruction's fast math flags are "1060"inconsistent with the root instructions' flags: "1061<< *I << "\n");1062return false;1063}1064}1065return true;1066};10671068std::vector<Product> RealMuls, ImagMuls;1069std::list<Addend> RealAddends, ImagAddends;1070if (!Collect(Real, RealMuls, RealAddends) ||1071!Collect(Imag, ImagMuls, ImagAddends))1072return nullptr;10731074if (RealAddends.size() != ImagAddends.size())1075return nullptr;10761077NodePtr FinalNode;1078if (!RealMuls.empty() || !ImagMuls.empty()) {1079// If there are multiplicands, extract positive addend and use it as an1080// accumulator1081FinalNode = extractPositiveAddend(RealAddends, ImagAddends);1082FinalNode = identifyMultiplications(RealMuls, ImagMuls, FinalNode);1083if (!FinalNode)1084return nullptr;1085}10861087// Identify and process remaining additions1088if (!RealAddends.empty() || !ImagAddends.empty()) {1089FinalNode = identifyAdditions(RealAddends, ImagAddends, Flags, FinalNode);1090if (!FinalNode)1091return nullptr;1092}1093assert(FinalNode && "FinalNode can not be nullptr here");1094// Set the Real and Imag fields of the final node and submit it1095FinalNode->Real = Real;1096FinalNode->Imag = Imag;1097submitCompositeNode(FinalNode);1098return FinalNode;1099}11001101bool ComplexDeinterleavingGraph::collectPartialMuls(1102const std::vector<Product> &RealMuls, const std::vector<Product> &ImagMuls,1103std::vector<PartialMulCandidate> &PartialMulCandidates) {1104// Helper function to extract a common operand from two products1105auto FindCommonInstruction = [](const Product &Real,1106const Product &Imag) -> Value * {1107if (Real.Multiplicand == Imag.Multiplicand ||1108Real.Multiplicand == Imag.Multiplier)1109return Real.Multiplicand;11101111if (Real.Multiplier == Imag.Multiplicand ||1112Real.Multiplier == Imag.Multiplier)1113return Real.Multiplier;11141115return nullptr;1116};11171118// Iterating over real and imaginary multiplications to find common operands1119// If a common operand is found, a partial multiplication candidate is created1120// and added to the candidates vector The function returns false if no common1121// operands are found for any product1122for (unsigned i = 0; i < RealMuls.size(); ++i) {1123bool FoundCommon = false;1124for (unsigned j = 0; j < ImagMuls.size(); ++j) {1125auto *Common = FindCommonInstruction(RealMuls[i], ImagMuls[j]);1126if (!Common)1127continue;11281129auto *A = RealMuls[i].Multiplicand == Common ? RealMuls[i].Multiplier1130: RealMuls[i].Multiplicand;1131auto *B = ImagMuls[j].Multiplicand == Common ? ImagMuls[j].Multiplier1132: ImagMuls[j].Multiplicand;11331134auto Node = identifyNode(A, B);1135if (Node) {1136FoundCommon = true;1137PartialMulCandidates.push_back({Common, Node, i, j, false});1138}11391140Node = identifyNode(B, A);1141if (Node) {1142FoundCommon = true;1143PartialMulCandidates.push_back({Common, Node, i, j, true});1144}1145}1146if (!FoundCommon)1147return false;1148}1149return true;1150}11511152ComplexDeinterleavingGraph::NodePtr1153ComplexDeinterleavingGraph::identifyMultiplications(1154std::vector<Product> &RealMuls, std::vector<Product> &ImagMuls,1155NodePtr Accumulator = nullptr) {1156if (RealMuls.size() != ImagMuls.size())1157return nullptr;11581159std::vector<PartialMulCandidate> Info;1160if (!collectPartialMuls(RealMuls, ImagMuls, Info))1161return nullptr;11621163// Map to store common instruction to node pointers1164std::map<Value *, NodePtr> CommonToNode;1165std::vector<bool> Processed(Info.size(), false);1166for (unsigned I = 0; I < Info.size(); ++I) {1167if (Processed[I])1168continue;11691170PartialMulCandidate &InfoA = Info[I];1171for (unsigned J = I + 1; J < Info.size(); ++J) {1172if (Processed[J])1173continue;11741175PartialMulCandidate &InfoB = Info[J];1176auto *InfoReal = &InfoA;1177auto *InfoImag = &InfoB;11781179auto NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);1180if (!NodeFromCommon) {1181std::swap(InfoReal, InfoImag);1182NodeFromCommon = identifyNode(InfoReal->Common, InfoImag->Common);1183}1184if (!NodeFromCommon)1185continue;11861187CommonToNode[InfoReal->Common] = NodeFromCommon;1188CommonToNode[InfoImag->Common] = NodeFromCommon;1189Processed[I] = true;1190Processed[J] = true;1191}1192}11931194std::vector<bool> ProcessedReal(RealMuls.size(), false);1195std::vector<bool> ProcessedImag(ImagMuls.size(), false);1196NodePtr Result = Accumulator;1197for (auto &PMI : Info) {1198if (ProcessedReal[PMI.RealIdx] || ProcessedImag[PMI.ImagIdx])1199continue;12001201auto It = CommonToNode.find(PMI.Common);1202// TODO: Process independent complex multiplications. Cases like this:1203// A.real() * B where both A and B are complex numbers.1204if (It == CommonToNode.end()) {1205LLVM_DEBUG({1206dbgs() << "Unprocessed independent partial multiplication:\n";1207for (auto *Mul : {&RealMuls[PMI.RealIdx], &RealMuls[PMI.RealIdx]})1208dbgs().indent(4) << (Mul->IsPositive ? "+" : "-") << *Mul->Multiplier1209<< " multiplied by " << *Mul->Multiplicand << "\n";1210});1211return nullptr;1212}12131214auto &RealMul = RealMuls[PMI.RealIdx];1215auto &ImagMul = ImagMuls[PMI.ImagIdx];12161217auto NodeA = It->second;1218auto NodeB = PMI.Node;1219auto IsMultiplicandReal = PMI.Common == NodeA->Real;1220// The following table illustrates the relationship between multiplications1221// and rotations. If we consider the multiplication (X + iY) * (U + iV), we1222// can see:1223//1224// Rotation | Real | Imag |1225// ---------+--------+--------+1226// 0 | x * u | x * v |1227// 90 | -y * v | y * u |1228// 180 | -x * u | -x * v |1229// 270 | y * v | -y * u |1230//1231// Check if the candidate can indeed be represented by partial1232// multiplication1233// TODO: Add support for multiplication by complex one1234if ((IsMultiplicandReal && PMI.IsNodeInverted) ||1235(!IsMultiplicandReal && !PMI.IsNodeInverted))1236continue;12371238// Determine the rotation based on the multiplications1239ComplexDeinterleavingRotation Rotation;1240if (IsMultiplicandReal) {1241// Detect 0 and 180 degrees rotation1242if (RealMul.IsPositive && ImagMul.IsPositive)1243Rotation = llvm::ComplexDeinterleavingRotation::Rotation_0;1244else if (!RealMul.IsPositive && !ImagMul.IsPositive)1245Rotation = llvm::ComplexDeinterleavingRotation::Rotation_180;1246else1247continue;12481249} else {1250// Detect 90 and 270 degrees rotation1251if (!RealMul.IsPositive && ImagMul.IsPositive)1252Rotation = llvm::ComplexDeinterleavingRotation::Rotation_90;1253else if (RealMul.IsPositive && !ImagMul.IsPositive)1254Rotation = llvm::ComplexDeinterleavingRotation::Rotation_270;1255else1256continue;1257}12581259LLVM_DEBUG({1260dbgs() << "Identified partial multiplication (X, Y) * (U, V):\n";1261dbgs().indent(4) << "X: " << *NodeA->Real << "\n";1262dbgs().indent(4) << "Y: " << *NodeA->Imag << "\n";1263dbgs().indent(4) << "U: " << *NodeB->Real << "\n";1264dbgs().indent(4) << "V: " << *NodeB->Imag << "\n";1265dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";1266});12671268NodePtr NodeMul = prepareCompositeNode(1269ComplexDeinterleavingOperation::CMulPartial, nullptr, nullptr);1270NodeMul->Rotation = Rotation;1271NodeMul->addOperand(NodeA);1272NodeMul->addOperand(NodeB);1273if (Result)1274NodeMul->addOperand(Result);1275submitCompositeNode(NodeMul);1276Result = NodeMul;1277ProcessedReal[PMI.RealIdx] = true;1278ProcessedImag[PMI.ImagIdx] = true;1279}12801281// Ensure all products have been processed, if not return nullptr.1282if (!all_of(ProcessedReal, [](bool V) { return V; }) ||1283!all_of(ProcessedImag, [](bool V) { return V; })) {12841285// Dump debug information about which partial multiplications are not1286// processed.1287LLVM_DEBUG({1288dbgs() << "Unprocessed products (Real):\n";1289for (size_t i = 0; i < ProcessedReal.size(); ++i) {1290if (!ProcessedReal[i])1291dbgs().indent(4) << (RealMuls[i].IsPositive ? "+" : "-")1292<< *RealMuls[i].Multiplier << " multiplied by "1293<< *RealMuls[i].Multiplicand << "\n";1294}1295dbgs() << "Unprocessed products (Imag):\n";1296for (size_t i = 0; i < ProcessedImag.size(); ++i) {1297if (!ProcessedImag[i])1298dbgs().indent(4) << (ImagMuls[i].IsPositive ? "+" : "-")1299<< *ImagMuls[i].Multiplier << " multiplied by "1300<< *ImagMuls[i].Multiplicand << "\n";1301}1302});1303return nullptr;1304}13051306return Result;1307}13081309ComplexDeinterleavingGraph::NodePtr1310ComplexDeinterleavingGraph::identifyAdditions(1311std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends,1312std::optional<FastMathFlags> Flags, NodePtr Accumulator = nullptr) {1313if (RealAddends.size() != ImagAddends.size())1314return nullptr;13151316NodePtr Result;1317// If we have accumulator use it as first addend1318if (Accumulator)1319Result = Accumulator;1320// Otherwise find an element with both positive real and imaginary parts.1321else1322Result = extractPositiveAddend(RealAddends, ImagAddends);13231324if (!Result)1325return nullptr;13261327while (!RealAddends.empty()) {1328auto ItR = RealAddends.begin();1329auto [R, IsPositiveR] = *ItR;13301331bool FoundImag = false;1332for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {1333auto [I, IsPositiveI] = *ItI;1334ComplexDeinterleavingRotation Rotation;1335if (IsPositiveR && IsPositiveI)1336Rotation = ComplexDeinterleavingRotation::Rotation_0;1337else if (!IsPositiveR && IsPositiveI)1338Rotation = ComplexDeinterleavingRotation::Rotation_90;1339else if (!IsPositiveR && !IsPositiveI)1340Rotation = ComplexDeinterleavingRotation::Rotation_180;1341else1342Rotation = ComplexDeinterleavingRotation::Rotation_270;13431344NodePtr AddNode;1345if (Rotation == ComplexDeinterleavingRotation::Rotation_0 ||1346Rotation == ComplexDeinterleavingRotation::Rotation_180) {1347AddNode = identifyNode(R, I);1348} else {1349AddNode = identifyNode(I, R);1350}1351if (AddNode) {1352LLVM_DEBUG({1353dbgs() << "Identified addition:\n";1354dbgs().indent(4) << "X: " << *R << "\n";1355dbgs().indent(4) << "Y: " << *I << "\n";1356dbgs().indent(4) << "Rotation - " << (int)Rotation * 90 << "\n";1357});13581359NodePtr TmpNode;1360if (Rotation == llvm::ComplexDeinterleavingRotation::Rotation_0) {1361TmpNode = prepareCompositeNode(1362ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);1363if (Flags) {1364TmpNode->Opcode = Instruction::FAdd;1365TmpNode->Flags = *Flags;1366} else {1367TmpNode->Opcode = Instruction::Add;1368}1369} else if (Rotation ==1370llvm::ComplexDeinterleavingRotation::Rotation_180) {1371TmpNode = prepareCompositeNode(1372ComplexDeinterleavingOperation::Symmetric, nullptr, nullptr);1373if (Flags) {1374TmpNode->Opcode = Instruction::FSub;1375TmpNode->Flags = *Flags;1376} else {1377TmpNode->Opcode = Instruction::Sub;1378}1379} else {1380TmpNode = prepareCompositeNode(ComplexDeinterleavingOperation::CAdd,1381nullptr, nullptr);1382TmpNode->Rotation = Rotation;1383}13841385TmpNode->addOperand(Result);1386TmpNode->addOperand(AddNode);1387submitCompositeNode(TmpNode);1388Result = TmpNode;1389RealAddends.erase(ItR);1390ImagAddends.erase(ItI);1391FoundImag = true;1392break;1393}1394}1395if (!FoundImag)1396return nullptr;1397}1398return Result;1399}14001401ComplexDeinterleavingGraph::NodePtr1402ComplexDeinterleavingGraph::extractPositiveAddend(1403std::list<Addend> &RealAddends, std::list<Addend> &ImagAddends) {1404for (auto ItR = RealAddends.begin(); ItR != RealAddends.end(); ++ItR) {1405for (auto ItI = ImagAddends.begin(); ItI != ImagAddends.end(); ++ItI) {1406auto [R, IsPositiveR] = *ItR;1407auto [I, IsPositiveI] = *ItI;1408if (IsPositiveR && IsPositiveI) {1409auto Result = identifyNode(R, I);1410if (Result) {1411RealAddends.erase(ItR);1412ImagAddends.erase(ItI);1413return Result;1414}1415}1416}1417}1418return nullptr;1419}14201421bool ComplexDeinterleavingGraph::identifyNodes(Instruction *RootI) {1422// This potential root instruction might already have been recognized as1423// reduction. Because RootToNode maps both Real and Imaginary parts to1424// CompositeNode we should choose only one either Real or Imag instruction to1425// use as an anchor for generating complex instruction.1426auto It = RootToNode.find(RootI);1427if (It != RootToNode.end()) {1428auto RootNode = It->second;1429assert(RootNode->Operation ==1430ComplexDeinterleavingOperation::ReductionOperation);1431// Find out which part, Real or Imag, comes later, and only if we come to1432// the latest part, add it to OrderedRoots.1433auto *R = cast<Instruction>(RootNode->Real);1434auto *I = cast<Instruction>(RootNode->Imag);1435auto *ReplacementAnchor = R->comesBefore(I) ? I : R;1436if (ReplacementAnchor != RootI)1437return false;1438OrderedRoots.push_back(RootI);1439return true;1440}14411442auto RootNode = identifyRoot(RootI);1443if (!RootNode)1444return false;14451446LLVM_DEBUG({1447Function *F = RootI->getFunction();1448BasicBlock *B = RootI->getParent();1449dbgs() << "Complex deinterleaving graph for " << F->getName()1450<< "::" << B->getName() << ".\n";1451dump(dbgs());1452dbgs() << "\n";1453});1454RootToNode[RootI] = RootNode;1455OrderedRoots.push_back(RootI);1456return true;1457}14581459bool ComplexDeinterleavingGraph::collectPotentialReductions(BasicBlock *B) {1460bool FoundPotentialReduction = false;14611462auto *Br = dyn_cast<BranchInst>(B->getTerminator());1463if (!Br || Br->getNumSuccessors() != 2)1464return false;14651466// Identify simple one-block loop1467if (Br->getSuccessor(0) != B && Br->getSuccessor(1) != B)1468return false;14691470SmallVector<PHINode *> PHIs;1471for (auto &PHI : B->phis()) {1472if (PHI.getNumIncomingValues() != 2)1473continue;14741475if (!PHI.getType()->isVectorTy())1476continue;14771478auto *ReductionOp = dyn_cast<Instruction>(PHI.getIncomingValueForBlock(B));1479if (!ReductionOp)1480continue;14811482// Check if final instruction is reduced outside of current block1483Instruction *FinalReduction = nullptr;1484auto NumUsers = 0u;1485for (auto *U : ReductionOp->users()) {1486++NumUsers;1487if (U == &PHI)1488continue;1489FinalReduction = dyn_cast<Instruction>(U);1490}14911492if (NumUsers != 2 || !FinalReduction || FinalReduction->getParent() == B ||1493isa<PHINode>(FinalReduction))1494continue;14951496ReductionInfo[ReductionOp] = {&PHI, FinalReduction};1497BackEdge = B;1498auto BackEdgeIdx = PHI.getBasicBlockIndex(B);1499auto IncomingIdx = BackEdgeIdx == 0 ? 1 : 0;1500Incoming = PHI.getIncomingBlock(IncomingIdx);1501FoundPotentialReduction = true;15021503// If the initial value of PHINode is an Instruction, consider it a leaf1504// value of a complex deinterleaving graph.1505if (auto *InitPHI =1506dyn_cast<Instruction>(PHI.getIncomingValueForBlock(Incoming)))1507FinalInstructions.insert(InitPHI);1508}1509return FoundPotentialReduction;1510}15111512void ComplexDeinterleavingGraph::identifyReductionNodes() {1513SmallVector<bool> Processed(ReductionInfo.size(), false);1514SmallVector<Instruction *> OperationInstruction;1515for (auto &P : ReductionInfo)1516OperationInstruction.push_back(P.first);15171518// Identify a complex computation by evaluating two reduction operations that1519// potentially could be involved1520for (size_t i = 0; i < OperationInstruction.size(); ++i) {1521if (Processed[i])1522continue;1523for (size_t j = i + 1; j < OperationInstruction.size(); ++j) {1524if (Processed[j])1525continue;15261527auto *Real = OperationInstruction[i];1528auto *Imag = OperationInstruction[j];1529if (Real->getType() != Imag->getType())1530continue;15311532RealPHI = ReductionInfo[Real].first;1533ImagPHI = ReductionInfo[Imag].first;1534PHIsFound = false;1535auto Node = identifyNode(Real, Imag);1536if (!Node) {1537std::swap(Real, Imag);1538std::swap(RealPHI, ImagPHI);1539Node = identifyNode(Real, Imag);1540}15411542// If a node is identified and reduction PHINode is used in the chain of1543// operations, mark its operation instructions as used to prevent1544// re-identification and attach the node to the real part1545if (Node && PHIsFound) {1546LLVM_DEBUG(dbgs() << "Identified reduction starting from instructions: "1547<< *Real << " / " << *Imag << "\n");1548Processed[i] = true;1549Processed[j] = true;1550auto RootNode = prepareCompositeNode(1551ComplexDeinterleavingOperation::ReductionOperation, Real, Imag);1552RootNode->addOperand(Node);1553RootToNode[Real] = RootNode;1554RootToNode[Imag] = RootNode;1555submitCompositeNode(RootNode);1556break;1557}1558}1559}15601561RealPHI = nullptr;1562ImagPHI = nullptr;1563}15641565bool ComplexDeinterleavingGraph::checkNodes() {1566// Collect all instructions from roots to leaves1567SmallPtrSet<Instruction *, 16> AllInstructions;1568SmallVector<Instruction *, 8> Worklist;1569for (auto &Pair : RootToNode)1570Worklist.push_back(Pair.first);15711572// Extract all instructions that are used by all XCMLA/XCADD/ADD/SUB/NEG1573// chains1574while (!Worklist.empty()) {1575auto *I = Worklist.back();1576Worklist.pop_back();15771578if (!AllInstructions.insert(I).second)1579continue;15801581for (Value *Op : I->operands()) {1582if (auto *OpI = dyn_cast<Instruction>(Op)) {1583if (!FinalInstructions.count(I))1584Worklist.emplace_back(OpI);1585}1586}1587}15881589// Find instructions that have users outside of chain1590SmallVector<Instruction *, 2> OuterInstructions;1591for (auto *I : AllInstructions) {1592// Skip root nodes1593if (RootToNode.count(I))1594continue;15951596for (User *U : I->users()) {1597if (AllInstructions.count(cast<Instruction>(U)))1598continue;15991600// Found an instruction that is not used by XCMLA/XCADD chain1601Worklist.emplace_back(I);1602break;1603}1604}16051606// If any instructions are found to be used outside, find and remove roots1607// that somehow connect to those instructions.1608SmallPtrSet<Instruction *, 16> Visited;1609while (!Worklist.empty()) {1610auto *I = Worklist.back();1611Worklist.pop_back();1612if (!Visited.insert(I).second)1613continue;16141615// Found an impacted root node. Removing it from the nodes to be1616// deinterleaved1617if (RootToNode.count(I)) {1618LLVM_DEBUG(dbgs() << "Instruction " << *I1619<< " could be deinterleaved but its chain of complex "1620"operations have an outside user\n");1621RootToNode.erase(I);1622}16231624if (!AllInstructions.count(I) || FinalInstructions.count(I))1625continue;16261627for (User *U : I->users())1628Worklist.emplace_back(cast<Instruction>(U));16291630for (Value *Op : I->operands()) {1631if (auto *OpI = dyn_cast<Instruction>(Op))1632Worklist.emplace_back(OpI);1633}1634}1635return !RootToNode.empty();1636}16371638ComplexDeinterleavingGraph::NodePtr1639ComplexDeinterleavingGraph::identifyRoot(Instruction *RootI) {1640if (auto *Intrinsic = dyn_cast<IntrinsicInst>(RootI)) {1641if (Intrinsic->getIntrinsicID() != Intrinsic::vector_interleave2)1642return nullptr;16431644auto *Real = dyn_cast<Instruction>(Intrinsic->getOperand(0));1645auto *Imag = dyn_cast<Instruction>(Intrinsic->getOperand(1));1646if (!Real || !Imag)1647return nullptr;16481649return identifyNode(Real, Imag);1650}16511652auto *SVI = dyn_cast<ShuffleVectorInst>(RootI);1653if (!SVI)1654return nullptr;16551656// Look for a shufflevector that takes separate vectors of the real and1657// imaginary components and recombines them into a single vector.1658if (!isInterleavingMask(SVI->getShuffleMask()))1659return nullptr;16601661Instruction *Real;1662Instruction *Imag;1663if (!match(RootI, m_Shuffle(m_Instruction(Real), m_Instruction(Imag))))1664return nullptr;16651666return identifyNode(Real, Imag);1667}16681669ComplexDeinterleavingGraph::NodePtr1670ComplexDeinterleavingGraph::identifyDeinterleave(Instruction *Real,1671Instruction *Imag) {1672Instruction *I = nullptr;1673Value *FinalValue = nullptr;1674if (match(Real, m_ExtractValue<0>(m_Instruction(I))) &&1675match(Imag, m_ExtractValue<1>(m_Specific(I))) &&1676match(I, m_Intrinsic<Intrinsic::vector_deinterleave2>(1677m_Value(FinalValue)))) {1678NodePtr PlaceholderNode = prepareCompositeNode(1679llvm::ComplexDeinterleavingOperation::Deinterleave, Real, Imag);1680PlaceholderNode->ReplacementNode = FinalValue;1681FinalInstructions.insert(Real);1682FinalInstructions.insert(Imag);1683return submitCompositeNode(PlaceholderNode);1684}16851686auto *RealShuffle = dyn_cast<ShuffleVectorInst>(Real);1687auto *ImagShuffle = dyn_cast<ShuffleVectorInst>(Imag);1688if (!RealShuffle || !ImagShuffle) {1689if (RealShuffle || ImagShuffle)1690LLVM_DEBUG(dbgs() << " - There's a shuffle where there shouldn't be.\n");1691return nullptr;1692}16931694Value *RealOp1 = RealShuffle->getOperand(1);1695if (!isa<UndefValue>(RealOp1) && !isa<ConstantAggregateZero>(RealOp1)) {1696LLVM_DEBUG(dbgs() << " - RealOp1 is not undef or zero.\n");1697return nullptr;1698}1699Value *ImagOp1 = ImagShuffle->getOperand(1);1700if (!isa<UndefValue>(ImagOp1) && !isa<ConstantAggregateZero>(ImagOp1)) {1701LLVM_DEBUG(dbgs() << " - ImagOp1 is not undef or zero.\n");1702return nullptr;1703}17041705Value *RealOp0 = RealShuffle->getOperand(0);1706Value *ImagOp0 = ImagShuffle->getOperand(0);17071708if (RealOp0 != ImagOp0) {1709LLVM_DEBUG(dbgs() << " - Shuffle operands are not equal.\n");1710return nullptr;1711}17121713ArrayRef<int> RealMask = RealShuffle->getShuffleMask();1714ArrayRef<int> ImagMask = ImagShuffle->getShuffleMask();1715if (!isDeinterleavingMask(RealMask) || !isDeinterleavingMask(ImagMask)) {1716LLVM_DEBUG(dbgs() << " - Masks are not deinterleaving.\n");1717return nullptr;1718}17191720if (RealMask[0] != 0 || ImagMask[0] != 1) {1721LLVM_DEBUG(dbgs() << " - Masks do not have the correct initial value.\n");1722return nullptr;1723}17241725// Type checking, the shuffle type should be a vector type of the same1726// scalar type, but half the size1727auto CheckType = [&](ShuffleVectorInst *Shuffle) {1728Value *Op = Shuffle->getOperand(0);1729auto *ShuffleTy = cast<FixedVectorType>(Shuffle->getType());1730auto *OpTy = cast<FixedVectorType>(Op->getType());17311732if (OpTy->getScalarType() != ShuffleTy->getScalarType())1733return false;1734if ((ShuffleTy->getNumElements() * 2) != OpTy->getNumElements())1735return false;17361737return true;1738};17391740auto CheckDeinterleavingShuffle = [&](ShuffleVectorInst *Shuffle) -> bool {1741if (!CheckType(Shuffle))1742return false;17431744ArrayRef<int> Mask = Shuffle->getShuffleMask();1745int Last = *Mask.rbegin();17461747Value *Op = Shuffle->getOperand(0);1748auto *OpTy = cast<FixedVectorType>(Op->getType());1749int NumElements = OpTy->getNumElements();17501751// Ensure that the deinterleaving shuffle only pulls from the first1752// shuffle operand.1753return Last < NumElements;1754};17551756if (RealShuffle->getType() != ImagShuffle->getType()) {1757LLVM_DEBUG(dbgs() << " - Shuffle types aren't equal.\n");1758return nullptr;1759}1760if (!CheckDeinterleavingShuffle(RealShuffle)) {1761LLVM_DEBUG(dbgs() << " - RealShuffle is invalid type.\n");1762return nullptr;1763}1764if (!CheckDeinterleavingShuffle(ImagShuffle)) {1765LLVM_DEBUG(dbgs() << " - ImagShuffle is invalid type.\n");1766return nullptr;1767}17681769NodePtr PlaceholderNode =1770prepareCompositeNode(llvm::ComplexDeinterleavingOperation::Deinterleave,1771RealShuffle, ImagShuffle);1772PlaceholderNode->ReplacementNode = RealShuffle->getOperand(0);1773FinalInstructions.insert(RealShuffle);1774FinalInstructions.insert(ImagShuffle);1775return submitCompositeNode(PlaceholderNode);1776}17771778ComplexDeinterleavingGraph::NodePtr1779ComplexDeinterleavingGraph::identifySplat(Value *R, Value *I) {1780auto IsSplat = [](Value *V) -> bool {1781// Fixed-width vector with constants1782if (isa<ConstantDataVector>(V))1783return true;17841785VectorType *VTy;1786ArrayRef<int> Mask;1787// Splats are represented differently depending on whether the repeated1788// value is a constant or an Instruction1789if (auto *Const = dyn_cast<ConstantExpr>(V)) {1790if (Const->getOpcode() != Instruction::ShuffleVector)1791return false;1792VTy = cast<VectorType>(Const->getType());1793Mask = Const->getShuffleMask();1794} else if (auto *Shuf = dyn_cast<ShuffleVectorInst>(V)) {1795VTy = Shuf->getType();1796Mask = Shuf->getShuffleMask();1797} else {1798return false;1799}18001801// When the data type is <1 x Type>, it's not possible to differentiate1802// between the ComplexDeinterleaving::Deinterleave and1803// ComplexDeinterleaving::Splat operations.1804if (!VTy->isScalableTy() && VTy->getElementCount().getKnownMinValue() == 1)1805return false;18061807return all_equal(Mask) && Mask[0] == 0;1808};18091810if (!IsSplat(R) || !IsSplat(I))1811return nullptr;18121813auto *Real = dyn_cast<Instruction>(R);1814auto *Imag = dyn_cast<Instruction>(I);1815if ((!Real && Imag) || (Real && !Imag))1816return nullptr;18171818if (Real && Imag) {1819// Non-constant splats should be in the same basic block1820if (Real->getParent() != Imag->getParent())1821return nullptr;18221823FinalInstructions.insert(Real);1824FinalInstructions.insert(Imag);1825}1826NodePtr PlaceholderNode =1827prepareCompositeNode(ComplexDeinterleavingOperation::Splat, R, I);1828return submitCompositeNode(PlaceholderNode);1829}18301831ComplexDeinterleavingGraph::NodePtr1832ComplexDeinterleavingGraph::identifyPHINode(Instruction *Real,1833Instruction *Imag) {1834if (Real != RealPHI || Imag != ImagPHI)1835return nullptr;18361837PHIsFound = true;1838NodePtr PlaceholderNode = prepareCompositeNode(1839ComplexDeinterleavingOperation::ReductionPHI, Real, Imag);1840return submitCompositeNode(PlaceholderNode);1841}18421843ComplexDeinterleavingGraph::NodePtr1844ComplexDeinterleavingGraph::identifySelectNode(Instruction *Real,1845Instruction *Imag) {1846auto *SelectReal = dyn_cast<SelectInst>(Real);1847auto *SelectImag = dyn_cast<SelectInst>(Imag);1848if (!SelectReal || !SelectImag)1849return nullptr;18501851Instruction *MaskA, *MaskB;1852Instruction *AR, *AI, *RA, *BI;1853if (!match(Real, m_Select(m_Instruction(MaskA), m_Instruction(AR),1854m_Instruction(RA))) ||1855!match(Imag, m_Select(m_Instruction(MaskB), m_Instruction(AI),1856m_Instruction(BI))))1857return nullptr;18581859if (MaskA != MaskB && !MaskA->isIdenticalTo(MaskB))1860return nullptr;18611862if (!MaskA->getType()->isVectorTy())1863return nullptr;18641865auto NodeA = identifyNode(AR, AI);1866if (!NodeA)1867return nullptr;18681869auto NodeB = identifyNode(RA, BI);1870if (!NodeB)1871return nullptr;18721873NodePtr PlaceholderNode = prepareCompositeNode(1874ComplexDeinterleavingOperation::ReductionSelect, Real, Imag);1875PlaceholderNode->addOperand(NodeA);1876PlaceholderNode->addOperand(NodeB);1877FinalInstructions.insert(MaskA);1878FinalInstructions.insert(MaskB);1879return submitCompositeNode(PlaceholderNode);1880}18811882static Value *replaceSymmetricNode(IRBuilderBase &B, unsigned Opcode,1883std::optional<FastMathFlags> Flags,1884Value *InputA, Value *InputB) {1885Value *I;1886switch (Opcode) {1887case Instruction::FNeg:1888I = B.CreateFNeg(InputA);1889break;1890case Instruction::FAdd:1891I = B.CreateFAdd(InputA, InputB);1892break;1893case Instruction::Add:1894I = B.CreateAdd(InputA, InputB);1895break;1896case Instruction::FSub:1897I = B.CreateFSub(InputA, InputB);1898break;1899case Instruction::Sub:1900I = B.CreateSub(InputA, InputB);1901break;1902case Instruction::FMul:1903I = B.CreateFMul(InputA, InputB);1904break;1905case Instruction::Mul:1906I = B.CreateMul(InputA, InputB);1907break;1908default:1909llvm_unreachable("Incorrect symmetric opcode");1910}1911if (Flags)1912cast<Instruction>(I)->setFastMathFlags(*Flags);1913return I;1914}19151916Value *ComplexDeinterleavingGraph::replaceNode(IRBuilderBase &Builder,1917RawNodePtr Node) {1918if (Node->ReplacementNode)1919return Node->ReplacementNode;19201921auto ReplaceOperandIfExist = [&](RawNodePtr &Node, unsigned Idx) -> Value * {1922return Node->Operands.size() > Idx1923? replaceNode(Builder, Node->Operands[Idx])1924: nullptr;1925};19261927Value *ReplacementNode;1928switch (Node->Operation) {1929case ComplexDeinterleavingOperation::CAdd:1930case ComplexDeinterleavingOperation::CMulPartial:1931case ComplexDeinterleavingOperation::Symmetric: {1932Value *Input0 = ReplaceOperandIfExist(Node, 0);1933Value *Input1 = ReplaceOperandIfExist(Node, 1);1934Value *Accumulator = ReplaceOperandIfExist(Node, 2);1935assert(!Input1 || (Input0->getType() == Input1->getType() &&1936"Node inputs need to be of the same type"));1937assert(!Accumulator ||1938(Input0->getType() == Accumulator->getType() &&1939"Accumulator and input need to be of the same type"));1940if (Node->Operation == ComplexDeinterleavingOperation::Symmetric)1941ReplacementNode = replaceSymmetricNode(Builder, Node->Opcode, Node->Flags,1942Input0, Input1);1943else1944ReplacementNode = TL->createComplexDeinterleavingIR(1945Builder, Node->Operation, Node->Rotation, Input0, Input1,1946Accumulator);1947break;1948}1949case ComplexDeinterleavingOperation::Deinterleave:1950llvm_unreachable("Deinterleave node should already have ReplacementNode");1951break;1952case ComplexDeinterleavingOperation::Splat: {1953auto *NewTy = VectorType::getDoubleElementsVectorType(1954cast<VectorType>(Node->Real->getType()));1955auto *R = dyn_cast<Instruction>(Node->Real);1956auto *I = dyn_cast<Instruction>(Node->Imag);1957if (R && I) {1958// Splats that are not constant are interleaved where they are located1959Instruction *InsertPoint = (I->comesBefore(R) ? R : I)->getNextNode();1960IRBuilder<> IRB(InsertPoint);1961ReplacementNode = IRB.CreateIntrinsic(Intrinsic::vector_interleave2,1962NewTy, {Node->Real, Node->Imag});1963} else {1964ReplacementNode = Builder.CreateIntrinsic(1965Intrinsic::vector_interleave2, NewTy, {Node->Real, Node->Imag});1966}1967break;1968}1969case ComplexDeinterleavingOperation::ReductionPHI: {1970// If Operation is ReductionPHI, a new empty PHINode is created.1971// It is filled later when the ReductionOperation is processed.1972auto *VTy = cast<VectorType>(Node->Real->getType());1973auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);1974auto *NewPHI = PHINode::Create(NewVTy, 0, "", BackEdge->getFirstNonPHIIt());1975OldToNewPHI[dyn_cast<PHINode>(Node->Real)] = NewPHI;1976ReplacementNode = NewPHI;1977break;1978}1979case ComplexDeinterleavingOperation::ReductionOperation:1980ReplacementNode = replaceNode(Builder, Node->Operands[0]);1981processReductionOperation(ReplacementNode, Node);1982break;1983case ComplexDeinterleavingOperation::ReductionSelect: {1984auto *MaskReal = cast<Instruction>(Node->Real)->getOperand(0);1985auto *MaskImag = cast<Instruction>(Node->Imag)->getOperand(0);1986auto *A = replaceNode(Builder, Node->Operands[0]);1987auto *B = replaceNode(Builder, Node->Operands[1]);1988auto *NewMaskTy = VectorType::getDoubleElementsVectorType(1989cast<VectorType>(MaskReal->getType()));1990auto *NewMask = Builder.CreateIntrinsic(Intrinsic::vector_interleave2,1991NewMaskTy, {MaskReal, MaskImag});1992ReplacementNode = Builder.CreateSelect(NewMask, A, B);1993break;1994}1995}19961997assert(ReplacementNode && "Target failed to create Intrinsic call.");1998NumComplexTransformations += 1;1999Node->ReplacementNode = ReplacementNode;2000return ReplacementNode;2001}20022003void ComplexDeinterleavingGraph::processReductionOperation(2004Value *OperationReplacement, RawNodePtr Node) {2005auto *Real = cast<Instruction>(Node->Real);2006auto *Imag = cast<Instruction>(Node->Imag);2007auto *OldPHIReal = ReductionInfo[Real].first;2008auto *OldPHIImag = ReductionInfo[Imag].first;2009auto *NewPHI = OldToNewPHI[OldPHIReal];20102011auto *VTy = cast<VectorType>(Real->getType());2012auto *NewVTy = VectorType::getDoubleElementsVectorType(VTy);20132014// We have to interleave initial origin values coming from IncomingBlock2015Value *InitReal = OldPHIReal->getIncomingValueForBlock(Incoming);2016Value *InitImag = OldPHIImag->getIncomingValueForBlock(Incoming);20172018IRBuilder<> Builder(Incoming->getTerminator());2019auto *NewInit = Builder.CreateIntrinsic(Intrinsic::vector_interleave2, NewVTy,2020{InitReal, InitImag});20212022NewPHI->addIncoming(NewInit, Incoming);2023NewPHI->addIncoming(OperationReplacement, BackEdge);20242025// Deinterleave complex vector outside of loop so that it can be finally2026// reduced2027auto *FinalReductionReal = ReductionInfo[Real].second;2028auto *FinalReductionImag = ReductionInfo[Imag].second;20292030Builder.SetInsertPoint(2031&*FinalReductionReal->getParent()->getFirstInsertionPt());2032auto *Deinterleave = Builder.CreateIntrinsic(Intrinsic::vector_deinterleave2,2033OperationReplacement->getType(),2034OperationReplacement);20352036auto *NewReal = Builder.CreateExtractValue(Deinterleave, (uint64_t)0);2037FinalReductionReal->replaceUsesOfWith(Real, NewReal);20382039Builder.SetInsertPoint(FinalReductionImag);2040auto *NewImag = Builder.CreateExtractValue(Deinterleave, 1);2041FinalReductionImag->replaceUsesOfWith(Imag, NewImag);2042}20432044void ComplexDeinterleavingGraph::replaceNodes() {2045SmallVector<Instruction *, 16> DeadInstrRoots;2046for (auto *RootInstruction : OrderedRoots) {2047// Check if this potential root went through check process and we can2048// deinterleave it2049if (!RootToNode.count(RootInstruction))2050continue;20512052IRBuilder<> Builder(RootInstruction);2053auto RootNode = RootToNode[RootInstruction];2054Value *R = replaceNode(Builder, RootNode.get());20552056if (RootNode->Operation ==2057ComplexDeinterleavingOperation::ReductionOperation) {2058auto *RootReal = cast<Instruction>(RootNode->Real);2059auto *RootImag = cast<Instruction>(RootNode->Imag);2060ReductionInfo[RootReal].first->removeIncomingValue(BackEdge);2061ReductionInfo[RootImag].first->removeIncomingValue(BackEdge);2062DeadInstrRoots.push_back(cast<Instruction>(RootReal));2063DeadInstrRoots.push_back(cast<Instruction>(RootImag));2064} else {2065assert(R && "Unable to find replacement for RootInstruction");2066DeadInstrRoots.push_back(RootInstruction);2067RootInstruction->replaceAllUsesWith(R);2068}2069}20702071for (auto *I : DeadInstrRoots)2072RecursivelyDeleteTriviallyDeadInstructions(I, TLI);2073}207420752076