Path: blob/main/contrib/llvm-project/llvm/lib/Transforms/Utils/AMDGPUEmitPrintf.cpp
35271 views
//===- AMDGPUEmitPrintf.cpp -----------------------------------------------===//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.3// See https://llvm.org/LICENSE.txt for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===----------------------------------------------------------------------===//7//8// Utility function to lower a printf call into a series of device9// library calls on the AMDGPU target.10//11// WARNING: This file knows about certain library functions. It recognizes them12// by name, and hardwires knowledge of their semantics.13//14//===----------------------------------------------------------------------===//1516#include "llvm/Transforms/Utils/AMDGPUEmitPrintf.h"17#include "llvm/ADT/SparseBitVector.h"18#include "llvm/ADT/StringExtras.h"19#include "llvm/Analysis/ValueTracking.h"20#include "llvm/IR/Module.h"21#include "llvm/Support/DataExtractor.h"22#include "llvm/Support/MD5.h"23#include "llvm/Support/MathExtras.h"2425using namespace llvm;2627#define DEBUG_TYPE "amdgpu-emit-printf"2829static Value *fitArgInto64Bits(IRBuilder<> &Builder, Value *Arg) {30auto Int64Ty = Builder.getInt64Ty();31auto Ty = Arg->getType();3233if (auto IntTy = dyn_cast<IntegerType>(Ty)) {34switch (IntTy->getBitWidth()) {35case 32:36return Builder.CreateZExt(Arg, Int64Ty);37case 64:38return Arg;39}40}4142if (Ty->getTypeID() == Type::DoubleTyID) {43return Builder.CreateBitCast(Arg, Int64Ty);44}4546if (isa<PointerType>(Ty)) {47return Builder.CreatePtrToInt(Arg, Int64Ty);48}4950llvm_unreachable("unexpected type");51}5253static Value *callPrintfBegin(IRBuilder<> &Builder, Value *Version) {54auto Int64Ty = Builder.getInt64Ty();55auto M = Builder.GetInsertBlock()->getModule();56auto Fn = M->getOrInsertFunction("__ockl_printf_begin", Int64Ty, Int64Ty);57return Builder.CreateCall(Fn, Version);58}5960static Value *callAppendArgs(IRBuilder<> &Builder, Value *Desc, int NumArgs,61Value *Arg0, Value *Arg1, Value *Arg2, Value *Arg3,62Value *Arg4, Value *Arg5, Value *Arg6,63bool IsLast) {64auto Int64Ty = Builder.getInt64Ty();65auto Int32Ty = Builder.getInt32Ty();66auto M = Builder.GetInsertBlock()->getModule();67auto Fn = M->getOrInsertFunction("__ockl_printf_append_args", Int64Ty,68Int64Ty, Int32Ty, Int64Ty, Int64Ty, Int64Ty,69Int64Ty, Int64Ty, Int64Ty, Int64Ty, Int32Ty);70auto IsLastValue = Builder.getInt32(IsLast);71auto NumArgsValue = Builder.getInt32(NumArgs);72return Builder.CreateCall(Fn, {Desc, NumArgsValue, Arg0, Arg1, Arg2, Arg3,73Arg4, Arg5, Arg6, IsLastValue});74}7576static Value *appendArg(IRBuilder<> &Builder, Value *Desc, Value *Arg,77bool IsLast) {78auto Arg0 = fitArgInto64Bits(Builder, Arg);79auto Zero = Builder.getInt64(0);80return callAppendArgs(Builder, Desc, 1, Arg0, Zero, Zero, Zero, Zero, Zero,81Zero, IsLast);82}8384// The device library does not provide strlen, so we build our own loop85// here. While we are at it, we also include the terminating null in the length.86static Value *getStrlenWithNull(IRBuilder<> &Builder, Value *Str) {87auto *Prev = Builder.GetInsertBlock();88Module *M = Prev->getModule();8990auto CharZero = Builder.getInt8(0);91auto One = Builder.getInt64(1);92auto Zero = Builder.getInt64(0);93auto Int64Ty = Builder.getInt64Ty();9495// The length is either zero for a null pointer, or the computed value for an96// actual string. We need a join block for a phi that represents the final97// value.98//99// Strictly speaking, the zero does not matter since100// __ockl_printf_append_string_n ignores the length if the pointer is null.101BasicBlock *Join = nullptr;102if (Prev->getTerminator()) {103Join = Prev->splitBasicBlock(Builder.GetInsertPoint(),104"strlen.join");105Prev->getTerminator()->eraseFromParent();106} else {107Join = BasicBlock::Create(M->getContext(), "strlen.join",108Prev->getParent());109}110BasicBlock *While =111BasicBlock::Create(M->getContext(), "strlen.while",112Prev->getParent(), Join);113BasicBlock *WhileDone = BasicBlock::Create(114M->getContext(), "strlen.while.done",115Prev->getParent(), Join);116117// Emit an early return for when the pointer is null.118Builder.SetInsertPoint(Prev);119auto CmpNull =120Builder.CreateICmpEQ(Str, Constant::getNullValue(Str->getType()));121BranchInst::Create(Join, While, CmpNull, Prev);122123// Entry to the while loop.124Builder.SetInsertPoint(While);125126auto PtrPhi = Builder.CreatePHI(Str->getType(), 2);127PtrPhi->addIncoming(Str, Prev);128auto PtrNext = Builder.CreateGEP(Builder.getInt8Ty(), PtrPhi, One);129PtrPhi->addIncoming(PtrNext, While);130131// Condition for the while loop.132auto Data = Builder.CreateLoad(Builder.getInt8Ty(), PtrPhi);133auto Cmp = Builder.CreateICmpEQ(Data, CharZero);134Builder.CreateCondBr(Cmp, WhileDone, While);135136// Add one to the computed length.137Builder.SetInsertPoint(WhileDone, WhileDone->begin());138auto Begin = Builder.CreatePtrToInt(Str, Int64Ty);139auto End = Builder.CreatePtrToInt(PtrPhi, Int64Ty);140auto Len = Builder.CreateSub(End, Begin);141Len = Builder.CreateAdd(Len, One);142143// Final join.144BranchInst::Create(Join, WhileDone);145Builder.SetInsertPoint(Join, Join->begin());146auto LenPhi = Builder.CreatePHI(Len->getType(), 2);147LenPhi->addIncoming(Len, WhileDone);148LenPhi->addIncoming(Zero, Prev);149150return LenPhi;151}152153static Value *callAppendStringN(IRBuilder<> &Builder, Value *Desc, Value *Str,154Value *Length, bool isLast) {155auto Int64Ty = Builder.getInt64Ty();156auto IsLastInt32 = Builder.getInt32(isLast);157auto M = Builder.GetInsertBlock()->getModule();158auto Fn = M->getOrInsertFunction("__ockl_printf_append_string_n", Int64Ty,159Desc->getType(), Str->getType(),160Length->getType(), IsLastInt32->getType());161return Builder.CreateCall(Fn, {Desc, Str, Length, IsLastInt32});162}163164static Value *appendString(IRBuilder<> &Builder, Value *Desc, Value *Arg,165bool IsLast) {166auto Length = getStrlenWithNull(Builder, Arg);167return callAppendStringN(Builder, Desc, Arg, Length, IsLast);168}169170static Value *processArg(IRBuilder<> &Builder, Value *Desc, Value *Arg,171bool SpecIsCString, bool IsLast) {172if (SpecIsCString && isa<PointerType>(Arg->getType())) {173return appendString(Builder, Desc, Arg, IsLast);174}175// If the format specifies a string but the argument is not, the frontend will176// have printed a warning. We just rely on undefined behaviour and send the177// argument anyway.178return appendArg(Builder, Desc, Arg, IsLast);179}180181// Scan the format string to locate all specifiers, and mark the ones that182// specify a string, i.e, the "%s" specifier with optional '*' characters.183static void locateCStrings(SparseBitVector<8> &BV, StringRef Str) {184static const char ConvSpecifiers[] = "diouxXfFeEgGaAcspn";185size_t SpecPos = 0;186// Skip the first argument, the format string.187unsigned ArgIdx = 1;188189while ((SpecPos = Str.find_first_of('%', SpecPos)) != StringRef::npos) {190if (Str[SpecPos + 1] == '%') {191SpecPos += 2;192continue;193}194auto SpecEnd = Str.find_first_of(ConvSpecifiers, SpecPos);195if (SpecEnd == StringRef::npos)196return;197auto Spec = Str.slice(SpecPos, SpecEnd + 1);198ArgIdx += Spec.count('*');199if (Str[SpecEnd] == 's') {200BV.set(ArgIdx);201}202SpecPos = SpecEnd + 1;203++ArgIdx;204}205}206207// helper struct to package the string related data208struct StringData {209StringRef Str;210Value *RealSize = nullptr;211Value *AlignedSize = nullptr;212bool IsConst = true;213214StringData(StringRef ST, Value *RS, Value *AS, bool IC)215: Str(ST), RealSize(RS), AlignedSize(AS), IsConst(IC) {}216};217218// Calculates frame size required for current printf expansion and allocates219// space on printf buffer. Printf frame includes following contents220// [ ControlDWord , format string/Hash , Arguments (each aligned to 8 byte) ]221static Value *callBufferedPrintfStart(222IRBuilder<> &Builder, ArrayRef<Value *> Args, Value *Fmt,223bool isConstFmtStr, SparseBitVector<8> &SpecIsCString,224SmallVectorImpl<StringData> &StringContents, Value *&ArgSize) {225Module *M = Builder.GetInsertBlock()->getModule();226Value *NonConstStrLen = nullptr;227Value *LenWithNull = nullptr;228Value *LenWithNullAligned = nullptr;229Value *TempAdd = nullptr;230231// First 4 bytes to be reserved for control dword232size_t BufSize = 4;233if (isConstFmtStr)234// First 8 bytes of MD5 hash235BufSize += 8;236else {237LenWithNull = getStrlenWithNull(Builder, Fmt);238239// Align the computed length to next 8 byte boundary240TempAdd = Builder.CreateAdd(LenWithNull,241ConstantInt::get(LenWithNull->getType(), 7U));242NonConstStrLen = Builder.CreateAnd(243TempAdd, ConstantInt::get(LenWithNull->getType(), ~7U));244245StringContents.push_back(246StringData(StringRef(), LenWithNull, NonConstStrLen, false));247}248249for (size_t i = 1; i < Args.size(); i++) {250if (SpecIsCString.test(i)) {251StringRef ArgStr;252if (getConstantStringInfo(Args[i], ArgStr)) {253auto alignedLen = alignTo(ArgStr.size() + 1, 8);254StringContents.push_back(StringData(255ArgStr,256/*RealSize*/ nullptr, /*AlignedSize*/ nullptr, /*IsConst*/ true));257BufSize += alignedLen;258} else {259LenWithNull = getStrlenWithNull(Builder, Args[i]);260261// Align the computed length to next 8 byte boundary262TempAdd = Builder.CreateAdd(263LenWithNull, ConstantInt::get(LenWithNull->getType(), 7U));264LenWithNullAligned = Builder.CreateAnd(265TempAdd, ConstantInt::get(LenWithNull->getType(), ~7U));266267if (NonConstStrLen) {268auto Val = Builder.CreateAdd(LenWithNullAligned, NonConstStrLen,269"cumulativeAdd");270NonConstStrLen = Val;271} else272NonConstStrLen = LenWithNullAligned;273274StringContents.push_back(275StringData(StringRef(), LenWithNull, LenWithNullAligned, false));276}277} else {278int AllocSize = M->getDataLayout().getTypeAllocSize(Args[i]->getType());279// We end up expanding non string arguments to 8 bytes280// (args smaller than 8 bytes)281BufSize += std::max(AllocSize, 8);282}283}284285// calculate final size value to be passed to printf_alloc286Value *SizeToReserve = ConstantInt::get(Builder.getInt64Ty(), BufSize, false);287SmallVector<Value *, 1> Alloc_args;288if (NonConstStrLen)289SizeToReserve = Builder.CreateAdd(NonConstStrLen, SizeToReserve);290291ArgSize = Builder.CreateTrunc(SizeToReserve, Builder.getInt32Ty());292Alloc_args.push_back(ArgSize);293294// call the printf_alloc function295AttributeList Attr = AttributeList::get(296Builder.getContext(), AttributeList::FunctionIndex, Attribute::NoUnwind);297298Type *Tys_alloc[1] = {Builder.getInt32Ty()};299Type *PtrTy =300Builder.getPtrTy(M->getDataLayout().getDefaultGlobalsAddressSpace());301FunctionType *FTy_alloc = FunctionType::get(PtrTy, Tys_alloc, false);302auto PrintfAllocFn =303M->getOrInsertFunction(StringRef("__printf_alloc"), FTy_alloc, Attr);304305return Builder.CreateCall(PrintfAllocFn, Alloc_args, "printf_alloc_fn");306}307308// Prepare constant string argument to push onto the buffer309static void processConstantStringArg(StringData *SD, IRBuilder<> &Builder,310SmallVectorImpl<Value *> &WhatToStore) {311std::string Str(SD->Str.str() + '\0');312313DataExtractor Extractor(Str, /*IsLittleEndian=*/true, 8);314DataExtractor::Cursor Offset(0);315while (Offset && Offset.tell() < Str.size()) {316const uint64_t ReadSize = 4;317uint64_t ReadNow = std::min(ReadSize, Str.size() - Offset.tell());318uint64_t ReadBytes = 0;319switch (ReadNow) {320default:321llvm_unreachable("min(4, X) > 4?");322case 1:323ReadBytes = Extractor.getU8(Offset);324break;325case 2:326ReadBytes = Extractor.getU16(Offset);327break;328case 3:329ReadBytes = Extractor.getU24(Offset);330break;331case 4:332ReadBytes = Extractor.getU32(Offset);333break;334}335cantFail(Offset.takeError(), "failed to read bytes from constant array");336337APInt IntVal(8 * ReadSize, ReadBytes);338339// TODO: Should not bother aligning up.340if (ReadNow < ReadSize)341IntVal = IntVal.zext(8 * ReadSize);342343Type *IntTy = Type::getIntNTy(Builder.getContext(), IntVal.getBitWidth());344WhatToStore.push_back(ConstantInt::get(IntTy, IntVal));345}346// Additional padding for 8 byte alignment347int Rem = (Str.size() % 8);348if (Rem > 0 && Rem <= 4)349WhatToStore.push_back(ConstantInt::get(Builder.getInt32Ty(), 0));350}351352static Value *processNonStringArg(Value *Arg, IRBuilder<> &Builder) {353const DataLayout &DL = Builder.GetInsertBlock()->getDataLayout();354auto Ty = Arg->getType();355356if (auto IntTy = dyn_cast<IntegerType>(Ty)) {357if (IntTy->getBitWidth() < 64) {358return Builder.CreateZExt(Arg, Builder.getInt64Ty());359}360}361362if (Ty->isFloatingPointTy()) {363if (DL.getTypeAllocSize(Ty) < 8) {364return Builder.CreateFPExt(Arg, Builder.getDoubleTy());365}366}367368return Arg;369}370371static void372callBufferedPrintfArgPush(IRBuilder<> &Builder, ArrayRef<Value *> Args,373Value *PtrToStore, SparseBitVector<8> &SpecIsCString,374SmallVectorImpl<StringData> &StringContents,375bool IsConstFmtStr) {376Module *M = Builder.GetInsertBlock()->getModule();377const DataLayout &DL = M->getDataLayout();378auto StrIt = StringContents.begin();379size_t i = IsConstFmtStr ? 1 : 0;380for (; i < Args.size(); i++) {381SmallVector<Value *, 32> WhatToStore;382if ((i == 0) || SpecIsCString.test(i)) {383if (StrIt->IsConst) {384processConstantStringArg(StrIt, Builder, WhatToStore);385StrIt++;386} else {387// This copies the contents of the string, however the next offset388// is at aligned length, the extra space that might be created due389// to alignment padding is not populated with any specific value390// here. This would be safe as long as runtime is sync with391// the offsets.392Builder.CreateMemCpy(PtrToStore, /*DstAlign*/ Align(1), Args[i],393/*SrcAlign*/ Args[i]->getPointerAlignment(DL),394StrIt->RealSize);395396PtrToStore =397Builder.CreateInBoundsGEP(Builder.getInt8Ty(), PtrToStore,398{StrIt->AlignedSize}, "PrintBuffNextPtr");399LLVM_DEBUG(dbgs() << "inserting gep to the printf buffer:"400<< *PtrToStore << '\n');401402// done with current argument, move to next403StrIt++;404continue;405}406} else {407WhatToStore.push_back(processNonStringArg(Args[i], Builder));408}409410for (Value *toStore : WhatToStore) {411StoreInst *StBuff = Builder.CreateStore(toStore, PtrToStore);412LLVM_DEBUG(dbgs() << "inserting store to printf buffer:" << *StBuff413<< '\n');414(void)StBuff;415PtrToStore = Builder.CreateConstInBoundsGEP1_32(416Builder.getInt8Ty(), PtrToStore,417M->getDataLayout().getTypeAllocSize(toStore->getType()),418"PrintBuffNextPtr");419LLVM_DEBUG(dbgs() << "inserting gep to the printf buffer:" << *PtrToStore420<< '\n');421}422}423}424425Value *llvm::emitAMDGPUPrintfCall(IRBuilder<> &Builder, ArrayRef<Value *> Args,426bool IsBuffered) {427auto NumOps = Args.size();428assert(NumOps >= 1);429430auto Fmt = Args[0];431SparseBitVector<8> SpecIsCString;432StringRef FmtStr;433434if (getConstantStringInfo(Fmt, FmtStr))435locateCStrings(SpecIsCString, FmtStr);436437if (IsBuffered) {438SmallVector<StringData, 8> StringContents;439Module *M = Builder.GetInsertBlock()->getModule();440LLVMContext &Ctx = Builder.getContext();441auto Int8Ty = Builder.getInt8Ty();442auto Int32Ty = Builder.getInt32Ty();443bool IsConstFmtStr = !FmtStr.empty();444445Value *ArgSize = nullptr;446Value *Ptr =447callBufferedPrintfStart(Builder, Args, Fmt, IsConstFmtStr,448SpecIsCString, StringContents, ArgSize);449450// The buffered version still follows OpenCL printf standards for451// printf return value, i.e 0 on success, -1 on failure.452ConstantPointerNull *zeroIntPtr =453ConstantPointerNull::get(cast<PointerType>(Ptr->getType()));454455auto *Cmp = cast<ICmpInst>(Builder.CreateICmpNE(Ptr, zeroIntPtr, ""));456457BasicBlock *End = BasicBlock::Create(Ctx, "end.block",458Builder.GetInsertBlock()->getParent());459BasicBlock *ArgPush = BasicBlock::Create(460Ctx, "argpush.block", Builder.GetInsertBlock()->getParent());461462BranchInst::Create(ArgPush, End, Cmp, Builder.GetInsertBlock());463Builder.SetInsertPoint(ArgPush);464465// Create controlDWord and store as the first entry, format as follows466// Bit 0 (LSB) -> stream (1 if stderr, 0 if stdout, printf always outputs to467// stdout) Bit 1 -> constant format string (1 if constant) Bits 2-31 -> size468// of printf data frame469auto ConstantTwo = Builder.getInt32(2);470auto ControlDWord = Builder.CreateShl(ArgSize, ConstantTwo);471if (IsConstFmtStr)472ControlDWord = Builder.CreateOr(ControlDWord, ConstantTwo);473474Builder.CreateStore(ControlDWord, Ptr);475476Ptr = Builder.CreateConstInBoundsGEP1_32(Int8Ty, Ptr, 4);477478// Create MD5 hash for costant format string, push low 64 bits of the479// same onto buffer and metadata.480NamedMDNode *metaD = M->getOrInsertNamedMetadata("llvm.printf.fmts");481if (IsConstFmtStr) {482MD5 Hasher;483MD5::MD5Result Hash;484Hasher.update(FmtStr);485Hasher.final(Hash);486487// Try sticking to llvm.printf.fmts format, although we are not going to488// use the ID and argument size fields while printing,489std::string MetadataStr =490"0:0:" + llvm::utohexstr(Hash.low(), /*LowerCase=*/true) + "," +491FmtStr.str();492MDString *fmtStrArray = MDString::get(Ctx, MetadataStr);493MDNode *myMD = MDNode::get(Ctx, fmtStrArray);494metaD->addOperand(myMD);495496Builder.CreateStore(Builder.getInt64(Hash.low()), Ptr);497Ptr = Builder.CreateConstInBoundsGEP1_32(Int8Ty, Ptr, 8);498} else {499// Include a dummy metadata instance in case of only non constant500// format string usage, This might be an absurd usecase but needs to501// be done for completeness502if (metaD->getNumOperands() == 0) {503MDString *fmtStrArray =504MDString::get(Ctx, "0:0:ffffffff,\"Non const format string\"");505MDNode *myMD = MDNode::get(Ctx, fmtStrArray);506metaD->addOperand(myMD);507}508}509510// Push The printf arguments onto buffer511callBufferedPrintfArgPush(Builder, Args, Ptr, SpecIsCString, StringContents,512IsConstFmtStr);513514// End block, returns -1 on failure515BranchInst::Create(End, ArgPush);516Builder.SetInsertPoint(End);517return Builder.CreateSExt(Builder.CreateNot(Cmp), Int32Ty, "printf_result");518}519520auto Desc = callPrintfBegin(Builder, Builder.getIntN(64, 0));521Desc = appendString(Builder, Desc, Fmt, NumOps == 1);522523// FIXME: This invokes hostcall once for each argument. We can pack up to524// seven scalar printf arguments in a single hostcall. See the signature of525// callAppendArgs().526for (unsigned int i = 1; i != NumOps; ++i) {527bool IsLast = i == NumOps - 1;528bool IsCString = SpecIsCString.test(i);529Desc = processArg(Builder, Desc, Args[i], IsCString, IsLast);530}531532return Builder.CreateTrunc(Desc, Builder.getInt32Ty());533}534535536