Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Roblox
GitHub Repository: Roblox/luau
Path: blob/master/Analysis/src/Generalization.cpp
2725 views
1
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
2
3
#include "Luau/Generalization.h"
4
5
#include "Luau/Common.h"
6
#include "Luau/DenseHash.h"
7
#include "Luau/InsertionOrderedMap.h"
8
#include "Luau/OrderedSet.h"
9
#include "Luau/Polarity.h"
10
#include "Luau/Scope.h"
11
#include "Luau/ToString.h"
12
#include "Luau/Type.h"
13
#include "Luau/TypeArena.h"
14
#include "Luau/TypeIds.h"
15
#include "Luau/TypePack.h"
16
#include "Luau/VisitType.h"
17
18
LUAU_FASTINTVARIABLE(LuauGenericCounterMaxDepth, 15)
19
LUAU_FASTINTVARIABLE(LuauGenericCounterMaxSteps, 1500)
20
LUAU_FASTFLAGVARIABLE(LuauGeneralizationMoreAwareOfBounds3)
21
22
namespace Luau
23
{
24
25
struct FreeTypeSearcher : TypeVisitor
26
{
27
NotNull<Scope> scope;
28
NotNull<DenseHashSet<TypeId>> cachedTypes;
29
30
explicit FreeTypeSearcher(NotNull<Scope> scope, NotNull<DenseHashSet<TypeId>> cachedTypes)
31
: TypeVisitor("FreeTypeSearcher", /* skipBoundTypes */ true)
32
, scope(scope)
33
, cachedTypes(cachedTypes)
34
{
35
}
36
37
bool isWithinFunction = false;
38
Polarity polarity = Polarity::Positive;
39
40
void flip()
41
{
42
polarity = invert(polarity);
43
}
44
45
DenseHashSet<const void*> seenPositive{nullptr};
46
DenseHashSet<const void*> seenNegative{nullptr};
47
48
bool seenWithCurrentPolarity(const void* ty)
49
{
50
switch (polarity)
51
{
52
case Polarity::Positive:
53
{
54
if (seenPositive.contains(ty))
55
return true;
56
57
seenPositive.insert(ty);
58
return false;
59
}
60
case Polarity::Negative:
61
{
62
if (seenNegative.contains(ty))
63
return true;
64
65
seenNegative.insert(ty);
66
return false;
67
}
68
case Polarity::Mixed:
69
{
70
if (seenPositive.contains(ty) && seenNegative.contains(ty))
71
return true;
72
73
seenPositive.insert(ty);
74
seenNegative.insert(ty);
75
return false;
76
}
77
default:
78
LUAU_ASSERT(!"Unreachable");
79
}
80
81
return false;
82
}
83
84
DenseHashMap<const void*, size_t> negativeTypes{0};
85
DenseHashMap<const void*, size_t> positiveTypes{0};
86
87
InsertionOrderedMap<TypeId, GeneralizationParams<TypeId>> types;
88
InsertionOrderedMap<TypePackId, GeneralizationParams<TypePackId>> typePacks;
89
90
OrderedSet<TypeId> unsealedTables;
91
92
bool visit(TypeId ty) override
93
{
94
if (cachedTypes->contains(ty) || seenWithCurrentPolarity(ty))
95
return false;
96
97
LUAU_ASSERT(ty);
98
return true;
99
}
100
101
bool visit(TypeId ty, const FreeType& ft) override
102
{
103
if (!subsumes(scope, ft.scope))
104
return true;
105
106
GeneralizationParams<TypeId>& params = types[ty];
107
++params.useCount;
108
109
if (cachedTypes->contains(ty) || seenWithCurrentPolarity(ty))
110
return false;
111
112
if (!isWithinFunction)
113
params.foundOutsideFunctions = true;
114
115
params.polarity |= polarity;
116
117
return true;
118
}
119
120
bool visit(TypeId ty, const TableType& tt) override
121
{
122
if (cachedTypes->contains(ty) || seenWithCurrentPolarity(ty))
123
return false;
124
125
if ((tt.state == TableState::Free || tt.state == TableState::Unsealed) && subsumes(scope, tt.scope))
126
unsealedTables.insert(ty);
127
128
for (const auto& [_name, prop] : tt.props)
129
{
130
if (prop.isReadOnly())
131
{
132
traverse(*prop.readTy);
133
}
134
else if (prop.isWriteOnly())
135
{
136
Polarity p = polarity;
137
polarity = Polarity::Negative;
138
traverse(*prop.writeTy);
139
polarity = p;
140
}
141
else if (prop.isShared())
142
{
143
Polarity p = polarity;
144
polarity = Polarity::Mixed;
145
traverse(*prop.readTy);
146
polarity = p;
147
}
148
else
149
{
150
LUAU_ASSERT(prop.isReadWrite() && !prop.isShared());
151
152
traverse(*prop.readTy);
153
Polarity p = polarity;
154
polarity = Polarity::Negative;
155
traverse(*prop.writeTy);
156
polarity = p;
157
}
158
}
159
160
if (tt.indexer)
161
{
162
// {[K]: V} is equivalent to three functions: get, set, and iterate
163
//
164
// (K) -> V
165
// (K, V) -> ()
166
// () -> {K}
167
//
168
// K and V therefore both have mixed polarity.
169
170
const Polarity p = polarity;
171
polarity = Polarity::Mixed;
172
traverse(tt.indexer->indexType);
173
traverse(tt.indexer->indexResultType);
174
polarity = p;
175
}
176
177
return false;
178
}
179
180
bool visit(TypeId ty, const FunctionType& ft) override
181
{
182
if (cachedTypes->contains(ty) || seenWithCurrentPolarity(ty))
183
return false;
184
185
const bool oldValue = isWithinFunction;
186
isWithinFunction = true;
187
188
flip();
189
traverse(ft.argTypes);
190
flip();
191
192
traverse(ft.retTypes);
193
194
isWithinFunction = oldValue;
195
196
return false;
197
}
198
199
bool visit(TypeId, const ExternType&) override
200
{
201
return false;
202
}
203
204
bool visit(TypePackId tp, const FreeTypePack& ftp) override
205
{
206
if (seenWithCurrentPolarity(tp))
207
return false;
208
209
if (!subsumes(scope, ftp.scope))
210
return true;
211
212
GeneralizationParams<TypePackId>& params = typePacks[tp];
213
++params.useCount;
214
215
if (!isWithinFunction)
216
params.foundOutsideFunctions = true;
217
218
params.polarity |= polarity;
219
220
return true;
221
}
222
};
223
224
// We keep a running set of types that will not change under generalization and
225
// only have outgoing references to types that are the same. We use this to
226
// short circuit generalization. It improves performance quite a lot.
227
//
228
// We do this by tracing through the type and searching for types that are
229
// uncacheable. If a type has a reference to an uncacheable type, it is itself
230
// uncacheable.
231
//
232
// If a type has no outbound references to uncacheable types, we add it to the
233
// cache.
234
struct TypeCacher : TypeOnceVisitor
235
{
236
NotNull<DenseHashSet<TypeId>> cachedTypes;
237
238
DenseHashSet<TypeId> uncacheable{nullptr};
239
DenseHashSet<TypePackId> uncacheablePacks{nullptr};
240
241
explicit TypeCacher(NotNull<DenseHashSet<TypeId>> cachedTypes)
242
: TypeOnceVisitor("TypeCacher", /* skipBoundTypes */ true)
243
, cachedTypes(cachedTypes)
244
{
245
}
246
247
void cache(TypeId ty) const
248
{
249
cachedTypes->insert(follow(ty));
250
}
251
252
bool isCached(TypeId ty) const
253
{
254
return cachedTypes->contains(follow(ty));
255
}
256
257
void markUncacheable(TypeId ty)
258
{
259
uncacheable.insert(follow(ty));
260
}
261
262
void markUncacheable(TypePackId tp)
263
{
264
uncacheablePacks.insert(follow(tp));
265
}
266
267
bool isUncacheable(TypeId ty) const
268
{
269
return uncacheable.contains(follow(ty));
270
}
271
272
bool isUncacheable(TypePackId tp) const
273
{
274
return uncacheablePacks.contains(follow(tp));
275
}
276
277
bool visit(TypeId ty) override
278
{
279
// NOTE: `TypeCacher` should explicitly visit _all_ types and type packs,
280
// otherwise it's prone to marking types that cannot be cached as
281
// cacheable.
282
LUAU_ASSERT(false);
283
LUAU_UNREACHABLE();
284
}
285
286
bool visit(TypeId ty, const FreeType& ft) override
287
{
288
// Free types are never cacheable.
289
LUAU_ASSERT(!isCached(ty));
290
291
if (!isUncacheable(ty))
292
{
293
traverse(ft.lowerBound);
294
traverse(ft.upperBound);
295
296
markUncacheable(ty);
297
}
298
299
return false;
300
}
301
302
bool visit(TypeId ty, const GenericType&) override
303
{
304
cache(ty);
305
return false;
306
}
307
308
bool visit(TypeId ty, const ErrorType&) override
309
{
310
cache(ty);
311
return false;
312
}
313
314
bool visit(TypeId ty, const PrimitiveType&) override
315
{
316
cache(ty);
317
return false;
318
}
319
320
bool visit(TypeId ty, const SingletonType&) override
321
{
322
cache(ty);
323
return false;
324
}
325
326
bool visit(TypeId ty, const BlockedType&) override
327
{
328
markUncacheable(ty);
329
return false;
330
}
331
332
bool visit(TypeId ty, const PendingExpansionType&) override
333
{
334
markUncacheable(ty);
335
return false;
336
}
337
338
bool visit(TypeId ty, const FunctionType& ft) override
339
{
340
if (isCached(ty) || isUncacheable(ty))
341
return false;
342
343
traverse(ft.argTypes);
344
traverse(ft.retTypes);
345
for (TypeId gen : ft.generics)
346
traverse(gen);
347
348
bool uncacheable = false;
349
350
if (isUncacheable(ft.argTypes))
351
uncacheable = true;
352
353
else if (isUncacheable(ft.retTypes))
354
uncacheable = true;
355
356
for (TypeId argTy : ft.argTypes)
357
{
358
if (isUncacheable(argTy))
359
{
360
uncacheable = true;
361
break;
362
}
363
}
364
365
for (TypeId retTy : ft.retTypes)
366
{
367
if (isUncacheable(retTy))
368
{
369
uncacheable = true;
370
break;
371
}
372
}
373
374
for (TypeId g : ft.generics)
375
{
376
if (isUncacheable(g))
377
{
378
uncacheable = true;
379
break;
380
}
381
}
382
383
if (uncacheable)
384
markUncacheable(ty);
385
else
386
cache(ty);
387
388
return false;
389
}
390
391
bool visit(TypeId ty, const TableType& tt) override
392
{
393
if (isCached(ty) || isUncacheable(ty))
394
return false;
395
396
if (tt.boundTo)
397
{
398
traverse(*tt.boundTo);
399
if (isUncacheable(*tt.boundTo))
400
{
401
markUncacheable(ty);
402
return false;
403
}
404
}
405
406
bool uncacheable = false;
407
408
// This logic runs immediately after generalization, so any remaining
409
// unsealed tables are assuredly not cacheable. They may yet have
410
// properties added to them.
411
if (tt.state == TableState::Free || tt.state == TableState::Unsealed)
412
uncacheable = true;
413
414
for (const auto& [_name, prop] : tt.props)
415
{
416
if (prop.readTy)
417
{
418
traverse(*prop.readTy);
419
420
if (isUncacheable(*prop.readTy))
421
uncacheable = true;
422
}
423
if (prop.writeTy && prop.writeTy != prop.readTy)
424
{
425
traverse(*prop.writeTy);
426
427
if (isUncacheable(*prop.writeTy))
428
uncacheable = true;
429
}
430
}
431
432
if (tt.indexer)
433
{
434
traverse(tt.indexer->indexType);
435
if (isUncacheable(tt.indexer->indexType))
436
uncacheable = true;
437
438
traverse(tt.indexer->indexResultType);
439
if (isUncacheable(tt.indexer->indexResultType))
440
uncacheable = true;
441
}
442
443
if (uncacheable)
444
markUncacheable(ty);
445
else
446
cache(ty);
447
448
return false;
449
}
450
451
bool visit(TypeId ty, const MetatableType& mtv) override
452
{
453
traverse(mtv.table);
454
traverse(mtv.metatable);
455
if (isUncacheable(mtv.table) || isUncacheable(mtv.metatable))
456
markUncacheable(ty);
457
else
458
cache(ty);
459
return false;
460
}
461
462
bool visit(TypeId ty, const ExternType&) override
463
{
464
cache(ty);
465
return false;
466
}
467
468
bool visit(TypeId ty, const AnyType&) override
469
{
470
cache(ty);
471
return false;
472
}
473
474
bool visit(TypeId ty, const NoRefineType&) override
475
{
476
cache(ty);
477
return false;
478
}
479
480
bool visit(TypeId ty, const UnionType& ut) override
481
{
482
if (isUncacheable(ty) || isCached(ty))
483
return false;
484
485
bool uncacheable = false;
486
487
for (TypeId partTy : ut.options)
488
{
489
traverse(partTy);
490
491
uncacheable |= isUncacheable(partTy);
492
}
493
494
if (uncacheable)
495
markUncacheable(ty);
496
else
497
cache(ty);
498
499
return false;
500
}
501
502
bool visit(TypeId ty, const IntersectionType& it) override
503
{
504
if (isUncacheable(ty) || isCached(ty))
505
return false;
506
507
bool uncacheable = false;
508
509
for (TypeId partTy : it.parts)
510
{
511
traverse(partTy);
512
513
uncacheable |= isUncacheable(partTy);
514
}
515
516
if (uncacheable)
517
markUncacheable(ty);
518
else
519
cache(ty);
520
521
return false;
522
}
523
524
bool visit(TypeId ty, const UnknownType&) override
525
{
526
cache(ty);
527
return false;
528
}
529
530
bool visit(TypeId ty, const NeverType&) override
531
{
532
cache(ty);
533
return false;
534
}
535
536
bool visit(TypeId ty, const NegationType& nt) override
537
{
538
if (!isCached(ty) && !isUncacheable(ty))
539
{
540
traverse(nt.ty);
541
542
if (isUncacheable(nt.ty))
543
markUncacheable(ty);
544
else
545
cache(ty);
546
}
547
548
return false;
549
}
550
551
bool visit(TypeId ty, const TypeFunctionInstanceType& tfit) override
552
{
553
if (isCached(ty) || isUncacheable(ty))
554
return false;
555
556
bool uncacheable = false;
557
558
for (TypeId argTy : tfit.typeArguments)
559
{
560
traverse(argTy);
561
562
if (isUncacheable(argTy))
563
uncacheable = true;
564
}
565
566
for (TypePackId argPack : tfit.packArguments)
567
{
568
traverse(argPack);
569
570
if (isUncacheable(argPack))
571
uncacheable = true;
572
}
573
574
if (uncacheable)
575
markUncacheable(ty);
576
else
577
cache(ty);
578
579
return false;
580
}
581
582
bool visit(TypePackId tp) override
583
{
584
// NOTE: `TypeCacher` should explicitly visit _all_ types and type packs,
585
// otherwise it's prone to marking types that cannot be cached as
586
// cacheable, which will segfault down the line.
587
LUAU_ASSERT(false);
588
LUAU_UNREACHABLE();
589
}
590
591
bool visit(TypePackId tp, const FreeTypePack&) override
592
{
593
markUncacheable(tp);
594
return false;
595
}
596
597
bool visit(TypePackId tp, const GenericTypePack& gtp) override
598
{
599
return true;
600
}
601
602
bool visit(TypePackId tp, const ErrorTypePack& etp) override
603
{
604
return true;
605
}
606
607
bool visit(TypePackId tp, const VariadicTypePack& vtp) override
608
{
609
if (isUncacheable(tp))
610
return false;
611
612
traverse(vtp.ty);
613
614
if (isUncacheable(vtp.ty))
615
markUncacheable(tp);
616
617
return false;
618
}
619
620
bool visit(TypePackId tp, const BlockedTypePack&) override
621
{
622
markUncacheable(tp);
623
return false;
624
}
625
626
bool visit(TypePackId tp, const TypeFunctionInstanceTypePack&) override
627
{
628
markUncacheable(tp);
629
return false;
630
}
631
632
bool visit(TypePackId tp, const BoundTypePack& btp) override
633
{
634
traverse(btp.boundTo);
635
if (isUncacheable(btp.boundTo))
636
markUncacheable(tp);
637
return false;
638
}
639
640
bool visit(TypePackId tp, const TypePack& typ) override
641
{
642
bool uncacheable = false;
643
for (TypeId ty : typ.head)
644
{
645
traverse(ty);
646
uncacheable |= isUncacheable(ty);
647
}
648
if (typ.tail)
649
{
650
traverse(*typ.tail);
651
uncacheable |= isUncacheable(*typ.tail);
652
}
653
if (uncacheable)
654
markUncacheable(tp);
655
return false;
656
}
657
};
658
659
namespace
660
{
661
662
struct TypeRemover
663
{
664
NotNull<BuiltinTypes> builtinTypes;
665
NotNull<TypeArena> arena;
666
667
TypeId needle;
668
DenseHashSet<TypeId> seen{nullptr};
669
670
void process(TypeId item)
671
{
672
item = follow(item);
673
674
// If we've already visited this item, or it's outside our arena, then
675
// do not try to mutate it.
676
if (seen.contains(item) || item->owningArena != arena || item->persistent)
677
return;
678
seen.insert(item);
679
680
if (auto ut = getMutable<UnionType>(item))
681
{
682
TypeIds newOptions;
683
for (TypeId option : ut->options)
684
{
685
process(option);
686
option = follow(option);
687
if (option != needle && !is<NeverType>(option) && option != item)
688
newOptions.insert(option);
689
}
690
if (ut->options.size() != newOptions.size())
691
{
692
if (newOptions.empty())
693
emplaceType<BoundType>(asMutable(item), builtinTypes->neverType);
694
else if (newOptions.size() == 1)
695
emplaceType<BoundType>(asMutable(item), *newOptions.begin());
696
else
697
emplaceType<BoundType>(asMutable(item), arena->addType(UnionType{newOptions.take()}));
698
}
699
}
700
else if (auto it = getMutable<IntersectionType>(item))
701
{
702
TypeIds newParts;
703
for (TypeId part : it->parts)
704
{
705
process(part);
706
part = follow(part);
707
if (part != needle && !is<UnknownType>(part) && part != item)
708
newParts.insert(part);
709
}
710
if (it->parts.size() != newParts.size())
711
{
712
if (newParts.empty())
713
emplaceType<BoundType>(asMutable(item), builtinTypes->unknownType);
714
else if (newParts.size() == 1)
715
emplaceType<BoundType>(asMutable(item), *newParts.begin());
716
else
717
emplaceType<BoundType>(asMutable(item), arena->addType(IntersectionType{newParts.take()}));
718
}
719
}
720
}
721
};
722
723
void removeType(NotNull<TypeArena> arena, NotNull<BuiltinTypes> builtinTypes, TypeId haystack, TypeId needle)
724
{
725
TypeRemover tr{builtinTypes, arena, needle};
726
tr.process(haystack);
727
}
728
729
} // namespace
730
731
GeneralizationResult<TypeId> generalizeType(
732
NotNull<TypeArena> arena,
733
NotNull<BuiltinTypes> builtinTypes,
734
NotNull<Scope> scope,
735
TypeId freeTy,
736
const GeneralizationParams<TypeId>& params
737
)
738
{
739
freeTy = follow(freeTy);
740
741
FreeType* ft = getMutable<FreeType>(freeTy);
742
LUAU_ASSERT(ft);
743
744
LUAU_ASSERT(isKnown(params.polarity));
745
746
const bool hasLowerBound = !get<NeverType>(follow(ft->lowerBound));
747
const bool hasUpperBound = !get<UnknownType>(follow(ft->upperBound));
748
749
const bool isWithinFunction = !params.foundOutsideFunctions;
750
751
if (!hasLowerBound && !hasUpperBound)
752
{
753
if (!isWithinFunction)
754
emplaceType<BoundType>(asMutable(freeTy), builtinTypes->unknownType);
755
else
756
{
757
emplaceType<GenericType>(asMutable(freeTy), scope, params.polarity);
758
return {freeTy, /*wasReplacedByGeneric*/ true};
759
}
760
}
761
// It is possible that this free type has other free types in its upper
762
// or lower bounds. If this is the case, we must replace those
763
// references with never (for the lower bound) or unknown (for the upper
764
// bound).
765
//
766
// If we do not do this, we get tautological bounds like a <: a <: unknown.
767
else if (isPositive(params.polarity) && !hasUpperBound)
768
{
769
TypeId lb = follow(ft->lowerBound);
770
if (FreeType* lowerFree = getMutable<FreeType>(lb); lowerFree && lowerFree->upperBound == freeTy)
771
{
772
// If we are generalizing 'a in:
773
//
774
// LO <: 'b <: 'a <: UP
775
//
776
// ... we can hold onto the bound UP and forward it to 'b.
777
if (FFlag::LuauGeneralizationMoreAwareOfBounds3)
778
{
779
TypeId upperBound = follow(ft->upperBound);
780
removeType(arena, builtinTypes, upperBound, freeTy);
781
lowerFree->upperBound = follow(upperBound);
782
}
783
else
784
lowerFree->upperBound = builtinTypes->unknownType;
785
}
786
else
787
removeType(arena, builtinTypes, lb, freeTy);
788
789
if (follow(lb) != freeTy)
790
emplaceType<BoundType>(asMutable(freeTy), lb);
791
else if (!isWithinFunction)
792
emplaceType<BoundType>(asMutable(freeTy), builtinTypes->unknownType);
793
else
794
{
795
// if the lower bound is the type in question (eg 'a <: 'a), we don't actually have a lower bound.
796
emplaceType<GenericType>(asMutable(freeTy), scope, params.polarity);
797
return {freeTy, /*wasReplacedByGeneric*/ true};
798
}
799
}
800
else
801
{
802
TypeId ub = follow(ft->upperBound);
803
if (FreeType* upperFree = getMutable<FreeType>(ub); upperFree && upperFree->lowerBound == freeTy)
804
{
805
if (FFlag::LuauGeneralizationMoreAwareOfBounds3)
806
{
807
// If we are generalizing 'a in:
808
//
809
// LO <: 'a <: 'b <: UP
810
//
811
// ... we can hold onto the bound LO and forward it to 'b.
812
TypeId lowerBound = follow(ft->lowerBound);
813
removeType(arena, builtinTypes, lowerBound, freeTy);
814
upperFree->lowerBound = follow(lowerBound);
815
}
816
else
817
upperFree->lowerBound = builtinTypes->neverType;
818
}
819
else
820
removeType(arena, builtinTypes, ub, freeTy);
821
822
if (follow(ub) != freeTy)
823
emplaceType<BoundType>(asMutable(freeTy), ub);
824
else if (!isWithinFunction || params.useCount == 1)
825
{
826
// If we have some free type:
827
//
828
// A <: 'b < C
829
//
830
// We can approximately generalize this to the intersection of its
831
// bounds, taking care to avoid constructing a degenerate
832
// union or intersection by clipping the free type from the upper
833
// and lower bounds, then also cleaning the resulting intersection.
834
removeType(arena, builtinTypes, ft->lowerBound, freeTy);
835
TypeId cleanedTy = arena->addType(IntersectionType{{ft->lowerBound, ub}});
836
removeType(arena, builtinTypes, cleanedTy, freeTy);
837
emplaceType<BoundType>(asMutable(freeTy), cleanedTy);
838
}
839
else
840
{
841
// if the upper bound is the type in question, we don't actually have an upper bound.
842
emplaceType<GenericType>(asMutable(freeTy), scope, params.polarity);
843
return {freeTy, /*wasReplacedByGeneric*/ true};
844
}
845
}
846
847
return {freeTy, /*wasReplacedByGeneric*/ false};
848
}
849
850
GeneralizationResult<TypePackId> generalizeTypePack(
851
NotNull<TypeArena> arena,
852
NotNull<BuiltinTypes> builtinTypes,
853
NotNull<Scope> scope,
854
TypePackId tp,
855
const GeneralizationParams<TypePackId>& params
856
)
857
{
858
tp = follow(tp);
859
860
if (tp->owningArena != arena)
861
return {tp, /*wasReplacedByGeneric*/ false};
862
863
const FreeTypePack* ftp = get<FreeTypePack>(tp);
864
if (!ftp)
865
return {tp, /*wasReplacedByGeneric*/ false};
866
867
if (!subsumes(scope, ftp->scope))
868
return {tp, /*wasReplacedByGeneric*/ false};
869
870
if (1 == params.useCount)
871
emplaceTypePack<BoundTypePack>(asMutable(tp), builtinTypes->unknownTypePack);
872
else
873
{
874
emplaceTypePack<GenericTypePack>(asMutable(tp), scope, params.polarity);
875
return {tp, /*wasReplacedByGeneric*/ true};
876
}
877
878
return {tp, /*wasReplacedByGeneric*/ false};
879
}
880
881
void sealTable(NotNull<Scope> scope, TypeId ty)
882
{
883
TableType* tableTy = getMutable<TableType>(follow(ty));
884
if (!tableTy)
885
return;
886
887
if (!subsumes(scope, tableTy->scope))
888
return;
889
890
if (tableTy->state == TableState::Unsealed || tableTy->state == TableState::Free)
891
tableTy->state = TableState::Sealed;
892
}
893
894
std::optional<TypeId> generalize(
895
NotNull<TypeArena> arena,
896
NotNull<BuiltinTypes> builtinTypes,
897
NotNull<Scope> scope,
898
NotNull<DenseHashSet<TypeId>> cachedTypes,
899
TypeId ty,
900
std::optional<TypeId> generalizationTarget
901
)
902
{
903
ty = follow(ty);
904
905
if (ty->owningArena != arena || ty->persistent)
906
return ty;
907
908
FreeTypeSearcher fts{scope, cachedTypes};
909
fts.traverse(ty);
910
911
FunctionType* functionTy = getMutable<FunctionType>(ty);
912
auto pushGeneric = [&](TypeId t)
913
{
914
if (functionTy)
915
functionTy->generics.push_back(t);
916
};
917
918
auto pushGenericPack = [&](TypePackId tp)
919
{
920
if (functionTy)
921
functionTy->genericPacks.push_back(tp);
922
};
923
924
for (const auto& [freeTy, params] : fts.types)
925
{
926
if (!generalizationTarget || freeTy == *generalizationTarget)
927
{
928
GeneralizationResult<TypeId> res = generalizeType(arena, builtinTypes, scope, freeTy, params);
929
930
if (res.resourceLimitsExceeded)
931
return std::nullopt;
932
933
if (res && res.wasReplacedByGeneric)
934
pushGeneric(*res.result);
935
}
936
}
937
938
for (TypeId unsealedTableTy : fts.unsealedTables)
939
{
940
if (!generalizationTarget || unsealedTableTy == *generalizationTarget)
941
sealTable(scope, unsealedTableTy);
942
}
943
944
for (const auto& [freePackId, params] : fts.typePacks)
945
{
946
TypePackId freePack = follow(freePackId);
947
if (!generalizationTarget)
948
{
949
GeneralizationResult<TypePackId> generalizedTp = generalizeTypePack(arena, builtinTypes, scope, freePack, params);
950
951
if (generalizedTp.resourceLimitsExceeded)
952
return std::nullopt;
953
954
if (generalizedTp && generalizedTp.wasReplacedByGeneric)
955
pushGenericPack(freePack);
956
}
957
}
958
959
TypeCacher cacher{cachedTypes};
960
cacher.traverse(ty);
961
962
return ty;
963
}
964
965
struct GenericCounter : TypeVisitor
966
{
967
struct CounterState
968
{
969
size_t count = 0;
970
Polarity polarity = Polarity::None;
971
};
972
973
// This traversal does need to walk into types multiple times because we
974
// care about generics that are only referred to once. If a type is present
975
// more than once, however, we don't care exactly how many times, so we also
976
// track counts in our "seen set."
977
DenseHashMap<TypeId, size_t> seenCounts{nullptr};
978
979
NotNull<DenseHashSet<TypeId>> cachedTypes;
980
DenseHashMap<TypeId, CounterState> generics{nullptr};
981
DenseHashMap<TypePackId, CounterState> genericPacks{nullptr};
982
983
Polarity polarity = Polarity::Positive;
984
985
int steps = 0;
986
bool hitLimits = false;
987
988
explicit GenericCounter(NotNull<DenseHashSet<TypeId>> cachedTypes)
989
: TypeVisitor("GenericCounter", /* skipBoundTypes */ true)
990
, cachedTypes(cachedTypes)
991
{
992
}
993
994
void checkLimits()
995
{
996
steps++;
997
hitLimits |= steps > FInt::LuauGenericCounterMaxSteps;
998
}
999
1000
bool visit(TypeId ty) override
1001
{
1002
checkLimits();
1003
return !hitLimits;
1004
}
1005
1006
1007
bool visit(TypeId ty, const FunctionType& ft) override
1008
{
1009
checkLimits();
1010
1011
if (ty->persistent)
1012
return false;
1013
1014
size_t& seenCount = seenCounts[ty];
1015
if (seenCount > 1)
1016
return false;
1017
1018
++seenCount;
1019
1020
polarity = invert(polarity);
1021
traverse(ft.argTypes);
1022
polarity = invert(polarity);
1023
traverse(ft.retTypes);
1024
1025
return false;
1026
}
1027
1028
bool visit(TypeId ty, const TableType& tt) override
1029
{
1030
checkLimits();
1031
1032
if (ty->persistent)
1033
return false;
1034
1035
size_t& seenCount = seenCounts[ty];
1036
if (seenCount > 1)
1037
return false;
1038
++seenCount;
1039
1040
const Polarity previous = polarity;
1041
1042
for (const auto& [_name, prop] : tt.props)
1043
{
1044
if (prop.isReadOnly())
1045
{
1046
traverse(*prop.readTy);
1047
}
1048
else if (prop.isWriteOnly())
1049
{
1050
Polarity p = polarity;
1051
polarity = Polarity::Negative;
1052
traverse(*prop.writeTy);
1053
polarity = p;
1054
}
1055
else if (prop.isShared())
1056
{
1057
Polarity p = polarity;
1058
polarity = Polarity::Mixed;
1059
traverse(*prop.readTy);
1060
polarity = p;
1061
}
1062
else
1063
{
1064
LUAU_ASSERT(prop.isReadWrite() && !prop.isShared());
1065
1066
traverse(*prop.readTy);
1067
Polarity p = polarity;
1068
polarity = Polarity::Negative;
1069
traverse(*prop.writeTy);
1070
polarity = p;
1071
}
1072
}
1073
1074
if (tt.indexer)
1075
{
1076
polarity = Polarity::Mixed;
1077
traverse(tt.indexer->indexType);
1078
traverse(tt.indexer->indexResultType);
1079
polarity = previous;
1080
}
1081
1082
return false;
1083
}
1084
1085
bool visit(TypeId ty, const ExternType&) override
1086
{
1087
return false;
1088
}
1089
1090
bool visit(TypeId ty, const GenericType&) override
1091
{
1092
auto state = generics.find(ty);
1093
if (state)
1094
{
1095
++state->count;
1096
state->polarity |= polarity;
1097
}
1098
1099
return false;
1100
}
1101
1102
bool visit(TypePackId tp, const GenericTypePack&) override
1103
{
1104
auto state = genericPacks.find(tp);
1105
if (state)
1106
{
1107
++state->count;
1108
state->polarity |= polarity;
1109
}
1110
1111
return false;
1112
}
1113
};
1114
1115
void pruneUnnecessaryGenerics(
1116
NotNull<TypeArena> arena,
1117
NotNull<BuiltinTypes> builtinTypes,
1118
NotNull<Scope> scope,
1119
NotNull<DenseHashSet<TypeId>> cachedTypes,
1120
TypeId ty
1121
)
1122
{
1123
ty = follow(ty);
1124
1125
if (ty->owningArena != arena || ty->persistent)
1126
return;
1127
1128
FunctionType* functionTy = getMutable<FunctionType>(ty);
1129
1130
if (!functionTy)
1131
return;
1132
1133
// If a generic has no explicit name and is only referred to in one place in
1134
// the function's signature, it can be replaced with unknown.
1135
1136
GenericCounter counter{cachedTypes};
1137
for (TypeId generic : functionTy->generics)
1138
{
1139
generic = follow(generic);
1140
auto g = get<GenericType>(generic);
1141
if (g && !g->explicitName)
1142
counter.generics[generic] = {};
1143
}
1144
1145
// It is sometimes the case that a pack in the generic list will become a
1146
// pack that (transitively) has a generic tail. If it does, we need to add
1147
// that generic tail to the generic pack list.
1148
for (size_t i = 0; i < functionTy->genericPacks.size(); ++i)
1149
{
1150
TypePackId genericPack = follow(functionTy->genericPacks[i]);
1151
1152
TypePackId tail = getTail(genericPack);
1153
1154
if (tail != genericPack)
1155
functionTy->genericPacks.push_back(tail);
1156
1157
if (auto g = get<GenericTypePack>(tail); g && !g->explicitName)
1158
counter.genericPacks[genericPack] = {};
1159
}
1160
1161
counter.traverse(ty);
1162
1163
if (!counter.hitLimits)
1164
{
1165
for (const auto& [generic, state] : counter.generics)
1166
{
1167
if (state.count == 1 && state.polarity != Polarity::Mixed)
1168
{
1169
if (arena.get() != generic->owningArena)
1170
continue;
1171
emplaceType<BoundType>(asMutable(generic), builtinTypes->unknownType);
1172
}
1173
}
1174
}
1175
1176
// Remove duplicates and types that aren't actually generics.
1177
DenseHashSet<TypeId> seen{nullptr};
1178
auto it = std::remove_if(
1179
functionTy->generics.begin(),
1180
functionTy->generics.end(),
1181
[&](TypeId ty)
1182
{
1183
ty = follow(ty);
1184
if (seen.contains(ty))
1185
return true;
1186
seen.insert(ty);
1187
1188
if (!counter.hitLimits)
1189
{
1190
auto state = counter.generics.find(ty);
1191
if (state && state->count == 0)
1192
return true;
1193
}
1194
1195
return !get<GenericType>(ty);
1196
}
1197
);
1198
1199
functionTy->generics.erase(it, functionTy->generics.end());
1200
1201
1202
if (!counter.hitLimits)
1203
{
1204
for (const auto& [genericPack, state] : counter.genericPacks)
1205
{
1206
if (state.count == 1)
1207
emplaceTypePack<BoundTypePack>(asMutable(genericPack), builtinTypes->unknownTypePack);
1208
}
1209
}
1210
1211
1212
DenseHashSet<TypePackId> seen2{nullptr};
1213
auto it2 = std::remove_if(
1214
functionTy->genericPacks.begin(),
1215
functionTy->genericPacks.end(),
1216
[&](TypePackId tp)
1217
{
1218
tp = follow(tp);
1219
if (seen2.contains(tp))
1220
return true;
1221
seen2.insert(tp);
1222
1223
if (!counter.hitLimits)
1224
{
1225
auto state = counter.genericPacks.find(tp);
1226
if (state && state->count == 0)
1227
return true;
1228
}
1229
1230
return !get<GenericTypePack>(tp);
1231
}
1232
);
1233
1234
functionTy->genericPacks.erase(it2, functionTy->genericPacks.end());
1235
}
1236
1237
} // namespace Luau
1238
1239