Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
allendowney
GitHub Repository: allendowney/cpython
Path: blob/main/Tools/build/umarshal.py
12 views
1
# Implementat marshal.loads() in pure Python
2
3
import ast
4
5
from typing import Any, Tuple
6
7
8
class Type:
9
# Adapted from marshal.c
10
NULL = ord('0')
11
NONE = ord('N')
12
FALSE = ord('F')
13
TRUE = ord('T')
14
STOPITER = ord('S')
15
ELLIPSIS = ord('.')
16
INT = ord('i')
17
INT64 = ord('I')
18
FLOAT = ord('f')
19
BINARY_FLOAT = ord('g')
20
COMPLEX = ord('x')
21
BINARY_COMPLEX = ord('y')
22
LONG = ord('l')
23
STRING = ord('s')
24
INTERNED = ord('t')
25
REF = ord('r')
26
TUPLE = ord('(')
27
LIST = ord('[')
28
DICT = ord('{')
29
CODE = ord('c')
30
UNICODE = ord('u')
31
UNKNOWN = ord('?')
32
SET = ord('<')
33
FROZENSET = ord('>')
34
ASCII = ord('a')
35
ASCII_INTERNED = ord('A')
36
SMALL_TUPLE = ord(')')
37
SHORT_ASCII = ord('z')
38
SHORT_ASCII_INTERNED = ord('Z')
39
40
41
FLAG_REF = 0x80 # with a type, add obj to index
42
43
NULL = object() # marker
44
45
# Cell kinds
46
CO_FAST_LOCAL = 0x20
47
CO_FAST_CELL = 0x40
48
CO_FAST_FREE = 0x80
49
50
51
class Code:
52
def __init__(self, **kwds: Any):
53
self.__dict__.update(kwds)
54
55
def __repr__(self) -> str:
56
return f"Code(**{self.__dict__})"
57
58
co_localsplusnames: Tuple[str]
59
co_localspluskinds: Tuple[int]
60
61
def get_localsplus_names(self, select_kind: int) -> Tuple[str, ...]:
62
varnames: list[str] = []
63
for name, kind in zip(self.co_localsplusnames,
64
self.co_localspluskinds):
65
if kind & select_kind:
66
varnames.append(name)
67
return tuple(varnames)
68
69
@property
70
def co_varnames(self) -> Tuple[str, ...]:
71
return self.get_localsplus_names(CO_FAST_LOCAL)
72
73
@property
74
def co_cellvars(self) -> Tuple[str, ...]:
75
return self.get_localsplus_names(CO_FAST_CELL)
76
77
@property
78
def co_freevars(self) -> Tuple[str, ...]:
79
return self.get_localsplus_names(CO_FAST_FREE)
80
81
@property
82
def co_nlocals(self) -> int:
83
return len(self.co_varnames)
84
85
86
class Reader:
87
# A fairly literal translation of the marshal reader.
88
89
def __init__(self, data: bytes):
90
self.data: bytes = data
91
self.end: int = len(self.data)
92
self.pos: int = 0
93
self.refs: list[Any] = []
94
self.level: int = 0
95
96
def r_string(self, n: int) -> bytes:
97
assert 0 <= n <= self.end - self.pos
98
buf = self.data[self.pos : self.pos + n]
99
self.pos += n
100
return buf
101
102
def r_byte(self) -> int:
103
buf = self.r_string(1)
104
return buf[0]
105
106
def r_short(self) -> int:
107
buf = self.r_string(2)
108
x = buf[0]
109
x |= buf[1] << 8
110
x |= -(x & (1<<15)) # Sign-extend
111
return x
112
113
def r_long(self) -> int:
114
buf = self.r_string(4)
115
x = buf[0]
116
x |= buf[1] << 8
117
x |= buf[2] << 16
118
x |= buf[3] << 24
119
x |= -(x & (1<<31)) # Sign-extend
120
return x
121
122
def r_long64(self) -> int:
123
buf = self.r_string(8)
124
x = buf[0]
125
x |= buf[1] << 8
126
x |= buf[2] << 16
127
x |= buf[3] << 24
128
x |= buf[1] << 32
129
x |= buf[1] << 40
130
x |= buf[1] << 48
131
x |= buf[1] << 56
132
x |= -(x & (1<<63)) # Sign-extend
133
return x
134
135
def r_PyLong(self) -> int:
136
n = self.r_long()
137
size = abs(n)
138
x = 0
139
# Pray this is right
140
for i in range(size):
141
x |= self.r_short() << i*15
142
if n < 0:
143
x = -x
144
return x
145
146
def r_float_bin(self) -> float:
147
buf = self.r_string(8)
148
import struct # Lazy import to avoid breaking UNIX build
149
return struct.unpack("d", buf)[0]
150
151
def r_float_str(self) -> float:
152
n = self.r_byte()
153
buf = self.r_string(n)
154
return ast.literal_eval(buf.decode("ascii"))
155
156
def r_ref_reserve(self, flag: int) -> int:
157
if flag:
158
idx = len(self.refs)
159
self.refs.append(None)
160
return idx
161
else:
162
return 0
163
164
def r_ref_insert(self, obj: Any, idx: int, flag: int) -> Any:
165
if flag:
166
self.refs[idx] = obj
167
return obj
168
169
def r_ref(self, obj: Any, flag: int) -> Any:
170
assert flag & FLAG_REF
171
self.refs.append(obj)
172
return obj
173
174
def r_object(self) -> Any:
175
old_level = self.level
176
try:
177
return self._r_object()
178
finally:
179
self.level = old_level
180
181
def _r_object(self) -> Any:
182
code = self.r_byte()
183
flag = code & FLAG_REF
184
type = code & ~FLAG_REF
185
# print(" "*self.level + f"{code} {flag} {type} {chr(type)!r}")
186
self.level += 1
187
188
def R_REF(obj: Any) -> Any:
189
if flag:
190
obj = self.r_ref(obj, flag)
191
return obj
192
193
if type == Type.NULL:
194
return NULL
195
elif type == Type.NONE:
196
return None
197
elif type == Type.ELLIPSIS:
198
return Ellipsis
199
elif type == Type.FALSE:
200
return False
201
elif type == Type.TRUE:
202
return True
203
elif type == Type.INT:
204
return R_REF(self.r_long())
205
elif type == Type.INT64:
206
return R_REF(self.r_long64())
207
elif type == Type.LONG:
208
return R_REF(self.r_PyLong())
209
elif type == Type.FLOAT:
210
return R_REF(self.r_float_str())
211
elif type == Type.BINARY_FLOAT:
212
return R_REF(self.r_float_bin())
213
elif type == Type.COMPLEX:
214
return R_REF(complex(self.r_float_str(),
215
self.r_float_str()))
216
elif type == Type.BINARY_COMPLEX:
217
return R_REF(complex(self.r_float_bin(),
218
self.r_float_bin()))
219
elif type == Type.STRING:
220
n = self.r_long()
221
return R_REF(self.r_string(n))
222
elif type == Type.ASCII_INTERNED or type == Type.ASCII:
223
n = self.r_long()
224
return R_REF(self.r_string(n).decode("ascii"))
225
elif type == Type.SHORT_ASCII_INTERNED or type == Type.SHORT_ASCII:
226
n = self.r_byte()
227
return R_REF(self.r_string(n).decode("ascii"))
228
elif type == Type.INTERNED or type == Type.UNICODE:
229
n = self.r_long()
230
return R_REF(self.r_string(n).decode("utf8", "surrogatepass"))
231
elif type == Type.SMALL_TUPLE:
232
n = self.r_byte()
233
idx = self.r_ref_reserve(flag)
234
retval: Any = tuple(self.r_object() for _ in range(n))
235
self.r_ref_insert(retval, idx, flag)
236
return retval
237
elif type == Type.TUPLE:
238
n = self.r_long()
239
idx = self.r_ref_reserve(flag)
240
retval = tuple(self.r_object() for _ in range(n))
241
self.r_ref_insert(retval, idx, flag)
242
return retval
243
elif type == Type.LIST:
244
n = self.r_long()
245
retval = R_REF([])
246
for _ in range(n):
247
retval.append(self.r_object())
248
return retval
249
elif type == Type.DICT:
250
retval = R_REF({})
251
while True:
252
key = self.r_object()
253
if key == NULL:
254
break
255
val = self.r_object()
256
retval[key] = val
257
return retval
258
elif type == Type.SET:
259
n = self.r_long()
260
retval = R_REF(set())
261
for _ in range(n):
262
v = self.r_object()
263
retval.add(v)
264
return retval
265
elif type == Type.FROZENSET:
266
n = self.r_long()
267
s: set[Any] = set()
268
idx = self.r_ref_reserve(flag)
269
for _ in range(n):
270
v = self.r_object()
271
s.add(v)
272
retval = frozenset(s)
273
self.r_ref_insert(retval, idx, flag)
274
return retval
275
elif type == Type.CODE:
276
retval = R_REF(Code())
277
retval.co_argcount = self.r_long()
278
retval.co_posonlyargcount = self.r_long()
279
retval.co_kwonlyargcount = self.r_long()
280
retval.co_stacksize = self.r_long()
281
retval.co_flags = self.r_long()
282
retval.co_code = self.r_object()
283
retval.co_consts = self.r_object()
284
retval.co_names = self.r_object()
285
retval.co_localsplusnames = self.r_object()
286
retval.co_localspluskinds = self.r_object()
287
retval.co_filename = self.r_object()
288
retval.co_name = self.r_object()
289
retval.co_qualname = self.r_object()
290
retval.co_firstlineno = self.r_long()
291
retval.co_linetable = self.r_object()
292
retval.co_exceptiontable = self.r_object()
293
return retval
294
elif type == Type.REF:
295
n = self.r_long()
296
retval = self.refs[n]
297
assert retval is not None
298
return retval
299
else:
300
breakpoint()
301
raise AssertionError(f"Unknown type {type} {chr(type)!r}")
302
303
304
def loads(data: bytes) -> Any:
305
assert isinstance(data, bytes)
306
r = Reader(data)
307
return r.r_object()
308
309
310
def main():
311
# Test
312
import marshal, pprint
313
sample = {'foo': {(42, "bar", 3.14)}}
314
data = marshal.dumps(sample)
315
retval = loads(data)
316
assert retval == sample, retval
317
sample = main.__code__
318
data = marshal.dumps(sample)
319
retval = loads(data)
320
assert isinstance(retval, Code), retval
321
pprint.pprint(retval.__dict__)
322
323
324
if __name__ == "__main__":
325
main()
326
327