Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
allendowney
GitHub Repository: allendowney/cpython
Path: blob/main/Python/ast_opt.c
12 views
1
/* AST Optimizer */
2
#include "Python.h"
3
#include "pycore_ast.h" // _PyAST_GetDocString()
4
#include "pycore_long.h" // _PyLong
5
#include "pycore_pystate.h" // _PyThreadState_GET()
6
#include "pycore_format.h" // F_LJUST
7
8
9
typedef struct {
10
int optimize;
11
int ff_features;
12
13
int recursion_depth; /* current recursion depth */
14
int recursion_limit; /* recursion limit */
15
} _PyASTOptimizeState;
16
17
18
static int
19
make_const(expr_ty node, PyObject *val, PyArena *arena)
20
{
21
// Even if no new value was calculated, make_const may still
22
// need to clear an error (e.g. for division by zero)
23
if (val == NULL) {
24
if (PyErr_ExceptionMatches(PyExc_KeyboardInterrupt)) {
25
return 0;
26
}
27
PyErr_Clear();
28
return 1;
29
}
30
if (_PyArena_AddPyObject(arena, val) < 0) {
31
Py_DECREF(val);
32
return 0;
33
}
34
node->kind = Constant_kind;
35
node->v.Constant.kind = NULL;
36
node->v.Constant.value = val;
37
return 1;
38
}
39
40
#define COPY_NODE(TO, FROM) (memcpy((TO), (FROM), sizeof(struct _expr)))
41
42
static int
43
has_starred(asdl_expr_seq *elts)
44
{
45
Py_ssize_t n = asdl_seq_LEN(elts);
46
for (Py_ssize_t i = 0; i < n; i++) {
47
expr_ty e = (expr_ty)asdl_seq_GET(elts, i);
48
if (e->kind == Starred_kind) {
49
return 1;
50
}
51
}
52
return 0;
53
}
54
55
56
static PyObject*
57
unary_not(PyObject *v)
58
{
59
int r = PyObject_IsTrue(v);
60
if (r < 0)
61
return NULL;
62
return PyBool_FromLong(!r);
63
}
64
65
static int
66
fold_unaryop(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
67
{
68
expr_ty arg = node->v.UnaryOp.operand;
69
70
if (arg->kind != Constant_kind) {
71
/* Fold not into comparison */
72
if (node->v.UnaryOp.op == Not && arg->kind == Compare_kind &&
73
asdl_seq_LEN(arg->v.Compare.ops) == 1) {
74
/* Eq and NotEq are often implemented in terms of one another, so
75
folding not (self == other) into self != other breaks implementation
76
of !=. Detecting such cases doesn't seem worthwhile.
77
Python uses </> for 'is subset'/'is superset' operations on sets.
78
They don't satisfy not folding laws. */
79
cmpop_ty op = asdl_seq_GET(arg->v.Compare.ops, 0);
80
switch (op) {
81
case Is:
82
op = IsNot;
83
break;
84
case IsNot:
85
op = Is;
86
break;
87
case In:
88
op = NotIn;
89
break;
90
case NotIn:
91
op = In;
92
break;
93
// The remaining comparison operators can't be safely inverted
94
case Eq:
95
case NotEq:
96
case Lt:
97
case LtE:
98
case Gt:
99
case GtE:
100
op = 0; // The AST enums leave "0" free as an "unused" marker
101
break;
102
// No default case, so the compiler will emit a warning if new
103
// comparison operators are added without being handled here
104
}
105
if (op) {
106
asdl_seq_SET(arg->v.Compare.ops, 0, op);
107
COPY_NODE(node, arg);
108
return 1;
109
}
110
}
111
return 1;
112
}
113
114
typedef PyObject *(*unary_op)(PyObject*);
115
static const unary_op ops[] = {
116
[Invert] = PyNumber_Invert,
117
[Not] = unary_not,
118
[UAdd] = PyNumber_Positive,
119
[USub] = PyNumber_Negative,
120
};
121
PyObject *newval = ops[node->v.UnaryOp.op](arg->v.Constant.value);
122
return make_const(node, newval, arena);
123
}
124
125
/* Check whether a collection doesn't containing too much items (including
126
subcollections). This protects from creating a constant that needs
127
too much time for calculating a hash.
128
"limit" is the maximal number of items.
129
Returns the negative number if the total number of items exceeds the
130
limit. Otherwise returns the limit minus the total number of items.
131
*/
132
133
static Py_ssize_t
134
check_complexity(PyObject *obj, Py_ssize_t limit)
135
{
136
if (PyTuple_Check(obj)) {
137
Py_ssize_t i;
138
limit -= PyTuple_GET_SIZE(obj);
139
for (i = 0; limit >= 0 && i < PyTuple_GET_SIZE(obj); i++) {
140
limit = check_complexity(PyTuple_GET_ITEM(obj, i), limit);
141
}
142
return limit;
143
}
144
else if (PyFrozenSet_Check(obj)) {
145
Py_ssize_t i = 0;
146
PyObject *item;
147
Py_hash_t hash;
148
limit -= PySet_GET_SIZE(obj);
149
while (limit >= 0 && _PySet_NextEntry(obj, &i, &item, &hash)) {
150
limit = check_complexity(item, limit);
151
}
152
}
153
return limit;
154
}
155
156
#define MAX_INT_SIZE 128 /* bits */
157
#define MAX_COLLECTION_SIZE 256 /* items */
158
#define MAX_STR_SIZE 4096 /* characters */
159
#define MAX_TOTAL_ITEMS 1024 /* including nested collections */
160
161
static PyObject *
162
safe_multiply(PyObject *v, PyObject *w)
163
{
164
if (PyLong_Check(v) && PyLong_Check(w) &&
165
!_PyLong_IsZero((PyLongObject *)v) && !_PyLong_IsZero((PyLongObject *)w)
166
) {
167
size_t vbits = _PyLong_NumBits(v);
168
size_t wbits = _PyLong_NumBits(w);
169
if (vbits == (size_t)-1 || wbits == (size_t)-1) {
170
return NULL;
171
}
172
if (vbits + wbits > MAX_INT_SIZE) {
173
return NULL;
174
}
175
}
176
else if (PyLong_Check(v) && (PyTuple_Check(w) || PyFrozenSet_Check(w))) {
177
Py_ssize_t size = PyTuple_Check(w) ? PyTuple_GET_SIZE(w) :
178
PySet_GET_SIZE(w);
179
if (size) {
180
long n = PyLong_AsLong(v);
181
if (n < 0 || n > MAX_COLLECTION_SIZE / size) {
182
return NULL;
183
}
184
if (n && check_complexity(w, MAX_TOTAL_ITEMS / n) < 0) {
185
return NULL;
186
}
187
}
188
}
189
else if (PyLong_Check(v) && (PyUnicode_Check(w) || PyBytes_Check(w))) {
190
Py_ssize_t size = PyUnicode_Check(w) ? PyUnicode_GET_LENGTH(w) :
191
PyBytes_GET_SIZE(w);
192
if (size) {
193
long n = PyLong_AsLong(v);
194
if (n < 0 || n > MAX_STR_SIZE / size) {
195
return NULL;
196
}
197
}
198
}
199
else if (PyLong_Check(w) &&
200
(PyTuple_Check(v) || PyFrozenSet_Check(v) ||
201
PyUnicode_Check(v) || PyBytes_Check(v)))
202
{
203
return safe_multiply(w, v);
204
}
205
206
return PyNumber_Multiply(v, w);
207
}
208
209
static PyObject *
210
safe_power(PyObject *v, PyObject *w)
211
{
212
if (PyLong_Check(v) && PyLong_Check(w) &&
213
!_PyLong_IsZero((PyLongObject *)v) && _PyLong_IsPositive((PyLongObject *)w)
214
) {
215
size_t vbits = _PyLong_NumBits(v);
216
size_t wbits = PyLong_AsSize_t(w);
217
if (vbits == (size_t)-1 || wbits == (size_t)-1) {
218
return NULL;
219
}
220
if (vbits > MAX_INT_SIZE / wbits) {
221
return NULL;
222
}
223
}
224
225
return PyNumber_Power(v, w, Py_None);
226
}
227
228
static PyObject *
229
safe_lshift(PyObject *v, PyObject *w)
230
{
231
if (PyLong_Check(v) && PyLong_Check(w) &&
232
!_PyLong_IsZero((PyLongObject *)v) && !_PyLong_IsZero((PyLongObject *)w)
233
) {
234
size_t vbits = _PyLong_NumBits(v);
235
size_t wbits = PyLong_AsSize_t(w);
236
if (vbits == (size_t)-1 || wbits == (size_t)-1) {
237
return NULL;
238
}
239
if (wbits > MAX_INT_SIZE || vbits > MAX_INT_SIZE - wbits) {
240
return NULL;
241
}
242
}
243
244
return PyNumber_Lshift(v, w);
245
}
246
247
static PyObject *
248
safe_mod(PyObject *v, PyObject *w)
249
{
250
if (PyUnicode_Check(v) || PyBytes_Check(v)) {
251
return NULL;
252
}
253
254
return PyNumber_Remainder(v, w);
255
}
256
257
258
static expr_ty
259
parse_literal(PyObject *fmt, Py_ssize_t *ppos, PyArena *arena)
260
{
261
const void *data = PyUnicode_DATA(fmt);
262
int kind = PyUnicode_KIND(fmt);
263
Py_ssize_t size = PyUnicode_GET_LENGTH(fmt);
264
Py_ssize_t start, pos;
265
int has_percents = 0;
266
start = pos = *ppos;
267
while (pos < size) {
268
if (PyUnicode_READ(kind, data, pos) != '%') {
269
pos++;
270
}
271
else if (pos+1 < size && PyUnicode_READ(kind, data, pos+1) == '%') {
272
has_percents = 1;
273
pos += 2;
274
}
275
else {
276
break;
277
}
278
}
279
*ppos = pos;
280
if (pos == start) {
281
return NULL;
282
}
283
PyObject *str = PyUnicode_Substring(fmt, start, pos);
284
/* str = str.replace('%%', '%') */
285
if (str && has_percents) {
286
_Py_DECLARE_STR(percent, "%");
287
_Py_DECLARE_STR(dbl_percent, "%%");
288
Py_SETREF(str, PyUnicode_Replace(str, &_Py_STR(dbl_percent),
289
&_Py_STR(percent), -1));
290
}
291
if (!str) {
292
return NULL;
293
}
294
295
if (_PyArena_AddPyObject(arena, str) < 0) {
296
Py_DECREF(str);
297
return NULL;
298
}
299
return _PyAST_Constant(str, NULL, -1, -1, -1, -1, arena);
300
}
301
302
#define MAXDIGITS 3
303
304
static int
305
simple_format_arg_parse(PyObject *fmt, Py_ssize_t *ppos,
306
int *spec, int *flags, int *width, int *prec)
307
{
308
Py_ssize_t pos = *ppos, len = PyUnicode_GET_LENGTH(fmt);
309
Py_UCS4 ch;
310
311
#define NEXTC do { \
312
if (pos >= len) { \
313
return 0; \
314
} \
315
ch = PyUnicode_READ_CHAR(fmt, pos); \
316
pos++; \
317
} while (0)
318
319
*flags = 0;
320
while (1) {
321
NEXTC;
322
switch (ch) {
323
case '-': *flags |= F_LJUST; continue;
324
case '+': *flags |= F_SIGN; continue;
325
case ' ': *flags |= F_BLANK; continue;
326
case '#': *flags |= F_ALT; continue;
327
case '0': *flags |= F_ZERO; continue;
328
}
329
break;
330
}
331
if ('0' <= ch && ch <= '9') {
332
*width = 0;
333
int digits = 0;
334
while ('0' <= ch && ch <= '9') {
335
*width = *width * 10 + (ch - '0');
336
NEXTC;
337
if (++digits >= MAXDIGITS) {
338
return 0;
339
}
340
}
341
}
342
343
if (ch == '.') {
344
NEXTC;
345
*prec = 0;
346
if ('0' <= ch && ch <= '9') {
347
int digits = 0;
348
while ('0' <= ch && ch <= '9') {
349
*prec = *prec * 10 + (ch - '0');
350
NEXTC;
351
if (++digits >= MAXDIGITS) {
352
return 0;
353
}
354
}
355
}
356
}
357
*spec = ch;
358
*ppos = pos;
359
return 1;
360
361
#undef NEXTC
362
}
363
364
static expr_ty
365
parse_format(PyObject *fmt, Py_ssize_t *ppos, expr_ty arg, PyArena *arena)
366
{
367
int spec, flags, width = -1, prec = -1;
368
if (!simple_format_arg_parse(fmt, ppos, &spec, &flags, &width, &prec)) {
369
// Unsupported format.
370
return NULL;
371
}
372
if (spec == 's' || spec == 'r' || spec == 'a') {
373
char buf[1 + MAXDIGITS + 1 + MAXDIGITS + 1], *p = buf;
374
if (!(flags & F_LJUST) && width > 0) {
375
*p++ = '>';
376
}
377
if (width >= 0) {
378
p += snprintf(p, MAXDIGITS + 1, "%d", width);
379
}
380
if (prec >= 0) {
381
p += snprintf(p, MAXDIGITS + 2, ".%d", prec);
382
}
383
expr_ty format_spec = NULL;
384
if (p != buf) {
385
PyObject *str = PyUnicode_FromString(buf);
386
if (str == NULL) {
387
return NULL;
388
}
389
if (_PyArena_AddPyObject(arena, str) < 0) {
390
Py_DECREF(str);
391
return NULL;
392
}
393
format_spec = _PyAST_Constant(str, NULL, -1, -1, -1, -1, arena);
394
if (format_spec == NULL) {
395
return NULL;
396
}
397
}
398
return _PyAST_FormattedValue(arg, spec, format_spec,
399
arg->lineno, arg->col_offset,
400
arg->end_lineno, arg->end_col_offset,
401
arena);
402
}
403
// Unsupported format.
404
return NULL;
405
}
406
407
static int
408
optimize_format(expr_ty node, PyObject *fmt, asdl_expr_seq *elts, PyArena *arena)
409
{
410
Py_ssize_t pos = 0;
411
Py_ssize_t cnt = 0;
412
asdl_expr_seq *seq = _Py_asdl_expr_seq_new(asdl_seq_LEN(elts) * 2 + 1, arena);
413
if (!seq) {
414
return 0;
415
}
416
seq->size = 0;
417
418
while (1) {
419
expr_ty lit = parse_literal(fmt, &pos, arena);
420
if (lit) {
421
asdl_seq_SET(seq, seq->size++, lit);
422
}
423
else if (PyErr_Occurred()) {
424
return 0;
425
}
426
427
if (pos >= PyUnicode_GET_LENGTH(fmt)) {
428
break;
429
}
430
if (cnt >= asdl_seq_LEN(elts)) {
431
// More format units than items.
432
return 1;
433
}
434
assert(PyUnicode_READ_CHAR(fmt, pos) == '%');
435
pos++;
436
expr_ty expr = parse_format(fmt, &pos, asdl_seq_GET(elts, cnt), arena);
437
cnt++;
438
if (!expr) {
439
return !PyErr_Occurred();
440
}
441
asdl_seq_SET(seq, seq->size++, expr);
442
}
443
if (cnt < asdl_seq_LEN(elts)) {
444
// More items than format units.
445
return 1;
446
}
447
expr_ty res = _PyAST_JoinedStr(seq,
448
node->lineno, node->col_offset,
449
node->end_lineno, node->end_col_offset,
450
arena);
451
if (!res) {
452
return 0;
453
}
454
COPY_NODE(node, res);
455
// PySys_FormatStderr("format = %R\n", fmt);
456
return 1;
457
}
458
459
static int
460
fold_binop(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
461
{
462
expr_ty lhs, rhs;
463
lhs = node->v.BinOp.left;
464
rhs = node->v.BinOp.right;
465
if (lhs->kind != Constant_kind) {
466
return 1;
467
}
468
PyObject *lv = lhs->v.Constant.value;
469
470
if (node->v.BinOp.op == Mod &&
471
rhs->kind == Tuple_kind &&
472
PyUnicode_Check(lv) &&
473
!has_starred(rhs->v.Tuple.elts))
474
{
475
return optimize_format(node, lv, rhs->v.Tuple.elts, arena);
476
}
477
478
if (rhs->kind != Constant_kind) {
479
return 1;
480
}
481
482
PyObject *rv = rhs->v.Constant.value;
483
PyObject *newval = NULL;
484
485
switch (node->v.BinOp.op) {
486
case Add:
487
newval = PyNumber_Add(lv, rv);
488
break;
489
case Sub:
490
newval = PyNumber_Subtract(lv, rv);
491
break;
492
case Mult:
493
newval = safe_multiply(lv, rv);
494
break;
495
case Div:
496
newval = PyNumber_TrueDivide(lv, rv);
497
break;
498
case FloorDiv:
499
newval = PyNumber_FloorDivide(lv, rv);
500
break;
501
case Mod:
502
newval = safe_mod(lv, rv);
503
break;
504
case Pow:
505
newval = safe_power(lv, rv);
506
break;
507
case LShift:
508
newval = safe_lshift(lv, rv);
509
break;
510
case RShift:
511
newval = PyNumber_Rshift(lv, rv);
512
break;
513
case BitOr:
514
newval = PyNumber_Or(lv, rv);
515
break;
516
case BitXor:
517
newval = PyNumber_Xor(lv, rv);
518
break;
519
case BitAnd:
520
newval = PyNumber_And(lv, rv);
521
break;
522
// No builtin constants implement the following operators
523
case MatMult:
524
return 1;
525
// No default case, so the compiler will emit a warning if new binary
526
// operators are added without being handled here
527
}
528
529
return make_const(node, newval, arena);
530
}
531
532
static PyObject*
533
make_const_tuple(asdl_expr_seq *elts)
534
{
535
for (int i = 0; i < asdl_seq_LEN(elts); i++) {
536
expr_ty e = (expr_ty)asdl_seq_GET(elts, i);
537
if (e->kind != Constant_kind) {
538
return NULL;
539
}
540
}
541
542
PyObject *newval = PyTuple_New(asdl_seq_LEN(elts));
543
if (newval == NULL) {
544
return NULL;
545
}
546
547
for (int i = 0; i < asdl_seq_LEN(elts); i++) {
548
expr_ty e = (expr_ty)asdl_seq_GET(elts, i);
549
PyObject *v = e->v.Constant.value;
550
PyTuple_SET_ITEM(newval, i, Py_NewRef(v));
551
}
552
return newval;
553
}
554
555
static int
556
fold_tuple(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
557
{
558
PyObject *newval;
559
560
if (node->v.Tuple.ctx != Load)
561
return 1;
562
563
newval = make_const_tuple(node->v.Tuple.elts);
564
return make_const(node, newval, arena);
565
}
566
567
static int
568
fold_subscr(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
569
{
570
PyObject *newval;
571
expr_ty arg, idx;
572
573
arg = node->v.Subscript.value;
574
idx = node->v.Subscript.slice;
575
if (node->v.Subscript.ctx != Load ||
576
arg->kind != Constant_kind ||
577
idx->kind != Constant_kind)
578
{
579
return 1;
580
}
581
582
newval = PyObject_GetItem(arg->v.Constant.value, idx->v.Constant.value);
583
return make_const(node, newval, arena);
584
}
585
586
/* Change literal list or set of constants into constant
587
tuple or frozenset respectively. Change literal list of
588
non-constants into tuple.
589
Used for right operand of "in" and "not in" tests and for iterable
590
in "for" loop and comprehensions.
591
*/
592
static int
593
fold_iter(expr_ty arg, PyArena *arena, _PyASTOptimizeState *state)
594
{
595
PyObject *newval;
596
if (arg->kind == List_kind) {
597
/* First change a list into tuple. */
598
asdl_expr_seq *elts = arg->v.List.elts;
599
if (has_starred(elts)) {
600
return 1;
601
}
602
expr_context_ty ctx = arg->v.List.ctx;
603
arg->kind = Tuple_kind;
604
arg->v.Tuple.elts = elts;
605
arg->v.Tuple.ctx = ctx;
606
/* Try to create a constant tuple. */
607
newval = make_const_tuple(elts);
608
}
609
else if (arg->kind == Set_kind) {
610
newval = make_const_tuple(arg->v.Set.elts);
611
if (newval) {
612
Py_SETREF(newval, PyFrozenSet_New(newval));
613
}
614
}
615
else {
616
return 1;
617
}
618
return make_const(arg, newval, arena);
619
}
620
621
static int
622
fold_compare(expr_ty node, PyArena *arena, _PyASTOptimizeState *state)
623
{
624
asdl_int_seq *ops;
625
asdl_expr_seq *args;
626
Py_ssize_t i;
627
628
ops = node->v.Compare.ops;
629
args = node->v.Compare.comparators;
630
/* Change literal list or set in 'in' or 'not in' into
631
tuple or frozenset respectively. */
632
i = asdl_seq_LEN(ops) - 1;
633
int op = asdl_seq_GET(ops, i);
634
if (op == In || op == NotIn) {
635
if (!fold_iter((expr_ty)asdl_seq_GET(args, i), arena, state)) {
636
return 0;
637
}
638
}
639
return 1;
640
}
641
642
static int astfold_mod(mod_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
643
static int astfold_stmt(stmt_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
644
static int astfold_expr(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
645
static int astfold_arguments(arguments_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
646
static int astfold_comprehension(comprehension_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
647
static int astfold_keyword(keyword_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
648
static int astfold_arg(arg_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
649
static int astfold_withitem(withitem_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
650
static int astfold_excepthandler(excepthandler_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
651
static int astfold_match_case(match_case_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
652
static int astfold_pattern(pattern_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
653
static int astfold_type_param(type_param_ty node_, PyArena *ctx_, _PyASTOptimizeState *state);
654
655
#define CALL(FUNC, TYPE, ARG) \
656
if (!FUNC((ARG), ctx_, state)) \
657
return 0;
658
659
#define CALL_OPT(FUNC, TYPE, ARG) \
660
if ((ARG) != NULL && !FUNC((ARG), ctx_, state)) \
661
return 0;
662
663
#define CALL_SEQ(FUNC, TYPE, ARG) { \
664
int i; \
665
asdl_ ## TYPE ## _seq *seq = (ARG); /* avoid variable capture */ \
666
for (i = 0; i < asdl_seq_LEN(seq); i++) { \
667
TYPE ## _ty elt = (TYPE ## _ty)asdl_seq_GET(seq, i); \
668
if (elt != NULL && !FUNC(elt, ctx_, state)) \
669
return 0; \
670
} \
671
}
672
673
674
static int
675
astfold_body(asdl_stmt_seq *stmts, PyArena *ctx_, _PyASTOptimizeState *state)
676
{
677
int docstring = _PyAST_GetDocString(stmts) != NULL;
678
CALL_SEQ(astfold_stmt, stmt, stmts);
679
if (!docstring && _PyAST_GetDocString(stmts) != NULL) {
680
stmt_ty st = (stmt_ty)asdl_seq_GET(stmts, 0);
681
asdl_expr_seq *values = _Py_asdl_expr_seq_new(1, ctx_);
682
if (!values) {
683
return 0;
684
}
685
asdl_seq_SET(values, 0, st->v.Expr.value);
686
expr_ty expr = _PyAST_JoinedStr(values, st->lineno, st->col_offset,
687
st->end_lineno, st->end_col_offset,
688
ctx_);
689
if (!expr) {
690
return 0;
691
}
692
st->v.Expr.value = expr;
693
}
694
return 1;
695
}
696
697
static int
698
astfold_mod(mod_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
699
{
700
switch (node_->kind) {
701
case Module_kind:
702
CALL(astfold_body, asdl_seq, node_->v.Module.body);
703
break;
704
case Interactive_kind:
705
CALL_SEQ(astfold_stmt, stmt, node_->v.Interactive.body);
706
break;
707
case Expression_kind:
708
CALL(astfold_expr, expr_ty, node_->v.Expression.body);
709
break;
710
// The following top level nodes don't participate in constant folding
711
case FunctionType_kind:
712
break;
713
// No default case, so the compiler will emit a warning if new top level
714
// compilation nodes are added without being handled here
715
}
716
return 1;
717
}
718
719
static int
720
astfold_expr(expr_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
721
{
722
if (++state->recursion_depth > state->recursion_limit) {
723
PyErr_SetString(PyExc_RecursionError,
724
"maximum recursion depth exceeded during compilation");
725
return 0;
726
}
727
switch (node_->kind) {
728
case BoolOp_kind:
729
CALL_SEQ(astfold_expr, expr, node_->v.BoolOp.values);
730
break;
731
case BinOp_kind:
732
CALL(astfold_expr, expr_ty, node_->v.BinOp.left);
733
CALL(astfold_expr, expr_ty, node_->v.BinOp.right);
734
CALL(fold_binop, expr_ty, node_);
735
break;
736
case UnaryOp_kind:
737
CALL(astfold_expr, expr_ty, node_->v.UnaryOp.operand);
738
CALL(fold_unaryop, expr_ty, node_);
739
break;
740
case Lambda_kind:
741
CALL(astfold_arguments, arguments_ty, node_->v.Lambda.args);
742
CALL(astfold_expr, expr_ty, node_->v.Lambda.body);
743
break;
744
case IfExp_kind:
745
CALL(astfold_expr, expr_ty, node_->v.IfExp.test);
746
CALL(astfold_expr, expr_ty, node_->v.IfExp.body);
747
CALL(astfold_expr, expr_ty, node_->v.IfExp.orelse);
748
break;
749
case Dict_kind:
750
CALL_SEQ(astfold_expr, expr, node_->v.Dict.keys);
751
CALL_SEQ(astfold_expr, expr, node_->v.Dict.values);
752
break;
753
case Set_kind:
754
CALL_SEQ(astfold_expr, expr, node_->v.Set.elts);
755
break;
756
case ListComp_kind:
757
CALL(astfold_expr, expr_ty, node_->v.ListComp.elt);
758
CALL_SEQ(astfold_comprehension, comprehension, node_->v.ListComp.generators);
759
break;
760
case SetComp_kind:
761
CALL(astfold_expr, expr_ty, node_->v.SetComp.elt);
762
CALL_SEQ(astfold_comprehension, comprehension, node_->v.SetComp.generators);
763
break;
764
case DictComp_kind:
765
CALL(astfold_expr, expr_ty, node_->v.DictComp.key);
766
CALL(astfold_expr, expr_ty, node_->v.DictComp.value);
767
CALL_SEQ(astfold_comprehension, comprehension, node_->v.DictComp.generators);
768
break;
769
case GeneratorExp_kind:
770
CALL(astfold_expr, expr_ty, node_->v.GeneratorExp.elt);
771
CALL_SEQ(astfold_comprehension, comprehension, node_->v.GeneratorExp.generators);
772
break;
773
case Await_kind:
774
CALL(astfold_expr, expr_ty, node_->v.Await.value);
775
break;
776
case Yield_kind:
777
CALL_OPT(astfold_expr, expr_ty, node_->v.Yield.value);
778
break;
779
case YieldFrom_kind:
780
CALL(astfold_expr, expr_ty, node_->v.YieldFrom.value);
781
break;
782
case Compare_kind:
783
CALL(astfold_expr, expr_ty, node_->v.Compare.left);
784
CALL_SEQ(astfold_expr, expr, node_->v.Compare.comparators);
785
CALL(fold_compare, expr_ty, node_);
786
break;
787
case Call_kind:
788
CALL(astfold_expr, expr_ty, node_->v.Call.func);
789
CALL_SEQ(astfold_expr, expr, node_->v.Call.args);
790
CALL_SEQ(astfold_keyword, keyword, node_->v.Call.keywords);
791
break;
792
case FormattedValue_kind:
793
CALL(astfold_expr, expr_ty, node_->v.FormattedValue.value);
794
CALL_OPT(astfold_expr, expr_ty, node_->v.FormattedValue.format_spec);
795
break;
796
case JoinedStr_kind:
797
CALL_SEQ(astfold_expr, expr, node_->v.JoinedStr.values);
798
break;
799
case Attribute_kind:
800
CALL(astfold_expr, expr_ty, node_->v.Attribute.value);
801
break;
802
case Subscript_kind:
803
CALL(astfold_expr, expr_ty, node_->v.Subscript.value);
804
CALL(astfold_expr, expr_ty, node_->v.Subscript.slice);
805
CALL(fold_subscr, expr_ty, node_);
806
break;
807
case Starred_kind:
808
CALL(astfold_expr, expr_ty, node_->v.Starred.value);
809
break;
810
case Slice_kind:
811
CALL_OPT(astfold_expr, expr_ty, node_->v.Slice.lower);
812
CALL_OPT(astfold_expr, expr_ty, node_->v.Slice.upper);
813
CALL_OPT(astfold_expr, expr_ty, node_->v.Slice.step);
814
break;
815
case List_kind:
816
CALL_SEQ(astfold_expr, expr, node_->v.List.elts);
817
break;
818
case Tuple_kind:
819
CALL_SEQ(astfold_expr, expr, node_->v.Tuple.elts);
820
CALL(fold_tuple, expr_ty, node_);
821
break;
822
case Name_kind:
823
if (node_->v.Name.ctx == Load &&
824
_PyUnicode_EqualToASCIIString(node_->v.Name.id, "__debug__")) {
825
state->recursion_depth--;
826
return make_const(node_, PyBool_FromLong(!state->optimize), ctx_);
827
}
828
break;
829
case NamedExpr_kind:
830
CALL(astfold_expr, expr_ty, node_->v.NamedExpr.value);
831
break;
832
case Constant_kind:
833
// Already a constant, nothing further to do
834
break;
835
// No default case, so the compiler will emit a warning if new expression
836
// kinds are added without being handled here
837
}
838
state->recursion_depth--;
839
return 1;
840
}
841
842
static int
843
astfold_keyword(keyword_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
844
{
845
CALL(astfold_expr, expr_ty, node_->value);
846
return 1;
847
}
848
849
static int
850
astfold_comprehension(comprehension_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
851
{
852
CALL(astfold_expr, expr_ty, node_->target);
853
CALL(astfold_expr, expr_ty, node_->iter);
854
CALL_SEQ(astfold_expr, expr, node_->ifs);
855
856
CALL(fold_iter, expr_ty, node_->iter);
857
return 1;
858
}
859
860
static int
861
astfold_arguments(arguments_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
862
{
863
CALL_SEQ(astfold_arg, arg, node_->posonlyargs);
864
CALL_SEQ(astfold_arg, arg, node_->args);
865
CALL_OPT(astfold_arg, arg_ty, node_->vararg);
866
CALL_SEQ(astfold_arg, arg, node_->kwonlyargs);
867
CALL_SEQ(astfold_expr, expr, node_->kw_defaults);
868
CALL_OPT(astfold_arg, arg_ty, node_->kwarg);
869
CALL_SEQ(astfold_expr, expr, node_->defaults);
870
return 1;
871
}
872
873
static int
874
astfold_arg(arg_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
875
{
876
if (!(state->ff_features & CO_FUTURE_ANNOTATIONS)) {
877
CALL_OPT(astfold_expr, expr_ty, node_->annotation);
878
}
879
return 1;
880
}
881
882
static int
883
astfold_stmt(stmt_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
884
{
885
if (++state->recursion_depth > state->recursion_limit) {
886
PyErr_SetString(PyExc_RecursionError,
887
"maximum recursion depth exceeded during compilation");
888
return 0;
889
}
890
switch (node_->kind) {
891
case FunctionDef_kind:
892
CALL_SEQ(astfold_type_param, type_param, node_->v.FunctionDef.type_params);
893
CALL(astfold_arguments, arguments_ty, node_->v.FunctionDef.args);
894
CALL(astfold_body, asdl_seq, node_->v.FunctionDef.body);
895
CALL_SEQ(astfold_expr, expr, node_->v.FunctionDef.decorator_list);
896
if (!(state->ff_features & CO_FUTURE_ANNOTATIONS)) {
897
CALL_OPT(astfold_expr, expr_ty, node_->v.FunctionDef.returns);
898
}
899
break;
900
case AsyncFunctionDef_kind:
901
CALL_SEQ(astfold_type_param, type_param, node_->v.AsyncFunctionDef.type_params);
902
CALL(astfold_arguments, arguments_ty, node_->v.AsyncFunctionDef.args);
903
CALL(astfold_body, asdl_seq, node_->v.AsyncFunctionDef.body);
904
CALL_SEQ(astfold_expr, expr, node_->v.AsyncFunctionDef.decorator_list);
905
if (!(state->ff_features & CO_FUTURE_ANNOTATIONS)) {
906
CALL_OPT(astfold_expr, expr_ty, node_->v.AsyncFunctionDef.returns);
907
}
908
break;
909
case ClassDef_kind:
910
CALL_SEQ(astfold_type_param, type_param, node_->v.ClassDef.type_params);
911
CALL_SEQ(astfold_expr, expr, node_->v.ClassDef.bases);
912
CALL_SEQ(astfold_keyword, keyword, node_->v.ClassDef.keywords);
913
CALL(astfold_body, asdl_seq, node_->v.ClassDef.body);
914
CALL_SEQ(astfold_expr, expr, node_->v.ClassDef.decorator_list);
915
break;
916
case Return_kind:
917
CALL_OPT(astfold_expr, expr_ty, node_->v.Return.value);
918
break;
919
case Delete_kind:
920
CALL_SEQ(astfold_expr, expr, node_->v.Delete.targets);
921
break;
922
case Assign_kind:
923
CALL_SEQ(astfold_expr, expr, node_->v.Assign.targets);
924
CALL(astfold_expr, expr_ty, node_->v.Assign.value);
925
break;
926
case AugAssign_kind:
927
CALL(astfold_expr, expr_ty, node_->v.AugAssign.target);
928
CALL(astfold_expr, expr_ty, node_->v.AugAssign.value);
929
break;
930
case AnnAssign_kind:
931
CALL(astfold_expr, expr_ty, node_->v.AnnAssign.target);
932
if (!(state->ff_features & CO_FUTURE_ANNOTATIONS)) {
933
CALL(astfold_expr, expr_ty, node_->v.AnnAssign.annotation);
934
}
935
CALL_OPT(astfold_expr, expr_ty, node_->v.AnnAssign.value);
936
break;
937
case TypeAlias_kind:
938
CALL(astfold_expr, expr_ty, node_->v.TypeAlias.name);
939
CALL_SEQ(astfold_type_param, type_param, node_->v.TypeAlias.type_params);
940
CALL(astfold_expr, expr_ty, node_->v.TypeAlias.value);
941
break;
942
case For_kind:
943
CALL(astfold_expr, expr_ty, node_->v.For.target);
944
CALL(astfold_expr, expr_ty, node_->v.For.iter);
945
CALL_SEQ(astfold_stmt, stmt, node_->v.For.body);
946
CALL_SEQ(astfold_stmt, stmt, node_->v.For.orelse);
947
948
CALL(fold_iter, expr_ty, node_->v.For.iter);
949
break;
950
case AsyncFor_kind:
951
CALL(astfold_expr, expr_ty, node_->v.AsyncFor.target);
952
CALL(astfold_expr, expr_ty, node_->v.AsyncFor.iter);
953
CALL_SEQ(astfold_stmt, stmt, node_->v.AsyncFor.body);
954
CALL_SEQ(astfold_stmt, stmt, node_->v.AsyncFor.orelse);
955
break;
956
case While_kind:
957
CALL(astfold_expr, expr_ty, node_->v.While.test);
958
CALL_SEQ(astfold_stmt, stmt, node_->v.While.body);
959
CALL_SEQ(astfold_stmt, stmt, node_->v.While.orelse);
960
break;
961
case If_kind:
962
CALL(astfold_expr, expr_ty, node_->v.If.test);
963
CALL_SEQ(astfold_stmt, stmt, node_->v.If.body);
964
CALL_SEQ(astfold_stmt, stmt, node_->v.If.orelse);
965
break;
966
case With_kind:
967
CALL_SEQ(astfold_withitem, withitem, node_->v.With.items);
968
CALL_SEQ(astfold_stmt, stmt, node_->v.With.body);
969
break;
970
case AsyncWith_kind:
971
CALL_SEQ(astfold_withitem, withitem, node_->v.AsyncWith.items);
972
CALL_SEQ(astfold_stmt, stmt, node_->v.AsyncWith.body);
973
break;
974
case Raise_kind:
975
CALL_OPT(astfold_expr, expr_ty, node_->v.Raise.exc);
976
CALL_OPT(astfold_expr, expr_ty, node_->v.Raise.cause);
977
break;
978
case Try_kind:
979
CALL_SEQ(astfold_stmt, stmt, node_->v.Try.body);
980
CALL_SEQ(astfold_excepthandler, excepthandler, node_->v.Try.handlers);
981
CALL_SEQ(astfold_stmt, stmt, node_->v.Try.orelse);
982
CALL_SEQ(astfold_stmt, stmt, node_->v.Try.finalbody);
983
break;
984
case TryStar_kind:
985
CALL_SEQ(astfold_stmt, stmt, node_->v.TryStar.body);
986
CALL_SEQ(astfold_excepthandler, excepthandler, node_->v.TryStar.handlers);
987
CALL_SEQ(astfold_stmt, stmt, node_->v.TryStar.orelse);
988
CALL_SEQ(astfold_stmt, stmt, node_->v.TryStar.finalbody);
989
break;
990
case Assert_kind:
991
CALL(astfold_expr, expr_ty, node_->v.Assert.test);
992
CALL_OPT(astfold_expr, expr_ty, node_->v.Assert.msg);
993
break;
994
case Expr_kind:
995
CALL(astfold_expr, expr_ty, node_->v.Expr.value);
996
break;
997
case Match_kind:
998
CALL(astfold_expr, expr_ty, node_->v.Match.subject);
999
CALL_SEQ(astfold_match_case, match_case, node_->v.Match.cases);
1000
break;
1001
// The following statements don't contain any subexpressions to be folded
1002
case Import_kind:
1003
case ImportFrom_kind:
1004
case Global_kind:
1005
case Nonlocal_kind:
1006
case Pass_kind:
1007
case Break_kind:
1008
case Continue_kind:
1009
break;
1010
// No default case, so the compiler will emit a warning if new statement
1011
// kinds are added without being handled here
1012
}
1013
state->recursion_depth--;
1014
return 1;
1015
}
1016
1017
static int
1018
astfold_excepthandler(excepthandler_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
1019
{
1020
switch (node_->kind) {
1021
case ExceptHandler_kind:
1022
CALL_OPT(astfold_expr, expr_ty, node_->v.ExceptHandler.type);
1023
CALL_SEQ(astfold_stmt, stmt, node_->v.ExceptHandler.body);
1024
break;
1025
// No default case, so the compiler will emit a warning if new handler
1026
// kinds are added without being handled here
1027
}
1028
return 1;
1029
}
1030
1031
static int
1032
astfold_withitem(withitem_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
1033
{
1034
CALL(astfold_expr, expr_ty, node_->context_expr);
1035
CALL_OPT(astfold_expr, expr_ty, node_->optional_vars);
1036
return 1;
1037
}
1038
1039
static int
1040
astfold_pattern(pattern_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
1041
{
1042
// Currently, this is really only used to form complex/negative numeric
1043
// constants in MatchValue and MatchMapping nodes
1044
// We still recurse into all subexpressions and subpatterns anyway
1045
if (++state->recursion_depth > state->recursion_limit) {
1046
PyErr_SetString(PyExc_RecursionError,
1047
"maximum recursion depth exceeded during compilation");
1048
return 0;
1049
}
1050
switch (node_->kind) {
1051
case MatchValue_kind:
1052
CALL(astfold_expr, expr_ty, node_->v.MatchValue.value);
1053
break;
1054
case MatchSingleton_kind:
1055
break;
1056
case MatchSequence_kind:
1057
CALL_SEQ(astfold_pattern, pattern, node_->v.MatchSequence.patterns);
1058
break;
1059
case MatchMapping_kind:
1060
CALL_SEQ(astfold_expr, expr, node_->v.MatchMapping.keys);
1061
CALL_SEQ(astfold_pattern, pattern, node_->v.MatchMapping.patterns);
1062
break;
1063
case MatchClass_kind:
1064
CALL(astfold_expr, expr_ty, node_->v.MatchClass.cls);
1065
CALL_SEQ(astfold_pattern, pattern, node_->v.MatchClass.patterns);
1066
CALL_SEQ(astfold_pattern, pattern, node_->v.MatchClass.kwd_patterns);
1067
break;
1068
case MatchStar_kind:
1069
break;
1070
case MatchAs_kind:
1071
if (node_->v.MatchAs.pattern) {
1072
CALL(astfold_pattern, pattern_ty, node_->v.MatchAs.pattern);
1073
}
1074
break;
1075
case MatchOr_kind:
1076
CALL_SEQ(astfold_pattern, pattern, node_->v.MatchOr.patterns);
1077
break;
1078
// No default case, so the compiler will emit a warning if new pattern
1079
// kinds are added without being handled here
1080
}
1081
state->recursion_depth--;
1082
return 1;
1083
}
1084
1085
static int
1086
astfold_match_case(match_case_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
1087
{
1088
CALL(astfold_pattern, expr_ty, node_->pattern);
1089
CALL_OPT(astfold_expr, expr_ty, node_->guard);
1090
CALL_SEQ(astfold_stmt, stmt, node_->body);
1091
return 1;
1092
}
1093
1094
static int
1095
astfold_type_param(type_param_ty node_, PyArena *ctx_, _PyASTOptimizeState *state)
1096
{
1097
switch (node_->kind) {
1098
case TypeVar_kind:
1099
CALL_OPT(astfold_expr, expr_ty, node_->v.TypeVar.bound);
1100
break;
1101
case ParamSpec_kind:
1102
break;
1103
case TypeVarTuple_kind:
1104
break;
1105
}
1106
return 1;
1107
}
1108
1109
#undef CALL
1110
#undef CALL_OPT
1111
#undef CALL_SEQ
1112
1113
/* See comments in symtable.c. */
1114
#define COMPILER_STACK_FRAME_SCALE 3
1115
1116
int
1117
_PyAST_Optimize(mod_ty mod, PyArena *arena, int optimize, int ff_features)
1118
{
1119
PyThreadState *tstate;
1120
int starting_recursion_depth;
1121
1122
_PyASTOptimizeState state;
1123
state.optimize = optimize;
1124
state.ff_features = ff_features;
1125
1126
/* Setup recursion depth check counters */
1127
tstate = _PyThreadState_GET();
1128
if (!tstate) {
1129
return 0;
1130
}
1131
/* Be careful here to prevent overflow. */
1132
int recursion_depth = C_RECURSION_LIMIT - tstate->c_recursion_remaining;
1133
starting_recursion_depth = recursion_depth * COMPILER_STACK_FRAME_SCALE;
1134
state.recursion_depth = starting_recursion_depth;
1135
state.recursion_limit = C_RECURSION_LIMIT * COMPILER_STACK_FRAME_SCALE;
1136
1137
int ret = astfold_mod(mod, arena, &state);
1138
assert(ret || PyErr_Occurred());
1139
1140
/* Check that the recursion depth counting balanced correctly */
1141
if (ret && state.recursion_depth != starting_recursion_depth) {
1142
PyErr_Format(PyExc_SystemError,
1143
"AST optimizer recursion depth mismatch (before=%d, after=%d)",
1144
starting_recursion_depth, state.recursion_depth);
1145
return 0;
1146
}
1147
1148
return ret;
1149
}
1150
1151