Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
freebsd
GitHub Repository: freebsd/freebsd-src
Path: blob/main/tests/atf_python/sys/netlink/netlink.py
39553 views
1
#!/usr/local/bin/python3
2
import os
3
import socket
4
import sys
5
from ctypes import c_int
6
from ctypes import c_ubyte
7
from ctypes import c_uint
8
from ctypes import c_ushort
9
from ctypes import sizeof
10
from ctypes import Structure
11
from enum import auto
12
from enum import Enum
13
14
from atf_python.sys.netlink.attrs import NlAttr
15
from atf_python.sys.netlink.attrs import NlAttrStr
16
from atf_python.sys.netlink.attrs import NlAttrU32
17
from atf_python.sys.netlink.base_headers import GenlMsgHdr
18
from atf_python.sys.netlink.base_headers import NlmBaseFlags
19
from atf_python.sys.netlink.base_headers import Nlmsghdr
20
from atf_python.sys.netlink.base_headers import NlMsgType
21
from atf_python.sys.netlink.message import BaseNetlinkMessage
22
from atf_python.sys.netlink.message import NlMsgCategory
23
from atf_python.sys.netlink.message import NlMsgProps
24
from atf_python.sys.netlink.message import StdNetlinkMessage
25
from atf_python.sys.netlink.netlink_generic import GenlCtrlAttrType
26
from atf_python.sys.netlink.netlink_generic import GenlCtrlMsgType
27
from atf_python.sys.netlink.netlink_generic import handler_classes as genl_classes
28
from atf_python.sys.netlink.netlink_route import handler_classes as rt_classes
29
from atf_python.sys.netlink.utils import align4
30
from atf_python.sys.netlink.utils import AttrDescr
31
from atf_python.sys.netlink.utils import build_propmap
32
from atf_python.sys.netlink.utils import enum_or_int
33
from atf_python.sys.netlink.utils import get_bitmask_map
34
from atf_python.sys.netlink.utils import NlConst
35
from atf_python.sys.netlink.utils import prepare_attrs_map
36
37
38
class SockaddrNl(Structure):
39
_fields_ = [
40
("nl_len", c_ubyte),
41
("nl_family", c_ubyte),
42
("nl_pad", c_ushort),
43
("nl_pid", c_uint),
44
("nl_groups", c_uint),
45
]
46
47
48
class Nlmsgdone(Structure):
49
_fields_ = [
50
("error", c_int),
51
]
52
53
54
class Nlmsgerr(Structure):
55
_fields_ = [
56
("error", c_int),
57
("msg", Nlmsghdr),
58
]
59
60
61
class NlErrattrType(Enum):
62
NLMSGERR_ATTR_UNUSED = 0
63
NLMSGERR_ATTR_MSG = auto()
64
NLMSGERR_ATTR_OFFS = auto()
65
NLMSGERR_ATTR_COOKIE = auto()
66
NLMSGERR_ATTR_POLICY = auto()
67
68
69
class AddressFamilyLinux(Enum):
70
AF_INET = socket.AF_INET
71
AF_INET6 = socket.AF_INET6
72
AF_NETLINK = 16
73
74
75
class AddressFamilyBsd(Enum):
76
AF_INET = socket.AF_INET
77
AF_INET6 = socket.AF_INET6
78
AF_NETLINK = 38
79
80
81
class NlHelper:
82
def __init__(self):
83
self._pmap = {}
84
self._af_cls = self.get_af_cls()
85
self._seq_counter = 1
86
self.pid = os.getpid()
87
88
def get_seq(self):
89
ret = self._seq_counter
90
self._seq_counter += 1
91
return ret
92
93
def get_af_cls(self):
94
if sys.platform.startswith("freebsd"):
95
cls = AddressFamilyBsd
96
else:
97
cls = AddressFamilyLinux
98
return cls
99
100
def get_propmap(self, cls):
101
if cls not in self._pmap:
102
self._pmap[cls] = build_propmap(cls)
103
return self._pmap[cls]
104
105
def get_name_propmap(self, cls):
106
ret = {}
107
for prop in dir(cls):
108
if not prop.startswith("_"):
109
ret[prop] = getattr(cls, prop).value
110
return ret
111
112
def get_attr_byval(self, cls, attr_val):
113
propmap = self.get_propmap(cls)
114
return propmap.get(attr_val)
115
116
def get_af_name(self, family):
117
v = self.get_attr_byval(self._af_cls, family)
118
if v is not None:
119
return v
120
return "af#{}".format(family)
121
122
def get_af_value(self, family_str: str) -> int:
123
propmap = self.get_name_propmap(self._af_cls)
124
return propmap.get(family_str)
125
126
def get_bitmask_str(self, cls, val):
127
bmap = get_bitmask_map(self.get_propmap(cls), val)
128
return ",".join([v for k, v in bmap.items()])
129
130
@staticmethod
131
def get_bitmask_str_uncached(cls, val):
132
pmap = NlHelper.build_propmap(cls)
133
bmap = NlHelper.get_bitmask_map(pmap, val)
134
return ",".join([v for k, v in bmap.items()])
135
136
137
nldone_attrs = prepare_attrs_map([])
138
139
nlerr_attrs = prepare_attrs_map(
140
[
141
AttrDescr(NlErrattrType.NLMSGERR_ATTR_MSG, NlAttrStr),
142
AttrDescr(NlErrattrType.NLMSGERR_ATTR_OFFS, NlAttrU32),
143
AttrDescr(NlErrattrType.NLMSGERR_ATTR_COOKIE, NlAttr),
144
]
145
)
146
147
148
class NetlinkDoneMessage(StdNetlinkMessage):
149
messages = [NlMsgProps(NlMsgType.NLMSG_DONE, NlMsgCategory.ACK)]
150
nl_attrs_map = nldone_attrs
151
152
@property
153
def error_code(self):
154
return self.base_hdr.error
155
156
def parse_base_header(self, data):
157
if len(data) < sizeof(Nlmsgdone):
158
raise ValueError("length less than nlmsgdone header")
159
done_hdr = Nlmsgdone.from_buffer_copy(data)
160
sz = sizeof(Nlmsgdone)
161
return (done_hdr, sz)
162
163
def print_base_header(self, hdr, prepend=""):
164
print("{}error={}".format(prepend, hdr.error))
165
166
167
class NetlinkErrorMessage(StdNetlinkMessage):
168
messages = [NlMsgProps(NlMsgType.NLMSG_ERROR, NlMsgCategory.ACK)]
169
nl_attrs_map = nlerr_attrs
170
171
@property
172
def error_code(self):
173
return self.base_hdr.error
174
175
@property
176
def error_str(self):
177
nla = self.get_nla(NlErrattrType.NLMSGERR_ATTR_MSG)
178
if nla:
179
return nla.text
180
return None
181
182
@property
183
def error_offset(self):
184
nla = self.get_nla(NlErrattrType.NLMSGERR_ATTR_OFFS)
185
if nla:
186
return nla.u32
187
return None
188
189
@property
190
def cookie(self):
191
return self.get_nla(NlErrattrType.NLMSGERR_ATTR_COOKIE)
192
193
def parse_base_header(self, data):
194
if len(data) < sizeof(Nlmsgerr):
195
raise ValueError("length less than nlmsgerr header")
196
err_hdr = Nlmsgerr.from_buffer_copy(data)
197
sz = sizeof(Nlmsgerr)
198
if (self.nl_hdr.nlmsg_flags & 0x100) == 0:
199
sz += align4(err_hdr.msg.nlmsg_len - sizeof(Nlmsghdr))
200
return (err_hdr, sz)
201
202
def print_base_header(self, errhdr, prepend=""):
203
print("{}error={}, ".format(prepend, errhdr.error), end="")
204
hdr = errhdr.msg
205
print(
206
"{}len={}, type={}, flags={}(0x{:X}), seq={}, pid={}".format(
207
prepend,
208
hdr.nlmsg_len,
209
"msg#{}".format(hdr.nlmsg_type),
210
self.helper.get_bitmask_str(NlmBaseFlags, hdr.nlmsg_flags),
211
hdr.nlmsg_flags,
212
hdr.nlmsg_seq,
213
hdr.nlmsg_pid,
214
)
215
)
216
217
218
core_classes = {
219
"netlink_core": [
220
NetlinkDoneMessage,
221
NetlinkErrorMessage,
222
],
223
}
224
225
226
class Nlsock:
227
HANDLER_CLASSES = [core_classes, rt_classes, genl_classes]
228
229
def __init__(self, family, helper):
230
self.helper = helper
231
self.sock_fd = self._setup_netlink(family)
232
self._sock_family = family
233
self._data = bytes()
234
self.msgmap = self.build_msgmap()
235
self._family_map = {
236
NlConst.GENL_ID_CTRL: "nlctrl",
237
}
238
239
def build_msgmap(self):
240
handler_classes = {}
241
for d in self.HANDLER_CLASSES:
242
handler_classes.update(d)
243
xmap = {}
244
# 'family_name': [class.messages[MsgProps.msg], ]
245
for family_id, family_classes in handler_classes.items():
246
xmap[family_id] = {}
247
for cls in family_classes:
248
for msg_props in cls.messages:
249
xmap[family_id][enum_or_int(msg_props.msg)] = cls
250
return xmap
251
252
def _setup_netlink(self, netlink_family) -> int:
253
family = self.helper.get_af_value("AF_NETLINK")
254
s = socket.socket(family, socket.SOCK_RAW, netlink_family)
255
s.setsockopt(270, 10, 1) # NETLINK_CAP_ACK
256
s.setsockopt(270, 11, 1) # NETLINK_EXT_ACK
257
return s
258
259
def set_groups(self, mask: int):
260
self.sock_fd.setsockopt(socket.SOL_SOCKET, 1, mask)
261
# snl = SockaddrNl(nl_len = sizeof(SockaddrNl), nl_family=38,
262
# nl_pid=self.pid, nl_groups=mask)
263
# xbuffer = create_string_buffer(sizeof(SockaddrNl))
264
# memmove(xbuffer, addressof(snl), sizeof(SockaddrNl))
265
# k = struct.pack("@BBHII", 12, 38, 0, self.pid, mask)
266
# self.sock_fd.bind(k)
267
268
def join_group(self, group_id: int):
269
self.sock_fd.setsockopt(270, 1, group_id)
270
271
def write_message(self, msg, verbose=True):
272
if verbose:
273
print("vvvvvvvv OUT vvvvvvvv")
274
msg.print_message()
275
msg_bytes = bytes(msg)
276
try:
277
ret = os.write(self.sock_fd.fileno(), msg_bytes)
278
assert ret == len(msg_bytes)
279
except Exception as e:
280
print("write({}) -> {}".format(len(msg_bytes), e))
281
282
def parse_message(self, data: bytes):
283
if len(data) < sizeof(Nlmsghdr):
284
raise Exception("Short read from nl: {} bytes".format(len(data)))
285
hdr = Nlmsghdr.from_buffer_copy(data)
286
if hdr.nlmsg_type < 16:
287
family_name = "netlink_core"
288
nlmsg_type = hdr.nlmsg_type
289
elif self._sock_family == NlConst.NETLINK_ROUTE:
290
family_name = "netlink_route"
291
nlmsg_type = hdr.nlmsg_type
292
else:
293
# Genetlink
294
if len(data) < sizeof(Nlmsghdr) + sizeof(GenlMsgHdr):
295
raise Exception("Short read from genl: {} bytes".format(len(data)))
296
family_name = self._family_map.get(hdr.nlmsg_type, "")
297
ghdr = GenlMsgHdr.from_buffer_copy(data[sizeof(Nlmsghdr):])
298
nlmsg_type = ghdr.cmd
299
cls = self.msgmap.get(family_name, {}).get(nlmsg_type)
300
if not cls:
301
cls = BaseNetlinkMessage
302
return cls.from_bytes(self.helper, data)
303
304
def get_genl_family_id(self, family_name):
305
hdr = Nlmsghdr(
306
nlmsg_type=NlConst.GENL_ID_CTRL,
307
nlmsg_flags=NlmBaseFlags.NLM_F_REQUEST.value,
308
nlmsg_seq=self.helper.get_seq(),
309
)
310
ghdr = GenlMsgHdr(cmd=GenlCtrlMsgType.CTRL_CMD_GETFAMILY.value)
311
nla = NlAttrStr(GenlCtrlAttrType.CTRL_ATTR_FAMILY_NAME, family_name)
312
hdr.nlmsg_len = sizeof(Nlmsghdr) + sizeof(GenlMsgHdr) + len(bytes(nla))
313
314
msg_bytes = bytes(hdr) + bytes(ghdr) + bytes(nla)
315
self.write_data(msg_bytes)
316
while True:
317
rx_msg = self.read_message()
318
if hdr.nlmsg_seq == rx_msg.nl_hdr.nlmsg_seq:
319
if rx_msg.is_type(NlMsgType.NLMSG_ERROR):
320
if rx_msg.error_code != 0:
321
raise ValueError("unable to get family {}".format(family_name))
322
else:
323
family_id = rx_msg.get_nla(GenlCtrlAttrType.CTRL_ATTR_FAMILY_ID).u16
324
self._family_map[family_id] = family_name
325
return family_id
326
raise ValueError("unable to get family {}".format(family_name))
327
328
def write_data(self, data: bytes):
329
self.sock_fd.send(data)
330
331
def read_data(self):
332
while True:
333
data = self.sock_fd.recv(65535)
334
self._data += data
335
if len(self._data) >= sizeof(Nlmsghdr):
336
break
337
338
def read_message(self) -> bytes:
339
if len(self._data) < sizeof(Nlmsghdr):
340
self.read_data()
341
hdr = Nlmsghdr.from_buffer_copy(self._data)
342
while hdr.nlmsg_len > len(self._data):
343
self.read_data()
344
raw_msg = self._data[: hdr.nlmsg_len]
345
self._data = self._data[hdr.nlmsg_len:]
346
return self.parse_message(raw_msg)
347
348
def get_reply(self, tx_msg):
349
self.write_message(tx_msg)
350
while True:
351
rx_msg = self.read_message()
352
if tx_msg.nl_hdr.nlmsg_seq == rx_msg.nl_hdr.nlmsg_seq:
353
return rx_msg
354
355
356
class NetlinkMultipartIterator(object):
357
def __init__(self, obj, seq_number: int, msg_type):
358
self._obj = obj
359
self._seq = seq_number
360
self._msg_type = msg_type
361
362
def __iter__(self):
363
return self
364
365
def __next__(self):
366
msg = self._obj.read_message()
367
if self._seq != msg.nl_hdr.nlmsg_seq:
368
raise ValueError("bad sequence number")
369
if msg.is_type(NlMsgType.NLMSG_ERROR):
370
raise ValueError(
371
"error while handling multipart msg: {}".format(msg.error_code)
372
)
373
elif msg.is_type(NlMsgType.NLMSG_DONE):
374
if msg.error_code == 0:
375
raise StopIteration
376
raise ValueError(
377
"error listing some parts of the multipart msg: {}".format(
378
msg.error_code
379
)
380
)
381
elif not msg.is_type(self._msg_type):
382
raise ValueError("bad message type: {}".format(msg))
383
return msg
384
385
386
class NetlinkTestTemplate(object):
387
REQUIRED_MODULES = ["netlink"]
388
389
def setup_netlink(self, netlink_family: NlConst):
390
self.helper = NlHelper()
391
self.nlsock = Nlsock(netlink_family, self.helper)
392
393
def write_message(self, msg, silent=False):
394
if not silent:
395
print("")
396
print("============= >> TX MESSAGE =============")
397
msg.print_message()
398
msg.print_as_bytes(bytes(msg), "-- DATA --")
399
self.nlsock.write_data(bytes(msg))
400
401
def read_message(self, silent=False):
402
msg = self.nlsock.read_message()
403
if not silent:
404
print("")
405
print("============= << RX MESSAGE =============")
406
msg.print_message()
407
return msg
408
409
def get_reply(self, tx_msg):
410
self.write_message(tx_msg)
411
while True:
412
rx_msg = self.read_message()
413
if tx_msg.nl_hdr.nlmsg_seq == rx_msg.nl_hdr.nlmsg_seq:
414
return rx_msg
415
416
def read_msg_list(self, seq, msg_type):
417
return list(NetlinkMultipartIterator(self, seq, msg_type))
418
419