Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
allendowney
GitHub Repository: allendowney/cpython
Path: blob/main/Parser/asdl_c.py
12 views
1
#! /usr/bin/env python
2
"""Generate C code from an ASDL description."""
3
4
import sys
5
import textwrap
6
import types
7
8
from argparse import ArgumentParser
9
from contextlib import contextmanager
10
from pathlib import Path
11
12
import asdl
13
14
TABSIZE = 4
15
MAX_COL = 80
16
AUTOGEN_MESSAGE = "// File automatically generated by {}.\n\n"
17
18
def get_c_type(name):
19
"""Return a string for the C name of the type.
20
21
This function special cases the default types provided by asdl.
22
"""
23
if name in asdl.builtin_types:
24
return name
25
else:
26
return "%s_ty" % name
27
28
def reflow_lines(s, depth):
29
"""Reflow the line s indented depth tabs.
30
31
Return a sequence of lines where no line extends beyond MAX_COL
32
when properly indented. The first line is properly indented based
33
exclusively on depth * TABSIZE. All following lines -- these are
34
the reflowed lines generated by this function -- start at the same
35
column as the first character beyond the opening { in the first
36
line.
37
"""
38
size = MAX_COL - depth * TABSIZE
39
if len(s) < size:
40
return [s]
41
42
lines = []
43
cur = s
44
padding = ""
45
while len(cur) > size:
46
i = cur.rfind(' ', 0, size)
47
# XXX this should be fixed for real
48
if i == -1 and 'GeneratorExp' in cur:
49
i = size + 3
50
assert i != -1, "Impossible line %d to reflow: %r" % (size, s)
51
lines.append(padding + cur[:i])
52
if len(lines) == 1:
53
# find new size based on brace
54
j = cur.find('{', 0, i)
55
if j >= 0:
56
j += 2 # account for the brace and the space after it
57
size -= j
58
padding = " " * j
59
else:
60
j = cur.find('(', 0, i)
61
if j >= 0:
62
j += 1 # account for the paren (no space after it)
63
size -= j
64
padding = " " * j
65
cur = cur[i+1:]
66
else:
67
lines.append(padding + cur)
68
return lines
69
70
def reflow_c_string(s, depth):
71
return '"%s"' % s.replace('\n', '\\n"\n%s"' % (' ' * depth * TABSIZE))
72
73
def is_simple(sum_type):
74
"""Return True if a sum is a simple.
75
76
A sum is simple if its types have no fields and itself
77
doesn't have any attributes. Instances of these types are
78
cached at C level, and they act like singletons when propagating
79
parser generated nodes into Python level, e.g.
80
unaryop = Invert | Not | UAdd | USub
81
"""
82
83
return not (
84
sum_type.attributes or
85
any(constructor.fields for constructor in sum_type.types)
86
)
87
88
def asdl_of(name, obj):
89
if isinstance(obj, asdl.Product) or isinstance(obj, asdl.Constructor):
90
fields = ", ".join(map(str, obj.fields))
91
if fields:
92
fields = "({})".format(fields)
93
return "{}{}".format(name, fields)
94
else:
95
if is_simple(obj):
96
types = " | ".join(type.name for type in obj.types)
97
else:
98
sep = "\n{}| ".format(" " * (len(name) + 1))
99
types = sep.join(
100
asdl_of(type.name, type) for type in obj.types
101
)
102
return "{} = {}".format(name, types)
103
104
class EmitVisitor(asdl.VisitorBase):
105
"""Visit that emits lines"""
106
107
def __init__(self, file, metadata = None):
108
self.file = file
109
self._metadata = metadata
110
super(EmitVisitor, self).__init__()
111
112
def emit(self, s, depth, reflow=True):
113
# XXX reflow long lines?
114
if reflow:
115
lines = reflow_lines(s, depth)
116
else:
117
lines = [s]
118
for line in lines:
119
if line:
120
line = (" " * TABSIZE * depth) + line
121
self.file.write(line + "\n")
122
123
@property
124
def metadata(self):
125
if self._metadata is None:
126
raise ValueError(
127
"%s was expecting to be annnotated with metadata"
128
% type(self).__name__
129
)
130
return self._metadata
131
132
@metadata.setter
133
def metadata(self, value):
134
self._metadata = value
135
136
class MetadataVisitor(asdl.VisitorBase):
137
ROOT_TYPE = "AST"
138
139
def __init__(self, *args, **kwargs):
140
super().__init__(*args, **kwargs)
141
142
# Metadata:
143
# - simple_sums: Tracks the list of compound type
144
# names where all the constructors
145
# belonging to that type lack of any
146
# fields.
147
# - identifiers: All identifiers used in the AST declarations
148
# - singletons: List of all constructors that originates from
149
# simple sums.
150
# - types: List of all top level type names
151
#
152
self.metadata = types.SimpleNamespace(
153
simple_sums=set(),
154
identifiers=set(),
155
singletons=set(),
156
types={self.ROOT_TYPE},
157
)
158
159
def visitModule(self, mod):
160
for dfn in mod.dfns:
161
self.visit(dfn)
162
163
def visitType(self, type):
164
self.visit(type.value, type.name)
165
166
def visitSum(self, sum, name):
167
self.metadata.types.add(name)
168
169
simple_sum = is_simple(sum)
170
if simple_sum:
171
self.metadata.simple_sums.add(name)
172
173
for constructor in sum.types:
174
if simple_sum:
175
self.metadata.singletons.add(constructor.name)
176
self.visitConstructor(constructor)
177
self.visitFields(sum.attributes)
178
179
def visitConstructor(self, constructor):
180
self.metadata.types.add(constructor.name)
181
self.visitFields(constructor.fields)
182
183
def visitProduct(self, product, name):
184
self.metadata.types.add(name)
185
self.visitFields(product.attributes)
186
self.visitFields(product.fields)
187
188
def visitFields(self, fields):
189
for field in fields:
190
self.visitField(field)
191
192
def visitField(self, field):
193
self.metadata.identifiers.add(field.name)
194
195
196
class TypeDefVisitor(EmitVisitor):
197
def visitModule(self, mod):
198
for dfn in mod.dfns:
199
self.visit(dfn)
200
201
def visitType(self, type, depth=0):
202
self.visit(type.value, type.name, depth)
203
204
def visitSum(self, sum, name, depth):
205
if is_simple(sum):
206
self.simple_sum(sum, name, depth)
207
else:
208
self.sum_with_constructors(sum, name, depth)
209
210
def simple_sum(self, sum, name, depth):
211
enum = []
212
for i in range(len(sum.types)):
213
type = sum.types[i]
214
enum.append("%s=%d" % (type.name, i + 1))
215
enums = ", ".join(enum)
216
ctype = get_c_type(name)
217
s = "typedef enum _%s { %s } %s;" % (name, enums, ctype)
218
self.emit(s, depth)
219
self.emit("", depth)
220
221
def sum_with_constructors(self, sum, name, depth):
222
ctype = get_c_type(name)
223
s = "typedef struct _%(name)s *%(ctype)s;" % locals()
224
self.emit(s, depth)
225
self.emit("", depth)
226
227
def visitProduct(self, product, name, depth):
228
ctype = get_c_type(name)
229
s = "typedef struct _%(name)s *%(ctype)s;" % locals()
230
self.emit(s, depth)
231
self.emit("", depth)
232
233
class SequenceDefVisitor(EmitVisitor):
234
def visitModule(self, mod):
235
for dfn in mod.dfns:
236
self.visit(dfn)
237
238
def visitType(self, type, depth=0):
239
self.visit(type.value, type.name, depth)
240
241
def visitSum(self, sum, name, depth):
242
if is_simple(sum):
243
return
244
self.emit_sequence_constructor(name, depth)
245
246
def emit_sequence_constructor(self, name,depth):
247
ctype = get_c_type(name)
248
self.emit("""\
249
typedef struct {
250
_ASDL_SEQ_HEAD
251
%(ctype)s typed_elements[1];
252
} asdl_%(name)s_seq;""" % locals(), reflow=False, depth=depth)
253
self.emit("", depth)
254
self.emit("asdl_%(name)s_seq *_Py_asdl_%(name)s_seq_new(Py_ssize_t size, PyArena *arena);" % locals(), depth)
255
self.emit("", depth)
256
257
def visitProduct(self, product, name, depth):
258
self.emit_sequence_constructor(name, depth)
259
260
class StructVisitor(EmitVisitor):
261
"""Visitor to generate typedefs for AST."""
262
263
def visitModule(self, mod):
264
for dfn in mod.dfns:
265
self.visit(dfn)
266
267
def visitType(self, type, depth=0):
268
self.visit(type.value, type.name, depth)
269
270
def visitSum(self, sum, name, depth):
271
if not is_simple(sum):
272
self.sum_with_constructors(sum, name, depth)
273
274
def sum_with_constructors(self, sum, name, depth):
275
def emit(s, depth=depth):
276
self.emit(s % sys._getframe(1).f_locals, depth)
277
enum = []
278
for i in range(len(sum.types)):
279
type = sum.types[i]
280
enum.append("%s_kind=%d" % (type.name, i + 1))
281
282
emit("enum _%(name)s_kind {" + ", ".join(enum) + "};")
283
284
emit("struct _%(name)s {")
285
emit("enum _%(name)s_kind kind;", depth + 1)
286
emit("union {", depth + 1)
287
for t in sum.types:
288
self.visit(t, depth + 2)
289
emit("} v;", depth + 1)
290
for field in sum.attributes:
291
# rudimentary attribute handling
292
type = str(field.type)
293
assert type in asdl.builtin_types, type
294
emit("%s %s;" % (type, field.name), depth + 1);
295
emit("};")
296
emit("")
297
298
def visitConstructor(self, cons, depth):
299
if cons.fields:
300
self.emit("struct {", depth)
301
for f in cons.fields:
302
self.visit(f, depth + 1)
303
self.emit("} %s;" % cons.name, depth)
304
self.emit("", depth)
305
306
def visitField(self, field, depth):
307
# XXX need to lookup field.type, because it might be something
308
# like a builtin...
309
ctype = get_c_type(field.type)
310
name = field.name
311
if field.seq:
312
if field.type in self.metadata.simple_sums:
313
self.emit("asdl_int_seq *%(name)s;" % locals(), depth)
314
else:
315
_type = field.type
316
self.emit("asdl_%(_type)s_seq *%(name)s;" % locals(), depth)
317
else:
318
self.emit("%(ctype)s %(name)s;" % locals(), depth)
319
320
def visitProduct(self, product, name, depth):
321
self.emit("struct _%(name)s {" % locals(), depth)
322
for f in product.fields:
323
self.visit(f, depth + 1)
324
for field in product.attributes:
325
# rudimentary attribute handling
326
type = str(field.type)
327
assert type in asdl.builtin_types, type
328
self.emit("%s %s;" % (type, field.name), depth + 1);
329
self.emit("};", depth)
330
self.emit("", depth)
331
332
333
def ast_func_name(name):
334
return f"_PyAST_{name}"
335
336
337
class PrototypeVisitor(EmitVisitor):
338
"""Generate function prototypes for the .h file"""
339
340
def visitModule(self, mod):
341
for dfn in mod.dfns:
342
self.visit(dfn)
343
344
def visitType(self, type):
345
self.visit(type.value, type.name)
346
347
def visitSum(self, sum, name):
348
if is_simple(sum):
349
pass # XXX
350
else:
351
for t in sum.types:
352
self.visit(t, name, sum.attributes)
353
354
def get_args(self, fields):
355
"""Return list of C argument info, one for each field.
356
357
Argument info is 3-tuple of a C type, variable name, and flag
358
that is true if type can be NULL.
359
"""
360
args = []
361
unnamed = {}
362
for f in fields:
363
if f.name is None:
364
name = f.type
365
c = unnamed[name] = unnamed.get(name, 0) + 1
366
if c > 1:
367
name = "name%d" % (c - 1)
368
else:
369
name = f.name
370
# XXX should extend get_c_type() to handle this
371
if f.seq:
372
if f.type in self.metadata.simple_sums:
373
ctype = "asdl_int_seq *"
374
else:
375
ctype = f"asdl_{f.type}_seq *"
376
else:
377
ctype = get_c_type(f.type)
378
args.append((ctype, name, f.opt or f.seq))
379
return args
380
381
def visitConstructor(self, cons, type, attrs):
382
args = self.get_args(cons.fields)
383
attrs = self.get_args(attrs)
384
ctype = get_c_type(type)
385
self.emit_function(cons.name, ctype, args, attrs)
386
387
def emit_function(self, name, ctype, args, attrs, union=True):
388
args = args + attrs
389
if args:
390
argstr = ", ".join(["%s %s" % (atype, aname)
391
for atype, aname, opt in args])
392
argstr += ", PyArena *arena"
393
else:
394
argstr = "PyArena *arena"
395
self.emit("%s %s(%s);" % (ctype, ast_func_name(name), argstr), False)
396
397
def visitProduct(self, prod, name):
398
self.emit_function(name, get_c_type(name),
399
self.get_args(prod.fields),
400
self.get_args(prod.attributes),
401
union=False)
402
403
404
class FunctionVisitor(PrototypeVisitor):
405
"""Visitor to generate constructor functions for AST."""
406
407
def emit_function(self, name, ctype, args, attrs, union=True):
408
def emit(s, depth=0, reflow=True):
409
self.emit(s, depth, reflow)
410
argstr = ", ".join(["%s %s" % (atype, aname)
411
for atype, aname, opt in args + attrs])
412
if argstr:
413
argstr += ", PyArena *arena"
414
else:
415
argstr = "PyArena *arena"
416
self.emit("%s" % ctype, 0)
417
emit("%s(%s)" % (ast_func_name(name), argstr))
418
emit("{")
419
emit("%s p;" % ctype, 1)
420
for argtype, argname, opt in args:
421
if not opt and argtype != "int":
422
emit("if (!%s) {" % argname, 1)
423
emit("PyErr_SetString(PyExc_ValueError,", 2)
424
msg = "field '%s' is required for %s" % (argname, name)
425
emit(' "%s");' % msg,
426
2, reflow=False)
427
emit('return NULL;', 2)
428
emit('}', 1)
429
430
emit("p = (%s)_PyArena_Malloc(arena, sizeof(*p));" % ctype, 1);
431
emit("if (!p)", 1)
432
emit("return NULL;", 2)
433
if union:
434
self.emit_body_union(name, args, attrs)
435
else:
436
self.emit_body_struct(name, args, attrs)
437
emit("return p;", 1)
438
emit("}")
439
emit("")
440
441
def emit_body_union(self, name, args, attrs):
442
def emit(s, depth=0, reflow=True):
443
self.emit(s, depth, reflow)
444
emit("p->kind = %s_kind;" % name, 1)
445
for argtype, argname, opt in args:
446
emit("p->v.%s.%s = %s;" % (name, argname, argname), 1)
447
for argtype, argname, opt in attrs:
448
emit("p->%s = %s;" % (argname, argname), 1)
449
450
def emit_body_struct(self, name, args, attrs):
451
def emit(s, depth=0, reflow=True):
452
self.emit(s, depth, reflow)
453
for argtype, argname, opt in args:
454
emit("p->%s = %s;" % (argname, argname), 1)
455
for argtype, argname, opt in attrs:
456
emit("p->%s = %s;" % (argname, argname), 1)
457
458
459
class PickleVisitor(EmitVisitor):
460
461
def visitModule(self, mod):
462
for dfn in mod.dfns:
463
self.visit(dfn)
464
465
def visitType(self, type):
466
self.visit(type.value, type.name)
467
468
def visitSum(self, sum, name):
469
pass
470
471
def visitProduct(self, sum, name):
472
pass
473
474
def visitConstructor(self, cons, name):
475
pass
476
477
def visitField(self, sum):
478
pass
479
480
481
class Obj2ModPrototypeVisitor(PickleVisitor):
482
def visitProduct(self, prod, name):
483
code = "static int obj2ast_%s(struct ast_state *state, PyObject* obj, %s* out, PyArena* arena);"
484
self.emit(code % (name, get_c_type(name)), 0)
485
486
visitSum = visitProduct
487
488
489
class Obj2ModVisitor(PickleVisitor):
490
491
attribute_special_defaults = {
492
"end_lineno": "lineno",
493
"end_col_offset": "col_offset",
494
}
495
496
@contextmanager
497
def recursive_call(self, node, level):
498
self.emit('if (_Py_EnterRecursiveCall(" while traversing \'%s\' node")) {' % node, level, reflow=False)
499
self.emit('goto failed;', level + 1)
500
self.emit('}', level)
501
yield
502
self.emit('_Py_LeaveRecursiveCall();', level)
503
504
def funcHeader(self, name):
505
ctype = get_c_type(name)
506
self.emit("int", 0)
507
self.emit("obj2ast_%s(struct ast_state *state, PyObject* obj, %s* out, PyArena* arena)" % (name, ctype), 0)
508
self.emit("{", 0)
509
self.emit("int isinstance;", 1)
510
self.emit("", 0)
511
512
def sumTrailer(self, name, add_label=False):
513
self.emit("", 0)
514
# there's really nothing more we can do if this fails ...
515
error = "expected some sort of %s, but got %%R" % name
516
format = "PyErr_Format(PyExc_TypeError, \"%s\", obj);"
517
self.emit(format % error, 1, reflow=False)
518
if add_label:
519
self.emit("failed:", 1)
520
self.emit("Py_XDECREF(tmp);", 1)
521
self.emit("return 1;", 1)
522
self.emit("}", 0)
523
self.emit("", 0)
524
525
def simpleSum(self, sum, name):
526
self.funcHeader(name)
527
for t in sum.types:
528
line = ("isinstance = PyObject_IsInstance(obj, "
529
"state->%s_type);")
530
self.emit(line % (t.name,), 1)
531
self.emit("if (isinstance == -1) {", 1)
532
self.emit("return 1;", 2)
533
self.emit("}", 1)
534
self.emit("if (isinstance) {", 1)
535
self.emit("*out = %s;" % t.name, 2)
536
self.emit("return 0;", 2)
537
self.emit("}", 1)
538
self.sumTrailer(name)
539
540
def buildArgs(self, fields):
541
return ", ".join(fields + ["arena"])
542
543
def complexSum(self, sum, name):
544
self.funcHeader(name)
545
self.emit("PyObject *tmp = NULL;", 1)
546
self.emit("PyObject *tp;", 1)
547
for a in sum.attributes:
548
self.visitAttributeDeclaration(a, name, sum=sum)
549
self.emit("", 0)
550
# XXX: should we only do this for 'expr'?
551
self.emit("if (obj == Py_None) {", 1)
552
self.emit("*out = NULL;", 2)
553
self.emit("return 0;", 2)
554
self.emit("}", 1)
555
for a in sum.attributes:
556
self.visitField(a, name, sum=sum, depth=1)
557
for t in sum.types:
558
self.emit("tp = state->%s_type;" % (t.name,), 1)
559
self.emit("isinstance = PyObject_IsInstance(obj, tp);", 1)
560
self.emit("if (isinstance == -1) {", 1)
561
self.emit("return 1;", 2)
562
self.emit("}", 1)
563
self.emit("if (isinstance) {", 1)
564
for f in t.fields:
565
self.visitFieldDeclaration(f, t.name, sum=sum, depth=2)
566
self.emit("", 0)
567
for f in t.fields:
568
self.visitField(f, t.name, sum=sum, depth=2)
569
args = [f.name for f in t.fields] + [a.name for a in sum.attributes]
570
self.emit("*out = %s(%s);" % (ast_func_name(t.name), self.buildArgs(args)), 2)
571
self.emit("if (*out == NULL) goto failed;", 2)
572
self.emit("return 0;", 2)
573
self.emit("}", 1)
574
self.sumTrailer(name, True)
575
576
def visitAttributeDeclaration(self, a, name, sum=sum):
577
ctype = get_c_type(a.type)
578
self.emit("%s %s;" % (ctype, a.name), 1)
579
580
def visitSum(self, sum, name):
581
if is_simple(sum):
582
self.simpleSum(sum, name)
583
else:
584
self.complexSum(sum, name)
585
586
def visitProduct(self, prod, name):
587
ctype = get_c_type(name)
588
self.emit("int", 0)
589
self.emit("obj2ast_%s(struct ast_state *state, PyObject* obj, %s* out, PyArena* arena)" % (name, ctype), 0)
590
self.emit("{", 0)
591
self.emit("PyObject* tmp = NULL;", 1)
592
for f in prod.fields:
593
self.visitFieldDeclaration(f, name, prod=prod, depth=1)
594
for a in prod.attributes:
595
self.visitFieldDeclaration(a, name, prod=prod, depth=1)
596
self.emit("", 0)
597
for f in prod.fields:
598
self.visitField(f, name, prod=prod, depth=1)
599
for a in prod.attributes:
600
self.visitField(a, name, prod=prod, depth=1)
601
args = [f.name for f in prod.fields]
602
args.extend([a.name for a in prod.attributes])
603
self.emit("*out = %s(%s);" % (ast_func_name(name), self.buildArgs(args)), 1)
604
self.emit("if (*out == NULL) goto failed;", 1)
605
self.emit("return 0;", 1)
606
self.emit("failed:", 0)
607
self.emit("Py_XDECREF(tmp);", 1)
608
self.emit("return 1;", 1)
609
self.emit("}", 0)
610
self.emit("", 0)
611
612
def visitFieldDeclaration(self, field, name, sum=None, prod=None, depth=0):
613
ctype = get_c_type(field.type)
614
if field.seq:
615
if self.isSimpleType(field):
616
self.emit("asdl_int_seq* %s;" % field.name, depth)
617
else:
618
_type = field.type
619
self.emit(f"asdl_{field.type}_seq* {field.name};", depth)
620
else:
621
ctype = get_c_type(field.type)
622
self.emit("%s %s;" % (ctype, field.name), depth)
623
624
def isNumeric(self, field):
625
return get_c_type(field.type) in ("int", "bool")
626
627
def isSimpleType(self, field):
628
return field.type in self.metadata.simple_sums or self.isNumeric(field)
629
630
def visitField(self, field, name, sum=None, prod=None, depth=0):
631
ctype = get_c_type(field.type)
632
line = "if (_PyObject_LookupAttr(obj, state->%s, &tmp) < 0) {"
633
self.emit(line % field.name, depth)
634
self.emit("return 1;", depth+1)
635
self.emit("}", depth)
636
if field.seq:
637
self.emit("if (tmp == NULL) {", depth)
638
self.emit("tmp = PyList_New(0);", depth+1)
639
self.emit("if (tmp == NULL) {", depth+1)
640
self.emit("return 1;", depth+2)
641
self.emit("}", depth+1)
642
self.emit("}", depth)
643
self.emit("{", depth)
644
else:
645
if not field.opt:
646
self.emit("if (tmp == NULL) {", depth)
647
message = "required field \\\"%s\\\" missing from %s" % (field.name, name)
648
format = "PyErr_SetString(PyExc_TypeError, \"%s\");"
649
self.emit(format % message, depth+1, reflow=False)
650
self.emit("return 1;", depth+1)
651
else:
652
self.emit("if (tmp == NULL || tmp == Py_None) {", depth)
653
self.emit("Py_CLEAR(tmp);", depth+1)
654
if self.isNumeric(field):
655
if field.name in self.attribute_special_defaults:
656
self.emit(
657
"%s = %s;" % (field.name, self.attribute_special_defaults[field.name]),
658
depth+1,
659
)
660
else:
661
self.emit("%s = 0;" % field.name, depth+1)
662
elif not self.isSimpleType(field):
663
self.emit("%s = NULL;" % field.name, depth+1)
664
else:
665
raise TypeError("could not determine the default value for %s" % field.name)
666
self.emit("}", depth)
667
self.emit("else {", depth)
668
669
self.emit("int res;", depth+1)
670
if field.seq:
671
self.emit("Py_ssize_t len;", depth+1)
672
self.emit("Py_ssize_t i;", depth+1)
673
self.emit("if (!PyList_Check(tmp)) {", depth+1)
674
self.emit("PyErr_Format(PyExc_TypeError, \"%s field \\\"%s\\\" must "
675
"be a list, not a %%.200s\", _PyType_Name(Py_TYPE(tmp)));" %
676
(name, field.name),
677
depth+2, reflow=False)
678
self.emit("goto failed;", depth+2)
679
self.emit("}", depth+1)
680
self.emit("len = PyList_GET_SIZE(tmp);", depth+1)
681
if self.isSimpleType(field):
682
self.emit("%s = _Py_asdl_int_seq_new(len, arena);" % field.name, depth+1)
683
else:
684
self.emit("%s = _Py_asdl_%s_seq_new(len, arena);" % (field.name, field.type), depth+1)
685
self.emit("if (%s == NULL) goto failed;" % field.name, depth+1)
686
self.emit("for (i = 0; i < len; i++) {", depth+1)
687
self.emit("%s val;" % ctype, depth+2)
688
self.emit("PyObject *tmp2 = Py_NewRef(PyList_GET_ITEM(tmp, i));", depth+2)
689
with self.recursive_call(name, depth+2):
690
self.emit("res = obj2ast_%s(state, tmp2, &val, arena);" %
691
field.type, depth+2, reflow=False)
692
self.emit("Py_DECREF(tmp2);", depth+2)
693
self.emit("if (res != 0) goto failed;", depth+2)
694
self.emit("if (len != PyList_GET_SIZE(tmp)) {", depth+2)
695
self.emit("PyErr_SetString(PyExc_RuntimeError, \"%s field \\\"%s\\\" "
696
"changed size during iteration\");" %
697
(name, field.name),
698
depth+3, reflow=False)
699
self.emit("goto failed;", depth+3)
700
self.emit("}", depth+2)
701
self.emit("asdl_seq_SET(%s, i, val);" % field.name, depth+2)
702
self.emit("}", depth+1)
703
else:
704
with self.recursive_call(name, depth+1):
705
self.emit("res = obj2ast_%s(state, tmp, &%s, arena);" %
706
(field.type, field.name), depth+1)
707
self.emit("if (res != 0) goto failed;", depth+1)
708
709
self.emit("Py_CLEAR(tmp);", depth+1)
710
self.emit("}", depth)
711
712
713
class SequenceConstructorVisitor(EmitVisitor):
714
def visitModule(self, mod):
715
for dfn in mod.dfns:
716
self.visit(dfn)
717
718
def visitType(self, type):
719
self.visit(type.value, type.name)
720
721
def visitProduct(self, prod, name):
722
self.emit_sequence_constructor(name, get_c_type(name))
723
724
def visitSum(self, sum, name):
725
if not is_simple(sum):
726
self.emit_sequence_constructor(name, get_c_type(name))
727
728
def emit_sequence_constructor(self, name, type):
729
self.emit(f"GENERATE_ASDL_SEQ_CONSTRUCTOR({name}, {type})", depth=0)
730
731
class PyTypesDeclareVisitor(PickleVisitor):
732
733
def visitProduct(self, prod, name):
734
self.emit("static PyObject* ast2obj_%s(struct ast_state *state, void*);" % name, 0)
735
if prod.attributes:
736
self.emit("static const char * const %s_attributes[] = {" % name, 0)
737
for a in prod.attributes:
738
self.emit('"%s",' % a.name, 1)
739
self.emit("};", 0)
740
if prod.fields:
741
self.emit("static const char * const %s_fields[]={" % name,0)
742
for f in prod.fields:
743
self.emit('"%s",' % f.name, 1)
744
self.emit("};", 0)
745
746
def visitSum(self, sum, name):
747
if sum.attributes:
748
self.emit("static const char * const %s_attributes[] = {" % name, 0)
749
for a in sum.attributes:
750
self.emit('"%s",' % a.name, 1)
751
self.emit("};", 0)
752
ptype = "void*"
753
if is_simple(sum):
754
ptype = get_c_type(name)
755
self.emit("static PyObject* ast2obj_%s(struct ast_state *state, %s);" % (name, ptype), 0)
756
for t in sum.types:
757
self.visitConstructor(t, name)
758
759
def visitConstructor(self, cons, name):
760
if cons.fields:
761
self.emit("static const char * const %s_fields[]={" % cons.name, 0)
762
for t in cons.fields:
763
self.emit('"%s",' % t.name, 1)
764
self.emit("};",0)
765
766
767
class PyTypesVisitor(PickleVisitor):
768
769
def visitModule(self, mod):
770
self.emit("""
771
772
typedef struct {
773
PyObject_HEAD
774
PyObject *dict;
775
} AST_object;
776
777
static void
778
ast_dealloc(AST_object *self)
779
{
780
/* bpo-31095: UnTrack is needed before calling any callbacks */
781
PyTypeObject *tp = Py_TYPE(self);
782
PyObject_GC_UnTrack(self);
783
Py_CLEAR(self->dict);
784
freefunc free_func = PyType_GetSlot(tp, Py_tp_free);
785
assert(free_func != NULL);
786
free_func(self);
787
Py_DECREF(tp);
788
}
789
790
static int
791
ast_traverse(AST_object *self, visitproc visit, void *arg)
792
{
793
Py_VISIT(Py_TYPE(self));
794
Py_VISIT(self->dict);
795
return 0;
796
}
797
798
static int
799
ast_clear(AST_object *self)
800
{
801
Py_CLEAR(self->dict);
802
return 0;
803
}
804
805
static int
806
ast_type_init(PyObject *self, PyObject *args, PyObject *kw)
807
{
808
struct ast_state *state = get_ast_state();
809
if (state == NULL) {
810
return -1;
811
}
812
813
Py_ssize_t i, numfields = 0;
814
int res = -1;
815
PyObject *key, *value, *fields;
816
if (_PyObject_LookupAttr((PyObject*)Py_TYPE(self), state->_fields, &fields) < 0) {
817
goto cleanup;
818
}
819
if (fields) {
820
numfields = PySequence_Size(fields);
821
if (numfields == -1) {
822
goto cleanup;
823
}
824
}
825
826
res = 0; /* if no error occurs, this stays 0 to the end */
827
if (numfields < PyTuple_GET_SIZE(args)) {
828
PyErr_Format(PyExc_TypeError, "%.400s constructor takes at most "
829
"%zd positional argument%s",
830
_PyType_Name(Py_TYPE(self)),
831
numfields, numfields == 1 ? "" : "s");
832
res = -1;
833
goto cleanup;
834
}
835
for (i = 0; i < PyTuple_GET_SIZE(args); i++) {
836
/* cannot be reached when fields is NULL */
837
PyObject *name = PySequence_GetItem(fields, i);
838
if (!name) {
839
res = -1;
840
goto cleanup;
841
}
842
res = PyObject_SetAttr(self, name, PyTuple_GET_ITEM(args, i));
843
Py_DECREF(name);
844
if (res < 0) {
845
goto cleanup;
846
}
847
}
848
if (kw) {
849
i = 0; /* needed by PyDict_Next */
850
while (PyDict_Next(kw, &i, &key, &value)) {
851
int contains = PySequence_Contains(fields, key);
852
if (contains == -1) {
853
res = -1;
854
goto cleanup;
855
} else if (contains == 1) {
856
Py_ssize_t p = PySequence_Index(fields, key);
857
if (p == -1) {
858
res = -1;
859
goto cleanup;
860
}
861
if (p < PyTuple_GET_SIZE(args)) {
862
PyErr_Format(PyExc_TypeError,
863
"%.400s got multiple values for argument '%U'",
864
Py_TYPE(self)->tp_name, key);
865
res = -1;
866
goto cleanup;
867
}
868
}
869
res = PyObject_SetAttr(self, key, value);
870
if (res < 0) {
871
goto cleanup;
872
}
873
}
874
}
875
cleanup:
876
Py_XDECREF(fields);
877
return res;
878
}
879
880
/* Pickling support */
881
static PyObject *
882
ast_type_reduce(PyObject *self, PyObject *unused)
883
{
884
struct ast_state *state = get_ast_state();
885
if (state == NULL) {
886
return NULL;
887
}
888
889
PyObject *dict;
890
if (_PyObject_LookupAttr(self, state->__dict__, &dict) < 0) {
891
return NULL;
892
}
893
if (dict) {
894
return Py_BuildValue("O()N", Py_TYPE(self), dict);
895
}
896
return Py_BuildValue("O()", Py_TYPE(self));
897
}
898
899
static PyMemberDef ast_type_members[] = {
900
{"__dictoffset__", T_PYSSIZET, offsetof(AST_object, dict), READONLY},
901
{NULL} /* Sentinel */
902
};
903
904
static PyMethodDef ast_type_methods[] = {
905
{"__reduce__", ast_type_reduce, METH_NOARGS, NULL},
906
{NULL}
907
};
908
909
static PyGetSetDef ast_type_getsets[] = {
910
{"__dict__", PyObject_GenericGetDict, PyObject_GenericSetDict},
911
{NULL}
912
};
913
914
static PyType_Slot AST_type_slots[] = {
915
{Py_tp_dealloc, ast_dealloc},
916
{Py_tp_getattro, PyObject_GenericGetAttr},
917
{Py_tp_setattro, PyObject_GenericSetAttr},
918
{Py_tp_traverse, ast_traverse},
919
{Py_tp_clear, ast_clear},
920
{Py_tp_members, ast_type_members},
921
{Py_tp_methods, ast_type_methods},
922
{Py_tp_getset, ast_type_getsets},
923
{Py_tp_init, ast_type_init},
924
{Py_tp_alloc, PyType_GenericAlloc},
925
{Py_tp_new, PyType_GenericNew},
926
{Py_tp_free, PyObject_GC_Del},
927
{0, 0},
928
};
929
930
static PyType_Spec AST_type_spec = {
931
"ast.AST",
932
sizeof(AST_object),
933
0,
934
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC,
935
AST_type_slots
936
};
937
938
static PyObject *
939
make_type(struct ast_state *state, const char *type, PyObject* base,
940
const char* const* fields, int num_fields, const char *doc)
941
{
942
PyObject *fnames, *result;
943
int i;
944
fnames = PyTuple_New(num_fields);
945
if (!fnames) return NULL;
946
for (i = 0; i < num_fields; i++) {
947
PyObject *field = PyUnicode_InternFromString(fields[i]);
948
if (!field) {
949
Py_DECREF(fnames);
950
return NULL;
951
}
952
PyTuple_SET_ITEM(fnames, i, field);
953
}
954
result = PyObject_CallFunction((PyObject*)&PyType_Type, "s(O){OOOOOOOs}",
955
type, base,
956
state->_fields, fnames,
957
state->__match_args__, fnames,
958
state->__module__,
959
state->ast,
960
state->__doc__, doc);
961
Py_DECREF(fnames);
962
return result;
963
}
964
965
static int
966
add_attributes(struct ast_state *state, PyObject *type, const char * const *attrs, int num_fields)
967
{
968
int i, result;
969
PyObject *s, *l = PyTuple_New(num_fields);
970
if (!l)
971
return 0;
972
for (i = 0; i < num_fields; i++) {
973
s = PyUnicode_InternFromString(attrs[i]);
974
if (!s) {
975
Py_DECREF(l);
976
return 0;
977
}
978
PyTuple_SET_ITEM(l, i, s);
979
}
980
result = PyObject_SetAttr(type, state->_attributes, l) >= 0;
981
Py_DECREF(l);
982
return result;
983
}
984
985
/* Conversion AST -> Python */
986
987
static PyObject* ast2obj_list(struct ast_state *state, asdl_seq *seq, PyObject* (*func)(struct ast_state *state, void*))
988
{
989
Py_ssize_t i, n = asdl_seq_LEN(seq);
990
PyObject *result = PyList_New(n);
991
PyObject *value;
992
if (!result)
993
return NULL;
994
for (i = 0; i < n; i++) {
995
value = func(state, asdl_seq_GET_UNTYPED(seq, i));
996
if (!value) {
997
Py_DECREF(result);
998
return NULL;
999
}
1000
PyList_SET_ITEM(result, i, value);
1001
}
1002
return result;
1003
}
1004
1005
static PyObject* ast2obj_object(struct ast_state *Py_UNUSED(state), void *o)
1006
{
1007
PyObject *op = (PyObject*)o;
1008
if (!op) {
1009
op = Py_None;
1010
}
1011
return Py_NewRef(op);
1012
}
1013
#define ast2obj_constant ast2obj_object
1014
#define ast2obj_identifier ast2obj_object
1015
#define ast2obj_string ast2obj_object
1016
1017
static PyObject* ast2obj_int(struct ast_state *Py_UNUSED(state), long b)
1018
{
1019
return PyLong_FromLong(b);
1020
}
1021
1022
/* Conversion Python -> AST */
1023
1024
static int obj2ast_object(struct ast_state *Py_UNUSED(state), PyObject* obj, PyObject** out, PyArena* arena)
1025
{
1026
if (obj == Py_None)
1027
obj = NULL;
1028
if (obj) {
1029
if (_PyArena_AddPyObject(arena, obj) < 0) {
1030
*out = NULL;
1031
return -1;
1032
}
1033
*out = Py_NewRef(obj);
1034
}
1035
else {
1036
*out = NULL;
1037
}
1038
return 0;
1039
}
1040
1041
static int obj2ast_constant(struct ast_state *Py_UNUSED(state), PyObject* obj, PyObject** out, PyArena* arena)
1042
{
1043
if (_PyArena_AddPyObject(arena, obj) < 0) {
1044
*out = NULL;
1045
return -1;
1046
}
1047
*out = Py_NewRef(obj);
1048
return 0;
1049
}
1050
1051
static int obj2ast_identifier(struct ast_state *state, PyObject* obj, PyObject** out, PyArena* arena)
1052
{
1053
if (!PyUnicode_CheckExact(obj) && obj != Py_None) {
1054
PyErr_SetString(PyExc_TypeError, "AST identifier must be of type str");
1055
return 1;
1056
}
1057
return obj2ast_object(state, obj, out, arena);
1058
}
1059
1060
static int obj2ast_string(struct ast_state *state, PyObject* obj, PyObject** out, PyArena* arena)
1061
{
1062
if (!PyUnicode_CheckExact(obj) && !PyBytes_CheckExact(obj)) {
1063
PyErr_SetString(PyExc_TypeError, "AST string must be of type str");
1064
return 1;
1065
}
1066
return obj2ast_object(state, obj, out, arena);
1067
}
1068
1069
static int obj2ast_int(struct ast_state* Py_UNUSED(state), PyObject* obj, int* out, PyArena* arena)
1070
{
1071
int i;
1072
if (!PyLong_Check(obj)) {
1073
PyErr_Format(PyExc_ValueError, "invalid integer value: %R", obj);
1074
return 1;
1075
}
1076
1077
i = _PyLong_AsInt(obj);
1078
if (i == -1 && PyErr_Occurred())
1079
return 1;
1080
*out = i;
1081
return 0;
1082
}
1083
1084
static int add_ast_fields(struct ast_state *state)
1085
{
1086
PyObject *empty_tuple;
1087
empty_tuple = PyTuple_New(0);
1088
if (!empty_tuple ||
1089
PyObject_SetAttrString(state->AST_type, "_fields", empty_tuple) < 0 ||
1090
PyObject_SetAttrString(state->AST_type, "__match_args__", empty_tuple) < 0 ||
1091
PyObject_SetAttrString(state->AST_type, "_attributes", empty_tuple) < 0) {
1092
Py_XDECREF(empty_tuple);
1093
return -1;
1094
}
1095
Py_DECREF(empty_tuple);
1096
return 0;
1097
}
1098
1099
""", 0, reflow=False)
1100
1101
self.file.write(textwrap.dedent('''
1102
static int
1103
init_types(struct ast_state *state)
1104
{
1105
// init_types() must not be called after _PyAST_Fini()
1106
// has been called
1107
assert(state->initialized >= 0);
1108
1109
if (state->initialized) {
1110
return 1;
1111
}
1112
if (init_identifiers(state) < 0) {
1113
return 0;
1114
}
1115
state->AST_type = PyType_FromSpec(&AST_type_spec);
1116
if (!state->AST_type) {
1117
return 0;
1118
}
1119
if (add_ast_fields(state) < 0) {
1120
return 0;
1121
}
1122
'''))
1123
for dfn in mod.dfns:
1124
self.visit(dfn)
1125
self.file.write(textwrap.dedent('''
1126
state->recursion_depth = 0;
1127
state->recursion_limit = 0;
1128
state->initialized = 1;
1129
return 1;
1130
}
1131
'''))
1132
1133
def visitProduct(self, prod, name):
1134
if prod.fields:
1135
fields = name+"_fields"
1136
else:
1137
fields = "NULL"
1138
self.emit('state->%s_type = make_type(state, "%s", state->AST_type, %s, %d,' %
1139
(name, name, fields, len(prod.fields)), 1)
1140
self.emit('%s);' % reflow_c_string(asdl_of(name, prod), 2), 2, reflow=False)
1141
self.emit("if (!state->%s_type) return 0;" % name, 1)
1142
if prod.attributes:
1143
self.emit("if (!add_attributes(state, state->%s_type, %s_attributes, %d)) return 0;" %
1144
(name, name, len(prod.attributes)), 1)
1145
else:
1146
self.emit("if (!add_attributes(state, state->%s_type, NULL, 0)) return 0;" % name, 1)
1147
self.emit_defaults(name, prod.fields, 1)
1148
self.emit_defaults(name, prod.attributes, 1)
1149
1150
def visitSum(self, sum, name):
1151
self.emit('state->%s_type = make_type(state, "%s", state->AST_type, NULL, 0,' %
1152
(name, name), 1)
1153
self.emit('%s);' % reflow_c_string(asdl_of(name, sum), 2), 2, reflow=False)
1154
self.emit("if (!state->%s_type) return 0;" % name, 1)
1155
if sum.attributes:
1156
self.emit("if (!add_attributes(state, state->%s_type, %s_attributes, %d)) return 0;" %
1157
(name, name, len(sum.attributes)), 1)
1158
else:
1159
self.emit("if (!add_attributes(state, state->%s_type, NULL, 0)) return 0;" % name, 1)
1160
self.emit_defaults(name, sum.attributes, 1)
1161
simple = is_simple(sum)
1162
for t in sum.types:
1163
self.visitConstructor(t, name, simple)
1164
1165
def visitConstructor(self, cons, name, simple):
1166
if cons.fields:
1167
fields = cons.name+"_fields"
1168
else:
1169
fields = "NULL"
1170
self.emit('state->%s_type = make_type(state, "%s", state->%s_type, %s, %d,' %
1171
(cons.name, cons.name, name, fields, len(cons.fields)), 1)
1172
self.emit('%s);' % reflow_c_string(asdl_of(cons.name, cons), 2), 2, reflow=False)
1173
self.emit("if (!state->%s_type) return 0;" % cons.name, 1)
1174
self.emit_defaults(cons.name, cons.fields, 1)
1175
if simple:
1176
self.emit("state->%s_singleton = PyType_GenericNew((PyTypeObject *)"
1177
"state->%s_type, NULL, NULL);" %
1178
(cons.name, cons.name), 1)
1179
self.emit("if (!state->%s_singleton) return 0;" % cons.name, 1)
1180
1181
def emit_defaults(self, name, fields, depth):
1182
for field in fields:
1183
if field.opt:
1184
self.emit('if (PyObject_SetAttr(state->%s_type, state->%s, Py_None) == -1)' %
1185
(name, field.name), depth)
1186
self.emit("return 0;", depth+1)
1187
1188
1189
class ASTModuleVisitor(PickleVisitor):
1190
1191
def visitModule(self, mod):
1192
self.emit("static int", 0)
1193
self.emit("astmodule_exec(PyObject *m)", 0)
1194
self.emit("{", 0)
1195
self.emit('struct ast_state *state = get_ast_state();', 1)
1196
self.emit('if (state == NULL) {', 1)
1197
self.emit('return -1;', 2)
1198
self.emit('}', 1)
1199
self.emit('if (PyModule_AddObjectRef(m, "AST", state->AST_type) < 0) {', 1)
1200
self.emit('return -1;', 2)
1201
self.emit('}', 1)
1202
self.emit('if (PyModule_AddIntMacro(m, PyCF_ALLOW_TOP_LEVEL_AWAIT) < 0) {', 1)
1203
self.emit("return -1;", 2)
1204
self.emit('}', 1)
1205
self.emit('if (PyModule_AddIntMacro(m, PyCF_ONLY_AST) < 0) {', 1)
1206
self.emit("return -1;", 2)
1207
self.emit('}', 1)
1208
self.emit('if (PyModule_AddIntMacro(m, PyCF_TYPE_COMMENTS) < 0) {', 1)
1209
self.emit("return -1;", 2)
1210
self.emit('}', 1)
1211
for dfn in mod.dfns:
1212
self.visit(dfn)
1213
self.emit("return 0;", 1)
1214
self.emit("}", 0)
1215
self.emit("", 0)
1216
self.emit("""
1217
static PyModuleDef_Slot astmodule_slots[] = {
1218
{Py_mod_exec, astmodule_exec},
1219
{Py_mod_multiple_interpreters, Py_MOD_PER_INTERPRETER_GIL_SUPPORTED},
1220
{0, NULL}
1221
};
1222
1223
static struct PyModuleDef _astmodule = {
1224
PyModuleDef_HEAD_INIT,
1225
.m_name = "_ast",
1226
// The _ast module uses a per-interpreter state (PyInterpreterState.ast)
1227
.m_size = 0,
1228
.m_slots = astmodule_slots,
1229
};
1230
1231
PyMODINIT_FUNC
1232
PyInit__ast(void)
1233
{
1234
return PyModuleDef_Init(&_astmodule);
1235
}
1236
""".strip(), 0, reflow=False)
1237
1238
def visitProduct(self, prod, name):
1239
self.addObj(name)
1240
1241
def visitSum(self, sum, name):
1242
self.addObj(name)
1243
for t in sum.types:
1244
self.visitConstructor(t, name)
1245
1246
def visitConstructor(self, cons, name):
1247
self.addObj(cons.name)
1248
1249
def addObj(self, name):
1250
self.emit("if (PyModule_AddObjectRef(m, \"%s\", "
1251
"state->%s_type) < 0) {" % (name, name), 1)
1252
self.emit("return -1;", 2)
1253
self.emit('}', 1)
1254
1255
1256
class StaticVisitor(PickleVisitor):
1257
CODE = '''Very simple, always emit this static code. Override CODE'''
1258
1259
def visit(self, object):
1260
self.emit(self.CODE, 0, reflow=False)
1261
1262
1263
class ObjVisitor(PickleVisitor):
1264
1265
def func_begin(self, name):
1266
ctype = get_c_type(name)
1267
self.emit("PyObject*", 0)
1268
self.emit("ast2obj_%s(struct ast_state *state, void* _o)" % (name), 0)
1269
self.emit("{", 0)
1270
self.emit("%s o = (%s)_o;" % (ctype, ctype), 1)
1271
self.emit("PyObject *result = NULL, *value = NULL;", 1)
1272
self.emit("PyTypeObject *tp;", 1)
1273
self.emit('if (!o) {', 1)
1274
self.emit("Py_RETURN_NONE;", 2)
1275
self.emit("}", 1)
1276
self.emit("if (++state->recursion_depth > state->recursion_limit) {", 1)
1277
self.emit("PyErr_SetString(PyExc_RecursionError,", 2)
1278
self.emit('"maximum recursion depth exceeded during ast construction");', 3)
1279
self.emit("return 0;", 2)
1280
self.emit("}", 1)
1281
1282
def func_end(self):
1283
self.emit("state->recursion_depth--;", 1)
1284
self.emit("return result;", 1)
1285
self.emit("failed:", 0)
1286
self.emit("Py_XDECREF(value);", 1)
1287
self.emit("Py_XDECREF(result);", 1)
1288
self.emit("return NULL;", 1)
1289
self.emit("}", 0)
1290
self.emit("", 0)
1291
1292
def visitSum(self, sum, name):
1293
if is_simple(sum):
1294
self.simpleSum(sum, name)
1295
return
1296
self.func_begin(name)
1297
self.emit("switch (o->kind) {", 1)
1298
for i in range(len(sum.types)):
1299
t = sum.types[i]
1300
self.visitConstructor(t, i + 1, name)
1301
self.emit("}", 1)
1302
for a in sum.attributes:
1303
self.emit("value = ast2obj_%s(state, o->%s);" % (a.type, a.name), 1)
1304
self.emit("if (!value) goto failed;", 1)
1305
self.emit('if (PyObject_SetAttr(result, state->%s, value) < 0)' % a.name, 1)
1306
self.emit('goto failed;', 2)
1307
self.emit('Py_DECREF(value);', 1)
1308
self.func_end()
1309
1310
def simpleSum(self, sum, name):
1311
self.emit("PyObject* ast2obj_%s(struct ast_state *state, %s_ty o)" % (name, name), 0)
1312
self.emit("{", 0)
1313
self.emit("switch(o) {", 1)
1314
for t in sum.types:
1315
self.emit("case %s:" % t.name, 2)
1316
self.emit("return Py_NewRef(state->%s_singleton);" % t.name, 3)
1317
self.emit("}", 1)
1318
self.emit("Py_UNREACHABLE();", 1);
1319
self.emit("}", 0)
1320
1321
def visitProduct(self, prod, name):
1322
self.func_begin(name)
1323
self.emit("tp = (PyTypeObject *)state->%s_type;" % name, 1)
1324
self.emit("result = PyType_GenericNew(tp, NULL, NULL);", 1);
1325
self.emit("if (!result) return NULL;", 1)
1326
for field in prod.fields:
1327
self.visitField(field, name, 1, True)
1328
for a in prod.attributes:
1329
self.emit("value = ast2obj_%s(state, o->%s);" % (a.type, a.name), 1)
1330
self.emit("if (!value) goto failed;", 1)
1331
self.emit("if (PyObject_SetAttr(result, state->%s, value) < 0)" % a.name, 1)
1332
self.emit('goto failed;', 2)
1333
self.emit('Py_DECREF(value);', 1)
1334
self.func_end()
1335
1336
def visitConstructor(self, cons, enum, name):
1337
self.emit("case %s_kind:" % cons.name, 1)
1338
self.emit("tp = (PyTypeObject *)state->%s_type;" % cons.name, 2)
1339
self.emit("result = PyType_GenericNew(tp, NULL, NULL);", 2);
1340
self.emit("if (!result) goto failed;", 2)
1341
for f in cons.fields:
1342
self.visitField(f, cons.name, 2, False)
1343
self.emit("break;", 2)
1344
1345
def visitField(self, field, name, depth, product):
1346
def emit(s, d):
1347
self.emit(s, depth + d)
1348
if product:
1349
value = "o->%s" % field.name
1350
else:
1351
value = "o->v.%s.%s" % (name, field.name)
1352
self.set(field, value, depth)
1353
emit("if (!value) goto failed;", 0)
1354
emit("if (PyObject_SetAttr(result, state->%s, value) == -1)" % field.name, 0)
1355
emit("goto failed;", 1)
1356
emit("Py_DECREF(value);", 0)
1357
1358
def set(self, field, value, depth):
1359
if field.seq:
1360
if field.type in self.metadata.simple_sums:
1361
# While the sequence elements are stored as void*,
1362
# simple sums expects an enum
1363
self.emit("{", depth)
1364
self.emit("Py_ssize_t i, n = asdl_seq_LEN(%s);" % value, depth+1)
1365
self.emit("value = PyList_New(n);", depth+1)
1366
self.emit("if (!value) goto failed;", depth+1)
1367
self.emit("for(i = 0; i < n; i++)", depth+1)
1368
# This cannot fail, so no need for error handling
1369
self.emit(
1370
"PyList_SET_ITEM(value, i, ast2obj_{0}(state, ({0}_ty)asdl_seq_GET({1}, i)));".format(
1371
field.type,
1372
value
1373
),
1374
depth + 2,
1375
reflow=False,
1376
)
1377
self.emit("}", depth)
1378
else:
1379
self.emit("value = ast2obj_list(state, (asdl_seq*)%s, ast2obj_%s);" % (value, field.type), depth)
1380
else:
1381
self.emit("value = ast2obj_%s(state, %s);" % (field.type, value), depth, reflow=False)
1382
1383
1384
class PartingShots(StaticVisitor):
1385
1386
CODE = """
1387
PyObject* PyAST_mod2obj(mod_ty t)
1388
{
1389
struct ast_state *state = get_ast_state();
1390
if (state == NULL) {
1391
return NULL;
1392
}
1393
1394
int starting_recursion_depth;
1395
/* Be careful here to prevent overflow. */
1396
int COMPILER_STACK_FRAME_SCALE = 3;
1397
PyThreadState *tstate = _PyThreadState_GET();
1398
if (!tstate) {
1399
return 0;
1400
}
1401
state->recursion_limit = C_RECURSION_LIMIT * COMPILER_STACK_FRAME_SCALE;
1402
int recursion_depth = C_RECURSION_LIMIT - tstate->c_recursion_remaining;
1403
starting_recursion_depth = recursion_depth * COMPILER_STACK_FRAME_SCALE;
1404
state->recursion_depth = starting_recursion_depth;
1405
1406
PyObject *result = ast2obj_mod(state, t);
1407
1408
/* Check that the recursion depth counting balanced correctly */
1409
if (result && state->recursion_depth != starting_recursion_depth) {
1410
PyErr_Format(PyExc_SystemError,
1411
"AST constructor recursion depth mismatch (before=%d, after=%d)",
1412
starting_recursion_depth, state->recursion_depth);
1413
return 0;
1414
}
1415
return result;
1416
}
1417
1418
/* mode is 0 for "exec", 1 for "eval" and 2 for "single" input */
1419
mod_ty PyAST_obj2mod(PyObject* ast, PyArena* arena, int mode)
1420
{
1421
const char * const req_name[] = {"Module", "Expression", "Interactive"};
1422
int isinstance;
1423
1424
if (PySys_Audit("compile", "OO", ast, Py_None) < 0) {
1425
return NULL;
1426
}
1427
1428
struct ast_state *state = get_ast_state();
1429
if (state == NULL) {
1430
return NULL;
1431
}
1432
1433
PyObject *req_type[3];
1434
req_type[0] = state->Module_type;
1435
req_type[1] = state->Expression_type;
1436
req_type[2] = state->Interactive_type;
1437
1438
assert(0 <= mode && mode <= 2);
1439
1440
isinstance = PyObject_IsInstance(ast, req_type[mode]);
1441
if (isinstance == -1)
1442
return NULL;
1443
if (!isinstance) {
1444
PyErr_Format(PyExc_TypeError, "expected %s node, got %.400s",
1445
req_name[mode], _PyType_Name(Py_TYPE(ast)));
1446
return NULL;
1447
}
1448
1449
mod_ty res = NULL;
1450
if (obj2ast_mod(state, ast, &res, arena) != 0)
1451
return NULL;
1452
else
1453
return res;
1454
}
1455
1456
int PyAST_Check(PyObject* obj)
1457
{
1458
struct ast_state *state = get_ast_state();
1459
if (state == NULL) {
1460
return -1;
1461
}
1462
return PyObject_IsInstance(obj, state->AST_type);
1463
}
1464
"""
1465
1466
class ChainOfVisitors:
1467
def __init__(self, *visitors, metadata = None):
1468
self.visitors = visitors
1469
self.metadata = metadata
1470
1471
def visit(self, object):
1472
for v in self.visitors:
1473
v.metadata = self.metadata
1474
v.visit(object)
1475
v.emit("", 0)
1476
1477
1478
def generate_ast_state(module_state, f):
1479
f.write('struct ast_state {\n')
1480
f.write(' int initialized;\n')
1481
f.write(' int recursion_depth;\n')
1482
f.write(' int recursion_limit;\n')
1483
for s in module_state:
1484
f.write(' PyObject *' + s + ';\n')
1485
f.write('};')
1486
1487
1488
def generate_ast_fini(module_state, f):
1489
f.write(textwrap.dedent("""
1490
void _PyAST_Fini(PyInterpreterState *interp)
1491
{
1492
struct ast_state *state = &interp->ast;
1493
1494
"""))
1495
for s in module_state:
1496
f.write(" Py_CLEAR(state->" + s + ');\n')
1497
f.write(textwrap.dedent("""
1498
Py_CLEAR(_Py_INTERP_CACHED_OBJECT(interp, str_replace_inf));
1499
1500
#if !defined(NDEBUG)
1501
state->initialized = -1;
1502
#else
1503
state->initialized = 0;
1504
#endif
1505
}
1506
1507
"""))
1508
1509
1510
def generate_module_def(mod, metadata, f, internal_h):
1511
# Gather all the data needed for ModuleSpec
1512
state_strings = {
1513
"ast",
1514
"_fields",
1515
"__match_args__",
1516
"__doc__",
1517
"__dict__",
1518
"__module__",
1519
"_attributes",
1520
*metadata.identifiers
1521
}
1522
1523
module_state = state_strings.copy()
1524
module_state.update(
1525
"%s_singleton" % singleton
1526
for singleton in metadata.singletons
1527
)
1528
module_state.update(
1529
"%s_type" % type
1530
for type in metadata.types
1531
)
1532
1533
state_strings = sorted(state_strings)
1534
module_state = sorted(module_state)
1535
1536
generate_ast_state(module_state, internal_h)
1537
1538
print(textwrap.dedent("""
1539
#include "Python.h"
1540
#include "pycore_ast.h"
1541
#include "pycore_ast_state.h" // struct ast_state
1542
#include "pycore_ceval.h" // _Py_EnterRecursiveCall
1543
#include "pycore_interp.h" // _PyInterpreterState.ast
1544
#include "pycore_pystate.h" // _PyInterpreterState_GET()
1545
#include "structmember.h"
1546
#include <stddef.h>
1547
1548
// Forward declaration
1549
static int init_types(struct ast_state *state);
1550
1551
static struct ast_state*
1552
get_ast_state(void)
1553
{
1554
PyInterpreterState *interp = _PyInterpreterState_GET();
1555
struct ast_state *state = &interp->ast;
1556
if (!init_types(state)) {
1557
return NULL;
1558
}
1559
return state;
1560
}
1561
""").strip(), file=f)
1562
1563
generate_ast_fini(module_state, f)
1564
1565
f.write('static int init_identifiers(struct ast_state *state)\n')
1566
f.write('{\n')
1567
for identifier in state_strings:
1568
f.write(' if ((state->' + identifier)
1569
f.write(' = PyUnicode_InternFromString("')
1570
f.write(identifier + '")) == NULL) return 0;\n')
1571
f.write(' return 1;\n')
1572
f.write('};\n\n')
1573
1574
def write_header(mod, metadata, f):
1575
f.write(textwrap.dedent("""
1576
#ifndef Py_INTERNAL_AST_H
1577
#define Py_INTERNAL_AST_H
1578
#ifdef __cplusplus
1579
extern "C" {
1580
#endif
1581
1582
#ifndef Py_BUILD_CORE
1583
# error "this header requires Py_BUILD_CORE define"
1584
#endif
1585
1586
#include "pycore_asdl.h"
1587
1588
""").lstrip())
1589
1590
c = ChainOfVisitors(
1591
TypeDefVisitor(f),
1592
SequenceDefVisitor(f),
1593
StructVisitor(f),
1594
metadata=metadata
1595
)
1596
c.visit(mod)
1597
1598
f.write("// Note: these macros affect function definitions, not only call sites.\n")
1599
prototype_visitor = PrototypeVisitor(f, metadata=metadata)
1600
prototype_visitor.visit(mod)
1601
1602
f.write(textwrap.dedent("""
1603
1604
PyObject* PyAST_mod2obj(mod_ty t);
1605
mod_ty PyAST_obj2mod(PyObject* ast, PyArena* arena, int mode);
1606
int PyAST_Check(PyObject* obj);
1607
1608
extern int _PyAST_Validate(mod_ty);
1609
1610
/* _PyAST_ExprAsUnicode is defined in ast_unparse.c */
1611
extern PyObject* _PyAST_ExprAsUnicode(expr_ty);
1612
1613
/* Return the borrowed reference to the first literal string in the
1614
sequence of statements or NULL if it doesn't start from a literal string.
1615
Doesn't set exception. */
1616
extern PyObject* _PyAST_GetDocString(asdl_stmt_seq *);
1617
1618
#ifdef __cplusplus
1619
}
1620
#endif
1621
#endif /* !Py_INTERNAL_AST_H */
1622
"""))
1623
1624
1625
def write_internal_h_header(mod, f):
1626
print(textwrap.dedent("""
1627
#ifndef Py_INTERNAL_AST_STATE_H
1628
#define Py_INTERNAL_AST_STATE_H
1629
#ifdef __cplusplus
1630
extern "C" {
1631
#endif
1632
1633
#ifndef Py_BUILD_CORE
1634
# error "this header requires Py_BUILD_CORE define"
1635
#endif
1636
""").lstrip(), file=f)
1637
1638
1639
def write_internal_h_footer(mod, f):
1640
print(textwrap.dedent("""
1641
1642
#ifdef __cplusplus
1643
}
1644
#endif
1645
#endif /* !Py_INTERNAL_AST_STATE_H */
1646
"""), file=f)
1647
1648
def write_source(mod, metadata, f, internal_h_file):
1649
generate_module_def(mod, metadata, f, internal_h_file)
1650
1651
v = ChainOfVisitors(
1652
SequenceConstructorVisitor(f),
1653
PyTypesDeclareVisitor(f),
1654
PyTypesVisitor(f),
1655
Obj2ModPrototypeVisitor(f),
1656
FunctionVisitor(f),
1657
ObjVisitor(f),
1658
Obj2ModVisitor(f),
1659
ASTModuleVisitor(f),
1660
PartingShots(f),
1661
metadata=metadata
1662
)
1663
v.visit(mod)
1664
1665
def main(input_filename, c_filename, h_filename, internal_h_filename, dump_module=False):
1666
auto_gen_msg = AUTOGEN_MESSAGE.format("/".join(Path(__file__).parts[-2:]))
1667
mod = asdl.parse(input_filename)
1668
if dump_module:
1669
print('Parsed Module:')
1670
print(mod)
1671
if not asdl.check(mod):
1672
sys.exit(1)
1673
1674
metadata_visitor = MetadataVisitor()
1675
metadata_visitor.visit(mod)
1676
metadata = metadata_visitor.metadata
1677
1678
with c_filename.open("w") as c_file, \
1679
h_filename.open("w") as h_file, \
1680
internal_h_filename.open("w") as internal_h_file:
1681
c_file.write(auto_gen_msg)
1682
h_file.write(auto_gen_msg)
1683
internal_h_file.write(auto_gen_msg)
1684
1685
write_internal_h_header(mod, internal_h_file)
1686
write_source(mod, metadata, c_file, internal_h_file)
1687
write_header(mod, metadata, h_file)
1688
write_internal_h_footer(mod, internal_h_file)
1689
1690
print(f"{c_filename}, {h_filename}, {internal_h_filename} regenerated.")
1691
1692
if __name__ == "__main__":
1693
parser = ArgumentParser()
1694
parser.add_argument("input_file", type=Path)
1695
parser.add_argument("-C", "--c-file", type=Path, required=True)
1696
parser.add_argument("-H", "--h-file", type=Path, required=True)
1697
parser.add_argument("-I", "--internal-h-file", type=Path, required=True)
1698
parser.add_argument("-d", "--dump-module", action="store_true")
1699
1700
args = parser.parse_args()
1701
main(args.input_file, args.c_file, args.h_file,
1702
args.internal_h_file, args.dump_module)
1703
1704