#include "Luau/RequireTracer.h"
#include "Luau/Ast.h"
#include "Luau/Module.h"
namespace Luau
{
struct RequireTracer : AstVisitor
{
RequireTracer(RequireTraceResult& result, FileResolver* fileResolver, const ModuleName& currentModuleName)
: result(result)
, fileResolver(fileResolver)
, currentModuleName(currentModuleName)
, locals(nullptr)
{
}
bool visit(AstExprTypeAssertion* expr) override
{
return false;
}
bool visit(AstExprCall* expr) override
{
AstExprGlobal* global = expr->func->as<AstExprGlobal>();
if (global && global->name == "require" && expr->args.size >= 1)
requireCalls.push_back(expr);
return true;
}
bool visit(AstStatLocal* stat) override
{
for (size_t i = 0; i < stat->vars.size && i < stat->values.size; ++i)
{
AstLocal* local = stat->vars.data[i];
AstExpr* expr = stat->values.data[i];
locals[local] = expr;
}
return true;
}
bool visit(AstStatAssign* stat) override
{
for (size_t i = 0; i < stat->vars.size; ++i)
{
if (AstExprLocal* expr = stat->vars.data[i]->as<AstExprLocal>())
locals[expr->local] = nullptr;
}
return true;
}
bool visit(AstType* node) override
{
return true;
}
bool visit(AstTypePack* node) override
{
return true;
}
AstNode* getDependent(AstNode* node)
{
if (AstExprLocal* expr = node->as<AstExprLocal>())
return locals[expr->local];
else if (AstExprIndexName* expr = node->as<AstExprIndexName>())
return expr->expr;
else if (AstExprIndexExpr* expr = node->as<AstExprIndexExpr>())
return expr->expr;
else if (AstExprCall* expr = node->as<AstExprCall>(); expr && expr->self)
return expr->func->as<AstExprIndexName>()->expr;
else if (AstExprGroup* expr = node->as<AstExprGroup>())
return expr->expr;
else if (AstExprTypeAssertion* expr = node->as<AstExprTypeAssertion>())
return expr->annotation;
else if (AstTypeGroup* expr = node->as<AstTypeGroup>())
return expr->type;
else if (AstTypeTypeof* expr = node->as<AstTypeTypeof>())
return expr->expr;
else
return nullptr;
}
void process(const TypeCheckLimits& limits)
{
ModuleInfo moduleContext{currentModuleName};
work.reserve(requireCalls.size());
for (AstExprCall* require : requireCalls)
work.push_back(require->args.data[0]);
for (size_t i = 0; i < work.size(); ++i)
{
if (AstNode* dep = getDependent(work[i]))
work.push_back(dep);
}
for (size_t i = work.size(); i > 0; --i)
{
AstNode* expr = work[i - 1];
if (result.exprs.contains(expr))
continue;
std::optional<ModuleInfo> info;
if (AstNode* dep = getDependent(expr))
{
const ModuleInfo* context = result.exprs.find(dep);
if (context && expr->is<AstExprLocal>())
info = *context;
else if (context && (expr->is<AstExprGroup>() || expr->is<AstTypeGroup>()))
info = *context;
else if (context && (expr->is<AstTypeTypeof>() || expr->is<AstExprTypeAssertion>()))
info = *context;
else if (AstExpr* asExpr = expr->asExpr())
info = fileResolver->resolveModule(context, asExpr, limits);
}
else if (AstExpr* asExpr = expr->asExpr())
{
info = fileResolver->resolveModule(&moduleContext, asExpr, limits);
}
if (info)
result.exprs[expr] = std::move(*info);
}
result.requireList.reserve(requireCalls.size());
for (AstExprCall* require : requireCalls)
{
AstExpr* arg = require->args.data[0];
if (const ModuleInfo* info = result.exprs.find(arg))
{
result.requireList.push_back({info->name, require->location});
ModuleInfo infoCopy = *info;
result.exprs[require] = std::move(infoCopy);
}
else
{
result.exprs[require] = {};
}
}
}
RequireTraceResult& result;
FileResolver* fileResolver;
ModuleName currentModuleName;
DenseHashMap<AstLocal*, AstExpr*> locals;
std::vector<AstNode*> work;
std::vector<AstExprCall*> requireCalls;
};
RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, const ModuleName& currentModuleName, const TypeCheckLimits& limits)
{
RequireTraceResult result;
RequireTracer tracer{result, fileResolver, currentModuleName};
root->visit(&tracer);
tracer.process(limits);
return result;
}
}