#include "Luau/TableLiteralInference.h"
#include "Luau/Ast.h"
#include "Luau/Common.h"
#include "Luau/ConstraintSolver.h"
#include "Luau/HashUtil.h"
#include "Luau/Simplify.h"
#include "Luau/Subtyping.h"
#include "Luau/Type.h"
#include "Luau/ToString.h"
#include "Luau/TypeUtils.h"
#include "Luau/Unifier2.h"
namespace Luau
{
namespace
{
struct BidirectionalTypePusher
{
NotNull<DenseHashMap<const AstExpr*, TypeId>> astTypes;
NotNull<DenseHashMap<const AstExpr*, TypeId>> astExpectedTypes;
NotNull<ConstraintSolver> solver;
NotNull<const Constraint> constraint;
NotNull<DenseHashSet<const void*>> genericTypesAndPacks;
NotNull<Unifier2> unifier;
NotNull<Subtyping> subtyping;
std::vector<IncompleteInference> incompleteInferences;
DenseHashSet<std::pair<TypeId, const AstExpr*>, PairHash<TypeId, const AstExpr*>> seen{{nullptr, nullptr}};
BidirectionalTypePusher(
NotNull<DenseHashMap<const AstExpr*, TypeId>> astTypes,
NotNull<DenseHashMap<const AstExpr*, TypeId>> astExpectedTypes,
NotNull<ConstraintSolver> solver,
NotNull<const Constraint> constraint,
NotNull<DenseHashSet<const void*>> genericTypesAndPacks,
NotNull<Unifier2> unifier,
NotNull<Subtyping> subtyping
)
: astTypes{astTypes}
, astExpectedTypes{astExpectedTypes}
, solver{solver}
, constraint{constraint}
, genericTypesAndPacks{genericTypesAndPacks}
, unifier{unifier}
, subtyping{subtyping}
{
}
TypeId pushType(TypeId expectedType, const AstExpr* expr)
{
(*astExpectedTypes)[expr] = expectedType;
if (!astTypes->contains(expr))
return solver->builtinTypes->anyType;
TypeId exprType = *astTypes->find(expr);
if (seen.contains({expectedType, expr}))
return exprType;
seen.insert({expectedType, expr});
expectedType = follow(expectedType);
exprType = follow(exprType);
if (auto tfit = get<TypeFunctionInstanceType>(expectedType); tfit && tfit->state == TypeFunctionInstanceState::Unsolved)
{
incompleteInferences.push_back(IncompleteInference{expectedType, exprType, expr});
return exprType;
}
if (is<BlockedType, PendingExpansionType>(expectedType))
{
incompleteInferences.push_back(IncompleteInference{expectedType, exprType, expr});
return exprType;
}
if (is<AnyType, UnknownType>(expectedType))
return exprType;
if (auto group = expr->as<AstExprGroup>())
{
pushType(expectedType, group->expr);
return exprType;
}
if (auto ternary = expr->as<AstExprIfElse>())
{
pushType(expectedType, ternary->trueExpr);
pushType(expectedType, ternary->falseExpr);
return exprType;
}
if (!isLiteral(expr))
return exprType;
if (expr->is<AstExprConstantString>() || expr->is<AstExprConstantNumber>() || expr->is<AstExprConstantBool>() ||
expr->is<AstExprConstantNil>())
{
if (auto ft = get<FreeType>(exprType))
{
if (maybeSingleton(expectedType) && maybeSingleton(ft->lowerBound))
{
solver->bind(constraint, exprType, ft->lowerBound);
return exprType;
}
Relation upperBoundRelation = relate(ft->upperBound, expectedType);
if (upperBoundRelation == Relation::Subset || upperBoundRelation == Relation::Coincident)
{
solver->bind(constraint, exprType, expectedType);
return exprType;
}
Relation lowerBoundRelation = relate(ft->lowerBound, expectedType);
if (lowerBoundRelation == Relation::Subset || lowerBoundRelation == Relation::Coincident)
{
solver->bind(constraint, exprType, expectedType);
return exprType;
}
}
}
if (auto exprLambda = expr->as<AstExprFunction>())
{
const auto lambdaTy = get<FunctionType>(exprType);
const auto expectedLambdaTy = get<FunctionType>(stripNil(solver->builtinTypes, *solver->arena, expectedType));
if (lambdaTy && expectedLambdaTy)
{
const auto& [lambdaArgTys, _lambdaTail] = flatten(lambdaTy->argTypes);
const auto& [expectedLambdaArgTys, _expectedLambdaTail] = flatten(expectedLambdaTy->argTypes);
auto limit = std::min({lambdaArgTys.size(), expectedLambdaArgTys.size(), exprLambda->args.size});
for (size_t argIndex = 0; argIndex < limit; argIndex++)
{
if (!exprLambda->args.data[argIndex]->annotation && get<FreeType>(follow(lambdaArgTys[argIndex])) &&
!containsGeneric(expectedLambdaArgTys[argIndex], NotNull{genericTypesAndPacks}))
solver->bind(NotNull{constraint}, lambdaArgTys[argIndex], expectedLambdaArgTys[argIndex]);
}
if (!exprLambda->returnAnnotation && get<FreeTypePack>(follow(lambdaTy->retTypes)) &&
!containsGeneric(expectedLambdaTy->retTypes, NotNull{genericTypesAndPacks}))
solver->bind(NotNull{constraint}, lambdaTy->retTypes, expectedLambdaTy->retTypes);
}
}
if (auto exprTable = expr->as<AstExprTable>())
{
const TableType* expectedTableTy = get<TableType>(expectedType);
if (!expectedTableTy)
{
if (auto utv = get<UnionType>(expectedType))
{
std::vector<TypeId> parts{begin(utv), end(utv)};
std::optional<TypeId> tt = extractMatchingTableType(parts, exprType, solver->builtinTypes);
if (tt)
(void)pushType(*tt, expr);
}
else if (auto itv = get<IntersectionType>(expectedType))
{
for (const auto part : itv)
(void)pushType(part, expr);
(*astExpectedTypes)[expr] = expectedType;
}
return exprType;
}
for (const AstExprTable::Item& item : exprTable->items)
{
if (isRecord(item))
{
const AstArray<char>& s = item.key->as<AstExprConstantString>()->value;
std::string keyStr{s.data, s.data + s.size};
auto it = expectedTableTy->props.find(keyStr);
if (it == expectedTableTy->props.end())
{
if (expectedTableTy->indexer)
(void)pushType(expectedTableTy->indexer->indexResultType, item.value);
continue;
}
LUAU_ASSERT(it != expectedTableTy->props.end());
const Property& expectedProp = it->second;
if (expectedProp.readTy)
(void)pushType(*expectedProp.readTy, item.value);
}
else if (item.kind == AstExprTable::Item::List)
{
if (expectedTableTy->indexer)
{
unifier->unify(expectedTableTy->indexer->indexType, solver->builtinTypes->numberType);
(void)pushType(expectedTableTy->indexer->indexResultType, item.value);
}
}
else if (item.kind == AstExprTable::Item::General)
{
if (expectedTableTy->indexer)
{
(void)pushType(expectedTableTy->indexer->indexType, item.key);
(void)pushType(expectedTableTy->indexer->indexResultType, item.value);
}
}
else
LUAU_ASSERT(!"Unexpected");
}
}
return exprType;
}
};
}
PushTypeResult pushTypeInto(
NotNull<DenseHashMap<const AstExpr*, TypeId>> astTypes,
NotNull<DenseHashMap<const AstExpr*, TypeId>> astExpectedTypes,
NotNull<ConstraintSolver> solver,
NotNull<const Constraint> constraint,
NotNull<DenseHashSet<const void*>> genericTypesAndPacks,
NotNull<Unifier2> unifier,
NotNull<Subtyping> subtyping,
TypeId expectedType,
const AstExpr* expr
)
{
BidirectionalTypePusher btp{astTypes, astExpectedTypes, solver, constraint, genericTypesAndPacks, unifier, subtyping};
(void)btp.pushType(expectedType, expr);
return {std::move(btp.incompleteInferences)};
}
}