Path: blob/main/contrib/llvm-project/llvm/lib/Analysis/IR2Vec.cpp
213766 views
//===- IR2Vec.cpp - Implementation of IR2Vec -----------------------------===//1//2// Part of the LLVM Project, under the Apache License v2.0 with LLVM3// Exceptions. See the LICENSE file for license information.4// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception5//6//===----------------------------------------------------------------------===//7///8/// \file9/// This file implements the IR2Vec algorithm.10///11//===----------------------------------------------------------------------===//1213#include "llvm/Analysis/IR2Vec.h"1415#include "llvm/ADT/DepthFirstIterator.h"16#include "llvm/ADT/Sequence.h"17#include "llvm/ADT/Statistic.h"18#include "llvm/IR/CFG.h"19#include "llvm/IR/Module.h"20#include "llvm/IR/PassManager.h"21#include "llvm/Support/Debug.h"22#include "llvm/Support/Errc.h"23#include "llvm/Support/Error.h"24#include "llvm/Support/ErrorHandling.h"25#include "llvm/Support/Format.h"26#include "llvm/Support/MemoryBuffer.h"2728using namespace llvm;29using namespace ir2vec;3031#define DEBUG_TYPE "ir2vec"3233STATISTIC(VocabMissCounter,34"Number of lookups to entites not present in the vocabulary");3536namespace llvm {37namespace ir2vec {38static cl::OptionCategory IR2VecCategory("IR2Vec Options");3940// FIXME: Use a default vocab when not specified41static cl::opt<std::string>42VocabFile("ir2vec-vocab-path", cl::Optional,43cl::desc("Path to the vocabulary file for IR2Vec"), cl::init(""),44cl::cat(IR2VecCategory));45cl::opt<float> OpcWeight("ir2vec-opc-weight", cl::Optional, cl::init(1.0),46cl::desc("Weight for opcode embeddings"),47cl::cat(IR2VecCategory));48cl::opt<float> TypeWeight("ir2vec-type-weight", cl::Optional, cl::init(0.5),49cl::desc("Weight for type embeddings"),50cl::cat(IR2VecCategory));51cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional, cl::init(0.2),52cl::desc("Weight for argument embeddings"),53cl::cat(IR2VecCategory));54} // namespace ir2vec55} // namespace llvm5657AnalysisKey IR2VecVocabAnalysis::Key;5859// ==----------------------------------------------------------------------===//60// Local helper functions61//===----------------------------------------------------------------------===//62namespace llvm::json {63inline bool fromJSON(const llvm::json::Value &E, Embedding &Out,64llvm::json::Path P) {65std::vector<double> TempOut;66if (!llvm::json::fromJSON(E, TempOut, P))67return false;68Out = Embedding(std::move(TempOut));69return true;70}71} // namespace llvm::json7273// ==----------------------------------------------------------------------===//74// Embedding75//===----------------------------------------------------------------------===//76Embedding &Embedding::operator+=(const Embedding &RHS) {77assert(this->size() == RHS.size() && "Vectors must have the same dimension");78std::transform(this->begin(), this->end(), RHS.begin(), this->begin(),79std::plus<double>());80return *this;81}8283Embedding Embedding::operator+(const Embedding &RHS) const {84Embedding Result(*this);85Result += RHS;86return Result;87}8889Embedding &Embedding::operator-=(const Embedding &RHS) {90assert(this->size() == RHS.size() && "Vectors must have the same dimension");91std::transform(this->begin(), this->end(), RHS.begin(), this->begin(),92std::minus<double>());93return *this;94}9596Embedding Embedding::operator-(const Embedding &RHS) const {97Embedding Result(*this);98Result -= RHS;99return Result;100}101102Embedding &Embedding::operator*=(double Factor) {103std::transform(this->begin(), this->end(), this->begin(),104[Factor](double Elem) { return Elem * Factor; });105return *this;106}107108Embedding Embedding::operator*(double Factor) const {109Embedding Result(*this);110Result *= Factor;111return Result;112}113114Embedding &Embedding::scaleAndAdd(const Embedding &Src, float Factor) {115assert(this->size() == Src.size() && "Vectors must have the same dimension");116for (size_t Itr = 0; Itr < this->size(); ++Itr)117(*this)[Itr] += Src[Itr] * Factor;118return *this;119}120121bool Embedding::approximatelyEquals(const Embedding &RHS,122double Tolerance) const {123assert(this->size() == RHS.size() && "Vectors must have the same dimension");124for (size_t Itr = 0; Itr < this->size(); ++Itr)125if (std::abs((*this)[Itr] - RHS[Itr]) > Tolerance)126return false;127return true;128}129130void Embedding::print(raw_ostream &OS) const {131OS << " [";132for (const auto &Elem : Data)133OS << " " << format("%.2f", Elem) << " ";134OS << "]\n";135}136137// ==----------------------------------------------------------------------===//138// Embedder and its subclasses139//===----------------------------------------------------------------------===//140141Embedder::Embedder(const Function &F, const Vocabulary &Vocab)142: F(F), Vocab(Vocab), Dimension(Vocab.getDimension()),143OpcWeight(::OpcWeight), TypeWeight(::TypeWeight), ArgWeight(::ArgWeight) {144}145146std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,147const Vocabulary &Vocab) {148switch (Mode) {149case IR2VecKind::Symbolic:150return std::make_unique<SymbolicEmbedder>(F, Vocab);151}152return nullptr;153}154155const InstEmbeddingsMap &Embedder::getInstVecMap() const {156if (InstVecMap.empty())157computeEmbeddings();158return InstVecMap;159}160161const BBEmbeddingsMap &Embedder::getBBVecMap() const {162if (BBVecMap.empty())163computeEmbeddings();164return BBVecMap;165}166167const Embedding &Embedder::getBBVector(const BasicBlock &BB) const {168auto It = BBVecMap.find(&BB);169if (It != BBVecMap.end())170return It->second;171computeEmbeddings(BB);172return BBVecMap[&BB];173}174175const Embedding &Embedder::getFunctionVector() const {176// Currently, we always (re)compute the embeddings for the function.177// This is cheaper than caching the vector.178computeEmbeddings();179return FuncVector;180}181182void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {183Embedding BBVector(Dimension, 0);184185// We consider only the non-debug and non-pseudo instructions186for (const auto &I : BB.instructionsWithoutDebug()) {187Embedding ArgEmb(Dimension, 0);188for (const auto &Op : I.operands())189ArgEmb += Vocab[Op];190auto InstVector =191Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;192InstVecMap[&I] = InstVector;193BBVector += InstVector;194}195BBVecMap[&BB] = BBVector;196}197198void SymbolicEmbedder::computeEmbeddings() const {199if (F.isDeclaration())200return;201202// Consider only the basic blocks that are reachable from entry203for (const BasicBlock *BB : depth_first(&F)) {204computeEmbeddings(*BB);205FuncVector += BBVecMap[BB];206}207}208209// ==----------------------------------------------------------------------===//210// Vocabulary211//===----------------------------------------------------------------------===//212213Vocabulary::Vocabulary(VocabVector &&Vocab)214: Vocab(std::move(Vocab)), Valid(true) {}215216bool Vocabulary::isValid() const {217return Vocab.size() == (MaxOpcodes + MaxTypeIDs + MaxOperandKinds) && Valid;218}219220size_t Vocabulary::size() const {221assert(Valid && "IR2Vec Vocabulary is invalid");222return Vocab.size();223}224225unsigned Vocabulary::getDimension() const {226assert(Valid && "IR2Vec Vocabulary is invalid");227return Vocab[0].size();228}229230const Embedding &Vocabulary::operator[](unsigned Opcode) const {231assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");232return Vocab[Opcode - 1];233}234235const Embedding &Vocabulary::operator[](Type::TypeID TypeId) const {236assert(static_cast<unsigned>(TypeId) < MaxTypeIDs && "Invalid type ID");237return Vocab[MaxOpcodes + static_cast<unsigned>(TypeId)];238}239240const ir2vec::Embedding &Vocabulary::operator[](const Value *Arg) const {241OperandKind ArgKind = getOperandKind(Arg);242return Vocab[MaxOpcodes + MaxTypeIDs + static_cast<unsigned>(ArgKind)];243}244245StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) {246assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");247#define HANDLE_INST(NUM, OPCODE, CLASS) \248if (Opcode == NUM) { \249return #OPCODE; \250}251#include "llvm/IR/Instruction.def"252#undef HANDLE_INST253return "UnknownOpcode";254}255256StringRef Vocabulary::getVocabKeyForTypeID(Type::TypeID TypeID) {257switch (TypeID) {258case Type::VoidTyID:259return "VoidTy";260case Type::HalfTyID:261case Type::BFloatTyID:262case Type::FloatTyID:263case Type::DoubleTyID:264case Type::X86_FP80TyID:265case Type::FP128TyID:266case Type::PPC_FP128TyID:267return "FloatTy";268case Type::IntegerTyID:269return "IntegerTy";270case Type::FunctionTyID:271return "FunctionTy";272case Type::StructTyID:273return "StructTy";274case Type::ArrayTyID:275return "ArrayTy";276case Type::PointerTyID:277case Type::TypedPointerTyID:278return "PointerTy";279case Type::FixedVectorTyID:280case Type::ScalableVectorTyID:281return "VectorTy";282case Type::LabelTyID:283return "LabelTy";284case Type::TokenTyID:285return "TokenTy";286case Type::MetadataTyID:287return "MetadataTy";288case Type::X86_AMXTyID:289case Type::TargetExtTyID:290return "UnknownTy";291}292return "UnknownTy";293}294295StringRef Vocabulary::getVocabKeyForOperandKind(Vocabulary::OperandKind Kind) {296unsigned Index = static_cast<unsigned>(Kind);297assert(Index < MaxOperandKinds && "Invalid OperandKind");298return OperandKindNames[Index];299}300301Vocabulary::VocabVector Vocabulary::createDummyVocabForTest(unsigned Dim) {302VocabVector DummyVocab;303float DummyVal = 0.1f;304// Create a dummy vocabulary with entries for all opcodes, types, and305// operand306for (unsigned _ : seq(0u, Vocabulary::MaxOpcodes + Vocabulary::MaxTypeIDs +307Vocabulary::MaxOperandKinds)) {308DummyVocab.push_back(Embedding(Dim, DummyVal));309DummyVal += 0.1;310}311return DummyVocab;312}313314// Helper function to classify an operand into OperandKind315Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) {316if (isa<Function>(Op))317return OperandKind::FunctionID;318if (isa<PointerType>(Op->getType()))319return OperandKind::PointerID;320if (isa<Constant>(Op))321return OperandKind::ConstantID;322return OperandKind::VariableID;323}324325StringRef Vocabulary::getStringKey(unsigned Pos) {326assert(Pos < MaxOpcodes + MaxTypeIDs + MaxOperandKinds &&327"Position out of bounds in vocabulary");328// Opcode329if (Pos < MaxOpcodes)330return getVocabKeyForOpcode(Pos + 1);331// Type332if (Pos < MaxOpcodes + MaxTypeIDs)333return getVocabKeyForTypeID(static_cast<Type::TypeID>(Pos - MaxOpcodes));334// Operand335return getVocabKeyForOperandKind(336static_cast<OperandKind>(Pos - MaxOpcodes - MaxTypeIDs));337}338339// For now, assume vocabulary is stable unless explicitly invalidated.340bool Vocabulary::invalidate(Module &M, const PreservedAnalyses &PA,341ModuleAnalysisManager::Invalidator &Inv) const {342auto PAC = PA.getChecker<IR2VecVocabAnalysis>();343return !(PAC.preservedWhenStateless());344}345346// ==----------------------------------------------------------------------===//347// IR2VecVocabAnalysis348//===----------------------------------------------------------------------===//349350Error IR2VecVocabAnalysis::parseVocabSection(351StringRef Key, const json::Value &ParsedVocabValue, VocabMap &TargetVocab,352unsigned &Dim) {353json::Path::Root Path("");354const json::Object *RootObj = ParsedVocabValue.getAsObject();355if (!RootObj)356return createStringError(errc::invalid_argument,357"JSON root is not an object");358359const json::Value *SectionValue = RootObj->get(Key);360if (!SectionValue)361return createStringError(errc::invalid_argument,362"Missing '" + std::string(Key) +363"' section in vocabulary file");364if (!json::fromJSON(*SectionValue, TargetVocab, Path))365return createStringError(errc::illegal_byte_sequence,366"Unable to parse '" + std::string(Key) +367"' section from vocabulary");368369Dim = TargetVocab.begin()->second.size();370if (Dim == 0)371return createStringError(errc::illegal_byte_sequence,372"Dimension of '" + std::string(Key) +373"' section of the vocabulary is zero");374375if (!std::all_of(TargetVocab.begin(), TargetVocab.end(),376[Dim](const std::pair<StringRef, Embedding> &Entry) {377return Entry.second.size() == Dim;378}))379return createStringError(380errc::illegal_byte_sequence,381"All vectors in the '" + std::string(Key) +382"' section of the vocabulary are not of the same dimension");383384return Error::success();385}386387// FIXME: Make this optional. We can avoid file reads388// by auto-generating a default vocabulary during the build time.389Error IR2VecVocabAnalysis::readVocabulary() {390auto BufOrError = MemoryBuffer::getFileOrSTDIN(VocabFile, /*IsText=*/true);391if (!BufOrError)392return createFileError(VocabFile, BufOrError.getError());393394auto Content = BufOrError.get()->getBuffer();395396Expected<json::Value> ParsedVocabValue = json::parse(Content);397if (!ParsedVocabValue)398return ParsedVocabValue.takeError();399400unsigned OpcodeDim = 0, TypeDim = 0, ArgDim = 0;401if (auto Err =402parseVocabSection("Opcodes", *ParsedVocabValue, OpcVocab, OpcodeDim))403return Err;404405if (auto Err =406parseVocabSection("Types", *ParsedVocabValue, TypeVocab, TypeDim))407return Err;408409if (auto Err =410parseVocabSection("Arguments", *ParsedVocabValue, ArgVocab, ArgDim))411return Err;412413if (!(OpcodeDim == TypeDim && TypeDim == ArgDim))414return createStringError(errc::illegal_byte_sequence,415"Vocabulary sections have different dimensions");416417return Error::success();418}419420void IR2VecVocabAnalysis::generateNumMappedVocab() {421422// Helper for handling missing entities in the vocabulary.423// Currently, we use a zero vector. In the future, we will throw an error to424// ensure that *all* known entities are present in the vocabulary.425auto handleMissingEntity = [](const std::string &Val) {426LLVM_DEBUG(errs() << Val427<< " is not in vocabulary, using zero vector; This "428"would result in an error in future.\n");429++VocabMissCounter;430};431432unsigned Dim = OpcVocab.begin()->second.size();433assert(Dim > 0 && "Vocabulary dimension must be greater than zero");434435// Handle Opcodes436std::vector<Embedding> NumericOpcodeEmbeddings(Vocabulary::MaxOpcodes,437Embedding(Dim, 0));438for (unsigned Opcode : seq(0u, Vocabulary::MaxOpcodes)) {439StringRef VocabKey = Vocabulary::getVocabKeyForOpcode(Opcode + 1);440auto It = OpcVocab.find(VocabKey.str());441if (It != OpcVocab.end())442NumericOpcodeEmbeddings[Opcode] = It->second;443else444handleMissingEntity(VocabKey.str());445}446Vocab.insert(Vocab.end(), NumericOpcodeEmbeddings.begin(),447NumericOpcodeEmbeddings.end());448449// Handle Types450std::vector<Embedding> NumericTypeEmbeddings(Vocabulary::MaxTypeIDs,451Embedding(Dim, 0));452for (unsigned TypeID : seq(0u, Vocabulary::MaxTypeIDs)) {453StringRef VocabKey =454Vocabulary::getVocabKeyForTypeID(static_cast<Type::TypeID>(TypeID));455if (auto It = TypeVocab.find(VocabKey.str()); It != TypeVocab.end()) {456NumericTypeEmbeddings[TypeID] = It->second;457continue;458}459handleMissingEntity(VocabKey.str());460}461Vocab.insert(Vocab.end(), NumericTypeEmbeddings.begin(),462NumericTypeEmbeddings.end());463464// Handle Arguments/Operands465std::vector<Embedding> NumericArgEmbeddings(Vocabulary::MaxOperandKinds,466Embedding(Dim, 0));467for (unsigned OpKind : seq(0u, Vocabulary::MaxOperandKinds)) {468Vocabulary::OperandKind Kind = static_cast<Vocabulary::OperandKind>(OpKind);469StringRef VocabKey = Vocabulary::getVocabKeyForOperandKind(Kind);470auto It = ArgVocab.find(VocabKey.str());471if (It != ArgVocab.end()) {472NumericArgEmbeddings[OpKind] = It->second;473continue;474}475handleMissingEntity(VocabKey.str());476}477Vocab.insert(Vocab.end(), NumericArgEmbeddings.begin(),478NumericArgEmbeddings.end());479}480481IR2VecVocabAnalysis::IR2VecVocabAnalysis(const VocabVector &Vocab)482: Vocab(Vocab) {}483484IR2VecVocabAnalysis::IR2VecVocabAnalysis(VocabVector &&Vocab)485: Vocab(std::move(Vocab)) {}486487void IR2VecVocabAnalysis::emitError(Error Err, LLVMContext &Ctx) {488handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {489Ctx.emitError("Error reading vocabulary: " + EI.message());490});491}492493IR2VecVocabAnalysis::Result494IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) {495auto Ctx = &M.getContext();496// If vocabulary is already populated by the constructor, use it.497if (!Vocab.empty())498return Vocabulary(std::move(Vocab));499500// Otherwise, try to read from the vocabulary file.501if (VocabFile.empty()) {502// FIXME: Use default vocabulary503Ctx->emitError("IR2Vec vocabulary file path not specified; You may need to "504"set it using --ir2vec-vocab-path");505return Vocabulary(); // Return invalid result506}507if (auto Err = readVocabulary()) {508emitError(std::move(Err), *Ctx);509return Vocabulary();510}511512// Scale the vocabulary sections based on the provided weights513auto scaleVocabSection = [](VocabMap &Vocab, double Weight) {514for (auto &Entry : Vocab)515Entry.second *= Weight;516};517scaleVocabSection(OpcVocab, OpcWeight);518scaleVocabSection(TypeVocab, TypeWeight);519scaleVocabSection(ArgVocab, ArgWeight);520521// Generate the numeric lookup vocabulary522generateNumMappedVocab();523524return Vocabulary(std::move(Vocab));525}526527// ==----------------------------------------------------------------------===//528// Printer Passes529//===----------------------------------------------------------------------===//530531PreservedAnalyses IR2VecPrinterPass::run(Module &M,532ModuleAnalysisManager &MAM) {533auto Vocabulary = MAM.getResult<IR2VecVocabAnalysis>(M);534assert(Vocabulary.isValid() && "IR2Vec Vocabulary is invalid");535536for (Function &F : M) {537std::unique_ptr<Embedder> Emb =538Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);539if (!Emb) {540OS << "Error creating IR2Vec embeddings \n";541continue;542}543544OS << "IR2Vec embeddings for function " << F.getName() << ":\n";545OS << "Function vector: ";546Emb->getFunctionVector().print(OS);547548OS << "Basic block vectors:\n";549const auto &BBMap = Emb->getBBVecMap();550for (const BasicBlock &BB : F) {551auto It = BBMap.find(&BB);552if (It != BBMap.end()) {553OS << "Basic block: " << BB.getName() << ":\n";554It->second.print(OS);555}556}557558OS << "Instruction vectors:\n";559const auto &InstMap = Emb->getInstVecMap();560for (const BasicBlock &BB : F) {561for (const Instruction &I : BB) {562auto It = InstMap.find(&I);563if (It != InstMap.end()) {564OS << "Instruction: ";565I.print(OS);566It->second.print(OS);567}568}569}570}571return PreservedAnalyses::all();572}573574PreservedAnalyses IR2VecVocabPrinterPass::run(Module &M,575ModuleAnalysisManager &MAM) {576auto IR2VecVocabulary = MAM.getResult<IR2VecVocabAnalysis>(M);577assert(IR2VecVocabulary.isValid() && "IR2Vec Vocabulary is invalid");578579// Print each entry580unsigned Pos = 0;581for (const auto &Entry : IR2VecVocabulary) {582OS << "Key: " << IR2VecVocabulary.getStringKey(Pos++) << ": ";583Entry.print(OS);584}585return PreservedAnalyses::all();586}587588589