Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Roblox
GitHub Repository: Roblox/luau
Path: blob/master/Analysis/src/NonStrictTypeChecker.cpp
2725 views
1
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
2
#include "Luau/NonStrictTypeChecker.h"
3
4
#include "Luau/Ast.h"
5
#include "Luau/AstQuery.h"
6
#include "Luau/Common.h"
7
#include "Luau/Def.h"
8
#include "Luau/Error.h"
9
#include "Luau/Normalize.h"
10
#include "Luau/RecursionCounter.h"
11
#include "Luau/Simplify.h"
12
#include "Luau/Subtyping.h"
13
#include "Luau/TimeTrace.h"
14
#include "Luau/ToString.h"
15
#include "Luau/Type.h"
16
#include "Luau/TypeArena.h"
17
#include "Luau/TypeFunction.h"
18
#include "Luau/TypeUtils.h"
19
20
#include <iterator>
21
22
LUAU_FASTFLAG(DebugLuauMagicTypes)
23
24
LUAU_FASTINTVARIABLE(LuauNonStrictTypeCheckerRecursionLimit, 300)
25
LUAU_FASTFLAGVARIABLE(LuauAddRecursionCounterToNonStrictTypeChecker)
26
27
namespace Luau
28
{
29
30
/* Push a scope onto the end of a stack for the lifetime of the StackPusher instance.
31
* NonStrictTypeChecker uses this to maintain knowledge about which scope encloses every
32
* given AstNode.
33
*/
34
struct StackPusher
35
{
36
std::vector<NotNull<Scope>>* stack;
37
NotNull<Scope> scope;
38
39
explicit StackPusher(std::vector<NotNull<Scope>>& stack, Scope* scope)
40
: stack(&stack)
41
, scope(scope)
42
{
43
stack.emplace_back(scope);
44
}
45
46
~StackPusher()
47
{
48
if (stack)
49
{
50
LUAU_ASSERT(stack->back() == scope);
51
stack->pop_back();
52
}
53
}
54
55
StackPusher(const StackPusher&) = delete;
56
StackPusher&& operator=(const StackPusher&) = delete;
57
58
StackPusher(StackPusher&& other)
59
: stack(std::exchange(other.stack, nullptr))
60
, scope(other.scope)
61
{
62
}
63
};
64
65
66
struct NonStrictContext
67
{
68
NonStrictContext() = default;
69
70
NonStrictContext(const NonStrictContext&) = delete;
71
NonStrictContext& operator=(const NonStrictContext&) = delete;
72
73
NonStrictContext(NonStrictContext&&) = default;
74
NonStrictContext& operator=(NonStrictContext&&) = default;
75
76
static NonStrictContext disjunction(
77
NotNull<BuiltinTypes> builtinTypes,
78
NotNull<TypeArena> arena,
79
const NonStrictContext& left,
80
const NonStrictContext& right
81
)
82
{
83
// disjunction implements union over the domain of keys
84
// if the default value for a defId not in the map is `never`
85
// then never | T is T
86
NonStrictContext disj{};
87
88
for (auto [def, leftTy] : left.context)
89
{
90
if (std::optional<TypeId> rightTy = right.find(def))
91
disj.context[def] = simplifyUnion(builtinTypes, arena, leftTy, *rightTy).result;
92
else
93
disj.context[def] = leftTy;
94
}
95
96
for (auto [def, rightTy] : right.context)
97
{
98
if (!left.find(def).has_value())
99
disj.context[def] = rightTy;
100
}
101
102
return disj;
103
}
104
105
static NonStrictContext conjunction(
106
NotNull<BuiltinTypes> builtins,
107
NotNull<TypeArena> arena,
108
const NonStrictContext& left,
109
const NonStrictContext& right
110
)
111
{
112
NonStrictContext conj{};
113
114
for (auto [def, leftTy] : left.context)
115
{
116
if (std::optional<TypeId> rightTy = right.find(def))
117
conj.context[def] = simplifyIntersection(builtins, arena, leftTy, *rightTy).result;
118
}
119
120
return conj;
121
}
122
123
// Returns true if the removal was successful
124
bool remove(const DefId& def)
125
{
126
std::vector<DefId> defs;
127
collectOperands(def, &defs);
128
bool result = true;
129
for (DefId def : defs)
130
result = result && context.erase(def.get()) == 1;
131
return result;
132
}
133
134
std::optional<TypeId> find(const DefId& def) const
135
{
136
const Def* d = def.get();
137
return find(d);
138
}
139
140
void addContext(const DefId& def, TypeId ty)
141
{
142
std::vector<DefId> defs;
143
collectOperands(def, &defs);
144
for (DefId def : defs)
145
context[def.get()] = ty;
146
}
147
148
private:
149
std::optional<TypeId> find(const Def* d) const
150
{
151
auto it = context.find(d);
152
if (it != context.end())
153
return {it->second};
154
return {};
155
}
156
157
std::unordered_map<const Def*, TypeId> context;
158
};
159
160
struct NonStrictTypeChecker
161
{
162
NotNull<BuiltinTypes> builtinTypes;
163
NotNull<TypeFunctionRuntime> typeFunctionRuntime;
164
const NotNull<InternalErrorReporter> ice;
165
NotNull<TypeArena> arena;
166
Module* module;
167
Normalizer normalizer;
168
Subtyping subtyping;
169
NotNull<const DataFlowGraph> dfg;
170
DenseHashSet<TypeId> noTypeFunctionErrors{nullptr};
171
std::vector<NotNull<Scope>> stack;
172
DenseHashMap<TypeId, TypeId> cachedNegations{nullptr};
173
174
const NotNull<TypeCheckLimits> limits;
175
176
NonStrictTypeChecker(
177
NotNull<TypeArena> arena,
178
NotNull<BuiltinTypes> builtinTypes,
179
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
180
const NotNull<InternalErrorReporter> ice,
181
NotNull<UnifierSharedState> unifierState,
182
NotNull<const DataFlowGraph> dfg,
183
NotNull<TypeCheckLimits> limits,
184
Module* module
185
)
186
: builtinTypes(builtinTypes)
187
, typeFunctionRuntime(typeFunctionRuntime)
188
, ice(ice)
189
, arena(arena)
190
, module(module)
191
, normalizer{arena, builtinTypes, unifierState, SolverMode::New, /* cache inhabitance */ true}
192
, subtyping{builtinTypes, arena, NotNull(&normalizer), typeFunctionRuntime, ice}
193
, dfg(dfg)
194
, limits(limits)
195
{
196
}
197
198
std::optional<StackPusher> pushStack(AstNode* node)
199
{
200
if (Scope** scope = module->astScopes.find(node))
201
return StackPusher{stack, *scope};
202
else
203
return std::nullopt;
204
}
205
206
TypeId flattenPack(TypePackId pack)
207
{
208
pack = follow(pack);
209
210
if (auto fst = first(pack, /*ignoreHiddenVariadics*/ false))
211
return *fst;
212
else if (auto ftp = get<FreeTypePack>(pack))
213
{
214
TypeId result = arena->freshType(builtinTypes, ftp->scope);
215
TypePackId freeTail = arena->addTypePack(FreeTypePack{ftp->scope});
216
217
TypePack* resultPack = emplaceTypePack<TypePack>(asMutable(pack));
218
resultPack->head.assign(1, result);
219
resultPack->tail = freeTail;
220
221
return result;
222
}
223
else if (get<ErrorTypePack>(pack))
224
return builtinTypes->errorType;
225
else if (finite(pack) && size(pack) == 0)
226
return builtinTypes->nilType; // `(f())` where `f()` returns no values is coerced into `nil`
227
else
228
ice->ice("flattenPack got a weird pack!");
229
}
230
231
232
TypeId checkForTypeFunctionInhabitance(TypeId instance, Location location)
233
{
234
if (noTypeFunctionErrors.find(instance))
235
return instance;
236
237
TypeFunctionContext context{arena, builtinTypes, stack.back(), NotNull{&normalizer}, typeFunctionRuntime, ice, limits};
238
ErrorVec errors = reduceTypeFunctions(instance, location, NotNull{&context}, true).errors;
239
240
if (errors.empty())
241
noTypeFunctionErrors.insert(instance);
242
// TODO??
243
// if (!isErrorSuppressing(location, instance))
244
// reportErrors(std::move(errors));
245
return instance;
246
}
247
248
249
TypeId lookupType(AstExpr* expr)
250
{
251
TypeId* ty = module->astTypes.find(expr);
252
if (ty)
253
return checkForTypeFunctionInhabitance(follow(*ty), expr->location);
254
255
TypePackId* tp = module->astTypePacks.find(expr);
256
if (tp)
257
return checkForTypeFunctionInhabitance(flattenPack(*tp), expr->location);
258
return builtinTypes->anyType;
259
}
260
261
NonStrictContext visit(AstStat* stat)
262
{
263
auto pusher = pushStack(stat);
264
if (auto s = stat->as<AstStatBlock>())
265
return visit(s);
266
else if (auto s = stat->as<AstStatIf>())
267
return visit(s);
268
else if (auto s = stat->as<AstStatWhile>())
269
return visit(s);
270
else if (auto s = stat->as<AstStatRepeat>())
271
return visit(s);
272
else if (auto s = stat->as<AstStatBreak>())
273
return visit(s);
274
else if (auto s = stat->as<AstStatContinue>())
275
return visit(s);
276
else if (auto s = stat->as<AstStatReturn>())
277
return visit(s);
278
else if (auto s = stat->as<AstStatExpr>())
279
return visit(s);
280
else if (auto s = stat->as<AstStatLocal>())
281
return visit(s);
282
else if (auto s = stat->as<AstStatFor>())
283
return visit(s);
284
else if (auto s = stat->as<AstStatForIn>())
285
return visit(s);
286
else if (auto s = stat->as<AstStatAssign>())
287
return visit(s);
288
else if (auto s = stat->as<AstStatCompoundAssign>())
289
return visit(s);
290
else if (auto s = stat->as<AstStatFunction>())
291
return visit(s);
292
else if (auto s = stat->as<AstStatLocalFunction>())
293
return visit(s);
294
else if (auto s = stat->as<AstStatTypeAlias>())
295
return visit(s);
296
else if (auto f = stat->as<AstStatTypeFunction>())
297
return visit(f);
298
else if (auto s = stat->as<AstStatDeclareFunction>())
299
return visit(s);
300
else if (auto s = stat->as<AstStatDeclareGlobal>())
301
return visit(s);
302
else if (auto s = stat->as<AstStatDeclareExternType>())
303
return visit(s);
304
else if (auto s = stat->as<AstStatError>())
305
return visit(s);
306
else
307
{
308
LUAU_ASSERT(!"NonStrictTypeChecker encountered an unknown statement type");
309
ice->ice("NonStrictTypeChecker encountered an unknown statement type");
310
}
311
}
312
313
NonStrictContext visit(AstStatBlock* block)
314
{
315
std::optional<RecursionCounter> _rc;
316
if (FFlag::LuauAddRecursionCounterToNonStrictTypeChecker)
317
{
318
_rc.emplace(&nonStrictRecursionCount);
319
if (FInt::LuauNonStrictTypeCheckerRecursionLimit > 0 && nonStrictRecursionCount >= FInt::LuauNonStrictTypeCheckerRecursionLimit)
320
return {};
321
}
322
323
auto StackPusher = pushStack(block);
324
NonStrictContext ctx;
325
326
327
for (auto it = block->body.rbegin(); it != block->body.rend(); it++)
328
{
329
AstStat* stat = *it;
330
if (AstStatLocal* local = stat->as<AstStatLocal>())
331
{
332
// Iterating in reverse order
333
// local x ; B generates the context of B without x
334
visit(local);
335
for (auto local : local->vars)
336
{
337
ctx.remove(dfg->getDef(local));
338
339
visit(local->annotation);
340
}
341
}
342
else
343
ctx = NonStrictContext::disjunction(builtinTypes, arena, visit(stat), ctx);
344
}
345
return ctx;
346
}
347
348
NonStrictContext visit(AstStatIf* ifStatement)
349
{
350
NonStrictContext condB = visit(ifStatement->condition, ValueContext::RValue);
351
NonStrictContext branchContext;
352
353
NonStrictContext thenBody = visit(ifStatement->thenbody);
354
if (ifStatement->elsebody)
355
{
356
NonStrictContext elseBody = visit(ifStatement->elsebody);
357
branchContext = NonStrictContext::conjunction(builtinTypes, arena, thenBody, elseBody);
358
}
359
360
return NonStrictContext::disjunction(builtinTypes, arena, condB, branchContext);
361
}
362
363
NonStrictContext visit(AstStatWhile* whileStatement)
364
{
365
NonStrictContext condition = visit(whileStatement->condition, ValueContext::RValue);
366
NonStrictContext body = visit(whileStatement->body);
367
return NonStrictContext::disjunction(builtinTypes, arena, condition, body);
368
}
369
370
NonStrictContext visit(AstStatRepeat* repeatStatement)
371
{
372
NonStrictContext body = visit(repeatStatement->body);
373
NonStrictContext condition = visit(repeatStatement->condition, ValueContext::RValue);
374
return NonStrictContext::disjunction(builtinTypes, arena, body, condition);
375
}
376
377
NonStrictContext visit(AstStatBreak* breakStatement)
378
{
379
return {};
380
}
381
382
NonStrictContext visit(AstStatContinue* continueStatement)
383
{
384
return {};
385
}
386
387
NonStrictContext visit(AstStatReturn* returnStatement)
388
{
389
// TODO: this is believing existing code, but i'm not sure if this makes sense
390
// for how the contexts are handled
391
for (AstExpr* expr : returnStatement->list)
392
visit(expr, ValueContext::RValue);
393
394
return {};
395
}
396
397
NonStrictContext visit(AstStatExpr* expr)
398
{
399
return visit(expr->expr, ValueContext::RValue);
400
}
401
402
NonStrictContext visit(AstStatLocal* local)
403
{
404
for (AstExpr* rhs : local->values)
405
visit(rhs, ValueContext::RValue);
406
return {};
407
}
408
409
NonStrictContext visit(AstStatFor* forStatement)
410
{
411
visit(forStatement->var->annotation);
412
413
// TODO: throwing out context based on same principle as existing code?
414
if (forStatement->from)
415
visit(forStatement->from, ValueContext::RValue);
416
if (forStatement->to)
417
visit(forStatement->to, ValueContext::RValue);
418
if (forStatement->step)
419
visit(forStatement->step, ValueContext::RValue);
420
return visit(forStatement->body);
421
}
422
423
NonStrictContext visit(AstStatForIn* forInStatement)
424
{
425
for (auto var : forInStatement->vars)
426
visit(var->annotation);
427
428
for (AstExpr* rhs : forInStatement->values)
429
visit(rhs, ValueContext::RValue);
430
return visit(forInStatement->body);
431
}
432
433
NonStrictContext visit(AstStatAssign* assign)
434
{
435
for (AstExpr* lhs : assign->vars)
436
visit(lhs, ValueContext::LValue);
437
for (AstExpr* rhs : assign->values)
438
visit(rhs, ValueContext::RValue);
439
440
return {};
441
}
442
443
NonStrictContext visit(AstStatCompoundAssign* compoundAssign)
444
{
445
visit(compoundAssign->var, ValueContext::LValue);
446
visit(compoundAssign->value, ValueContext::RValue);
447
448
return {};
449
}
450
451
NonStrictContext visit(AstStatFunction* statFn)
452
{
453
return visit(statFn->func, ValueContext::RValue);
454
}
455
456
NonStrictContext visit(AstStatLocalFunction* localFn)
457
{
458
return visit(localFn->func, ValueContext::RValue);
459
}
460
461
NonStrictContext visit(AstStatTypeAlias* typeAlias)
462
{
463
visitGenerics(typeAlias->generics, typeAlias->genericPacks);
464
visit(typeAlias->type);
465
466
return {};
467
}
468
469
NonStrictContext visit(AstStatTypeFunction* typeFunc)
470
{
471
return {};
472
}
473
474
NonStrictContext visit(AstStatDeclareFunction* declFn)
475
{
476
visitGenerics(declFn->generics, declFn->genericPacks);
477
visit(declFn->params);
478
visit(declFn->retTypes);
479
480
return {};
481
}
482
483
NonStrictContext visit(AstStatDeclareGlobal* declGlobal)
484
{
485
visit(declGlobal->type);
486
487
return {};
488
}
489
490
NonStrictContext visit(AstStatDeclareExternType* declClass)
491
{
492
if (declClass->indexer)
493
{
494
visit(declClass->indexer->indexType);
495
visit(declClass->indexer->resultType);
496
}
497
498
for (auto prop : declClass->props)
499
visit(prop.ty);
500
501
return {};
502
}
503
504
NonStrictContext visit(AstStatError* error)
505
{
506
for (AstStat* stat : error->statements)
507
visit(stat);
508
for (AstExpr* expr : error->expressions)
509
visit(expr, ValueContext::RValue);
510
511
return {};
512
}
513
514
NonStrictContext visit(AstExpr* expr, ValueContext context)
515
{
516
std::optional<RecursionCounter> _rc;
517
if (FFlag::LuauAddRecursionCounterToNonStrictTypeChecker)
518
{
519
_rc.emplace(&nonStrictRecursionCount);
520
if (FInt::LuauNonStrictTypeCheckerRecursionLimit > 0 && nonStrictRecursionCount >= FInt::LuauNonStrictTypeCheckerRecursionLimit)
521
return {};
522
}
523
524
auto pusher = pushStack(expr);
525
if (auto e = expr->as<AstExprGroup>())
526
return visit(e, context);
527
else if (auto e = expr->as<AstExprConstantNil>())
528
return visit(e);
529
else if (auto e = expr->as<AstExprConstantBool>())
530
return visit(e);
531
else if (auto e = expr->as<AstExprConstantNumber>())
532
return visit(e);
533
else if (auto e = expr->as<AstExprConstantInteger>())
534
return visit(e);
535
else if (auto e = expr->as<AstExprConstantString>())
536
return visit(e);
537
else if (auto e = expr->as<AstExprLocal>())
538
return visit(e, context);
539
else if (auto e = expr->as<AstExprGlobal>())
540
return visit(e, context);
541
else if (auto e = expr->as<AstExprVarargs>())
542
return visit(e);
543
else if (auto e = expr->as<AstExprCall>())
544
return visit(e);
545
else if (auto e = expr->as<AstExprIndexName>())
546
return visit(e, context);
547
else if (auto e = expr->as<AstExprIndexExpr>())
548
return visit(e, context);
549
else if (auto e = expr->as<AstExprFunction>())
550
return visit(e);
551
else if (auto e = expr->as<AstExprTable>())
552
return visit(e);
553
else if (auto e = expr->as<AstExprUnary>())
554
return visit(e);
555
else if (auto e = expr->as<AstExprBinary>())
556
return visit(e);
557
else if (auto e = expr->as<AstExprTypeAssertion>())
558
return visit(e);
559
else if (auto e = expr->as<AstExprIfElse>())
560
return visit(e);
561
else if (auto e = expr->as<AstExprInterpString>())
562
return visit(e);
563
else if (auto e = expr->as<AstExprError>())
564
return visit(e);
565
else if (auto e = expr->as<AstExprInstantiate>())
566
return visit(e);
567
else
568
{
569
LUAU_ASSERT(!"NonStrictTypeChecker encountered an unknown expression type");
570
ice->ice("NonStrictTypeChecker encountered an unknown expression type");
571
}
572
}
573
574
NonStrictContext visit(AstExprGroup* group, ValueContext context)
575
{
576
return visit(group->expr, context);
577
}
578
579
NonStrictContext visit(AstExprConstantNil* expr)
580
{
581
return {};
582
}
583
584
NonStrictContext visit(AstExprConstantBool* expr)
585
{
586
return {};
587
}
588
589
NonStrictContext visit(AstExprConstantNumber* expr)
590
{
591
return {};
592
}
593
594
NonStrictContext visit(AstExprConstantInteger* expr)
595
{
596
return {};
597
}
598
599
NonStrictContext visit(AstExprConstantString* expr)
600
{
601
return {};
602
}
603
604
NonStrictContext visit(AstExprLocal* local, ValueContext context)
605
{
606
return {};
607
}
608
609
NonStrictContext visit(AstExprGlobal* global, ValueContext context)
610
{
611
// We don't file unknown symbols for LValues.
612
if (context == ValueContext::LValue)
613
return {};
614
615
NotNull<Scope> scope = stack.back();
616
if (!scope->lookup(global->name))
617
{
618
reportError(UnknownSymbol{global->name.value, UnknownSymbol::Binding}, global->location);
619
}
620
621
return {};
622
}
623
624
NonStrictContext visit(AstExprVarargs* varargs)
625
{
626
return {};
627
}
628
629
NonStrictContext visit(AstExprCall* call)
630
{
631
visit(call->func, ValueContext::RValue);
632
for (auto arg : call->args)
633
visit(arg, ValueContext::RValue);
634
635
NonStrictContext fresh{};
636
TypeId* originalCallTy = module->astOriginalCallTypes.find(call->func);
637
if (!originalCallTy)
638
return fresh;
639
640
TypeId fnTy = *originalCallTy;
641
if (auto fn = get<FunctionType>(follow(fnTy)); fn && fn->isCheckedFunction)
642
{
643
// We know fn is a checked function, which means it looks like:
644
// (S1, ... SN) -> T &
645
// (~S1, unknown^N-1) -> error &
646
// (unknown, ~S2, unknown^N-2) -> error
647
// ...
648
// ...
649
// (unknown^N-1, ~S_N) -> error
650
651
std::vector<AstExpr*> arguments;
652
arguments.reserve(call->args.size + (call->self ? 1 : 0));
653
if (call->self)
654
{
655
if (auto indexExpr = call->func->as<AstExprIndexName>())
656
arguments.push_back(indexExpr->expr);
657
else
658
ice->ice("method call expression has no 'self'");
659
}
660
arguments.insert(arguments.end(), call->args.begin(), call->args.end());
661
662
std::vector<TypeId> argTypes;
663
argTypes.reserve(arguments.size());
664
665
// Move all the types over from the argument typepack for `fn`
666
TypePackIterator curr = begin(fn->argTypes);
667
TypePackIterator fin = end(fn->argTypes);
668
for (; curr != fin; curr++)
669
argTypes.push_back(*curr);
670
671
// Pad out the rest with the variadic as needed.
672
if (auto argTail = curr.tail())
673
{
674
if (const VariadicTypePack* vtp = get<VariadicTypePack>(follow(*argTail)))
675
{
676
while (argTypes.size() < arguments.size())
677
{
678
argTypes.push_back(vtp->ty);
679
}
680
}
681
}
682
683
std::string functionName = getFunctionNameAsString(*call->func).value_or("");
684
if (arguments.size() > argTypes.size())
685
{
686
// We are passing more arguments than we expect, so we should error
687
reportError(CheckedFunctionIncorrectArgs{std::move(functionName), argTypes.size(), arguments.size()}, call->location);
688
return fresh;
689
}
690
691
for (size_t i = 0; i < arguments.size(); i++)
692
{
693
// For example, if the arg is "hi"
694
// The actual arg type is string
695
// The expected arg type is number
696
// The type of the argument in the overload is ~number
697
// We will compare arg and ~number
698
AstExpr* arg = arguments[i];
699
TypeId expectedArgType = argTypes[i];
700
std::shared_ptr<const NormalizedType> norm = normalizer.normalize(expectedArgType);
701
DefId def = dfg->getDef(arg);
702
TypeId runTimeErrorTy;
703
// If we're dealing with any, negating any will cause all subtype tests to fail
704
// However, when someone calls this function, they're going to want to be able to pass it anything,
705
// for that reason, we manually inject never into the context so that the runtime test will always pass.
706
if (!norm)
707
reportError(NormalizationTooComplex{}, arg->location);
708
709
if (norm && get<AnyType>(norm->tops))
710
runTimeErrorTy = builtinTypes->neverType;
711
else
712
runTimeErrorTy = getOrCreateNegation(expectedArgType);
713
fresh.addContext(def, runTimeErrorTy);
714
}
715
716
// Populate the context and now iterate through each of the arguments to the call to find out if we satisfy the types
717
NotNull<Scope> scope{findInnermostScope(call->location)};
718
for (size_t i = 0; i < arguments.size(); i++)
719
{
720
AstExpr* arg = arguments[i];
721
if (auto runTimeFailureType = willRunTimeError(arg, fresh, scope))
722
{
723
reportError(CheckedFunctionCallError{argTypes[i], *runTimeFailureType, functionName, i}, arg->location);
724
}
725
}
726
if (arguments.size() < argTypes.size())
727
{
728
// We are passing fewer arguments than we expect
729
// so we need to ensure that the rest of the args are optional.
730
bool remainingArgsOptional = true;
731
for (size_t i = arguments.size(); i < argTypes.size(); i++)
732
remainingArgsOptional = remainingArgsOptional && isOptional(argTypes[i]);
733
734
if (!remainingArgsOptional)
735
{
736
reportError(CheckedFunctionIncorrectArgs{std::move(functionName), argTypes.size(), arguments.size()}, call->location);
737
return fresh;
738
}
739
}
740
}
741
742
return fresh;
743
}
744
745
NonStrictContext visit(AstExprIndexName* indexName, ValueContext context)
746
{
747
return visit(indexName->expr, context);
748
}
749
750
NonStrictContext visit(AstExprIndexExpr* indexExpr, ValueContext context)
751
{
752
NonStrictContext expr = visit(indexExpr->expr, context);
753
NonStrictContext index = visit(indexExpr->index, ValueContext::RValue);
754
return NonStrictContext::disjunction(builtinTypes, arena, expr, index);
755
}
756
757
758
NonStrictContext visit(AstExprFunction* exprFn)
759
{
760
// TODO: should a function being used as an expression generate a context without the arguments?
761
auto pusher = pushStack(exprFn);
762
NonStrictContext remainder = visit(exprFn->body);
763
auto scope = pusher ? pusher->scope : NotNull{module->getModuleScope().get()};
764
for (AstLocal* local : exprFn->args)
765
{
766
if (std::optional<TypeId> ty = willRunTimeErrorFunctionDefinition(local, scope, remainder))
767
{
768
const char* debugname = exprFn->debugname.value;
769
reportError(NonStrictFunctionDefinitionError{debugname ? debugname : "", local->name.value, *ty}, local->location);
770
}
771
remainder.remove(dfg->getDef(local));
772
773
visit(local->annotation);
774
}
775
visitGenerics(exprFn->generics, exprFn->genericPacks);
776
777
visit(exprFn->returnAnnotation);
778
779
if (exprFn->varargAnnotation)
780
visit(exprFn->varargAnnotation);
781
782
return remainder;
783
}
784
785
NonStrictContext visit(AstExprTable* table)
786
{
787
std::optional<RecursionCounter> _rc;
788
if (FFlag::LuauAddRecursionCounterToNonStrictTypeChecker)
789
{
790
_rc.emplace(&nonStrictRecursionCount);
791
if (FInt::LuauNonStrictTypeCheckerRecursionLimit > 0 && nonStrictRecursionCount >= FInt::LuauNonStrictTypeCheckerRecursionLimit)
792
return {};
793
}
794
795
for (auto [_, key, value] : table->items)
796
{
797
if (key)
798
visit(key, ValueContext::RValue);
799
visit(value, ValueContext::RValue);
800
}
801
802
return {};
803
}
804
805
NonStrictContext visit(AstExprUnary* unary)
806
{
807
return visit(unary->expr, ValueContext::RValue);
808
}
809
810
NonStrictContext visit(AstExprBinary* binary)
811
{
812
NonStrictContext lhs = visit(binary->left, ValueContext::RValue);
813
NonStrictContext rhs = visit(binary->right, ValueContext::RValue);
814
return NonStrictContext::disjunction(builtinTypes, arena, lhs, rhs);
815
}
816
817
NonStrictContext visit(AstExprTypeAssertion* typeAssertion)
818
{
819
visit(typeAssertion->annotation);
820
821
return visit(typeAssertion->expr, ValueContext::RValue);
822
}
823
824
NonStrictContext visit(AstExprIfElse* ifElse)
825
{
826
NonStrictContext condB = visit(ifElse->condition, ValueContext::RValue);
827
NonStrictContext thenB = visit(ifElse->trueExpr, ValueContext::RValue);
828
NonStrictContext elseB = visit(ifElse->falseExpr, ValueContext::RValue);
829
return NonStrictContext::disjunction(builtinTypes, arena, condB, NonStrictContext::conjunction(builtinTypes, arena, thenB, elseB));
830
}
831
832
NonStrictContext visit(AstExprInterpString* interpString)
833
{
834
for (AstExpr* expr : interpString->expressions)
835
visit(expr, ValueContext::RValue);
836
837
return {};
838
}
839
840
NonStrictContext visit(AstExprError* error)
841
{
842
for (AstExpr* expr : error->expressions)
843
visit(expr, ValueContext::RValue);
844
845
return {};
846
}
847
848
NonStrictContext visit(AstExprInstantiate* instantiate)
849
{
850
for (const AstTypeOrPack& param : instantiate->typeArguments)
851
{
852
if (param.type)
853
visit(param.type);
854
else
855
visit(param.typePack);
856
}
857
858
return visit(instantiate->expr, ValueContext::RValue);
859
}
860
861
void visit(AstType* ty)
862
{
863
// If this node is `nullptr`, early exit.
864
if (!ty)
865
return;
866
867
if (auto t = ty->as<AstTypeReference>())
868
return visit(t);
869
else if (auto t = ty->as<AstTypeTable>())
870
return visit(t);
871
else if (auto t = ty->as<AstTypeFunction>())
872
return visit(t);
873
else if (auto t = ty->as<AstTypeTypeof>())
874
return visit(t);
875
else if (auto t = ty->as<AstTypeUnion>())
876
return visit(t);
877
else if (auto t = ty->as<AstTypeIntersection>())
878
return visit(t);
879
else if (auto t = ty->as<AstTypeGroup>())
880
return visit(t->type);
881
}
882
883
void visit(AstTypeReference* ty)
884
{
885
if (FFlag::DebugLuauMagicTypes)
886
{
887
// No further validation is necessary in this case.
888
if (ty->name == kLuauPrint)
889
return;
890
891
if (ty->name == kLuauForceConstraintSolvingIncomplete)
892
{
893
reportError(ConstraintSolvingIncompleteError{}, ty->location);
894
return;
895
}
896
}
897
898
if (FFlag::DebugLuauMagicTypes && (ty->name == kLuauPrint || ty->name == kLuauForceConstraintSolvingIncomplete))
899
return;
900
901
for (const AstTypeOrPack& param : ty->parameters)
902
{
903
if (param.type)
904
visit(param.type);
905
else
906
visit(param.typePack);
907
}
908
909
Scope* scope = findInnermostScope(ty->location);
910
LUAU_ASSERT(scope);
911
912
std::optional<TypeFun> alias = ty->prefix ? scope->lookupImportedType(ty->prefix->value, ty->name.value) : scope->lookupType(ty->name.value);
913
914
if (alias.has_value())
915
{
916
size_t typesRequired = alias->typeParams.size();
917
size_t packsRequired = alias->typePackParams.size();
918
919
bool hasDefaultTypes = std::any_of(
920
alias->typeParams.begin(),
921
alias->typeParams.end(),
922
[](auto&& el)
923
{
924
return el.defaultValue.has_value();
925
}
926
);
927
928
bool hasDefaultPacks = std::any_of(
929
alias->typePackParams.begin(),
930
alias->typePackParams.end(),
931
[](auto&& el)
932
{
933
return el.defaultValue.has_value();
934
}
935
);
936
937
if (!ty->hasParameterList)
938
{
939
if ((!alias->typeParams.empty() && !hasDefaultTypes) || (!alias->typePackParams.empty() && !hasDefaultPacks))
940
reportError(GenericError{"Type parameter list is required"}, ty->location);
941
}
942
943
size_t typesProvided = 0;
944
size_t extraTypes = 0;
945
size_t packsProvided = 0;
946
947
for (const AstTypeOrPack& p : ty->parameters)
948
{
949
if (p.type)
950
{
951
if (packsProvided != 0)
952
{
953
reportError(GenericError{"Type parameters must come before type pack parameters"}, ty->location);
954
continue;
955
}
956
957
if (typesProvided < typesRequired)
958
typesProvided += 1;
959
else
960
extraTypes += 1;
961
}
962
else if (p.typePack)
963
{
964
std::optional<TypePackId> tp = lookupPackAnnotation(p.typePack);
965
if (!tp.has_value())
966
continue;
967
968
if (typesProvided < typesRequired && size(*tp) == 1 && finite(*tp) && first(*tp))
969
typesProvided += 1;
970
else
971
packsProvided += 1;
972
}
973
}
974
975
if (extraTypes != 0 && packsProvided == 0)
976
{
977
// Extra types are only collected into a pack if a pack is expected
978
if (packsRequired != 0)
979
packsProvided += 1;
980
else
981
typesProvided += extraTypes;
982
}
983
984
for (size_t i = typesProvided; i < typesRequired; ++i)
985
{
986
if (alias->typeParams[i].defaultValue)
987
typesProvided += 1;
988
}
989
990
for (size_t i = packsProvided; i < packsRequired; ++i)
991
{
992
if (alias->typePackParams[i].defaultValue)
993
packsProvided += 1;
994
}
995
996
if (extraTypes == 0 && packsProvided + 1 == packsRequired)
997
packsProvided += 1;
998
999
1000
if (typesProvided != typesRequired || packsProvided != packsRequired)
1001
{
1002
reportError(
1003
IncorrectGenericParameterCount{
1004
/* name */ ty->name.value,
1005
/* typeFun */ *alias,
1006
/* actualParameters */ typesProvided,
1007
/* actualPackParameters */ packsProvided,
1008
},
1009
ty->location
1010
);
1011
}
1012
}
1013
else
1014
{
1015
if (scope->lookupPack(ty->name.value))
1016
{
1017
reportError(
1018
SwappedGenericTypeParameter{
1019
ty->name.value,
1020
SwappedGenericTypeParameter::Kind::Type,
1021
},
1022
ty->location
1023
);
1024
}
1025
else
1026
{
1027
std::string symbol = "";
1028
if (ty->prefix)
1029
{
1030
symbol += (*(ty->prefix)).value;
1031
symbol += ".";
1032
}
1033
symbol += ty->name.value;
1034
1035
reportError(UnknownSymbol{std::move(symbol), UnknownSymbol::Context::Type}, ty->location);
1036
}
1037
}
1038
}
1039
1040
void visit(AstTypeTable* table)
1041
{
1042
if (table->indexer)
1043
{
1044
visit(table->indexer->indexType);
1045
visit(table->indexer->resultType);
1046
}
1047
1048
for (auto prop : table->props)
1049
visit(prop.type);
1050
}
1051
1052
void visit(AstTypeFunction* function)
1053
{
1054
visit(function->argTypes);
1055
visit(function->returnTypes);
1056
}
1057
1058
void visit(AstTypeTypeof* typeOf)
1059
{
1060
visit(typeOf->expr, ValueContext::RValue);
1061
}
1062
1063
void visit(AstTypeUnion* unionType)
1064
{
1065
for (auto typ : unionType->types)
1066
visit(typ);
1067
}
1068
1069
void visit(AstTypeIntersection* intersectionType)
1070
{
1071
for (auto typ : intersectionType->types)
1072
visit(typ);
1073
}
1074
1075
void visit(AstTypeList& list)
1076
{
1077
for (auto typ : list.types)
1078
visit(typ);
1079
if (list.tailType)
1080
visit(list.tailType);
1081
}
1082
1083
void visit(AstTypePack* pack)
1084
{
1085
// If there is no pack node, early exit.
1086
if (!pack)
1087
return;
1088
1089
if (auto p = pack->as<AstTypePackExplicit>())
1090
return visit(p);
1091
else if (auto p = pack->as<AstTypePackVariadic>())
1092
return visit(p);
1093
else if (auto p = pack->as<AstTypePackGeneric>())
1094
return visit(p);
1095
}
1096
1097
void visit(AstTypePackExplicit* tp)
1098
{
1099
for (AstType* type : tp->typeList.types)
1100
visit(type);
1101
1102
if (tp->typeList.tailType)
1103
visit(tp->typeList.tailType);
1104
}
1105
1106
void visit(AstTypePackVariadic* tp)
1107
{
1108
visit(tp->variadicType);
1109
}
1110
1111
void visit(AstTypePackGeneric* tp)
1112
{
1113
Scope* scope = findInnermostScope(tp->location);
1114
LUAU_ASSERT(scope);
1115
1116
if (std::optional<TypePackId> alias = scope->lookupPack(tp->genericName.value))
1117
return;
1118
1119
if (scope->lookupType(tp->genericName.value))
1120
return reportError(
1121
SwappedGenericTypeParameter{
1122
tp->genericName.value,
1123
SwappedGenericTypeParameter::Kind::Pack,
1124
},
1125
tp->location
1126
);
1127
1128
reportError(UnknownSymbol{tp->genericName.value, UnknownSymbol::Context::Type}, tp->location);
1129
}
1130
1131
void visitGenerics(AstArray<AstGenericType*> generics, AstArray<AstGenericTypePack*> genericPacks)
1132
{
1133
DenseHashSet<AstName> seen{AstName{}};
1134
1135
for (const auto* g : generics)
1136
{
1137
if (seen.contains(g->name))
1138
reportError(DuplicateGenericParameter{g->name.value}, g->location);
1139
else
1140
seen.insert(g->name);
1141
1142
if (g->defaultValue)
1143
visit(g->defaultValue);
1144
}
1145
1146
for (const auto* g : genericPacks)
1147
{
1148
if (seen.contains(g->name))
1149
reportError(DuplicateGenericParameter{g->name.value}, g->location);
1150
else
1151
seen.insert(g->name);
1152
1153
if (g->defaultValue)
1154
visit(g->defaultValue);
1155
}
1156
}
1157
1158
Scope* findInnermostScope(Location location) const
1159
{
1160
Scope* bestScope = module->getModuleScope().get();
1161
1162
bool didNarrow;
1163
do
1164
{
1165
didNarrow = false;
1166
for (auto scope : bestScope->children)
1167
{
1168
if (scope->location.encloses(location))
1169
{
1170
bestScope = scope.get();
1171
didNarrow = true;
1172
break;
1173
}
1174
}
1175
} while (didNarrow && bestScope->children.size() > 0);
1176
1177
return bestScope;
1178
}
1179
1180
std::optional<TypePackId> lookupPackAnnotation(AstTypePack* annotation) const
1181
{
1182
TypePackId* tp = module->astResolvedTypePacks.find(annotation);
1183
if (tp != nullptr)
1184
return {follow(*tp)};
1185
return {};
1186
}
1187
1188
void reportError(TypeErrorData data, const Location& location)
1189
{
1190
module->errors.emplace_back(location, module->name, std::move(data));
1191
// TODO: weave in logger here?
1192
}
1193
1194
std::optional<TypeId> willRunTimeError(AstExpr* fragment, const NonStrictContext& context, NotNull<Scope> scope)
1195
{
1196
DefId def = dfg->getDef(fragment);
1197
std::vector<DefId> defs;
1198
collectOperands(def, &defs);
1199
for (DefId def : defs)
1200
{
1201
if (std::optional<TypeId> contextTy = context.find(def))
1202
{
1203
1204
TypeId actualType = lookupType(fragment);
1205
if (shouldSkipRuntimeErrorTesting(actualType))
1206
continue;
1207
SubtypingResult r = subtyping.isSubtype(actualType, *contextTy, scope);
1208
if (r.normalizationTooComplex)
1209
reportError(NormalizationTooComplex{}, fragment->location);
1210
if (r.isSubtype)
1211
return {actualType};
1212
}
1213
}
1214
1215
return {};
1216
}
1217
1218
std::optional<TypeId> willRunTimeErrorFunctionDefinition(AstLocal* fragment, NotNull<Scope> scope, const NonStrictContext& context)
1219
{
1220
DefId def = dfg->getDef(fragment);
1221
std::vector<DefId> defs;
1222
collectOperands(def, &defs);
1223
for (DefId def : defs)
1224
{
1225
if (std::optional<TypeId> contextTy = context.find(def))
1226
{
1227
SubtypingResult r1 = subtyping.isSubtype(builtinTypes->unknownType, *contextTy, scope);
1228
SubtypingResult r2 = subtyping.isSubtype(*contextTy, builtinTypes->unknownType, scope);
1229
if (r1.normalizationTooComplex || r2.normalizationTooComplex)
1230
reportError(NormalizationTooComplex{}, fragment->location);
1231
bool isUnknown = r1.isSubtype && r2.isSubtype;
1232
if (isUnknown)
1233
return {builtinTypes->unknownType};
1234
}
1235
}
1236
return {};
1237
}
1238
1239
private:
1240
int nonStrictRecursionCount = 0;
1241
1242
TypeId getOrCreateNegation(TypeId baseType)
1243
{
1244
TypeId& cachedResult = cachedNegations[baseType];
1245
if (!cachedResult)
1246
cachedResult = arena->addType(NegationType{baseType});
1247
return cachedResult;
1248
}
1249
1250
bool shouldSkipRuntimeErrorTesting(TypeId test)
1251
{
1252
TypeId t = follow(test);
1253
return is<NeverType, TypeFunctionInstanceType>(t);
1254
}
1255
};
1256
1257
void checkNonStrict(
1258
NotNull<BuiltinTypes> builtinTypes,
1259
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
1260
NotNull<InternalErrorReporter> ice,
1261
NotNull<UnifierSharedState> unifierState,
1262
NotNull<const DataFlowGraph> dfg,
1263
NotNull<TypeCheckLimits> limits,
1264
const SourceModule& sourceModule,
1265
Module* module
1266
)
1267
{
1268
LUAU_TIMETRACE_SCOPE("checkNonStrict", "Typechecking");
1269
1270
NonStrictTypeChecker typeChecker{NotNull{&module->internalTypes}, builtinTypes, typeFunctionRuntime, ice, unifierState, dfg, limits, module};
1271
typeChecker.visit(sourceModule.root);
1272
unfreeze(module->interfaceTypes);
1273
copyErrors(module->errors, module->interfaceTypes, builtinTypes);
1274
1275
module->errors.erase(
1276
std::remove_if(
1277
module->errors.begin(),
1278
module->errors.end(),
1279
[](auto err)
1280
{
1281
return get<UnknownRequire>(err) != nullptr;
1282
}
1283
),
1284
module->errors.end()
1285
);
1286
1287
freeze(module->interfaceTypes);
1288
}
1289
1290
} // namespace Luau
1291
1292