Path: blob/main/tests/atf_python/sys/netlink/netlink.py
39553 views
#!/usr/local/bin/python31import os2import socket3import sys4from ctypes import c_int5from ctypes import c_ubyte6from ctypes import c_uint7from ctypes import c_ushort8from ctypes import sizeof9from ctypes import Structure10from enum import auto11from enum import Enum1213from atf_python.sys.netlink.attrs import NlAttr14from atf_python.sys.netlink.attrs import NlAttrStr15from atf_python.sys.netlink.attrs import NlAttrU3216from atf_python.sys.netlink.base_headers import GenlMsgHdr17from atf_python.sys.netlink.base_headers import NlmBaseFlags18from atf_python.sys.netlink.base_headers import Nlmsghdr19from atf_python.sys.netlink.base_headers import NlMsgType20from atf_python.sys.netlink.message import BaseNetlinkMessage21from atf_python.sys.netlink.message import NlMsgCategory22from atf_python.sys.netlink.message import NlMsgProps23from atf_python.sys.netlink.message import StdNetlinkMessage24from atf_python.sys.netlink.netlink_generic import GenlCtrlAttrType25from atf_python.sys.netlink.netlink_generic import GenlCtrlMsgType26from atf_python.sys.netlink.netlink_generic import handler_classes as genl_classes27from atf_python.sys.netlink.netlink_route import handler_classes as rt_classes28from atf_python.sys.netlink.utils import align429from atf_python.sys.netlink.utils import AttrDescr30from atf_python.sys.netlink.utils import build_propmap31from atf_python.sys.netlink.utils import enum_or_int32from atf_python.sys.netlink.utils import get_bitmask_map33from atf_python.sys.netlink.utils import NlConst34from atf_python.sys.netlink.utils import prepare_attrs_map353637class SockaddrNl(Structure):38_fields_ = [39("nl_len", c_ubyte),40("nl_family", c_ubyte),41("nl_pad", c_ushort),42("nl_pid", c_uint),43("nl_groups", c_uint),44]454647class Nlmsgdone(Structure):48_fields_ = [49("error", c_int),50]515253class Nlmsgerr(Structure):54_fields_ = [55("error", c_int),56("msg", Nlmsghdr),57]585960class NlErrattrType(Enum):61NLMSGERR_ATTR_UNUSED = 062NLMSGERR_ATTR_MSG = auto()63NLMSGERR_ATTR_OFFS = auto()64NLMSGERR_ATTR_COOKIE = auto()65NLMSGERR_ATTR_POLICY = auto()666768class AddressFamilyLinux(Enum):69AF_INET = socket.AF_INET70AF_INET6 = socket.AF_INET671AF_NETLINK = 16727374class AddressFamilyBsd(Enum):75AF_INET = socket.AF_INET76AF_INET6 = socket.AF_INET677AF_NETLINK = 38787980class NlHelper:81def __init__(self):82self._pmap = {}83self._af_cls = self.get_af_cls()84self._seq_counter = 185self.pid = os.getpid()8687def get_seq(self):88ret = self._seq_counter89self._seq_counter += 190return ret9192def get_af_cls(self):93if sys.platform.startswith("freebsd"):94cls = AddressFamilyBsd95else:96cls = AddressFamilyLinux97return cls9899def get_propmap(self, cls):100if cls not in self._pmap:101self._pmap[cls] = build_propmap(cls)102return self._pmap[cls]103104def get_name_propmap(self, cls):105ret = {}106for prop in dir(cls):107if not prop.startswith("_"):108ret[prop] = getattr(cls, prop).value109return ret110111def get_attr_byval(self, cls, attr_val):112propmap = self.get_propmap(cls)113return propmap.get(attr_val)114115def get_af_name(self, family):116v = self.get_attr_byval(self._af_cls, family)117if v is not None:118return v119return "af#{}".format(family)120121def get_af_value(self, family_str: str) -> int:122propmap = self.get_name_propmap(self._af_cls)123return propmap.get(family_str)124125def get_bitmask_str(self, cls, val):126bmap = get_bitmask_map(self.get_propmap(cls), val)127return ",".join([v for k, v in bmap.items()])128129@staticmethod130def get_bitmask_str_uncached(cls, val):131pmap = NlHelper.build_propmap(cls)132bmap = NlHelper.get_bitmask_map(pmap, val)133return ",".join([v for k, v in bmap.items()])134135136nldone_attrs = prepare_attrs_map([])137138nlerr_attrs = prepare_attrs_map(139[140AttrDescr(NlErrattrType.NLMSGERR_ATTR_MSG, NlAttrStr),141AttrDescr(NlErrattrType.NLMSGERR_ATTR_OFFS, NlAttrU32),142AttrDescr(NlErrattrType.NLMSGERR_ATTR_COOKIE, NlAttr),143]144)145146147class NetlinkDoneMessage(StdNetlinkMessage):148messages = [NlMsgProps(NlMsgType.NLMSG_DONE, NlMsgCategory.ACK)]149nl_attrs_map = nldone_attrs150151@property152def error_code(self):153return self.base_hdr.error154155def parse_base_header(self, data):156if len(data) < sizeof(Nlmsgdone):157raise ValueError("length less than nlmsgdone header")158done_hdr = Nlmsgdone.from_buffer_copy(data)159sz = sizeof(Nlmsgdone)160return (done_hdr, sz)161162def print_base_header(self, hdr, prepend=""):163print("{}error={}".format(prepend, hdr.error))164165166class NetlinkErrorMessage(StdNetlinkMessage):167messages = [NlMsgProps(NlMsgType.NLMSG_ERROR, NlMsgCategory.ACK)]168nl_attrs_map = nlerr_attrs169170@property171def error_code(self):172return self.base_hdr.error173174@property175def error_str(self):176nla = self.get_nla(NlErrattrType.NLMSGERR_ATTR_MSG)177if nla:178return nla.text179return None180181@property182def error_offset(self):183nla = self.get_nla(NlErrattrType.NLMSGERR_ATTR_OFFS)184if nla:185return nla.u32186return None187188@property189def cookie(self):190return self.get_nla(NlErrattrType.NLMSGERR_ATTR_COOKIE)191192def parse_base_header(self, data):193if len(data) < sizeof(Nlmsgerr):194raise ValueError("length less than nlmsgerr header")195err_hdr = Nlmsgerr.from_buffer_copy(data)196sz = sizeof(Nlmsgerr)197if (self.nl_hdr.nlmsg_flags & 0x100) == 0:198sz += align4(err_hdr.msg.nlmsg_len - sizeof(Nlmsghdr))199return (err_hdr, sz)200201def print_base_header(self, errhdr, prepend=""):202print("{}error={}, ".format(prepend, errhdr.error), end="")203hdr = errhdr.msg204print(205"{}len={}, type={}, flags={}(0x{:X}), seq={}, pid={}".format(206prepend,207hdr.nlmsg_len,208"msg#{}".format(hdr.nlmsg_type),209self.helper.get_bitmask_str(NlmBaseFlags, hdr.nlmsg_flags),210hdr.nlmsg_flags,211hdr.nlmsg_seq,212hdr.nlmsg_pid,213)214)215216217core_classes = {218"netlink_core": [219NetlinkDoneMessage,220NetlinkErrorMessage,221],222}223224225class Nlsock:226HANDLER_CLASSES = [core_classes, rt_classes, genl_classes]227228def __init__(self, family, helper):229self.helper = helper230self.sock_fd = self._setup_netlink(family)231self._sock_family = family232self._data = bytes()233self.msgmap = self.build_msgmap()234self._family_map = {235NlConst.GENL_ID_CTRL: "nlctrl",236}237238def build_msgmap(self):239handler_classes = {}240for d in self.HANDLER_CLASSES:241handler_classes.update(d)242xmap = {}243# 'family_name': [class.messages[MsgProps.msg], ]244for family_id, family_classes in handler_classes.items():245xmap[family_id] = {}246for cls in family_classes:247for msg_props in cls.messages:248xmap[family_id][enum_or_int(msg_props.msg)] = cls249return xmap250251def _setup_netlink(self, netlink_family) -> int:252family = self.helper.get_af_value("AF_NETLINK")253s = socket.socket(family, socket.SOCK_RAW, netlink_family)254s.setsockopt(270, 10, 1) # NETLINK_CAP_ACK255s.setsockopt(270, 11, 1) # NETLINK_EXT_ACK256return s257258def set_groups(self, mask: int):259self.sock_fd.setsockopt(socket.SOL_SOCKET, 1, mask)260# snl = SockaddrNl(nl_len = sizeof(SockaddrNl), nl_family=38,261# nl_pid=self.pid, nl_groups=mask)262# xbuffer = create_string_buffer(sizeof(SockaddrNl))263# memmove(xbuffer, addressof(snl), sizeof(SockaddrNl))264# k = struct.pack("@BBHII", 12, 38, 0, self.pid, mask)265# self.sock_fd.bind(k)266267def join_group(self, group_id: int):268self.sock_fd.setsockopt(270, 1, group_id)269270def write_message(self, msg, verbose=True):271if verbose:272print("vvvvvvvv OUT vvvvvvvv")273msg.print_message()274msg_bytes = bytes(msg)275try:276ret = os.write(self.sock_fd.fileno(), msg_bytes)277assert ret == len(msg_bytes)278except Exception as e:279print("write({}) -> {}".format(len(msg_bytes), e))280281def parse_message(self, data: bytes):282if len(data) < sizeof(Nlmsghdr):283raise Exception("Short read from nl: {} bytes".format(len(data)))284hdr = Nlmsghdr.from_buffer_copy(data)285if hdr.nlmsg_type < 16:286family_name = "netlink_core"287nlmsg_type = hdr.nlmsg_type288elif self._sock_family == NlConst.NETLINK_ROUTE:289family_name = "netlink_route"290nlmsg_type = hdr.nlmsg_type291else:292# Genetlink293if len(data) < sizeof(Nlmsghdr) + sizeof(GenlMsgHdr):294raise Exception("Short read from genl: {} bytes".format(len(data)))295family_name = self._family_map.get(hdr.nlmsg_type, "")296ghdr = GenlMsgHdr.from_buffer_copy(data[sizeof(Nlmsghdr):])297nlmsg_type = ghdr.cmd298cls = self.msgmap.get(family_name, {}).get(nlmsg_type)299if not cls:300cls = BaseNetlinkMessage301return cls.from_bytes(self.helper, data)302303def get_genl_family_id(self, family_name):304hdr = Nlmsghdr(305nlmsg_type=NlConst.GENL_ID_CTRL,306nlmsg_flags=NlmBaseFlags.NLM_F_REQUEST.value,307nlmsg_seq=self.helper.get_seq(),308)309ghdr = GenlMsgHdr(cmd=GenlCtrlMsgType.CTRL_CMD_GETFAMILY.value)310nla = NlAttrStr(GenlCtrlAttrType.CTRL_ATTR_FAMILY_NAME, family_name)311hdr.nlmsg_len = sizeof(Nlmsghdr) + sizeof(GenlMsgHdr) + len(bytes(nla))312313msg_bytes = bytes(hdr) + bytes(ghdr) + bytes(nla)314self.write_data(msg_bytes)315while True:316rx_msg = self.read_message()317if hdr.nlmsg_seq == rx_msg.nl_hdr.nlmsg_seq:318if rx_msg.is_type(NlMsgType.NLMSG_ERROR):319if rx_msg.error_code != 0:320raise ValueError("unable to get family {}".format(family_name))321else:322family_id = rx_msg.get_nla(GenlCtrlAttrType.CTRL_ATTR_FAMILY_ID).u16323self._family_map[family_id] = family_name324return family_id325raise ValueError("unable to get family {}".format(family_name))326327def write_data(self, data: bytes):328self.sock_fd.send(data)329330def read_data(self):331while True:332data = self.sock_fd.recv(65535)333self._data += data334if len(self._data) >= sizeof(Nlmsghdr):335break336337def read_message(self) -> bytes:338if len(self._data) < sizeof(Nlmsghdr):339self.read_data()340hdr = Nlmsghdr.from_buffer_copy(self._data)341while hdr.nlmsg_len > len(self._data):342self.read_data()343raw_msg = self._data[: hdr.nlmsg_len]344self._data = self._data[hdr.nlmsg_len:]345return self.parse_message(raw_msg)346347def get_reply(self, tx_msg):348self.write_message(tx_msg)349while True:350rx_msg = self.read_message()351if tx_msg.nl_hdr.nlmsg_seq == rx_msg.nl_hdr.nlmsg_seq:352return rx_msg353354355class NetlinkMultipartIterator(object):356def __init__(self, obj, seq_number: int, msg_type):357self._obj = obj358self._seq = seq_number359self._msg_type = msg_type360361def __iter__(self):362return self363364def __next__(self):365msg = self._obj.read_message()366if self._seq != msg.nl_hdr.nlmsg_seq:367raise ValueError("bad sequence number")368if msg.is_type(NlMsgType.NLMSG_ERROR):369raise ValueError(370"error while handling multipart msg: {}".format(msg.error_code)371)372elif msg.is_type(NlMsgType.NLMSG_DONE):373if msg.error_code == 0:374raise StopIteration375raise ValueError(376"error listing some parts of the multipart msg: {}".format(377msg.error_code378)379)380elif not msg.is_type(self._msg_type):381raise ValueError("bad message type: {}".format(msg))382return msg383384385class NetlinkTestTemplate(object):386REQUIRED_MODULES = ["netlink"]387388def setup_netlink(self, netlink_family: NlConst):389self.helper = NlHelper()390self.nlsock = Nlsock(netlink_family, self.helper)391392def write_message(self, msg, silent=False):393if not silent:394print("")395print("============= >> TX MESSAGE =============")396msg.print_message()397msg.print_as_bytes(bytes(msg), "-- DATA --")398self.nlsock.write_data(bytes(msg))399400def read_message(self, silent=False):401msg = self.nlsock.read_message()402if not silent:403print("")404print("============= << RX MESSAGE =============")405msg.print_message()406return msg407408def get_reply(self, tx_msg):409self.write_message(tx_msg)410while True:411rx_msg = self.read_message()412if tx_msg.nl_hdr.nlmsg_seq == rx_msg.nl_hdr.nlmsg_seq:413return rx_msg414415def read_msg_list(self, seq, msg_type):416return list(NetlinkMultipartIterator(self, seq, msg_type))417418419