Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Roblox
GitHub Repository: Roblox/luau
Path: blob/master/Analysis/src/TableLiteralInference.cpp
2725 views
1
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
2
3
#include "Luau/TableLiteralInference.h"
4
5
#include "Luau/Ast.h"
6
#include "Luau/Common.h"
7
#include "Luau/ConstraintSolver.h"
8
#include "Luau/HashUtil.h"
9
#include "Luau/Simplify.h"
10
#include "Luau/Subtyping.h"
11
#include "Luau/Type.h"
12
#include "Luau/ToString.h"
13
#include "Luau/TypeUtils.h"
14
#include "Luau/Unifier2.h"
15
16
namespace Luau
17
{
18
19
namespace
20
{
21
22
struct BidirectionalTypePusher
23
{
24
25
NotNull<DenseHashMap<const AstExpr*, TypeId>> astTypes;
26
NotNull<DenseHashMap<const AstExpr*, TypeId>> astExpectedTypes;
27
28
NotNull<ConstraintSolver> solver;
29
NotNull<const Constraint> constraint;
30
NotNull<DenseHashSet<const void*>> genericTypesAndPacks;
31
NotNull<Unifier2> unifier;
32
NotNull<Subtyping> subtyping;
33
34
std::vector<IncompleteInference> incompleteInferences;
35
36
DenseHashSet<std::pair<TypeId, const AstExpr*>, PairHash<TypeId, const AstExpr*>> seen{{nullptr, nullptr}};
37
38
BidirectionalTypePusher(
39
NotNull<DenseHashMap<const AstExpr*, TypeId>> astTypes,
40
NotNull<DenseHashMap<const AstExpr*, TypeId>> astExpectedTypes,
41
NotNull<ConstraintSolver> solver,
42
NotNull<const Constraint> constraint,
43
NotNull<DenseHashSet<const void*>> genericTypesAndPacks,
44
NotNull<Unifier2> unifier,
45
NotNull<Subtyping> subtyping
46
)
47
: astTypes{astTypes}
48
, astExpectedTypes{astExpectedTypes}
49
, solver{solver}
50
, constraint{constraint}
51
, genericTypesAndPacks{genericTypesAndPacks}
52
, unifier{unifier}
53
, subtyping{subtyping}
54
{
55
}
56
57
TypeId pushType(TypeId expectedType, const AstExpr* expr)
58
{
59
(*astExpectedTypes)[expr] = expectedType;
60
// We may not have a type here if this is the last argument
61
// passed to a function call: this is potentially expected
62
// behavior.
63
if (!astTypes->contains(expr))
64
return solver->builtinTypes->anyType;
65
66
TypeId exprType = *astTypes->find(expr);
67
68
if (seen.contains({expectedType, expr}))
69
return exprType;
70
seen.insert({expectedType, expr});
71
72
expectedType = follow(expectedType);
73
exprType = follow(exprType);
74
75
// NOTE: We cannot block on free types here, as that trivially means
76
// any recursive function would have a cycle, consider:
77
//
78
// local function fact(n)
79
// return if n < 2 then 1 else n * fact(n - 1)
80
// end
81
//
82
// We'll have a cycle between trying to push `fact`'s type into its
83
// arguments and generalizing `fact`.
84
85
if (auto tfit = get<TypeFunctionInstanceType>(expectedType); tfit && tfit->state == TypeFunctionInstanceState::Unsolved)
86
{
87
incompleteInferences.push_back(IncompleteInference{expectedType, exprType, expr});
88
return exprType;
89
}
90
91
if (is<BlockedType, PendingExpansionType>(expectedType))
92
{
93
incompleteInferences.push_back(IncompleteInference{expectedType, exprType, expr});
94
return exprType;
95
}
96
97
if (is<AnyType, UnknownType>(expectedType))
98
return exprType;
99
100
if (auto group = expr->as<AstExprGroup>())
101
{
102
pushType(expectedType, group->expr);
103
return exprType;
104
}
105
106
if (auto ternary = expr->as<AstExprIfElse>())
107
{
108
pushType(expectedType, ternary->trueExpr);
109
pushType(expectedType, ternary->falseExpr);
110
return exprType;
111
}
112
113
if (!isLiteral(expr))
114
// NOTE: For now we aren't using the result of this function, so
115
// just return the original expression type.
116
return exprType;
117
118
if (expr->is<AstExprConstantString>() || expr->is<AstExprConstantNumber>() || expr->is<AstExprConstantBool>() ||
119
expr->is<AstExprConstantNil>())
120
{
121
if (auto ft = get<FreeType>(exprType))
122
{
123
if (maybeSingleton(expectedType) && maybeSingleton(ft->lowerBound))
124
{
125
// If we see a pattern like:
126
//
127
// local function foo<T>(my_enum: "foo" | "bar" | T) -> T
128
// return my_enum
129
// end
130
// local var = foo("meow")
131
//
132
// ... where we are attempting to push a singleton onto any string
133
// literal, and the lower bound is still a singleton, then snap
134
// to said lower bound.
135
solver->bind(constraint, exprType, ft->lowerBound);
136
return exprType;
137
}
138
139
// if the upper bound is a subtype of the expected type, we can push the expected type in
140
Relation upperBoundRelation = relate(ft->upperBound, expectedType);
141
if (upperBoundRelation == Relation::Subset || upperBoundRelation == Relation::Coincident)
142
{
143
solver->bind(constraint, exprType, expectedType);
144
return exprType;
145
}
146
147
// likewise, if the lower bound is a subtype, we can force the expected type in
148
// if this is the case and the previous relation failed, it means that the primitive type
149
// constraint was going to have to select the lower bound for this type anyway.
150
Relation lowerBoundRelation = relate(ft->lowerBound, expectedType);
151
if (lowerBoundRelation == Relation::Subset || lowerBoundRelation == Relation::Coincident)
152
{
153
solver->bind(constraint, exprType, expectedType);
154
return exprType;
155
}
156
}
157
}
158
159
if (auto exprLambda = expr->as<AstExprFunction>())
160
{
161
const auto lambdaTy = get<FunctionType>(exprType);
162
const auto expectedLambdaTy = get<FunctionType>(stripNil(solver->builtinTypes, *solver->arena, expectedType));
163
if (lambdaTy && expectedLambdaTy)
164
{
165
const auto& [lambdaArgTys, _lambdaTail] = flatten(lambdaTy->argTypes);
166
const auto& [expectedLambdaArgTys, _expectedLambdaTail] = flatten(expectedLambdaTy->argTypes);
167
168
auto limit = std::min({lambdaArgTys.size(), expectedLambdaArgTys.size(), exprLambda->args.size});
169
for (size_t argIndex = 0; argIndex < limit; argIndex++)
170
{
171
if (!exprLambda->args.data[argIndex]->annotation && get<FreeType>(follow(lambdaArgTys[argIndex])) &&
172
!containsGeneric(expectedLambdaArgTys[argIndex], NotNull{genericTypesAndPacks}))
173
solver->bind(NotNull{constraint}, lambdaArgTys[argIndex], expectedLambdaArgTys[argIndex]);
174
}
175
176
if (!exprLambda->returnAnnotation && get<FreeTypePack>(follow(lambdaTy->retTypes)) &&
177
!containsGeneric(expectedLambdaTy->retTypes, NotNull{genericTypesAndPacks}))
178
solver->bind(NotNull{constraint}, lambdaTy->retTypes, expectedLambdaTy->retTypes);
179
}
180
}
181
182
183
// TODO: CLI-169235: This probably ought to use the same logic as
184
// `index` to determine what the type of a given member is.
185
if (auto exprTable = expr->as<AstExprTable>())
186
{
187
const TableType* expectedTableTy = get<TableType>(expectedType);
188
189
if (!expectedTableTy)
190
{
191
if (auto utv = get<UnionType>(expectedType))
192
{
193
std::vector<TypeId> parts{begin(utv), end(utv)};
194
195
std::optional<TypeId> tt = extractMatchingTableType(parts, exprType, solver->builtinTypes);
196
197
if (tt)
198
(void)pushType(*tt, expr);
199
}
200
else if (auto itv = get<IntersectionType>(expectedType))
201
{
202
for (const auto part : itv)
203
(void)pushType(part, expr);
204
205
// Reset the expected type for this expression prior,
206
// otherwise the expected type will be the last part
207
// of the intersection, which does not seem ideal.
208
(*astExpectedTypes)[expr] = expectedType;
209
}
210
211
return exprType;
212
}
213
214
for (const AstExprTable::Item& item : exprTable->items)
215
{
216
if (isRecord(item))
217
{
218
const AstArray<char>& s = item.key->as<AstExprConstantString>()->value;
219
std::string keyStr{s.data, s.data + s.size};
220
auto it = expectedTableTy->props.find(keyStr);
221
222
if (it == expectedTableTy->props.end())
223
{
224
// If we have some type:
225
//
226
// { [T]: U }
227
//
228
// ... that we're trying to push into ...
229
//
230
// { foo = bar }
231
//
232
// Then the intent is probably to push `U` into `bar`.
233
if (expectedTableTy->indexer)
234
(void)pushType(expectedTableTy->indexer->indexResultType, item.value);
235
236
// If it's just an extra property and the expected type
237
// has no indexer, there's no work to do here.
238
continue;
239
}
240
241
LUAU_ASSERT(it != expectedTableTy->props.end());
242
243
const Property& expectedProp = it->second;
244
245
if (expectedProp.readTy)
246
(void)pushType(*expectedProp.readTy, item.value);
247
248
// NOTE: We do *not* add to the potential indexer types here.
249
// I think this is correct to support something like:
250
//
251
// { [string]: number, foo: boolean }
252
//
253
// NOTE: We also do nothing for write properties.
254
}
255
else if (item.kind == AstExprTable::Item::List)
256
{
257
if (expectedTableTy->indexer)
258
{
259
unifier->unify(expectedTableTy->indexer->indexType, solver->builtinTypes->numberType);
260
(void)pushType(expectedTableTy->indexer->indexResultType, item.value);
261
}
262
}
263
else if (item.kind == AstExprTable::Item::General)
264
{
265
266
// We have { ..., [blocked] : somePropExpr, ...}
267
// If blocked resolves to a string, we will then take care of this above
268
// If it resolves to some other kind of expression, we don't have a way of folding this information into indexer
269
// because there is no named prop to remove
270
// We should just block here
271
if (expectedTableTy->indexer)
272
{
273
(void)pushType(expectedTableTy->indexer->indexType, item.key);
274
(void)pushType(expectedTableTy->indexer->indexResultType, item.value);
275
}
276
}
277
else
278
LUAU_ASSERT(!"Unexpected");
279
}
280
}
281
282
return exprType;
283
}
284
};
285
} // namespace
286
287
PushTypeResult pushTypeInto(
288
NotNull<DenseHashMap<const AstExpr*, TypeId>> astTypes,
289
NotNull<DenseHashMap<const AstExpr*, TypeId>> astExpectedTypes,
290
NotNull<ConstraintSolver> solver,
291
NotNull<const Constraint> constraint,
292
NotNull<DenseHashSet<const void*>> genericTypesAndPacks,
293
NotNull<Unifier2> unifier,
294
NotNull<Subtyping> subtyping,
295
TypeId expectedType,
296
const AstExpr* expr
297
)
298
{
299
BidirectionalTypePusher btp{astTypes, astExpectedTypes, solver, constraint, genericTypesAndPacks, unifier, subtyping};
300
(void)btp.pushType(expectedType, expr);
301
return {std::move(btp.incompleteInferences)};
302
}
303
304
} // namespace Luau
305
306