Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
freebsd
GitHub Repository: freebsd/freebsd-src
Path: blob/main/contrib/llvm-project/llvm/lib/Target/VE/VVPISelLowering.cpp
35294 views
1
//===-- VVPISelLowering.cpp - VE DAG Lowering Implementation --------------===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
//
9
// This file implements the lowering and legalization of vector instructions to
10
// VVP_*layer SDNodes.
11
//
12
//===----------------------------------------------------------------------===//
13
14
#include "VECustomDAG.h"
15
#include "VEISelLowering.h"
16
17
using namespace llvm;
18
19
#define DEBUG_TYPE "ve-lower"
20
21
SDValue VETargetLowering::splitMaskArithmetic(SDValue Op,
22
SelectionDAG &DAG) const {
23
VECustomDAG CDAG(DAG, Op);
24
SDValue AVL =
25
CDAG.getConstant(Op.getValueType().getVectorNumElements(), MVT::i32);
26
SDValue A = Op->getOperand(0);
27
SDValue B = Op->getOperand(1);
28
SDValue LoA = CDAG.getUnpack(MVT::v256i1, A, PackElem::Lo, AVL);
29
SDValue HiA = CDAG.getUnpack(MVT::v256i1, A, PackElem::Hi, AVL);
30
SDValue LoB = CDAG.getUnpack(MVT::v256i1, B, PackElem::Lo, AVL);
31
SDValue HiB = CDAG.getUnpack(MVT::v256i1, B, PackElem::Hi, AVL);
32
unsigned Opc = Op.getOpcode();
33
auto LoRes = CDAG.getNode(Opc, MVT::v256i1, {LoA, LoB});
34
auto HiRes = CDAG.getNode(Opc, MVT::v256i1, {HiA, HiB});
35
return CDAG.getPack(MVT::v512i1, LoRes, HiRes, AVL);
36
}
37
38
SDValue VETargetLowering::lowerToVVP(SDValue Op, SelectionDAG &DAG) const {
39
// Can we represent this as a VVP node.
40
const unsigned Opcode = Op->getOpcode();
41
auto VVPOpcodeOpt = getVVPOpcode(Opcode);
42
if (!VVPOpcodeOpt)
43
return SDValue();
44
unsigned VVPOpcode = *VVPOpcodeOpt;
45
const bool FromVP = ISD::isVPOpcode(Opcode);
46
47
// The representative and legalized vector type of this operation.
48
VECustomDAG CDAG(DAG, Op);
49
// Dispatch to complex lowering functions.
50
switch (VVPOpcode) {
51
case VEISD::VVP_LOAD:
52
case VEISD::VVP_STORE:
53
return lowerVVP_LOAD_STORE(Op, CDAG);
54
case VEISD::VVP_GATHER:
55
case VEISD::VVP_SCATTER:
56
return lowerVVP_GATHER_SCATTER(Op, CDAG);
57
}
58
59
EVT OpVecVT = *getIdiomaticVectorType(Op.getNode());
60
EVT LegalVecVT = getTypeToTransformTo(*DAG.getContext(), OpVecVT);
61
auto Packing = getTypePacking(LegalVecVT.getSimpleVT());
62
63
SDValue AVL;
64
SDValue Mask;
65
66
if (FromVP) {
67
// All upstream VP SDNodes always have a mask and avl.
68
auto MaskIdx = ISD::getVPMaskIdx(Opcode);
69
auto AVLIdx = ISD::getVPExplicitVectorLengthIdx(Opcode);
70
if (MaskIdx)
71
Mask = Op->getOperand(*MaskIdx);
72
if (AVLIdx)
73
AVL = Op->getOperand(*AVLIdx);
74
}
75
76
// Materialize default mask and avl.
77
if (!AVL)
78
AVL = CDAG.getConstant(OpVecVT.getVectorNumElements(), MVT::i32);
79
if (!Mask)
80
Mask = CDAG.getConstantMask(Packing, true);
81
82
assert(LegalVecVT.isSimple());
83
if (isVVPUnaryOp(VVPOpcode))
84
return CDAG.getNode(VVPOpcode, LegalVecVT, {Op->getOperand(0), Mask, AVL});
85
if (isVVPBinaryOp(VVPOpcode))
86
return CDAG.getNode(VVPOpcode, LegalVecVT,
87
{Op->getOperand(0), Op->getOperand(1), Mask, AVL});
88
if (isVVPReductionOp(VVPOpcode)) {
89
auto SrcHasStart = hasReductionStartParam(Op->getOpcode());
90
SDValue StartV = SrcHasStart ? Op->getOperand(0) : SDValue();
91
SDValue VectorV = Op->getOperand(SrcHasStart ? 1 : 0);
92
return CDAG.getLegalReductionOpVVP(VVPOpcode, Op.getValueType(), StartV,
93
VectorV, Mask, AVL, Op->getFlags());
94
}
95
96
switch (VVPOpcode) {
97
default:
98
llvm_unreachable("lowerToVVP called for unexpected SDNode.");
99
case VEISD::VVP_FFMA: {
100
// VE has a swizzled operand order in FMA (compared to LLVM IR and
101
// SDNodes).
102
auto X = Op->getOperand(2);
103
auto Y = Op->getOperand(0);
104
auto Z = Op->getOperand(1);
105
return CDAG.getNode(VVPOpcode, LegalVecVT, {X, Y, Z, Mask, AVL});
106
}
107
case VEISD::VVP_SELECT: {
108
auto Mask = Op->getOperand(0);
109
auto OnTrue = Op->getOperand(1);
110
auto OnFalse = Op->getOperand(2);
111
return CDAG.getNode(VVPOpcode, LegalVecVT, {OnTrue, OnFalse, Mask, AVL});
112
}
113
case VEISD::VVP_SETCC: {
114
EVT LegalResVT = getTypeToTransformTo(*DAG.getContext(), Op.getValueType());
115
auto LHS = Op->getOperand(0);
116
auto RHS = Op->getOperand(1);
117
auto Pred = Op->getOperand(2);
118
return CDAG.getNode(VVPOpcode, LegalResVT, {LHS, RHS, Pred, Mask, AVL});
119
}
120
}
121
}
122
123
SDValue VETargetLowering::lowerVVP_LOAD_STORE(SDValue Op,
124
VECustomDAG &CDAG) const {
125
auto VVPOpc = *getVVPOpcode(Op->getOpcode());
126
const bool IsLoad = (VVPOpc == VEISD::VVP_LOAD);
127
128
// Shares.
129
SDValue BasePtr = getMemoryPtr(Op);
130
SDValue Mask = getNodeMask(Op);
131
SDValue Chain = getNodeChain(Op);
132
SDValue AVL = getNodeAVL(Op);
133
// Store specific.
134
SDValue Data = getStoredValue(Op);
135
// Load specific.
136
SDValue PassThru = getNodePassthru(Op);
137
138
SDValue StrideV = getLoadStoreStride(Op, CDAG);
139
140
auto DataVT = *getIdiomaticVectorType(Op.getNode());
141
auto Packing = getTypePacking(DataVT);
142
143
// TODO: Infer lower AVL from mask.
144
if (!AVL)
145
AVL = CDAG.getConstant(DataVT.getVectorNumElements(), MVT::i32);
146
147
// Default to the all-true mask.
148
if (!Mask)
149
Mask = CDAG.getConstantMask(Packing, true);
150
151
if (IsLoad) {
152
MVT LegalDataVT = getLegalVectorType(
153
Packing, DataVT.getVectorElementType().getSimpleVT());
154
155
auto NewLoadV = CDAG.getNode(VEISD::VVP_LOAD, {LegalDataVT, MVT::Other},
156
{Chain, BasePtr, StrideV, Mask, AVL});
157
158
if (!PassThru || PassThru->isUndef())
159
return NewLoadV;
160
161
// Convert passthru to an explicit select node.
162
SDValue DataV = CDAG.getNode(VEISD::VVP_SELECT, DataVT,
163
{NewLoadV, PassThru, Mask, AVL});
164
SDValue NewLoadChainV = SDValue(NewLoadV.getNode(), 1);
165
166
// Merge them back into one node.
167
return CDAG.getMergeValues({DataV, NewLoadChainV});
168
}
169
170
// VVP_STORE
171
assert(VVPOpc == VEISD::VVP_STORE);
172
if (getTypeAction(*CDAG.getDAG()->getContext(), Data.getValueType()) !=
173
TargetLowering::TypeLegal)
174
// Doesn't lower store instruction if an operand is not lowered yet.
175
// If it isn't, return SDValue(). In this way, LLVM will try to lower
176
// store instruction again after lowering all operands.
177
return SDValue();
178
return CDAG.getNode(VEISD::VVP_STORE, Op.getNode()->getVTList(),
179
{Chain, Data, BasePtr, StrideV, Mask, AVL});
180
}
181
182
SDValue VETargetLowering::splitPackedLoadStore(SDValue Op,
183
VECustomDAG &CDAG) const {
184
auto VVPOC = *getVVPOpcode(Op.getOpcode());
185
assert((VVPOC == VEISD::VVP_LOAD) || (VVPOC == VEISD::VVP_STORE));
186
187
MVT DataVT = getIdiomaticVectorType(Op.getNode())->getSimpleVT();
188
assert(getTypePacking(DataVT) == Packing::Dense &&
189
"Can only split packed load/store");
190
MVT SplitDataVT = splitVectorType(DataVT);
191
192
assert(!getNodePassthru(Op) &&
193
"Should have been folded in lowering to VVP layer");
194
195
// Analyze the operation
196
SDValue PackedMask = getNodeMask(Op);
197
SDValue PackedAVL = getAnnotatedNodeAVL(Op).first;
198
SDValue PackPtr = getMemoryPtr(Op);
199
SDValue PackData = getStoredValue(Op);
200
SDValue PackStride = getLoadStoreStride(Op, CDAG);
201
202
unsigned ChainResIdx = PackData ? 0 : 1;
203
204
SDValue PartOps[2];
205
206
SDValue UpperPartAVL; // we will use this for packing things back together
207
for (PackElem Part : {PackElem::Hi, PackElem::Lo}) {
208
// VP ops already have an explicit mask and AVL. When expanding from non-VP
209
// attach those additional inputs here.
210
auto SplitTM = CDAG.getTargetSplitMask(PackedMask, PackedAVL, Part);
211
212
// Keep track of the (higher) lvl.
213
if (Part == PackElem::Hi)
214
UpperPartAVL = SplitTM.AVL;
215
216
// Attach non-predicating value operands
217
SmallVector<SDValue, 4> OpVec;
218
219
// Chain
220
OpVec.push_back(getNodeChain(Op));
221
222
// Data
223
if (PackData) {
224
SDValue PartData =
225
CDAG.getUnpack(SplitDataVT, PackData, Part, SplitTM.AVL);
226
OpVec.push_back(PartData);
227
}
228
229
// Ptr & Stride
230
// Push (ptr + ElemBytes * <Part>, 2 * ElemBytes)
231
// Stride info
232
// EVT DataVT = LegalizeVectorType(getMemoryDataVT(Op), Op, DAG, Mode);
233
OpVec.push_back(CDAG.getSplitPtrOffset(PackPtr, PackStride, Part));
234
OpVec.push_back(CDAG.getSplitPtrStride(PackStride));
235
236
// Add predicating args and generate part node
237
OpVec.push_back(SplitTM.Mask);
238
OpVec.push_back(SplitTM.AVL);
239
240
if (PackData) {
241
// Store
242
PartOps[(int)Part] = CDAG.getNode(VVPOC, MVT::Other, OpVec);
243
} else {
244
// Load
245
PartOps[(int)Part] =
246
CDAG.getNode(VVPOC, {SplitDataVT, MVT::Other}, OpVec);
247
}
248
}
249
250
// Merge the chains
251
SDValue LowChain = SDValue(PartOps[(int)PackElem::Lo].getNode(), ChainResIdx);
252
SDValue HiChain = SDValue(PartOps[(int)PackElem::Hi].getNode(), ChainResIdx);
253
SDValue FusedChains =
254
CDAG.getNode(ISD::TokenFactor, MVT::Other, {LowChain, HiChain});
255
256
// Chain only [store]
257
if (PackData)
258
return FusedChains;
259
260
// Re-pack into full packed vector result
261
MVT PackedVT =
262
getLegalVectorType(Packing::Dense, DataVT.getVectorElementType());
263
SDValue PackedVals = CDAG.getPack(PackedVT, PartOps[(int)PackElem::Lo],
264
PartOps[(int)PackElem::Hi], UpperPartAVL);
265
266
return CDAG.getMergeValues({PackedVals, FusedChains});
267
}
268
269
SDValue VETargetLowering::lowerVVP_GATHER_SCATTER(SDValue Op,
270
VECustomDAG &CDAG) const {
271
EVT DataVT = *getIdiomaticVectorType(Op.getNode());
272
auto Packing = getTypePacking(DataVT);
273
MVT LegalDataVT =
274
getLegalVectorType(Packing, DataVT.getVectorElementType().getSimpleVT());
275
276
SDValue AVL = getAnnotatedNodeAVL(Op).first;
277
SDValue Index = getGatherScatterIndex(Op);
278
SDValue BasePtr = getMemoryPtr(Op);
279
SDValue Mask = getNodeMask(Op);
280
SDValue Chain = getNodeChain(Op);
281
SDValue Scale = getGatherScatterScale(Op);
282
SDValue PassThru = getNodePassthru(Op);
283
SDValue StoredValue = getStoredValue(Op);
284
if (PassThru && PassThru->isUndef())
285
PassThru = SDValue();
286
287
bool IsScatter = (bool)StoredValue;
288
289
// TODO: Infer lower AVL from mask.
290
if (!AVL)
291
AVL = CDAG.getConstant(DataVT.getVectorNumElements(), MVT::i32);
292
293
// Default to the all-true mask.
294
if (!Mask)
295
Mask = CDAG.getConstantMask(Packing, true);
296
297
SDValue AddressVec =
298
CDAG.getGatherScatterAddress(BasePtr, Scale, Index, Mask, AVL);
299
if (IsScatter)
300
return CDAG.getNode(VEISD::VVP_SCATTER, MVT::Other,
301
{Chain, StoredValue, AddressVec, Mask, AVL});
302
303
// Gather.
304
SDValue NewLoadV = CDAG.getNode(VEISD::VVP_GATHER, {LegalDataVT, MVT::Other},
305
{Chain, AddressVec, Mask, AVL});
306
307
if (!PassThru)
308
return NewLoadV;
309
310
// TODO: Use vvp_select
311
SDValue DataV = CDAG.getNode(VEISD::VVP_SELECT, LegalDataVT,
312
{NewLoadV, PassThru, Mask, AVL});
313
SDValue NewLoadChainV = SDValue(NewLoadV.getNode(), 1);
314
return CDAG.getMergeValues({DataV, NewLoadChainV});
315
}
316
317
SDValue VETargetLowering::legalizeInternalLoadStoreOp(SDValue Op,
318
VECustomDAG &CDAG) const {
319
LLVM_DEBUG(dbgs() << "::legalizeInternalLoadStoreOp\n";);
320
MVT DataVT = getIdiomaticVectorType(Op.getNode())->getSimpleVT();
321
322
// TODO: Recognize packable load,store.
323
if (isPackedVectorType(DataVT))
324
return splitPackedLoadStore(Op, CDAG);
325
326
return legalizePackedAVL(Op, CDAG);
327
}
328
329
SDValue VETargetLowering::legalizeInternalVectorOp(SDValue Op,
330
SelectionDAG &DAG) const {
331
LLVM_DEBUG(dbgs() << "::legalizeInternalVectorOp\n";);
332
VECustomDAG CDAG(DAG, Op);
333
334
// Dispatch to specialized legalization functions.
335
switch (Op->getOpcode()) {
336
case VEISD::VVP_LOAD:
337
case VEISD::VVP_STORE:
338
return legalizeInternalLoadStoreOp(Op, CDAG);
339
}
340
341
EVT IdiomVT = Op.getValueType();
342
if (isPackedVectorType(IdiomVT) &&
343
!supportsPackedMode(Op.getOpcode(), IdiomVT))
344
return splitVectorOp(Op, CDAG);
345
346
// TODO: Implement odd/even splitting.
347
return legalizePackedAVL(Op, CDAG);
348
}
349
350
SDValue VETargetLowering::splitVectorOp(SDValue Op, VECustomDAG &CDAG) const {
351
MVT ResVT = splitVectorType(Op.getValue(0).getSimpleValueType());
352
353
auto AVLPos = getAVLPos(Op->getOpcode());
354
auto MaskPos = getMaskPos(Op->getOpcode());
355
356
SDValue PackedMask = getNodeMask(Op);
357
auto AVLPair = getAnnotatedNodeAVL(Op);
358
SDValue PackedAVL = AVLPair.first;
359
assert(!AVLPair.second && "Expecting non pack-legalized oepration");
360
361
// request the parts
362
SDValue PartOps[2];
363
364
SDValue UpperPartAVL; // we will use this for packing things back together
365
for (PackElem Part : {PackElem::Hi, PackElem::Lo}) {
366
// VP ops already have an explicit mask and AVL. When expanding from non-VP
367
// attach those additional inputs here.
368
auto SplitTM = CDAG.getTargetSplitMask(PackedMask, PackedAVL, Part);
369
370
if (Part == PackElem::Hi)
371
UpperPartAVL = SplitTM.AVL;
372
373
// Attach non-predicating value operands
374
SmallVector<SDValue, 4> OpVec;
375
for (unsigned i = 0; i < Op.getNumOperands(); ++i) {
376
if (AVLPos && ((int)i) == *AVLPos)
377
continue;
378
if (MaskPos && ((int)i) == *MaskPos)
379
continue;
380
381
// Value operand
382
auto PackedOperand = Op.getOperand(i);
383
auto UnpackedOpVT = splitVectorType(PackedOperand.getSimpleValueType());
384
SDValue PartV =
385
CDAG.getUnpack(UnpackedOpVT, PackedOperand, Part, SplitTM.AVL);
386
OpVec.push_back(PartV);
387
}
388
389
// Add predicating args and generate part node.
390
OpVec.push_back(SplitTM.Mask);
391
OpVec.push_back(SplitTM.AVL);
392
// Emit legal VVP nodes.
393
PartOps[(int)Part] =
394
CDAG.getNode(Op.getOpcode(), ResVT, OpVec, Op->getFlags());
395
}
396
397
// Re-package vectors.
398
return CDAG.getPack(Op.getValueType(), PartOps[(int)PackElem::Lo],
399
PartOps[(int)PackElem::Hi], UpperPartAVL);
400
}
401
402
SDValue VETargetLowering::legalizePackedAVL(SDValue Op,
403
VECustomDAG &CDAG) const {
404
LLVM_DEBUG(dbgs() << "::legalizePackedAVL\n";);
405
// Only required for VEC and VVP ops.
406
if (!isVVPOrVEC(Op->getOpcode()))
407
return Op;
408
409
// Operation already has a legal AVL.
410
auto AVL = getNodeAVL(Op);
411
if (isLegalAVL(AVL))
412
return Op;
413
414
// Half and round up EVL for 32bit element types.
415
SDValue LegalAVL = AVL;
416
MVT IdiomVT = getIdiomaticVectorType(Op.getNode())->getSimpleVT();
417
if (isPackedVectorType(IdiomVT)) {
418
assert(maySafelyIgnoreMask(Op) &&
419
"TODO Shift predication from EVL into Mask");
420
421
if (auto *ConstAVL = dyn_cast<ConstantSDNode>(AVL)) {
422
LegalAVL = CDAG.getConstant((ConstAVL->getZExtValue() + 1) / 2, MVT::i32);
423
} else {
424
auto ConstOne = CDAG.getConstant(1, MVT::i32);
425
auto PlusOne = CDAG.getNode(ISD::ADD, MVT::i32, {AVL, ConstOne});
426
LegalAVL = CDAG.getNode(ISD::SRL, MVT::i32, {PlusOne, ConstOne});
427
}
428
}
429
430
SDValue AnnotatedLegalAVL = CDAG.annotateLegalAVL(LegalAVL);
431
432
// Copy the operand list.
433
int NumOp = Op->getNumOperands();
434
auto AVLPos = getAVLPos(Op->getOpcode());
435
std::vector<SDValue> FixedOperands;
436
for (int i = 0; i < NumOp; ++i) {
437
if (AVLPos && (i == *AVLPos)) {
438
FixedOperands.push_back(AnnotatedLegalAVL);
439
continue;
440
}
441
FixedOperands.push_back(Op->getOperand(i));
442
}
443
444
// Clone the operation with fixed operands.
445
auto Flags = Op->getFlags();
446
SDValue NewN =
447
CDAG.getNode(Op->getOpcode(), Op->getVTList(), FixedOperands, Flags);
448
return NewN;
449
}
450
451