Path: blob/main/tests/atf_python/sys/net/rtsock.py
39553 views
#!/usr/local/bin/python31import os2import socket3import struct4import sys5from ctypes import c_byte6from ctypes import c_char7from ctypes import c_int8from ctypes import c_long9from ctypes import c_uint3210from ctypes import c_ulong11from ctypes import c_ushort12from ctypes import sizeof13from ctypes import Structure14from typing import Dict15from typing import List16from typing import Optional17from typing import Union181920def roundup2(val: int, num: int) -> int:21if val % num:22return (val | (num - 1)) + 123else:24return val252627class RtSockException(OSError):28pass293031class RtConst:32RTM_VERSION = 533ALIGN = sizeof(c_long)3435AF_INET = socket.AF_INET36AF_INET6 = socket.AF_INET637AF_LINK = socket.AF_LINK3839RTA_DST = 0x140RTA_GATEWAY = 0x241RTA_NETMASK = 0x442RTA_GENMASK = 0x843RTA_IFP = 0x1044RTA_IFA = 0x2045RTA_AUTHOR = 0x4046RTA_BRD = 0x804748RTM_ADD = 149RTM_DELETE = 250RTM_CHANGE = 351RTM_GET = 45253RTF_UP = 0x154RTF_GATEWAY = 0x255RTF_HOST = 0x456RTF_REJECT = 0x857RTF_DYNAMIC = 0x1058RTF_MODIFIED = 0x2059RTF_DONE = 0x4060RTF_XRESOLVE = 0x20061RTF_LLINFO = 0x40062RTF_LLDATA = 0x40063RTF_STATIC = 0x80064RTF_BLACKHOLE = 0x100065RTF_PROTO2 = 0x400066RTF_PROTO1 = 0x800067RTF_PROTO3 = 0x4000068RTF_FIXEDMTU = 0x8000069RTF_PINNED = 0x10000070RTF_LOCAL = 0x20000071RTF_BROADCAST = 0x40000072RTF_MULTICAST = 0x80000073RTF_STICKY = 0x1000000074RTF_RNH_LOCKED = 0x4000000075RTF_GWFLAG_COMPAT = 0x800000007677RTV_MTU = 0x178RTV_HOPCOUNT = 0x279RTV_EXPIRE = 0x480RTV_RPIPE = 0x881RTV_SPIPE = 0x1082RTV_SSTHRESH = 0x2083RTV_RTT = 0x4084RTV_RTTVAR = 0x8085RTV_WEIGHT = 0x1008687@staticmethod88def get_props(prefix: str) -> List[str]:89return [n for n in dir(RtConst) if n.startswith(prefix)]9091@staticmethod92def get_name(prefix: str, value: int) -> str:93props = RtConst.get_props(prefix)94for prop in props:95if getattr(RtConst, prop) == value:96return prop97return "U:{}:{}".format(prefix, value)9899@staticmethod100def get_bitmask_map(prefix: str, value: int) -> Dict[int, str]:101props = RtConst.get_props(prefix)102propmap = {getattr(RtConst, prop): prop for prop in props}103v = 1104ret = {}105while value:106if v & value:107if v in propmap:108ret[v] = propmap[v]109else:110ret[v] = hex(v)111value -= v112v *= 2113return ret114115@staticmethod116def get_bitmask_str(prefix: str, value: int) -> str:117bmap = RtConst.get_bitmask_map(prefix, value)118return ",".join([v for k, v in bmap.items()])119120121class RtMetrics(Structure):122_fields_ = [123("rmx_locks", c_ulong),124("rmx_mtu", c_ulong),125("rmx_hopcount", c_ulong),126("rmx_expire", c_ulong),127("rmx_recvpipe", c_ulong),128("rmx_sendpipe", c_ulong),129("rmx_ssthresh", c_ulong),130("rmx_rtt", c_ulong),131("rmx_rttvar", c_ulong),132("rmx_pksent", c_ulong),133("rmx_weight", c_ulong),134("rmx_nhidx", c_ulong),135("rmx_filler", c_ulong * 2),136]137138139class RtMsgHdr(Structure):140_fields_ = [141("rtm_msglen", c_ushort),142("rtm_version", c_byte),143("rtm_type", c_byte),144("rtm_index", c_ushort),145("_rtm_spare1", c_ushort),146("rtm_flags", c_int),147("rtm_addrs", c_int),148("rtm_pid", c_int),149("rtm_seq", c_int),150("rtm_errno", c_int),151("rtm_fmask", c_int),152("rtm_inits", c_ulong),153("rtm_rmx", RtMetrics),154]155156157class SockaddrIn(Structure):158_fields_ = [159("sin_len", c_byte),160("sin_family", c_byte),161("sin_port", c_ushort),162("sin_addr", c_uint32),163("sin_zero", c_char * 8),164]165166167class SockaddrIn6(Structure):168_fields_ = [169("sin6_len", c_byte),170("sin6_family", c_byte),171("sin6_port", c_ushort),172("sin6_flowinfo", c_uint32),173("sin6_addr", c_byte * 16),174("sin6_scope_id", c_uint32),175]176177178class SockaddrDl(Structure):179_fields_ = [180("sdl_len", c_byte),181("sdl_family", c_byte),182("sdl_index", c_ushort),183("sdl_type", c_byte),184("sdl_nlen", c_byte),185("sdl_alen", c_byte),186("sdl_slen", c_byte),187("sdl_data", c_byte * 8),188]189190191class SaHelper(object):192@staticmethod193def is_ipv6(ip: str) -> bool:194return ":" in ip195196@staticmethod197def ip_sa(ip: str, scopeid: int = 0) -> bytes:198if SaHelper.is_ipv6(ip):199return SaHelper.ip6_sa(ip, scopeid)200else:201return SaHelper.ip4_sa(ip)202203@staticmethod204def ip4_sa(ip: str) -> bytes:205addr_int = int.from_bytes(socket.inet_pton(2, ip), sys.byteorder)206sin = SockaddrIn(sizeof(SockaddrIn), socket.AF_INET, 0, addr_int)207return bytes(sin)208209@staticmethod210def ip6_sa(ip6: str, scopeid: int) -> bytes:211addr_bytes = (c_byte * 16)()212for i, b in enumerate(socket.inet_pton(socket.AF_INET6, ip6)):213addr_bytes[i] = b214sin6 = SockaddrIn6(215sizeof(SockaddrIn6), socket.AF_INET6, 0, 0, addr_bytes, scopeid216)217return bytes(sin6)218219@staticmethod220def link_sa(ifindex: int = 0, iftype: int = 0) -> bytes:221sa = SockaddrDl(sizeof(SockaddrDl), socket.AF_LINK, c_ushort(ifindex), iftype)222return bytes(sa)223224@staticmethod225def pxlen4_sa(pxlen: int) -> bytes:226return SaHelper.ip_sa(SaHelper.pxlen_to_ip4(pxlen))227228@staticmethod229def pxlen_to_ip4(pxlen: int) -> str:230if pxlen == 32:231return "255.255.255.255"232else:233addr = 0xFFFFFFFF - ((1 << (32 - pxlen)) - 1)234addr_bytes = struct.pack("!I", addr)235return socket.inet_ntop(socket.AF_INET, addr_bytes)236237@staticmethod238def pxlen6_sa(pxlen: int) -> bytes:239return SaHelper.ip_sa(SaHelper.pxlen_to_ip6(pxlen))240241@staticmethod242def pxlen_to_ip6(pxlen: int) -> str:243ip6_b = [0] * 16244start = 0245while pxlen > 8:246ip6_b[start] = 0xFF247pxlen -= 8248start += 1249ip6_b[start] = 0xFF - ((1 << (8 - pxlen)) - 1)250return socket.inet_ntop(socket.AF_INET6, bytes(ip6_b))251252@staticmethod253def print_sa_inet(sa: bytes):254if len(sa) < 8:255raise RtSockException("IPv4 sa size too small: {}".format(len(sa)))256addr = socket.inet_ntop(socket.AF_INET, sa[4:8])257return "{}".format(addr)258259@staticmethod260def print_sa_inet6(sa: bytes):261if len(sa) < sizeof(SockaddrIn6):262raise RtSockException("IPv6 sa size too small: {}".format(len(sa)))263addr = socket.inet_ntop(socket.AF_INET6, sa[8:24])264scopeid = struct.unpack(">I", sa[24:28])[0]265return "{} scopeid {}".format(addr, scopeid)266267@staticmethod268def print_sa_link(sa: bytes, hd: Optional[bool] = True):269if len(sa) < sizeof(SockaddrDl):270raise RtSockException("LINK sa size too small: {}".format(len(sa)))271sdl = SockaddrDl.from_buffer_copy(sa)272if sdl.sdl_index:273ifindex = "link#{} ".format(sdl.sdl_index)274else:275ifindex = ""276if sdl.sdl_nlen:277iface_offset = 8278if sdl.sdl_nlen + iface_offset > len(sa):279raise RtSockException(280"LINK sa sdl_nlen {} > total len {}".format(sdl.sdl_nlen, len(sa))281)282ifname = "ifname:{} ".format(283bytes.decode(sa[iface_offset : iface_offset + sdl.sdl_nlen])284)285else:286ifname = ""287return "{}{}".format(ifindex, ifname)288289@staticmethod290def print_sa_unknown(sa: bytes):291return "unknown_type:{}".format(sa[1])292293@classmethod294def print_sa(cls, sa: bytes, hd: Optional[bool] = False):295if sa[0] != len(sa):296raise Exception("sa size {} != buffer size {}".format(sa[0], len(sa)))297298if len(sa) < 2:299raise Exception(300"sa type {} too short: {}".format(301RtConst.get_name("AF_", sa[1]), len(sa)302)303)304305if sa[1] == socket.AF_INET:306text = cls.print_sa_inet(sa)307elif sa[1] == socket.AF_INET6:308text = cls.print_sa_inet6(sa)309elif sa[1] == socket.AF_LINK:310text = cls.print_sa_link(sa)311else:312text = cls.print_sa_unknown(sa)313if hd:314dump = " [{!r}]".format(sa)315else:316dump = ""317return "{}{}".format(text, dump)318319320class BaseRtsockMessage(object):321def __init__(self, rtm_type):322self.rtm_type = rtm_type323self.sa = SaHelper()324325@staticmethod326def print_rtm_type(rtm_type):327return RtConst.get_name("RTM_", rtm_type)328329@property330def rtm_type_str(self):331return self.print_rtm_type(self.rtm_type)332333334class RtsockRtMessage(BaseRtsockMessage):335messages = [336RtConst.RTM_ADD,337RtConst.RTM_DELETE,338RtConst.RTM_CHANGE,339RtConst.RTM_GET,340]341342def __init__(self, rtm_type, rtm_seq=1, dst_sa=None, mask_sa=None):343super().__init__(rtm_type)344self.rtm_flags = 0345self.rtm_seq = rtm_seq346self._attrs = {}347self.rtm_errno = 0348self.rtm_pid = 0349self.rtm_inits = 0350self.rtm_rmx = RtMetrics()351self._orig_data = None352if dst_sa:353self.add_sa_attr(RtConst.RTA_DST, dst_sa)354if mask_sa:355self.add_sa_attr(RtConst.RTA_NETMASK, mask_sa)356357def add_sa_attr(self, attr_type, attr_bytes: bytes):358self._attrs[attr_type] = attr_bytes359360def add_ip_attr(self, attr_type, ip_addr: str, scopeid: int = 0):361if ":" in ip_addr:362self.add_ip6_attr(attr_type, ip_addr, scopeid)363else:364self.add_ip4_attr(attr_type, ip_addr)365366def add_ip4_attr(self, attr_type, ip: str):367self.add_sa_attr(attr_type, self.sa.ip_sa(ip))368369def add_ip6_attr(self, attr_type, ip6: str, scopeid: int):370self.add_sa_attr(attr_type, self.sa.ip6_sa(ip6, scopeid))371372def add_link_attr(self, attr_type, ifindex: Optional[int] = 0):373self.add_sa_attr(attr_type, self.sa.link_sa(ifindex))374375def get_sa(self, attr_type) -> bytes:376return self._attrs.get(attr_type)377378def print_message(self):379# RTM_GET: Report Metrics: len 272, pid: 87839, seq 1, errno 0, flags:<UP,GATEWAY,DONE,STATIC>380if self._orig_data:381rtm_len = len(self._orig_data)382else:383rtm_len = len(bytes(self))384print(385"{}: len {}, pid: {}, seq {}, errno {}, flags: <{}>".format(386self.rtm_type_str,387rtm_len,388self.rtm_pid,389self.rtm_seq,390self.rtm_errno,391RtConst.get_bitmask_str("RTF_", self.rtm_flags),392)393)394rtm_addrs = sum(list(self._attrs.keys()))395print("Addrs: <{}>".format(RtConst.get_bitmask_str("RTA_", rtm_addrs)))396for attr in sorted(self._attrs.keys()):397sa_data = SaHelper.print_sa(self._attrs[attr])398print(" {}: {}".format(RtConst.get_name("RTA_", attr), sa_data))399400def print_in_message(self):401print("vvvvvvvv IN vvvvvvvv")402self.print_message()403print()404405def verify_sa_inet(self, sa_data):406if len(sa_data) < 8:407raise Exception("IPv4 sa size too small: {}".format(sa_data))408if sa_data[0] > len(sa_data):409raise Exception(410"IPv4 sin_len too big: {} vs sa size {}: {}".format(411sa_data[0], len(sa_data), sa_data412)413)414sin = SockaddrIn.from_buffer_copy(sa_data)415assert sin.sin_port == 0416assert sin.sin_zero == [0] * 8417418def compare_sa(self, sa_type, sa_data):419if len(sa_data) < 4:420sa_type_name = RtConst.get_name("RTA_", sa_type)421raise Exception(422"sa_len for type {} too short: {}".format(sa_type_name, len(sa_data))423)424our_sa = self._attrs[sa_type]425assert SaHelper.print_sa(sa_data) == SaHelper.print_sa(our_sa)426assert len(sa_data) == len(our_sa)427assert sa_data == our_sa428429def verify(self, rtm_type: int, rtm_sa):430assert self.rtm_type_str == self.print_rtm_type(rtm_type)431assert self.rtm_errno == 0432hdr = RtMsgHdr.from_buffer_copy(self._orig_data)433assert hdr._rtm_spare1 == 0434for sa_type, sa_data in rtm_sa.items():435if sa_type not in self._attrs:436sa_type_name = RtConst.get_name("RTA_", sa_type)437raise Exception("SA type {} not present".format(sa_type_name))438self.compare_sa(sa_type, sa_data)439440@classmethod441def from_bytes(cls, data: bytes):442if len(data) < sizeof(RtMsgHdr):443raise Exception(444"messages size {} is less than expected {}".format(445len(data), sizeof(RtMsgHdr)446)447)448hdr = RtMsgHdr.from_buffer_copy(data)449450self = cls(hdr.rtm_type)451self.rtm_flags = hdr.rtm_flags452self.rtm_seq = hdr.rtm_seq453self.rtm_errno = hdr.rtm_errno454self.rtm_pid = hdr.rtm_pid455self.rtm_inits = hdr.rtm_inits456self.rtm_rmx = hdr.rtm_rmx457self._orig_data = data458459off = sizeof(RtMsgHdr)460v = 1461addrs_mask = hdr.rtm_addrs462while addrs_mask:463if addrs_mask & v:464addrs_mask -= v465466if off + data[off] > len(data):467raise Exception(468"SA sizeof for {} > total message length: {}+{} > {}".format(469RtConst.get_name("RTA_", v), off, data[off], len(data)470)471)472self._attrs[v] = data[off : off + data[off]]473off += roundup2(data[off], RtConst.ALIGN)474v *= 2475return self476477def __bytes__(self):478sz = sizeof(RtMsgHdr)479addrs_mask = 0480for k, v in self._attrs.items():481sz += roundup2(len(v), RtConst.ALIGN)482addrs_mask += k483hdr = RtMsgHdr(484rtm_msglen=sz,485rtm_version=RtConst.RTM_VERSION,486rtm_type=self.rtm_type,487rtm_flags=self.rtm_flags,488rtm_seq=self.rtm_seq,489rtm_addrs=addrs_mask,490rtm_inits=self.rtm_inits,491rtm_rmx=self.rtm_rmx,492)493buf = bytearray(sz)494buf[0 : sizeof(RtMsgHdr)] = hdr495off = sizeof(RtMsgHdr)496for attr in sorted(self._attrs.keys()):497v = self._attrs[attr]498sa_len = len(v)499buf[off : off + sa_len] = v500off += roundup2(len(v), RtConst.ALIGN)501return bytes(buf)502503504class Rtsock:505def __init__(self):506self.socket = self._setup_rtsock()507self.rtm_seq = 1508self.msgmap = self.build_msgmap()509510def build_msgmap(self):511classes = [RtsockRtMessage]512xmap = {}513for cls in classes:514for message in cls.messages:515xmap[message] = cls516return xmap517518def get_seq(self):519ret = self.rtm_seq520self.rtm_seq += 1521return ret522523def get_weight(self, weight) -> int:524if weight:525return weight526else:527return 1 # RT_DEFAULT_WEIGHT528529def new_rtm_any(self, msg_type, prefix: str, gw: Union[str, bytes]):530px = prefix.split("/")531addr_sa = SaHelper.ip_sa(px[0])532if len(px) > 1:533pxlen = int(px[1])534if SaHelper.is_ipv6(px[0]):535mask_sa = SaHelper.pxlen6_sa(pxlen)536else:537mask_sa = SaHelper.pxlen4_sa(pxlen)538else:539mask_sa = None540msg = RtsockRtMessage(msg_type, self.get_seq(), addr_sa, mask_sa)541if isinstance(gw, bytes):542msg.add_sa_attr(RtConst.RTA_GATEWAY, gw)543else:544# String545msg.add_ip_attr(RtConst.RTA_GATEWAY, gw)546return msg547548def new_rtm_add(self, prefix: str, gw: Union[str, bytes]):549return self.new_rtm_any(RtConst.RTM_ADD, prefix, gw)550551def new_rtm_del(self, prefix: str, gw: Union[str, bytes]):552return self.new_rtm_any(RtConst.RTM_DELETE, prefix, gw)553554def new_rtm_change(self, prefix: str, gw: Union[str, bytes]):555return self.new_rtm_any(RtConst.RTM_CHANGE, prefix, gw)556557def _setup_rtsock(self) -> socket.socket:558s = socket.socket(socket.AF_ROUTE, socket.SOCK_RAW, socket.AF_UNSPEC)559s.setsockopt(socket.SOL_SOCKET, socket.SO_USELOOPBACK, 1)560return s561562def print_hd(self, data: bytes):563width = 16564print("==========================================")565for chunk in [data[i : i + width] for i in range(0, len(data), width)]:566for b in chunk:567print("0x{:02X} ".format(b), end="")568print()569print()570571def write_message(self, msg):572print("vvvvvvvv OUT vvvvvvvv")573msg.print_message()574print()575msg_bytes = bytes(msg)576ret = os.write(self.socket.fileno(), msg_bytes)577if ret != -1:578assert ret == len(msg_bytes)579580def parse_message(self, data: bytes):581if len(data) < 4:582raise OSError("Short read from rtsock: {} bytes".format(len(data)))583rtm_type = data[4]584if rtm_type not in self.msgmap:585return None586587def write_data(self, data: bytes):588self.socket.send(data)589590def read_data(self, seq: Optional[int] = None) -> bytes:591while True:592data = self.socket.recv(4096)593if seq is None:594break595if len(data) > sizeof(RtMsgHdr):596hdr = RtMsgHdr.from_buffer_copy(data)597if hdr.rtm_seq == seq:598break599return data600601def read_message(self) -> bytes:602data = self.read_data()603return self.parse_message(data)604605606