Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Roblox
GitHub Repository: Roblox/luau
Path: blob/master/tests/Generalization.test.cpp
2723 views
1
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
2
3
#include "Luau/Generalization.h"
4
#include "Luau/Scope.h"
5
#include "Luau/ToString.h"
6
#include "Luau/Type.h"
7
#include "Luau/TypeArena.h"
8
#include "Luau/Error.h"
9
10
#include "Fixture.h"
11
#include "ScopedFlags.h"
12
13
#include "doctest.h"
14
15
using namespace Luau;
16
17
LUAU_FASTFLAG(DebugLuauForceOldSolver)
18
LUAU_FASTFLAG(DebugLuauForbidInternalTypes)
19
LUAU_FASTFLAG(LuauOverloadGetsInstantiated)
20
LUAU_FASTFLAG(LuauReplacerRespectsReboundGenerics)
21
22
TEST_SUITE_BEGIN("Generalization");
23
24
struct GeneralizationFixture
25
{
26
TypeArena arena;
27
BuiltinTypes builtinTypes;
28
ScopePtr globalScope = std::make_shared<Scope>(builtinTypes.anyTypePack);
29
ScopePtr scope = std::make_shared<Scope>(globalScope);
30
ToStringOptions opts;
31
32
DenseHashSet<TypeId> generalizedTypes_{nullptr};
33
NotNull<DenseHashSet<TypeId>> generalizedTypes{&generalizedTypes_};
34
35
ScopedFastFlag sff{FFlag::DebugLuauForceOldSolver, false};
36
37
std::pair<TypeId, FreeType*> freshType()
38
{
39
FreeType ft{scope.get(), builtinTypes.neverType, builtinTypes.unknownType};
40
41
TypeId ty = arena.addType(ft);
42
FreeType* ftv = getMutable<FreeType>(ty);
43
REQUIRE(ftv != nullptr);
44
45
return {ty, ftv};
46
}
47
48
std::string toString(TypeId ty)
49
{
50
return ::Luau::toString(ty, opts);
51
}
52
53
std::string toString(TypePackId ty)
54
{
55
return ::Luau::toString(ty, opts);
56
}
57
58
std::optional<TypeId> generalize(TypeId ty)
59
{
60
return ::Luau::generalize(NotNull{&arena}, NotNull{&builtinTypes}, NotNull{scope.get()}, generalizedTypes, ty);
61
}
62
};
63
64
TEST_CASE_FIXTURE(GeneralizationFixture, "generalize_a_type_that_is_bounded_by_another_generalizable_type")
65
{
66
auto [t1, ft1] = freshType();
67
auto [t2, ft2] = freshType();
68
69
// t2 <: t1 <: unknown
70
// unknown <: t2 <: t1
71
72
ft1->lowerBound = t2;
73
ft2->upperBound = t1;
74
ft2->lowerBound = builtinTypes.unknownType;
75
76
auto t2generalized = generalize(t2);
77
REQUIRE(t2generalized);
78
79
CHECK(follow(t1) == follow(t2));
80
81
auto t1generalized = generalize(t1);
82
REQUIRE(t1generalized);
83
84
CHECK(builtinTypes.unknownType == follow(t1));
85
CHECK(builtinTypes.unknownType == follow(t2));
86
}
87
88
// Same as generalize_a_type_that_is_bounded_by_another_generalizable_type
89
// except that we generalize the types in the opposite order
90
TEST_CASE_FIXTURE(GeneralizationFixture, "generalize_a_type_that_is_bounded_by_another_generalizable_type_in_reverse_order")
91
{
92
auto [t1, ft1] = freshType();
93
auto [t2, ft2] = freshType();
94
95
// t2 <: t1 <: unknown
96
// unknown <: t2 <: t1
97
98
ft1->lowerBound = t2;
99
ft2->upperBound = t1;
100
ft2->lowerBound = builtinTypes.unknownType;
101
102
auto t1generalized = generalize(t1);
103
REQUIRE(t1generalized);
104
105
CHECK(follow(t1) == follow(t2));
106
107
auto t2generalized = generalize(t2);
108
REQUIRE(t2generalized);
109
110
CHECK(builtinTypes.unknownType == follow(t1));
111
CHECK(builtinTypes.unknownType == follow(t2));
112
}
113
114
TEST_CASE_FIXTURE(GeneralizationFixture, "dont_traverse_into_class_types_when_generalizing")
115
{
116
auto [propTy, _] = freshType();
117
118
TypeId cursedExternType =
119
arena.addType(ExternType{"Cursed", {{"oh_no", Property::readonly(propTy)}}, std::nullopt, std::nullopt, {}, {}, "", {}});
120
121
auto genExternType = generalize(cursedExternType);
122
REQUIRE(genExternType);
123
124
auto genPropTy = get<ExternType>(*genExternType)->props.at("oh_no").readTy;
125
CHECK(is<FreeType>(*genPropTy));
126
}
127
128
TEST_CASE_FIXTURE(GeneralizationFixture, "cache_fully_generalized_types")
129
{
130
CHECK(generalizedTypes->empty());
131
132
TypeId tinyTable = arena.addType(
133
TableType{TableType::Props{{"one", builtinTypes.numberType}, {"two", builtinTypes.stringType}}, std::nullopt, TypeLevel{}, TableState::Sealed}
134
);
135
136
generalize(tinyTable);
137
138
CHECK(generalizedTypes->contains(tinyTable));
139
CHECK(generalizedTypes->contains(builtinTypes.numberType));
140
CHECK(generalizedTypes->contains(builtinTypes.stringType));
141
}
142
143
TEST_CASE_FIXTURE(GeneralizationFixture, "dont_cache_types_that_arent_done_yet")
144
{
145
TypeId freeTy = arena.addType(FreeType{NotNull{globalScope.get()}, builtinTypes.neverType, builtinTypes.stringType});
146
147
TypeId fnTy = arena.addType(FunctionType{builtinTypes.emptyTypePack, arena.addTypePack(TypePack{{builtinTypes.numberType}})});
148
149
TypeId tableTy = arena.addType(
150
TableType{TableType::Props{{"one", builtinTypes.numberType}, {"two", freeTy}, {"three", fnTy}}, std::nullopt, TypeLevel{}, TableState::Sealed}
151
);
152
153
generalize(tableTy);
154
155
CHECK(generalizedTypes->contains(fnTy));
156
CHECK(generalizedTypes->contains(builtinTypes.numberType));
157
CHECK(generalizedTypes->contains(builtinTypes.neverType));
158
CHECK(generalizedTypes->contains(builtinTypes.stringType));
159
CHECK(!generalizedTypes->contains(freeTy));
160
CHECK(!generalizedTypes->contains(tableTy));
161
}
162
163
TEST_CASE_FIXTURE(GeneralizationFixture, "functions_containing_cyclic_tables_can_be_cached")
164
{
165
TypeId selfTy = arena.addType(BlockedType{});
166
167
TypeId methodTy = arena.addType(
168
FunctionType{
169
arena.addTypePack({selfTy}),
170
arena.addTypePack({builtinTypes.numberType}),
171
}
172
);
173
174
asMutable(selfTy)->ty.emplace<TableType>(
175
TableType::Props{{"count", builtinTypes.numberType}, {"method", methodTy}}, std::nullopt, TypeLevel{}, TableState::Sealed
176
);
177
178
generalize(methodTy);
179
180
CHECK(generalizedTypes->contains(methodTy));
181
CHECK(generalizedTypes->contains(selfTy));
182
CHECK(generalizedTypes->contains(builtinTypes.numberType));
183
}
184
185
TEST_CASE_FIXTURE(GeneralizationFixture, "union_type_traversal_doesnt_crash")
186
{
187
// t1 where t1 = ('h <: (t1 <: 'i)) | ('j <: (t1 <: 'i))
188
TypeId i = arena.freshType(NotNull{&builtinTypes}, globalScope.get());
189
TypeId h = arena.freshType(NotNull{&builtinTypes}, globalScope.get());
190
TypeId j = arena.freshType(NotNull{&builtinTypes}, globalScope.get());
191
TypeId unionType = arena.addType(UnionType{{h, j}});
192
getMutable<FreeType>(h)->upperBound = i;
193
getMutable<FreeType>(h)->lowerBound = builtinTypes.neverType;
194
getMutable<FreeType>(i)->upperBound = builtinTypes.unknownType;
195
getMutable<FreeType>(i)->lowerBound = unionType;
196
getMutable<FreeType>(j)->upperBound = i;
197
getMutable<FreeType>(j)->lowerBound = builtinTypes.neverType;
198
199
generalize(unionType);
200
}
201
202
TEST_CASE_FIXTURE(GeneralizationFixture, "intersection_type_traversal_doesnt_crash")
203
{
204
// t1 where t1 = ('h <: (t1 <: 'i)) & ('j <: (t1 <: 'i))
205
TypeId i = arena.freshType(NotNull{&builtinTypes}, globalScope.get());
206
TypeId h = arena.freshType(NotNull{&builtinTypes}, globalScope.get());
207
TypeId j = arena.freshType(NotNull{&builtinTypes}, globalScope.get());
208
TypeId intersectionType = arena.addType(IntersectionType{{h, j}});
209
210
getMutable<FreeType>(h)->upperBound = i;
211
getMutable<FreeType>(h)->lowerBound = builtinTypes.neverType;
212
getMutable<FreeType>(i)->upperBound = builtinTypes.unknownType;
213
getMutable<FreeType>(i)->lowerBound = intersectionType;
214
getMutable<FreeType>(j)->upperBound = i;
215
getMutable<FreeType>(j)->lowerBound = builtinTypes.neverType;
216
217
generalize(intersectionType);
218
}
219
220
TEST_CASE_FIXTURE(GeneralizationFixture, "('a) -> 'a")
221
{
222
TypeId freeTy = freshType().first;
223
TypeId fnTy = arena.addType(FunctionType{arena.addTypePack({freeTy}), arena.addTypePack({freeTy})});
224
225
generalize(fnTy);
226
227
CHECK("<a>(a) -> a" == toString(fnTy));
228
}
229
230
TEST_CASE_FIXTURE(GeneralizationFixture, "(t1, (t1 <: 'b)) -> () where t1 = ('a <: (t1 <: 'b) & {number} & {number})")
231
{
232
TableType tt;
233
tt.indexer = TableIndexer{builtinTypes.numberType, builtinTypes.numberType};
234
TypeId numberArray = arena.addType(TableType{tt});
235
236
auto [aTy, aFree] = freshType();
237
auto [bTy, bFree] = freshType();
238
239
aFree->upperBound = arena.addType(IntersectionType{{bTy, numberArray, numberArray}});
240
bFree->lowerBound = aTy;
241
242
TypeId functionTy = arena.addType(FunctionType{arena.addTypePack({aTy, bTy}), builtinTypes.emptyTypePack});
243
244
generalize(functionTy);
245
246
CHECK("(unknown & {number}, unknown) -> ()" == toString(functionTy));
247
}
248
249
TEST_CASE_FIXTURE(GeneralizationFixture, "(('a <: number | string)) -> string?")
250
{
251
auto [aTy, aFree] = freshType();
252
253
aFree->upperBound = arena.addType(UnionType{{builtinTypes.numberType, builtinTypes.stringType}});
254
255
TypeId fnType = arena.addType(FunctionType{arena.addTypePack({aTy}), arena.addTypePack({builtinTypes.optionalStringType})});
256
257
generalize(fnType);
258
259
CHECK("(number | string) -> string?" == toString(fnType));
260
}
261
262
TEST_CASE_FIXTURE(GeneralizationFixture, "(('a <: {'b})) -> ()")
263
{
264
auto [aTy, aFree] = freshType();
265
auto [bTy, bFree] = freshType();
266
267
TableType tt;
268
tt.indexer = TableIndexer{builtinTypes.numberType, bTy};
269
270
aFree->upperBound = arena.addType(tt);
271
272
TypeId functionTy = arena.addType(FunctionType{arena.addTypePack({aTy}), builtinTypes.emptyTypePack});
273
274
generalize(functionTy);
275
276
// The free type 'b is not replace with unknown because it appears in an
277
// invariant context.
278
CHECK("<a>({a}) -> ()" == toString(functionTy));
279
}
280
281
TEST_CASE_FIXTURE(GeneralizationFixture, "(('b <: {t1}), ('a <: t1)) -> t1 where t1 = (('a <: t1) <: 'c)")
282
{
283
auto [aTy, aFree] = freshType();
284
auto [bTy, bFree] = freshType();
285
auto [cTy, cFree] = freshType();
286
287
aFree->upperBound = cTy;
288
cFree->lowerBound = aTy;
289
290
TableType tt;
291
tt.indexer = TableIndexer{builtinTypes.numberType, cTy};
292
293
bFree->upperBound = arena.addType(tt);
294
295
TypeId functionTy = arena.addType(FunctionType{arena.addTypePack({bTy, aTy}), arena.addTypePack({cTy})});
296
297
generalize(functionTy);
298
299
CHECK("<a>({a}, a) -> a" == toString(functionTy));
300
}
301
302
TEST_CASE_FIXTURE(BuiltinsFixture, "generalization_traversal_should_re_traverse_unions_if_they_change_type")
303
{
304
// This test case should just not assert
305
CheckResult result = check(R"(
306
function byId(p)
307
return p.id
308
end
309
310
function foo()
311
312
local productButtonPairs = {}
313
local func = byId
314
local dir = -1
315
316
local function updateSearch()
317
for product, button in pairs(productButtonPairs) do
318
button.LayoutOrder = func(product) * dir
319
end
320
end
321
322
function(mode)
323
if mode == 'Name'then
324
else
325
if mode == 'New'then
326
func = function(p)
327
return p.id
328
end
329
elseif mode == 'Price'then
330
func = function(p)
331
return p.price
332
end
333
end
334
335
end
336
end
337
end
338
)");
339
}
340
341
TEST_CASE_FIXTURE(BuiltinsFixture, "generalization_should_not_leak_free_type")
342
{
343
ScopedFastFlag _{FFlag::DebugLuauForbidInternalTypes, true};
344
345
// This test case should just not assert
346
CheckResult result = check(R"(
347
function foo()
348
349
local productButtonPairs = {}
350
local func
351
local dir = -1
352
353
local function updateSearch()
354
for product, button in pairs(productButtonPairs) do
355
-- This line may have a floating free type pack.
356
button.LayoutOrder = func(product) * dir
357
end
358
end
359
360
function(mode)
361
if mode == 'New'then
362
func = function(p)
363
return p.id
364
end
365
elseif mode == 'Price'then
366
func = function(p)
367
return p.price
368
end
369
end
370
end
371
end
372
)");
373
}
374
375
TEST_CASE_FIXTURE(Fixture, "generics_dont_leak_into_callback")
376
{
377
ScopedFastFlag _{FFlag::DebugLuauForceOldSolver, false};
378
379
LUAU_REQUIRE_NO_ERRORS(check(R"(
380
local func: <T>(T, (T) -> ()) -> () = nil :: any
381
func({}, function(obj)
382
local _ = obj
383
end)
384
)"));
385
386
// `unknown` is correct here
387
// - The lambda given can be generalized to `(unknown) -> ()`
388
// - We can substitute the `T` in `func` for either `{}` or `unknown` and
389
// still have a well typed program.
390
// We *probably* can do a better job bidirectionally inferring the types.
391
CHECK_EQ("unknown", toString(requireTypeAtPosition(Position{3, 23})));
392
}
393
394
TEST_CASE_FIXTURE(Fixture, "generics_dont_leak_into_callback_2")
395
{
396
ScopedFastFlag sffs[] = {
397
{FFlag::DebugLuauForceOldSolver, false},
398
{FFlag::LuauReplacerRespectsReboundGenerics, true},
399
{FFlag::LuauOverloadGetsInstantiated, true},
400
};
401
402
CheckResult result = check(R"(
403
local func: <T>(T, (T) -> ()) -> () = nil :: any
404
local foobar: (number) -> () = nil :: any
405
func({}, function(obj)
406
foobar(obj)
407
end)
408
)");
409
410
LUAU_REQUIRE_ERROR_COUNT(1, result);
411
auto err = get<TypeMismatch>(result.errors[0]);
412
REQUIRE(err);
413
CHECK_EQ("number", toString(err->wantedType));
414
CHECK_EQ("{ }", toString(err->givenType));
415
}
416
417
TEST_CASE_FIXTURE(Fixture, "generic_argument_with_singleton_oss_1808")
418
{
419
// All we care about here is that this has no errors, and we correctly
420
// infer that the `false` literal should be typed as `false`.
421
LUAU_REQUIRE_NO_ERRORS(check(R"(
422
local function test<T>(value: false | (T) -> T)
423
return value
424
end
425
test(false)
426
)"));
427
}
428
429
TEST_CASE_FIXTURE(BuiltinsFixture, "avoid_cross_module_mutation_in_bidirectional_inference")
430
{
431
fileResolver.source["Module/ListFns"] = R"(
432
local mod = {}
433
function mod.findWhere(list, predicate): number?
434
for i = 1, #list do
435
if predicate(list[i], i) then
436
return i
437
end
438
end
439
return nil
440
end
441
return mod
442
)";
443
444
fileResolver.source["Module/B"] = R"(
445
local funs = require(script.Parent.ListFns)
446
local accessories = funs.findWhere(getList(), function(accessory)
447
return accessory.AccessoryType ~= accessoryTypeEnum
448
end)
449
return {}
450
)";
451
452
CheckResult result = getFrontend().check("Module/ListFns");
453
auto modListFns = getFrontend().moduleResolver.getModule("Module/ListFns");
454
freeze(modListFns->interfaceTypes);
455
freeze(modListFns->internalTypes);
456
LUAU_REQUIRE_NO_ERRORS(result);
457
CheckResult result2 = getFrontend().check("Module/B");
458
LUAU_REQUIRE_NO_ERRORS(result);
459
}
460
461
TEST_CASE_FIXTURE(BuiltinsFixture, "generalization_fuzzer_crash")
462
{
463
LUAU_REQUIRE_ERRORS(check(R"(
464
type function t0<A>(l0,...):""
465
type t0 = any
466
do
467
_()
468
_ = {_=...,}
469
_ = {_=rawget({_=_,l0,},_,- _),}
470
end
471
end
472
)"));
473
}
474
475
476
TEST_SUITE_END();
477
478