Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
freebsd
GitHub Repository: freebsd/freebsd-src
Path: blob/main/tests/atf_python/sys/net/rtsock.py
39553 views
1
#!/usr/local/bin/python3
2
import os
3
import socket
4
import struct
5
import sys
6
from ctypes import c_byte
7
from ctypes import c_char
8
from ctypes import c_int
9
from ctypes import c_long
10
from ctypes import c_uint32
11
from ctypes import c_ulong
12
from ctypes import c_ushort
13
from ctypes import sizeof
14
from ctypes import Structure
15
from typing import Dict
16
from typing import List
17
from typing import Optional
18
from typing import Union
19
20
21
def roundup2(val: int, num: int) -> int:
22
if val % num:
23
return (val | (num - 1)) + 1
24
else:
25
return val
26
27
28
class RtSockException(OSError):
29
pass
30
31
32
class RtConst:
33
RTM_VERSION = 5
34
ALIGN = sizeof(c_long)
35
36
AF_INET = socket.AF_INET
37
AF_INET6 = socket.AF_INET6
38
AF_LINK = socket.AF_LINK
39
40
RTA_DST = 0x1
41
RTA_GATEWAY = 0x2
42
RTA_NETMASK = 0x4
43
RTA_GENMASK = 0x8
44
RTA_IFP = 0x10
45
RTA_IFA = 0x20
46
RTA_AUTHOR = 0x40
47
RTA_BRD = 0x80
48
49
RTM_ADD = 1
50
RTM_DELETE = 2
51
RTM_CHANGE = 3
52
RTM_GET = 4
53
54
RTF_UP = 0x1
55
RTF_GATEWAY = 0x2
56
RTF_HOST = 0x4
57
RTF_REJECT = 0x8
58
RTF_DYNAMIC = 0x10
59
RTF_MODIFIED = 0x20
60
RTF_DONE = 0x40
61
RTF_XRESOLVE = 0x200
62
RTF_LLINFO = 0x400
63
RTF_LLDATA = 0x400
64
RTF_STATIC = 0x800
65
RTF_BLACKHOLE = 0x1000
66
RTF_PROTO2 = 0x4000
67
RTF_PROTO1 = 0x8000
68
RTF_PROTO3 = 0x40000
69
RTF_FIXEDMTU = 0x80000
70
RTF_PINNED = 0x100000
71
RTF_LOCAL = 0x200000
72
RTF_BROADCAST = 0x400000
73
RTF_MULTICAST = 0x800000
74
RTF_STICKY = 0x10000000
75
RTF_RNH_LOCKED = 0x40000000
76
RTF_GWFLAG_COMPAT = 0x80000000
77
78
RTV_MTU = 0x1
79
RTV_HOPCOUNT = 0x2
80
RTV_EXPIRE = 0x4
81
RTV_RPIPE = 0x8
82
RTV_SPIPE = 0x10
83
RTV_SSTHRESH = 0x20
84
RTV_RTT = 0x40
85
RTV_RTTVAR = 0x80
86
RTV_WEIGHT = 0x100
87
88
@staticmethod
89
def get_props(prefix: str) -> List[str]:
90
return [n for n in dir(RtConst) if n.startswith(prefix)]
91
92
@staticmethod
93
def get_name(prefix: str, value: int) -> str:
94
props = RtConst.get_props(prefix)
95
for prop in props:
96
if getattr(RtConst, prop) == value:
97
return prop
98
return "U:{}:{}".format(prefix, value)
99
100
@staticmethod
101
def get_bitmask_map(prefix: str, value: int) -> Dict[int, str]:
102
props = RtConst.get_props(prefix)
103
propmap = {getattr(RtConst, prop): prop for prop in props}
104
v = 1
105
ret = {}
106
while value:
107
if v & value:
108
if v in propmap:
109
ret[v] = propmap[v]
110
else:
111
ret[v] = hex(v)
112
value -= v
113
v *= 2
114
return ret
115
116
@staticmethod
117
def get_bitmask_str(prefix: str, value: int) -> str:
118
bmap = RtConst.get_bitmask_map(prefix, value)
119
return ",".join([v for k, v in bmap.items()])
120
121
122
class RtMetrics(Structure):
123
_fields_ = [
124
("rmx_locks", c_ulong),
125
("rmx_mtu", c_ulong),
126
("rmx_hopcount", c_ulong),
127
("rmx_expire", c_ulong),
128
("rmx_recvpipe", c_ulong),
129
("rmx_sendpipe", c_ulong),
130
("rmx_ssthresh", c_ulong),
131
("rmx_rtt", c_ulong),
132
("rmx_rttvar", c_ulong),
133
("rmx_pksent", c_ulong),
134
("rmx_weight", c_ulong),
135
("rmx_nhidx", c_ulong),
136
("rmx_filler", c_ulong * 2),
137
]
138
139
140
class RtMsgHdr(Structure):
141
_fields_ = [
142
("rtm_msglen", c_ushort),
143
("rtm_version", c_byte),
144
("rtm_type", c_byte),
145
("rtm_index", c_ushort),
146
("_rtm_spare1", c_ushort),
147
("rtm_flags", c_int),
148
("rtm_addrs", c_int),
149
("rtm_pid", c_int),
150
("rtm_seq", c_int),
151
("rtm_errno", c_int),
152
("rtm_fmask", c_int),
153
("rtm_inits", c_ulong),
154
("rtm_rmx", RtMetrics),
155
]
156
157
158
class SockaddrIn(Structure):
159
_fields_ = [
160
("sin_len", c_byte),
161
("sin_family", c_byte),
162
("sin_port", c_ushort),
163
("sin_addr", c_uint32),
164
("sin_zero", c_char * 8),
165
]
166
167
168
class SockaddrIn6(Structure):
169
_fields_ = [
170
("sin6_len", c_byte),
171
("sin6_family", c_byte),
172
("sin6_port", c_ushort),
173
("sin6_flowinfo", c_uint32),
174
("sin6_addr", c_byte * 16),
175
("sin6_scope_id", c_uint32),
176
]
177
178
179
class SockaddrDl(Structure):
180
_fields_ = [
181
("sdl_len", c_byte),
182
("sdl_family", c_byte),
183
("sdl_index", c_ushort),
184
("sdl_type", c_byte),
185
("sdl_nlen", c_byte),
186
("sdl_alen", c_byte),
187
("sdl_slen", c_byte),
188
("sdl_data", c_byte * 8),
189
]
190
191
192
class SaHelper(object):
193
@staticmethod
194
def is_ipv6(ip: str) -> bool:
195
return ":" in ip
196
197
@staticmethod
198
def ip_sa(ip: str, scopeid: int = 0) -> bytes:
199
if SaHelper.is_ipv6(ip):
200
return SaHelper.ip6_sa(ip, scopeid)
201
else:
202
return SaHelper.ip4_sa(ip)
203
204
@staticmethod
205
def ip4_sa(ip: str) -> bytes:
206
addr_int = int.from_bytes(socket.inet_pton(2, ip), sys.byteorder)
207
sin = SockaddrIn(sizeof(SockaddrIn), socket.AF_INET, 0, addr_int)
208
return bytes(sin)
209
210
@staticmethod
211
def ip6_sa(ip6: str, scopeid: int) -> bytes:
212
addr_bytes = (c_byte * 16)()
213
for i, b in enumerate(socket.inet_pton(socket.AF_INET6, ip6)):
214
addr_bytes[i] = b
215
sin6 = SockaddrIn6(
216
sizeof(SockaddrIn6), socket.AF_INET6, 0, 0, addr_bytes, scopeid
217
)
218
return bytes(sin6)
219
220
@staticmethod
221
def link_sa(ifindex: int = 0, iftype: int = 0) -> bytes:
222
sa = SockaddrDl(sizeof(SockaddrDl), socket.AF_LINK, c_ushort(ifindex), iftype)
223
return bytes(sa)
224
225
@staticmethod
226
def pxlen4_sa(pxlen: int) -> bytes:
227
return SaHelper.ip_sa(SaHelper.pxlen_to_ip4(pxlen))
228
229
@staticmethod
230
def pxlen_to_ip4(pxlen: int) -> str:
231
if pxlen == 32:
232
return "255.255.255.255"
233
else:
234
addr = 0xFFFFFFFF - ((1 << (32 - pxlen)) - 1)
235
addr_bytes = struct.pack("!I", addr)
236
return socket.inet_ntop(socket.AF_INET, addr_bytes)
237
238
@staticmethod
239
def pxlen6_sa(pxlen: int) -> bytes:
240
return SaHelper.ip_sa(SaHelper.pxlen_to_ip6(pxlen))
241
242
@staticmethod
243
def pxlen_to_ip6(pxlen: int) -> str:
244
ip6_b = [0] * 16
245
start = 0
246
while pxlen > 8:
247
ip6_b[start] = 0xFF
248
pxlen -= 8
249
start += 1
250
ip6_b[start] = 0xFF - ((1 << (8 - pxlen)) - 1)
251
return socket.inet_ntop(socket.AF_INET6, bytes(ip6_b))
252
253
@staticmethod
254
def print_sa_inet(sa: bytes):
255
if len(sa) < 8:
256
raise RtSockException("IPv4 sa size too small: {}".format(len(sa)))
257
addr = socket.inet_ntop(socket.AF_INET, sa[4:8])
258
return "{}".format(addr)
259
260
@staticmethod
261
def print_sa_inet6(sa: bytes):
262
if len(sa) < sizeof(SockaddrIn6):
263
raise RtSockException("IPv6 sa size too small: {}".format(len(sa)))
264
addr = socket.inet_ntop(socket.AF_INET6, sa[8:24])
265
scopeid = struct.unpack(">I", sa[24:28])[0]
266
return "{} scopeid {}".format(addr, scopeid)
267
268
@staticmethod
269
def print_sa_link(sa: bytes, hd: Optional[bool] = True):
270
if len(sa) < sizeof(SockaddrDl):
271
raise RtSockException("LINK sa size too small: {}".format(len(sa)))
272
sdl = SockaddrDl.from_buffer_copy(sa)
273
if sdl.sdl_index:
274
ifindex = "link#{} ".format(sdl.sdl_index)
275
else:
276
ifindex = ""
277
if sdl.sdl_nlen:
278
iface_offset = 8
279
if sdl.sdl_nlen + iface_offset > len(sa):
280
raise RtSockException(
281
"LINK sa sdl_nlen {} > total len {}".format(sdl.sdl_nlen, len(sa))
282
)
283
ifname = "ifname:{} ".format(
284
bytes.decode(sa[iface_offset : iface_offset + sdl.sdl_nlen])
285
)
286
else:
287
ifname = ""
288
return "{}{}".format(ifindex, ifname)
289
290
@staticmethod
291
def print_sa_unknown(sa: bytes):
292
return "unknown_type:{}".format(sa[1])
293
294
@classmethod
295
def print_sa(cls, sa: bytes, hd: Optional[bool] = False):
296
if sa[0] != len(sa):
297
raise Exception("sa size {} != buffer size {}".format(sa[0], len(sa)))
298
299
if len(sa) < 2:
300
raise Exception(
301
"sa type {} too short: {}".format(
302
RtConst.get_name("AF_", sa[1]), len(sa)
303
)
304
)
305
306
if sa[1] == socket.AF_INET:
307
text = cls.print_sa_inet(sa)
308
elif sa[1] == socket.AF_INET6:
309
text = cls.print_sa_inet6(sa)
310
elif sa[1] == socket.AF_LINK:
311
text = cls.print_sa_link(sa)
312
else:
313
text = cls.print_sa_unknown(sa)
314
if hd:
315
dump = " [{!r}]".format(sa)
316
else:
317
dump = ""
318
return "{}{}".format(text, dump)
319
320
321
class BaseRtsockMessage(object):
322
def __init__(self, rtm_type):
323
self.rtm_type = rtm_type
324
self.sa = SaHelper()
325
326
@staticmethod
327
def print_rtm_type(rtm_type):
328
return RtConst.get_name("RTM_", rtm_type)
329
330
@property
331
def rtm_type_str(self):
332
return self.print_rtm_type(self.rtm_type)
333
334
335
class RtsockRtMessage(BaseRtsockMessage):
336
messages = [
337
RtConst.RTM_ADD,
338
RtConst.RTM_DELETE,
339
RtConst.RTM_CHANGE,
340
RtConst.RTM_GET,
341
]
342
343
def __init__(self, rtm_type, rtm_seq=1, dst_sa=None, mask_sa=None):
344
super().__init__(rtm_type)
345
self.rtm_flags = 0
346
self.rtm_seq = rtm_seq
347
self._attrs = {}
348
self.rtm_errno = 0
349
self.rtm_pid = 0
350
self.rtm_inits = 0
351
self.rtm_rmx = RtMetrics()
352
self._orig_data = None
353
if dst_sa:
354
self.add_sa_attr(RtConst.RTA_DST, dst_sa)
355
if mask_sa:
356
self.add_sa_attr(RtConst.RTA_NETMASK, mask_sa)
357
358
def add_sa_attr(self, attr_type, attr_bytes: bytes):
359
self._attrs[attr_type] = attr_bytes
360
361
def add_ip_attr(self, attr_type, ip_addr: str, scopeid: int = 0):
362
if ":" in ip_addr:
363
self.add_ip6_attr(attr_type, ip_addr, scopeid)
364
else:
365
self.add_ip4_attr(attr_type, ip_addr)
366
367
def add_ip4_attr(self, attr_type, ip: str):
368
self.add_sa_attr(attr_type, self.sa.ip_sa(ip))
369
370
def add_ip6_attr(self, attr_type, ip6: str, scopeid: int):
371
self.add_sa_attr(attr_type, self.sa.ip6_sa(ip6, scopeid))
372
373
def add_link_attr(self, attr_type, ifindex: Optional[int] = 0):
374
self.add_sa_attr(attr_type, self.sa.link_sa(ifindex))
375
376
def get_sa(self, attr_type) -> bytes:
377
return self._attrs.get(attr_type)
378
379
def print_message(self):
380
# RTM_GET: Report Metrics: len 272, pid: 87839, seq 1, errno 0, flags:<UP,GATEWAY,DONE,STATIC>
381
if self._orig_data:
382
rtm_len = len(self._orig_data)
383
else:
384
rtm_len = len(bytes(self))
385
print(
386
"{}: len {}, pid: {}, seq {}, errno {}, flags: <{}>".format(
387
self.rtm_type_str,
388
rtm_len,
389
self.rtm_pid,
390
self.rtm_seq,
391
self.rtm_errno,
392
RtConst.get_bitmask_str("RTF_", self.rtm_flags),
393
)
394
)
395
rtm_addrs = sum(list(self._attrs.keys()))
396
print("Addrs: <{}>".format(RtConst.get_bitmask_str("RTA_", rtm_addrs)))
397
for attr in sorted(self._attrs.keys()):
398
sa_data = SaHelper.print_sa(self._attrs[attr])
399
print(" {}: {}".format(RtConst.get_name("RTA_", attr), sa_data))
400
401
def print_in_message(self):
402
print("vvvvvvvv IN vvvvvvvv")
403
self.print_message()
404
print()
405
406
def verify_sa_inet(self, sa_data):
407
if len(sa_data) < 8:
408
raise Exception("IPv4 sa size too small: {}".format(sa_data))
409
if sa_data[0] > len(sa_data):
410
raise Exception(
411
"IPv4 sin_len too big: {} vs sa size {}: {}".format(
412
sa_data[0], len(sa_data), sa_data
413
)
414
)
415
sin = SockaddrIn.from_buffer_copy(sa_data)
416
assert sin.sin_port == 0
417
assert sin.sin_zero == [0] * 8
418
419
def compare_sa(self, sa_type, sa_data):
420
if len(sa_data) < 4:
421
sa_type_name = RtConst.get_name("RTA_", sa_type)
422
raise Exception(
423
"sa_len for type {} too short: {}".format(sa_type_name, len(sa_data))
424
)
425
our_sa = self._attrs[sa_type]
426
assert SaHelper.print_sa(sa_data) == SaHelper.print_sa(our_sa)
427
assert len(sa_data) == len(our_sa)
428
assert sa_data == our_sa
429
430
def verify(self, rtm_type: int, rtm_sa):
431
assert self.rtm_type_str == self.print_rtm_type(rtm_type)
432
assert self.rtm_errno == 0
433
hdr = RtMsgHdr.from_buffer_copy(self._orig_data)
434
assert hdr._rtm_spare1 == 0
435
for sa_type, sa_data in rtm_sa.items():
436
if sa_type not in self._attrs:
437
sa_type_name = RtConst.get_name("RTA_", sa_type)
438
raise Exception("SA type {} not present".format(sa_type_name))
439
self.compare_sa(sa_type, sa_data)
440
441
@classmethod
442
def from_bytes(cls, data: bytes):
443
if len(data) < sizeof(RtMsgHdr):
444
raise Exception(
445
"messages size {} is less than expected {}".format(
446
len(data), sizeof(RtMsgHdr)
447
)
448
)
449
hdr = RtMsgHdr.from_buffer_copy(data)
450
451
self = cls(hdr.rtm_type)
452
self.rtm_flags = hdr.rtm_flags
453
self.rtm_seq = hdr.rtm_seq
454
self.rtm_errno = hdr.rtm_errno
455
self.rtm_pid = hdr.rtm_pid
456
self.rtm_inits = hdr.rtm_inits
457
self.rtm_rmx = hdr.rtm_rmx
458
self._orig_data = data
459
460
off = sizeof(RtMsgHdr)
461
v = 1
462
addrs_mask = hdr.rtm_addrs
463
while addrs_mask:
464
if addrs_mask & v:
465
addrs_mask -= v
466
467
if off + data[off] > len(data):
468
raise Exception(
469
"SA sizeof for {} > total message length: {}+{} > {}".format(
470
RtConst.get_name("RTA_", v), off, data[off], len(data)
471
)
472
)
473
self._attrs[v] = data[off : off + data[off]]
474
off += roundup2(data[off], RtConst.ALIGN)
475
v *= 2
476
return self
477
478
def __bytes__(self):
479
sz = sizeof(RtMsgHdr)
480
addrs_mask = 0
481
for k, v in self._attrs.items():
482
sz += roundup2(len(v), RtConst.ALIGN)
483
addrs_mask += k
484
hdr = RtMsgHdr(
485
rtm_msglen=sz,
486
rtm_version=RtConst.RTM_VERSION,
487
rtm_type=self.rtm_type,
488
rtm_flags=self.rtm_flags,
489
rtm_seq=self.rtm_seq,
490
rtm_addrs=addrs_mask,
491
rtm_inits=self.rtm_inits,
492
rtm_rmx=self.rtm_rmx,
493
)
494
buf = bytearray(sz)
495
buf[0 : sizeof(RtMsgHdr)] = hdr
496
off = sizeof(RtMsgHdr)
497
for attr in sorted(self._attrs.keys()):
498
v = self._attrs[attr]
499
sa_len = len(v)
500
buf[off : off + sa_len] = v
501
off += roundup2(len(v), RtConst.ALIGN)
502
return bytes(buf)
503
504
505
class Rtsock:
506
def __init__(self):
507
self.socket = self._setup_rtsock()
508
self.rtm_seq = 1
509
self.msgmap = self.build_msgmap()
510
511
def build_msgmap(self):
512
classes = [RtsockRtMessage]
513
xmap = {}
514
for cls in classes:
515
for message in cls.messages:
516
xmap[message] = cls
517
return xmap
518
519
def get_seq(self):
520
ret = self.rtm_seq
521
self.rtm_seq += 1
522
return ret
523
524
def get_weight(self, weight) -> int:
525
if weight:
526
return weight
527
else:
528
return 1 # RT_DEFAULT_WEIGHT
529
530
def new_rtm_any(self, msg_type, prefix: str, gw: Union[str, bytes]):
531
px = prefix.split("/")
532
addr_sa = SaHelper.ip_sa(px[0])
533
if len(px) > 1:
534
pxlen = int(px[1])
535
if SaHelper.is_ipv6(px[0]):
536
mask_sa = SaHelper.pxlen6_sa(pxlen)
537
else:
538
mask_sa = SaHelper.pxlen4_sa(pxlen)
539
else:
540
mask_sa = None
541
msg = RtsockRtMessage(msg_type, self.get_seq(), addr_sa, mask_sa)
542
if isinstance(gw, bytes):
543
msg.add_sa_attr(RtConst.RTA_GATEWAY, gw)
544
else:
545
# String
546
msg.add_ip_attr(RtConst.RTA_GATEWAY, gw)
547
return msg
548
549
def new_rtm_add(self, prefix: str, gw: Union[str, bytes]):
550
return self.new_rtm_any(RtConst.RTM_ADD, prefix, gw)
551
552
def new_rtm_del(self, prefix: str, gw: Union[str, bytes]):
553
return self.new_rtm_any(RtConst.RTM_DELETE, prefix, gw)
554
555
def new_rtm_change(self, prefix: str, gw: Union[str, bytes]):
556
return self.new_rtm_any(RtConst.RTM_CHANGE, prefix, gw)
557
558
def _setup_rtsock(self) -> socket.socket:
559
s = socket.socket(socket.AF_ROUTE, socket.SOCK_RAW, socket.AF_UNSPEC)
560
s.setsockopt(socket.SOL_SOCKET, socket.SO_USELOOPBACK, 1)
561
return s
562
563
def print_hd(self, data: bytes):
564
width = 16
565
print("==========================================")
566
for chunk in [data[i : i + width] for i in range(0, len(data), width)]:
567
for b in chunk:
568
print("0x{:02X} ".format(b), end="")
569
print()
570
print()
571
572
def write_message(self, msg):
573
print("vvvvvvvv OUT vvvvvvvv")
574
msg.print_message()
575
print()
576
msg_bytes = bytes(msg)
577
ret = os.write(self.socket.fileno(), msg_bytes)
578
if ret != -1:
579
assert ret == len(msg_bytes)
580
581
def parse_message(self, data: bytes):
582
if len(data) < 4:
583
raise OSError("Short read from rtsock: {} bytes".format(len(data)))
584
rtm_type = data[4]
585
if rtm_type not in self.msgmap:
586
return None
587
588
def write_data(self, data: bytes):
589
self.socket.send(data)
590
591
def read_data(self, seq: Optional[int] = None) -> bytes:
592
while True:
593
data = self.socket.recv(4096)
594
if seq is None:
595
break
596
if len(data) > sizeof(RtMsgHdr):
597
hdr = RtMsgHdr.from_buffer_copy(data)
598
if hdr.rtm_seq == seq:
599
break
600
return data
601
602
def read_message(self) -> bytes:
603
data = self.read_data()
604
return self.parse_message(data)
605
606