Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
torvalds
GitHub Repository: torvalds/linux
Path: blob/master/tools/net/ynl/pyynl/lib/ynl.py
50619 views
1
# SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause
2
3
from collections import namedtuple
4
from enum import Enum
5
import functools
6
import os
7
import random
8
import socket
9
import struct
10
from struct import Struct
11
import sys
12
import ipaddress
13
import uuid
14
import queue
15
import selectors
16
import time
17
18
from .nlspec import SpecFamily
19
20
#
21
# Generic Netlink code which should really be in some library, but I can't quickly find one.
22
#
23
24
25
class Netlink:
26
# Netlink socket
27
SOL_NETLINK = 270
28
29
NETLINK_ADD_MEMBERSHIP = 1
30
NETLINK_CAP_ACK = 10
31
NETLINK_EXT_ACK = 11
32
NETLINK_GET_STRICT_CHK = 12
33
34
# Netlink message
35
NLMSG_ERROR = 2
36
NLMSG_DONE = 3
37
38
NLM_F_REQUEST = 1
39
NLM_F_ACK = 4
40
NLM_F_ROOT = 0x100
41
NLM_F_MATCH = 0x200
42
43
NLM_F_REPLACE = 0x100
44
NLM_F_EXCL = 0x200
45
NLM_F_CREATE = 0x400
46
NLM_F_APPEND = 0x800
47
48
NLM_F_CAPPED = 0x100
49
NLM_F_ACK_TLVS = 0x200
50
51
NLM_F_DUMP = NLM_F_ROOT | NLM_F_MATCH
52
53
NLA_F_NESTED = 0x8000
54
NLA_F_NET_BYTEORDER = 0x4000
55
56
NLA_TYPE_MASK = NLA_F_NESTED | NLA_F_NET_BYTEORDER
57
58
# Genetlink defines
59
NETLINK_GENERIC = 16
60
61
GENL_ID_CTRL = 0x10
62
63
# nlctrl
64
CTRL_CMD_GETFAMILY = 3
65
66
CTRL_ATTR_FAMILY_ID = 1
67
CTRL_ATTR_FAMILY_NAME = 2
68
CTRL_ATTR_MAXATTR = 5
69
CTRL_ATTR_MCAST_GROUPS = 7
70
71
CTRL_ATTR_MCAST_GRP_NAME = 1
72
CTRL_ATTR_MCAST_GRP_ID = 2
73
74
# Extack types
75
NLMSGERR_ATTR_MSG = 1
76
NLMSGERR_ATTR_OFFS = 2
77
NLMSGERR_ATTR_COOKIE = 3
78
NLMSGERR_ATTR_POLICY = 4
79
NLMSGERR_ATTR_MISS_TYPE = 5
80
NLMSGERR_ATTR_MISS_NEST = 6
81
82
# Policy types
83
NL_POLICY_TYPE_ATTR_TYPE = 1
84
NL_POLICY_TYPE_ATTR_MIN_VALUE_S = 2
85
NL_POLICY_TYPE_ATTR_MAX_VALUE_S = 3
86
NL_POLICY_TYPE_ATTR_MIN_VALUE_U = 4
87
NL_POLICY_TYPE_ATTR_MAX_VALUE_U = 5
88
NL_POLICY_TYPE_ATTR_MIN_LENGTH = 6
89
NL_POLICY_TYPE_ATTR_MAX_LENGTH = 7
90
NL_POLICY_TYPE_ATTR_POLICY_IDX = 8
91
NL_POLICY_TYPE_ATTR_POLICY_MAXTYPE = 9
92
NL_POLICY_TYPE_ATTR_BITFIELD32_MASK = 10
93
NL_POLICY_TYPE_ATTR_PAD = 11
94
NL_POLICY_TYPE_ATTR_MASK = 12
95
96
AttrType = Enum('AttrType', ['flag', 'u8', 'u16', 'u32', 'u64',
97
's8', 's16', 's32', 's64',
98
'binary', 'string', 'nul-string',
99
'nested', 'nested-array',
100
'bitfield32', 'sint', 'uint'])
101
102
class NlError(Exception):
103
def __init__(self, nl_msg):
104
self.nl_msg = nl_msg
105
self.error = -nl_msg.error
106
107
def __str__(self):
108
msg = "Netlink error: "
109
110
extack = self.nl_msg.extack.copy() if self.nl_msg.extack else {}
111
if 'msg' in extack:
112
msg += extack['msg'] + ': '
113
del extack['msg']
114
msg += os.strerror(self.error)
115
if extack:
116
msg += ' ' + str(extack)
117
return msg
118
119
120
class ConfigError(Exception):
121
pass
122
123
124
class NlAttr:
125
ScalarFormat = namedtuple('ScalarFormat', ['native', 'big', 'little'])
126
type_formats = {
127
'u8' : ScalarFormat(Struct('B'), Struct("B"), Struct("B")),
128
's8' : ScalarFormat(Struct('b'), Struct("b"), Struct("b")),
129
'u16': ScalarFormat(Struct('H'), Struct(">H"), Struct("<H")),
130
's16': ScalarFormat(Struct('h'), Struct(">h"), Struct("<h")),
131
'u32': ScalarFormat(Struct('I'), Struct(">I"), Struct("<I")),
132
's32': ScalarFormat(Struct('i'), Struct(">i"), Struct("<i")),
133
'u64': ScalarFormat(Struct('Q'), Struct(">Q"), Struct("<Q")),
134
's64': ScalarFormat(Struct('q'), Struct(">q"), Struct("<q"))
135
}
136
137
def __init__(self, raw, offset):
138
self._len, self._type = struct.unpack("HH", raw[offset : offset + 4])
139
self.type = self._type & ~Netlink.NLA_TYPE_MASK
140
self.is_nest = self._type & Netlink.NLA_F_NESTED
141
self.payload_len = self._len
142
self.full_len = (self.payload_len + 3) & ~3
143
self.raw = raw[offset + 4 : offset + self.payload_len]
144
145
@classmethod
146
def get_format(cls, attr_type, byte_order=None):
147
format = cls.type_formats[attr_type]
148
if byte_order:
149
return format.big if byte_order == "big-endian" \
150
else format.little
151
return format.native
152
153
def as_scalar(self, attr_type, byte_order=None):
154
format = self.get_format(attr_type, byte_order)
155
return format.unpack(self.raw)[0]
156
157
def as_auto_scalar(self, attr_type, byte_order=None):
158
if len(self.raw) != 4 and len(self.raw) != 8:
159
raise Exception(f"Auto-scalar len payload be 4 or 8 bytes, got {len(self.raw)}")
160
real_type = attr_type[0] + str(len(self.raw) * 8)
161
format = self.get_format(real_type, byte_order)
162
return format.unpack(self.raw)[0]
163
164
def as_strz(self):
165
return self.raw.decode('ascii')[:-1]
166
167
def as_bin(self):
168
return self.raw
169
170
def as_c_array(self, type):
171
format = self.get_format(type)
172
return [ x[0] for x in format.iter_unpack(self.raw) ]
173
174
def __repr__(self):
175
return f"[type:{self.type} len:{self._len}] {self.raw}"
176
177
178
class NlAttrs:
179
def __init__(self, msg, offset=0):
180
self.attrs = []
181
182
while offset < len(msg):
183
attr = NlAttr(msg, offset)
184
offset += attr.full_len
185
self.attrs.append(attr)
186
187
def __iter__(self):
188
yield from self.attrs
189
190
def __repr__(self):
191
msg = ''
192
for a in self.attrs:
193
if msg:
194
msg += '\n'
195
msg += repr(a)
196
return msg
197
198
199
class NlMsg:
200
def __init__(self, msg, offset, attr_space=None):
201
self.hdr = msg[offset : offset + 16]
202
203
self.nl_len, self.nl_type, self.nl_flags, self.nl_seq, self.nl_portid = \
204
struct.unpack("IHHII", self.hdr)
205
206
self.raw = msg[offset + 16 : offset + self.nl_len]
207
208
self.error = 0
209
self.done = 0
210
211
extack_off = None
212
if self.nl_type == Netlink.NLMSG_ERROR:
213
self.error = struct.unpack("i", self.raw[0:4])[0]
214
self.done = 1
215
extack_off = 20
216
elif self.nl_type == Netlink.NLMSG_DONE:
217
self.error = struct.unpack("i", self.raw[0:4])[0]
218
self.done = 1
219
extack_off = 4
220
221
self.extack = None
222
if self.nl_flags & Netlink.NLM_F_ACK_TLVS and extack_off:
223
self.extack = dict()
224
extack_attrs = NlAttrs(self.raw[extack_off:])
225
for extack in extack_attrs:
226
if extack.type == Netlink.NLMSGERR_ATTR_MSG:
227
self.extack['msg'] = extack.as_strz()
228
elif extack.type == Netlink.NLMSGERR_ATTR_MISS_TYPE:
229
self.extack['miss-type'] = extack.as_scalar('u32')
230
elif extack.type == Netlink.NLMSGERR_ATTR_MISS_NEST:
231
self.extack['miss-nest'] = extack.as_scalar('u32')
232
elif extack.type == Netlink.NLMSGERR_ATTR_OFFS:
233
self.extack['bad-attr-offs'] = extack.as_scalar('u32')
234
elif extack.type == Netlink.NLMSGERR_ATTR_POLICY:
235
self.extack['policy'] = self._decode_policy(extack.raw)
236
else:
237
if 'unknown' not in self.extack:
238
self.extack['unknown'] = []
239
self.extack['unknown'].append(extack)
240
241
if attr_space:
242
self.annotate_extack(attr_space)
243
244
def _decode_policy(self, raw):
245
policy = {}
246
for attr in NlAttrs(raw):
247
if attr.type == Netlink.NL_POLICY_TYPE_ATTR_TYPE:
248
type = attr.as_scalar('u32')
249
policy['type'] = Netlink.AttrType(type).name
250
elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_S:
251
policy['min-value'] = attr.as_scalar('s64')
252
elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_S:
253
policy['max-value'] = attr.as_scalar('s64')
254
elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_VALUE_U:
255
policy['min-value'] = attr.as_scalar('u64')
256
elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_VALUE_U:
257
policy['max-value'] = attr.as_scalar('u64')
258
elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MIN_LENGTH:
259
policy['min-length'] = attr.as_scalar('u32')
260
elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MAX_LENGTH:
261
policy['max-length'] = attr.as_scalar('u32')
262
elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_BITFIELD32_MASK:
263
policy['bitfield32-mask'] = attr.as_scalar('u32')
264
elif attr.type == Netlink.NL_POLICY_TYPE_ATTR_MASK:
265
policy['mask'] = attr.as_scalar('u64')
266
return policy
267
268
def annotate_extack(self, attr_space):
269
""" Make extack more human friendly with attribute information """
270
271
# We don't have the ability to parse nests yet, so only do global
272
if 'miss-type' in self.extack and 'miss-nest' not in self.extack:
273
miss_type = self.extack['miss-type']
274
if miss_type in attr_space.attrs_by_val:
275
spec = attr_space.attrs_by_val[miss_type]
276
self.extack['miss-type'] = spec['name']
277
if 'doc' in spec:
278
self.extack['miss-type-doc'] = spec['doc']
279
280
def cmd(self):
281
return self.nl_type
282
283
def __repr__(self):
284
msg = f"nl_len = {self.nl_len} ({len(self.raw)}) nl_flags = 0x{self.nl_flags:x} nl_type = {self.nl_type}"
285
if self.error:
286
msg += '\n\terror: ' + str(self.error)
287
if self.extack:
288
msg += '\n\textack: ' + repr(self.extack)
289
return msg
290
291
292
class NlMsgs:
293
def __init__(self, data):
294
self.msgs = []
295
296
offset = 0
297
while offset < len(data):
298
msg = NlMsg(data, offset)
299
offset += msg.nl_len
300
self.msgs.append(msg)
301
302
def __iter__(self):
303
yield from self.msgs
304
305
306
genl_family_name_to_id = None
307
308
309
def _genl_msg(nl_type, nl_flags, genl_cmd, genl_version, seq=None):
310
# we prepend length in _genl_msg_finalize()
311
if seq is None:
312
seq = random.randint(1, 1024)
313
nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
314
genlmsg = struct.pack("BBH", genl_cmd, genl_version, 0)
315
return nlmsg + genlmsg
316
317
318
def _genl_msg_finalize(msg):
319
return struct.pack("I", len(msg) + 4) + msg
320
321
322
def _genl_load_families():
323
with socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, Netlink.NETLINK_GENERIC) as sock:
324
sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
325
326
msg = _genl_msg(Netlink.GENL_ID_CTRL,
327
Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK | Netlink.NLM_F_DUMP,
328
Netlink.CTRL_CMD_GETFAMILY, 1)
329
msg = _genl_msg_finalize(msg)
330
331
sock.send(msg, 0)
332
333
global genl_family_name_to_id
334
genl_family_name_to_id = dict()
335
336
while True:
337
reply = sock.recv(128 * 1024)
338
nms = NlMsgs(reply)
339
for nl_msg in nms:
340
if nl_msg.error:
341
print("Netlink error:", nl_msg.error)
342
return
343
if nl_msg.done:
344
return
345
346
gm = GenlMsg(nl_msg)
347
fam = dict()
348
for attr in NlAttrs(gm.raw):
349
if attr.type == Netlink.CTRL_ATTR_FAMILY_ID:
350
fam['id'] = attr.as_scalar('u16')
351
elif attr.type == Netlink.CTRL_ATTR_FAMILY_NAME:
352
fam['name'] = attr.as_strz()
353
elif attr.type == Netlink.CTRL_ATTR_MAXATTR:
354
fam['maxattr'] = attr.as_scalar('u32')
355
elif attr.type == Netlink.CTRL_ATTR_MCAST_GROUPS:
356
fam['mcast'] = dict()
357
for entry in NlAttrs(attr.raw):
358
mcast_name = None
359
mcast_id = None
360
for entry_attr in NlAttrs(entry.raw):
361
if entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_NAME:
362
mcast_name = entry_attr.as_strz()
363
elif entry_attr.type == Netlink.CTRL_ATTR_MCAST_GRP_ID:
364
mcast_id = entry_attr.as_scalar('u32')
365
if mcast_name and mcast_id is not None:
366
fam['mcast'][mcast_name] = mcast_id
367
if 'name' in fam and 'id' in fam:
368
genl_family_name_to_id[fam['name']] = fam
369
370
371
class GenlMsg:
372
def __init__(self, nl_msg):
373
self.nl = nl_msg
374
self.genl_cmd, self.genl_version, _ = struct.unpack_from("BBH", nl_msg.raw, 0)
375
self.raw = nl_msg.raw[4:]
376
377
def cmd(self):
378
return self.genl_cmd
379
380
def __repr__(self):
381
msg = repr(self.nl)
382
msg += f"\tgenl_cmd = {self.genl_cmd} genl_ver = {self.genl_version}\n"
383
for a in self.raw_attrs:
384
msg += '\t\t' + repr(a) + '\n'
385
return msg
386
387
388
class NetlinkProtocol:
389
def __init__(self, family_name, proto_num):
390
self.family_name = family_name
391
self.proto_num = proto_num
392
393
def _message(self, nl_type, nl_flags, seq=None):
394
if seq is None:
395
seq = random.randint(1, 1024)
396
nlmsg = struct.pack("HHII", nl_type, nl_flags, seq, 0)
397
return nlmsg
398
399
def message(self, flags, command, version, seq=None):
400
return self._message(command, flags, seq)
401
402
def _decode(self, nl_msg):
403
return nl_msg
404
405
def decode(self, ynl, nl_msg, op):
406
msg = self._decode(nl_msg)
407
if op is None:
408
op = ynl.rsp_by_value[msg.cmd()]
409
fixed_header_size = ynl._struct_size(op.fixed_header)
410
msg.raw_attrs = NlAttrs(msg.raw, fixed_header_size)
411
return msg
412
413
def get_mcast_id(self, mcast_name, mcast_groups):
414
if mcast_name not in mcast_groups:
415
raise Exception(f'Multicast group "{mcast_name}" not present in the spec')
416
return mcast_groups[mcast_name].value
417
418
def msghdr_size(self):
419
return 16
420
421
422
class GenlProtocol(NetlinkProtocol):
423
def __init__(self, family_name):
424
super().__init__(family_name, Netlink.NETLINK_GENERIC)
425
426
global genl_family_name_to_id
427
if genl_family_name_to_id is None:
428
_genl_load_families()
429
430
self.genl_family = genl_family_name_to_id[family_name]
431
self.family_id = genl_family_name_to_id[family_name]['id']
432
433
def message(self, flags, command, version, seq=None):
434
nlmsg = self._message(self.family_id, flags, seq)
435
genlmsg = struct.pack("BBH", command, version, 0)
436
return nlmsg + genlmsg
437
438
def _decode(self, nl_msg):
439
return GenlMsg(nl_msg)
440
441
def get_mcast_id(self, mcast_name, mcast_groups):
442
if mcast_name not in self.genl_family['mcast']:
443
raise Exception(f'Multicast group "{mcast_name}" not present in the family')
444
return self.genl_family['mcast'][mcast_name]
445
446
def msghdr_size(self):
447
return super().msghdr_size() + 4
448
449
450
class SpaceAttrs:
451
SpecValuesPair = namedtuple('SpecValuesPair', ['spec', 'values'])
452
453
def __init__(self, attr_space, attrs, outer = None):
454
outer_scopes = outer.scopes if outer else []
455
inner_scope = self.SpecValuesPair(attr_space, attrs)
456
self.scopes = [inner_scope] + outer_scopes
457
458
def lookup(self, name):
459
for scope in self.scopes:
460
if name in scope.spec:
461
if name in scope.values:
462
return scope.values[name]
463
spec_name = scope.spec.yaml['name']
464
raise Exception(
465
f"No value for '{name}' in attribute space '{spec_name}'")
466
raise Exception(f"Attribute '{name}' not defined in any attribute-set")
467
468
469
#
470
# YNL implementation details.
471
#
472
473
474
class YnlFamily(SpecFamily):
475
def __init__(self, def_path, schema=None, process_unknown=False,
476
recv_size=0):
477
super().__init__(def_path, schema)
478
479
self.include_raw = False
480
self.process_unknown = process_unknown
481
482
try:
483
if self.proto == "netlink-raw":
484
self.nlproto = NetlinkProtocol(self.yaml['name'],
485
self.yaml['protonum'])
486
else:
487
self.nlproto = GenlProtocol(self.yaml['name'])
488
except KeyError:
489
raise Exception(f"Family '{self.yaml['name']}' not supported by the kernel")
490
491
self._recv_dbg = False
492
# Note that netlink will use conservative (min) message size for
493
# the first dump recv() on the socket, our setting will only matter
494
# from the second recv() on.
495
self._recv_size = recv_size if recv_size else 131072
496
# Netlink will always allocate at least PAGE_SIZE - sizeof(skb_shinfo)
497
# for a message, so smaller receive sizes will lead to truncation.
498
# Note that the min size for other families may be larger than 4k!
499
if self._recv_size < 4000:
500
raise ConfigError()
501
502
self.sock = socket.socket(socket.AF_NETLINK, socket.SOCK_RAW, self.nlproto.proto_num)
503
self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_CAP_ACK, 1)
504
self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_EXT_ACK, 1)
505
self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_GET_STRICT_CHK, 1)
506
507
self.async_msg_ids = set()
508
self.async_msg_queue = queue.Queue()
509
510
for msg in self.msgs.values():
511
if msg.is_async:
512
self.async_msg_ids.add(msg.rsp_value)
513
514
for op_name, op in self.ops.items():
515
bound_f = functools.partial(self._op, op_name)
516
setattr(self, op.ident_name, bound_f)
517
518
519
def ntf_subscribe(self, mcast_name):
520
mcast_id = self.nlproto.get_mcast_id(mcast_name, self.mcast_groups)
521
self.sock.bind((0, 0))
522
self.sock.setsockopt(Netlink.SOL_NETLINK, Netlink.NETLINK_ADD_MEMBERSHIP,
523
mcast_id)
524
525
def set_recv_dbg(self, enabled):
526
self._recv_dbg = enabled
527
528
def _recv_dbg_print(self, reply, nl_msgs):
529
if not self._recv_dbg:
530
return
531
print("Recv: read", len(reply), "bytes,",
532
len(nl_msgs.msgs), "messages", file=sys.stderr)
533
for nl_msg in nl_msgs:
534
print(" ", nl_msg, file=sys.stderr)
535
536
def _encode_enum(self, attr_spec, value):
537
enum = self.consts[attr_spec['enum']]
538
if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
539
scalar = 0
540
if isinstance(value, str):
541
value = [value]
542
for single_value in value:
543
scalar += enum.entries[single_value].user_value(as_flags = True)
544
return scalar
545
else:
546
return enum.entries[value].user_value()
547
548
def _get_scalar(self, attr_spec, value):
549
try:
550
return int(value)
551
except (ValueError, TypeError) as e:
552
if 'enum' in attr_spec:
553
return self._encode_enum(attr_spec, value)
554
if attr_spec.display_hint:
555
return self._from_string(value, attr_spec)
556
raise e
557
558
def _add_attr(self, space, name, value, search_attrs):
559
try:
560
attr = self.attr_sets[space][name]
561
except KeyError:
562
raise Exception(f"Space '{space}' has no attribute '{name}'")
563
nl_type = attr.value
564
565
if attr.is_multi and isinstance(value, list):
566
attr_payload = b''
567
for subvalue in value:
568
attr_payload += self._add_attr(space, name, subvalue, search_attrs)
569
return attr_payload
570
571
if attr["type"] == 'nest':
572
nl_type |= Netlink.NLA_F_NESTED
573
sub_space = attr['nested-attributes']
574
attr_payload = self._add_nest_attrs(value, sub_space, search_attrs)
575
elif attr['type'] == 'indexed-array' and attr['sub-type'] == 'nest':
576
nl_type |= Netlink.NLA_F_NESTED
577
sub_space = attr['nested-attributes']
578
attr_payload = self._encode_indexed_array(value, sub_space,
579
search_attrs)
580
elif attr["type"] == 'flag':
581
if not value:
582
# If value is absent or false then skip attribute creation.
583
return b''
584
attr_payload = b''
585
elif attr["type"] == 'string':
586
attr_payload = str(value).encode('ascii') + b'\x00'
587
elif attr["type"] == 'binary':
588
if value is None:
589
attr_payload = b''
590
elif isinstance(value, bytes):
591
attr_payload = value
592
elif isinstance(value, str):
593
if attr.display_hint:
594
attr_payload = self._from_string(value, attr)
595
else:
596
attr_payload = bytes.fromhex(value)
597
elif isinstance(value, dict) and attr.struct_name:
598
attr_payload = self._encode_struct(attr.struct_name, value)
599
elif isinstance(value, list) and attr.sub_type in NlAttr.type_formats:
600
format = NlAttr.get_format(attr.sub_type)
601
attr_payload = b''.join([format.pack(x) for x in value])
602
else:
603
raise Exception(f'Unknown type for binary attribute, value: {value}')
604
elif attr['type'] in NlAttr.type_formats or attr.is_auto_scalar:
605
scalar = self._get_scalar(attr, value)
606
if attr.is_auto_scalar:
607
attr_type = attr["type"][0] + ('32' if scalar.bit_length() <= 32 else '64')
608
else:
609
attr_type = attr["type"]
610
format = NlAttr.get_format(attr_type, attr.byte_order)
611
attr_payload = format.pack(scalar)
612
elif attr['type'] in "bitfield32":
613
scalar_value = self._get_scalar(attr, value["value"])
614
scalar_selector = self._get_scalar(attr, value["selector"])
615
attr_payload = struct.pack("II", scalar_value, scalar_selector)
616
elif attr['type'] == 'sub-message':
617
msg_format, _ = self._resolve_selector(attr, search_attrs)
618
attr_payload = b''
619
if msg_format.fixed_header:
620
attr_payload += self._encode_struct(msg_format.fixed_header, value)
621
if msg_format.attr_set:
622
if msg_format.attr_set in self.attr_sets:
623
nl_type |= Netlink.NLA_F_NESTED
624
sub_attrs = SpaceAttrs(msg_format.attr_set, value, search_attrs)
625
for subname, subvalue in value.items():
626
attr_payload += self._add_attr(msg_format.attr_set,
627
subname, subvalue, sub_attrs)
628
else:
629
raise Exception(f"Unknown attribute-set '{msg_format.attr_set}'")
630
else:
631
raise Exception(f'Unknown type at {space} {name} {value} {attr["type"]}')
632
633
return self._add_attr_raw(nl_type, attr_payload)
634
635
def _add_attr_raw(self, nl_type, attr_payload):
636
pad = b'\x00' * ((4 - len(attr_payload) % 4) % 4)
637
return struct.pack('HH', len(attr_payload) + 4, nl_type) + attr_payload + pad
638
639
def _add_nest_attrs(self, value, sub_space, search_attrs):
640
sub_attrs = SpaceAttrs(self.attr_sets[sub_space], value, search_attrs)
641
attr_payload = b''
642
for subname, subvalue in value.items():
643
attr_payload += self._add_attr(sub_space, subname, subvalue,
644
sub_attrs)
645
return attr_payload
646
647
def _encode_indexed_array(self, vals, sub_space, search_attrs):
648
attr_payload = b''
649
for i, val in enumerate(vals):
650
idx = i | Netlink.NLA_F_NESTED
651
val_payload = self._add_nest_attrs(val, sub_space, search_attrs)
652
attr_payload += self._add_attr_raw(idx, val_payload)
653
return attr_payload
654
655
def _get_enum_or_unknown(self, enum, raw):
656
try:
657
name = enum.entries_by_val[raw].name
658
except KeyError as error:
659
if self.process_unknown:
660
name = f"Unknown({raw})"
661
else:
662
raise error
663
return name
664
665
def _decode_enum(self, raw, attr_spec):
666
enum = self.consts[attr_spec['enum']]
667
if enum.type == 'flags' or attr_spec.get('enum-as-flags', False):
668
i = 0
669
value = set()
670
while raw:
671
if raw & 1:
672
value.add(self._get_enum_or_unknown(enum, i))
673
raw >>= 1
674
i += 1
675
else:
676
value = self._get_enum_or_unknown(enum, raw)
677
return value
678
679
def _decode_binary(self, attr, attr_spec):
680
if attr_spec.struct_name:
681
decoded = self._decode_struct(attr.raw, attr_spec.struct_name)
682
elif attr_spec.sub_type:
683
decoded = attr.as_c_array(attr_spec.sub_type)
684
if 'enum' in attr_spec:
685
decoded = [ self._decode_enum(x, attr_spec) for x in decoded ]
686
elif attr_spec.display_hint:
687
decoded = [ self._formatted_string(x, attr_spec.display_hint)
688
for x in decoded ]
689
else:
690
decoded = attr.as_bin()
691
if attr_spec.display_hint:
692
decoded = self._formatted_string(decoded, attr_spec.display_hint)
693
return decoded
694
695
def _decode_array_attr(self, attr, attr_spec):
696
decoded = []
697
offset = 0
698
while offset < len(attr.raw):
699
item = NlAttr(attr.raw, offset)
700
offset += item.full_len
701
702
if attr_spec["sub-type"] == 'nest':
703
subattrs = self._decode(NlAttrs(item.raw), attr_spec['nested-attributes'])
704
decoded.append({ item.type: subattrs })
705
elif attr_spec["sub-type"] == 'binary':
706
subattr = item.as_bin()
707
if attr_spec.display_hint:
708
subattr = self._formatted_string(subattr, attr_spec.display_hint)
709
decoded.append(subattr)
710
elif attr_spec["sub-type"] in NlAttr.type_formats:
711
subattr = item.as_scalar(attr_spec['sub-type'], attr_spec.byte_order)
712
if 'enum' in attr_spec:
713
subattr = self._decode_enum(subattr, attr_spec)
714
elif attr_spec.display_hint:
715
subattr = self._formatted_string(subattr, attr_spec.display_hint)
716
decoded.append(subattr)
717
else:
718
raise Exception(f'Unknown {attr_spec["sub-type"]} with name {attr_spec["name"]}')
719
return decoded
720
721
def _decode_nest_type_value(self, attr, attr_spec):
722
decoded = {}
723
value = attr
724
for name in attr_spec['type-value']:
725
value = NlAttr(value.raw, 0)
726
decoded[name] = value.type
727
subattrs = self._decode(NlAttrs(value.raw), attr_spec['nested-attributes'])
728
decoded.update(subattrs)
729
return decoded
730
731
def _decode_unknown(self, attr):
732
if attr.is_nest:
733
return self._decode(NlAttrs(attr.raw), None)
734
else:
735
return attr.as_bin()
736
737
def _rsp_add(self, rsp, name, is_multi, decoded):
738
if is_multi is None:
739
if name in rsp and type(rsp[name]) is not list:
740
rsp[name] = [rsp[name]]
741
is_multi = True
742
else:
743
is_multi = False
744
745
if not is_multi:
746
rsp[name] = decoded
747
elif name in rsp:
748
rsp[name].append(decoded)
749
else:
750
rsp[name] = [decoded]
751
752
def _resolve_selector(self, attr_spec, search_attrs):
753
sub_msg = attr_spec.sub_message
754
if sub_msg not in self.sub_msgs:
755
raise Exception(f"No sub-message spec named {sub_msg} for {attr_spec.name}")
756
sub_msg_spec = self.sub_msgs[sub_msg]
757
758
selector = attr_spec.selector
759
value = search_attrs.lookup(selector)
760
if value not in sub_msg_spec.formats:
761
raise Exception(f"No message format for '{value}' in sub-message spec '{sub_msg}'")
762
763
spec = sub_msg_spec.formats[value]
764
return spec, value
765
766
def _decode_sub_msg(self, attr, attr_spec, search_attrs):
767
msg_format, _ = self._resolve_selector(attr_spec, search_attrs)
768
decoded = {}
769
offset = 0
770
if msg_format.fixed_header:
771
decoded.update(self._decode_struct(attr.raw, msg_format.fixed_header))
772
offset = self._struct_size(msg_format.fixed_header)
773
if msg_format.attr_set:
774
if msg_format.attr_set in self.attr_sets:
775
subdict = self._decode(NlAttrs(attr.raw, offset), msg_format.attr_set)
776
decoded.update(subdict)
777
else:
778
raise Exception(f"Unknown attribute-set '{msg_format.attr_set}' when decoding '{attr_spec.name}'")
779
return decoded
780
781
def _decode(self, attrs, space, outer_attrs = None):
782
rsp = dict()
783
if space:
784
attr_space = self.attr_sets[space]
785
search_attrs = SpaceAttrs(attr_space, rsp, outer_attrs)
786
787
for attr in attrs:
788
try:
789
attr_spec = attr_space.attrs_by_val[attr.type]
790
except (KeyError, UnboundLocalError):
791
if not self.process_unknown:
792
raise Exception(f"Space '{space}' has no attribute with value '{attr.type}'")
793
attr_name = f"UnknownAttr({attr.type})"
794
self._rsp_add(rsp, attr_name, None, self._decode_unknown(attr))
795
continue
796
797
try:
798
if attr_spec["type"] == 'nest':
799
subdict = self._decode(NlAttrs(attr.raw), attr_spec['nested-attributes'], search_attrs)
800
decoded = subdict
801
elif attr_spec["type"] == 'string':
802
decoded = attr.as_strz()
803
elif attr_spec["type"] == 'binary':
804
decoded = self._decode_binary(attr, attr_spec)
805
elif attr_spec["type"] == 'flag':
806
decoded = True
807
elif attr_spec.is_auto_scalar:
808
decoded = attr.as_auto_scalar(attr_spec['type'], attr_spec.byte_order)
809
if 'enum' in attr_spec:
810
decoded = self._decode_enum(decoded, attr_spec)
811
elif attr_spec["type"] in NlAttr.type_formats:
812
decoded = attr.as_scalar(attr_spec['type'], attr_spec.byte_order)
813
if 'enum' in attr_spec:
814
decoded = self._decode_enum(decoded, attr_spec)
815
elif attr_spec.display_hint:
816
decoded = self._formatted_string(decoded, attr_spec.display_hint)
817
elif attr_spec["type"] == 'indexed-array':
818
decoded = self._decode_array_attr(attr, attr_spec)
819
elif attr_spec["type"] == 'bitfield32':
820
value, selector = struct.unpack("II", attr.raw)
821
if 'enum' in attr_spec:
822
value = self._decode_enum(value, attr_spec)
823
selector = self._decode_enum(selector, attr_spec)
824
decoded = {"value": value, "selector": selector}
825
elif attr_spec["type"] == 'sub-message':
826
decoded = self._decode_sub_msg(attr, attr_spec, search_attrs)
827
elif attr_spec["type"] == 'nest-type-value':
828
decoded = self._decode_nest_type_value(attr, attr_spec)
829
else:
830
if not self.process_unknown:
831
raise Exception(f'Unknown {attr_spec["type"]} with name {attr_spec["name"]}')
832
decoded = self._decode_unknown(attr)
833
834
self._rsp_add(rsp, attr_spec["name"], attr_spec.is_multi, decoded)
835
except:
836
print(f"Error decoding '{attr_spec.name}' from '{space}'")
837
raise
838
839
return rsp
840
841
def _decode_extack_path(self, attrs, attr_set, offset, target, search_attrs):
842
for attr in attrs:
843
try:
844
attr_spec = attr_set.attrs_by_val[attr.type]
845
except KeyError:
846
raise Exception(f"Space '{attr_set.name}' has no attribute with value '{attr.type}'")
847
if offset > target:
848
break
849
if offset == target:
850
return '.' + attr_spec.name
851
852
if offset + attr.full_len <= target:
853
offset += attr.full_len
854
continue
855
856
pathname = attr_spec.name
857
if attr_spec['type'] == 'nest':
858
sub_attrs = self.attr_sets[attr_spec['nested-attributes']]
859
search_attrs = SpaceAttrs(sub_attrs, search_attrs.lookup(attr_spec['name']))
860
elif attr_spec['type'] == 'sub-message':
861
msg_format, value = self._resolve_selector(attr_spec, search_attrs)
862
if msg_format is None:
863
raise Exception(f"Can't resolve sub-message of {attr_spec['name']} for extack")
864
sub_attrs = self.attr_sets[msg_format.attr_set]
865
pathname += f"({value})"
866
else:
867
raise Exception(f"Can't dive into {attr.type} ({attr_spec['name']}) for extack")
868
offset += 4
869
subpath = self._decode_extack_path(NlAttrs(attr.raw), sub_attrs,
870
offset, target, search_attrs)
871
if subpath is None:
872
return None
873
return '.' + pathname + subpath
874
875
return None
876
877
def _decode_extack(self, request, op, extack, vals):
878
if 'bad-attr-offs' not in extack:
879
return
880
881
msg = self.nlproto.decode(self, NlMsg(request, 0, op.attr_set), op)
882
offset = self.nlproto.msghdr_size() + self._struct_size(op.fixed_header)
883
search_attrs = SpaceAttrs(op.attr_set, vals)
884
path = self._decode_extack_path(msg.raw_attrs, op.attr_set, offset,
885
extack['bad-attr-offs'], search_attrs)
886
if path:
887
del extack['bad-attr-offs']
888
extack['bad-attr'] = path
889
890
def _struct_size(self, name):
891
if name:
892
members = self.consts[name].members
893
size = 0
894
for m in members:
895
if m.type in ['pad', 'binary']:
896
if m.struct:
897
size += self._struct_size(m.struct)
898
else:
899
size += m.len
900
else:
901
format = NlAttr.get_format(m.type, m.byte_order)
902
size += format.size
903
return size
904
else:
905
return 0
906
907
def _decode_struct(self, data, name):
908
members = self.consts[name].members
909
attrs = dict()
910
offset = 0
911
for m in members:
912
value = None
913
if m.type == 'pad':
914
offset += m.len
915
elif m.type == 'binary':
916
if m.struct:
917
len = self._struct_size(m.struct)
918
value = self._decode_struct(data[offset : offset + len],
919
m.struct)
920
offset += len
921
else:
922
value = data[offset : offset + m.len]
923
offset += m.len
924
else:
925
format = NlAttr.get_format(m.type, m.byte_order)
926
[ value ] = format.unpack_from(data, offset)
927
offset += format.size
928
if value is not None:
929
if m.enum:
930
value = self._decode_enum(value, m)
931
elif m.display_hint:
932
value = self._formatted_string(value, m.display_hint)
933
attrs[m.name] = value
934
return attrs
935
936
def _encode_struct(self, name, vals):
937
members = self.consts[name].members
938
attr_payload = b''
939
for m in members:
940
value = vals.pop(m.name) if m.name in vals else None
941
if m.type == 'pad':
942
attr_payload += bytearray(m.len)
943
elif m.type == 'binary':
944
if m.struct:
945
if value is None:
946
value = dict()
947
attr_payload += self._encode_struct(m.struct, value)
948
else:
949
if value is None:
950
attr_payload += bytearray(m.len)
951
else:
952
attr_payload += bytes.fromhex(value)
953
else:
954
if value is None:
955
value = 0
956
format = NlAttr.get_format(m.type, m.byte_order)
957
attr_payload += format.pack(value)
958
return attr_payload
959
960
def _formatted_string(self, raw, display_hint):
961
if display_hint == 'mac':
962
formatted = ':'.join('%02x' % b for b in raw)
963
elif display_hint == 'hex':
964
if isinstance(raw, int):
965
formatted = hex(raw)
966
else:
967
formatted = bytes.hex(raw, ' ')
968
elif display_hint in [ 'ipv4', 'ipv6', 'ipv4-or-v6' ]:
969
formatted = format(ipaddress.ip_address(raw))
970
elif display_hint == 'uuid':
971
formatted = str(uuid.UUID(bytes=raw))
972
else:
973
formatted = raw
974
return formatted
975
976
def _from_string(self, string, attr_spec):
977
if attr_spec.display_hint in ['ipv4', 'ipv6', 'ipv4-or-v6']:
978
ip = ipaddress.ip_address(string)
979
if attr_spec['type'] == 'binary':
980
raw = ip.packed
981
else:
982
raw = int(ip)
983
elif attr_spec.display_hint == 'hex':
984
if attr_spec['type'] == 'binary':
985
raw = bytes.fromhex(string)
986
else:
987
raw = int(string, 16)
988
elif attr_spec.display_hint == 'mac':
989
# Parse MAC address in format "00:11:22:33:44:55" or "001122334455"
990
if ':' in string:
991
mac_bytes = [int(x, 16) for x in string.split(':')]
992
else:
993
if len(string) % 2 != 0:
994
raise Exception(f"Invalid MAC address format: {string}")
995
mac_bytes = [int(string[i:i+2], 16) for i in range(0, len(string), 2)]
996
raw = bytes(mac_bytes)
997
else:
998
raise Exception(f"Display hint '{attr_spec.display_hint}' not implemented"
999
f" when parsing '{attr_spec['name']}'")
1000
return raw
1001
1002
def handle_ntf(self, decoded):
1003
msg = dict()
1004
if self.include_raw:
1005
msg['raw'] = decoded
1006
op = self.rsp_by_value[decoded.cmd()]
1007
attrs = self._decode(decoded.raw_attrs, op.attr_set.name)
1008
if op.fixed_header:
1009
attrs.update(self._decode_struct(decoded.raw, op.fixed_header))
1010
1011
msg['name'] = op['name']
1012
msg['msg'] = attrs
1013
self.async_msg_queue.put(msg)
1014
1015
def check_ntf(self):
1016
while True:
1017
try:
1018
reply = self.sock.recv(self._recv_size, socket.MSG_DONTWAIT)
1019
except BlockingIOError:
1020
return
1021
1022
nms = NlMsgs(reply)
1023
self._recv_dbg_print(reply, nms)
1024
for nl_msg in nms:
1025
if nl_msg.error:
1026
print("Netlink error in ntf!?", os.strerror(-nl_msg.error))
1027
print(nl_msg)
1028
continue
1029
if nl_msg.done:
1030
print("Netlink done while checking for ntf!?")
1031
continue
1032
1033
decoded = self.nlproto.decode(self, nl_msg, None)
1034
if decoded.cmd() not in self.async_msg_ids:
1035
print("Unexpected msg id while checking for ntf", decoded)
1036
continue
1037
1038
self.handle_ntf(decoded)
1039
1040
def poll_ntf(self, duration=None):
1041
start_time = time.time()
1042
selector = selectors.DefaultSelector()
1043
selector.register(self.sock, selectors.EVENT_READ)
1044
1045
while True:
1046
try:
1047
yield self.async_msg_queue.get_nowait()
1048
except queue.Empty:
1049
if duration is not None:
1050
timeout = start_time + duration - time.time()
1051
if timeout <= 0:
1052
return
1053
else:
1054
timeout = None
1055
events = selector.select(timeout)
1056
if events:
1057
self.check_ntf()
1058
1059
def operation_do_attributes(self, name):
1060
"""
1061
For a given operation name, find and return a supported
1062
set of attributes (as a dict).
1063
"""
1064
op = self.find_operation(name)
1065
if not op:
1066
return None
1067
1068
return op['do']['request']['attributes'].copy()
1069
1070
def _encode_message(self, op, vals, flags, req_seq):
1071
nl_flags = Netlink.NLM_F_REQUEST | Netlink.NLM_F_ACK
1072
for flag in flags or []:
1073
nl_flags |= flag
1074
1075
msg = self.nlproto.message(nl_flags, op.req_value, 1, req_seq)
1076
if op.fixed_header:
1077
msg += self._encode_struct(op.fixed_header, vals)
1078
search_attrs = SpaceAttrs(op.attr_set, vals)
1079
for name, value in vals.items():
1080
msg += self._add_attr(op.attr_set.name, name, value, search_attrs)
1081
msg = _genl_msg_finalize(msg)
1082
return msg
1083
1084
def _ops(self, ops):
1085
reqs_by_seq = {}
1086
req_seq = random.randint(1024, 65535)
1087
payload = b''
1088
for (method, vals, flags) in ops:
1089
op = self.ops[method]
1090
msg = self._encode_message(op, vals, flags, req_seq)
1091
reqs_by_seq[req_seq] = (op, vals, msg, flags)
1092
payload += msg
1093
req_seq += 1
1094
1095
self.sock.send(payload, 0)
1096
1097
done = False
1098
rsp = []
1099
op_rsp = []
1100
while not done:
1101
reply = self.sock.recv(self._recv_size)
1102
nms = NlMsgs(reply)
1103
self._recv_dbg_print(reply, nms)
1104
for nl_msg in nms:
1105
if nl_msg.nl_seq in reqs_by_seq:
1106
(op, vals, req_msg, req_flags) = reqs_by_seq[nl_msg.nl_seq]
1107
if nl_msg.extack:
1108
nl_msg.annotate_extack(op.attr_set)
1109
self._decode_extack(req_msg, op, nl_msg.extack, vals)
1110
else:
1111
op = None
1112
req_flags = []
1113
1114
if nl_msg.error:
1115
raise NlError(nl_msg)
1116
if nl_msg.done:
1117
if nl_msg.extack:
1118
print("Netlink warning:")
1119
print(nl_msg)
1120
1121
if Netlink.NLM_F_DUMP in req_flags:
1122
rsp.append(op_rsp)
1123
elif not op_rsp:
1124
rsp.append(None)
1125
elif len(op_rsp) == 1:
1126
rsp.append(op_rsp[0])
1127
else:
1128
rsp.append(op_rsp)
1129
op_rsp = []
1130
1131
del reqs_by_seq[nl_msg.nl_seq]
1132
done = len(reqs_by_seq) == 0
1133
break
1134
1135
decoded = self.nlproto.decode(self, nl_msg, op)
1136
1137
# Check if this is a reply to our request
1138
if nl_msg.nl_seq not in reqs_by_seq or decoded.cmd() != op.rsp_value:
1139
if decoded.cmd() in self.async_msg_ids:
1140
self.handle_ntf(decoded)
1141
continue
1142
else:
1143
print('Unexpected message: ' + repr(decoded))
1144
continue
1145
1146
rsp_msg = self._decode(decoded.raw_attrs, op.attr_set.name)
1147
if op.fixed_header:
1148
rsp_msg.update(self._decode_struct(decoded.raw, op.fixed_header))
1149
op_rsp.append(rsp_msg)
1150
1151
return rsp
1152
1153
def _op(self, method, vals, flags=None, dump=False):
1154
req_flags = flags or []
1155
if dump:
1156
req_flags.append(Netlink.NLM_F_DUMP)
1157
1158
ops = [(method, vals, req_flags)]
1159
return self._ops(ops)[0]
1160
1161
def do(self, method, vals, flags=None):
1162
return self._op(method, vals, flags)
1163
1164
def dump(self, method, vals):
1165
return self._op(method, vals, dump=True)
1166
1167
def do_multi(self, ops):
1168
return self._ops(ops)
1169
1170