Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
freebsd
GitHub Repository: freebsd/freebsd-src
Path: blob/main/contrib/llvm-project/llvm/lib/Target/VE/VECustomDAG.cpp
35294 views
1
//===-- VECustomDAG.h - VE Custom DAG Nodes ------------*- C++ -*-===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
//
9
// This file defines the interfaces that VE uses to lower LLVM code into a
10
// selection DAG.
11
//
12
//===----------------------------------------------------------------------===//
13
14
#include "VECustomDAG.h"
15
16
#ifndef DEBUG_TYPE
17
#define DEBUG_TYPE "vecustomdag"
18
#endif
19
20
namespace llvm {
21
22
bool isPackedVectorType(EVT SomeVT) {
23
if (!SomeVT.isVector())
24
return false;
25
return SomeVT.getVectorNumElements() > StandardVectorWidth;
26
}
27
28
MVT splitVectorType(MVT VT) {
29
if (!VT.isVector())
30
return VT;
31
return MVT::getVectorVT(VT.getVectorElementType(), StandardVectorWidth);
32
}
33
34
MVT getLegalVectorType(Packing P, MVT ElemVT) {
35
return MVT::getVectorVT(ElemVT, P == Packing::Normal ? StandardVectorWidth
36
: PackedVectorWidth);
37
}
38
39
Packing getTypePacking(EVT VT) {
40
assert(VT.isVector());
41
return isPackedVectorType(VT) ? Packing::Dense : Packing::Normal;
42
}
43
44
bool isMaskType(EVT SomeVT) {
45
if (!SomeVT.isVector())
46
return false;
47
return SomeVT.getVectorElementType() == MVT::i1;
48
}
49
50
bool isMaskArithmetic(SDValue Op) {
51
switch (Op.getOpcode()) {
52
default:
53
return false;
54
case ISD::AND:
55
case ISD::XOR:
56
case ISD::OR:
57
return isMaskType(Op.getValueType());
58
}
59
}
60
61
/// \returns the VVP_* SDNode opcode corresponsing to \p OC.
62
std::optional<unsigned> getVVPOpcode(unsigned Opcode) {
63
switch (Opcode) {
64
case ISD::MLOAD:
65
return VEISD::VVP_LOAD;
66
case ISD::MSTORE:
67
return VEISD::VVP_STORE;
68
#define HANDLE_VP_TO_VVP(VPOPC, VVPNAME) \
69
case ISD::VPOPC: \
70
return VEISD::VVPNAME;
71
#define ADD_VVP_OP(VVPNAME, SDNAME) \
72
case VEISD::VVPNAME: \
73
case ISD::SDNAME: \
74
return VEISD::VVPNAME;
75
#include "VVPNodes.def"
76
// TODO: Map those in VVPNodes.def too
77
case ISD::EXPERIMENTAL_VP_STRIDED_LOAD:
78
return VEISD::VVP_LOAD;
79
case ISD::EXPERIMENTAL_VP_STRIDED_STORE:
80
return VEISD::VVP_STORE;
81
}
82
return std::nullopt;
83
}
84
85
bool maySafelyIgnoreMask(SDValue Op) {
86
auto VVPOpc = getVVPOpcode(Op->getOpcode());
87
auto Opc = VVPOpc.value_or(Op->getOpcode());
88
89
switch (Opc) {
90
case VEISD::VVP_SDIV:
91
case VEISD::VVP_UDIV:
92
case VEISD::VVP_FDIV:
93
case VEISD::VVP_SELECT:
94
return false;
95
96
default:
97
return true;
98
}
99
}
100
101
bool supportsPackedMode(unsigned Opcode, EVT IdiomVT) {
102
bool IsPackedOp = isPackedVectorType(IdiomVT);
103
bool IsMaskOp = isMaskType(IdiomVT);
104
switch (Opcode) {
105
default:
106
return false;
107
108
case VEISD::VEC_BROADCAST:
109
return true;
110
#define REGISTER_PACKED(VVP_NAME) case VEISD::VVP_NAME:
111
#include "VVPNodes.def"
112
return IsPackedOp && !IsMaskOp;
113
}
114
}
115
116
bool isPackingSupportOpcode(unsigned Opc) {
117
switch (Opc) {
118
case VEISD::VEC_PACK:
119
case VEISD::VEC_UNPACK_LO:
120
case VEISD::VEC_UNPACK_HI:
121
return true;
122
}
123
return false;
124
}
125
126
bool isVVPOrVEC(unsigned Opcode) {
127
switch (Opcode) {
128
case VEISD::VEC_BROADCAST:
129
#define ADD_VVP_OP(VVPNAME, ...) case VEISD::VVPNAME:
130
#include "VVPNodes.def"
131
return true;
132
}
133
return false;
134
}
135
136
bool isVVPUnaryOp(unsigned VVPOpcode) {
137
switch (VVPOpcode) {
138
#define ADD_UNARY_VVP_OP(VVPNAME, ...) \
139
case VEISD::VVPNAME: \
140
return true;
141
#include "VVPNodes.def"
142
}
143
return false;
144
}
145
146
bool isVVPBinaryOp(unsigned VVPOpcode) {
147
switch (VVPOpcode) {
148
#define ADD_BINARY_VVP_OP(VVPNAME, ...) \
149
case VEISD::VVPNAME: \
150
return true;
151
#include "VVPNodes.def"
152
}
153
return false;
154
}
155
156
bool isVVPReductionOp(unsigned Opcode) {
157
switch (Opcode) {
158
#define ADD_REDUCE_VVP_OP(VVP_NAME, SDNAME) case VEISD::VVP_NAME:
159
#include "VVPNodes.def"
160
return true;
161
}
162
return false;
163
}
164
165
// Return the AVL operand position for this VVP or VEC Op.
166
std::optional<int> getAVLPos(unsigned Opc) {
167
// This is only available for VP SDNodes
168
auto PosOpt = ISD::getVPExplicitVectorLengthIdx(Opc);
169
if (PosOpt)
170
return *PosOpt;
171
172
// VVP Opcodes.
173
if (isVVPBinaryOp(Opc))
174
return 3;
175
176
// VM Opcodes.
177
switch (Opc) {
178
case VEISD::VEC_BROADCAST:
179
return 1;
180
case VEISD::VVP_SELECT:
181
return 3;
182
case VEISD::VVP_LOAD:
183
return 4;
184
case VEISD::VVP_STORE:
185
return 5;
186
}
187
188
return std::nullopt;
189
}
190
191
std::optional<int> getMaskPos(unsigned Opc) {
192
// This is only available for VP SDNodes
193
auto PosOpt = ISD::getVPMaskIdx(Opc);
194
if (PosOpt)
195
return *PosOpt;
196
197
// VVP Opcodes.
198
if (isVVPBinaryOp(Opc))
199
return 2;
200
201
// Other opcodes.
202
switch (Opc) {
203
case ISD::MSTORE:
204
return 4;
205
case ISD::MLOAD:
206
return 3;
207
case VEISD::VVP_SELECT:
208
return 2;
209
}
210
211
return std::nullopt;
212
}
213
214
bool isLegalAVL(SDValue AVL) { return AVL->getOpcode() == VEISD::LEGALAVL; }
215
216
/// Node Properties {
217
218
SDValue getNodeChain(SDValue Op) {
219
if (MemSDNode *MemN = dyn_cast<MemSDNode>(Op.getNode()))
220
return MemN->getChain();
221
222
switch (Op->getOpcode()) {
223
case VEISD::VVP_LOAD:
224
case VEISD::VVP_STORE:
225
return Op->getOperand(0);
226
}
227
return SDValue();
228
}
229
230
SDValue getMemoryPtr(SDValue Op) {
231
if (auto *MemN = dyn_cast<MemSDNode>(Op.getNode()))
232
return MemN->getBasePtr();
233
234
switch (Op->getOpcode()) {
235
case VEISD::VVP_LOAD:
236
return Op->getOperand(1);
237
case VEISD::VVP_STORE:
238
return Op->getOperand(2);
239
}
240
return SDValue();
241
}
242
243
std::optional<EVT> getIdiomaticVectorType(SDNode *Op) {
244
unsigned OC = Op->getOpcode();
245
246
// For memory ops -> the transfered data type
247
if (auto MemN = dyn_cast<MemSDNode>(Op))
248
return MemN->getMemoryVT();
249
250
switch (OC) {
251
// Standard ISD.
252
case ISD::SELECT: // not aliased with VVP_SELECT
253
case ISD::CONCAT_VECTORS:
254
case ISD::EXTRACT_SUBVECTOR:
255
case ISD::VECTOR_SHUFFLE:
256
case ISD::BUILD_VECTOR:
257
case ISD::SCALAR_TO_VECTOR:
258
return Op->getValueType(0);
259
}
260
261
// Translate to VVP where possible.
262
unsigned OriginalOC = OC;
263
if (auto VVPOpc = getVVPOpcode(OC))
264
OC = *VVPOpc;
265
266
if (isVVPReductionOp(OC))
267
return Op->getOperand(hasReductionStartParam(OriginalOC) ? 1 : 0)
268
.getValueType();
269
270
switch (OC) {
271
default:
272
case VEISD::VVP_SETCC:
273
return Op->getOperand(0).getValueType();
274
275
case VEISD::VVP_SELECT:
276
#define ADD_BINARY_VVP_OP(VVP_NAME, ...) case VEISD::VVP_NAME:
277
#include "VVPNodes.def"
278
return Op->getValueType(0);
279
280
case VEISD::VVP_LOAD:
281
return Op->getValueType(0);
282
283
case VEISD::VVP_STORE:
284
return Op->getOperand(1)->getValueType(0);
285
286
// VEC
287
case VEISD::VEC_BROADCAST:
288
return Op->getValueType(0);
289
}
290
}
291
292
SDValue getLoadStoreStride(SDValue Op, VECustomDAG &CDAG) {
293
switch (Op->getOpcode()) {
294
case VEISD::VVP_STORE:
295
return Op->getOperand(3);
296
case VEISD::VVP_LOAD:
297
return Op->getOperand(2);
298
}
299
300
if (auto *StoreN = dyn_cast<VPStridedStoreSDNode>(Op.getNode()))
301
return StoreN->getStride();
302
if (auto *StoreN = dyn_cast<VPStridedLoadSDNode>(Op.getNode()))
303
return StoreN->getStride();
304
305
if (isa<MemSDNode>(Op.getNode())) {
306
// Regular MLOAD/MSTORE/LOAD/STORE
307
// No stride argument -> use the contiguous element size as stride.
308
uint64_t ElemStride = getIdiomaticVectorType(Op.getNode())
309
->getVectorElementType()
310
.getStoreSize();
311
return CDAG.getConstant(ElemStride, MVT::i64);
312
}
313
return SDValue();
314
}
315
316
SDValue getGatherScatterIndex(SDValue Op) {
317
if (auto *N = dyn_cast<MaskedGatherScatterSDNode>(Op.getNode()))
318
return N->getIndex();
319
if (auto *N = dyn_cast<VPGatherScatterSDNode>(Op.getNode()))
320
return N->getIndex();
321
return SDValue();
322
}
323
324
SDValue getGatherScatterScale(SDValue Op) {
325
if (auto *N = dyn_cast<MaskedGatherScatterSDNode>(Op.getNode()))
326
return N->getScale();
327
if (auto *N = dyn_cast<VPGatherScatterSDNode>(Op.getNode()))
328
return N->getScale();
329
return SDValue();
330
}
331
332
SDValue getStoredValue(SDValue Op) {
333
switch (Op->getOpcode()) {
334
case ISD::EXPERIMENTAL_VP_STRIDED_STORE:
335
case VEISD::VVP_STORE:
336
return Op->getOperand(1);
337
}
338
if (auto *StoreN = dyn_cast<StoreSDNode>(Op.getNode()))
339
return StoreN->getValue();
340
if (auto *StoreN = dyn_cast<MaskedStoreSDNode>(Op.getNode()))
341
return StoreN->getValue();
342
if (auto *StoreN = dyn_cast<VPStridedStoreSDNode>(Op.getNode()))
343
return StoreN->getValue();
344
if (auto *StoreN = dyn_cast<VPStoreSDNode>(Op.getNode()))
345
return StoreN->getValue();
346
if (auto *StoreN = dyn_cast<MaskedScatterSDNode>(Op.getNode()))
347
return StoreN->getValue();
348
if (auto *StoreN = dyn_cast<VPScatterSDNode>(Op.getNode()))
349
return StoreN->getValue();
350
return SDValue();
351
}
352
353
SDValue getNodePassthru(SDValue Op) {
354
if (auto *N = dyn_cast<MaskedLoadSDNode>(Op.getNode()))
355
return N->getPassThru();
356
if (auto *N = dyn_cast<MaskedGatherSDNode>(Op.getNode()))
357
return N->getPassThru();
358
359
return SDValue();
360
}
361
362
bool hasReductionStartParam(unsigned OPC) {
363
// TODO: Ordered reduction opcodes.
364
if (ISD::isVPReduction(OPC))
365
return true;
366
return false;
367
}
368
369
unsigned getScalarReductionOpcode(unsigned VVPOC, bool IsMask) {
370
assert(!IsMask && "Mask reduction isel");
371
372
switch (VVPOC) {
373
#define HANDLE_VVP_REDUCE_TO_SCALAR(VVP_RED_ISD, REDUCE_ISD) \
374
case VEISD::VVP_RED_ISD: \
375
return ISD::REDUCE_ISD;
376
#include "VVPNodes.def"
377
default:
378
break;
379
}
380
llvm_unreachable("Cannot not scalarize this reduction Opcode!");
381
}
382
383
/// } Node Properties
384
385
SDValue getNodeAVL(SDValue Op) {
386
auto PosOpt = getAVLPos(Op->getOpcode());
387
return PosOpt ? Op->getOperand(*PosOpt) : SDValue();
388
}
389
390
SDValue getNodeMask(SDValue Op) {
391
auto PosOpt = getMaskPos(Op->getOpcode());
392
return PosOpt ? Op->getOperand(*PosOpt) : SDValue();
393
}
394
395
std::pair<SDValue, bool> getAnnotatedNodeAVL(SDValue Op) {
396
SDValue AVL = getNodeAVL(Op);
397
if (!AVL)
398
return {SDValue(), true};
399
if (isLegalAVL(AVL))
400
return {AVL->getOperand(0), true};
401
return {AVL, false};
402
}
403
404
SDValue VECustomDAG::getConstant(uint64_t Val, EVT VT, bool IsTarget,
405
bool IsOpaque) const {
406
return DAG.getConstant(Val, DL, VT, IsTarget, IsOpaque);
407
}
408
409
SDValue VECustomDAG::getConstantMask(Packing Packing, bool AllTrue) const {
410
auto MaskVT = getLegalVectorType(Packing, MVT::i1);
411
412
// VEISelDAGtoDAG will replace this pattern with the constant-true VM.
413
auto TrueVal = DAG.getConstant(-1, DL, MVT::i32);
414
auto AVL = getConstant(MaskVT.getVectorNumElements(), MVT::i32);
415
auto Res = getNode(VEISD::VEC_BROADCAST, MaskVT, {TrueVal, AVL});
416
if (AllTrue)
417
return Res;
418
419
return DAG.getNOT(DL, Res, Res.getValueType());
420
}
421
422
SDValue VECustomDAG::getMaskBroadcast(EVT ResultVT, SDValue Scalar,
423
SDValue AVL) const {
424
// Constant mask splat.
425
if (auto BcConst = dyn_cast<ConstantSDNode>(Scalar))
426
return getConstantMask(getTypePacking(ResultVT),
427
BcConst->getSExtValue() != 0);
428
429
// Expand the broadcast to a vector comparison.
430
auto ScalarBoolVT = Scalar.getSimpleValueType();
431
assert(ScalarBoolVT == MVT::i32);
432
433
// Cast to i32 ty.
434
SDValue CmpElem = DAG.getSExtOrTrunc(Scalar, DL, MVT::i32);
435
unsigned ElemCount = ResultVT.getVectorNumElements();
436
MVT CmpVecTy = MVT::getVectorVT(ScalarBoolVT, ElemCount);
437
438
// Broadcast to vector.
439
SDValue BCVec =
440
DAG.getNode(VEISD::VEC_BROADCAST, DL, CmpVecTy, {CmpElem, AVL});
441
SDValue ZeroVec =
442
getBroadcast(CmpVecTy, {DAG.getConstant(0, DL, ScalarBoolVT)}, AVL);
443
444
MVT BoolVecTy = MVT::getVectorVT(MVT::i1, ElemCount);
445
446
// Broadcast(Data) != Broadcast(0)
447
// TODO: Use a VVP operation for this.
448
return DAG.getSetCC(DL, BoolVecTy, BCVec, ZeroVec, ISD::CondCode::SETNE);
449
}
450
451
SDValue VECustomDAG::getBroadcast(EVT ResultVT, SDValue Scalar,
452
SDValue AVL) const {
453
assert(ResultVT.isVector());
454
auto ScaVT = Scalar.getValueType();
455
456
if (isMaskType(ResultVT))
457
return getMaskBroadcast(ResultVT, Scalar, AVL);
458
459
if (isPackedVectorType(ResultVT)) {
460
// v512x packed mode broadcast
461
// Replicate the scalar reg (f32 or i32) onto the opposing half of the full
462
// scalar register. If it's an I64 type, assume that this has already
463
// happened.
464
if (ScaVT == MVT::f32) {
465
Scalar = getNode(VEISD::REPL_F32, MVT::i64, Scalar);
466
} else if (ScaVT == MVT::i32) {
467
Scalar = getNode(VEISD::REPL_I32, MVT::i64, Scalar);
468
}
469
}
470
471
return getNode(VEISD::VEC_BROADCAST, ResultVT, {Scalar, AVL});
472
}
473
474
SDValue VECustomDAG::annotateLegalAVL(SDValue AVL) const {
475
if (isLegalAVL(AVL))
476
return AVL;
477
return getNode(VEISD::LEGALAVL, AVL.getValueType(), AVL);
478
}
479
480
SDValue VECustomDAG::getUnpack(EVT DestVT, SDValue Vec, PackElem Part,
481
SDValue AVL) const {
482
assert(getAnnotatedNodeAVL(AVL).second && "Expected a pack-legalized AVL");
483
484
// TODO: Peek through VEC_PACK and VEC_BROADCAST(REPL_<sth> ..) operands.
485
unsigned OC =
486
(Part == PackElem::Lo) ? VEISD::VEC_UNPACK_LO : VEISD::VEC_UNPACK_HI;
487
return DAG.getNode(OC, DL, DestVT, Vec, AVL);
488
}
489
490
SDValue VECustomDAG::getPack(EVT DestVT, SDValue LoVec, SDValue HiVec,
491
SDValue AVL) const {
492
assert(getAnnotatedNodeAVL(AVL).second && "Expected a pack-legalized AVL");
493
494
// TODO: Peek through VEC_UNPACK_LO|HI operands.
495
return DAG.getNode(VEISD::VEC_PACK, DL, DestVT, LoVec, HiVec, AVL);
496
}
497
498
VETargetMasks VECustomDAG::getTargetSplitMask(SDValue RawMask, SDValue RawAVL,
499
PackElem Part) const {
500
// Adjust AVL for this part
501
SDValue NewAVL;
502
SDValue OneV = getConstant(1, MVT::i32);
503
if (Part == PackElem::Hi)
504
NewAVL = getNode(ISD::ADD, MVT::i32, {RawAVL, OneV});
505
else
506
NewAVL = RawAVL;
507
NewAVL = getNode(ISD::SRL, MVT::i32, {NewAVL, OneV});
508
509
NewAVL = annotateLegalAVL(NewAVL);
510
511
// Legalize Mask (unpack or all-true)
512
SDValue NewMask;
513
if (!RawMask)
514
NewMask = getConstantMask(Packing::Normal, true);
515
else
516
NewMask = getUnpack(MVT::v256i1, RawMask, Part, NewAVL);
517
518
return VETargetMasks(NewMask, NewAVL);
519
}
520
521
SDValue VECustomDAG::getSplitPtrOffset(SDValue Ptr, SDValue ByteStride,
522
PackElem Part) const {
523
// High starts at base ptr but has more significant bits in the 64bit vector
524
// element.
525
if (Part == PackElem::Hi)
526
return Ptr;
527
return getNode(ISD::ADD, MVT::i64, {Ptr, ByteStride});
528
}
529
530
SDValue VECustomDAG::getSplitPtrStride(SDValue PackStride) const {
531
if (auto ConstBytes = dyn_cast<ConstantSDNode>(PackStride))
532
return getConstant(2 * ConstBytes->getSExtValue(), MVT::i64);
533
return getNode(ISD::SHL, MVT::i64, {PackStride, getConstant(1, MVT::i32)});
534
}
535
536
SDValue VECustomDAG::getGatherScatterAddress(SDValue BasePtr, SDValue Scale,
537
SDValue Index, SDValue Mask,
538
SDValue AVL) const {
539
EVT IndexVT = Index.getValueType();
540
541
// Apply scale.
542
SDValue ScaledIndex;
543
if (!Scale || isOneConstant(Scale))
544
ScaledIndex = Index;
545
else {
546
SDValue ScaleBroadcast = getBroadcast(IndexVT, Scale, AVL);
547
ScaledIndex =
548
getNode(VEISD::VVP_MUL, IndexVT, {Index, ScaleBroadcast, Mask, AVL});
549
}
550
551
// Add basePtr.
552
if (isNullConstant(BasePtr))
553
return ScaledIndex;
554
555
// re-constitute pointer vector (basePtr + index * scale)
556
SDValue BaseBroadcast = getBroadcast(IndexVT, BasePtr, AVL);
557
auto ResPtr =
558
getNode(VEISD::VVP_ADD, IndexVT, {BaseBroadcast, ScaledIndex, Mask, AVL});
559
return ResPtr;
560
}
561
562
SDValue VECustomDAG::getLegalReductionOpVVP(unsigned VVPOpcode, EVT ResVT,
563
SDValue StartV, SDValue VectorV,
564
SDValue Mask, SDValue AVL,
565
SDNodeFlags Flags) const {
566
567
// Optionally attach the start param with a scalar op (where it is
568
// unsupported).
569
bool scalarizeStartParam = StartV && !hasReductionStartParam(VVPOpcode);
570
bool IsMaskReduction = isMaskType(VectorV.getValueType());
571
assert(!IsMaskReduction && "TODO Implement");
572
auto AttachStartValue = [&](SDValue ReductionResV) {
573
if (!scalarizeStartParam)
574
return ReductionResV;
575
auto ScalarOC = getScalarReductionOpcode(VVPOpcode, IsMaskReduction);
576
return getNode(ScalarOC, ResVT, {StartV, ReductionResV});
577
};
578
579
// Fixup: Always Use sequential 'fmul' reduction.
580
if (!scalarizeStartParam && StartV) {
581
assert(hasReductionStartParam(VVPOpcode));
582
return AttachStartValue(
583
getNode(VVPOpcode, ResVT, {StartV, VectorV, Mask, AVL}, Flags));
584
} else
585
return AttachStartValue(
586
getNode(VVPOpcode, ResVT, {VectorV, Mask, AVL}, Flags));
587
}
588
589
} // namespace llvm
590
591