Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
freebsd
GitHub Repository: freebsd/freebsd-src
Path: blob/main/contrib/llvm-project/llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp
35294 views
1
//===--- SPIRVCallLowering.cpp - Call lowering ------------------*- C++ -*-===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
// See https://llvm.org/LICENSE.txt for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
//
9
// This file implements the lowering of LLVM calls to machine code calls for
10
// GlobalISel.
11
//
12
//===----------------------------------------------------------------------===//
13
14
#include "SPIRVCallLowering.h"
15
#include "MCTargetDesc/SPIRVBaseInfo.h"
16
#include "SPIRV.h"
17
#include "SPIRVBuiltins.h"
18
#include "SPIRVGlobalRegistry.h"
19
#include "SPIRVISelLowering.h"
20
#include "SPIRVMetadata.h"
21
#include "SPIRVRegisterInfo.h"
22
#include "SPIRVSubtarget.h"
23
#include "SPIRVUtils.h"
24
#include "llvm/CodeGen/FunctionLoweringInfo.h"
25
#include "llvm/IR/IntrinsicInst.h"
26
#include "llvm/IR/IntrinsicsSPIRV.h"
27
#include "llvm/Support/ModRef.h"
28
29
using namespace llvm;
30
31
SPIRVCallLowering::SPIRVCallLowering(const SPIRVTargetLowering &TLI,
32
SPIRVGlobalRegistry *GR)
33
: CallLowering(&TLI), GR(GR) {}
34
35
bool SPIRVCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder,
36
const Value *Val, ArrayRef<Register> VRegs,
37
FunctionLoweringInfo &FLI,
38
Register SwiftErrorVReg) const {
39
// Maybe run postponed production of types for function pointers
40
if (IndirectCalls.size() > 0) {
41
produceIndirectPtrTypes(MIRBuilder);
42
IndirectCalls.clear();
43
}
44
45
// Currently all return types should use a single register.
46
// TODO: handle the case of multiple registers.
47
if (VRegs.size() > 1)
48
return false;
49
if (Val) {
50
const auto &STI = MIRBuilder.getMF().getSubtarget();
51
return MIRBuilder.buildInstr(SPIRV::OpReturnValue)
52
.addUse(VRegs[0])
53
.constrainAllUses(MIRBuilder.getTII(), *STI.getRegisterInfo(),
54
*STI.getRegBankInfo());
55
}
56
MIRBuilder.buildInstr(SPIRV::OpReturn);
57
return true;
58
}
59
60
// Based on the LLVM function attributes, get a SPIR-V FunctionControl.
61
static uint32_t getFunctionControl(const Function &F) {
62
MemoryEffects MemEffects = F.getMemoryEffects();
63
64
uint32_t FuncControl = static_cast<uint32_t>(SPIRV::FunctionControl::None);
65
66
if (F.hasFnAttribute(Attribute::AttrKind::NoInline))
67
FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::DontInline);
68
else if (F.hasFnAttribute(Attribute::AttrKind::AlwaysInline))
69
FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Inline);
70
71
if (MemEffects.doesNotAccessMemory())
72
FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Pure);
73
else if (MemEffects.onlyReadsMemory())
74
FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Const);
75
76
return FuncControl;
77
}
78
79
static ConstantInt *getConstInt(MDNode *MD, unsigned NumOp) {
80
if (MD->getNumOperands() > NumOp) {
81
auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(NumOp));
82
if (CMeta)
83
return dyn_cast<ConstantInt>(CMeta->getValue());
84
}
85
return nullptr;
86
}
87
88
// If the function has pointer arguments, we are forced to re-create this
89
// function type from the very beginning, changing PointerType by
90
// TypedPointerType for each pointer argument. Otherwise, the same `Type*`
91
// potentially corresponds to different SPIR-V function type, effectively
92
// invalidating logic behind global registry and duplicates tracker.
93
static FunctionType *
94
fixFunctionTypeIfPtrArgs(SPIRVGlobalRegistry *GR, const Function &F,
95
FunctionType *FTy, const SPIRVType *SRetTy,
96
const SmallVector<SPIRVType *, 4> &SArgTys) {
97
if (F.getParent()->getNamedMetadata("spv.cloned_funcs"))
98
return FTy;
99
100
bool hasArgPtrs = false;
101
for (auto &Arg : F.args()) {
102
// check if it's an instance of a non-typed PointerType
103
if (Arg.getType()->isPointerTy()) {
104
hasArgPtrs = true;
105
break;
106
}
107
}
108
if (!hasArgPtrs) {
109
Type *RetTy = FTy->getReturnType();
110
// check if it's an instance of a non-typed PointerType
111
if (!RetTy->isPointerTy())
112
return FTy;
113
}
114
115
// re-create function type, using TypedPointerType instead of PointerType to
116
// properly trace argument types
117
const Type *RetTy = GR->getTypeForSPIRVType(SRetTy);
118
SmallVector<Type *, 4> ArgTys;
119
for (auto SArgTy : SArgTys)
120
ArgTys.push_back(const_cast<Type *>(GR->getTypeForSPIRVType(SArgTy)));
121
return FunctionType::get(const_cast<Type *>(RetTy), ArgTys, false);
122
}
123
124
// This code restores function args/retvalue types for composite cases
125
// because the final types should still be aggregate whereas they're i32
126
// during the translation to cope with aggregate flattening etc.
127
static FunctionType *getOriginalFunctionType(const Function &F) {
128
auto *NamedMD = F.getParent()->getNamedMetadata("spv.cloned_funcs");
129
if (NamedMD == nullptr)
130
return F.getFunctionType();
131
132
Type *RetTy = F.getFunctionType()->getReturnType();
133
SmallVector<Type *, 4> ArgTypes;
134
for (auto &Arg : F.args())
135
ArgTypes.push_back(Arg.getType());
136
137
auto ThisFuncMDIt =
138
std::find_if(NamedMD->op_begin(), NamedMD->op_end(), [&F](MDNode *N) {
139
return isa<MDString>(N->getOperand(0)) &&
140
cast<MDString>(N->getOperand(0))->getString() == F.getName();
141
});
142
// TODO: probably one function can have numerous type mutations,
143
// so we should support this.
144
if (ThisFuncMDIt != NamedMD->op_end()) {
145
auto *ThisFuncMD = *ThisFuncMDIt;
146
MDNode *MD = dyn_cast<MDNode>(ThisFuncMD->getOperand(1));
147
assert(MD && "MDNode operand is expected");
148
ConstantInt *Const = getConstInt(MD, 0);
149
if (Const) {
150
auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(1));
151
assert(CMeta && "ConstantAsMetadata operand is expected");
152
assert(Const->getSExtValue() >= -1);
153
// Currently -1 indicates return value, greater values mean
154
// argument numbers.
155
if (Const->getSExtValue() == -1)
156
RetTy = CMeta->getType();
157
else
158
ArgTypes[Const->getSExtValue()] = CMeta->getType();
159
}
160
}
161
162
return FunctionType::get(RetTy, ArgTypes, F.isVarArg());
163
}
164
165
static SPIRV::AccessQualifier::AccessQualifier
166
getArgAccessQual(const Function &F, unsigned ArgIdx) {
167
if (F.getCallingConv() != CallingConv::SPIR_KERNEL)
168
return SPIRV::AccessQualifier::ReadWrite;
169
170
MDString *ArgAttribute = getOCLKernelArgAccessQual(F, ArgIdx);
171
if (!ArgAttribute)
172
return SPIRV::AccessQualifier::ReadWrite;
173
174
if (ArgAttribute->getString() == "read_only")
175
return SPIRV::AccessQualifier::ReadOnly;
176
if (ArgAttribute->getString() == "write_only")
177
return SPIRV::AccessQualifier::WriteOnly;
178
return SPIRV::AccessQualifier::ReadWrite;
179
}
180
181
static std::vector<SPIRV::Decoration::Decoration>
182
getKernelArgTypeQual(const Function &F, unsigned ArgIdx) {
183
MDString *ArgAttribute = getOCLKernelArgTypeQual(F, ArgIdx);
184
if (ArgAttribute && ArgAttribute->getString() == "volatile")
185
return {SPIRV::Decoration::Volatile};
186
return {};
187
}
188
189
static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
190
SPIRVGlobalRegistry *GR,
191
MachineIRBuilder &MIRBuilder,
192
const SPIRVSubtarget &ST) {
193
// Read argument's access qualifier from metadata or default.
194
SPIRV::AccessQualifier::AccessQualifier ArgAccessQual =
195
getArgAccessQual(F, ArgIdx);
196
197
Type *OriginalArgType = getOriginalFunctionType(F)->getParamType(ArgIdx);
198
199
// If OriginalArgType is non-pointer, use the OriginalArgType (the type cannot
200
// be legally reassigned later).
201
if (!isPointerTy(OriginalArgType))
202
return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);
203
204
Argument *Arg = F.getArg(ArgIdx);
205
Type *ArgType = Arg->getType();
206
if (isTypedPointerTy(ArgType)) {
207
SPIRVType *ElementType = GR->getOrCreateSPIRVType(
208
cast<TypedPointerType>(ArgType)->getElementType(), MIRBuilder);
209
return GR->getOrCreateSPIRVPointerType(
210
ElementType, MIRBuilder,
211
addressSpaceToStorageClass(getPointerAddressSpace(ArgType), ST));
212
}
213
214
// In case OriginalArgType is of untyped pointer type, there are three
215
// possibilities:
216
// 1) This is a pointer of an LLVM IR element type, passed byval/byref.
217
// 2) This is an OpenCL/SPIR-V builtin type if there is spv_assign_type
218
// intrinsic assigning a TargetExtType.
219
// 3) This is a pointer, try to retrieve pointer element type from a
220
// spv_assign_ptr_type intrinsic or otherwise use default pointer element
221
// type.
222
if (hasPointeeTypeAttr(Arg)) {
223
SPIRVType *ElementType =
224
GR->getOrCreateSPIRVType(getPointeeTypeByAttr(Arg), MIRBuilder);
225
return GR->getOrCreateSPIRVPointerType(
226
ElementType, MIRBuilder,
227
addressSpaceToStorageClass(getPointerAddressSpace(ArgType), ST));
228
}
229
230
for (auto User : Arg->users()) {
231
auto *II = dyn_cast<IntrinsicInst>(User);
232
// Check if this is spv_assign_type assigning OpenCL/SPIR-V builtin type.
233
if (II && II->getIntrinsicID() == Intrinsic::spv_assign_type) {
234
MetadataAsValue *VMD = cast<MetadataAsValue>(II->getOperand(1));
235
Type *BuiltinType =
236
cast<ConstantAsMetadata>(VMD->getMetadata())->getType();
237
assert(BuiltinType->isTargetExtTy() && "Expected TargetExtType");
238
return GR->getOrCreateSPIRVType(BuiltinType, MIRBuilder, ArgAccessQual);
239
}
240
241
// Check if this is spv_assign_ptr_type assigning pointer element type.
242
if (!II || II->getIntrinsicID() != Intrinsic::spv_assign_ptr_type)
243
continue;
244
245
MetadataAsValue *VMD = cast<MetadataAsValue>(II->getOperand(1));
246
Type *ElementTy =
247
toTypedPointer(cast<ConstantAsMetadata>(VMD->getMetadata())->getType());
248
SPIRVType *ElementType = GR->getOrCreateSPIRVType(ElementTy, MIRBuilder);
249
return GR->getOrCreateSPIRVPointerType(
250
ElementType, MIRBuilder,
251
addressSpaceToStorageClass(
252
cast<ConstantInt>(II->getOperand(2))->getZExtValue(), ST));
253
}
254
255
// Replace PointerType with TypedPointerType to be able to map SPIR-V types to
256
// LLVM types in a consistent manner
257
return GR->getOrCreateSPIRVType(toTypedPointer(OriginalArgType), MIRBuilder,
258
ArgAccessQual);
259
}
260
261
static SPIRV::ExecutionModel::ExecutionModel
262
getExecutionModel(const SPIRVSubtarget &STI, const Function &F) {
263
if (STI.isOpenCLEnv())
264
return SPIRV::ExecutionModel::Kernel;
265
266
auto attribute = F.getFnAttribute("hlsl.shader");
267
if (!attribute.isValid()) {
268
report_fatal_error(
269
"This entry point lacks mandatory hlsl.shader attribute.");
270
}
271
272
const auto value = attribute.getValueAsString();
273
if (value == "compute")
274
return SPIRV::ExecutionModel::GLCompute;
275
276
report_fatal_error("This HLSL entry point is not supported by this backend.");
277
}
278
279
bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
280
const Function &F,
281
ArrayRef<ArrayRef<Register>> VRegs,
282
FunctionLoweringInfo &FLI) const {
283
assert(GR && "Must initialize the SPIRV type registry before lowering args.");
284
GR->setCurrentFunc(MIRBuilder.getMF());
285
286
// Get access to information about available extensions
287
const SPIRVSubtarget *ST =
288
static_cast<const SPIRVSubtarget *>(&MIRBuilder.getMF().getSubtarget());
289
290
// Assign types and names to all args, and store their types for later.
291
SmallVector<SPIRVType *, 4> ArgTypeVRegs;
292
if (VRegs.size() > 0) {
293
unsigned i = 0;
294
for (const auto &Arg : F.args()) {
295
// Currently formal args should use single registers.
296
// TODO: handle the case of multiple registers.
297
if (VRegs[i].size() > 1)
298
return false;
299
auto *SpirvTy = getArgSPIRVType(F, i, GR, MIRBuilder, *ST);
300
GR->assignSPIRVTypeToVReg(SpirvTy, VRegs[i][0], MIRBuilder.getMF());
301
ArgTypeVRegs.push_back(SpirvTy);
302
303
if (Arg.hasName())
304
buildOpName(VRegs[i][0], Arg.getName(), MIRBuilder);
305
if (isPointerTy(Arg.getType())) {
306
auto DerefBytes = static_cast<unsigned>(Arg.getDereferenceableBytes());
307
if (DerefBytes != 0)
308
buildOpDecorate(VRegs[i][0], MIRBuilder,
309
SPIRV::Decoration::MaxByteOffset, {DerefBytes});
310
}
311
if (Arg.hasAttribute(Attribute::Alignment)) {
312
auto Alignment = static_cast<unsigned>(
313
Arg.getAttribute(Attribute::Alignment).getValueAsInt());
314
buildOpDecorate(VRegs[i][0], MIRBuilder, SPIRV::Decoration::Alignment,
315
{Alignment});
316
}
317
if (Arg.hasAttribute(Attribute::ReadOnly)) {
318
auto Attr =
319
static_cast<unsigned>(SPIRV::FunctionParameterAttribute::NoWrite);
320
buildOpDecorate(VRegs[i][0], MIRBuilder,
321
SPIRV::Decoration::FuncParamAttr, {Attr});
322
}
323
if (Arg.hasAttribute(Attribute::ZExt)) {
324
auto Attr =
325
static_cast<unsigned>(SPIRV::FunctionParameterAttribute::Zext);
326
buildOpDecorate(VRegs[i][0], MIRBuilder,
327
SPIRV::Decoration::FuncParamAttr, {Attr});
328
}
329
if (Arg.hasAttribute(Attribute::NoAlias)) {
330
auto Attr =
331
static_cast<unsigned>(SPIRV::FunctionParameterAttribute::NoAlias);
332
buildOpDecorate(VRegs[i][0], MIRBuilder,
333
SPIRV::Decoration::FuncParamAttr, {Attr});
334
}
335
if (Arg.hasAttribute(Attribute::ByVal)) {
336
auto Attr =
337
static_cast<unsigned>(SPIRV::FunctionParameterAttribute::ByVal);
338
buildOpDecorate(VRegs[i][0], MIRBuilder,
339
SPIRV::Decoration::FuncParamAttr, {Attr});
340
}
341
342
if (F.getCallingConv() == CallingConv::SPIR_KERNEL) {
343
std::vector<SPIRV::Decoration::Decoration> ArgTypeQualDecs =
344
getKernelArgTypeQual(F, i);
345
for (SPIRV::Decoration::Decoration Decoration : ArgTypeQualDecs)
346
buildOpDecorate(VRegs[i][0], MIRBuilder, Decoration, {});
347
}
348
349
MDNode *Node = F.getMetadata("spirv.ParameterDecorations");
350
if (Node && i < Node->getNumOperands() &&
351
isa<MDNode>(Node->getOperand(i))) {
352
MDNode *MD = cast<MDNode>(Node->getOperand(i));
353
for (const MDOperand &MDOp : MD->operands()) {
354
MDNode *MD2 = dyn_cast<MDNode>(MDOp);
355
assert(MD2 && "Metadata operand is expected");
356
ConstantInt *Const = getConstInt(MD2, 0);
357
assert(Const && "MDOperand should be ConstantInt");
358
auto Dec =
359
static_cast<SPIRV::Decoration::Decoration>(Const->getZExtValue());
360
std::vector<uint32_t> DecVec;
361
for (unsigned j = 1; j < MD2->getNumOperands(); j++) {
362
ConstantInt *Const = getConstInt(MD2, j);
363
assert(Const && "MDOperand should be ConstantInt");
364
DecVec.push_back(static_cast<uint32_t>(Const->getZExtValue()));
365
}
366
buildOpDecorate(VRegs[i][0], MIRBuilder, Dec, DecVec);
367
}
368
}
369
++i;
370
}
371
}
372
373
auto MRI = MIRBuilder.getMRI();
374
Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
375
MRI->setRegClass(FuncVReg, &SPIRV::IDRegClass);
376
if (F.isDeclaration())
377
GR->add(&F, &MIRBuilder.getMF(), FuncVReg);
378
FunctionType *FTy = getOriginalFunctionType(F);
379
Type *FRetTy = FTy->getReturnType();
380
if (isUntypedPointerTy(FRetTy)) {
381
if (Type *FRetElemTy = GR->findDeducedElementType(&F)) {
382
TypedPointerType *DerivedTy = TypedPointerType::get(
383
toTypedPointer(FRetElemTy), getPointerAddressSpace(FRetTy));
384
GR->addReturnType(&F, DerivedTy);
385
FRetTy = DerivedTy;
386
}
387
}
388
SPIRVType *RetTy = GR->getOrCreateSPIRVType(FRetTy, MIRBuilder);
389
FTy = fixFunctionTypeIfPtrArgs(GR, F, FTy, RetTy, ArgTypeVRegs);
390
SPIRVType *FuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
391
FTy, RetTy, ArgTypeVRegs, MIRBuilder);
392
uint32_t FuncControl = getFunctionControl(F);
393
394
// Add OpFunction instruction
395
MachineInstrBuilder MB = MIRBuilder.buildInstr(SPIRV::OpFunction)
396
.addDef(FuncVReg)
397
.addUse(GR->getSPIRVTypeID(RetTy))
398
.addImm(FuncControl)
399
.addUse(GR->getSPIRVTypeID(FuncTy));
400
GR->recordFunctionDefinition(&F, &MB.getInstr()->getOperand(0));
401
402
// Add OpFunctionParameter instructions
403
int i = 0;
404
for (const auto &Arg : F.args()) {
405
assert(VRegs[i].size() == 1 && "Formal arg has multiple vregs");
406
MRI->setRegClass(VRegs[i][0], &SPIRV::IDRegClass);
407
MIRBuilder.buildInstr(SPIRV::OpFunctionParameter)
408
.addDef(VRegs[i][0])
409
.addUse(GR->getSPIRVTypeID(ArgTypeVRegs[i]));
410
if (F.isDeclaration())
411
GR->add(&Arg, &MIRBuilder.getMF(), VRegs[i][0]);
412
i++;
413
}
414
// Name the function.
415
if (F.hasName())
416
buildOpName(FuncVReg, F.getName(), MIRBuilder);
417
418
// Handle entry points and function linkage.
419
if (isEntryPoint(F)) {
420
const auto &STI = MIRBuilder.getMF().getSubtarget<SPIRVSubtarget>();
421
auto executionModel = getExecutionModel(STI, F);
422
auto MIB = MIRBuilder.buildInstr(SPIRV::OpEntryPoint)
423
.addImm(static_cast<uint32_t>(executionModel))
424
.addUse(FuncVReg);
425
addStringImm(F.getName(), MIB);
426
} else if (F.getLinkage() != GlobalValue::InternalLinkage &&
427
F.getLinkage() != GlobalValue::PrivateLinkage) {
428
SPIRV::LinkageType::LinkageType LnkTy =
429
F.isDeclaration()
430
? SPIRV::LinkageType::Import
431
: (F.getLinkage() == GlobalValue::LinkOnceODRLinkage &&
432
ST->canUseExtension(
433
SPIRV::Extension::SPV_KHR_linkonce_odr)
434
? SPIRV::LinkageType::LinkOnceODR
435
: SPIRV::LinkageType::Export);
436
buildOpDecorate(FuncVReg, MIRBuilder, SPIRV::Decoration::LinkageAttributes,
437
{static_cast<uint32_t>(LnkTy)}, F.getGlobalIdentifier());
438
}
439
440
// Handle function pointers decoration
441
bool hasFunctionPointers =
442
ST->canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers);
443
if (hasFunctionPointers) {
444
if (F.hasFnAttribute("referenced-indirectly")) {
445
assert((F.getCallingConv() != CallingConv::SPIR_KERNEL) &&
446
"Unexpected 'referenced-indirectly' attribute of the kernel "
447
"function");
448
buildOpDecorate(FuncVReg, MIRBuilder,
449
SPIRV::Decoration::ReferencedIndirectlyINTEL, {});
450
}
451
}
452
453
return true;
454
}
455
456
// Used to postpone producing of indirect function pointer types after all
457
// indirect calls info is collected
458
// TODO:
459
// - add a topological sort of IndirectCalls to ensure the best types knowledge
460
// - we may need to fix function formal parameter types if they are opaque
461
// pointers used as function pointers in these indirect calls
462
void SPIRVCallLowering::produceIndirectPtrTypes(
463
MachineIRBuilder &MIRBuilder) const {
464
// Create indirect call data types if any
465
MachineFunction &MF = MIRBuilder.getMF();
466
for (auto const &IC : IndirectCalls) {
467
SPIRVType *SpirvRetTy = GR->getOrCreateSPIRVType(IC.RetTy, MIRBuilder);
468
SmallVector<SPIRVType *, 4> SpirvArgTypes;
469
for (size_t i = 0; i < IC.ArgTys.size(); ++i) {
470
SPIRVType *SPIRVTy = GR->getOrCreateSPIRVType(IC.ArgTys[i], MIRBuilder);
471
SpirvArgTypes.push_back(SPIRVTy);
472
if (!GR->getSPIRVTypeForVReg(IC.ArgRegs[i]))
473
GR->assignSPIRVTypeToVReg(SPIRVTy, IC.ArgRegs[i], MF);
474
}
475
// SPIR-V function type:
476
FunctionType *FTy =
477
FunctionType::get(const_cast<Type *>(IC.RetTy), IC.ArgTys, false);
478
SPIRVType *SpirvFuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
479
FTy, SpirvRetTy, SpirvArgTypes, MIRBuilder);
480
// SPIR-V pointer to function type:
481
SPIRVType *IndirectFuncPtrTy = GR->getOrCreateSPIRVPointerType(
482
SpirvFuncTy, MIRBuilder, SPIRV::StorageClass::Function);
483
// Correct the Callee type
484
GR->assignSPIRVTypeToVReg(IndirectFuncPtrTy, IC.Callee, MF);
485
}
486
}
487
488
bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
489
CallLoweringInfo &Info) const {
490
// Currently call returns should have single vregs.
491
// TODO: handle the case of multiple registers.
492
if (Info.OrigRet.Regs.size() > 1)
493
return false;
494
MachineFunction &MF = MIRBuilder.getMF();
495
GR->setCurrentFunc(MF);
496
const Function *CF = nullptr;
497
std::string DemangledName;
498
const Type *OrigRetTy = Info.OrigRet.Ty;
499
500
// Emit a regular OpFunctionCall. If it's an externally declared function,
501
// be sure to emit its type and function declaration here. It will be hoisted
502
// globally later.
503
if (Info.Callee.isGlobal()) {
504
std::string FuncName = Info.Callee.getGlobal()->getName().str();
505
DemangledName = getOclOrSpirvBuiltinDemangledName(FuncName);
506
CF = dyn_cast_or_null<const Function>(Info.Callee.getGlobal());
507
// TODO: support constexpr casts and indirect calls.
508
if (CF == nullptr)
509
return false;
510
if (FunctionType *FTy = getOriginalFunctionType(*CF)) {
511
OrigRetTy = FTy->getReturnType();
512
if (isUntypedPointerTy(OrigRetTy)) {
513
if (auto *DerivedRetTy = GR->findReturnType(CF))
514
OrigRetTy = DerivedRetTy;
515
}
516
}
517
}
518
519
MachineRegisterInfo *MRI = MIRBuilder.getMRI();
520
Register ResVReg =
521
Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0];
522
const auto *ST = static_cast<const SPIRVSubtarget *>(&MF.getSubtarget());
523
524
bool isFunctionDecl = CF && CF->isDeclaration();
525
bool canUseOpenCL = ST->canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std);
526
bool canUseGLSL = ST->canUseExtInstSet(SPIRV::InstructionSet::GLSL_std_450);
527
assert(canUseGLSL != canUseOpenCL &&
528
"Scenario where both sets are enabled is not supported.");
529
530
if (isFunctionDecl && !DemangledName.empty() &&
531
(canUseGLSL || canUseOpenCL)) {
532
SmallVector<Register, 8> ArgVRegs;
533
for (auto Arg : Info.OrigArgs) {
534
assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs");
535
ArgVRegs.push_back(Arg.Regs[0]);
536
SPIRVType *SPIRVTy = GR->getOrCreateSPIRVType(Arg.Ty, MIRBuilder);
537
if (!GR->getSPIRVTypeForVReg(Arg.Regs[0]))
538
GR->assignSPIRVTypeToVReg(SPIRVTy, Arg.Regs[0], MF);
539
}
540
auto instructionSet = canUseOpenCL ? SPIRV::InstructionSet::OpenCL_std
541
: SPIRV::InstructionSet::GLSL_std_450;
542
if (auto Res =
543
SPIRV::lowerBuiltin(DemangledName, instructionSet, MIRBuilder,
544
ResVReg, OrigRetTy, ArgVRegs, GR))
545
return *Res;
546
}
547
548
if (isFunctionDecl && !GR->find(CF, &MF).isValid()) {
549
// Emit the type info and forward function declaration to the first MBB
550
// to ensure VReg definition dependencies are valid across all MBBs.
551
MachineIRBuilder FirstBlockBuilder;
552
FirstBlockBuilder.setMF(MF);
553
FirstBlockBuilder.setMBB(*MF.getBlockNumbered(0));
554
555
SmallVector<ArrayRef<Register>, 8> VRegArgs;
556
SmallVector<SmallVector<Register, 1>, 8> ToInsert;
557
for (const Argument &Arg : CF->args()) {
558
if (MIRBuilder.getDataLayout().getTypeStoreSize(Arg.getType()).isZero())
559
continue; // Don't handle zero sized types.
560
Register Reg = MRI->createGenericVirtualRegister(LLT::scalar(32));
561
MRI->setRegClass(Reg, &SPIRV::IDRegClass);
562
ToInsert.push_back({Reg});
563
VRegArgs.push_back(ToInsert.back());
564
}
565
// TODO: Reuse FunctionLoweringInfo
566
FunctionLoweringInfo FuncInfo;
567
lowerFormalArguments(FirstBlockBuilder, *CF, VRegArgs, FuncInfo);
568
}
569
570
unsigned CallOp;
571
if (Info.CB->isIndirectCall()) {
572
if (!ST->canUseExtension(SPIRV::Extension::SPV_INTEL_function_pointers))
573
report_fatal_error("An indirect call is encountered but SPIR-V without "
574
"extensions does not support it",
575
false);
576
// Set instruction operation according to SPV_INTEL_function_pointers
577
CallOp = SPIRV::OpFunctionPointerCallINTEL;
578
// Collect information about the indirect call to support possible
579
// specification of opaque ptr types of parent function's parameters
580
Register CalleeReg = Info.Callee.getReg();
581
if (CalleeReg.isValid()) {
582
SPIRVCallLowering::SPIRVIndirectCall IndirectCall;
583
IndirectCall.Callee = CalleeReg;
584
IndirectCall.RetTy = OrigRetTy;
585
for (const auto &Arg : Info.OrigArgs) {
586
assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs");
587
IndirectCall.ArgTys.push_back(Arg.Ty);
588
IndirectCall.ArgRegs.push_back(Arg.Regs[0]);
589
}
590
IndirectCalls.push_back(IndirectCall);
591
}
592
} else {
593
// Emit a regular OpFunctionCall
594
CallOp = SPIRV::OpFunctionCall;
595
}
596
597
// Make sure there's a valid return reg, even for functions returning void.
598
if (!ResVReg.isValid())
599
ResVReg = MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass);
600
SPIRVType *RetType = GR->assignTypeToVReg(OrigRetTy, ResVReg, MIRBuilder);
601
602
// Emit the call instruction and its args.
603
auto MIB = MIRBuilder.buildInstr(CallOp)
604
.addDef(ResVReg)
605
.addUse(GR->getSPIRVTypeID(RetType))
606
.add(Info.Callee);
607
608
for (const auto &Arg : Info.OrigArgs) {
609
// Currently call args should have single vregs.
610
if (Arg.Regs.size() > 1)
611
return false;
612
MIB.addUse(Arg.Regs[0]);
613
}
614
return MIB.constrainAllUses(MIRBuilder.getTII(), *ST->getRegisterInfo(),
615
*ST->getRegBankInfo());
616
}
617
618