#include "Luau/OverloadResolution.h"
#include "Luau/Common.h"
#include "Luau/Instantiation2.h"
#include "Luau/Subtyping.h"
#include "Luau/TxnLog.h"
#include "Luau/Type.h"
#include "Luau/TypeFunction.h"
#include "Luau/TypePack.h"
#include "Luau/TypePath.h"
#include "Luau/TypeUtils.h"
#include "Luau/Unifier2.h"
namespace Luau
{
SelectedOverload OverloadResolution::getUnambiguousOverload() const
{
if (ok.size() == 1 && potentialOverloads.size() == 0)
{
return {
ok.front(),
{},
false,
};
}
if (ok.size() == 0 && potentialOverloads.size() == 1)
{
return {potentialOverloads.front().first, potentialOverloads.front().second, false};
}
if (ok.size() > 1)
{
return {std::nullopt, {}, false};
}
if (potentialOverloads.size() + ok.size() > 1)
{
if (ok.empty())
return {potentialOverloads.front().first, potentialOverloads.front().second, true};
else
{
LUAU_ASSERT(ok.size() == 1);
return {ok.front(), {}, true};
}
}
LUAU_ASSERT(potentialOverloads.size() + ok.size() == 0);
if (incompatibleOverloads.size() == 1)
{
return {incompatibleOverloads.front().first, {}, false};
}
return {std::nullopt, {}, false};
}
OverloadResolver::OverloadResolver(
NotNull<BuiltinTypes> builtinTypes,
NotNull<TypeArena> arena,
NotNull<Normalizer> normalizer,
NotNull<TypeFunctionRuntime> typeFunctionRuntime,
NotNull<Scope> scope,
NotNull<InternalErrorReporter> reporter,
NotNull<TypeCheckLimits> limits,
Location callLocation
)
: builtinTypes(builtinTypes)
, arena(arena)
, normalizer(normalizer)
, typeFunctionRuntime(typeFunctionRuntime)
, scope(scope)
, ice(reporter)
, limits(limits)
, subtyping({builtinTypes, arena, normalizer, typeFunctionRuntime, ice})
, callLoc(callLocation)
{
}
static bool reasoningIsReturnTypes(const Path& path)
{
if (path.empty())
return false;
const auto& firstComponent = path.components[0];
const auto field = get_if<TypePath::PackField>(&firstComponent);
return field != nullptr && *field == TypePath::PackField::Returns;
}
static void ignoreReasoningForReturnType(SubtypingResult& sr)
{
SubtypingReasonings result{kEmptyReasoning};
for (const SubtypingReasoning& reasoning : sr.reasoning)
{
if (reasoningIsReturnTypes(reasoning.subPath) && reasoningIsReturnTypes(reasoning.superPath))
continue;
result.insert(reasoning);
}
std::swap(sr.reasoning, result);
if (sr.reasoning.empty() && sr.genericBoundsMismatches.empty() && sr.errors.empty())
sr.isSubtype = true;
}
static bool areUnsatisfiedArgumentsOptional(const SubtypingReasonings& reasonings, TypePackId argPack, TypePackId funcArgPack)
{
if (1 != reasonings.size())
return false;
const TypePath::Path justArguments{TypePath::PackField::Arguments};
const auto& reason = *reasonings.begin();
if (reason.subPath != justArguments || reason.superPath != justArguments)
return false;
const auto [argHead, argTail] = flatten(argPack);
const auto [funArgHead, funArgTail] = flatten(funcArgPack);
if (argHead.size() >= funArgHead.size())
return false;
for (size_t i = argHead.size(); i < funArgHead.size(); ++i)
{
if (!isOptional(funArgHead[i]))
return false;
}
return true;
}
OverloadResolution OverloadResolver::resolveOverload(
TypeId ty,
TypePackId argsPack,
Location fnLocation,
NotNull<DenseHashSet<TypeId>> uniqueTypes,
bool useFreeTypeBounds
)
{
OverloadResolution result;
ty = follow(ty);
if (auto it = get<IntersectionType>(ty))
{
for (TypeId component : it)
testFunctionOrUnion(result, component, argsPack, fnLocation, uniqueTypes);
}
else
testFunctionOrUnion(result, ty, argsPack, fnLocation, uniqueTypes);
return result;
}
static bool isPathOnArgumentList(const Path& path)
{
auto iter = begin(path.components);
const auto endIter = end(path.components);
if (iter == endIter)
return false;
if (auto args = get_if<TypePath::PackField>(&*iter); args && *args != TypePath::PackField::Arguments)
return false;
++iter;
while (iter != endIter)
{
if (get_if<TypePath::PackSlice>(&*iter) || get_if<TypePath::GenericPackMapping>(&*iter))
++iter;
else if (const auto packField = get_if<TypePath::PackField>(&*iter); packField && *packField == TypePath::PackField::Tail)
++iter;
else
return false;
}
return true;
}
static std::optional<size_t> getArgumentIndex(const Path& path, TypeId fnTy)
{
auto iter = begin(path.components);
const auto endIter = end(path.components);
if (iter == endIter)
return std::nullopt;
if (auto args = get_if<TypePath::PackField>(&*iter); args && *args != TypePath::PackField::Arguments)
return std::nullopt;
++iter;
const FunctionType* ft = get<FunctionType>(fnTy);
LUAU_ASSERT(fnTy);
size_t result = 0;
TypeOrPack ty = ft->argTypes;
while (iter != endIter)
{
const auto& component = *iter;
++iter;
if (auto index = get_if<TypePath::Index>(&component))
return result + index->index;
else if (auto subst = get_if<TypePath::GenericPackMapping>(&component))
ty = subst->mappedType;
else if (auto slice = get_if<TypePath::PackSlice>(&component))
result += slice->start_index;
else if (auto packField = get_if<TypePath::PackField>(&component); packField && *packField == TypePath::PackField::Tail)
{
TypePackId* tp = get_if<TypePackId>(&ty);
LUAU_ASSERT(tp);
if (!tp)
return std::nullopt;
auto packIter = begin(*tp);
auto packEndIter = end(*tp);
while (packIter != packEndIter)
{
result += 1;
++packIter;
}
if (!packIter.tail())
return std::nullopt;
ty = *packIter.tail();
continue;
}
else
return std::nullopt;
}
return std::nullopt;
}
void OverloadResolver::reportErrors(
ErrorVec& errors,
TypeId fnTy,
Location fnLocation,
const ModuleName& moduleName,
TypePackId argPack,
const std::vector<AstExpr*>& argExprs,
const SubtypingReasoning& reason
) const
{
std::optional<size_t> argumentIndex = getArgumentIndex(reason.subPath, fnTy);
Location argLocation;
if (argumentIndex && *argumentIndex < argExprs.size())
argLocation = argExprs.at(*argumentIndex)->location;
else if (argExprs.size() != 0)
argLocation = argExprs.back()->location;
else
argLocation = fnLocation;
const TypeId prospectiveFunction = arena->addType(FunctionType{argPack, builtinTypes->anyTypePack});
std::optional<TypePackId> failedSubPack = traverseForPack(prospectiveFunction, reason.superPath, builtinTypes, arena);
std::optional<TypePackId> failedSuperPack = traverseForPack(fnTy, reason.subPath, builtinTypes, arena);
if (failedSuperPack && get<GenericTypePack>(*failedSuperPack))
{
maybeEmplaceError(&errors, argLocation, moduleName, &reason, failedSuperPack, failedSubPack.value_or(builtinTypes->emptyTypePack));
return;
}
if (isPathOnArgumentList(reason.subPath))
{
if (!failedSuperPack)
{
errors.emplace_back(fnLocation, moduleName, InternalError{"Malformed SubtypingReasoning"});
return;
}
const TypePackId requiredMappedArgs = arena->addTypePack(traverseForFlattenedPack(fnTy, reason.subPath, builtinTypes, arena));
const auto [paramsHead, paramsTail] = flatten(requiredMappedArgs);
const auto [argHead, argTail] = flatten(argPack);
const size_t argCount = argHead.size();
auto [minParams, optMaxParams] = getParameterExtents(TxnLog::empty(), requiredMappedArgs);
switch (shouldSuppressErrors(normalizer, argPack))
{
case ErrorSuppression::Suppress:
return;
case ErrorSuppression::DoNotSuppress:
break;
case ErrorSuppression::NormalizationFailed:
errors.emplace_back(fnLocation, moduleName, NormalizationTooComplex{});
return;
}
if (failedSuperPack)
{
switch (shouldSuppressErrors(normalizer, requiredMappedArgs))
{
case ErrorSuppression::Suppress:
return;
case ErrorSuppression::DoNotSuppress:
break;
case ErrorSuppression::NormalizationFailed:
errors.emplace_back(fnLocation, moduleName, NormalizationTooComplex{});
return;
}
}
const bool isVariadic = argTail && Luau::isVariadic(*argTail);
if (isVariadic)
{
maybeEmplaceError(&errors, argLocation, moduleName, &reason, failedSuperPack, failedSubPack.value_or(builtinTypes->emptyTypePack));
}
else
errors.emplace_back(fnLocation, moduleName, CountMismatch{paramsHead.size(), optMaxParams, argCount, CountMismatch::Arg, isVariadic});
return;
}
if (argumentIndex)
{
if (*argumentIndex < argExprs.size())
argLocation = argExprs.at(*argumentIndex)->location;
else if (argExprs.size() != 0)
argLocation = argExprs.back()->location;
else
argLocation = fnLocation;
LUAU_ASSERT(reason.subPath.components.size() > 1);
Path superPathTail = reason.superPath;
superPathTail.components.erase(superPathTail.components.begin());
std::optional<TypeOrPack> failedSub = traverse(argPack, superPathTail, builtinTypes, arena);
std::optional<TypeOrPack> failedSuper = traverse(fnTy, reason.subPath, builtinTypes, arena);
maybeEmplaceError(&errors, argLocation, moduleName, &reason, failedSuper, failedSub);
return;
}
if (failedSubPack && !failedSuperPack && get<GenericTypePack>(*failedSubPack))
{
errors.emplace_back(argLocation, moduleName, TypePackMismatch{*failedSubPack, builtinTypes->emptyTypePack});
}
if (failedSubPack && failedSuperPack)
{
if (argExprs.empty())
argLocation = fnLocation;
else
argLocation = argExprs.at(argExprs.size() - 1)->location;
auto errorSuppression = shouldSuppressErrors(normalizer, *failedSubPack).orElse(shouldSuppressErrors(normalizer, *failedSuperPack));
if (errorSuppression == ErrorSuppression::Suppress)
return;
switch (reason.variance)
{
case SubtypingVariance::Covariant:
errors.emplace_back(argLocation, moduleName, TypePackMismatch{*failedSubPack, *failedSuperPack});
break;
case SubtypingVariance::Contravariant:
errors.emplace_back(argLocation, moduleName, TypePackMismatch{*failedSuperPack, *failedSubPack});
break;
case SubtypingVariance::Invariant:
errors.emplace_back(argLocation, moduleName, TypePackMismatch{*failedSubPack, *failedSuperPack});
break;
default:
LUAU_ASSERT(0);
break;
}
}
}
void OverloadResolver::testFunction(
OverloadResolution& result,
TypeId fnTy,
TypePackId argsPack,
Location fnLocation,
NotNull<DenseHashSet<TypeId>> uniqueTypes
)
{
fnTy = follow(fnTy);
if (is<FreeType, BlockedType, PendingExpansionType>(fnTy))
{
std::vector<ConstraintV> constraints;
result.potentialOverloads.emplace_back(fnTy, std::move(constraints));
return;
}
if (auto tfit = get<TypeFunctionInstanceType>(fnTy); tfit && tfit->state == TypeFunctionInstanceState::Unsolved)
{
std::vector<ConstraintV> constraints;
result.potentialOverloads.emplace_back(fnTy, std::move(constraints));
return;
}
const FunctionType* ftv = get<FunctionType>(fnTy);
if (!ftv)
{
result.nonFunctions.emplace_back(fnTy);
return;
}
if (!isArityCompatible(argsPack, ftv->argTypes, builtinTypes))
{
result.arityMismatches.emplace_back(fnTy);
return;
}
TypeFunctionContext context{arena, builtinTypes, scope, normalizer, typeFunctionRuntime, ice, limits};
FunctionGraphReductionResult reduceResult = reduceTypeFunctions(fnTy, callLoc, NotNull{&context}, true);
if (!reduceResult.errors.empty())
{
result.incompatibleOverloads.emplace_back(fnTy, std::move(reduceResult.errors));
return;
}
TypeId prospectiveFunction = arena->addType(FunctionType{argsPack, builtinTypes->anyTypePack});
subtyping.uniqueTypes = uniqueTypes;
SubtypingResult r = subtyping.isSubtype(fnTy, prospectiveFunction, scope);
ignoreReasoningForReturnType(r);
if (r.isSubtype)
{
if (r.assumedConstraints.empty())
result.ok.emplace_back(fnTy);
else
result.potentialOverloads.emplace_back(fnTy, std::move(r.assumedConstraints));
}
else
{
if (!r.genericBoundsMismatches.empty())
{
ErrorVec errors;
for (const auto& gbm : r.genericBoundsMismatches)
errors.emplace_back(fnLocation, gbm);
result.incompatibleOverloads.emplace_back(fnTy, std::move(errors));
}
else if (areUnsatisfiedArgumentsOptional(r.reasoning, argsPack, ftv->argTypes))
{
if (r.assumedConstraints.empty())
result.ok.emplace_back(fnTy);
else
result.potentialOverloads.emplace_back(fnTy, std::move(r.assumedConstraints));
}
else
result.incompatibleOverloads.emplace_back(fnTy, std::move(r.reasoning));
}
}
void OverloadResolver::testFunctionOrUnion(
OverloadResolution& result,
TypeId fnTy,
TypePackId argsPack,
Location fnLocation,
NotNull<DenseHashSet<TypeId>> uniqueTypes
)
{
LUAU_ASSERT(fnTy == follow(fnTy));
if (auto ut = get<UnionType>(fnTy))
{
OverloadResolution innerResult;
size_t count = 0;
for (TypeId t : ut)
{
++count;
testFunctionOrCallMetamethod(innerResult, t, argsPack, fnLocation, uniqueTypes);
}
if (count == innerResult.ok.size())
{
result.ok.emplace_back(fnTy);
}
else if (count == innerResult.ok.size() + innerResult.potentialOverloads.size())
{
std::vector<ConstraintV> allConstraints;
for (const auto& [_t, constraints] : innerResult.potentialOverloads)
allConstraints.insert(allConstraints.end(), constraints.begin(), constraints.end());
result.potentialOverloads.emplace_back(fnTy, std::move(allConstraints));
}
else
{
result.incompatibleOverloads.emplace_back(fnTy, ErrorVec{{fnLocation, CannotCallNonFunction{fnTy}}});
}
}
else
testFunctionOrCallMetamethod(result, fnTy, argsPack, fnLocation, uniqueTypes);
}
void OverloadResolver::testFunctionOrCallMetamethod(
OverloadResolution& result,
TypeId fnTy,
TypePackId argsPack,
Location fnLocation,
NotNull<DenseHashSet<TypeId>> uniqueTypes
)
{
fnTy = follow(fnTy);
ErrorVec dummyErrors;
if (auto callMetamethod = findMetatableEntry(builtinTypes, dummyErrors, fnTy, "__call", callLoc))
{
argsPack = arena->addTypePack({fnTy}, argsPack);
fnTy = follow(*callMetamethod);
if (auto it = get<IntersectionType>(fnTy))
{
for (TypeId component : it)
{
component = follow(component);
result.metamethods.insert(component);
const FunctionType* fn = get<FunctionType>(component);
if (fn && !isArityCompatible(argsPack, fn->argTypes, builtinTypes))
result.arityMismatches.emplace_back(component);
else
testFunction(result, component, argsPack, fnLocation, uniqueTypes);
}
return;
}
result.metamethods.insert(fnTy);
}
testFunction(result, fnTy, argsPack, fnLocation, uniqueTypes);
}
void OverloadResolver::maybeEmplaceError(
ErrorVec* errors,
Location argLocation,
const SubtypingReasoning* reason,
const std::optional<TypeId> wantedType,
const std::optional<TypeId> givenType
) const
{
return maybeEmplaceError(errors, argLocation, ModuleName{}, reason, wantedType, givenType);
}
void OverloadResolver::maybeEmplaceError(
ErrorVec* errors,
Location argLocation,
const ModuleName& moduleName,
const SubtypingReasoning* reason,
const std::optional<TypeId> wantedType,
const std::optional<TypeId> givenType
) const
{
if (wantedType && givenType)
{
switch (shouldSuppressErrors(normalizer, *wantedType).orElse(shouldSuppressErrors(normalizer, *givenType)))
{
case ErrorSuppression::Suppress:
break;
case ErrorSuppression::NormalizationFailed:
errors->emplace_back(argLocation, moduleName, NormalizationTooComplex{});
[[fallthrough]];
case ErrorSuppression::DoNotSuppress:
switch (reason->variance)
{
case SubtypingVariance::Covariant:
case SubtypingVariance::Contravariant:
errors->emplace_back(argLocation, moduleName, TypeMismatch{*wantedType, *givenType, TypeMismatch::CovariantContext});
break;
case SubtypingVariance::Invariant:
errors->emplace_back(argLocation, moduleName, TypeMismatch{*wantedType, *givenType, TypeMismatch::InvariantContext});
break;
default:
LUAU_ASSERT(0);
break;
}
}
}
}
void OverloadResolver::maybeEmplaceError(
ErrorVec* errors,
Location argLocation,
const ModuleName& moduleName,
const SubtypingReasoning* reason,
const std::optional<TypePackId> wantedTp,
const std::optional<TypePackId> givenTp
) const
{
if (!wantedTp || !givenTp)
return;
switch (shouldSuppressErrors(normalizer, *wantedTp).orElse(shouldSuppressErrors(normalizer, *givenTp)))
{
case ErrorSuppression::Suppress:
break;
case ErrorSuppression::NormalizationFailed:
errors->emplace_back(argLocation, moduleName, NormalizationTooComplex{});
break;
case ErrorSuppression::DoNotSuppress:
errors->emplace_back(argLocation, moduleName, TypePackMismatch{*wantedTp, *givenTp});
break;
}
}
void OverloadResolver::maybeEmplaceError(
ErrorVec* errors,
Location argLocation,
const ModuleName& moduleName,
const SubtypingReasoning* reason,
const std::optional<TypeOrPack> wantedType,
const std::optional<TypeOrPack> givenType
) const
{
if (!wantedType || !givenType)
return;
const TypeId* wantedTy = get_if<TypeId>(&*wantedType);
const TypeId* givenTy = get_if<TypeId>(&*givenType);
if (wantedTy && givenTy)
return maybeEmplaceError(errors, argLocation, moduleName, reason, std::optional<TypeId>{*wantedTy}, std::optional<TypeId>{*givenTy});
const TypePackId* wantedTp = get_if<TypePackId>(&*wantedType);
const TypePackId* givenTp = get_if<TypePackId>(&*givenType);
if (wantedTp && givenTp)
return maybeEmplaceError(errors, argLocation, moduleName, reason, std::optional<TypePackId>{*wantedTp}, std::optional<TypePackId>{*givenTp});
}
bool OverloadResolver::isArityCompatible(const TypePackId candidate, const TypePackId desired, NotNull<BuiltinTypes> builtinTypes) const
{
auto [candidateHead, candidateTail] = flatten(candidate);
auto [desiredHead, desiredTail] = flatten(desired);
if (candidateHead.size() < desiredHead.size())
{
if (candidateTail)
return true;
for (size_t i = candidateHead.size(); i < desiredHead.size(); ++i)
{
if (const TypeId ty = follow(desiredHead[i]); !isOptionalType(ty, builtinTypes))
return false;
}
}
if (candidateHead.size() > desiredHead.size())
{
return desiredTail.has_value();
}
return true;
}
}