Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
freebsd
GitHub Repository: freebsd/freebsd-src
Path: blob/main/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
35266 views
1
//===- SPIRVLegalizerInfo.cpp --- SPIR-V Legalization Rules ------*- C++ -*-==//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
//
9
// This file implements the targeting of the Machinelegalizer class for SPIR-V.
10
//
11
//===----------------------------------------------------------------------===//
12
13
#include "SPIRVLegalizerInfo.h"
14
#include "SPIRV.h"
15
#include "SPIRVGlobalRegistry.h"
16
#include "SPIRVSubtarget.h"
17
#include "llvm/CodeGen/GlobalISel/LegalizerHelper.h"
18
#include "llvm/CodeGen/GlobalISel/MachineIRBuilder.h"
19
#include "llvm/CodeGen/MachineInstr.h"
20
#include "llvm/CodeGen/MachineRegisterInfo.h"
21
#include "llvm/CodeGen/TargetOpcodes.h"
22
23
using namespace llvm;
24
using namespace llvm::LegalizeActions;
25
using namespace llvm::LegalityPredicates;
26
27
static const std::set<unsigned> TypeFoldingSupportingOpcs = {
28
TargetOpcode::G_ADD,
29
TargetOpcode::G_FADD,
30
TargetOpcode::G_SUB,
31
TargetOpcode::G_FSUB,
32
TargetOpcode::G_MUL,
33
TargetOpcode::G_FMUL,
34
TargetOpcode::G_SDIV,
35
TargetOpcode::G_UDIV,
36
TargetOpcode::G_FDIV,
37
TargetOpcode::G_SREM,
38
TargetOpcode::G_UREM,
39
TargetOpcode::G_FREM,
40
TargetOpcode::G_FNEG,
41
TargetOpcode::G_CONSTANT,
42
TargetOpcode::G_FCONSTANT,
43
TargetOpcode::G_AND,
44
TargetOpcode::G_OR,
45
TargetOpcode::G_XOR,
46
TargetOpcode::G_SHL,
47
TargetOpcode::G_ASHR,
48
TargetOpcode::G_LSHR,
49
TargetOpcode::G_SELECT,
50
TargetOpcode::G_EXTRACT_VECTOR_ELT,
51
};
52
53
bool isTypeFoldingSupported(unsigned Opcode) {
54
return TypeFoldingSupportingOpcs.count(Opcode) > 0;
55
}
56
57
SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
58
using namespace TargetOpcode;
59
60
this->ST = &ST;
61
GR = ST.getSPIRVGlobalRegistry();
62
63
const LLT s1 = LLT::scalar(1);
64
const LLT s8 = LLT::scalar(8);
65
const LLT s16 = LLT::scalar(16);
66
const LLT s32 = LLT::scalar(32);
67
const LLT s64 = LLT::scalar(64);
68
69
const LLT v16s64 = LLT::fixed_vector(16, 64);
70
const LLT v16s32 = LLT::fixed_vector(16, 32);
71
const LLT v16s16 = LLT::fixed_vector(16, 16);
72
const LLT v16s8 = LLT::fixed_vector(16, 8);
73
const LLT v16s1 = LLT::fixed_vector(16, 1);
74
75
const LLT v8s64 = LLT::fixed_vector(8, 64);
76
const LLT v8s32 = LLT::fixed_vector(8, 32);
77
const LLT v8s16 = LLT::fixed_vector(8, 16);
78
const LLT v8s8 = LLT::fixed_vector(8, 8);
79
const LLT v8s1 = LLT::fixed_vector(8, 1);
80
81
const LLT v4s64 = LLT::fixed_vector(4, 64);
82
const LLT v4s32 = LLT::fixed_vector(4, 32);
83
const LLT v4s16 = LLT::fixed_vector(4, 16);
84
const LLT v4s8 = LLT::fixed_vector(4, 8);
85
const LLT v4s1 = LLT::fixed_vector(4, 1);
86
87
const LLT v3s64 = LLT::fixed_vector(3, 64);
88
const LLT v3s32 = LLT::fixed_vector(3, 32);
89
const LLT v3s16 = LLT::fixed_vector(3, 16);
90
const LLT v3s8 = LLT::fixed_vector(3, 8);
91
const LLT v3s1 = LLT::fixed_vector(3, 1);
92
93
const LLT v2s64 = LLT::fixed_vector(2, 64);
94
const LLT v2s32 = LLT::fixed_vector(2, 32);
95
const LLT v2s16 = LLT::fixed_vector(2, 16);
96
const LLT v2s8 = LLT::fixed_vector(2, 8);
97
const LLT v2s1 = LLT::fixed_vector(2, 1);
98
99
const unsigned PSize = ST.getPointerSize();
100
const LLT p0 = LLT::pointer(0, PSize); // Function
101
const LLT p1 = LLT::pointer(1, PSize); // CrossWorkgroup
102
const LLT p2 = LLT::pointer(2, PSize); // UniformConstant
103
const LLT p3 = LLT::pointer(3, PSize); // Workgroup
104
const LLT p4 = LLT::pointer(4, PSize); // Generic
105
const LLT p5 =
106
LLT::pointer(5, PSize); // Input, SPV_INTEL_usm_storage_classes (Device)
107
const LLT p6 = LLT::pointer(6, PSize); // SPV_INTEL_usm_storage_classes (Host)
108
109
// TODO: remove copy-pasting here by using concatenation in some way.
110
auto allPtrsScalarsAndVectors = {
111
p0, p1, p2, p3, p4, p5, p6, s1, s8, s16,
112
s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8, v3s16,
113
v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64, v8s1, v8s8, v8s16,
114
v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
115
116
auto allVectors = {v2s1, v2s8, v2s16, v2s32, v2s64, v3s1, v3s8,
117
v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32,
118
v4s64, v8s1, v8s8, v8s16, v8s32, v8s64, v16s1,
119
v16s8, v16s16, v16s32, v16s64};
120
121
auto allScalarsAndVectors = {
122
s1, s8, s16, s32, s64, v2s1, v2s8, v2s16, v2s32, v2s64,
123
v3s1, v3s8, v3s16, v3s32, v3s64, v4s1, v4s8, v4s16, v4s32, v4s64,
124
v8s1, v8s8, v8s16, v8s32, v8s64, v16s1, v16s8, v16s16, v16s32, v16s64};
125
126
auto allIntScalarsAndVectors = {s8, s16, s32, s64, v2s8, v2s16,
127
v2s32, v2s64, v3s8, v3s16, v3s32, v3s64,
128
v4s8, v4s16, v4s32, v4s64, v8s8, v8s16,
129
v8s32, v8s64, v16s8, v16s16, v16s32, v16s64};
130
131
auto allBoolScalarsAndVectors = {s1, v2s1, v3s1, v4s1, v8s1, v16s1};
132
133
auto allIntScalars = {s8, s16, s32, s64};
134
135
auto allFloatScalars = {s16, s32, s64};
136
137
auto allFloatScalarsAndVectors = {
138
s16, s32, s64, v2s16, v2s32, v2s64, v3s16, v3s32, v3s64,
139
v4s16, v4s32, v4s64, v8s16, v8s32, v8s64, v16s16, v16s32, v16s64};
140
141
auto allFloatAndIntScalarsAndPtrs = {s8, s16, s32, s64, p0, p1,
142
p2, p3, p4, p5, p6};
143
144
auto allPtrs = {p0, p1, p2, p3, p4, p5, p6};
145
auto allWritablePtrs = {p0, p1, p3, p4, p5, p6};
146
147
for (auto Opc : TypeFoldingSupportingOpcs)
148
getActionDefinitionsBuilder(Opc).custom();
149
150
getActionDefinitionsBuilder(G_GLOBAL_VALUE).alwaysLegal();
151
152
// TODO: add proper rules for vectors legalization.
153
getActionDefinitionsBuilder(
154
{G_BUILD_VECTOR, G_SHUFFLE_VECTOR, G_SPLAT_VECTOR})
155
.alwaysLegal();
156
157
// Vector Reduction Operations
158
getActionDefinitionsBuilder(
159
{G_VECREDUCE_SMIN, G_VECREDUCE_SMAX, G_VECREDUCE_UMIN, G_VECREDUCE_UMAX,
160
G_VECREDUCE_ADD, G_VECREDUCE_MUL, G_VECREDUCE_FMUL, G_VECREDUCE_FMIN,
161
G_VECREDUCE_FMAX, G_VECREDUCE_FMINIMUM, G_VECREDUCE_FMAXIMUM,
162
G_VECREDUCE_OR, G_VECREDUCE_AND, G_VECREDUCE_XOR})
163
.legalFor(allVectors)
164
.scalarize(1)
165
.lower();
166
167
getActionDefinitionsBuilder({G_VECREDUCE_SEQ_FADD, G_VECREDUCE_SEQ_FMUL})
168
.scalarize(2)
169
.lower();
170
171
// Merge/Unmerge
172
// TODO: add proper legalization rules.
173
getActionDefinitionsBuilder(G_UNMERGE_VALUES).alwaysLegal();
174
175
getActionDefinitionsBuilder({G_MEMCPY, G_MEMMOVE})
176
.legalIf(all(typeInSet(0, allWritablePtrs), typeInSet(1, allPtrs)));
177
178
getActionDefinitionsBuilder(G_MEMSET).legalIf(
179
all(typeInSet(0, allWritablePtrs), typeInSet(1, allIntScalars)));
180
181
getActionDefinitionsBuilder(G_ADDRSPACE_CAST)
182
.legalForCartesianProduct(allPtrs, allPtrs);
183
184
getActionDefinitionsBuilder({G_LOAD, G_STORE}).legalIf(typeInSet(1, allPtrs));
185
186
getActionDefinitionsBuilder(G_BITREVERSE).legalFor(allIntScalarsAndVectors);
187
188
getActionDefinitionsBuilder(G_FMA).legalFor(allFloatScalarsAndVectors);
189
190
getActionDefinitionsBuilder({G_FPTOSI, G_FPTOUI})
191
.legalForCartesianProduct(allIntScalarsAndVectors,
192
allFloatScalarsAndVectors);
193
194
getActionDefinitionsBuilder({G_SITOFP, G_UITOFP})
195
.legalForCartesianProduct(allFloatScalarsAndVectors,
196
allScalarsAndVectors);
197
198
getActionDefinitionsBuilder({G_SMIN, G_SMAX, G_UMIN, G_UMAX, G_ABS})
199
.legalFor(allIntScalarsAndVectors);
200
201
getActionDefinitionsBuilder(G_CTPOP).legalForCartesianProduct(
202
allIntScalarsAndVectors, allIntScalarsAndVectors);
203
204
getActionDefinitionsBuilder(G_PHI).legalFor(allPtrsScalarsAndVectors);
205
206
getActionDefinitionsBuilder(G_BITCAST).legalIf(
207
all(typeInSet(0, allPtrsScalarsAndVectors),
208
typeInSet(1, allPtrsScalarsAndVectors)));
209
210
getActionDefinitionsBuilder({G_IMPLICIT_DEF, G_FREEZE}).alwaysLegal();
211
212
getActionDefinitionsBuilder({G_STACKSAVE, G_STACKRESTORE}).alwaysLegal();
213
214
getActionDefinitionsBuilder(G_INTTOPTR)
215
.legalForCartesianProduct(allPtrs, allIntScalars);
216
getActionDefinitionsBuilder(G_PTRTOINT)
217
.legalForCartesianProduct(allIntScalars, allPtrs);
218
getActionDefinitionsBuilder(G_PTR_ADD).legalForCartesianProduct(
219
allPtrs, allIntScalars);
220
221
// ST.canDirectlyComparePointers() for pointer args is supported in
222
// legalizeCustom().
223
getActionDefinitionsBuilder(G_ICMP).customIf(
224
all(typeInSet(0, allBoolScalarsAndVectors),
225
typeInSet(1, allPtrsScalarsAndVectors)));
226
227
getActionDefinitionsBuilder(G_FCMP).legalIf(
228
all(typeInSet(0, allBoolScalarsAndVectors),
229
typeInSet(1, allFloatScalarsAndVectors)));
230
231
getActionDefinitionsBuilder({G_ATOMICRMW_OR, G_ATOMICRMW_ADD, G_ATOMICRMW_AND,
232
G_ATOMICRMW_MAX, G_ATOMICRMW_MIN,
233
G_ATOMICRMW_SUB, G_ATOMICRMW_XOR,
234
G_ATOMICRMW_UMAX, G_ATOMICRMW_UMIN})
235
.legalForCartesianProduct(allIntScalars, allWritablePtrs);
236
237
getActionDefinitionsBuilder(
238
{G_ATOMICRMW_FADD, G_ATOMICRMW_FSUB, G_ATOMICRMW_FMIN, G_ATOMICRMW_FMAX})
239
.legalForCartesianProduct(allFloatScalars, allWritablePtrs);
240
241
getActionDefinitionsBuilder(G_ATOMICRMW_XCHG)
242
.legalForCartesianProduct(allFloatAndIntScalarsAndPtrs, allWritablePtrs);
243
244
getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG_WITH_SUCCESS).lower();
245
// TODO: add proper legalization rules.
246
getActionDefinitionsBuilder(G_ATOMIC_CMPXCHG).alwaysLegal();
247
248
getActionDefinitionsBuilder({G_UADDO, G_USUBO, G_SMULO, G_UMULO})
249
.alwaysLegal();
250
251
// Extensions.
252
getActionDefinitionsBuilder({G_TRUNC, G_ZEXT, G_SEXT, G_ANYEXT})
253
.legalForCartesianProduct(allScalarsAndVectors);
254
255
// FP conversions.
256
getActionDefinitionsBuilder({G_FPTRUNC, G_FPEXT})
257
.legalForCartesianProduct(allFloatScalarsAndVectors);
258
259
// Pointer-handling.
260
getActionDefinitionsBuilder(G_FRAME_INDEX).legalFor({p0});
261
262
// Control-flow. In some cases (e.g. constants) s1 may be promoted to s32.
263
getActionDefinitionsBuilder(G_BRCOND).legalFor({s1, s32});
264
265
// TODO: Review the target OpenCL and GLSL Extended Instruction Set specs to
266
// tighten these requirements. Many of these math functions are only legal on
267
// specific bitwidths, so they are not selectable for
268
// allFloatScalarsAndVectors.
269
getActionDefinitionsBuilder({G_FPOW,
270
G_FEXP,
271
G_FEXP2,
272
G_FLOG,
273
G_FLOG2,
274
G_FLOG10,
275
G_FABS,
276
G_FMINNUM,
277
G_FMAXNUM,
278
G_FCEIL,
279
G_FCOS,
280
G_FSIN,
281
G_FTAN,
282
G_FACOS,
283
G_FASIN,
284
G_FATAN,
285
G_FCOSH,
286
G_FSINH,
287
G_FTANH,
288
G_FSQRT,
289
G_FFLOOR,
290
G_FRINT,
291
G_FNEARBYINT,
292
G_INTRINSIC_ROUND,
293
G_INTRINSIC_TRUNC,
294
G_FMINIMUM,
295
G_FMAXIMUM,
296
G_INTRINSIC_ROUNDEVEN})
297
.legalFor(allFloatScalarsAndVectors);
298
299
getActionDefinitionsBuilder(G_FCOPYSIGN)
300
.legalForCartesianProduct(allFloatScalarsAndVectors,
301
allFloatScalarsAndVectors);
302
303
getActionDefinitionsBuilder(G_FPOWI).legalForCartesianProduct(
304
allFloatScalarsAndVectors, allIntScalarsAndVectors);
305
306
if (ST.canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) {
307
getActionDefinitionsBuilder(
308
{G_CTTZ, G_CTTZ_ZERO_UNDEF, G_CTLZ, G_CTLZ_ZERO_UNDEF})
309
.legalForCartesianProduct(allIntScalarsAndVectors,
310
allIntScalarsAndVectors);
311
312
// Struct return types become a single scalar, so cannot easily legalize.
313
getActionDefinitionsBuilder({G_SMULH, G_UMULH}).alwaysLegal();
314
315
// supported saturation arithmetic
316
getActionDefinitionsBuilder({G_SADDSAT, G_UADDSAT, G_SSUBSAT, G_USUBSAT})
317
.legalFor(allIntScalarsAndVectors);
318
}
319
320
getLegacyLegalizerInfo().computeTables();
321
verify(*ST.getInstrInfo());
322
}
323
324
static Register convertPtrToInt(Register Reg, LLT ConvTy, SPIRVType *SpirvType,
325
LegalizerHelper &Helper,
326
MachineRegisterInfo &MRI,
327
SPIRVGlobalRegistry *GR) {
328
Register ConvReg = MRI.createGenericVirtualRegister(ConvTy);
329
GR->assignSPIRVTypeToVReg(SpirvType, ConvReg, Helper.MIRBuilder.getMF());
330
Helper.MIRBuilder.buildInstr(TargetOpcode::G_PTRTOINT)
331
.addDef(ConvReg)
332
.addUse(Reg);
333
return ConvReg;
334
}
335
336
bool SPIRVLegalizerInfo::legalizeCustom(
337
LegalizerHelper &Helper, MachineInstr &MI,
338
LostDebugLocObserver &LocObserver) const {
339
auto Opc = MI.getOpcode();
340
MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
341
if (!isTypeFoldingSupported(Opc)) {
342
assert(Opc == TargetOpcode::G_ICMP);
343
assert(GR->getSPIRVTypeForVReg(MI.getOperand(0).getReg()));
344
auto &Op0 = MI.getOperand(2);
345
auto &Op1 = MI.getOperand(3);
346
Register Reg0 = Op0.getReg();
347
Register Reg1 = Op1.getReg();
348
CmpInst::Predicate Cond =
349
static_cast<CmpInst::Predicate>(MI.getOperand(1).getPredicate());
350
if ((!ST->canDirectlyComparePointers() ||
351
(Cond != CmpInst::ICMP_EQ && Cond != CmpInst::ICMP_NE)) &&
352
MRI.getType(Reg0).isPointer() && MRI.getType(Reg1).isPointer()) {
353
LLT ConvT = LLT::scalar(ST->getPointerSize());
354
Type *LLVMTy = IntegerType::get(MI.getMF()->getFunction().getContext(),
355
ST->getPointerSize());
356
SPIRVType *SpirvTy = GR->getOrCreateSPIRVType(LLVMTy, Helper.MIRBuilder);
357
Op0.setReg(convertPtrToInt(Reg0, ConvT, SpirvTy, Helper, MRI, GR));
358
Op1.setReg(convertPtrToInt(Reg1, ConvT, SpirvTy, Helper, MRI, GR));
359
}
360
return true;
361
}
362
// TODO: implement legalization for other opcodes.
363
return true;
364
}
365
366