Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
freebsd
GitHub Repository: freebsd/freebsd-src
Path: blob/main/contrib/llvm-project/llvm/lib/Target/AArch64/AArch64PBQPRegAlloc.cpp
213799 views
1
//===-- AArch64PBQPRegAlloc.cpp - AArch64 specific PBQP constraints -------===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
// This file contains the AArch64 / Cortex-A57 specific register allocation
9
// constraints for use by the PBQP register allocator.
10
//
11
// It is essentially a transcription of what is contained in
12
// AArch64A57FPLoadBalancing, which tries to use a balanced
13
// mix of odd and even D-registers when performing a critical sequence of
14
// independent, non-quadword FP/ASIMD floating-point multiply-accumulates.
15
//===----------------------------------------------------------------------===//
16
17
#include "AArch64PBQPRegAlloc.h"
18
#include "AArch64InstrInfo.h"
19
#include "AArch64RegisterInfo.h"
20
#include "llvm/CodeGen/LiveIntervals.h"
21
#include "llvm/CodeGen/MachineBasicBlock.h"
22
#include "llvm/CodeGen/MachineFunction.h"
23
#include "llvm/CodeGen/RegAllocPBQP.h"
24
#include "llvm/Support/Debug.h"
25
#include "llvm/Support/ErrorHandling.h"
26
#include "llvm/Support/raw_ostream.h"
27
28
#define DEBUG_TYPE "aarch64-pbqp"
29
30
using namespace llvm;
31
32
namespace {
33
34
bool isOdd(unsigned reg) {
35
switch (reg) {
36
default:
37
llvm_unreachable("Register is not from the expected class !");
38
case AArch64::S1:
39
case AArch64::S3:
40
case AArch64::S5:
41
case AArch64::S7:
42
case AArch64::S9:
43
case AArch64::S11:
44
case AArch64::S13:
45
case AArch64::S15:
46
case AArch64::S17:
47
case AArch64::S19:
48
case AArch64::S21:
49
case AArch64::S23:
50
case AArch64::S25:
51
case AArch64::S27:
52
case AArch64::S29:
53
case AArch64::S31:
54
case AArch64::D1:
55
case AArch64::D3:
56
case AArch64::D5:
57
case AArch64::D7:
58
case AArch64::D9:
59
case AArch64::D11:
60
case AArch64::D13:
61
case AArch64::D15:
62
case AArch64::D17:
63
case AArch64::D19:
64
case AArch64::D21:
65
case AArch64::D23:
66
case AArch64::D25:
67
case AArch64::D27:
68
case AArch64::D29:
69
case AArch64::D31:
70
case AArch64::Q1:
71
case AArch64::Q3:
72
case AArch64::Q5:
73
case AArch64::Q7:
74
case AArch64::Q9:
75
case AArch64::Q11:
76
case AArch64::Q13:
77
case AArch64::Q15:
78
case AArch64::Q17:
79
case AArch64::Q19:
80
case AArch64::Q21:
81
case AArch64::Q23:
82
case AArch64::Q25:
83
case AArch64::Q27:
84
case AArch64::Q29:
85
case AArch64::Q31:
86
return true;
87
case AArch64::S0:
88
case AArch64::S2:
89
case AArch64::S4:
90
case AArch64::S6:
91
case AArch64::S8:
92
case AArch64::S10:
93
case AArch64::S12:
94
case AArch64::S14:
95
case AArch64::S16:
96
case AArch64::S18:
97
case AArch64::S20:
98
case AArch64::S22:
99
case AArch64::S24:
100
case AArch64::S26:
101
case AArch64::S28:
102
case AArch64::S30:
103
case AArch64::D0:
104
case AArch64::D2:
105
case AArch64::D4:
106
case AArch64::D6:
107
case AArch64::D8:
108
case AArch64::D10:
109
case AArch64::D12:
110
case AArch64::D14:
111
case AArch64::D16:
112
case AArch64::D18:
113
case AArch64::D20:
114
case AArch64::D22:
115
case AArch64::D24:
116
case AArch64::D26:
117
case AArch64::D28:
118
case AArch64::D30:
119
case AArch64::Q0:
120
case AArch64::Q2:
121
case AArch64::Q4:
122
case AArch64::Q6:
123
case AArch64::Q8:
124
case AArch64::Q10:
125
case AArch64::Q12:
126
case AArch64::Q14:
127
case AArch64::Q16:
128
case AArch64::Q18:
129
case AArch64::Q20:
130
case AArch64::Q22:
131
case AArch64::Q24:
132
case AArch64::Q26:
133
case AArch64::Q28:
134
case AArch64::Q30:
135
return false;
136
137
}
138
}
139
140
bool haveSameParity(unsigned reg1, unsigned reg2) {
141
assert(AArch64InstrInfo::isFpOrNEON(reg1) &&
142
"Expecting an FP register for reg1");
143
assert(AArch64InstrInfo::isFpOrNEON(reg2) &&
144
"Expecting an FP register for reg2");
145
146
return isOdd(reg1) == isOdd(reg2);
147
}
148
149
}
150
151
bool A57ChainingConstraint::addIntraChainConstraint(PBQPRAGraph &G, unsigned Rd,
152
unsigned Ra) {
153
if (Rd == Ra)
154
return false;
155
156
LiveIntervals &LIs = G.getMetadata().LIS;
157
158
if (Register::isPhysicalRegister(Rd) || Register::isPhysicalRegister(Ra)) {
159
LLVM_DEBUG(dbgs() << "Rd is a physical reg:"
160
<< Register::isPhysicalRegister(Rd) << '\n');
161
LLVM_DEBUG(dbgs() << "Ra is a physical reg:"
162
<< Register::isPhysicalRegister(Ra) << '\n');
163
return false;
164
}
165
166
PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(Rd);
167
PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(Ra);
168
169
const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed =
170
&G.getNodeMetadata(node1).getAllowedRegs();
171
const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRaAllowed =
172
&G.getNodeMetadata(node2).getAllowedRegs();
173
174
PBQPRAGraph::EdgeId edge = G.findEdge(node1, node2);
175
176
// The edge does not exist. Create one with the appropriate interference
177
// costs.
178
if (edge == G.invalidEdgeId()) {
179
const LiveInterval &ld = LIs.getInterval(Rd);
180
const LiveInterval &la = LIs.getInterval(Ra);
181
bool livesOverlap = ld.overlaps(la);
182
183
PBQPRAGraph::RawMatrix costs(vRdAllowed->size() + 1,
184
vRaAllowed->size() + 1, 0);
185
for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
186
unsigned pRd = (*vRdAllowed)[i];
187
for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
188
unsigned pRa = (*vRaAllowed)[j];
189
if (livesOverlap && TRI->regsOverlap(pRd, pRa))
190
costs[i + 1][j + 1] = std::numeric_limits<PBQP::PBQPNum>::infinity();
191
else
192
costs[i + 1][j + 1] = haveSameParity(pRd, pRa) ? 0.0 : 1.0;
193
}
194
}
195
G.addEdge(node1, node2, std::move(costs));
196
return true;
197
}
198
199
if (G.getEdgeNode1Id(edge) == node2) {
200
std::swap(node1, node2);
201
std::swap(vRdAllowed, vRaAllowed);
202
}
203
204
// Enforce minCost(sameParity(RaClass)) > maxCost(otherParity(RdClass))
205
PBQPRAGraph::RawMatrix costs(G.getEdgeCosts(edge));
206
for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
207
unsigned pRd = (*vRdAllowed)[i];
208
209
// Get the maximum cost (excluding unallocatable reg) for same parity
210
// registers
211
PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min();
212
for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
213
unsigned pRa = (*vRaAllowed)[j];
214
if (haveSameParity(pRd, pRa))
215
if (costs[i + 1][j + 1] !=
216
std::numeric_limits<PBQP::PBQPNum>::infinity() &&
217
costs[i + 1][j + 1] > sameParityMax)
218
sameParityMax = costs[i + 1][j + 1];
219
}
220
221
// Ensure all registers with a different parity have a higher cost
222
// than sameParityMax
223
for (unsigned j = 0, je = vRaAllowed->size(); j != je; ++j) {
224
unsigned pRa = (*vRaAllowed)[j];
225
if (!haveSameParity(pRd, pRa))
226
if (sameParityMax > costs[i + 1][j + 1])
227
costs[i + 1][j + 1] = sameParityMax + 1.0;
228
}
229
}
230
G.updateEdgeCosts(edge, std::move(costs));
231
232
return true;
233
}
234
235
void A57ChainingConstraint::addInterChainConstraint(PBQPRAGraph &G, unsigned Rd,
236
unsigned Ra) {
237
LiveIntervals &LIs = G.getMetadata().LIS;
238
239
// Do some Chain management
240
if (Chains.count(Ra)) {
241
if (Rd != Ra) {
242
LLVM_DEBUG(dbgs() << "Moving acc chain from " << printReg(Ra, TRI)
243
<< " to " << printReg(Rd, TRI) << '\n');
244
Chains.remove(Ra);
245
Chains.insert(Rd);
246
}
247
} else {
248
LLVM_DEBUG(dbgs() << "Creating new acc chain for " << printReg(Rd, TRI)
249
<< '\n');
250
Chains.insert(Rd);
251
}
252
253
PBQPRAGraph::NodeId node1 = G.getMetadata().getNodeIdForVReg(Rd);
254
255
const LiveInterval &ld = LIs.getInterval(Rd);
256
for (auto r : Chains) {
257
// Skip self
258
if (r == Rd)
259
continue;
260
261
const LiveInterval &lr = LIs.getInterval(r);
262
if (ld.overlaps(lr)) {
263
const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRdAllowed =
264
&G.getNodeMetadata(node1).getAllowedRegs();
265
266
PBQPRAGraph::NodeId node2 = G.getMetadata().getNodeIdForVReg(r);
267
const PBQPRAGraph::NodeMetadata::AllowedRegVector *vRrAllowed =
268
&G.getNodeMetadata(node2).getAllowedRegs();
269
270
PBQPRAGraph::EdgeId edge = G.findEdge(node1, node2);
271
assert(edge != G.invalidEdgeId() &&
272
"PBQP error ! The edge should exist !");
273
274
LLVM_DEBUG(dbgs() << "Refining constraint !\n");
275
276
if (G.getEdgeNode1Id(edge) == node2) {
277
std::swap(node1, node2);
278
std::swap(vRdAllowed, vRrAllowed);
279
}
280
281
// Enforce that cost is higher with all other Chains of the same parity
282
PBQP::Matrix costs(G.getEdgeCosts(edge));
283
for (unsigned i = 0, ie = vRdAllowed->size(); i != ie; ++i) {
284
unsigned pRd = (*vRdAllowed)[i];
285
286
// Get the maximum cost (excluding unallocatable reg) for all other
287
// parity registers
288
PBQP::PBQPNum sameParityMax = std::numeric_limits<PBQP::PBQPNum>::min();
289
for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) {
290
unsigned pRa = (*vRrAllowed)[j];
291
if (!haveSameParity(pRd, pRa))
292
if (costs[i + 1][j + 1] !=
293
std::numeric_limits<PBQP::PBQPNum>::infinity() &&
294
costs[i + 1][j + 1] > sameParityMax)
295
sameParityMax = costs[i + 1][j + 1];
296
}
297
298
// Ensure all registers with same parity have a higher cost
299
// than sameParityMax
300
for (unsigned j = 0, je = vRrAllowed->size(); j != je; ++j) {
301
unsigned pRa = (*vRrAllowed)[j];
302
if (haveSameParity(pRd, pRa))
303
if (sameParityMax > costs[i + 1][j + 1])
304
costs[i + 1][j + 1] = sameParityMax + 1.0;
305
}
306
}
307
G.updateEdgeCosts(edge, std::move(costs));
308
}
309
}
310
}
311
312
static bool regJustKilledBefore(const LiveIntervals &LIs, unsigned reg,
313
const MachineInstr &MI) {
314
const LiveInterval &LI = LIs.getInterval(reg);
315
SlotIndex SI = LIs.getInstructionIndex(MI);
316
return LI.expiredAt(SI);
317
}
318
319
void A57ChainingConstraint::apply(PBQPRAGraph &G) {
320
const MachineFunction &MF = G.getMetadata().MF;
321
LiveIntervals &LIs = G.getMetadata().LIS;
322
323
TRI = MF.getSubtarget().getRegisterInfo();
324
LLVM_DEBUG(MF.dump());
325
326
for (const auto &MBB: MF) {
327
Chains.clear(); // FIXME: really needed ? Could not work at MF level ?
328
329
for (const auto &MI: MBB) {
330
331
// Forget Chains which have expired
332
for (auto r : Chains) {
333
SmallVector<unsigned, 8> toDel;
334
if(regJustKilledBefore(LIs, r, MI)) {
335
LLVM_DEBUG(dbgs() << "Killing chain " << printReg(r, TRI) << " at ";
336
MI.print(dbgs()));
337
toDel.push_back(r);
338
}
339
340
while (!toDel.empty()) {
341
Chains.remove(toDel.back());
342
toDel.pop_back();
343
}
344
}
345
346
switch (MI.getOpcode()) {
347
case AArch64::FMSUBSrrr:
348
case AArch64::FMADDSrrr:
349
case AArch64::FNMSUBSrrr:
350
case AArch64::FNMADDSrrr:
351
case AArch64::FMSUBDrrr:
352
case AArch64::FMADDDrrr:
353
case AArch64::FNMSUBDrrr:
354
case AArch64::FNMADDDrrr: {
355
Register Rd = MI.getOperand(0).getReg();
356
Register Ra = MI.getOperand(3).getReg();
357
358
if (addIntraChainConstraint(G, Rd, Ra))
359
addInterChainConstraint(G, Rd, Ra);
360
break;
361
}
362
363
case AArch64::FMLAv2f32:
364
case AArch64::FMLSv2f32: {
365
Register Rd = MI.getOperand(0).getReg();
366
addInterChainConstraint(G, Rd, Rd);
367
break;
368
}
369
370
default:
371
break;
372
}
373
}
374
}
375
}
376
377