Path: blob/main/contrib/llvm-project/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp
213845 views
//===- CIRAttrs.cpp - MLIR CIR Attributes ---------------------------------===//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.3// See https://llvm.org/LICENSE.txt for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===----------------------------------------------------------------------===//7//8// This file defines the attributes in the CIR dialect.9//10//===----------------------------------------------------------------------===//1112#include "clang/CIR/Dialect/IR/CIRDialect.h"1314#include "mlir/IR/DialectImplementation.h"15#include "llvm/ADT/TypeSwitch.h"1617//===-----------------------------------------------------------------===//18// IntLiteral19//===-----------------------------------------------------------------===//2021static void printIntLiteral(mlir::AsmPrinter &p, llvm::APInt value,22cir::IntTypeInterface ty);23static mlir::ParseResult parseIntLiteral(mlir::AsmParser &parser,24llvm::APInt &value,25cir::IntTypeInterface ty);26//===-----------------------------------------------------------------===//27// FloatLiteral28//===-----------------------------------------------------------------===//2930static void printFloatLiteral(mlir::AsmPrinter &p, llvm::APFloat value,31mlir::Type ty);32static mlir::ParseResult33parseFloatLiteral(mlir::AsmParser &parser,34mlir::FailureOr<llvm::APFloat> &value,35cir::FPTypeInterface fpType);3637static mlir::ParseResult parseConstPtr(mlir::AsmParser &parser,38mlir::IntegerAttr &value);3940static void printConstPtr(mlir::AsmPrinter &p, mlir::IntegerAttr value);4142#define GET_ATTRDEF_CLASSES43#include "clang/CIR/Dialect/IR/CIROpsAttributes.cpp.inc"4445using namespace mlir;46using namespace cir;4748//===----------------------------------------------------------------------===//49// General CIR parsing / printing50//===----------------------------------------------------------------------===//5152Attribute CIRDialect::parseAttribute(DialectAsmParser &parser,53Type type) const {54llvm::SMLoc typeLoc = parser.getCurrentLocation();55llvm::StringRef mnemonic;56Attribute genAttr;57OptionalParseResult parseResult =58generatedAttributeParser(parser, &mnemonic, type, genAttr);59if (parseResult.has_value())60return genAttr;61parser.emitError(typeLoc, "unknown attribute in CIR dialect");62return Attribute();63}6465void CIRDialect::printAttribute(Attribute attr, DialectAsmPrinter &os) const {66if (failed(generatedAttributePrinter(attr, os)))67llvm_unreachable("unexpected CIR type kind");68}6970//===----------------------------------------------------------------------===//71// OptInfoAttr definitions72//===----------------------------------------------------------------------===//7374LogicalResult OptInfoAttr::verify(function_ref<InFlightDiagnostic()> emitError,75unsigned level, unsigned size) {76if (level > 3)77return emitError()78<< "optimization level must be between 0 and 3 inclusive";79if (size > 2)80return emitError()81<< "size optimization level must be between 0 and 2 inclusive";82return success();83}8485//===----------------------------------------------------------------------===//86// ConstPtrAttr definitions87//===----------------------------------------------------------------------===//8889// TODO(CIR): Consider encoding the null value differently and use conditional90// assembly format instead of custom parsing/printing.91static ParseResult parseConstPtr(AsmParser &parser, mlir::IntegerAttr &value) {9293if (parser.parseOptionalKeyword("null").succeeded()) {94value = parser.getBuilder().getI64IntegerAttr(0);95return success();96}9798return parser.parseAttribute(value);99}100101static void printConstPtr(AsmPrinter &p, mlir::IntegerAttr value) {102if (!value.getInt())103p << "null";104else105p << value;106}107108//===----------------------------------------------------------------------===//109// IntAttr definitions110//===----------------------------------------------------------------------===//111112template <typename IntT>113static bool isTooLargeForType(const mlir::APInt &value, IntT expectedValue) {114if constexpr (std::is_signed_v<IntT>) {115return value.getSExtValue() != expectedValue;116} else {117return value.getZExtValue() != expectedValue;118}119}120121template <typename IntT>122static mlir::ParseResult parseIntLiteralImpl(mlir::AsmParser &p,123llvm::APInt &value,124cir::IntTypeInterface ty) {125IntT ivalue;126const bool isSigned = ty.isSigned();127if (p.parseInteger(ivalue))128return p.emitError(p.getCurrentLocation(), "expected integer value");129130value = mlir::APInt(ty.getWidth(), ivalue, isSigned, /*implicitTrunc=*/true);131if (isTooLargeForType(value, ivalue))132return p.emitError(p.getCurrentLocation(),133"integer value too large for the given type");134135return success();136}137138mlir::ParseResult parseIntLiteral(mlir::AsmParser &parser, llvm::APInt &value,139cir::IntTypeInterface ty) {140if (ty.isSigned())141return parseIntLiteralImpl<int64_t>(parser, value, ty);142return parseIntLiteralImpl<uint64_t>(parser, value, ty);143}144145void printIntLiteral(mlir::AsmPrinter &p, llvm::APInt value,146cir::IntTypeInterface ty) {147if (ty.isSigned())148p << value.getSExtValue();149else150p << value.getZExtValue();151}152153LogicalResult IntAttr::verify(function_ref<InFlightDiagnostic()> emitError,154cir::IntTypeInterface type, llvm::APInt value) {155if (value.getBitWidth() != type.getWidth())156return emitError() << "type and value bitwidth mismatch: "157<< type.getWidth() << " != " << value.getBitWidth();158return success();159}160161//===----------------------------------------------------------------------===//162// FPAttr definitions163//===----------------------------------------------------------------------===//164165static void printFloatLiteral(AsmPrinter &p, APFloat value, Type ty) {166p << value;167}168169static ParseResult parseFloatLiteral(AsmParser &parser,170FailureOr<APFloat> &value,171cir::FPTypeInterface fpType) {172173APFloat parsedValue(0.0);174if (parser.parseFloat(fpType.getFloatSemantics(), parsedValue))175return failure();176177value.emplace(parsedValue);178return success();179}180181FPAttr FPAttr::getZero(Type type) {182return get(type,183APFloat::getZero(184mlir::cast<cir::FPTypeInterface>(type).getFloatSemantics()));185}186187LogicalResult FPAttr::verify(function_ref<InFlightDiagnostic()> emitError,188cir::FPTypeInterface fpType, APFloat value) {189if (APFloat::SemanticsToEnum(fpType.getFloatSemantics()) !=190APFloat::SemanticsToEnum(value.getSemantics()))191return emitError() << "floating-point semantics mismatch";192193return success();194}195196//===----------------------------------------------------------------------===//197// ConstComplexAttr definitions198//===----------------------------------------------------------------------===//199200LogicalResult201ConstComplexAttr::verify(function_ref<InFlightDiagnostic()> emitError,202cir::ComplexType type, mlir::TypedAttr real,203mlir::TypedAttr imag) {204mlir::Type elemType = type.getElementType();205if (real.getType() != elemType)206return emitError()207<< "type of the real part does not match the complex type";208209if (imag.getType() != elemType)210return emitError()211<< "type of the imaginary part does not match the complex type";212213return success();214}215216//===----------------------------------------------------------------------===//217// CIR ConstArrayAttr218//===----------------------------------------------------------------------===//219220LogicalResult221ConstArrayAttr::verify(function_ref<InFlightDiagnostic()> emitError, Type type,222Attribute elts, int trailingZerosNum) {223224if (!(mlir::isa<ArrayAttr, StringAttr>(elts)))225return emitError() << "constant array expects ArrayAttr or StringAttr";226227if (auto strAttr = mlir::dyn_cast<StringAttr>(elts)) {228const auto arrayTy = mlir::cast<ArrayType>(type);229const auto intTy = mlir::dyn_cast<IntType>(arrayTy.getElementType());230231// TODO: add CIR type for char.232if (!intTy || intTy.getWidth() != 8)233return emitError()234<< "constant array element for string literals expects "235"!cir.int<u, 8> element type";236return success();237}238239assert(mlir::isa<ArrayAttr>(elts));240const auto arrayAttr = mlir::cast<mlir::ArrayAttr>(elts);241const auto arrayTy = mlir::cast<ArrayType>(type);242243// Make sure both number of elements and subelement types match type.244if (arrayTy.getSize() != arrayAttr.size() + trailingZerosNum)245return emitError() << "constant array size should match type size";246return success();247}248249Attribute ConstArrayAttr::parse(AsmParser &parser, Type type) {250mlir::FailureOr<Type> resultTy;251mlir::FailureOr<Attribute> resultVal;252253// Parse literal '<'254if (parser.parseLess())255return {};256257// Parse variable 'value'258resultVal = FieldParser<Attribute>::parse(parser);259if (failed(resultVal)) {260parser.emitError(261parser.getCurrentLocation(),262"failed to parse ConstArrayAttr parameter 'value' which is "263"to be a `Attribute`");264return {};265}266267// ArrayAttrrs have per-element type, not the type of the array...268if (mlir::isa<ArrayAttr>(*resultVal)) {269// Array has implicit type: infer from const array type.270if (parser.parseOptionalColon().failed()) {271resultTy = type;272} else { // Array has explicit type: parse it.273resultTy = FieldParser<Type>::parse(parser);274if (failed(resultTy)) {275parser.emitError(276parser.getCurrentLocation(),277"failed to parse ConstArrayAttr parameter 'type' which is "278"to be a `::mlir::Type`");279return {};280}281}282} else {283auto ta = mlir::cast<TypedAttr>(*resultVal);284resultTy = ta.getType();285if (mlir::isa<mlir::NoneType>(*resultTy)) {286parser.emitError(parser.getCurrentLocation(),287"expected type declaration for string literal");288return {};289}290}291292unsigned zeros = 0;293if (parser.parseOptionalComma().succeeded()) {294if (parser.parseOptionalKeyword("trailing_zeros").succeeded()) {295unsigned typeSize =296mlir::cast<cir::ArrayType>(resultTy.value()).getSize();297mlir::Attribute elts = resultVal.value();298if (auto str = mlir::dyn_cast<mlir::StringAttr>(elts))299zeros = typeSize - str.size();300else301zeros = typeSize - mlir::cast<mlir::ArrayAttr>(elts).size();302} else {303return {};304}305}306307// Parse literal '>'308if (parser.parseGreater())309return {};310311return parser.getChecked<ConstArrayAttr>(312parser.getCurrentLocation(), parser.getContext(), resultTy.value(),313resultVal.value(), zeros);314}315316void ConstArrayAttr::print(AsmPrinter &printer) const {317printer << "<";318printer.printStrippedAttrOrType(getElts());319if (getTrailingZerosNum())320printer << ", trailing_zeros";321printer << ">";322}323324//===----------------------------------------------------------------------===//325// CIR ConstVectorAttr326//===----------------------------------------------------------------------===//327328LogicalResult329cir::ConstVectorAttr::verify(function_ref<InFlightDiagnostic()> emitError,330Type type, ArrayAttr elts) {331332if (!mlir::isa<cir::VectorType>(type))333return emitError() << "type of cir::ConstVectorAttr is not a "334"cir::VectorType: "335<< type;336337const auto vecType = mlir::cast<cir::VectorType>(type);338339if (vecType.getSize() != elts.size())340return emitError()341<< "number of constant elements should match vector size";342343// Check if the types of the elements match344LogicalResult elementTypeCheck = success();345elts.walkImmediateSubElements(346[&](Attribute element) {347if (elementTypeCheck.failed()) {348// An earlier element didn't match349return;350}351auto typedElement = mlir::dyn_cast<TypedAttr>(element);352if (!typedElement ||353typedElement.getType() != vecType.getElementType()) {354elementTypeCheck = failure();355emitError() << "constant type should match vector element type";356}357},358[&](Type) {});359360return elementTypeCheck;361}362363//===----------------------------------------------------------------------===//364// CIR Dialect365//===----------------------------------------------------------------------===//366367void CIRDialect::registerAttributes() {368addAttributes<369#define GET_ATTRDEF_LIST370#include "clang/CIR/Dialect/IR/CIROpsAttributes.cpp.inc"371>();372}373374375