Path: blob/main/contrib/llvm-project/llvm/lib/Transforms/IPO/CrossDSOCFI.cpp
35266 views
//===-- CrossDSOCFI.cpp - Externalize this module's CFI checks ------------===//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.3// See https://llvm.org/LICENSE.txt for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===----------------------------------------------------------------------===//7//8// This pass exports all llvm.bitset's found in the module in the form of a9// __cfi_check function, which can be used to verify cross-DSO call targets.10//11//===----------------------------------------------------------------------===//1213#include "llvm/Transforms/IPO/CrossDSOCFI.h"14#include "llvm/ADT/SetVector.h"15#include "llvm/ADT/Statistic.h"16#include "llvm/IR/Constants.h"17#include "llvm/IR/Function.h"18#include "llvm/IR/GlobalObject.h"19#include "llvm/IR/IRBuilder.h"20#include "llvm/IR/Instructions.h"21#include "llvm/IR/Intrinsics.h"22#include "llvm/IR/MDBuilder.h"23#include "llvm/IR/Module.h"24#include "llvm/TargetParser/Triple.h"25#include "llvm/Transforms/IPO.h"2627using namespace llvm;2829#define DEBUG_TYPE "cross-dso-cfi"3031STATISTIC(NumTypeIds, "Number of unique type identifiers");3233namespace {3435struct CrossDSOCFI {36MDNode *VeryLikelyWeights;3738ConstantInt *extractNumericTypeId(MDNode *MD);39void buildCFICheck(Module &M);40bool runOnModule(Module &M);41};4243} // anonymous namespace4445/// Extracts a numeric type identifier from an MDNode containing type metadata.46ConstantInt *CrossDSOCFI::extractNumericTypeId(MDNode *MD) {47// This check excludes vtables for classes inside anonymous namespaces.48auto TM = dyn_cast<ValueAsMetadata>(MD->getOperand(1));49if (!TM)50return nullptr;51auto C = dyn_cast_or_null<ConstantInt>(TM->getValue());52if (!C) return nullptr;53// We are looking for i64 constants.54if (C->getBitWidth() != 64) return nullptr;5556return C;57}5859/// buildCFICheck - emits __cfi_check for the current module.60void CrossDSOCFI::buildCFICheck(Module &M) {61// FIXME: verify that __cfi_check ends up near the end of the code section,62// but before the jump slots created in LowerTypeTests.63SetVector<uint64_t> TypeIds;64SmallVector<MDNode *, 2> Types;65for (GlobalObject &GO : M.global_objects()) {66Types.clear();67GO.getMetadata(LLVMContext::MD_type, Types);68for (MDNode *Type : Types)69if (ConstantInt *TypeId = extractNumericTypeId(Type))70TypeIds.insert(TypeId->getZExtValue());71}7273NamedMDNode *CfiFunctionsMD = M.getNamedMetadata("cfi.functions");74if (CfiFunctionsMD) {75for (auto *Func : CfiFunctionsMD->operands()) {76assert(Func->getNumOperands() >= 2);77for (unsigned I = 2; I < Func->getNumOperands(); ++I)78if (ConstantInt *TypeId =79extractNumericTypeId(cast<MDNode>(Func->getOperand(I).get())))80TypeIds.insert(TypeId->getZExtValue());81}82}8384LLVMContext &Ctx = M.getContext();85FunctionCallee C = M.getOrInsertFunction(86"__cfi_check", Type::getVoidTy(Ctx), Type::getInt64Ty(Ctx),87PointerType::getUnqual(Ctx), PointerType::getUnqual(Ctx));88Function *F = cast<Function>(C.getCallee());89// Take over the existing function. The frontend emits a weak stub so that the90// linker knows about the symbol; this pass replaces the function body.91F->deleteBody();92F->setAlignment(Align(4096));9394Triple T(M.getTargetTriple());95if (T.isARM() || T.isThumb())96F->addFnAttr("target-features", "+thumb-mode");9798auto args = F->arg_begin();99Value &CallSiteTypeId = *(args++);100CallSiteTypeId.setName("CallSiteTypeId");101Value &Addr = *(args++);102Addr.setName("Addr");103Value &CFICheckFailData = *(args++);104CFICheckFailData.setName("CFICheckFailData");105assert(args == F->arg_end());106107BasicBlock *BB = BasicBlock::Create(Ctx, "entry", F);108BasicBlock *ExitBB = BasicBlock::Create(Ctx, "exit", F);109110BasicBlock *TrapBB = BasicBlock::Create(Ctx, "fail", F);111IRBuilder<> IRBFail(TrapBB);112FunctionCallee CFICheckFailFn = M.getOrInsertFunction(113"__cfi_check_fail", Type::getVoidTy(Ctx), PointerType::getUnqual(Ctx),114PointerType::getUnqual(Ctx));115IRBFail.CreateCall(CFICheckFailFn, {&CFICheckFailData, &Addr});116IRBFail.CreateBr(ExitBB);117118IRBuilder<> IRBExit(ExitBB);119IRBExit.CreateRetVoid();120121IRBuilder<> IRB(BB);122SwitchInst *SI = IRB.CreateSwitch(&CallSiteTypeId, TrapBB, TypeIds.size());123for (uint64_t TypeId : TypeIds) {124ConstantInt *CaseTypeId = ConstantInt::get(Type::getInt64Ty(Ctx), TypeId);125BasicBlock *TestBB = BasicBlock::Create(Ctx, "test", F);126IRBuilder<> IRBTest(TestBB);127Function *BitsetTestFn = Intrinsic::getDeclaration(&M, Intrinsic::type_test);128129Value *Test = IRBTest.CreateCall(130BitsetTestFn, {&Addr, MetadataAsValue::get(131Ctx, ConstantAsMetadata::get(CaseTypeId))});132BranchInst *BI = IRBTest.CreateCondBr(Test, ExitBB, TrapBB);133BI->setMetadata(LLVMContext::MD_prof, VeryLikelyWeights);134135SI->addCase(CaseTypeId, TestBB);136++NumTypeIds;137}138}139140bool CrossDSOCFI::runOnModule(Module &M) {141VeryLikelyWeights = MDBuilder(M.getContext()).createLikelyBranchWeights();142if (M.getModuleFlag("Cross-DSO CFI") == nullptr)143return false;144buildCFICheck(M);145return true;146}147148PreservedAnalyses CrossDSOCFIPass::run(Module &M, ModuleAnalysisManager &AM) {149CrossDSOCFI Impl;150bool Changed = Impl.runOnModule(M);151if (!Changed)152return PreservedAnalyses::all();153return PreservedAnalyses::none();154}155156157