Path: blob/main/contrib/llvm-project/clang/lib/CodeGen/CGGPUBuiltin.cpp
35233 views
//===------ CGGPUBuiltin.cpp - Codegen for GPU builtins -------------------===//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.3// See https://llvm.org/LICENSE.txt for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===----------------------------------------------------------------------===//7//8// Generates code for built-in GPU calls which are not runtime-specific.9// (Runtime-specific codegen lives in programming model specific files.)10//11//===----------------------------------------------------------------------===//1213#include "CodeGenFunction.h"14#include "clang/Basic/Builtins.h"15#include "llvm/IR/DataLayout.h"16#include "llvm/IR/Instruction.h"17#include "llvm/Support/MathExtras.h"18#include "llvm/Transforms/Utils/AMDGPUEmitPrintf.h"1920using namespace clang;21using namespace CodeGen;2223namespace {24llvm::Function *GetVprintfDeclaration(llvm::Module &M) {25llvm::Type *ArgTypes[] = {llvm::PointerType::getUnqual(M.getContext()),26llvm::PointerType::getUnqual(M.getContext())};27llvm::FunctionType *VprintfFuncType = llvm::FunctionType::get(28llvm::Type::getInt32Ty(M.getContext()), ArgTypes, false);2930if (auto *F = M.getFunction("vprintf")) {31// Our CUDA system header declares vprintf with the right signature, so32// nobody else should have been able to declare vprintf with a bogus33// signature.34assert(F->getFunctionType() == VprintfFuncType);35return F;36}3738// vprintf doesn't already exist; create a declaration and insert it into the39// module.40return llvm::Function::Create(41VprintfFuncType, llvm::GlobalVariable::ExternalLinkage, "vprintf", &M);42}4344llvm::Function *GetOpenMPVprintfDeclaration(CodeGenModule &CGM) {45const char *Name = "__llvm_omp_vprintf";46llvm::Module &M = CGM.getModule();47llvm::Type *ArgTypes[] = {llvm::PointerType::getUnqual(M.getContext()),48llvm::PointerType::getUnqual(M.getContext()),49llvm::Type::getInt32Ty(M.getContext())};50llvm::FunctionType *VprintfFuncType = llvm::FunctionType::get(51llvm::Type::getInt32Ty(M.getContext()), ArgTypes, false);5253if (auto *F = M.getFunction(Name)) {54if (F->getFunctionType() != VprintfFuncType) {55CGM.Error(SourceLocation(),56"Invalid type declaration for __llvm_omp_vprintf");57return nullptr;58}59return F;60}6162return llvm::Function::Create(63VprintfFuncType, llvm::GlobalVariable::ExternalLinkage, Name, &M);64}6566// Transforms a call to printf into a call to the NVPTX vprintf syscall (which67// isn't particularly special; it's invoked just like a regular function).68// vprintf takes two args: A format string, and a pointer to a buffer containing69// the varargs.70//71// For example, the call72//73// printf("format string", arg1, arg2, arg3);74//75// is converted into something resembling76//77// struct Tmp {78// Arg1 a1;79// Arg2 a2;80// Arg3 a3;81// };82// char* buf = alloca(sizeof(Tmp));83// *(Tmp*)buf = {a1, a2, a3};84// vprintf("format string", buf);85//86// buf is aligned to the max of {alignof(Arg1), ...}. Furthermore, each of the87// args is itself aligned to its preferred alignment.88//89// Note that by the time this function runs, E's args have already undergone the90// standard C vararg promotion (short -> int, float -> double, etc.).9192std::pair<llvm::Value *, llvm::TypeSize>93packArgsIntoNVPTXFormatBuffer(CodeGenFunction *CGF, const CallArgList &Args) {94const llvm::DataLayout &DL = CGF->CGM.getDataLayout();95llvm::LLVMContext &Ctx = CGF->CGM.getLLVMContext();96CGBuilderTy &Builder = CGF->Builder;9798// Construct and fill the args buffer that we'll pass to vprintf.99if (Args.size() <= 1) {100// If there are no args, pass a null pointer and size 0101llvm::Value *BufferPtr =102llvm::ConstantPointerNull::get(llvm::PointerType::getUnqual(Ctx));103return {BufferPtr, llvm::TypeSize::getFixed(0)};104} else {105llvm::SmallVector<llvm::Type *, 8> ArgTypes;106for (unsigned I = 1, NumArgs = Args.size(); I < NumArgs; ++I)107ArgTypes.push_back(Args[I].getRValue(*CGF).getScalarVal()->getType());108109// Using llvm::StructType is correct only because printf doesn't accept110// aggregates. If we had to handle aggregates here, we'd have to manually111// compute the offsets within the alloca -- we wouldn't be able to assume112// that the alignment of the llvm type was the same as the alignment of the113// clang type.114llvm::Type *AllocaTy = llvm::StructType::create(ArgTypes, "printf_args");115llvm::Value *Alloca = CGF->CreateTempAlloca(AllocaTy);116117for (unsigned I = 1, NumArgs = Args.size(); I < NumArgs; ++I) {118llvm::Value *P = Builder.CreateStructGEP(AllocaTy, Alloca, I - 1);119llvm::Value *Arg = Args[I].getRValue(*CGF).getScalarVal();120Builder.CreateAlignedStore(Arg, P, DL.getPrefTypeAlign(Arg->getType()));121}122llvm::Value *BufferPtr =123Builder.CreatePointerCast(Alloca, llvm::PointerType::getUnqual(Ctx));124return {BufferPtr, DL.getTypeAllocSize(AllocaTy)};125}126}127128bool containsNonScalarVarargs(CodeGenFunction *CGF, const CallArgList &Args) {129return llvm::any_of(llvm::drop_begin(Args), [&](const CallArg &A) {130return !A.getRValue(*CGF).isScalar();131});132}133134RValue EmitDevicePrintfCallExpr(const CallExpr *E, CodeGenFunction *CGF,135llvm::Function *Decl, bool WithSizeArg) {136CodeGenModule &CGM = CGF->CGM;137CGBuilderTy &Builder = CGF->Builder;138assert(E->getBuiltinCallee() == Builtin::BIprintf ||139E->getBuiltinCallee() == Builtin::BI__builtin_printf);140assert(E->getNumArgs() >= 1); // printf always has at least one arg.141142// Uses the same format as nvptx for the argument packing, but also passes143// an i32 for the total size of the passed pointer144CallArgList Args;145CGF->EmitCallArgs(Args,146E->getDirectCallee()->getType()->getAs<FunctionProtoType>(),147E->arguments(), E->getDirectCallee(),148/* ParamsToSkip = */ 0);149150// We don't know how to emit non-scalar varargs.151if (containsNonScalarVarargs(CGF, Args)) {152CGM.ErrorUnsupported(E, "non-scalar arg to printf");153return RValue::get(llvm::ConstantInt::get(CGF->IntTy, 0));154}155156auto r = packArgsIntoNVPTXFormatBuffer(CGF, Args);157llvm::Value *BufferPtr = r.first;158159llvm::SmallVector<llvm::Value *, 3> Vec = {160Args[0].getRValue(*CGF).getScalarVal(), BufferPtr};161if (WithSizeArg) {162// Passing > 32bit of data as a local alloca doesn't work for nvptx or163// amdgpu164llvm::Constant *Size =165llvm::ConstantInt::get(llvm::Type::getInt32Ty(CGM.getLLVMContext()),166static_cast<uint32_t>(r.second.getFixedValue()));167168Vec.push_back(Size);169}170return RValue::get(Builder.CreateCall(Decl, Vec));171}172} // namespace173174RValue CodeGenFunction::EmitNVPTXDevicePrintfCallExpr(const CallExpr *E) {175assert(getTarget().getTriple().isNVPTX());176return EmitDevicePrintfCallExpr(177E, this, GetVprintfDeclaration(CGM.getModule()), false);178}179180RValue CodeGenFunction::EmitAMDGPUDevicePrintfCallExpr(const CallExpr *E) {181assert(getTarget().getTriple().isAMDGCN() ||182(getTarget().getTriple().isSPIRV() &&183getTarget().getTriple().getVendor() == llvm::Triple::AMD));184assert(E->getBuiltinCallee() == Builtin::BIprintf ||185E->getBuiltinCallee() == Builtin::BI__builtin_printf);186assert(E->getNumArgs() >= 1); // printf always has at least one arg.187188CallArgList CallArgs;189EmitCallArgs(CallArgs,190E->getDirectCallee()->getType()->getAs<FunctionProtoType>(),191E->arguments(), E->getDirectCallee(),192/* ParamsToSkip = */ 0);193194SmallVector<llvm::Value *, 8> Args;195for (const auto &A : CallArgs) {196// We don't know how to emit non-scalar varargs.197if (!A.getRValue(*this).isScalar()) {198CGM.ErrorUnsupported(E, "non-scalar arg to printf");199return RValue::get(llvm::ConstantInt::get(IntTy, -1));200}201202llvm::Value *Arg = A.getRValue(*this).getScalarVal();203Args.push_back(Arg);204}205206llvm::IRBuilder<> IRB(Builder.GetInsertBlock(), Builder.GetInsertPoint());207IRB.SetCurrentDebugLocation(Builder.getCurrentDebugLocation());208209bool isBuffered = (CGM.getTarget().getTargetOpts().AMDGPUPrintfKindVal ==210clang::TargetOptions::AMDGPUPrintfKind::Buffered);211auto Printf = llvm::emitAMDGPUPrintfCall(IRB, Args, isBuffered);212Builder.SetInsertPoint(IRB.GetInsertBlock(), IRB.GetInsertPoint());213return RValue::get(Printf);214}215216RValue CodeGenFunction::EmitOpenMPDevicePrintfCallExpr(const CallExpr *E) {217assert(getTarget().getTriple().isNVPTX() ||218getTarget().getTriple().isAMDGCN());219return EmitDevicePrintfCallExpr(E, this, GetOpenMPVprintfDeclaration(CGM),220true);221}222223224