Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Roblox
GitHub Repository: Roblox/luau
Path: blob/master/Require/src/RequireImpl.cpp
2725 views
1
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
2
3
#include "RequireImpl.h"
4
5
#include "Navigation.h"
6
7
#include "Luau/RequireNavigator.h"
8
#include "Luau/Require.h"
9
10
#include "lua.h"
11
#include "lualib.h"
12
13
namespace Luau::Require
14
{
15
16
// Stores explicitly registered modules.
17
static const char* registeredCacheTableKey = "_REGISTEREDMODULES";
18
19
// Stores the results of require calls.
20
static const char* requiredCacheTableKey = "_MODULES";
21
22
struct ResolvedRequire
23
{
24
static ResolvedRequire fromErrorHandler(const RuntimeErrorHandler& errorHandler)
25
{
26
return {ResolvedRequire::Status::ErrorReported, "", "", "", errorHandler.getReportedError()};
27
}
28
29
static ResolvedRequire fromErrorMessage(const char* message)
30
{
31
return {ResolvedRequire::Status::ErrorReported, "", "", "", message};
32
}
33
34
enum class Status
35
{
36
Cached,
37
ModuleRead,
38
ErrorReported
39
};
40
41
Status status;
42
std::string chunkname;
43
std::string loadname;
44
std::string cacheKey;
45
std::string error;
46
};
47
48
static bool isCached(lua_State* L, const std::string& key)
49
{
50
luaL_findtable(L, LUA_REGISTRYINDEX, requiredCacheTableKey, 1);
51
lua_getfield(L, -1, key.c_str());
52
bool cached = !lua_isnil(L, -1);
53
lua_pop(L, 2);
54
55
return cached;
56
}
57
58
static ResolvedRequire resolveRequire(luarequire_Configuration* lrc, lua_State* L, void* ctx, const char* requirerChunkname, std::string path)
59
{
60
if (!lrc->is_require_allowed(L, ctx, requirerChunkname))
61
return ResolvedRequire::fromErrorMessage("require is not supported in this context");
62
63
RuntimeNavigationContext navigationContext{lrc, L, ctx, requirerChunkname};
64
RuntimeErrorHandler errorHandler{path};
65
66
Navigator navigator(navigationContext, errorHandler);
67
68
// Updates navigationContext while navigating through the given path.
69
Navigator::Status status = navigator.navigate(std::move(path));
70
if (status == Navigator::Status::ErrorReported)
71
return ResolvedRequire::fromErrorHandler(errorHandler);
72
73
if (!navigationContext.isModulePresent())
74
return ResolvedRequire::fromErrorMessage("no module present at resolved path");
75
76
std::optional<std::string> cacheKey = navigationContext.getCacheKey();
77
if (!cacheKey)
78
return ResolvedRequire::fromErrorMessage("could not get cache key for module");
79
80
if (isCached(L, *cacheKey))
81
{
82
// Put cached result on top of stack before returning.
83
lua_getfield(L, LUA_REGISTRYINDEX, requiredCacheTableKey);
84
lua_getfield(L, -1, cacheKey->c_str());
85
lua_remove(L, -2);
86
87
return ResolvedRequire{ResolvedRequire::Status::Cached};
88
}
89
90
std::optional<std::string> chunkname = navigationContext.getChunkname();
91
if (!chunkname)
92
return ResolvedRequire::fromErrorMessage("could not get chunkname for module");
93
94
std::optional<std::string> loadname = navigationContext.getLoadname();
95
if (!loadname)
96
return ResolvedRequire::fromErrorMessage("could not get loadname for module");
97
98
return ResolvedRequire{
99
ResolvedRequire::Status::ModuleRead,
100
std::move(*chunkname),
101
std::move(*loadname),
102
std::move(*cacheKey),
103
};
104
}
105
106
static int checkRegisteredModules(lua_State* L, const char* path)
107
{
108
luaL_findtable(L, LUA_REGISTRYINDEX, registeredCacheTableKey, 1);
109
110
std::string pathLower = std::string(path);
111
for (char& c : pathLower)
112
{
113
if (c >= 'A' && c <= 'Z')
114
c -= ('A' - 'a');
115
}
116
117
lua_getfield(L, -1, pathLower.c_str());
118
if (lua_isnil(L, -1))
119
{
120
lua_pop(L, 2);
121
return 0;
122
}
123
124
lua_remove(L, -2);
125
return 1;
126
}
127
128
static const int kRequireStackValues = 4;
129
130
int lua_requirecont(lua_State* L, int status)
131
{
132
// Number of stack arguments present before this continuation is called.
133
LUAU_ASSERT(lua_gettop(L) >= kRequireStackValues);
134
const int numResults = lua_gettop(L) - kRequireStackValues;
135
const char* cacheKey = luaL_checkstring(L, 2);
136
137
if (numResults > 1)
138
luaL_error(L, "module must return a single value");
139
140
// Cache the result
141
if (numResults == 1)
142
{
143
// Initial stack state
144
// (-1) result
145
146
lua_getfield(L, LUA_REGISTRYINDEX, requiredCacheTableKey);
147
// (-2) result, (-1) cache table
148
149
lua_pushvalue(L, -2);
150
// (-3) result, (-2) cache table, (-1) result
151
152
lua_setfield(L, -2, cacheKey);
153
// (-2) result, (-1) cache table
154
155
lua_pop(L, 1);
156
// (-1) result
157
}
158
159
return numResults;
160
}
161
162
int lua_requireinternal(lua_State* L, const char* requirerChunkname)
163
{
164
// Discard extra arguments, we only use path
165
lua_settop(L, 1);
166
167
luarequire_Configuration* lrc = static_cast<luarequire_Configuration*>(lua_touserdata(L, lua_upvalueindex(1)));
168
if (!lrc)
169
luaL_error(L, "unable to find require configuration");
170
171
void* ctx = lua_tolightuserdata(L, lua_upvalueindex(2));
172
173
// (1) path
174
const char* path = luaL_checkstring(L, 1);
175
176
if (checkRegisteredModules(L, path) == 1)
177
return 1;
178
179
// ResolvedRequire will be destroyed and any string will be pinned to Luau stack, so that luaL_error doesn't need destructors
180
bool resolveError = false;
181
182
{
183
ResolvedRequire resolvedRequire = resolveRequire(lrc, L, ctx, requirerChunkname, path);
184
185
if (resolvedRequire.status == ResolvedRequire::Status::Cached)
186
return 1;
187
188
if (resolvedRequire.status == ResolvedRequire::Status::ErrorReported)
189
{
190
lua_pushstring(L, resolvedRequire.error.c_str());
191
resolveError = true;
192
}
193
else
194
{
195
// (1) path, ..., cacheKey, chunkname, loadname
196
lua_pushstring(L, resolvedRequire.cacheKey.c_str());
197
lua_pushstring(L, resolvedRequire.chunkname.c_str());
198
lua_pushstring(L, resolvedRequire.loadname.c_str());
199
}
200
}
201
202
if (resolveError)
203
lua_error(L); // Error already on top of the stack
204
205
int stackValues = lua_gettop(L);
206
LUAU_ASSERT(stackValues == kRequireStackValues);
207
208
const char* chunkname = lua_tostring(L, -2);
209
const char* loadname = lua_tostring(L, -1);
210
211
int numResults = lrc->load(L, ctx, path, chunkname, loadname);
212
if (numResults == -1)
213
{
214
if (lua_gettop(L) != stackValues)
215
luaL_error(L, "stack cannot be modified when require yields");
216
217
return lua_yield(L, 0);
218
}
219
220
return lua_requirecont(L, LUA_OK);
221
}
222
223
int lua_proxyrequire(lua_State* L)
224
{
225
const char* requirerChunkname = luaL_checkstring(L, 2);
226
return lua_requireinternal(L, requirerChunkname);
227
}
228
229
int lua_require(lua_State* L)
230
{
231
lua_Debug ar;
232
int level = 1;
233
234
do
235
{
236
if (!lua_getinfo(L, level++, "s", &ar))
237
luaL_error(L, "require is not supported in this context");
238
} while (ar.what[0] == 'C');
239
240
return lua_requireinternal(L, ar.source);
241
}
242
243
int registerModuleImpl(lua_State* L)
244
{
245
if (lua_gettop(L) != 2)
246
luaL_error(L, "expected 2 arguments: aliased require path and desired result");
247
248
size_t len;
249
const char* path = luaL_checklstring(L, 1, &len);
250
std::string_view pathView(path, len);
251
if (pathView.empty() || pathView[0] != '@')
252
luaL_argerrorL(L, 1, "path must begin with '@'");
253
254
// Make path lowercase to ensure case-insensitive matching.
255
std::string pathLower = std::string(path, len);
256
for (char& c : pathLower)
257
{
258
if (c >= 'A' && c <= 'Z')
259
c -= ('A' - 'a');
260
}
261
lua_pushstring(L, pathLower.c_str());
262
lua_replace(L, 1);
263
264
luaL_findtable(L, LUA_REGISTRYINDEX, registeredCacheTableKey, 1);
265
// (1) path, (2) result, (3) cache table
266
267
lua_insert(L, 1);
268
// (1) cache table, (2) path, (3) result
269
270
lua_settable(L, 1);
271
// (1) cache table
272
273
lua_pop(L, 1);
274
275
return 0;
276
}
277
278
int clearCacheEntry(lua_State* L)
279
{
280
const char* cacheKey = luaL_checkstring(L, 1);
281
luaL_findtable(L, LUA_REGISTRYINDEX, requiredCacheTableKey, 1);
282
lua_pushnil(L);
283
lua_setfield(L, -2, cacheKey);
284
lua_pop(L, 1);
285
return 0;
286
}
287
288
int clearCache(lua_State* L)
289
{
290
lua_newtable(L);
291
lua_setfield(L, LUA_REGISTRYINDEX, requiredCacheTableKey);
292
return 0;
293
}
294
295
} // namespace Luau::Require
296
297