Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
Roblox
GitHub Repository: Roblox/luau
Path: blob/master/Compiler/src/CostModel.cpp
2725 views
1
// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details
2
#include "CostModel.h"
3
4
#include "Luau/Bytecode.h"
5
#include "Luau/Common.h"
6
#include "Luau/DenseHash.h"
7
8
#include "ConstantFolding.h"
9
#include "Utils.h"
10
11
#include <limits.h>
12
13
namespace Luau
14
{
15
namespace Compile
16
{
17
18
inline uint64_t parallelAddSat(uint64_t x, uint64_t y)
19
{
20
uint64_t r = x + y;
21
uint64_t s = r & 0x8080808080808080ull; // saturation mask
22
23
return (r ^ s) | (s - (s >> 7));
24
}
25
26
static uint64_t parallelMulSat(uint64_t a, int b)
27
{
28
int bs = (b < 127) ? b : 127;
29
30
// multiply every other value by b, yielding 14-bit products
31
uint64_t l = bs * ((a >> 0) & 0x007f007f007f007full);
32
uint64_t h = bs * ((a >> 8) & 0x007f007f007f007full);
33
34
// each product is 14-bit, so adding 32768-128 sets high bit iff the sum is 128 or larger without an overflow
35
uint64_t ls = l + 0x7f807f807f807f80ull;
36
uint64_t hs = h + 0x7f807f807f807f80ull;
37
38
// we now merge saturation bits as well as low 7-bits of each product into one
39
uint64_t s = (hs & 0x8000800080008000ull) | ((ls & 0x8000800080008000ull) >> 8);
40
uint64_t r = ((h & 0x007f007f007f007full) << 8) | (l & 0x007f007f007f007full);
41
42
// the low bits are now correct for values that didn't saturate, and we simply need to mask them if high bit is 1
43
return r | (s - (s >> 7));
44
}
45
46
struct Cost
47
{
48
static const uint64_t kLiteral = ~0ull;
49
50
// cost model: 8 bytes, where first byte is the baseline cost, and the next 7 bytes are discounts for when variable #i is constant
51
uint64_t model;
52
// constant mask: 8-byte 0xff mask; equal to all ff's for literals, for variables only byte #i (1+) is set to align with model
53
uint64_t constant;
54
55
Cost(int cost = 0, uint64_t constant = 0)
56
: model(cost < 0x7f ? cost : 0x7f)
57
, constant(constant)
58
{
59
}
60
61
Cost operator+(const Cost& other) const
62
{
63
Cost result;
64
result.model = parallelAddSat(model, other.model);
65
return result;
66
}
67
68
Cost& operator+=(const Cost& other)
69
{
70
model = parallelAddSat(model, other.model);
71
constant = 0;
72
return *this;
73
}
74
75
Cost operator*(int other) const
76
{
77
Cost result;
78
result.model = parallelMulSat(model, other);
79
return result;
80
}
81
82
static Cost fold(const Cost& x, const Cost& y)
83
{
84
uint64_t newmodel = parallelAddSat(x.model, y.model);
85
uint64_t newconstant = x.constant & y.constant;
86
87
// the extra cost for folding is 1; the discount is 1 for the variable that is shared by x&y (or whichever one is used in x/y if the other is
88
// literal)
89
uint64_t extra = (newconstant == kLiteral) ? 0 : (1 | (0x0101010101010101ull & newconstant));
90
91
Cost result;
92
result.model = parallelAddSat(newmodel, extra);
93
result.constant = newconstant;
94
95
return result;
96
}
97
};
98
99
struct CostVisitor : AstVisitor
100
{
101
const DenseHashMap<AstExprCall*, int>& builtins;
102
const DenseHashMap<AstExpr*, Constant>& constants;
103
104
DenseHashMap<AstLocal*, uint64_t> vars;
105
Cost result;
106
107
CostVisitor(const DenseHashMap<AstExprCall*, int>& builtins, const DenseHashMap<AstExpr*, Constant>& constants)
108
: builtins(builtins)
109
, constants(constants)
110
, vars(nullptr)
111
{
112
}
113
114
Cost model(AstExpr* node)
115
{
116
if (constants.contains(node))
117
return Cost(0, Cost::kLiteral);
118
119
if (AstExprGroup* expr = node->as<AstExprGroup>())
120
{
121
return model(expr->expr);
122
}
123
else if (node->is<AstExprConstantNil>() || node->is<AstExprConstantBool>() || node->is<AstExprConstantNumber>() ||
124
node->is<AstExprConstantString>() || node->is<AstExprConstantInteger>())
125
{
126
return Cost(0, Cost::kLiteral);
127
}
128
else if (AstExprLocal* expr = node->as<AstExprLocal>())
129
{
130
const uint64_t* i = vars.find(expr->local);
131
132
return Cost(0, i ? *i : 0); // locals typically don't require extra instructions to compute
133
}
134
else if (node->is<AstExprGlobal>())
135
{
136
return 1;
137
}
138
else if (node->is<AstExprVarargs>())
139
{
140
return 3;
141
}
142
else if (AstExprCall* expr = node->as<AstExprCall>())
143
{
144
// builtin cost modeling is different from regular calls because we use FASTCALL to compile these
145
// thus we use a cheaper baseline, don't account for function, and assume constant/local copy is free
146
const int* bfid = builtins.find(expr);
147
bool builtin = bfid != nullptr && *bfid != LBF_NONE;
148
bool builtinShort = builtin && expr->args.size <= 2; // FASTCALL1/2
149
150
Cost cost = builtin ? 2 : 3;
151
152
if (!builtin)
153
cost += model(expr->func);
154
155
for (size_t i = 0; i < expr->args.size; ++i)
156
{
157
Cost ac = model(expr->args.data[i]);
158
// for constants/locals we still need to copy them to the argument list
159
cost += ac.model == 0 && !builtinShort ? Cost(1) : ac;
160
}
161
162
return cost;
163
}
164
else if (AstExprIndexName* expr = node->as<AstExprIndexName>())
165
{
166
return model(expr->expr) + 1;
167
}
168
else if (AstExprIndexExpr* expr = node->as<AstExprIndexExpr>())
169
{
170
return model(expr->expr) + model(expr->index) + 1;
171
}
172
else if (AstExprFunction* expr = node->as<AstExprFunction>())
173
{
174
return 10; // high baseline cost due to allocation
175
}
176
else if (AstExprTable* expr = node->as<AstExprTable>())
177
{
178
Cost cost = 10; // high baseline cost due to allocation
179
180
for (size_t i = 0; i < expr->items.size; ++i)
181
{
182
const AstExprTable::Item& item = expr->items.data[i];
183
184
if (item.key)
185
cost += model(item.key);
186
187
cost += model(item.value);
188
cost += 1;
189
}
190
191
return cost;
192
}
193
else if (AstExprUnary* expr = node->as<AstExprUnary>())
194
{
195
return Cost::fold(model(expr->expr), Cost(0, Cost::kLiteral));
196
}
197
else if (AstExprBinary* expr = node->as<AstExprBinary>())
198
{
199
return Cost::fold(model(expr->left), model(expr->right));
200
}
201
else if (AstExprTypeAssertion* expr = node->as<AstExprTypeAssertion>())
202
{
203
return model(expr->expr);
204
}
205
else if (AstExprIfElse* expr = node->as<AstExprIfElse>())
206
{
207
return model(expr->condition) + model(expr->trueExpr) + model(expr->falseExpr) + 2;
208
}
209
else if (AstExprInterpString* expr = node->as<AstExprInterpString>())
210
{
211
// Baseline cost of string.format
212
Cost cost = 3;
213
214
for (AstExpr* innerExpression : expr->expressions)
215
cost += model(innerExpression);
216
217
return cost;
218
}
219
else if (AstExprInstantiate* expr = node->as<AstExprInstantiate>())
220
{
221
return model(expr->expr);
222
}
223
else
224
{
225
LUAU_ASSERT(!"Unknown expression type");
226
return {};
227
}
228
}
229
230
void assign(AstExpr* expr)
231
{
232
// variable assignments reset variable mask, so that further uses of this variable aren't discounted
233
// this doesn't work perfectly with backwards control flow like loops, but is good enough for a single pass
234
if (AstExprLocal* lv = expr->as<AstExprLocal>())
235
if (uint64_t* i = vars.find(lv->local))
236
*i = 0;
237
}
238
239
void loop(AstStatBlock* body, Cost iterCost, int factor = 3)
240
{
241
Cost before = result;
242
243
result = Cost();
244
body->visit(this);
245
246
result = before + (result + iterCost) * factor;
247
}
248
249
bool visit(AstExpr* node) override
250
{
251
// note: we short-circuit the visitor traversal through any expression trees by returning false
252
// recursive traversal is happening inside model() which makes it easier to get the resulting value of the subexpression
253
result += model(node);
254
255
return false;
256
}
257
258
bool visit(AstStatFor* node) override
259
{
260
result += model(node->from);
261
result += model(node->to);
262
263
if (node->step)
264
result += model(node->step);
265
266
int tripCount = -1;
267
double from, to, step = 1;
268
269
if (getNumber(node->from, from) && getNumber(node->to, to) && (!node->step || getNumber(node->step, step)))
270
tripCount = getTripCount(from, to, step);
271
272
loop(node->body, 1, tripCount < 0 ? 3 : tripCount);
273
return false;
274
}
275
276
bool visit(AstStatForIn* node) override
277
{
278
for (size_t i = 0; i < node->values.size; ++i)
279
result += model(node->values.data[i]);
280
281
loop(node->body, 1);
282
return false;
283
}
284
285
bool visit(AstStatWhile* node) override
286
{
287
Cost condition = model(node->condition);
288
289
loop(node->body, condition);
290
return false;
291
}
292
293
bool visit(AstStatRepeat* node) override
294
{
295
Cost condition = model(node->condition);
296
297
loop(node->body, condition);
298
return false;
299
}
300
301
bool visit(AstStatIf* node) override
302
{
303
if (isConstantFalse(constants, node->condition))
304
{
305
if (node->elsebody)
306
node->elsebody->visit(this);
307
return false;
308
}
309
310
if (isConstantTrue(constants, node->condition))
311
{
312
node->thenbody->visit(this);
313
return false;
314
}
315
316
// unconditional 'else' may require a jump after the 'if' body
317
// note: this ignores cases when 'then' always terminates and also assumes comparison requires an extra instruction which may be false
318
result += 1 + (node->elsebody && !node->elsebody->is<AstStatIf>());
319
320
return true;
321
}
322
323
bool visit(AstStatLocal* node) override
324
{
325
for (size_t i = 0; i < node->values.size; ++i)
326
{
327
Cost arg = model(node->values.data[i]);
328
329
// propagate constant mask from expression through variables
330
if (arg.constant && i < node->vars.size)
331
vars[node->vars.data[i]] = arg.constant;
332
333
result += arg;
334
}
335
336
return false;
337
}
338
339
bool visit(AstStatAssign* node) override
340
{
341
for (size_t i = 0; i < node->vars.size; ++i)
342
assign(node->vars.data[i]);
343
344
for (size_t i = 0; i < node->vars.size || i < node->values.size; ++i)
345
{
346
Cost ac;
347
if (i < node->vars.size)
348
ac += model(node->vars.data[i]);
349
if (i < node->values.size)
350
ac += model(node->values.data[i]);
351
// local->local or constant->local assignment is not free
352
result += ac.model == 0 ? Cost(1) : ac;
353
}
354
355
return false;
356
}
357
358
bool visit(AstStatCompoundAssign* node) override
359
{
360
assign(node->var);
361
362
// if lhs is not a local, setting it requires an extra table operation
363
result += node->var->is<AstExprLocal>() ? 1 : 2;
364
365
return true;
366
}
367
368
bool visit(AstStatBreak* node) override
369
{
370
result += 1;
371
372
return false;
373
}
374
375
bool visit(AstStatContinue* node) override
376
{
377
result += 1;
378
379
return false;
380
}
381
382
bool getNumber(AstExpr* node, double& result)
383
{
384
if (const Constant* constant = constants.find(node))
385
{
386
if (constant->type == Constant::Type_Number)
387
{
388
result = constant->valueNumber;
389
return true;
390
}
391
}
392
393
return false;
394
}
395
396
bool visit(AstStatBlock* node) override
397
{
398
for (size_t i = 0; i < node->body.size; ++i)
399
{
400
AstStat* stat = node->body.data[i];
401
402
stat->visit(this);
403
404
if (alwaysTerminates(constants, stat))
405
break;
406
}
407
408
return false;
409
}
410
};
411
412
uint64_t modelCost(
413
AstNode* root,
414
AstLocal* const* vars,
415
size_t varCount,
416
const DenseHashMap<AstExprCall*, int>& builtins,
417
const DenseHashMap<AstExpr*, Constant>& constants
418
)
419
{
420
CostVisitor visitor{builtins, constants};
421
for (size_t i = 0; i < varCount && i < 7; ++i)
422
visitor.vars[vars[i]] = 0xffull << (i * 8 + 8);
423
424
root->visit(&visitor);
425
426
return visitor.result.model;
427
}
428
429
uint64_t modelCost(AstNode* root, AstLocal* const* vars, size_t varCount)
430
{
431
DenseHashMap<AstExprCall*, int> builtins{nullptr};
432
DenseHashMap<AstExpr*, Constant> constants{nullptr};
433
434
return modelCost(root, vars, varCount, builtins, constants);
435
}
436
437
int computeCost(uint64_t model, const bool* varsConst, size_t varCount)
438
{
439
int cost = int(model & 0x7f);
440
441
// don't apply discounts to what is likely a saturated sum
442
if (cost == 0x7f)
443
return cost;
444
445
for (size_t i = 0; i < varCount && i < 7; ++i)
446
cost -= int((model >> (i * 8 + 8)) & 0x7f) * varsConst[i];
447
448
return cost;
449
}
450
451
int getTripCount(double from, double to, double step)
452
{
453
// we compute trip count in integers because that way we know that the loop math (repeated addition) is precise
454
int fromi = (from >= -32767 && from <= 32767 && double(int(from)) == from) ? int(from) : INT_MIN;
455
int toi = (to >= -32767 && to <= 32767 && double(int(to)) == to) ? int(to) : INT_MIN;
456
int stepi = (step >= -32767 && step <= 32767 && double(int(step)) == step) ? int(step) : INT_MIN;
457
458
if (fromi == INT_MIN || toi == INT_MIN || stepi == INT_MIN || stepi == 0)
459
return -1;
460
461
if ((stepi < 0 && toi > fromi) || (stepi > 0 && toi < fromi))
462
return 0;
463
464
return (toi - fromi) / stepi + 1;
465
}
466
467
} // namespace Compile
468
} // namespace Luau
469
470