#include "doctest.h"
#include "Fixture.h"
#include "Luau/OverloadResolution.h"
#include "Luau/Normalize.h"
#include "Luau/UnifierSharedState.h"
LUAU_FASTFLAG(DebugLuauForceOldSolver)
using namespace Luau;
struct OverloadResolverFixture : Fixture
{
TypeArena arena_;
NotNull<TypeArena> arena{&arena_};
UnifierSharedState sharedState{&ice};
Normalizer normalizer{arena, getBuiltins(), NotNull{&sharedState}, !FFlag::DebugLuauForceOldSolver ? SolverMode::New : SolverMode::Old};
InternalErrorReporter iceReporter;
TypeCheckLimits limits;
TypeFunctionRuntime typeFunctionRuntime{NotNull{&iceReporter}, NotNull{&limits}};
Scope rootScope{getBuiltins()->emptyTypePack};
Location callLocation;
OverloadResolver resolver = mkResolver();
OverloadResolver mkResolver()
{
return OverloadResolver{
getBuiltins(),
arena,
NotNull{&normalizer},
NotNull{&typeFunctionRuntime},
NotNull{&rootScope},
NotNull{&iceReporter},
NotNull{&limits},
callLocation
};
}
DenseHashSet<TypeId> kEmptySet{nullptr};
NotNull<DenseHashSet<TypeId>> emptySet{&kEmptySet};
Location kDummyLocation;
AstExprConstantNil kDummyExpr{kDummyLocation};
std::vector<AstExpr*> kEmptyExprs;
TypePackId pack(std::vector<TypeId> tys) const
{
return arena->addTypePack(std::move(tys));
}
TypePackId pack(std::initializer_list<TypeId> tys) const
{
return arena->addTypePack(tys);
}
TypePackId pack(std::initializer_list<TypeId> tys, TypePackVariant tail) const
{
return arena->addTypePack(tys, arena->addTypePack(std::move(tail)));
}
TypeId fn(std::initializer_list<TypeId> args, std::initializer_list<TypeId> rets) const
{
return arena->addType(FunctionType{pack(args), pack(rets)});
}
TypeId meet(TypeId a, TypeId b) const
{
return arena->addType(IntersectionType{{a, b}});
}
TypeId meet(std::initializer_list<TypeId> parts) const
{
return arena->addType(IntersectionType{parts});
}
TypeId join(TypeId a, TypeId b) const
{
return arena->addType(UnionType{{a, b}});
}
TypeId tableWithCall(TypeId callMm) const
{
TypeId table = arena->addType(TableType{TableState::Sealed, TypeLevel{}, nullptr});
TypeId metatable = arena->addType(TableType{TableType::Props{{"__call", callMm}}, std::nullopt, TypeLevel{}, TableState::Sealed});
return arena->addType(MetatableType{table, metatable});
}
const TypeId numberToNumber = fn({getBuiltins()->numberType}, {getBuiltins()->numberType});
const TypeId numberNumberToNumber = fn({getBuiltins()->numberType, getBuiltins()->numberType}, {getBuiltins()->numberType});
const TypeId numberToString = fn({getBuiltins()->numberType}, {getBuiltins()->stringType});
const TypeId stringToString = fn({getBuiltins()->stringType}, {getBuiltins()->stringType});
const TypeId numberToNumberAndStringToString = meet(numberToNumber, stringToString);
const TypeId numberToNumberAndNumberNumberToNumber = meet(numberToNumber, numberNumberToNumber);
};
TEST_SUITE_BEGIN("OverloadResolverTest");
TEST_CASE_FIXTURE(OverloadResolverFixture, "new_basic_overload_selection")
{
OverloadResolution result =
resolver.resolveOverload(numberToNumberAndStringToString, pack({getBuiltins()->numberType}), Location{}, emptySet, false);
CHECK(1 == result.ok.size());
CHECK(result.ok.at(0) == numberToNumber);
}
TEST_CASE_FIXTURE(OverloadResolverFixture, "new_basic_overload_selection1")
{
OverloadResolution result =
resolver.resolveOverload(numberToNumberAndStringToString, pack({getBuiltins()->stringType}), Location{}, emptySet, false);
CHECK(1 == result.ok.size());
CHECK(stringToString == result.ok.at(0));
CHECK(1 == result.incompatibleOverloads.size());
CHECK(numberToNumber == result.incompatibleOverloads.at(0).first);
}
TEST_CASE_FIXTURE(OverloadResolverFixture, "new_match_call_metamethod")
{
TypeId callMm = fn({builtinTypes->unknownType, builtinTypes->numberType}, {builtinTypes->numberType});
TypeId tbl = tableWithCall(callMm);
OverloadResolution result = resolver.resolveOverload(tbl, pack({builtinTypes->numberType}), Location{}, emptySet, false);
CHECK(1 == result.ok.size());
CHECK(callMm == result.ok.at(0));
}
TEST_CASE_FIXTURE(OverloadResolverFixture, "new_metamethod_could_be_overloaded")
{
TypeId overload1 = fn({builtinTypes->unknownType, builtinTypes->numberType}, {builtinTypes->numberType});
TypeId overload2 = fn({builtinTypes->unknownType, builtinTypes->stringType}, {builtinTypes->stringType});
TypeId tbl = tableWithCall(meet(overload1, overload2));
OverloadResolution result = resolver.resolveOverload(tbl, pack({builtinTypes->numberType}), Location{}, emptySet, false);
CHECK(1 == result.ok.size());
CHECK(overload1 == result.ok.at(0));
CHECK(1 == result.incompatibleOverloads.size());
CHECK(overload2 == result.incompatibleOverloads.at(0).first);
}
TEST_CASE_FIXTURE(OverloadResolverFixture, "new_overload_group_could_include_metamethod")
{
TypeId overload1 = fn({builtinTypes->unknownType, builtinTypes->numberType}, {builtinTypes->numberType});
TypeId overload2 = fn({builtinTypes->unknownType, builtinTypes->stringType}, {builtinTypes->stringType});
TypeId tbl = tableWithCall(meet(overload1, overload2));
TypeId monstrosity = meet(tbl, fn({builtinTypes->booleanType}, {builtinTypes->booleanType}));
OverloadResolution result = resolver.resolveOverload(monstrosity, pack({builtinTypes->numberType}), Location{}, emptySet, false);
CHECK(1 == result.ok.size());
CHECK(overload1 == result.ok.at(0));
}
TEST_CASE_FIXTURE(OverloadResolverFixture, "new_overloads_with_different_arities")
{
OverloadResolution result =
resolver.resolveOverload(numberToNumberAndNumberNumberToNumber, pack({getBuiltins()->numberType}), Location{}, emptySet, false);
CHECK(1 == result.ok.size());
CHECK(numberToNumber == result.ok.at(0));
CHECK(1 == result.arityMismatches.size());
CHECK(numberNumberToNumber == result.arityMismatches.at(0));
}
TEST_CASE_FIXTURE(OverloadResolverFixture, "new_overloads_with_different_arities1")
{
OverloadResolution result = resolver.resolveOverload(
numberToNumberAndNumberNumberToNumber, pack({getBuiltins()->numberType, getBuiltins()->numberType}), Location{}, emptySet, false
);
CHECK(1 == result.ok.size());
CHECK(numberNumberToNumber == result.ok.at(0));
CHECK(1 == result.arityMismatches.size());
CHECK(numberToNumber == result.arityMismatches.at(0));
}
TEST_CASE_FIXTURE(OverloadResolverFixture, "new_separate_non_viable_overloads_by_arity_mismatch")
{
const TypePack args = TypePack{{builtinTypes->stringType}, std::nullopt};
OverloadResolution resolution = resolver.resolveOverload(
meet({numberToNumber, numberToString, numberNumberToNumber}), pack({builtinTypes->stringType}), Location{}, emptySet, false
);
CHECK(resolution.ok.empty());
CHECK(resolution.nonFunctions.empty());
CHECK_EQ(1, resolution.arityMismatches.size());
CHECK_EQ(numberNumberToNumber, resolution.arityMismatches[0]);
CHECK_EQ(2, resolution.incompatibleOverloads.size());
bool numberToNumberFound = false;
bool numberToStringFound = false;
for (const auto& [ty, _] : resolution.incompatibleOverloads)
{
if (ty == numberToNumber)
numberToNumberFound = true;
else if (ty == numberToString)
numberToStringFound = true;
}
CHECK(numberToNumberFound);
CHECK(numberToStringFound);
}
TEST_CASE_FIXTURE(OverloadResolverFixture, "new_select")
{
TypeId numberOrString = join(builtinTypes->numberType, builtinTypes->stringType);
TypePackId genericAs = arena->addTypePack(GenericTypePack{"A"});
TypeId selectTy = arena->addType(FunctionType{{}, {genericAs}, arena->addTypePack({numberOrString}, genericAs), builtinTypes->anyTypePack});
OverloadResolver r = mkResolver();
OverloadResolution resolution =
r.resolveOverload(selectTy, arena->addTypePack({numberOrString}, builtinTypes->anyTypePack), Location{}, emptySet, false);
CHECK(1 == resolution.ok.size());
}
TEST_CASE_FIXTURE(OverloadResolverFixture, "new_pass_table_with_indexer")
{
TypeId anyNumberTable = arena->addType(
TableType{TableType::Props{}, TableIndexer{builtinTypes->anyType, builtinTypes->numberType}, TypeLevel{}, &rootScope, TableState::Sealed}
);
TypeId tableToTable = fn({anyNumberTable}, {anyNumberTable});
OverloadResolver r = mkResolver();
OverloadResolution resolution = r.resolveOverload(tableToTable, pack({anyNumberTable}), Location{}, emptySet, false);
CHECK(1 == resolution.ok.size());
CHECK(0 == resolution.potentialOverloads.size());
CHECK(0 == resolution.incompatibleOverloads.size());
CHECK(0 == resolution.nonFunctions.size());
CHECK(0 == resolution.arityMismatches.size());
}
TEST_CASE_FIXTURE(OverloadResolverFixture, "generic_higher_order_function_called_improperly")
{
const TypeId genericA = arena->addType(GenericType{"A", Polarity::Mixed});
const TypePackId genericBs = arena->addTypePack(GenericTypePack{"B"});
const TypePackId genericCs = arena->addTypePack(GenericTypePack{"C"});
TypeId functionArgument = arena->addType(FunctionType{arena->addTypePack({genericA}, genericBs), genericCs});
TypePackId applyArgs = pack({functionArgument, genericA});
TypeId applyTy = arena->addType(FunctionType{{genericA}, {genericBs, genericCs}, applyArgs, genericCs});
TypePackId callArgsPack = pack({numberNumberToNumber, builtinTypes->numberType});
OverloadResolver r = mkResolver();
OverloadResolution resolution = r.resolveOverload(applyTy, callArgsPack, Location{}, emptySet, false);
CHECK(1 == resolution.ok.size());
}
TEST_CASE_FIXTURE(OverloadResolverFixture, "debug_traceback")
{
TypeId overload1 = fn({builtinTypes->optionalStringType, builtinTypes->optionalNumberType}, {builtinTypes->stringType});
TypeId overload2 = fn({builtinTypes->threadType, builtinTypes->optionalStringType, builtinTypes->optionalNumberType}, {builtinTypes->stringType});
TypeId debugTraceback = meet({overload1, overload2});
OverloadResolver r = mkResolver();
OverloadResolution resolution;
SUBCASE("no_arguments")
{
resolution = r.resolveOverload(debugTraceback, builtinTypes->emptyTypePack, Location{}, emptySet, false);
CHECK(1 == resolution.ok.size());
}
SUBCASE("message_only")
{
resolution = r.resolveOverload(debugTraceback, pack({builtinTypes->stringType}), Location{}, emptySet, false);
CHECK(1 == resolution.ok.size());
}
SUBCASE("message_and_level")
{
resolution = r.resolveOverload(debugTraceback, pack({builtinTypes->stringType, builtinTypes->numberType}), Location{}, emptySet, false);
CHECK(1 == resolution.ok.size());
}
SUBCASE("thread")
{
resolution = r.resolveOverload(debugTraceback, pack({builtinTypes->threadType}), Location{}, emptySet, false);
CHECK(1 == resolution.ok.size());
}
SUBCASE("thread_and_message")
{
resolution = r.resolveOverload(debugTraceback, pack({builtinTypes->threadType, builtinTypes->stringType}), Location{}, emptySet, false);
CHECK(1 == resolution.ok.size());
}
SUBCASE("thread_message_and_level")
{
resolution = r.resolveOverload(
debugTraceback, pack({builtinTypes->threadType, builtinTypes->stringType, builtinTypes->numberType}), Location{}, emptySet, false
);
CHECK(1 == resolution.ok.size());
}
}
TEST_SUITE_END();