Path: blob/main/cros_async/src/sys/windows/io_completion_port.rs
5394 views
// Copyright 2022 The ChromiumOS Authors1// Use of this source code is governed by a BSD-style license that can be2// found in the LICENSE file.34//! IO completion port wrapper.56use std::collections::VecDeque;7use std::io;8use std::ptr::null_mut;9use std::sync::Arc;10use std::sync::Condvar;11use std::time::Duration;1213use base::error;14use base::info;15use base::AsRawDescriptor;16use base::Error as SysError;17use base::Event;18use base::EventWaitResult;19use base::FromRawDescriptor;20use base::RawDescriptor;21use base::SafeDescriptor;22use base::WorkerThread;23use smallvec::smallvec;24use smallvec::SmallVec;25use sync::Mutex;26use winapi::shared::minwindef::BOOL;27use winapi::shared::minwindef::DWORD;28use winapi::shared::minwindef::ULONG;29use winapi::um::handleapi::INVALID_HANDLE_VALUE;30use winapi::um::ioapiset::CreateIoCompletionPort;31use winapi::um::ioapiset::GetOverlappedResult;32use winapi::um::ioapiset::GetQueuedCompletionStatus;33use winapi::um::ioapiset::GetQueuedCompletionStatusEx;34use winapi::um::ioapiset::PostQueuedCompletionStatus;35use winapi::um::minwinbase::LPOVERLAPPED_ENTRY;36use winapi::um::minwinbase::OVERLAPPED;37use winapi::um::minwinbase::OVERLAPPED_ENTRY;38use winapi::um::winbase::INFINITE;3940use super::handle_executor::Error;41use super::handle_executor::Result;4243/// The number of IOCP packets we accept per poll operation.44/// Because this is only used for SmallVec sizes, clippy thinks it is unused.45#[allow(dead_code)]46const ENTRIES_PER_POLL: usize = 16;4748/// A minimal version of completion packets from an IoCompletionPort.49pub(crate) struct CompletionPacket {50pub completion_key: usize,51pub overlapped_ptr: usize,52pub result: std::result::Result<usize, SysError>,53}5455struct Port {56inner: RawDescriptor,57}5859// SAFETY:60// Safe because the Port is dropped before IoCompletionPort goes out of scope61unsafe impl Send for Port {}6263/// Wraps an IO Completion Port (iocp). These ports are very similar to an epoll64/// context on unix. Handles (equivalent to FDs) we want to wait on for65/// readiness are added to the port, and then the port can be waited on using a66/// syscall (GetQueuedCompletionStatus). IOCP is a little more flexible than67/// epoll because custom messages can be enqueued and received from the port68/// just like if a handle became ready (see [IoCompletionPort::post_status]).69///70/// Note that completion ports can only be subscribed to a handle, they71/// can never be unsubscribed. Handles are removed from the port automatically when they are closed.72///73/// Registered handles have their completion key set to their handle number.74pub(crate) struct IoCompletionPort {75port: SafeDescriptor,76threads: Vec<WorkerThread<()>>,77completed: Arc<(Mutex<VecDeque<CompletionPacket>>, Condvar)>,78concurrency: u32,79}8081/// Gets a completion packet from the completion port. If the underlying IO operation82/// encountered an error, it will be contained inside the completion packet. If this method83/// encountered an error getting a completion packet, the error will be returned directly.84/// Safety: caller needs to ensure that the `handle` is valid and is for io completion port.85#[deny(unsafe_op_in_unsafe_fn)]86unsafe fn get_completion_status(87handle: RawDescriptor,88timeout: DWORD,89) -> io::Result<CompletionPacket> {90let mut bytes_transferred = 0;91let mut completion_key = 0;92// SAFETY: trivially safe93let mut overlapped: *mut OVERLAPPED = unsafe { std::mem::zeroed() };9495// SAFETY:96// Safe because:97// 1. Memory of pointers passed is stack allocated and lives as long as the syscall.98// 2. We check the error so we don't use invalid output values (e.g. overlapped).99let success = unsafe {100GetQueuedCompletionStatus(101handle,102&mut bytes_transferred,103&mut completion_key,104&mut overlapped as *mut *mut OVERLAPPED,105timeout,106)107} != 0;108109if success {110return Ok(CompletionPacket {111result: Ok(bytes_transferred as usize),112completion_key,113overlapped_ptr: overlapped as usize,114});115}116117// Did the IOCP operation fail, or did the overlapped operation fail?118if overlapped.is_null() {119// IOCP failed somehow.120Err(io::Error::last_os_error())121} else {122// Overlapped operation failed.123Ok(CompletionPacket {124result: Err(SysError::last()),125completion_key,126overlapped_ptr: overlapped as usize,127})128}129}130131/// Waits for completion events to arrive & returns the completion keys.132/// Safety: caller needs to ensure that the `handle` is valid and is for io completion port.133#[deny(unsafe_op_in_unsafe_fn)]134unsafe fn poll(port: RawDescriptor) -> Result<Vec<CompletionPacket>> {135let mut completion_packets = vec![];136completion_packets.push(137// SAFETY: caller has ensured that the handle is valid and is for io completion port138unsafe {139get_completion_status(port, INFINITE)140.map_err(|e| Error::IocpOperationFailed(SysError::from(e)))?141},142);143144// Drain any waiting completion packets.145//146// Wondering why we don't use GetQueuedCompletionStatusEx instead? Well, there's no way to147// get detailed error information for each of the returned overlapped IO operations without148// calling GetOverlappedResult. If we have to do that, then it's cheaper to just get each149// completion packet individually.150while completion_packets.len() < ENTRIES_PER_POLL {151// SAFETY:152// Safety: caller has ensured that the handle is valid and is for io completion port153match unsafe { get_completion_status(port, 0) } {154Ok(pkt) => {155completion_packets.push(pkt);156}157Err(e) if e.kind() == io::ErrorKind::TimedOut => break,158Err(e) => return Err(Error::IocpOperationFailed(SysError::from(e))),159}160}161162Ok(completion_packets)163}164165/// Safety: caller needs to ensure that the `handle` is valid and is for io completion port.166fn iocp_waiter_thread(167port: Arc<Mutex<Port>>,168kill_evt: Event,169completed: Arc<(Mutex<VecDeque<CompletionPacket>>, Condvar)>,170) -> Result<()> {171let port = port.lock();172loop {173// SAFETY: caller has ensured that the handle is valid and is for io completion port174let packets = unsafe { poll(port.inner)? };175if !packets.is_empty() {176{177let mut c = completed.0.lock();178for packet in packets {179c.push_back(packet);180}181completed.1.notify_one();182}183}184if kill_evt185.wait_timeout(Duration::from_nanos(0))186.map_err(Error::IocpOperationFailed)?187== EventWaitResult::Signaled188{189return Ok(());190}191}192}193194impl Drop for IoCompletionPort {195fn drop(&mut self) {196if !self.threaded() {197return;198}199200let mut threads = std::mem::take(&mut self.threads);201for thread in &mut threads {202// let the thread know that it should exit203if let Err(e) = thread.signal() {204error!("faild to signal iocp thread: {}", e);205}206}207208// interrupt all poll/get status on ports.209// Single thread can consume more ENTRIES_PER_POLL number of completion statuses.210// We send enough post_status so that all threads have enough data to be woken up by the211// completion ports.212// This is slightly unpleasant way to interrupt all the threads.213for _ in 0..(threads.len() * ENTRIES_PER_POLL) {214if let Err(e) = self.wake() {215error!("post_status failed during thread exit:{}", e);216}217}218}219}220221impl IoCompletionPort {222pub fn new(concurrency: u32) -> Result<Self> {223let completed = Arc::new((Mutex::new(VecDeque::new()), Condvar::new()));224// Unwrap is safe because we're creating a new IOCP and will receive the owned handle225// back.226let port = create_iocp(None, None, 0, concurrency)?.unwrap();227let mut threads = vec![];228if concurrency > 1 {229info!("creating iocp with concurrency: {}", concurrency);230for i in 0..concurrency {231let completed_clone = completed.clone();232let port_desc = Arc::new(Mutex::new(Port {233inner: port.as_raw_descriptor(),234}));235threads.push(WorkerThread::start(236format!("overlapped_io_{i}"),237move |kill_evt| {238iocp_waiter_thread(port_desc, kill_evt, completed_clone).unwrap();239},240));241}242}243Ok(Self {244port,245threads,246completed,247concurrency,248})249}250251fn threaded(&self) -> bool {252self.concurrency > 1253}254255/// Register the provided descriptor with this completion port. Registered descriptors cannot256/// be deregistered. To deregister, close the descriptor.257pub fn register_descriptor(&self, desc: &dyn AsRawDescriptor) -> Result<()> {258create_iocp(259Some(desc),260Some(&self.port),261desc.as_raw_descriptor() as usize,262self.concurrency,263)?;264Ok(())265}266267/// Posts a completion packet to the IO completion port.268pub fn post_status(&self, bytes_transferred: u32, completion_key: usize) -> Result<()> {269// SAFETY:270// Safe because the IOCP handle is valid.271let res = unsafe {272PostQueuedCompletionStatus(273self.port.as_raw_descriptor(),274bytes_transferred,275completion_key,276null_mut(),277)278};279if res == 0 {280return Err(Error::IocpOperationFailed(SysError::last()));281}282Ok(())283}284285/// Wake up thread waiting on this iocp.286/// If there are more than one thread waiting, then you may need to call this function287/// multiple times.288pub fn wake(&self) -> Result<()> {289self.post_status(0, INVALID_HANDLE_VALUE as usize)290}291292/// Get up to ENTRIES_PER_POLL completion packets from the IOCP in one shot.293#[allow(dead_code)]294fn get_completion_status_ex(295&self,296timeout: DWORD,297) -> Result<SmallVec<[OVERLAPPED_ENTRY; ENTRIES_PER_POLL]>> {298let mut overlapped_entries: SmallVec<[OVERLAPPED_ENTRY; ENTRIES_PER_POLL]> =299smallvec!(OVERLAPPED_ENTRY::default(); ENTRIES_PER_POLL);300301let mut entries_removed: ULONG = 0;302// SAFETY:303// Safe because:304// 1. IOCP is guaranteed to exist by self.305// 2. Memory of pointers passed is stack allocated and lives as long as the syscall.306// 3. We check the error so we don't use invalid output values (e.g. overlapped).307let success = unsafe {308GetQueuedCompletionStatusEx(309self.port.as_raw_descriptor(),310overlapped_entries.as_mut_ptr() as LPOVERLAPPED_ENTRY,311ENTRIES_PER_POLL as ULONG,312&mut entries_removed,313timeout,314// We are normally called from a polling loop. It's more efficient (loop latency315// wise) to hold the thread instead of performing an alertable wait.316/* fAlertable= */317false as BOOL,318)319} != 0;320321if success {322overlapped_entries.truncate(entries_removed as usize);323return Ok(overlapped_entries);324}325326// Overlapped operation failed.327Err(Error::IocpOperationFailed(SysError::last()))328}329330/// Waits for completion events to arrive & returns the completion keys.331fn poll_threaded(&self) -> Result<SmallVec<[CompletionPacket; ENTRIES_PER_POLL]>> {332let mut completion_packets = SmallVec::with_capacity(ENTRIES_PER_POLL);333let mut packets = self.completed.0.lock();334loop {335let len = usize::min(ENTRIES_PER_POLL, packets.len());336for p in packets.drain(..len) {337completion_packets.push(p)338}339if !completion_packets.is_empty() {340return Ok(completion_packets);341}342packets = self.completed.1.wait(packets).unwrap();343}344}345346/// Waits for completion events to arrive & returns the completion keys.347fn poll_unthreaded(&self) -> Result<SmallVec<[CompletionPacket; ENTRIES_PER_POLL]>> {348// SAFETY: safe because port is in scope for the duration of the call.349let packets = unsafe { poll(self.port.as_raw_descriptor())? };350let mut completion_packets = SmallVec::with_capacity(ENTRIES_PER_POLL);351for pkt in packets {352completion_packets.push(pkt);353}354Ok(completion_packets)355}356357pub fn poll(&self) -> Result<SmallVec<[CompletionPacket; ENTRIES_PER_POLL]>> {358if self.threaded() {359self.poll_threaded()360} else {361self.poll_unthreaded()362}363}364365/// Waits for completion events to arrive & returns the completion keys. Internally uses366/// GetCompletionStatusEx.367///368/// WARNING: do NOT use completion keys that are not IO handles except for INVALID_HANDLE_VALUE369/// or undefined behavior will result.370#[allow(dead_code)]371pub fn poll_ex(&self) -> Result<SmallVec<[CompletionPacket; ENTRIES_PER_POLL]>> {372if self.threaded() {373self.poll()374} else {375self.poll_ex_unthreaded()376}377}378379pub fn poll_ex_unthreaded(&self) -> Result<SmallVec<[CompletionPacket; ENTRIES_PER_POLL]>> {380let mut completion_packets = SmallVec::with_capacity(ENTRIES_PER_POLL);381let overlapped_entries = self.get_completion_status_ex(INFINITE)?;382383for entry in &overlapped_entries {384if entry.lpCompletionKey as RawDescriptor == INVALID_HANDLE_VALUE {385completion_packets.push(CompletionPacket {386result: Ok(0),387completion_key: entry.lpCompletionKey,388overlapped_ptr: entry.lpOverlapped as usize,389});390continue;391}392393let mut bytes_transferred = 0;394// SAFETY: trivially safe with return value checked395let success = unsafe {396GetOverlappedResult(397entry.lpCompletionKey as RawDescriptor,398entry.lpOverlapped,399&mut bytes_transferred,400// We don't need to wait because IOCP told us the IO is complete.401/* bWait= */402false as BOOL,403)404} != 0;405if success {406completion_packets.push(CompletionPacket {407result: Ok(bytes_transferred as usize),408completion_key: entry.lpCompletionKey,409overlapped_ptr: entry.lpOverlapped as usize,410});411} else {412completion_packets.push(CompletionPacket {413result: Err(SysError::last()),414completion_key: entry.lpCompletionKey,415overlapped_ptr: entry.lpOverlapped as usize,416});417}418}419Ok(completion_packets)420}421}422423/// If existing_iocp is None, will return the created IOCP.424fn create_iocp(425file: Option<&dyn AsRawDescriptor>,426existing_iocp: Option<&dyn AsRawDescriptor>,427completion_key: usize,428concurrency: u32,429) -> Result<Option<SafeDescriptor>> {430let raw_file = match file {431Some(file) => file.as_raw_descriptor(),432None => INVALID_HANDLE_VALUE,433};434let raw_existing_iocp = match existing_iocp {435Some(iocp) => iocp.as_raw_descriptor(),436None => null_mut(),437};438439let port =440// SAFETY:441// Safe because:442// 1. The file handle is open because we have a reference to it.443// 2. The existing IOCP (if applicable) is valid.444unsafe { CreateIoCompletionPort(raw_file, raw_existing_iocp, completion_key, concurrency) };445446if port.is_null() {447return Err(Error::IocpOperationFailed(SysError::last()));448}449450if existing_iocp.is_some() {451Ok(None)452} else {453// SAFETY:454// Safe because:455// 1. We are creating a new IOCP.456// 2. We exclusively own the handle.457// 3. The handle is valid since CreateIoCompletionPort returned without errors.458Ok(Some(unsafe { SafeDescriptor::from_raw_descriptor(port) }))459}460}461462#[cfg(test)]463mod tests {464use std::fs::File;465use std::fs::OpenOptions;466use std::os::windows::fs::OpenOptionsExt;467use std::path::PathBuf;468469use tempfile::TempDir;470use winapi::um::winbase::FILE_FLAG_OVERLAPPED;471472use super::*;473474static TEST_IO_CONCURRENCY: u32 = 4;475476fn tempfile_path() -> (PathBuf, TempDir) {477let dir = tempfile::TempDir::new().unwrap();478let mut file_path = PathBuf::from(dir.path());479file_path.push("test");480(file_path, dir)481}482483fn open_overlapped(path: &PathBuf) -> File {484OpenOptions::new()485.create(true)486.read(true)487.write(true)488.custom_flags(FILE_FLAG_OVERLAPPED)489.open(path)490.unwrap()491}492493fn basic_iocp_test_with(concurrency: u32) {494let iocp = IoCompletionPort::new(concurrency).unwrap();495let (file_path, _tmpdir) = tempfile_path();496let mut overlapped = OVERLAPPED::default();497let f = open_overlapped(&file_path);498499iocp.register_descriptor(&f).unwrap();500let buf = [0u8; 16];501// SAFETY: Safe given file is valid, buffers are allocated and initialized and return value502// is checked.503unsafe {504base::windows::write_file(&f, buf.as_ptr(), buf.len(), Some(&mut overlapped)).unwrap()505};506assert_eq!(iocp.poll().unwrap().len(), 1);507}508509#[test]510fn basic_iocp_test_unthreaded() {511basic_iocp_test_with(1)512}513514#[test]515fn basic_iocp_test_threaded() {516basic_iocp_test_with(TEST_IO_CONCURRENCY)517}518519fn basic_iocp_test_poll_ex(concurrency: u32) {520let iocp = IoCompletionPort::new(concurrency).unwrap();521let (file_path, _tmpdir) = tempfile_path();522let mut overlapped = OVERLAPPED::default();523let f = open_overlapped(&file_path);524525iocp.register_descriptor(&f).unwrap();526let buf = [0u8; 16];527// SAFETY: Safe given file is valid, buffers are allocated and initialized and return value528// is checked.529unsafe {530base::windows::write_file(&f, buf.as_ptr(), buf.len(), Some(&mut overlapped)).unwrap()531};532assert_eq!(iocp.poll_ex().unwrap().len(), 1);533}534535#[test]536fn basic_iocp_test_poll_ex_unthreaded() {537basic_iocp_test_poll_ex(1);538}539540#[test]541fn basic_iocp_test_poll_ex_threaded() {542basic_iocp_test_poll_ex(TEST_IO_CONCURRENCY);543}544}545546547