Path: blob/main/contrib/llvm-project/clang/lib/Analysis/Consumed.cpp
35234 views
//===- Consumed.cpp -------------------------------------------------------===//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.3// See https://llvm.org/LICENSE.txt for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===----------------------------------------------------------------------===//7//8// A intra-procedural analysis for checking consumed properties. This is based,9// in part, on research on linear types.10//11//===----------------------------------------------------------------------===//1213#include "clang/Analysis/Analyses/Consumed.h"14#include "clang/AST/Attr.h"15#include "clang/AST/Decl.h"16#include "clang/AST/DeclCXX.h"17#include "clang/AST/Expr.h"18#include "clang/AST/ExprCXX.h"19#include "clang/AST/Stmt.h"20#include "clang/AST/StmtVisitor.h"21#include "clang/AST/Type.h"22#include "clang/Analysis/Analyses/PostOrderCFGView.h"23#include "clang/Analysis/AnalysisDeclContext.h"24#include "clang/Analysis/CFG.h"25#include "clang/Basic/LLVM.h"26#include "clang/Basic/OperatorKinds.h"27#include "clang/Basic/SourceLocation.h"28#include "llvm/ADT/DenseMap.h"29#include "llvm/ADT/STLExtras.h"30#include "llvm/ADT/StringRef.h"31#include "llvm/Support/Casting.h"32#include "llvm/Support/ErrorHandling.h"33#include <cassert>34#include <memory>35#include <optional>36#include <utility>3738// TODO: Adjust states of args to constructors in the same way that arguments to39// function calls are handled.40// TODO: Use information from tests in for- and while-loop conditional.41// TODO: Add notes about the actual and expected state for42// TODO: Correctly identify unreachable blocks when chaining boolean operators.43// TODO: Adjust the parser and AttributesList class to support lists of44// identifiers.45// TODO: Warn about unreachable code.46// TODO: Switch to using a bitmap to track unreachable blocks.47// TODO: Handle variable definitions, e.g. bool valid = x.isValid();48// if (valid) ...; (Deferred)49// TODO: Take notes on state transitions to provide better warning messages.50// (Deferred)51// TODO: Test nested conditionals: A) Checking the same value multiple times,52// and 2) Checking different values. (Deferred)5354using namespace clang;55using namespace consumed;5657// Key method definition58ConsumedWarningsHandlerBase::~ConsumedWarningsHandlerBase() = default;5960static SourceLocation getFirstStmtLoc(const CFGBlock *Block) {61// Find the source location of the first statement in the block, if the block62// is not empty.63for (const auto &B : *Block)64if (std::optional<CFGStmt> CS = B.getAs<CFGStmt>())65return CS->getStmt()->getBeginLoc();6667// Block is empty.68// If we have one successor, return the first statement in that block69if (Block->succ_size() == 1 && *Block->succ_begin())70return getFirstStmtLoc(*Block->succ_begin());7172return {};73}7475static SourceLocation getLastStmtLoc(const CFGBlock *Block) {76// Find the source location of the last statement in the block, if the block77// is not empty.78if (const Stmt *StmtNode = Block->getTerminatorStmt()) {79return StmtNode->getBeginLoc();80} else {81for (CFGBlock::const_reverse_iterator BI = Block->rbegin(),82BE = Block->rend(); BI != BE; ++BI) {83if (std::optional<CFGStmt> CS = BI->getAs<CFGStmt>())84return CS->getStmt()->getBeginLoc();85}86}8788// If we have one successor, return the first statement in that block89SourceLocation Loc;90if (Block->succ_size() == 1 && *Block->succ_begin())91Loc = getFirstStmtLoc(*Block->succ_begin());92if (Loc.isValid())93return Loc;9495// If we have one predecessor, return the last statement in that block96if (Block->pred_size() == 1 && *Block->pred_begin())97return getLastStmtLoc(*Block->pred_begin());9899return Loc;100}101102static ConsumedState invertConsumedUnconsumed(ConsumedState State) {103switch (State) {104case CS_Unconsumed:105return CS_Consumed;106case CS_Consumed:107return CS_Unconsumed;108case CS_None:109return CS_None;110case CS_Unknown:111return CS_Unknown;112}113llvm_unreachable("invalid enum");114}115116static bool isCallableInState(const CallableWhenAttr *CWAttr,117ConsumedState State) {118for (const auto &S : CWAttr->callableStates()) {119ConsumedState MappedAttrState = CS_None;120121switch (S) {122case CallableWhenAttr::Unknown:123MappedAttrState = CS_Unknown;124break;125126case CallableWhenAttr::Unconsumed:127MappedAttrState = CS_Unconsumed;128break;129130case CallableWhenAttr::Consumed:131MappedAttrState = CS_Consumed;132break;133}134135if (MappedAttrState == State)136return true;137}138139return false;140}141142static bool isConsumableType(const QualType &QT) {143if (QT->isPointerType() || QT->isReferenceType())144return false;145146if (const CXXRecordDecl *RD = QT->getAsCXXRecordDecl())147return RD->hasAttr<ConsumableAttr>();148149return false;150}151152static bool isAutoCastType(const QualType &QT) {153if (QT->isPointerType() || QT->isReferenceType())154return false;155156if (const CXXRecordDecl *RD = QT->getAsCXXRecordDecl())157return RD->hasAttr<ConsumableAutoCastAttr>();158159return false;160}161162static bool isSetOnReadPtrType(const QualType &QT) {163if (const CXXRecordDecl *RD = QT->getPointeeCXXRecordDecl())164return RD->hasAttr<ConsumableSetOnReadAttr>();165return false;166}167168static bool isKnownState(ConsumedState State) {169switch (State) {170case CS_Unconsumed:171case CS_Consumed:172return true;173case CS_None:174case CS_Unknown:175return false;176}177llvm_unreachable("invalid enum");178}179180static bool isRValueRef(QualType ParamType) {181return ParamType->isRValueReferenceType();182}183184static bool isTestingFunction(const FunctionDecl *FunDecl) {185return FunDecl->hasAttr<TestTypestateAttr>();186}187188static bool isPointerOrRef(QualType ParamType) {189return ParamType->isPointerType() || ParamType->isReferenceType();190}191192static ConsumedState mapConsumableAttrState(const QualType QT) {193assert(isConsumableType(QT));194195const ConsumableAttr *CAttr =196QT->getAsCXXRecordDecl()->getAttr<ConsumableAttr>();197198switch (CAttr->getDefaultState()) {199case ConsumableAttr::Unknown:200return CS_Unknown;201case ConsumableAttr::Unconsumed:202return CS_Unconsumed;203case ConsumableAttr::Consumed:204return CS_Consumed;205}206llvm_unreachable("invalid enum");207}208209static ConsumedState210mapParamTypestateAttrState(const ParamTypestateAttr *PTAttr) {211switch (PTAttr->getParamState()) {212case ParamTypestateAttr::Unknown:213return CS_Unknown;214case ParamTypestateAttr::Unconsumed:215return CS_Unconsumed;216case ParamTypestateAttr::Consumed:217return CS_Consumed;218}219llvm_unreachable("invalid_enum");220}221222static ConsumedState223mapReturnTypestateAttrState(const ReturnTypestateAttr *RTSAttr) {224switch (RTSAttr->getState()) {225case ReturnTypestateAttr::Unknown:226return CS_Unknown;227case ReturnTypestateAttr::Unconsumed:228return CS_Unconsumed;229case ReturnTypestateAttr::Consumed:230return CS_Consumed;231}232llvm_unreachable("invalid enum");233}234235static ConsumedState mapSetTypestateAttrState(const SetTypestateAttr *STAttr) {236switch (STAttr->getNewState()) {237case SetTypestateAttr::Unknown:238return CS_Unknown;239case SetTypestateAttr::Unconsumed:240return CS_Unconsumed;241case SetTypestateAttr::Consumed:242return CS_Consumed;243}244llvm_unreachable("invalid_enum");245}246247static StringRef stateToString(ConsumedState State) {248switch (State) {249case consumed::CS_None:250return "none";251252case consumed::CS_Unknown:253return "unknown";254255case consumed::CS_Unconsumed:256return "unconsumed";257258case consumed::CS_Consumed:259return "consumed";260}261llvm_unreachable("invalid enum");262}263264static ConsumedState testsFor(const FunctionDecl *FunDecl) {265assert(isTestingFunction(FunDecl));266switch (FunDecl->getAttr<TestTypestateAttr>()->getTestState()) {267case TestTypestateAttr::Unconsumed:268return CS_Unconsumed;269case TestTypestateAttr::Consumed:270return CS_Consumed;271}272llvm_unreachable("invalid enum");273}274275namespace {276277struct VarTestResult {278const VarDecl *Var;279ConsumedState TestsFor;280};281282} // namespace283284namespace clang {285namespace consumed {286287enum EffectiveOp {288EO_And,289EO_Or290};291292class PropagationInfo {293enum {294IT_None,295IT_State,296IT_VarTest,297IT_BinTest,298IT_Var,299IT_Tmp300} InfoType = IT_None;301302struct BinTestTy {303const BinaryOperator *Source;304EffectiveOp EOp;305VarTestResult LTest;306VarTestResult RTest;307};308309union {310ConsumedState State;311VarTestResult VarTest;312const VarDecl *Var;313const CXXBindTemporaryExpr *Tmp;314BinTestTy BinTest;315};316317public:318PropagationInfo() = default;319PropagationInfo(const VarTestResult &VarTest)320: InfoType(IT_VarTest), VarTest(VarTest) {}321322PropagationInfo(const VarDecl *Var, ConsumedState TestsFor)323: InfoType(IT_VarTest) {324VarTest.Var = Var;325VarTest.TestsFor = TestsFor;326}327328PropagationInfo(const BinaryOperator *Source, EffectiveOp EOp,329const VarTestResult <est, const VarTestResult &RTest)330: InfoType(IT_BinTest) {331BinTest.Source = Source;332BinTest.EOp = EOp;333BinTest.LTest = LTest;334BinTest.RTest = RTest;335}336337PropagationInfo(const BinaryOperator *Source, EffectiveOp EOp,338const VarDecl *LVar, ConsumedState LTestsFor,339const VarDecl *RVar, ConsumedState RTestsFor)340: InfoType(IT_BinTest) {341BinTest.Source = Source;342BinTest.EOp = EOp;343BinTest.LTest.Var = LVar;344BinTest.LTest.TestsFor = LTestsFor;345BinTest.RTest.Var = RVar;346BinTest.RTest.TestsFor = RTestsFor;347}348349PropagationInfo(ConsumedState State)350: InfoType(IT_State), State(State) {}351PropagationInfo(const VarDecl *Var) : InfoType(IT_Var), Var(Var) {}352PropagationInfo(const CXXBindTemporaryExpr *Tmp)353: InfoType(IT_Tmp), Tmp(Tmp) {}354355const ConsumedState &getState() const {356assert(InfoType == IT_State);357return State;358}359360const VarTestResult &getVarTest() const {361assert(InfoType == IT_VarTest);362return VarTest;363}364365const VarTestResult &getLTest() const {366assert(InfoType == IT_BinTest);367return BinTest.LTest;368}369370const VarTestResult &getRTest() const {371assert(InfoType == IT_BinTest);372return BinTest.RTest;373}374375const VarDecl *getVar() const {376assert(InfoType == IT_Var);377return Var;378}379380const CXXBindTemporaryExpr *getTmp() const {381assert(InfoType == IT_Tmp);382return Tmp;383}384385ConsumedState getAsState(const ConsumedStateMap *StateMap) const {386assert(isVar() || isTmp() || isState());387388if (isVar())389return StateMap->getState(Var);390else if (isTmp())391return StateMap->getState(Tmp);392else if (isState())393return State;394else395return CS_None;396}397398EffectiveOp testEffectiveOp() const {399assert(InfoType == IT_BinTest);400return BinTest.EOp;401}402403const BinaryOperator * testSourceNode() const {404assert(InfoType == IT_BinTest);405return BinTest.Source;406}407408bool isValid() const { return InfoType != IT_None; }409bool isState() const { return InfoType == IT_State; }410bool isVarTest() const { return InfoType == IT_VarTest; }411bool isBinTest() const { return InfoType == IT_BinTest; }412bool isVar() const { return InfoType == IT_Var; }413bool isTmp() const { return InfoType == IT_Tmp; }414415bool isTest() const {416return InfoType == IT_VarTest || InfoType == IT_BinTest;417}418419bool isPointerToValue() const {420return InfoType == IT_Var || InfoType == IT_Tmp;421}422423PropagationInfo invertTest() const {424assert(InfoType == IT_VarTest || InfoType == IT_BinTest);425426if (InfoType == IT_VarTest) {427return PropagationInfo(VarTest.Var,428invertConsumedUnconsumed(VarTest.TestsFor));429430} else if (InfoType == IT_BinTest) {431return PropagationInfo(BinTest.Source,432BinTest.EOp == EO_And ? EO_Or : EO_And,433BinTest.LTest.Var, invertConsumedUnconsumed(BinTest.LTest.TestsFor),434BinTest.RTest.Var, invertConsumedUnconsumed(BinTest.RTest.TestsFor));435} else {436return {};437}438}439};440441} // namespace consumed442} // namespace clang443444static void445setStateForVarOrTmp(ConsumedStateMap *StateMap, const PropagationInfo &PInfo,446ConsumedState State) {447assert(PInfo.isVar() || PInfo.isTmp());448449if (PInfo.isVar())450StateMap->setState(PInfo.getVar(), State);451else452StateMap->setState(PInfo.getTmp(), State);453}454455namespace clang {456namespace consumed {457458class ConsumedStmtVisitor : public ConstStmtVisitor<ConsumedStmtVisitor> {459using MapType = llvm::DenseMap<const Stmt *, PropagationInfo>;460using PairType= std::pair<const Stmt *, PropagationInfo>;461using InfoEntry = MapType::iterator;462using ConstInfoEntry = MapType::const_iterator;463464ConsumedAnalyzer &Analyzer;465ConsumedStateMap *StateMap;466MapType PropagationMap;467468InfoEntry findInfo(const Expr *E) {469if (const auto Cleanups = dyn_cast<ExprWithCleanups>(E))470if (!Cleanups->cleanupsHaveSideEffects())471E = Cleanups->getSubExpr();472return PropagationMap.find(E->IgnoreParens());473}474475ConstInfoEntry findInfo(const Expr *E) const {476if (const auto Cleanups = dyn_cast<ExprWithCleanups>(E))477if (!Cleanups->cleanupsHaveSideEffects())478E = Cleanups->getSubExpr();479return PropagationMap.find(E->IgnoreParens());480}481482void insertInfo(const Expr *E, const PropagationInfo &PI) {483PropagationMap.insert(PairType(E->IgnoreParens(), PI));484}485486void forwardInfo(const Expr *From, const Expr *To);487void copyInfo(const Expr *From, const Expr *To, ConsumedState CS);488ConsumedState getInfo(const Expr *From);489void setInfo(const Expr *To, ConsumedState NS);490void propagateReturnType(const Expr *Call, const FunctionDecl *Fun);491492public:493void checkCallability(const PropagationInfo &PInfo,494const FunctionDecl *FunDecl,495SourceLocation BlameLoc);496bool handleCall(const CallExpr *Call, const Expr *ObjArg,497const FunctionDecl *FunD);498499void VisitBinaryOperator(const BinaryOperator *BinOp);500void VisitCallExpr(const CallExpr *Call);501void VisitCastExpr(const CastExpr *Cast);502void VisitCXXBindTemporaryExpr(const CXXBindTemporaryExpr *Temp);503void VisitCXXConstructExpr(const CXXConstructExpr *Call);504void VisitCXXMemberCallExpr(const CXXMemberCallExpr *Call);505void VisitCXXOperatorCallExpr(const CXXOperatorCallExpr *Call);506void VisitDeclRefExpr(const DeclRefExpr *DeclRef);507void VisitDeclStmt(const DeclStmt *DelcS);508void VisitMaterializeTemporaryExpr(const MaterializeTemporaryExpr *Temp);509void VisitMemberExpr(const MemberExpr *MExpr);510void VisitParmVarDecl(const ParmVarDecl *Param);511void VisitReturnStmt(const ReturnStmt *Ret);512void VisitUnaryOperator(const UnaryOperator *UOp);513void VisitVarDecl(const VarDecl *Var);514515ConsumedStmtVisitor(ConsumedAnalyzer &Analyzer, ConsumedStateMap *StateMap)516: Analyzer(Analyzer), StateMap(StateMap) {}517518PropagationInfo getInfo(const Expr *StmtNode) const {519ConstInfoEntry Entry = findInfo(StmtNode);520521if (Entry != PropagationMap.end())522return Entry->second;523else524return {};525}526527void reset(ConsumedStateMap *NewStateMap) {528StateMap = NewStateMap;529}530};531532} // namespace consumed533} // namespace clang534535void ConsumedStmtVisitor::forwardInfo(const Expr *From, const Expr *To) {536InfoEntry Entry = findInfo(From);537if (Entry != PropagationMap.end())538insertInfo(To, Entry->second);539}540541// Create a new state for To, which is initialized to the state of From.542// If NS is not CS_None, sets the state of From to NS.543void ConsumedStmtVisitor::copyInfo(const Expr *From, const Expr *To,544ConsumedState NS) {545InfoEntry Entry = findInfo(From);546if (Entry != PropagationMap.end()) {547PropagationInfo& PInfo = Entry->second;548ConsumedState CS = PInfo.getAsState(StateMap);549if (CS != CS_None)550insertInfo(To, PropagationInfo(CS));551if (NS != CS_None && PInfo.isPointerToValue())552setStateForVarOrTmp(StateMap, PInfo, NS);553}554}555556// Get the ConsumedState for From557ConsumedState ConsumedStmtVisitor::getInfo(const Expr *From) {558InfoEntry Entry = findInfo(From);559if (Entry != PropagationMap.end()) {560PropagationInfo& PInfo = Entry->second;561return PInfo.getAsState(StateMap);562}563return CS_None;564}565566// If we already have info for To then update it, otherwise create a new entry.567void ConsumedStmtVisitor::setInfo(const Expr *To, ConsumedState NS) {568InfoEntry Entry = findInfo(To);569if (Entry != PropagationMap.end()) {570PropagationInfo& PInfo = Entry->second;571if (PInfo.isPointerToValue())572setStateForVarOrTmp(StateMap, PInfo, NS);573} else if (NS != CS_None) {574insertInfo(To, PropagationInfo(NS));575}576}577578void ConsumedStmtVisitor::checkCallability(const PropagationInfo &PInfo,579const FunctionDecl *FunDecl,580SourceLocation BlameLoc) {581assert(!PInfo.isTest());582583const CallableWhenAttr *CWAttr = FunDecl->getAttr<CallableWhenAttr>();584if (!CWAttr)585return;586587if (PInfo.isVar()) {588ConsumedState VarState = StateMap->getState(PInfo.getVar());589590if (VarState == CS_None || isCallableInState(CWAttr, VarState))591return;592593Analyzer.WarningsHandler.warnUseInInvalidState(594FunDecl->getNameAsString(), PInfo.getVar()->getNameAsString(),595stateToString(VarState), BlameLoc);596} else {597ConsumedState TmpState = PInfo.getAsState(StateMap);598599if (TmpState == CS_None || isCallableInState(CWAttr, TmpState))600return;601602Analyzer.WarningsHandler.warnUseOfTempInInvalidState(603FunDecl->getNameAsString(), stateToString(TmpState), BlameLoc);604}605}606607// Factors out common behavior for function, method, and operator calls.608// Check parameters and set parameter state if necessary.609// Returns true if the state of ObjArg is set, or false otherwise.610bool ConsumedStmtVisitor::handleCall(const CallExpr *Call, const Expr *ObjArg,611const FunctionDecl *FunD) {612unsigned Offset = 0;613if (isa<CXXOperatorCallExpr>(Call) && isa<CXXMethodDecl>(FunD))614Offset = 1; // first argument is 'this'615616// check explicit parameters617for (unsigned Index = Offset; Index < Call->getNumArgs(); ++Index) {618// Skip variable argument lists.619if (Index - Offset >= FunD->getNumParams())620break;621622const ParmVarDecl *Param = FunD->getParamDecl(Index - Offset);623QualType ParamType = Param->getType();624625InfoEntry Entry = findInfo(Call->getArg(Index));626627if (Entry == PropagationMap.end() || Entry->second.isTest())628continue;629PropagationInfo PInfo = Entry->second;630631// Check that the parameter is in the correct state.632if (ParamTypestateAttr *PTA = Param->getAttr<ParamTypestateAttr>()) {633ConsumedState ParamState = PInfo.getAsState(StateMap);634ConsumedState ExpectedState = mapParamTypestateAttrState(PTA);635636if (ParamState != ExpectedState)637Analyzer.WarningsHandler.warnParamTypestateMismatch(638Call->getArg(Index)->getExprLoc(),639stateToString(ExpectedState), stateToString(ParamState));640}641642if (!(Entry->second.isVar() || Entry->second.isTmp()))643continue;644645// Adjust state on the caller side.646if (ReturnTypestateAttr *RT = Param->getAttr<ReturnTypestateAttr>())647setStateForVarOrTmp(StateMap, PInfo, mapReturnTypestateAttrState(RT));648else if (isRValueRef(ParamType) || isConsumableType(ParamType))649setStateForVarOrTmp(StateMap, PInfo, consumed::CS_Consumed);650else if (isPointerOrRef(ParamType) &&651(!ParamType->getPointeeType().isConstQualified() ||652isSetOnReadPtrType(ParamType)))653setStateForVarOrTmp(StateMap, PInfo, consumed::CS_Unknown);654}655656if (!ObjArg)657return false;658659// check implicit 'self' parameter, if present660InfoEntry Entry = findInfo(ObjArg);661if (Entry != PropagationMap.end()) {662PropagationInfo PInfo = Entry->second;663checkCallability(PInfo, FunD, Call->getExprLoc());664665if (SetTypestateAttr *STA = FunD->getAttr<SetTypestateAttr>()) {666if (PInfo.isVar()) {667StateMap->setState(PInfo.getVar(), mapSetTypestateAttrState(STA));668return true;669}670else if (PInfo.isTmp()) {671StateMap->setState(PInfo.getTmp(), mapSetTypestateAttrState(STA));672return true;673}674}675else if (isTestingFunction(FunD) && PInfo.isVar()) {676PropagationMap.insert(PairType(Call,677PropagationInfo(PInfo.getVar(), testsFor(FunD))));678}679}680return false;681}682683void ConsumedStmtVisitor::propagateReturnType(const Expr *Call,684const FunctionDecl *Fun) {685QualType RetType = Fun->getCallResultType();686if (RetType->isReferenceType())687RetType = RetType->getPointeeType();688689if (isConsumableType(RetType)) {690ConsumedState ReturnState;691if (ReturnTypestateAttr *RTA = Fun->getAttr<ReturnTypestateAttr>())692ReturnState = mapReturnTypestateAttrState(RTA);693else694ReturnState = mapConsumableAttrState(RetType);695696PropagationMap.insert(PairType(Call, PropagationInfo(ReturnState)));697}698}699700void ConsumedStmtVisitor::VisitBinaryOperator(const BinaryOperator *BinOp) {701switch (BinOp->getOpcode()) {702case BO_LAnd:703case BO_LOr : {704InfoEntry LEntry = findInfo(BinOp->getLHS()),705REntry = findInfo(BinOp->getRHS());706707VarTestResult LTest, RTest;708709if (LEntry != PropagationMap.end() && LEntry->second.isVarTest()) {710LTest = LEntry->second.getVarTest();711} else {712LTest.Var = nullptr;713LTest.TestsFor = CS_None;714}715716if (REntry != PropagationMap.end() && REntry->second.isVarTest()) {717RTest = REntry->second.getVarTest();718} else {719RTest.Var = nullptr;720RTest.TestsFor = CS_None;721}722723if (!(LTest.Var == nullptr && RTest.Var == nullptr))724PropagationMap.insert(PairType(BinOp, PropagationInfo(BinOp,725static_cast<EffectiveOp>(BinOp->getOpcode() == BO_LOr), LTest, RTest)));726break;727}728729case BO_PtrMemD:730case BO_PtrMemI:731forwardInfo(BinOp->getLHS(), BinOp);732break;733734default:735break;736}737}738739void ConsumedStmtVisitor::VisitCallExpr(const CallExpr *Call) {740const FunctionDecl *FunDecl = Call->getDirectCallee();741if (!FunDecl)742return;743744// Special case for the std::move function.745// TODO: Make this more specific. (Deferred)746if (Call->isCallToStdMove()) {747copyInfo(Call->getArg(0), Call, CS_Consumed);748return;749}750751handleCall(Call, nullptr, FunDecl);752propagateReturnType(Call, FunDecl);753}754755void ConsumedStmtVisitor::VisitCastExpr(const CastExpr *Cast) {756forwardInfo(Cast->getSubExpr(), Cast);757}758759void ConsumedStmtVisitor::VisitCXXBindTemporaryExpr(760const CXXBindTemporaryExpr *Temp) {761762InfoEntry Entry = findInfo(Temp->getSubExpr());763764if (Entry != PropagationMap.end() && !Entry->second.isTest()) {765StateMap->setState(Temp, Entry->second.getAsState(StateMap));766PropagationMap.insert(PairType(Temp, PropagationInfo(Temp)));767}768}769770void ConsumedStmtVisitor::VisitCXXConstructExpr(const CXXConstructExpr *Call) {771CXXConstructorDecl *Constructor = Call->getConstructor();772773QualType ThisType = Constructor->getFunctionObjectParameterType();774775if (!isConsumableType(ThisType))776return;777778// FIXME: What should happen if someone annotates the move constructor?779if (ReturnTypestateAttr *RTA = Constructor->getAttr<ReturnTypestateAttr>()) {780// TODO: Adjust state of args appropriately.781ConsumedState RetState = mapReturnTypestateAttrState(RTA);782PropagationMap.insert(PairType(Call, PropagationInfo(RetState)));783} else if (Constructor->isDefaultConstructor()) {784PropagationMap.insert(PairType(Call,785PropagationInfo(consumed::CS_Consumed)));786} else if (Constructor->isMoveConstructor()) {787copyInfo(Call->getArg(0), Call, CS_Consumed);788} else if (Constructor->isCopyConstructor()) {789// Copy state from arg. If setStateOnRead then set arg to CS_Unknown.790ConsumedState NS =791isSetOnReadPtrType(Constructor->getThisType()) ?792CS_Unknown : CS_None;793copyInfo(Call->getArg(0), Call, NS);794} else {795// TODO: Adjust state of args appropriately.796ConsumedState RetState = mapConsumableAttrState(ThisType);797PropagationMap.insert(PairType(Call, PropagationInfo(RetState)));798}799}800801void ConsumedStmtVisitor::VisitCXXMemberCallExpr(802const CXXMemberCallExpr *Call) {803CXXMethodDecl* MD = Call->getMethodDecl();804if (!MD)805return;806807handleCall(Call, Call->getImplicitObjectArgument(), MD);808propagateReturnType(Call, MD);809}810811void ConsumedStmtVisitor::VisitCXXOperatorCallExpr(812const CXXOperatorCallExpr *Call) {813const auto *FunDecl = dyn_cast_or_null<FunctionDecl>(Call->getDirectCallee());814if (!FunDecl) return;815816if (Call->getOperator() == OO_Equal) {817ConsumedState CS = getInfo(Call->getArg(1));818if (!handleCall(Call, Call->getArg(0), FunDecl))819setInfo(Call->getArg(0), CS);820return;821}822823if (const auto *MCall = dyn_cast<CXXMemberCallExpr>(Call))824handleCall(MCall, MCall->getImplicitObjectArgument(), FunDecl);825else826handleCall(Call, Call->getArg(0), FunDecl);827828propagateReturnType(Call, FunDecl);829}830831void ConsumedStmtVisitor::VisitDeclRefExpr(const DeclRefExpr *DeclRef) {832if (const auto *Var = dyn_cast_or_null<VarDecl>(DeclRef->getDecl()))833if (StateMap->getState(Var) != consumed::CS_None)834PropagationMap.insert(PairType(DeclRef, PropagationInfo(Var)));835}836837void ConsumedStmtVisitor::VisitDeclStmt(const DeclStmt *DeclS) {838for (const auto *DI : DeclS->decls())839if (isa<VarDecl>(DI))840VisitVarDecl(cast<VarDecl>(DI));841842if (DeclS->isSingleDecl())843if (const auto *Var = dyn_cast_or_null<VarDecl>(DeclS->getSingleDecl()))844PropagationMap.insert(PairType(DeclS, PropagationInfo(Var)));845}846847void ConsumedStmtVisitor::VisitMaterializeTemporaryExpr(848const MaterializeTemporaryExpr *Temp) {849forwardInfo(Temp->getSubExpr(), Temp);850}851852void ConsumedStmtVisitor::VisitMemberExpr(const MemberExpr *MExpr) {853forwardInfo(MExpr->getBase(), MExpr);854}855856void ConsumedStmtVisitor::VisitParmVarDecl(const ParmVarDecl *Param) {857QualType ParamType = Param->getType();858ConsumedState ParamState = consumed::CS_None;859860if (const ParamTypestateAttr *PTA = Param->getAttr<ParamTypestateAttr>())861ParamState = mapParamTypestateAttrState(PTA);862else if (isConsumableType(ParamType))863ParamState = mapConsumableAttrState(ParamType);864else if (isRValueRef(ParamType) &&865isConsumableType(ParamType->getPointeeType()))866ParamState = mapConsumableAttrState(ParamType->getPointeeType());867else if (ParamType->isReferenceType() &&868isConsumableType(ParamType->getPointeeType()))869ParamState = consumed::CS_Unknown;870871if (ParamState != CS_None)872StateMap->setState(Param, ParamState);873}874875void ConsumedStmtVisitor::VisitReturnStmt(const ReturnStmt *Ret) {876ConsumedState ExpectedState = Analyzer.getExpectedReturnState();877878if (ExpectedState != CS_None) {879InfoEntry Entry = findInfo(Ret->getRetValue());880881if (Entry != PropagationMap.end()) {882ConsumedState RetState = Entry->second.getAsState(StateMap);883884if (RetState != ExpectedState)885Analyzer.WarningsHandler.warnReturnTypestateMismatch(886Ret->getReturnLoc(), stateToString(ExpectedState),887stateToString(RetState));888}889}890891StateMap->checkParamsForReturnTypestate(Ret->getBeginLoc(),892Analyzer.WarningsHandler);893}894895void ConsumedStmtVisitor::VisitUnaryOperator(const UnaryOperator *UOp) {896InfoEntry Entry = findInfo(UOp->getSubExpr());897if (Entry == PropagationMap.end()) return;898899switch (UOp->getOpcode()) {900case UO_AddrOf:901PropagationMap.insert(PairType(UOp, Entry->second));902break;903904case UO_LNot:905if (Entry->second.isTest())906PropagationMap.insert(PairType(UOp, Entry->second.invertTest()));907break;908909default:910break;911}912}913914// TODO: See if I need to check for reference types here.915void ConsumedStmtVisitor::VisitVarDecl(const VarDecl *Var) {916if (isConsumableType(Var->getType())) {917if (Var->hasInit()) {918MapType::iterator VIT = findInfo(Var->getInit()->IgnoreImplicit());919if (VIT != PropagationMap.end()) {920PropagationInfo PInfo = VIT->second;921ConsumedState St = PInfo.getAsState(StateMap);922923if (St != consumed::CS_None) {924StateMap->setState(Var, St);925return;926}927}928}929// Otherwise930StateMap->setState(Var, consumed::CS_Unknown);931}932}933934static void splitVarStateForIf(const IfStmt *IfNode, const VarTestResult &Test,935ConsumedStateMap *ThenStates,936ConsumedStateMap *ElseStates) {937ConsumedState VarState = ThenStates->getState(Test.Var);938939if (VarState == CS_Unknown) {940ThenStates->setState(Test.Var, Test.TestsFor);941ElseStates->setState(Test.Var, invertConsumedUnconsumed(Test.TestsFor));942} else if (VarState == invertConsumedUnconsumed(Test.TestsFor)) {943ThenStates->markUnreachable();944} else if (VarState == Test.TestsFor) {945ElseStates->markUnreachable();946}947}948949static void splitVarStateForIfBinOp(const PropagationInfo &PInfo,950ConsumedStateMap *ThenStates,951ConsumedStateMap *ElseStates) {952const VarTestResult <est = PInfo.getLTest(),953&RTest = PInfo.getRTest();954955ConsumedState LState = LTest.Var ? ThenStates->getState(LTest.Var) : CS_None,956RState = RTest.Var ? ThenStates->getState(RTest.Var) : CS_None;957958if (LTest.Var) {959if (PInfo.testEffectiveOp() == EO_And) {960if (LState == CS_Unknown) {961ThenStates->setState(LTest.Var, LTest.TestsFor);962} else if (LState == invertConsumedUnconsumed(LTest.TestsFor)) {963ThenStates->markUnreachable();964} else if (LState == LTest.TestsFor && isKnownState(RState)) {965if (RState == RTest.TestsFor)966ElseStates->markUnreachable();967else968ThenStates->markUnreachable();969}970} else {971if (LState == CS_Unknown) {972ElseStates->setState(LTest.Var,973invertConsumedUnconsumed(LTest.TestsFor));974} else if (LState == LTest.TestsFor) {975ElseStates->markUnreachable();976} else if (LState == invertConsumedUnconsumed(LTest.TestsFor) &&977isKnownState(RState)) {978if (RState == RTest.TestsFor)979ElseStates->markUnreachable();980else981ThenStates->markUnreachable();982}983}984}985986if (RTest.Var) {987if (PInfo.testEffectiveOp() == EO_And) {988if (RState == CS_Unknown)989ThenStates->setState(RTest.Var, RTest.TestsFor);990else if (RState == invertConsumedUnconsumed(RTest.TestsFor))991ThenStates->markUnreachable();992} else {993if (RState == CS_Unknown)994ElseStates->setState(RTest.Var,995invertConsumedUnconsumed(RTest.TestsFor));996else if (RState == RTest.TestsFor)997ElseStates->markUnreachable();998}999}1000}10011002bool ConsumedBlockInfo::allBackEdgesVisited(const CFGBlock *CurrBlock,1003const CFGBlock *TargetBlock) {1004assert(CurrBlock && "Block pointer must not be NULL");1005assert(TargetBlock && "TargetBlock pointer must not be NULL");10061007unsigned int CurrBlockOrder = VisitOrder[CurrBlock->getBlockID()];1008for (CFGBlock::const_pred_iterator PI = TargetBlock->pred_begin(),1009PE = TargetBlock->pred_end(); PI != PE; ++PI) {1010if (*PI && CurrBlockOrder < VisitOrder[(*PI)->getBlockID()] )1011return false;1012}1013return true;1014}10151016void ConsumedBlockInfo::addInfo(1017const CFGBlock *Block, ConsumedStateMap *StateMap,1018std::unique_ptr<ConsumedStateMap> &OwnedStateMap) {1019assert(Block && "Block pointer must not be NULL");10201021auto &Entry = StateMapsArray[Block->getBlockID()];10221023if (Entry) {1024Entry->intersect(*StateMap);1025} else if (OwnedStateMap)1026Entry = std::move(OwnedStateMap);1027else1028Entry = std::make_unique<ConsumedStateMap>(*StateMap);1029}10301031void ConsumedBlockInfo::addInfo(const CFGBlock *Block,1032std::unique_ptr<ConsumedStateMap> StateMap) {1033assert(Block && "Block pointer must not be NULL");10341035auto &Entry = StateMapsArray[Block->getBlockID()];10361037if (Entry) {1038Entry->intersect(*StateMap);1039} else {1040Entry = std::move(StateMap);1041}1042}10431044ConsumedStateMap* ConsumedBlockInfo::borrowInfo(const CFGBlock *Block) {1045assert(Block && "Block pointer must not be NULL");1046assert(StateMapsArray[Block->getBlockID()] && "Block has no block info");10471048return StateMapsArray[Block->getBlockID()].get();1049}10501051void ConsumedBlockInfo::discardInfo(const CFGBlock *Block) {1052StateMapsArray[Block->getBlockID()] = nullptr;1053}10541055std::unique_ptr<ConsumedStateMap>1056ConsumedBlockInfo::getInfo(const CFGBlock *Block) {1057assert(Block && "Block pointer must not be NULL");10581059auto &Entry = StateMapsArray[Block->getBlockID()];1060return isBackEdgeTarget(Block) ? std::make_unique<ConsumedStateMap>(*Entry)1061: std::move(Entry);1062}10631064bool ConsumedBlockInfo::isBackEdge(const CFGBlock *From, const CFGBlock *To) {1065assert(From && "From block must not be NULL");1066assert(To && "From block must not be NULL");10671068return VisitOrder[From->getBlockID()] > VisitOrder[To->getBlockID()];1069}10701071bool ConsumedBlockInfo::isBackEdgeTarget(const CFGBlock *Block) {1072assert(Block && "Block pointer must not be NULL");10731074// Anything with less than two predecessors can't be the target of a back1075// edge.1076if (Block->pred_size() < 2)1077return false;10781079unsigned int BlockVisitOrder = VisitOrder[Block->getBlockID()];1080for (CFGBlock::const_pred_iterator PI = Block->pred_begin(),1081PE = Block->pred_end(); PI != PE; ++PI) {1082if (*PI && BlockVisitOrder < VisitOrder[(*PI)->getBlockID()])1083return true;1084}1085return false;1086}10871088void ConsumedStateMap::checkParamsForReturnTypestate(SourceLocation BlameLoc,1089ConsumedWarningsHandlerBase &WarningsHandler) const {10901091for (const auto &DM : VarMap) {1092if (isa<ParmVarDecl>(DM.first)) {1093const auto *Param = cast<ParmVarDecl>(DM.first);1094const ReturnTypestateAttr *RTA = Param->getAttr<ReturnTypestateAttr>();10951096if (!RTA)1097continue;10981099ConsumedState ExpectedState = mapReturnTypestateAttrState(RTA);1100if (DM.second != ExpectedState)1101WarningsHandler.warnParamReturnTypestateMismatch(BlameLoc,1102Param->getNameAsString(), stateToString(ExpectedState),1103stateToString(DM.second));1104}1105}1106}11071108void ConsumedStateMap::clearTemporaries() {1109TmpMap.clear();1110}11111112ConsumedState ConsumedStateMap::getState(const VarDecl *Var) const {1113VarMapType::const_iterator Entry = VarMap.find(Var);11141115if (Entry != VarMap.end())1116return Entry->second;11171118return CS_None;1119}11201121ConsumedState1122ConsumedStateMap::getState(const CXXBindTemporaryExpr *Tmp) const {1123TmpMapType::const_iterator Entry = TmpMap.find(Tmp);11241125if (Entry != TmpMap.end())1126return Entry->second;11271128return CS_None;1129}11301131void ConsumedStateMap::intersect(const ConsumedStateMap &Other) {1132ConsumedState LocalState;11331134if (this->From && this->From == Other.From && !Other.Reachable) {1135this->markUnreachable();1136return;1137}11381139for (const auto &DM : Other.VarMap) {1140LocalState = this->getState(DM.first);11411142if (LocalState == CS_None)1143continue;11441145if (LocalState != DM.second)1146VarMap[DM.first] = CS_Unknown;1147}1148}11491150void ConsumedStateMap::intersectAtLoopHead(const CFGBlock *LoopHead,1151const CFGBlock *LoopBack, const ConsumedStateMap *LoopBackStates,1152ConsumedWarningsHandlerBase &WarningsHandler) {11531154ConsumedState LocalState;1155SourceLocation BlameLoc = getLastStmtLoc(LoopBack);11561157for (const auto &DM : LoopBackStates->VarMap) {1158LocalState = this->getState(DM.first);11591160if (LocalState == CS_None)1161continue;11621163if (LocalState != DM.second) {1164VarMap[DM.first] = CS_Unknown;1165WarningsHandler.warnLoopStateMismatch(BlameLoc,1166DM.first->getNameAsString());1167}1168}1169}11701171void ConsumedStateMap::markUnreachable() {1172this->Reachable = false;1173VarMap.clear();1174TmpMap.clear();1175}11761177void ConsumedStateMap::setState(const VarDecl *Var, ConsumedState State) {1178VarMap[Var] = State;1179}11801181void ConsumedStateMap::setState(const CXXBindTemporaryExpr *Tmp,1182ConsumedState State) {1183TmpMap[Tmp] = State;1184}11851186void ConsumedStateMap::remove(const CXXBindTemporaryExpr *Tmp) {1187TmpMap.erase(Tmp);1188}11891190bool ConsumedStateMap::operator!=(const ConsumedStateMap *Other) const {1191for (const auto &DM : Other->VarMap)1192if (this->getState(DM.first) != DM.second)1193return true;1194return false;1195}11961197void ConsumedAnalyzer::determineExpectedReturnState(AnalysisDeclContext &AC,1198const FunctionDecl *D) {1199QualType ReturnType;1200if (const auto *Constructor = dyn_cast<CXXConstructorDecl>(D)) {1201ReturnType = Constructor->getFunctionObjectParameterType();1202} else1203ReturnType = D->getCallResultType();12041205if (const ReturnTypestateAttr *RTSAttr = D->getAttr<ReturnTypestateAttr>()) {1206const CXXRecordDecl *RD = ReturnType->getAsCXXRecordDecl();1207if (!RD || !RD->hasAttr<ConsumableAttr>()) {1208// FIXME: This should be removed when template instantiation propagates1209// attributes at template specialization definition, not1210// declaration. When it is removed the test needs to be enabled1211// in SemaDeclAttr.cpp.1212WarningsHandler.warnReturnTypestateForUnconsumableType(1213RTSAttr->getLocation(), ReturnType.getAsString());1214ExpectedReturnState = CS_None;1215} else1216ExpectedReturnState = mapReturnTypestateAttrState(RTSAttr);1217} else if (isConsumableType(ReturnType)) {1218if (isAutoCastType(ReturnType)) // We can auto-cast the state to the1219ExpectedReturnState = CS_None; // expected state.1220else1221ExpectedReturnState = mapConsumableAttrState(ReturnType);1222}1223else1224ExpectedReturnState = CS_None;1225}12261227bool ConsumedAnalyzer::splitState(const CFGBlock *CurrBlock,1228const ConsumedStmtVisitor &Visitor) {1229std::unique_ptr<ConsumedStateMap> FalseStates(1230new ConsumedStateMap(*CurrStates));1231PropagationInfo PInfo;12321233if (const auto *IfNode =1234dyn_cast_or_null<IfStmt>(CurrBlock->getTerminator().getStmt())) {1235const Expr *Cond = IfNode->getCond();12361237PInfo = Visitor.getInfo(Cond);1238if (!PInfo.isValid() && isa<BinaryOperator>(Cond))1239PInfo = Visitor.getInfo(cast<BinaryOperator>(Cond)->getRHS());12401241if (PInfo.isVarTest()) {1242CurrStates->setSource(Cond);1243FalseStates->setSource(Cond);1244splitVarStateForIf(IfNode, PInfo.getVarTest(), CurrStates.get(),1245FalseStates.get());1246} else if (PInfo.isBinTest()) {1247CurrStates->setSource(PInfo.testSourceNode());1248FalseStates->setSource(PInfo.testSourceNode());1249splitVarStateForIfBinOp(PInfo, CurrStates.get(), FalseStates.get());1250} else {1251return false;1252}1253} else if (const auto *BinOp =1254dyn_cast_or_null<BinaryOperator>(CurrBlock->getTerminator().getStmt())) {1255PInfo = Visitor.getInfo(BinOp->getLHS());1256if (!PInfo.isVarTest()) {1257if ((BinOp = dyn_cast_or_null<BinaryOperator>(BinOp->getLHS()))) {1258PInfo = Visitor.getInfo(BinOp->getRHS());12591260if (!PInfo.isVarTest())1261return false;1262} else {1263return false;1264}1265}12661267CurrStates->setSource(BinOp);1268FalseStates->setSource(BinOp);12691270const VarTestResult &Test = PInfo.getVarTest();1271ConsumedState VarState = CurrStates->getState(Test.Var);12721273if (BinOp->getOpcode() == BO_LAnd) {1274if (VarState == CS_Unknown)1275CurrStates->setState(Test.Var, Test.TestsFor);1276else if (VarState == invertConsumedUnconsumed(Test.TestsFor))1277CurrStates->markUnreachable();12781279} else if (BinOp->getOpcode() == BO_LOr) {1280if (VarState == CS_Unknown)1281FalseStates->setState(Test.Var,1282invertConsumedUnconsumed(Test.TestsFor));1283else if (VarState == Test.TestsFor)1284FalseStates->markUnreachable();1285}1286} else {1287return false;1288}12891290CFGBlock::const_succ_iterator SI = CurrBlock->succ_begin();12911292if (*SI)1293BlockInfo.addInfo(*SI, std::move(CurrStates));1294else1295CurrStates = nullptr;12961297if (*++SI)1298BlockInfo.addInfo(*SI, std::move(FalseStates));12991300return true;1301}13021303void ConsumedAnalyzer::run(AnalysisDeclContext &AC) {1304const auto *D = dyn_cast_or_null<FunctionDecl>(AC.getDecl());1305if (!D)1306return;13071308CFG *CFGraph = AC.getCFG();1309if (!CFGraph)1310return;13111312determineExpectedReturnState(AC, D);13131314PostOrderCFGView *SortedGraph = AC.getAnalysis<PostOrderCFGView>();1315// AC.getCFG()->viewCFG(LangOptions());13161317BlockInfo = ConsumedBlockInfo(CFGraph->getNumBlockIDs(), SortedGraph);13181319CurrStates = std::make_unique<ConsumedStateMap>();1320ConsumedStmtVisitor Visitor(*this, CurrStates.get());13211322// Add all trackable parameters to the state map.1323for (const auto *PI : D->parameters())1324Visitor.VisitParmVarDecl(PI);13251326// Visit all of the function's basic blocks.1327for (const auto *CurrBlock : *SortedGraph) {1328if (!CurrStates)1329CurrStates = BlockInfo.getInfo(CurrBlock);13301331if (!CurrStates) {1332continue;1333} else if (!CurrStates->isReachable()) {1334CurrStates = nullptr;1335continue;1336}13371338Visitor.reset(CurrStates.get());13391340// Visit all of the basic block's statements.1341for (const auto &B : *CurrBlock) {1342switch (B.getKind()) {1343case CFGElement::Statement:1344Visitor.Visit(B.castAs<CFGStmt>().getStmt());1345break;13461347case CFGElement::TemporaryDtor: {1348const CFGTemporaryDtor &DTor = B.castAs<CFGTemporaryDtor>();1349const CXXBindTemporaryExpr *BTE = DTor.getBindTemporaryExpr();13501351Visitor.checkCallability(PropagationInfo(BTE),1352DTor.getDestructorDecl(AC.getASTContext()),1353BTE->getExprLoc());1354CurrStates->remove(BTE);1355break;1356}13571358case CFGElement::AutomaticObjectDtor: {1359const CFGAutomaticObjDtor &DTor = B.castAs<CFGAutomaticObjDtor>();1360SourceLocation Loc = DTor.getTriggerStmt()->getEndLoc();1361const VarDecl *Var = DTor.getVarDecl();13621363Visitor.checkCallability(PropagationInfo(Var),1364DTor.getDestructorDecl(AC.getASTContext()),1365Loc);1366break;1367}13681369default:1370break;1371}1372}13731374// TODO: Handle other forms of branching with precision, including while-1375// and for-loops. (Deferred)1376if (!splitState(CurrBlock, Visitor)) {1377CurrStates->setSource(nullptr);13781379if (CurrBlock->succ_size() > 1 ||1380(CurrBlock->succ_size() == 1 &&1381(*CurrBlock->succ_begin())->pred_size() > 1)) {13821383auto *RawState = CurrStates.get();13841385for (CFGBlock::const_succ_iterator SI = CurrBlock->succ_begin(),1386SE = CurrBlock->succ_end(); SI != SE; ++SI) {1387if (*SI == nullptr) continue;13881389if (BlockInfo.isBackEdge(CurrBlock, *SI)) {1390BlockInfo.borrowInfo(*SI)->intersectAtLoopHead(1391*SI, CurrBlock, RawState, WarningsHandler);13921393if (BlockInfo.allBackEdgesVisited(CurrBlock, *SI))1394BlockInfo.discardInfo(*SI);1395} else {1396BlockInfo.addInfo(*SI, RawState, CurrStates);1397}1398}13991400CurrStates = nullptr;1401}1402}14031404if (CurrBlock == &AC.getCFG()->getExit() &&1405D->getCallResultType()->isVoidType())1406CurrStates->checkParamsForReturnTypestate(D->getLocation(),1407WarningsHandler);1408} // End of block iterator.14091410// Delete the last existing state map.1411CurrStates = nullptr;14121413WarningsHandler.emitDiagnostics();1414}141514161417