Path: blob/main/contrib/llvm-project/llvm/lib/Target/AArch64/AArch64PBQPRegAlloc.cpp
213799 views
//===-- AArch64PBQPRegAlloc.cpp - AArch64 specific PBQP constraints -------===//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.3// See https://llvm.org/LICENSE.txt for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===----------------------------------------------------------------------===//7// This file contains the AArch64 / Cortex-A57 specific register allocation8// constraints for use by the PBQP register allocator.9//10// It is essentially a transcription of what is contained in11// AArch64A57FPLoadBalancing, which tries to use a balanced12// mix of odd and even D-registers when performing a critical sequence of13// independent, non-quadword FP/ASIMD floating-point multiply-accumulates.14//===----------------------------------------------------------------------===//1516#include "AArch64PBQPRegAlloc.h"17#include "AArch64InstrInfo.h"18#include "AArch64RegisterInfo.h"19#include "llvm/CodeGen/LiveIntervals.h"20#include "llvm/CodeGen/MachineBasicBlock.h"21#include "llvm/CodeGen/MachineFunction.h"22#include "llvm/CodeGen/RegAllocPBQP.h"23#include "llvm/Support/Debug.h"24#include "llvm/Support/ErrorHandling.h"25#include "llvm/Support/raw_ostream.h"2627#define DEBUG_TYPE "aarch64-pbqp"2829using namespace llvm;3031namespace {3233bool isOdd(unsigned reg) {34switch (reg) {35default:36llvm_unreachable("Register is not from the expected class !");37case AArch64::S1:38case AArch64::S3:39case AArch64::S5:40case AArch64::S7:41case AArch64::S9:42case AArch64::S11:43case AArch64::S13:44case AArch64::S15:45case AArch64::S17:46case AArch64::S19:47case AArch64::S21:48case AArch64::S23:49case AArch64::S25:50case AArch64::S27:51case AArch64::S29:52case AArch64::S31:53case AArch64::D1:54case AArch64::D3:55case AArch64::D5:56case AArch64::D7:57case AArch64::D9:58case AArch64::D11:59case AArch64::D13:60case AArch64::D15:61case AArch64::D17:62case AArch64::D19:63case AArch64::D21:64case AArch64::D23:65case AArch64::D25:66case AArch64::D27:67case AArch64::D29:68case AArch64::D31:69case AArch64::Q1:70case AArch64::Q3:71case AArch64::Q5:72case AArch64::Q7:73case AArch64::Q9:74case AArch64::Q11:75case AArch64::Q13:76case AArch64::Q15:77case AArch64::Q17:78case AArch64::Q19:79case AArch64::Q21:80case AArch64::Q23:81case AArch64::Q25:82case AArch64::Q27:83case AArch64::Q29:84case AArch64::Q31:85return true;86case AArch64::S0:87case AArch64::S2:88case AArch64::S4:89case AArch64::S6:90case AArch64::S8:91case AArch64::S10:92case AArch64::S12:93case AArch64::S14:94case AArch64::S16:95case AArch64::S18:96case AArch64::S20:97case AArch64::S22:98case AArch64::S24:99case AArch64::S26:100case AArch64::S28:101case AArch64::S30:102case AArch64::D0:103case AArch64::D2:104case AArch64::D4:105case AArch64::D6:106case AArch64::D8:107case AArch64::D10:108case AArch64::D12:109case AArch64::D14:110case AArch64::D16:111case AArch64::D18:112case AArch64::D20:113case AArch64::D22:114case AArch64::D24:115case AArch64::D26:116case AArch64::D28:117case AArch64::D30:118case AArch64::Q0:119case AArch64::Q2:120case AArch64::Q4:121case AArch64::Q6:122case AArch64::Q8:123case AArch64::Q10:124case AArch64::Q12:125case AArch64::Q14:126case AArch64::Q16:127case AArch64::Q18:128case AArch64::Q20:129case AArch64::Q22:130case AArch64::Q24:131case AArch64::Q26:132case AArch64::Q28:133case AArch64::Q30:134return false;135136}137}138139bool haveSameParity(unsigned reg1, unsigned reg2) {140assert(AArch64InstrInfo::isFpOrNEON(reg1) &&141"Expecting an FP register for reg1");142assert(AArch64InstrInfo::isFpOrNEON(reg2) &&143"Expecting an FP register for reg2");144145return isOdd(reg1) == isOdd(reg2);146}147148}149150bool A57ChainingConstraint::addIntraChainConstraint(PBQPRAGraph &G, unsigned Rd,151unsigned Ra) {152if (Rd == Ra)153return false;154155LiveIntervals &LIs = G.getMetadata().LIS;156157if (Register::isPhysicalRegister(Rd) || Register::isPhysicalRegister(Ra)) {158LLVM_DEBUG(dbgs() << "Rd is a physical reg:"159<< Register::isPhysicalRegister(Rd) << '\n');160LLVM_DEBUG(dbgs() << "Ra is a physical reg:"161<< Register::isPhysicalRegister(Ra) << '\n');162return false;163}164165PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(Rd);166PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(Ra);167168const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed =169&G.getNodeMetadata(node1).getAllowedRegs();170const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRaAllowed =171&G.getNodeMetadata(node2).getAllowedRegs();172173PBQPRAGraph::EdgeId edge = G.findEdge(node1, node2);174175// The edge does not exist. Create one with the appropriate interference176// costs.177if (edge == G.invalidEdgeId()) {178const LiveInterval &ld = LIs.getInterval(Rd);179const LiveInterval &la = LIs.getInterval(Ra);180bool livesOverlap = ld.overlaps(la);181182PBQPRAGraph::RawMatrix costs(vRdAllowed->size() + 1,183vRaAllowed->size() + 1, 0);184for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {185unsigned pRd = (*vRdAllowed)[i];186for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {187unsigned pRa = (*vRaAllowed)[j];188if (livesOverlap && TRI->regsOverlap(pRd, pRa))189costs[i + 1][j + 1] = std::numeric_limits<PBQP::PBQPNum>::infinity();190else191costs[i + 1][j + 1] = haveSameParity(pRd, pRa) ? 0.0 : 1.0;192}193}194G.addEdge(node1, node2, std::move(costs));195return true;196}197198if (G.getEdgeNode1Id(edge) == node2) {199std::swap(node1, node2);200std::swap(vRdAllowed, vRaAllowed);201}202203// Enforce minCost(sameParity(RaClass)) > maxCost(otherParity(RdClass))204PBQPRAGraph::RawMatrix costs(G.getEdgeCosts(edge));205for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {206unsigned pRd = (*vRdAllowed)[i];207208// Get the maximum cost (excluding unallocatable reg) for same parity209// registers210PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min();211for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {212unsigned pRa = (*vRaAllowed)[j];213if (haveSameParity(pRd, pRa))214if (costs[i + 1][j + 1] !=215std::numeric_limits<PBQP::PBQPNum>::infinity() &&216costs[i + 1][j + 1] > sameParityMax)217sameParityMax = costs[i + 1][j + 1];218}219220// Ensure all registers with a different parity have a higher cost221// than sameParityMax222for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {223unsigned pRa = (*vRaAllowed)[j];224if (!haveSameParity(pRd, pRa))225if (sameParityMax > costs[i + 1][j + 1])226costs[i + 1][j + 1] = sameParityMax + 1.0;227}228}229G.updateEdgeCosts(edge, std::move(costs));230231return true;232}233234void A57ChainingConstraint::addInterChainConstraint(PBQPRAGraph &G, unsigned Rd,235unsigned Ra) {236LiveIntervals &LIs = G.getMetadata().LIS;237238// Do some Chain management239if (Chains.count(Ra)) {240if (Rd != Ra) {241LLVM_DEBUG(dbgs() << "Moving acc chain from " << printReg(Ra, TRI)242<< " to " << printReg(Rd, TRI) << '\n');243Chains.remove(Ra);244Chains.insert(Rd);245}246} else {247LLVM_DEBUG(dbgs() << "Creating new acc chain for " << printReg(Rd, TRI)248<< '\n');249Chains.insert(Rd);250}251252PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(Rd);253254const LiveInterval &ld = LIs.getInterval(Rd);255for (auto r : Chains) {256// Skip self257if (r == Rd)258continue;259260const LiveInterval &lr = LIs.getInterval(r);261if (ld.overlaps(lr)) {262const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed =263&G.getNodeMetadata(node1).getAllowedRegs();264265PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(r);266const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRrAllowed =267&G.getNodeMetadata(node2).getAllowedRegs();268269PBQPRAGraph::EdgeId edge = G.findEdge(node1, node2);270assert(edge != G.invalidEdgeId() &&271"PBQP error ! The edge should exist !");272273LLVM_DEBUG(dbgs() << "Refining constraint !\n");274275if (G.getEdgeNode1Id(edge) == node2) {276std::swap(node1, node2);277std::swap(vRdAllowed, vRrAllowed);278}279280// Enforce that cost is higher with all other Chains of the same parity281PBQP::Matrix costs(G.getEdgeCosts(edge));282for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {283unsigned pRd = (*vRdAllowed)[i];284285// Get the maximum cost (excluding unallocatable reg) for all other286// parity registers287PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min();288for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) {289unsigned pRa = (*vRrAllowed)[j];290if (!haveSameParity(pRd, pRa))291if (costs[i + 1][j + 1] !=292std::numeric_limits<PBQP::PBQPNum>::infinity() &&293costs[i + 1][j + 1] > sameParityMax)294sameParityMax = costs[i + 1][j + 1];295}296297// Ensure all registers with same parity have a higher cost298// than sameParityMax299for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) {300unsigned pRa = (*vRrAllowed)[j];301if (haveSameParity(pRd, pRa))302if (sameParityMax > costs[i + 1][j + 1])303costs[i + 1][j + 1] = sameParityMax + 1.0;304}305}306G.updateEdgeCosts(edge, std::move(costs));307}308}309}310311static bool regJustKilledBefore(const LiveIntervals &LIs, unsigned reg,312const MachineInstr &MI) {313const LiveInterval &LI = LIs.getInterval(reg);314SlotIndex SI = LIs.getInstructionIndex(MI);315return LI.expiredAt(SI);316}317318void A57ChainingConstraint::apply(PBQPRAGraph &G) {319const MachineFunction &MF = G.getMetadata().MF;320LiveIntervals &LIs = G.getMetadata().LIS;321322TRI = MF.getSubtarget().getRegisterInfo();323LLVM_DEBUG(MF.dump());324325for (const auto &MBB: MF) {326Chains.clear(); // FIXME: really needed ? Could not work at MF level ?327328for (const auto &MI: MBB) {329330// Forget Chains which have expired331for (auto r : Chains) {332SmallVector<unsigned, 8> toDel;333if(regJustKilledBefore(LIs, r, MI)) {334LLVM_DEBUG(dbgs() << "Killing chain " << printReg(r, TRI) << " at ";335MI.print(dbgs()));336toDel.push_back(r);337}338339while (!toDel.empty()) {340Chains.remove(toDel.back());341toDel.pop_back();342}343}344345switch (MI.getOpcode()) {346case AArch64::FMSUBSrrr:347case AArch64::FMADDSrrr:348case AArch64::FNMSUBSrrr:349case AArch64::FNMADDSrrr:350case AArch64::FMSUBDrrr:351case AArch64::FMADDDrrr:352case AArch64::FNMSUBDrrr:353case AArch64::FNMADDDrrr: {354Register Rd = MI.getOperand(0).getReg();355Register Ra = MI.getOperand(3).getReg();356357if (addIntraChainConstraint(G, Rd, Ra))358addInterChainConstraint(G, Rd, Ra);359break;360}361362case AArch64::FMLAv2f32:363case AArch64::FMLSv2f32: {364Register Rd = MI.getOperand(0).getReg();365addInterChainConstraint(G, Rd, Rd);366break;367}368369default:370break;371}372}373}374}375376377