Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Roblox
GitHub Repository: Roblox/luau
Path: blob/master/Analysis/src/AstQuery.cpp
2725 views
1
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
2
#include "Luau/AstQuery.h"
3
4
#include "Luau/Frontend.h"
5
#include "Luau/Module.h"
6
#include "Luau/Scope.h"
7
#include "Luau/TypeInfer.h"
8
#include "Luau/Type.h"
9
#include "Luau/ToString.h"
10
11
#include "Luau/Common.h"
12
13
#include <algorithm>
14
15
namespace Luau
16
{
17
18
namespace
19
{
20
21
struct AutocompleteNodeFinder : public AstVisitor
22
{
23
const Position pos;
24
std::vector<AstNode*> ancestry;
25
26
explicit AutocompleteNodeFinder(Position pos, AstNode* root)
27
: pos(pos)
28
{
29
}
30
31
bool visit(AstExpr* expr) override
32
{
33
// If the expression size is 0 (begin == end), we don't want to include it in the ancestry
34
if (expr->location.begin <= pos && pos <= expr->location.end && expr->location.begin != expr->location.end)
35
{
36
ancestry.push_back(expr);
37
return true;
38
}
39
return false;
40
}
41
42
bool visit(AstStat* stat) override
43
{
44
// Consider 'local myLocal = 4;|' and 'local myLocal = 4', where '|' is the cursor position. In both cases, the cursor position is equal
45
// to `AstStatLocal.location.end`. However, in the first case (semicolon), we are starting a new statement, whilst in the second case
46
// (no semicolon) we are still part of the AstStatLocal, hence the different comparison check.
47
if (stat->location.begin < pos && (stat->hasSemicolon ? pos < stat->location.end : pos <= stat->location.end))
48
{
49
ancestry.push_back(stat);
50
return true;
51
}
52
53
return false;
54
}
55
56
bool visit(AstType* type) override
57
{
58
if (type->location.begin < pos && pos <= type->location.end)
59
{
60
ancestry.push_back(type);
61
return true;
62
}
63
return false;
64
}
65
66
bool visit(AstTypeError* type) override
67
{
68
// For a missing type, match the whole range including the start position
69
if (type->isMissing && type->location.containsClosed(pos))
70
{
71
ancestry.push_back(type);
72
return true;
73
}
74
return false;
75
}
76
77
bool visit(class AstTypePack* typePack) override
78
{
79
return true;
80
}
81
82
bool visit(AstStatBlock* block) override
83
{
84
// If ancestry is empty, we are inspecting the root of the AST. Its extent is considered to be infinite.
85
if (ancestry.empty())
86
{
87
ancestry.push_back(block);
88
return true;
89
}
90
91
// AstExprIndexName nodes are nested outside-in, so we want the outermost node in the case of nested nodes.
92
// ex foo.bar.baz is represented in the AST as IndexName{ IndexName {foo, bar}, baz}
93
if (!ancestry.empty() && ancestry.back()->is<AstExprIndexName>())
94
return false;
95
96
// Type annotation error might intersect the block statement when the function header is being written,
97
// annotation takes priority
98
if (!ancestry.empty() && ancestry.back()->is<AstTypeError>())
99
return false;
100
101
// If the cursor is at the end of an expression or type and simultaneously at the beginning of a block,
102
// the expression or type wins out.
103
// The exception to this is if we are in a block under an AstExprFunction. In this case, we consider the position to
104
// be within the block.
105
if (block->location.begin == pos && !ancestry.empty())
106
{
107
if (ancestry.back()->asExpr() && !ancestry.back()->is<AstExprFunction>())
108
return false;
109
110
if (ancestry.back()->asType())
111
return false;
112
}
113
114
if (block->location.begin <= pos && pos <= block->location.end)
115
{
116
ancestry.push_back(block);
117
return true;
118
}
119
return false;
120
}
121
};
122
123
struct FindNode : public AstVisitor
124
{
125
const Position pos;
126
const Position documentEnd;
127
AstNode* best = nullptr;
128
129
explicit FindNode(Position pos, Position documentEnd)
130
: pos(pos)
131
, documentEnd(documentEnd)
132
{
133
}
134
135
bool visit(AstNode* node) override
136
{
137
if (node->location.contains(pos))
138
{
139
best = node;
140
return true;
141
}
142
143
// Edge case: If we ask for the node at the position that is the very end of the document
144
// return the innermost AST element that ends at that position.
145
146
if (node->location.end == documentEnd && pos >= documentEnd)
147
{
148
best = node;
149
return true;
150
}
151
152
return false;
153
}
154
155
bool visit(AstStatFunction* node) override
156
{
157
visit(static_cast<AstNode*>(node));
158
if (node->name->location.contains(pos))
159
node->name->visit(this);
160
else if (node->func->location.contains(pos))
161
node->func->visit(this);
162
return false;
163
}
164
165
bool visit(AstStatBlock* block) override
166
{
167
visit(static_cast<AstNode*>(block));
168
169
for (AstStat* stat : block->body)
170
{
171
if (stat->location.end < pos)
172
continue;
173
if (stat->location.begin > pos)
174
break;
175
176
stat->visit(this);
177
}
178
179
return false;
180
}
181
};
182
183
} // namespace
184
185
FindFullAncestry::FindFullAncestry(Position pos, Position documentEnd, bool includeTypes)
186
: pos(pos)
187
, documentEnd(documentEnd)
188
, includeTypes(includeTypes)
189
{
190
}
191
192
bool FindFullAncestry::visit(AstType* type)
193
{
194
if (includeTypes)
195
return visit(static_cast<AstNode*>(type));
196
else
197
return false;
198
}
199
200
bool FindFullAncestry::visit(AstStatFunction* node)
201
{
202
visit(static_cast<AstNode*>(node));
203
if (node->name->location.contains(pos))
204
node->name->visit(this);
205
else if (node->func->location.contains(pos))
206
node->func->visit(this);
207
return false;
208
}
209
210
bool FindFullAncestry::visit(AstNode* node)
211
{
212
if (node->location.contains(pos))
213
{
214
nodes.push_back(node);
215
return true;
216
}
217
218
// Edge case: If we ask for the node at the position that is the very end of the document
219
// return the innermost AST element that ends at that position.
220
221
if (node->location.end == documentEnd && pos >= documentEnd)
222
{
223
nodes.push_back(node);
224
return true;
225
}
226
227
return false;
228
}
229
230
std::vector<AstNode*> findAncestryAtPositionForAutocomplete(const SourceModule& source, Position pos)
231
{
232
return findAncestryAtPositionForAutocomplete(source.root, pos);
233
}
234
235
std::vector<AstNode*> findAncestryAtPositionForAutocomplete(AstStatBlock* root, Position pos)
236
{
237
AutocompleteNodeFinder finder{pos, root};
238
root->visit(&finder);
239
return finder.ancestry;
240
}
241
242
std::vector<AstNode*> findAstAncestryOfPosition(const SourceModule& source, Position pos, bool includeTypes)
243
{
244
return findAstAncestryOfPosition(source.root, pos, includeTypes);
245
}
246
247
std::vector<AstNode*> findAstAncestryOfPosition(AstStatBlock* root, Position pos, bool includeTypes)
248
{
249
const Position end = root->location.end;
250
if (pos > end)
251
pos = end;
252
253
FindFullAncestry finder(pos, end, includeTypes);
254
root->visit(&finder);
255
return finder.nodes;
256
}
257
258
AstNode* findNodeAtPosition(const SourceModule& source, Position pos)
259
{
260
return findNodeAtPosition(source.root, pos);
261
}
262
263
AstNode* findNodeAtPosition(AstStatBlock* root, Position pos)
264
{
265
const Position end = root->location.end;
266
if (pos < root->location.begin)
267
return root;
268
269
if (pos > end)
270
pos = end;
271
272
FindNode findNode{pos, end};
273
findNode.visit(root);
274
return findNode.best;
275
}
276
277
AstExpr* findExprAtPosition(const SourceModule& source, Position pos)
278
{
279
AstNode* node = findNodeAtPosition(source, pos);
280
if (node)
281
return node->asExpr();
282
else
283
return nullptr;
284
}
285
286
ScopePtr findScopeAtPosition(const Module& module, Position pos)
287
{
288
if (module.scopes.empty())
289
return nullptr;
290
291
Location scopeLocation = module.scopes.front().first;
292
ScopePtr scope = module.scopes.front().second;
293
for (const auto& s : module.scopes)
294
{
295
if (s.first.contains(pos))
296
{
297
if (!scope || scopeLocation.encloses(s.first))
298
{
299
scopeLocation = s.first;
300
scope = s.second;
301
}
302
}
303
}
304
return scope;
305
}
306
307
std::optional<TypeId> findTypeAtPosition(const Module& module, const SourceModule& sourceModule, Position pos)
308
{
309
if (auto expr = findExprAtPosition(sourceModule, pos))
310
{
311
if (auto it = module.astTypes.find(expr))
312
return *it;
313
}
314
315
return std::nullopt;
316
}
317
318
std::optional<TypeId> findExpectedTypeAtPosition(const Module& module, const SourceModule& sourceModule, Position pos)
319
{
320
if (auto expr = findExprAtPosition(sourceModule, pos))
321
{
322
if (auto it = module.astExpectedTypes.find(expr))
323
return *it;
324
}
325
326
return std::nullopt;
327
}
328
329
static std::optional<AstStatLocal*> findBindingLocalStatement(const SourceModule& source, const Binding& binding)
330
{
331
// Bindings coming from global sources (e.g., definition files) have a zero position.
332
// They cannot be defined from a local statement
333
if (binding.location == Location{{0, 0}, {0, 0}})
334
return std::nullopt;
335
336
std::vector<AstNode*> nodes = findAstAncestryOfPosition(source, binding.location.begin);
337
auto iter = std::find_if(
338
nodes.rbegin(),
339
nodes.rend(),
340
[](AstNode* node)
341
{
342
return node->is<AstStatLocal>();
343
}
344
);
345
return iter != nodes.rend() ? std::make_optional((*iter)->as<AstStatLocal>()) : std::nullopt;
346
}
347
348
std::optional<Binding> findBindingAtPosition(const Module& module, const SourceModule& source, Position pos)
349
{
350
ExprOrLocal exprOrLocal = findExprOrLocalAtPosition(source, pos);
351
352
Symbol name;
353
if (auto expr = exprOrLocal.getExpr())
354
{
355
if (auto g = expr->as<AstExprGlobal>())
356
name = g->name;
357
else if (auto l = expr->as<AstExprLocal>())
358
name = l->local;
359
else
360
return std::nullopt;
361
}
362
else if (auto local = exprOrLocal.getLocal())
363
name = local;
364
else
365
return std::nullopt;
366
367
ScopePtr currentScope = findScopeAtPosition(module, pos);
368
369
while (currentScope)
370
{
371
auto iter = currentScope->bindings.find(name);
372
if (iter != currentScope->bindings.end() && iter->second.location.begin <= pos)
373
{
374
// Ignore this binding if we're inside its definition. e.g. local abc = abc -- Will take the definition of abc from outer scope
375
std::optional<AstStatLocal*> bindingStatement = findBindingLocalStatement(source, iter->second);
376
if (!bindingStatement || !(*bindingStatement)->location.contains(pos))
377
return iter->second;
378
}
379
currentScope = currentScope->parent;
380
}
381
382
return std::nullopt;
383
}
384
385
namespace
386
{
387
struct FindExprOrLocal : public AstVisitor
388
{
389
const Position pos;
390
ExprOrLocal result;
391
392
explicit FindExprOrLocal(Position pos)
393
: pos(pos)
394
{
395
}
396
397
// We want to find the result with the smallest location range.
398
bool isCloserMatch(Location newLocation)
399
{
400
auto current = result.getLocation();
401
return newLocation.contains(pos) && (!current || current->encloses(newLocation));
402
}
403
404
bool visit(AstStatBlock* block) override
405
{
406
for (AstStat* stat : block->body)
407
{
408
if (stat->location.end <= pos)
409
continue;
410
if (stat->location.begin > pos)
411
break;
412
413
stat->visit(this);
414
}
415
416
return false;
417
}
418
419
bool visit(AstExpr* expr) override
420
{
421
if (isCloserMatch(expr->location))
422
{
423
result.setExpr(expr);
424
return true;
425
}
426
return false;
427
}
428
429
bool visitLocal(AstLocal* local)
430
{
431
if (isCloserMatch(local->location))
432
{
433
result.setLocal(local);
434
return true;
435
}
436
return false;
437
}
438
439
bool visit(AstStatLocalFunction* function) override
440
{
441
visitLocal(function->name);
442
return true;
443
}
444
445
bool visit(AstStatLocal* al) override
446
{
447
for (size_t i = 0; i < al->vars.size; ++i)
448
{
449
visitLocal(al->vars.data[i]);
450
}
451
return true;
452
}
453
454
bool visit(AstExprFunction* fn) override
455
{
456
for (size_t i = 0; i < fn->args.size; ++i)
457
{
458
visitLocal(fn->args.data[i]);
459
}
460
return visit((class AstExpr*)fn);
461
}
462
463
bool visit(AstStatFor* forStat) override
464
{
465
visitLocal(forStat->var);
466
return true;
467
}
468
469
bool visit(AstStatForIn* forIn) override
470
{
471
for (AstLocal* var : forIn->vars)
472
{
473
visitLocal(var);
474
}
475
return true;
476
}
477
};
478
}; // namespace
479
480
ExprOrLocal findExprOrLocalAtPosition(const SourceModule& source, Position pos)
481
{
482
FindExprOrLocal findVisitor{pos};
483
findVisitor.visit(source.root);
484
return findVisitor.result;
485
}
486
487
static std::optional<DocumentationSymbol> checkOverloadedDocumentationSymbol(
488
const Module& module,
489
const TypeId ty,
490
const AstExpr* parentExpr,
491
std::optional<DocumentationSymbol> documentationSymbol
492
)
493
{
494
if (!documentationSymbol)
495
return std::nullopt;
496
497
// This might be an overloaded function.
498
if (get<IntersectionType>(follow(ty)))
499
{
500
TypeId matchingOverload = nullptr;
501
if (parentExpr && parentExpr->is<AstExprCall>())
502
{
503
if (auto it = module.astOverloadResolvedTypes.find(parentExpr))
504
{
505
matchingOverload = *it;
506
}
507
}
508
509
if (matchingOverload)
510
{
511
std::string overloadSymbol = *documentationSymbol + "/overload/";
512
// Default toString options are fine for this purpose.
513
overloadSymbol += toString(matchingOverload);
514
return overloadSymbol;
515
}
516
}
517
518
return documentationSymbol;
519
}
520
521
static std::optional<DocumentationSymbol> getMetatableDocumentation(
522
const Module& module,
523
AstExpr* parentExpr,
524
const TableType* mtable,
525
const AstName& index
526
)
527
{
528
auto indexIt = mtable->props.find("__index");
529
if (indexIt == mtable->props.end())
530
return std::nullopt;
531
532
TypeId followed;
533
if (indexIt->second.readTy)
534
followed = follow(*indexIt->second.readTy);
535
else if (indexIt->second.writeTy)
536
followed = follow(*indexIt->second.writeTy);
537
else
538
return std::nullopt;
539
540
const TableType* ttv = get<TableType>(followed);
541
if (!ttv)
542
return std::nullopt;
543
544
auto propIt = ttv->props.find(index.value);
545
if (propIt == ttv->props.end())
546
return std::nullopt;
547
548
if (auto ty = propIt->second.readTy)
549
return checkOverloadedDocumentationSymbol(module, *ty, parentExpr, propIt->second.documentationSymbol);
550
551
return std::nullopt;
552
}
553
554
std::optional<DocumentationSymbol> getDocumentationSymbolAtPosition(const SourceModule& source, const Module& module, Position position)
555
{
556
std::vector<AstNode*> ancestry = findAstAncestryOfPosition(source, position);
557
558
AstExpr* targetExpr = ancestry.size() >= 1 ? ancestry[ancestry.size() - 1]->asExpr() : nullptr;
559
AstExpr* parentExpr = ancestry.size() >= 2 ? ancestry[ancestry.size() - 2]->asExpr() : nullptr;
560
561
if (targetExpr)
562
{
563
if (AstExprIndexName* indexName = targetExpr->as<AstExprIndexName>())
564
{
565
if (auto it = module.astTypes.find(indexName->expr))
566
{
567
TypeId parentTy = follow(*it);
568
if (const TableType* ttv = get<TableType>(parentTy))
569
{
570
if (auto propIt = ttv->props.find(indexName->index.value); propIt != ttv->props.end())
571
{
572
if (auto ty = propIt->second.readTy)
573
return checkOverloadedDocumentationSymbol(module, *ty, parentExpr, propIt->second.documentationSymbol);
574
}
575
}
576
else if (const ExternType* etv = get<ExternType>(parentTy))
577
{
578
while (etv)
579
{
580
if (auto propIt = etv->props.find(indexName->index.value); propIt != etv->props.end())
581
{
582
583
if (auto ty = propIt->second.readTy)
584
return checkOverloadedDocumentationSymbol(module, *ty, parentExpr, propIt->second.documentationSymbol);
585
}
586
etv = etv->parent ? Luau::get<Luau::ExternType>(*etv->parent) : nullptr;
587
}
588
}
589
else if (const PrimitiveType* ptv = get<PrimitiveType>(parentTy); ptv && ptv->metatable)
590
{
591
if (auto mtable = get<TableType>(*ptv->metatable))
592
{
593
if (std::optional<std::string> docSymbol = getMetatableDocumentation(module, parentExpr, mtable, indexName->index))
594
return docSymbol;
595
}
596
}
597
}
598
}
599
else if (AstExprFunction* fn = targetExpr->as<AstExprFunction>())
600
{
601
// Handle event connection-like structures where we have
602
// something:Connect(function(a, b, c) end)
603
// In this case, we want to ascribe a documentation symbol to 'a'
604
// based on the documentation symbol of Connect.
605
if (parentExpr && parentExpr->is<AstExprCall>())
606
{
607
AstExprCall* call = parentExpr->as<AstExprCall>();
608
if (std::optional<DocumentationSymbol> parentSymbol = getDocumentationSymbolAtPosition(source, module, call->func->location.begin))
609
{
610
for (size_t i = 0; i < call->args.size; ++i)
611
{
612
AstExpr* callArg = call->args.data[i];
613
if (callArg == targetExpr)
614
{
615
std::string fnSymbol = *parentSymbol + "/param/" + std::to_string(i);
616
for (size_t j = 0; j < fn->args.size; ++j)
617
{
618
AstLocal* fnArg = fn->args.data[j];
619
620
if (fnArg->location.contains(position))
621
{
622
return fnSymbol + "/param/" + std::to_string(j);
623
}
624
}
625
}
626
}
627
}
628
}
629
}
630
}
631
632
if (std::optional<Binding> binding = findBindingAtPosition(module, source, position))
633
return checkOverloadedDocumentationSymbol(module, binding->typeId, parentExpr, binding->documentationSymbol);
634
635
if (std::optional<TypeId> ty = findTypeAtPosition(module, source, position))
636
{
637
if ((*ty)->documentationSymbol)
638
{
639
return (*ty)->documentationSymbol;
640
}
641
}
642
643
return std::nullopt;
644
}
645
646
} // namespace Luau
647
648