Path: blob/main/contrib/llvm-project/clang/lib/CIR/CodeGen/CIRGenStmtOpenACC.cpp
213799 views
//===----------------------------------------------------------------------===//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.3// See https://llvm.org/LICENSE.txt for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===----------------------------------------------------------------------===//7//8// Emit OpenACC Stmt nodes as CIR code.9//10//===----------------------------------------------------------------------===//1112#include "CIRGenBuilder.h"13#include "CIRGenFunction.h"14#include "mlir/Dialect/OpenACC/OpenACC.h"15#include "clang/AST/OpenACCClause.h"16#include "clang/AST/StmtOpenACC.h"1718using namespace clang;19using namespace clang::CIRGen;20using namespace cir;21using namespace mlir::acc;2223template <typename Op, typename TermOp>24mlir::LogicalResult CIRGenFunction::emitOpenACCOpAssociatedStmt(25mlir::Location start, mlir::Location end, OpenACCDirectiveKind dirKind,26SourceLocation dirLoc, llvm::ArrayRef<const OpenACCClause *> clauses,27const Stmt *associatedStmt) {28mlir::LogicalResult res = mlir::success();2930llvm::SmallVector<mlir::Type> retTy;31llvm::SmallVector<mlir::Value> operands;32auto op = builder.create<Op>(start, retTy, operands);3334emitOpenACCClauses(op, dirKind, dirLoc, clauses);3536{37mlir::Block &block = op.getRegion().emplaceBlock();38mlir::OpBuilder::InsertionGuard guardCase(builder);39builder.setInsertionPointToEnd(&block);4041LexicalScope ls{*this, start, builder.getInsertionBlock()};42res = emitStmt(associatedStmt, /*useCurrentScope=*/true);4344builder.create<TermOp>(end);45}46return res;47}4849namespace {50template <typename Op> struct CombinedType;51template <> struct CombinedType<ParallelOp> {52static constexpr mlir::acc::CombinedConstructsType value =53mlir::acc::CombinedConstructsType::ParallelLoop;54};55template <> struct CombinedType<SerialOp> {56static constexpr mlir::acc::CombinedConstructsType value =57mlir::acc::CombinedConstructsType::SerialLoop;58};59template <> struct CombinedType<KernelsOp> {60static constexpr mlir::acc::CombinedConstructsType value =61mlir::acc::CombinedConstructsType::KernelsLoop;62};63} // namespace6465template <typename Op, typename TermOp>66mlir::LogicalResult CIRGenFunction::emitOpenACCOpCombinedConstruct(67mlir::Location start, mlir::Location end, OpenACCDirectiveKind dirKind,68SourceLocation dirLoc, llvm::ArrayRef<const OpenACCClause *> clauses,69const Stmt *loopStmt) {70mlir::LogicalResult res = mlir::success();7172llvm::SmallVector<mlir::Type> retTy;73llvm::SmallVector<mlir::Value> operands;7475auto computeOp = builder.create<Op>(start, retTy, operands);76computeOp.setCombinedAttr(builder.getUnitAttr());77mlir::acc::LoopOp loopOp;7879// First, emit the bodies of both operations, with the loop inside the body of80// the combined construct.81{82mlir::Block &block = computeOp.getRegion().emplaceBlock();83mlir::OpBuilder::InsertionGuard guardCase(builder);84builder.setInsertionPointToEnd(&block);8586LexicalScope ls{*this, start, builder.getInsertionBlock()};87auto loopOp = builder.create<LoopOp>(start, retTy, operands);88loopOp.setCombinedAttr(mlir::acc::CombinedConstructsTypeAttr::get(89builder.getContext(), CombinedType<Op>::value));9091{92mlir::Block &innerBlock = loopOp.getRegion().emplaceBlock();93mlir::OpBuilder::InsertionGuard guardCase(builder);94builder.setInsertionPointToEnd(&innerBlock);9596LexicalScope ls{*this, start, builder.getInsertionBlock()};97ActiveOpenACCLoopRAII activeLoop{*this, &loopOp};9899res = emitStmt(loopStmt, /*useCurrentScope=*/true);100101builder.create<mlir::acc::YieldOp>(end);102}103104emitOpenACCClauses(computeOp, loopOp, dirKind, dirLoc, clauses);105106updateLoopOpParallelism(loopOp, /*isOrphan=*/false, dirKind);107108builder.create<TermOp>(end);109}110111return res;112}113114template <typename Op>115Op CIRGenFunction::emitOpenACCOp(116mlir::Location start, OpenACCDirectiveKind dirKind, SourceLocation dirLoc,117llvm::ArrayRef<const OpenACCClause *> clauses) {118llvm::SmallVector<mlir::Type> retTy;119llvm::SmallVector<mlir::Value> operands;120auto op = builder.create<Op>(start, retTy, operands);121122emitOpenACCClauses(op, dirKind, dirLoc, clauses);123return op;124}125126mlir::LogicalResult127CIRGenFunction::emitOpenACCComputeConstruct(const OpenACCComputeConstruct &s) {128mlir::Location start = getLoc(s.getSourceRange().getBegin());129mlir::Location end = getLoc(s.getSourceRange().getEnd());130131switch (s.getDirectiveKind()) {132case OpenACCDirectiveKind::Parallel:133return emitOpenACCOpAssociatedStmt<ParallelOp, mlir::acc::YieldOp>(134start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(),135s.getStructuredBlock());136case OpenACCDirectiveKind::Serial:137return emitOpenACCOpAssociatedStmt<SerialOp, mlir::acc::YieldOp>(138start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(),139s.getStructuredBlock());140case OpenACCDirectiveKind::Kernels:141return emitOpenACCOpAssociatedStmt<KernelsOp, mlir::acc::TerminatorOp>(142start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(),143s.getStructuredBlock());144default:145llvm_unreachable("invalid compute construct kind");146}147}148149mlir::LogicalResult150CIRGenFunction::emitOpenACCDataConstruct(const OpenACCDataConstruct &s) {151mlir::Location start = getLoc(s.getSourceRange().getBegin());152mlir::Location end = getLoc(s.getSourceRange().getEnd());153154return emitOpenACCOpAssociatedStmt<DataOp, mlir::acc::TerminatorOp>(155start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(),156s.getStructuredBlock());157}158159mlir::LogicalResult160CIRGenFunction::emitOpenACCInitConstruct(const OpenACCInitConstruct &s) {161mlir::Location start = getLoc(s.getSourceRange().getBegin());162emitOpenACCOp<InitOp>(start, s.getDirectiveKind(), s.getDirectiveLoc(),163s.clauses());164return mlir::success();165}166167mlir::LogicalResult168CIRGenFunction::emitOpenACCSetConstruct(const OpenACCSetConstruct &s) {169mlir::Location start = getLoc(s.getSourceRange().getBegin());170emitOpenACCOp<SetOp>(start, s.getDirectiveKind(), s.getDirectiveLoc(),171s.clauses());172return mlir::success();173}174175mlir::LogicalResult CIRGenFunction::emitOpenACCShutdownConstruct(176const OpenACCShutdownConstruct &s) {177mlir::Location start = getLoc(s.getSourceRange().getBegin());178emitOpenACCOp<ShutdownOp>(start, s.getDirectiveKind(),179s.getDirectiveLoc(), s.clauses());180return mlir::success();181}182183mlir::LogicalResult184CIRGenFunction::emitOpenACCWaitConstruct(const OpenACCWaitConstruct &s) {185mlir::Location start = getLoc(s.getSourceRange().getBegin());186auto waitOp = emitOpenACCOp<WaitOp>(start, s.getDirectiveKind(),187s.getDirectiveLoc(), s.clauses());188189auto createIntExpr = [this](const Expr *intExpr) {190mlir::Value expr = emitScalarExpr(intExpr);191mlir::Location exprLoc = cgm.getLoc(intExpr->getBeginLoc());192193mlir::IntegerType targetType = mlir::IntegerType::get(194&getMLIRContext(), getContext().getIntWidth(intExpr->getType()),195intExpr->getType()->isSignedIntegerOrEnumerationType()196? mlir::IntegerType::SignednessSemantics::Signed197: mlir::IntegerType::SignednessSemantics::Unsigned);198199auto conversionOp = builder.create<mlir::UnrealizedConversionCastOp>(200exprLoc, targetType, expr);201return conversionOp.getResult(0);202};203204// Emit the correct 'wait' clauses.205{206mlir::OpBuilder::InsertionGuard guardCase(builder);207builder.setInsertionPoint(waitOp);208209if (s.hasDevNumExpr())210waitOp.getWaitDevnumMutable().append(createIntExpr(s.getDevNumExpr()));211212for (Expr *QueueExpr : s.getQueueIdExprs())213waitOp.getWaitOperandsMutable().append(createIntExpr(QueueExpr));214}215216return mlir::success();217}218219mlir::LogicalResult CIRGenFunction::emitOpenACCCombinedConstruct(220const OpenACCCombinedConstruct &s) {221mlir::Location start = getLoc(s.getSourceRange().getBegin());222mlir::Location end = getLoc(s.getSourceRange().getEnd());223224switch (s.getDirectiveKind()) {225case OpenACCDirectiveKind::ParallelLoop:226return emitOpenACCOpCombinedConstruct<ParallelOp, mlir::acc::YieldOp>(227start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(),228s.getLoop());229case OpenACCDirectiveKind::SerialLoop:230return emitOpenACCOpCombinedConstruct<SerialOp, mlir::acc::YieldOp>(231start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(),232s.getLoop());233case OpenACCDirectiveKind::KernelsLoop:234return emitOpenACCOpCombinedConstruct<KernelsOp, mlir::acc::TerminatorOp>(235start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(),236s.getLoop());237default:238llvm_unreachable("invalid compute construct kind");239}240}241242mlir::LogicalResult CIRGenFunction::emitOpenACCHostDataConstruct(243const OpenACCHostDataConstruct &s) {244mlir::Location start = getLoc(s.getSourceRange().getBegin());245mlir::Location end = getLoc(s.getSourceRange().getEnd());246247return emitOpenACCOpAssociatedStmt<HostDataOp, mlir::acc::TerminatorOp>(248start, end, s.getDirectiveKind(), s.getDirectiveLoc(), s.clauses(),249s.getStructuredBlock());250}251252mlir::LogicalResult CIRGenFunction::emitOpenACCEnterDataConstruct(253const OpenACCEnterDataConstruct &s) {254mlir::Location start = getLoc(s.getSourceRange().getBegin());255emitOpenACCOp<EnterDataOp>(start, s.getDirectiveKind(), s.getDirectiveLoc(),256s.clauses());257return mlir::success();258}259260mlir::LogicalResult CIRGenFunction::emitOpenACCExitDataConstruct(261const OpenACCExitDataConstruct &s) {262mlir::Location start = getLoc(s.getSourceRange().getBegin());263emitOpenACCOp<ExitDataOp>(start, s.getDirectiveKind(), s.getDirectiveLoc(),264s.clauses());265return mlir::success();266}267268mlir::LogicalResult269CIRGenFunction::emitOpenACCUpdateConstruct(const OpenACCUpdateConstruct &s) {270mlir::Location start = getLoc(s.getSourceRange().getBegin());271emitOpenACCOp<UpdateOp>(start, s.getDirectiveKind(), s.getDirectiveLoc(),272s.clauses());273return mlir::success();274}275276mlir::LogicalResult277CIRGenFunction::emitOpenACCCacheConstruct(const OpenACCCacheConstruct &s) {278// The 'cache' directive 'may' be at the top of a loop by standard, but279// doesn't have to be. Additionally, there is nothing that requires this be a280// loop affected by an OpenACC pragma. Sema doesn't do any level of281// enforcement here, since it isn't particularly valuable to do so thanks to282// that. Instead, we treat cache as a 'noop' if there is no acc.loop to apply283// it to.284if (!activeLoopOp)285return mlir::success();286287mlir::acc::LoopOp loopOp = *activeLoopOp;288289mlir::OpBuilder::InsertionGuard guard(builder);290builder.setInsertionPoint(loopOp);291292for (const Expr *var : s.getVarList()) {293CIRGenFunction::OpenACCDataOperandInfo opInfo =294getOpenACCDataOperandInfo(var);295296auto cacheOp = builder.create<CacheOp>(297opInfo.beginLoc, opInfo.varValue,298/*structured=*/false, /*implicit=*/false, opInfo.name, opInfo.bounds);299300loopOp.getCacheOperandsMutable().append(cacheOp.getResult());301}302303return mlir::success();304}305306mlir::LogicalResult307CIRGenFunction::emitOpenACCAtomicConstruct(const OpenACCAtomicConstruct &s) {308cgm.errorNYI(s.getSourceRange(), "OpenACC Atomic Construct");309return mlir::failure();310}311312313