Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
freebsd
GitHub Repository: freebsd/freebsd-src
Path: blob/main/contrib/llvm-project/llvm/lib/Analysis/IR2Vec.cpp
213766 views
1
//===- IR2Vec.cpp - Implementation of IR2Vec -----------------------------===//
2
//
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4
// Exceptions. See the LICENSE file for license information.
5
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
//
7
//===----------------------------------------------------------------------===//
8
///
9
/// \file
10
/// This file implements the IR2Vec algorithm.
11
///
12
//===----------------------------------------------------------------------===//
13
14
#include "llvm/Analysis/IR2Vec.h"
15
16
#include "llvm/ADT/DepthFirstIterator.h"
17
#include "llvm/ADT/Sequence.h"
18
#include "llvm/ADT/Statistic.h"
19
#include "llvm/IR/CFG.h"
20
#include "llvm/IR/Module.h"
21
#include "llvm/IR/PassManager.h"
22
#include "llvm/Support/Debug.h"
23
#include "llvm/Support/Errc.h"
24
#include "llvm/Support/Error.h"
25
#include "llvm/Support/ErrorHandling.h"
26
#include "llvm/Support/Format.h"
27
#include "llvm/Support/MemoryBuffer.h"
28
29
using namespace llvm;
30
using namespace ir2vec;
31
32
#define DEBUG_TYPE "ir2vec"
33
34
STATISTIC(VocabMissCounter,
35
"Number of lookups to entites not present in the vocabulary");
36
37
namespace llvm {
38
namespace ir2vec {
39
static cl::OptionCategory IR2VecCategory("IR2Vec Options");
40
41
// FIXME: Use a default vocab when not specified
42
static cl::opt<std::string>
43
VocabFile("ir2vec-vocab-path", cl::Optional,
44
cl::desc("Path to the vocabulary file for IR2Vec"), cl::init(""),
45
cl::cat(IR2VecCategory));
46
cl::opt<float> OpcWeight("ir2vec-opc-weight", cl::Optional, cl::init(1.0),
47
cl::desc("Weight for opcode embeddings"),
48
cl::cat(IR2VecCategory));
49
cl::opt<float> TypeWeight("ir2vec-type-weight", cl::Optional, cl::init(0.5),
50
cl::desc("Weight for type embeddings"),
51
cl::cat(IR2VecCategory));
52
cl::opt<float> ArgWeight("ir2vec-arg-weight", cl::Optional, cl::init(0.2),
53
cl::desc("Weight for argument embeddings"),
54
cl::cat(IR2VecCategory));
55
} // namespace ir2vec
56
} // namespace llvm
57
58
AnalysisKey IR2VecVocabAnalysis::Key;
59
60
// ==----------------------------------------------------------------------===//
61
// Local helper functions
62
//===----------------------------------------------------------------------===//
63
namespace llvm::json {
64
inline bool fromJSON(const llvm::json::Value &E, Embedding &Out,
65
llvm::json::Path P) {
66
std::vector<double> TempOut;
67
if (!llvm::json::fromJSON(E, TempOut, P))
68
return false;
69
Out = Embedding(std::move(TempOut));
70
return true;
71
}
72
} // namespace llvm::json
73
74
// ==----------------------------------------------------------------------===//
75
// Embedding
76
//===----------------------------------------------------------------------===//
77
Embedding &Embedding::operator+=(const Embedding &RHS) {
78
assert(this->size() == RHS.size() && "Vectors must have the same dimension");
79
std::transform(this->begin(), this->end(), RHS.begin(), this->begin(),
80
std::plus<double>());
81
return *this;
82
}
83
84
Embedding Embedding::operator+(const Embedding &RHS) const {
85
Embedding Result(*this);
86
Result += RHS;
87
return Result;
88
}
89
90
Embedding &Embedding::operator-=(const Embedding &RHS) {
91
assert(this->size() == RHS.size() && "Vectors must have the same dimension");
92
std::transform(this->begin(), this->end(), RHS.begin(), this->begin(),
93
std::minus<double>());
94
return *this;
95
}
96
97
Embedding Embedding::operator-(const Embedding &RHS) const {
98
Embedding Result(*this);
99
Result -= RHS;
100
return Result;
101
}
102
103
Embedding &Embedding::operator*=(double Factor) {
104
std::transform(this->begin(), this->end(), this->begin(),
105
[Factor](double Elem) { return Elem * Factor; });
106
return *this;
107
}
108
109
Embedding Embedding::operator*(double Factor) const {
110
Embedding Result(*this);
111
Result *= Factor;
112
return Result;
113
}
114
115
Embedding &Embedding::scaleAndAdd(const Embedding &Src, float Factor) {
116
assert(this->size() == Src.size() && "Vectors must have the same dimension");
117
for (size_t Itr = 0; Itr < this->size(); ++Itr)
118
(*this)[Itr] += Src[Itr] * Factor;
119
return *this;
120
}
121
122
bool Embedding::approximatelyEquals(const Embedding &RHS,
123
double Tolerance) const {
124
assert(this->size() == RHS.size() && "Vectors must have the same dimension");
125
for (size_t Itr = 0; Itr < this->size(); ++Itr)
126
if (std::abs((*this)[Itr] - RHS[Itr]) > Tolerance)
127
return false;
128
return true;
129
}
130
131
void Embedding::print(raw_ostream &OS) const {
132
OS << " [";
133
for (const auto &Elem : Data)
134
OS << " " << format("%.2f", Elem) << " ";
135
OS << "]\n";
136
}
137
138
// ==----------------------------------------------------------------------===//
139
// Embedder and its subclasses
140
//===----------------------------------------------------------------------===//
141
142
Embedder::Embedder(const Function &F, const Vocabulary &Vocab)
143
: F(F), Vocab(Vocab), Dimension(Vocab.getDimension()),
144
OpcWeight(::OpcWeight), TypeWeight(::TypeWeight), ArgWeight(::ArgWeight) {
145
}
146
147
std::unique_ptr<Embedder> Embedder::create(IR2VecKind Mode, const Function &F,
148
const Vocabulary &Vocab) {
149
switch (Mode) {
150
case IR2VecKind::Symbolic:
151
return std::make_unique<SymbolicEmbedder>(F, Vocab);
152
}
153
return nullptr;
154
}
155
156
const InstEmbeddingsMap &Embedder::getInstVecMap() const {
157
if (InstVecMap.empty())
158
computeEmbeddings();
159
return InstVecMap;
160
}
161
162
const BBEmbeddingsMap &Embedder::getBBVecMap() const {
163
if (BBVecMap.empty())
164
computeEmbeddings();
165
return BBVecMap;
166
}
167
168
const Embedding &Embedder::getBBVector(const BasicBlock &BB) const {
169
auto It = BBVecMap.find(&BB);
170
if (It != BBVecMap.end())
171
return It->second;
172
computeEmbeddings(BB);
173
return BBVecMap[&BB];
174
}
175
176
const Embedding &Embedder::getFunctionVector() const {
177
// Currently, we always (re)compute the embeddings for the function.
178
// This is cheaper than caching the vector.
179
computeEmbeddings();
180
return FuncVector;
181
}
182
183
void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
184
Embedding BBVector(Dimension, 0);
185
186
// We consider only the non-debug and non-pseudo instructions
187
for (const auto &I : BB.instructionsWithoutDebug()) {
188
Embedding ArgEmb(Dimension, 0);
189
for (const auto &Op : I.operands())
190
ArgEmb += Vocab[Op];
191
auto InstVector =
192
Vocab[I.getOpcode()] + Vocab[I.getType()->getTypeID()] + ArgEmb;
193
InstVecMap[&I] = InstVector;
194
BBVector += InstVector;
195
}
196
BBVecMap[&BB] = BBVector;
197
}
198
199
void SymbolicEmbedder::computeEmbeddings() const {
200
if (F.isDeclaration())
201
return;
202
203
// Consider only the basic blocks that are reachable from entry
204
for (const BasicBlock *BB : depth_first(&F)) {
205
computeEmbeddings(*BB);
206
FuncVector += BBVecMap[BB];
207
}
208
}
209
210
// ==----------------------------------------------------------------------===//
211
// Vocabulary
212
//===----------------------------------------------------------------------===//
213
214
Vocabulary::Vocabulary(VocabVector &&Vocab)
215
: Vocab(std::move(Vocab)), Valid(true) {}
216
217
bool Vocabulary::isValid() const {
218
return Vocab.size() == (MaxOpcodes + MaxTypeIDs + MaxOperandKinds) && Valid;
219
}
220
221
size_t Vocabulary::size() const {
222
assert(Valid && "IR2Vec Vocabulary is invalid");
223
return Vocab.size();
224
}
225
226
unsigned Vocabulary::getDimension() const {
227
assert(Valid && "IR2Vec Vocabulary is invalid");
228
return Vocab[0].size();
229
}
230
231
const Embedding &Vocabulary::operator[](unsigned Opcode) const {
232
assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
233
return Vocab[Opcode - 1];
234
}
235
236
const Embedding &Vocabulary::operator[](Type::TypeID TypeId) const {
237
assert(static_cast<unsigned>(TypeId) < MaxTypeIDs && "Invalid type ID");
238
return Vocab[MaxOpcodes + static_cast<unsigned>(TypeId)];
239
}
240
241
const ir2vec::Embedding &Vocabulary::operator[](const Value *Arg) const {
242
OperandKind ArgKind = getOperandKind(Arg);
243
return Vocab[MaxOpcodes + MaxTypeIDs + static_cast<unsigned>(ArgKind)];
244
}
245
246
StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) {
247
assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode");
248
#define HANDLE_INST(NUM, OPCODE, CLASS) \
249
if (Opcode == NUM) { \
250
return #OPCODE; \
251
}
252
#include "llvm/IR/Instruction.def"
253
#undef HANDLE_INST
254
return "UnknownOpcode";
255
}
256
257
StringRef Vocabulary::getVocabKeyForTypeID(Type::TypeID TypeID) {
258
switch (TypeID) {
259
case Type::VoidTyID:
260
return "VoidTy";
261
case Type::HalfTyID:
262
case Type::BFloatTyID:
263
case Type::FloatTyID:
264
case Type::DoubleTyID:
265
case Type::X86_FP80TyID:
266
case Type::FP128TyID:
267
case Type::PPC_FP128TyID:
268
return "FloatTy";
269
case Type::IntegerTyID:
270
return "IntegerTy";
271
case Type::FunctionTyID:
272
return "FunctionTy";
273
case Type::StructTyID:
274
return "StructTy";
275
case Type::ArrayTyID:
276
return "ArrayTy";
277
case Type::PointerTyID:
278
case Type::TypedPointerTyID:
279
return "PointerTy";
280
case Type::FixedVectorTyID:
281
case Type::ScalableVectorTyID:
282
return "VectorTy";
283
case Type::LabelTyID:
284
return "LabelTy";
285
case Type::TokenTyID:
286
return "TokenTy";
287
case Type::MetadataTyID:
288
return "MetadataTy";
289
case Type::X86_AMXTyID:
290
case Type::TargetExtTyID:
291
return "UnknownTy";
292
}
293
return "UnknownTy";
294
}
295
296
StringRef Vocabulary::getVocabKeyForOperandKind(Vocabulary::OperandKind Kind) {
297
unsigned Index = static_cast<unsigned>(Kind);
298
assert(Index < MaxOperandKinds && "Invalid OperandKind");
299
return OperandKindNames[Index];
300
}
301
302
Vocabulary::VocabVector Vocabulary::createDummyVocabForTest(unsigned Dim) {
303
VocabVector DummyVocab;
304
float DummyVal = 0.1f;
305
// Create a dummy vocabulary with entries for all opcodes, types, and
306
// operand
307
for (unsigned _ : seq(0u, Vocabulary::MaxOpcodes + Vocabulary::MaxTypeIDs +
308
Vocabulary::MaxOperandKinds)) {
309
DummyVocab.push_back(Embedding(Dim, DummyVal));
310
DummyVal += 0.1;
311
}
312
return DummyVocab;
313
}
314
315
// Helper function to classify an operand into OperandKind
316
Vocabulary::OperandKind Vocabulary::getOperandKind(const Value *Op) {
317
if (isa<Function>(Op))
318
return OperandKind::FunctionID;
319
if (isa<PointerType>(Op->getType()))
320
return OperandKind::PointerID;
321
if (isa<Constant>(Op))
322
return OperandKind::ConstantID;
323
return OperandKind::VariableID;
324
}
325
326
StringRef Vocabulary::getStringKey(unsigned Pos) {
327
assert(Pos < MaxOpcodes + MaxTypeIDs + MaxOperandKinds &&
328
"Position out of bounds in vocabulary");
329
// Opcode
330
if (Pos < MaxOpcodes)
331
return getVocabKeyForOpcode(Pos + 1);
332
// Type
333
if (Pos < MaxOpcodes + MaxTypeIDs)
334
return getVocabKeyForTypeID(static_cast<Type::TypeID>(Pos - MaxOpcodes));
335
// Operand
336
return getVocabKeyForOperandKind(
337
static_cast<OperandKind>(Pos - MaxOpcodes - MaxTypeIDs));
338
}
339
340
// For now, assume vocabulary is stable unless explicitly invalidated.
341
bool Vocabulary::invalidate(Module &M, const PreservedAnalyses &PA,
342
ModuleAnalysisManager::Invalidator &Inv) const {
343
auto PAC = PA.getChecker<IR2VecVocabAnalysis>();
344
return !(PAC.preservedWhenStateless());
345
}
346
347
// ==----------------------------------------------------------------------===//
348
// IR2VecVocabAnalysis
349
//===----------------------------------------------------------------------===//
350
351
Error IR2VecVocabAnalysis::parseVocabSection(
352
StringRef Key, const json::Value &ParsedVocabValue, VocabMap &TargetVocab,
353
unsigned &Dim) {
354
json::Path::Root Path("");
355
const json::Object *RootObj = ParsedVocabValue.getAsObject();
356
if (!RootObj)
357
return createStringError(errc::invalid_argument,
358
"JSON root is not an object");
359
360
const json::Value *SectionValue = RootObj->get(Key);
361
if (!SectionValue)
362
return createStringError(errc::invalid_argument,
363
"Missing '" + std::string(Key) +
364
"' section in vocabulary file");
365
if (!json::fromJSON(*SectionValue, TargetVocab, Path))
366
return createStringError(errc::illegal_byte_sequence,
367
"Unable to parse '" + std::string(Key) +
368
"' section from vocabulary");
369
370
Dim = TargetVocab.begin()->second.size();
371
if (Dim == 0)
372
return createStringError(errc::illegal_byte_sequence,
373
"Dimension of '" + std::string(Key) +
374
"' section of the vocabulary is zero");
375
376
if (!std::all_of(TargetVocab.begin(), TargetVocab.end(),
377
[Dim](const std::pair<StringRef, Embedding> &Entry) {
378
return Entry.second.size() == Dim;
379
}))
380
return createStringError(
381
errc::illegal_byte_sequence,
382
"All vectors in the '" + std::string(Key) +
383
"' section of the vocabulary are not of the same dimension");
384
385
return Error::success();
386
}
387
388
// FIXME: Make this optional. We can avoid file reads
389
// by auto-generating a default vocabulary during the build time.
390
Error IR2VecVocabAnalysis::readVocabulary() {
391
auto BufOrError = MemoryBuffer::getFileOrSTDIN(VocabFile, /*IsText=*/true);
392
if (!BufOrError)
393
return createFileError(VocabFile, BufOrError.getError());
394
395
auto Content = BufOrError.get()->getBuffer();
396
397
Expected<json::Value> ParsedVocabValue = json::parse(Content);
398
if (!ParsedVocabValue)
399
return ParsedVocabValue.takeError();
400
401
unsigned OpcodeDim = 0, TypeDim = 0, ArgDim = 0;
402
if (auto Err =
403
parseVocabSection("Opcodes", *ParsedVocabValue, OpcVocab, OpcodeDim))
404
return Err;
405
406
if (auto Err =
407
parseVocabSection("Types", *ParsedVocabValue, TypeVocab, TypeDim))
408
return Err;
409
410
if (auto Err =
411
parseVocabSection("Arguments", *ParsedVocabValue, ArgVocab, ArgDim))
412
return Err;
413
414
if (!(OpcodeDim == TypeDim && TypeDim == ArgDim))
415
return createStringError(errc::illegal_byte_sequence,
416
"Vocabulary sections have different dimensions");
417
418
return Error::success();
419
}
420
421
void IR2VecVocabAnalysis::generateNumMappedVocab() {
422
423
// Helper for handling missing entities in the vocabulary.
424
// Currently, we use a zero vector. In the future, we will throw an error to
425
// ensure that *all* known entities are present in the vocabulary.
426
auto handleMissingEntity = [](const std::string &Val) {
427
LLVM_DEBUG(errs() << Val
428
<< " is not in vocabulary, using zero vector; This "
429
"would result in an error in future.\n");
430
++VocabMissCounter;
431
};
432
433
unsigned Dim = OpcVocab.begin()->second.size();
434
assert(Dim > 0 && "Vocabulary dimension must be greater than zero");
435
436
// Handle Opcodes
437
std::vector<Embedding> NumericOpcodeEmbeddings(Vocabulary::MaxOpcodes,
438
Embedding(Dim, 0));
439
for (unsigned Opcode : seq(0u, Vocabulary::MaxOpcodes)) {
440
StringRef VocabKey = Vocabulary::getVocabKeyForOpcode(Opcode + 1);
441
auto It = OpcVocab.find(VocabKey.str());
442
if (It != OpcVocab.end())
443
NumericOpcodeEmbeddings[Opcode] = It->second;
444
else
445
handleMissingEntity(VocabKey.str());
446
}
447
Vocab.insert(Vocab.end(), NumericOpcodeEmbeddings.begin(),
448
NumericOpcodeEmbeddings.end());
449
450
// Handle Types
451
std::vector<Embedding> NumericTypeEmbeddings(Vocabulary::MaxTypeIDs,
452
Embedding(Dim, 0));
453
for (unsigned TypeID : seq(0u, Vocabulary::MaxTypeIDs)) {
454
StringRef VocabKey =
455
Vocabulary::getVocabKeyForTypeID(static_cast<Type::TypeID>(TypeID));
456
if (auto It = TypeVocab.find(VocabKey.str()); It != TypeVocab.end()) {
457
NumericTypeEmbeddings[TypeID] = It->second;
458
continue;
459
}
460
handleMissingEntity(VocabKey.str());
461
}
462
Vocab.insert(Vocab.end(), NumericTypeEmbeddings.begin(),
463
NumericTypeEmbeddings.end());
464
465
// Handle Arguments/Operands
466
std::vector<Embedding> NumericArgEmbeddings(Vocabulary::MaxOperandKinds,
467
Embedding(Dim, 0));
468
for (unsigned OpKind : seq(0u, Vocabulary::MaxOperandKinds)) {
469
Vocabulary::OperandKind Kind = static_cast<Vocabulary::OperandKind>(OpKind);
470
StringRef VocabKey = Vocabulary::getVocabKeyForOperandKind(Kind);
471
auto It = ArgVocab.find(VocabKey.str());
472
if (It != ArgVocab.end()) {
473
NumericArgEmbeddings[OpKind] = It->second;
474
continue;
475
}
476
handleMissingEntity(VocabKey.str());
477
}
478
Vocab.insert(Vocab.end(), NumericArgEmbeddings.begin(),
479
NumericArgEmbeddings.end());
480
}
481
482
IR2VecVocabAnalysis::IR2VecVocabAnalysis(const VocabVector &Vocab)
483
: Vocab(Vocab) {}
484
485
IR2VecVocabAnalysis::IR2VecVocabAnalysis(VocabVector &&Vocab)
486
: Vocab(std::move(Vocab)) {}
487
488
void IR2VecVocabAnalysis::emitError(Error Err, LLVMContext &Ctx) {
489
handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EI) {
490
Ctx.emitError("Error reading vocabulary: " + EI.message());
491
});
492
}
493
494
IR2VecVocabAnalysis::Result
495
IR2VecVocabAnalysis::run(Module &M, ModuleAnalysisManager &AM) {
496
auto Ctx = &M.getContext();
497
// If vocabulary is already populated by the constructor, use it.
498
if (!Vocab.empty())
499
return Vocabulary(std::move(Vocab));
500
501
// Otherwise, try to read from the vocabulary file.
502
if (VocabFile.empty()) {
503
// FIXME: Use default vocabulary
504
Ctx->emitError("IR2Vec vocabulary file path not specified; You may need to "
505
"set it using --ir2vec-vocab-path");
506
return Vocabulary(); // Return invalid result
507
}
508
if (auto Err = readVocabulary()) {
509
emitError(std::move(Err), *Ctx);
510
return Vocabulary();
511
}
512
513
// Scale the vocabulary sections based on the provided weights
514
auto scaleVocabSection = [](VocabMap &Vocab, double Weight) {
515
for (auto &Entry : Vocab)
516
Entry.second *= Weight;
517
};
518
scaleVocabSection(OpcVocab, OpcWeight);
519
scaleVocabSection(TypeVocab, TypeWeight);
520
scaleVocabSection(ArgVocab, ArgWeight);
521
522
// Generate the numeric lookup vocabulary
523
generateNumMappedVocab();
524
525
return Vocabulary(std::move(Vocab));
526
}
527
528
// ==----------------------------------------------------------------------===//
529
// Printer Passes
530
//===----------------------------------------------------------------------===//
531
532
PreservedAnalyses IR2VecPrinterPass::run(Module &M,
533
ModuleAnalysisManager &MAM) {
534
auto Vocabulary = MAM.getResult<IR2VecVocabAnalysis>(M);
535
assert(Vocabulary.isValid() && "IR2Vec Vocabulary is invalid");
536
537
for (Function &F : M) {
538
std::unique_ptr<Embedder> Emb =
539
Embedder::create(IR2VecKind::Symbolic, F, Vocabulary);
540
if (!Emb) {
541
OS << "Error creating IR2Vec embeddings \n";
542
continue;
543
}
544
545
OS << "IR2Vec embeddings for function " << F.getName() << ":\n";
546
OS << "Function vector: ";
547
Emb->getFunctionVector().print(OS);
548
549
OS << "Basic block vectors:\n";
550
const auto &BBMap = Emb->getBBVecMap();
551
for (const BasicBlock &BB : F) {
552
auto It = BBMap.find(&BB);
553
if (It != BBMap.end()) {
554
OS << "Basic block: " << BB.getName() << ":\n";
555
It->second.print(OS);
556
}
557
}
558
559
OS << "Instruction vectors:\n";
560
const auto &InstMap = Emb->getInstVecMap();
561
for (const BasicBlock &BB : F) {
562
for (const Instruction &I : BB) {
563
auto It = InstMap.find(&I);
564
if (It != InstMap.end()) {
565
OS << "Instruction: ";
566
I.print(OS);
567
It->second.print(OS);
568
}
569
}
570
}
571
}
572
return PreservedAnalyses::all();
573
}
574
575
PreservedAnalyses IR2VecVocabPrinterPass::run(Module &M,
576
ModuleAnalysisManager &MAM) {
577
auto IR2VecVocabulary = MAM.getResult<IR2VecVocabAnalysis>(M);
578
assert(IR2VecVocabulary.isValid() && "IR2Vec Vocabulary is invalid");
579
580
// Print each entry
581
unsigned Pos = 0;
582
for (const auto &Entry : IR2VecVocabulary) {
583
OS << "Key: " << IR2VecVocabulary.getStringKey(Pos++) << ": ";
584
Entry.print(OS);
585
}
586
return PreservedAnalyses::all();
587
}
588
589