Path: blob/main/contrib/llvm-project/llvm/lib/Target/DirectX/DXILOpBuilder.cpp
35266 views
//===- DXILOpBuilder.cpp - Helper class for build DIXLOp functions --------===//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.3// See https://llvm.org/LICENSE.txt for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===----------------------------------------------------------------------===//7///8/// \file This file contains class to help build DXIL op functions.9//===----------------------------------------------------------------------===//1011#include "DXILOpBuilder.h"12#include "DXILConstants.h"13#include "llvm/IR/IRBuilder.h"14#include "llvm/IR/Module.h"15#include "llvm/Support/DXILABI.h"16#include "llvm/Support/ErrorHandling.h"1718using namespace llvm;19using namespace llvm::dxil;2021constexpr StringLiteral DXILOpNamePrefix = "dx.op.";2223namespace {2425enum OverloadKind : uint16_t {26VOID = 1,27HALF = 1 << 1,28FLOAT = 1 << 2,29DOUBLE = 1 << 3,30I1 = 1 << 4,31I8 = 1 << 5,32I16 = 1 << 6,33I32 = 1 << 7,34I64 = 1 << 8,35UserDefineType = 1 << 9,36ObjectType = 1 << 10,37};3839} // namespace4041static const char *getOverloadTypeName(OverloadKind Kind) {42switch (Kind) {43case OverloadKind::HALF:44return "f16";45case OverloadKind::FLOAT:46return "f32";47case OverloadKind::DOUBLE:48return "f64";49case OverloadKind::I1:50return "i1";51case OverloadKind::I8:52return "i8";53case OverloadKind::I16:54return "i16";55case OverloadKind::I32:56return "i32";57case OverloadKind::I64:58return "i64";59case OverloadKind::VOID:60case OverloadKind::ObjectType:61case OverloadKind::UserDefineType:62break;63}64llvm_unreachable("invalid overload type for name");65return "void";66}6768static OverloadKind getOverloadKind(Type *Ty) {69Type::TypeID T = Ty->getTypeID();70switch (T) {71case Type::VoidTyID:72return OverloadKind::VOID;73case Type::HalfTyID:74return OverloadKind::HALF;75case Type::FloatTyID:76return OverloadKind::FLOAT;77case Type::DoubleTyID:78return OverloadKind::DOUBLE;79case Type::IntegerTyID: {80IntegerType *ITy = cast<IntegerType>(Ty);81unsigned Bits = ITy->getBitWidth();82switch (Bits) {83case 1:84return OverloadKind::I1;85case 8:86return OverloadKind::I8;87case 16:88return OverloadKind::I16;89case 32:90return OverloadKind::I32;91case 64:92return OverloadKind::I64;93default:94llvm_unreachable("invalid overload type");95return OverloadKind::VOID;96}97}98case Type::PointerTyID:99return OverloadKind::UserDefineType;100case Type::StructTyID:101return OverloadKind::ObjectType;102default:103llvm_unreachable("invalid overload type");104return OverloadKind::VOID;105}106}107108static std::string getTypeName(OverloadKind Kind, Type *Ty) {109if (Kind < OverloadKind::UserDefineType) {110return getOverloadTypeName(Kind);111} else if (Kind == OverloadKind::UserDefineType) {112StructType *ST = cast<StructType>(Ty);113return ST->getStructName().str();114} else if (Kind == OverloadKind::ObjectType) {115StructType *ST = cast<StructType>(Ty);116return ST->getStructName().str();117} else {118std::string Str;119raw_string_ostream OS(Str);120Ty->print(OS);121return OS.str();122}123}124125// Static properties.126struct OpCodeProperty {127dxil::OpCode OpCode;128// Offset in DXILOpCodeNameTable.129unsigned OpCodeNameOffset;130dxil::OpCodeClass OpCodeClass;131// Offset in DXILOpCodeClassNameTable.132unsigned OpCodeClassNameOffset;133uint16_t OverloadTys;134llvm::Attribute::AttrKind FuncAttr;135int OverloadParamIndex; // parameter index which control the overload.136// When < 0, should be only 1 overload type.137unsigned NumOfParameters; // Number of parameters include return value.138unsigned ParameterTableOffset; // Offset in ParameterTable.139};140141// Include getOpCodeClassName getOpCodeProperty, getOpCodeName and142// getOpCodeParameterKind which generated by tableGen.143#define DXIL_OP_OPERATION_TABLE144#include "DXILOperation.inc"145#undef DXIL_OP_OPERATION_TABLE146147static std::string constructOverloadName(OverloadKind Kind, Type *Ty,148const OpCodeProperty &Prop) {149if (Kind == OverloadKind::VOID) {150return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop)).str();151}152return (Twine(DXILOpNamePrefix) + getOpCodeClassName(Prop) + "." +153getTypeName(Kind, Ty))154.str();155}156157static std::string constructOverloadTypeName(OverloadKind Kind,158StringRef TypeName) {159if (Kind == OverloadKind::VOID)160return TypeName.str();161162assert(Kind < OverloadKind::UserDefineType && "invalid overload kind");163return (Twine(TypeName) + getOverloadTypeName(Kind)).str();164}165166static StructType *getOrCreateStructType(StringRef Name,167ArrayRef<Type *> EltTys,168LLVMContext &Ctx) {169StructType *ST = StructType::getTypeByName(Ctx, Name);170if (ST)171return ST;172173return StructType::create(Ctx, EltTys, Name);174}175176static StructType *getResRetType(Type *OverloadTy, LLVMContext &Ctx) {177OverloadKind Kind = getOverloadKind(OverloadTy);178std::string TypeName = constructOverloadTypeName(Kind, "dx.types.ResRet.");179Type *FieldTypes[5] = {OverloadTy, OverloadTy, OverloadTy, OverloadTy,180Type::getInt32Ty(Ctx)};181return getOrCreateStructType(TypeName, FieldTypes, Ctx);182}183184static StructType *getHandleType(LLVMContext &Ctx) {185return getOrCreateStructType("dx.types.Handle", PointerType::getUnqual(Ctx),186Ctx);187}188189static Type *getTypeFromParameterKind(ParameterKind Kind, Type *OverloadTy) {190auto &Ctx = OverloadTy->getContext();191switch (Kind) {192case ParameterKind::Void:193return Type::getVoidTy(Ctx);194case ParameterKind::Half:195return Type::getHalfTy(Ctx);196case ParameterKind::Float:197return Type::getFloatTy(Ctx);198case ParameterKind::Double:199return Type::getDoubleTy(Ctx);200case ParameterKind::I1:201return Type::getInt1Ty(Ctx);202case ParameterKind::I8:203return Type::getInt8Ty(Ctx);204case ParameterKind::I16:205return Type::getInt16Ty(Ctx);206case ParameterKind::I32:207return Type::getInt32Ty(Ctx);208case ParameterKind::I64:209return Type::getInt64Ty(Ctx);210case ParameterKind::Overload:211return OverloadTy;212case ParameterKind::ResourceRet:213return getResRetType(OverloadTy, Ctx);214case ParameterKind::DXILHandle:215return getHandleType(Ctx);216default:217break;218}219llvm_unreachable("Invalid parameter kind");220return nullptr;221}222223/// Construct DXIL function type. This is the type of a function with224/// the following prototype225/// OverloadType dx.op.<opclass>.<return-type>(int opcode, <param types>)226/// <param-types> are constructed from types in Prop.227/// \param Prop Structure containing DXIL Operation properties based on228/// its specification in DXIL.td.229/// \param OverloadTy Return type to be used to construct DXIL function type.230static FunctionType *getDXILOpFunctionType(const OpCodeProperty *Prop,231Type *ReturnTy, Type *OverloadTy) {232SmallVector<Type *> ArgTys;233234auto ParamKinds = getOpCodeParameterKind(*Prop);235236// Add ReturnTy as return type of the function237ArgTys.emplace_back(ReturnTy);238239// Add DXIL Opcode value type viz., Int32 as first argument240ArgTys.emplace_back(Type::getInt32Ty(OverloadTy->getContext()));241242// Add DXIL Operation parameter types as specified in DXIL properties243for (unsigned I = 0; I < Prop->NumOfParameters; ++I) {244ParameterKind Kind = ParamKinds[I];245ArgTys.emplace_back(getTypeFromParameterKind(Kind, OverloadTy));246}247return FunctionType::get(248ArgTys[0], ArrayRef<Type *>(&ArgTys[1], ArgTys.size() - 1), false);249}250251namespace llvm {252namespace dxil {253254CallInst *DXILOpBuilder::createDXILOpCall(dxil::OpCode OpCode, Type *ReturnTy,255Type *OverloadTy,256SmallVector<Value *> Args) {257const OpCodeProperty *Prop = getOpCodeProperty(OpCode);258259OverloadKind Kind = getOverloadKind(OverloadTy);260if ((Prop->OverloadTys & (uint16_t)Kind) == 0) {261report_fatal_error("Invalid Overload Type", /* gen_crash_diag=*/false);262}263264std::string DXILFnName = constructOverloadName(Kind, OverloadTy, *Prop);265FunctionCallee DXILFn;266// Get the function with name DXILFnName, if one exists267if (auto *Func = M.getFunction(DXILFnName)) {268DXILFn = FunctionCallee(Func);269} else {270// Construct and add a function with name DXILFnName271FunctionType *DXILOpFT = getDXILOpFunctionType(Prop, ReturnTy, OverloadTy);272DXILFn = M.getOrInsertFunction(DXILFnName, DXILOpFT);273}274275return B.CreateCall(DXILFn, Args);276}277278Type *DXILOpBuilder::getOverloadTy(dxil::OpCode OpCode, FunctionType *FT) {279280const OpCodeProperty *Prop = getOpCodeProperty(OpCode);281// If DXIL Op has no overload parameter, just return the282// precise return type specified.283if (Prop->OverloadParamIndex < 0) {284auto &Ctx = FT->getContext();285switch (Prop->OverloadTys) {286case OverloadKind::VOID:287return Type::getVoidTy(Ctx);288case OverloadKind::HALF:289return Type::getHalfTy(Ctx);290case OverloadKind::FLOAT:291return Type::getFloatTy(Ctx);292case OverloadKind::DOUBLE:293return Type::getDoubleTy(Ctx);294case OverloadKind::I1:295return Type::getInt1Ty(Ctx);296case OverloadKind::I8:297return Type::getInt8Ty(Ctx);298case OverloadKind::I16:299return Type::getInt16Ty(Ctx);300case OverloadKind::I32:301return Type::getInt32Ty(Ctx);302case OverloadKind::I64:303return Type::getInt64Ty(Ctx);304default:305llvm_unreachable("invalid overload type");306return nullptr;307}308}309310// Prop->OverloadParamIndex is 0, overload type is FT->getReturnType().311Type *OverloadType = FT->getReturnType();312if (Prop->OverloadParamIndex != 0) {313// Skip Return Type.314OverloadType = FT->getParamType(Prop->OverloadParamIndex - 1);315}316317auto ParamKinds = getOpCodeParameterKind(*Prop);318auto Kind = ParamKinds[Prop->OverloadParamIndex];319// For ResRet and CBufferRet, OverloadTy is in field of StructType.320if (Kind == ParameterKind::CBufferRet ||321Kind == ParameterKind::ResourceRet) {322auto *ST = cast<StructType>(OverloadType);323OverloadType = ST->getElementType(0);324}325return OverloadType;326}327328const char *DXILOpBuilder::getOpCodeName(dxil::OpCode DXILOp) {329return ::getOpCodeName(DXILOp);330}331} // namespace dxil332} // namespace llvm333334335