Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
google
GitHub Repository: google/crosvm
Path: blob/main/cros_async/src/sys/windows/io_completion_port.rs
5394 views
1
// Copyright 2022 The ChromiumOS Authors
2
// Use of this source code is governed by a BSD-style license that can be
3
// found in the LICENSE file.
4
5
//! IO completion port wrapper.
6
7
use std::collections::VecDeque;
8
use std::io;
9
use std::ptr::null_mut;
10
use std::sync::Arc;
11
use std::sync::Condvar;
12
use std::time::Duration;
13
14
use base::error;
15
use base::info;
16
use base::AsRawDescriptor;
17
use base::Error as SysError;
18
use base::Event;
19
use base::EventWaitResult;
20
use base::FromRawDescriptor;
21
use base::RawDescriptor;
22
use base::SafeDescriptor;
23
use base::WorkerThread;
24
use smallvec::smallvec;
25
use smallvec::SmallVec;
26
use sync::Mutex;
27
use winapi::shared::minwindef::BOOL;
28
use winapi::shared::minwindef::DWORD;
29
use winapi::shared::minwindef::ULONG;
30
use winapi::um::handleapi::INVALID_HANDLE_VALUE;
31
use winapi::um::ioapiset::CreateIoCompletionPort;
32
use winapi::um::ioapiset::GetOverlappedResult;
33
use winapi::um::ioapiset::GetQueuedCompletionStatus;
34
use winapi::um::ioapiset::GetQueuedCompletionStatusEx;
35
use winapi::um::ioapiset::PostQueuedCompletionStatus;
36
use winapi::um::minwinbase::LPOVERLAPPED_ENTRY;
37
use winapi::um::minwinbase::OVERLAPPED;
38
use winapi::um::minwinbase::OVERLAPPED_ENTRY;
39
use winapi::um::winbase::INFINITE;
40
41
use super::handle_executor::Error;
42
use super::handle_executor::Result;
43
44
/// The number of IOCP packets we accept per poll operation.
45
/// Because this is only used for SmallVec sizes, clippy thinks it is unused.
46
#[allow(dead_code)]
47
const ENTRIES_PER_POLL: usize = 16;
48
49
/// A minimal version of completion packets from an IoCompletionPort.
50
pub(crate) struct CompletionPacket {
51
pub completion_key: usize,
52
pub overlapped_ptr: usize,
53
pub result: std::result::Result<usize, SysError>,
54
}
55
56
struct Port {
57
inner: RawDescriptor,
58
}
59
60
// SAFETY:
61
// Safe because the Port is dropped before IoCompletionPort goes out of scope
62
unsafe impl Send for Port {}
63
64
/// Wraps an IO Completion Port (iocp). These ports are very similar to an epoll
65
/// context on unix. Handles (equivalent to FDs) we want to wait on for
66
/// readiness are added to the port, and then the port can be waited on using a
67
/// syscall (GetQueuedCompletionStatus). IOCP is a little more flexible than
68
/// epoll because custom messages can be enqueued and received from the port
69
/// just like if a handle became ready (see [IoCompletionPort::post_status]).
70
///
71
/// Note that completion ports can only be subscribed to a handle, they
72
/// can never be unsubscribed. Handles are removed from the port automatically when they are closed.
73
///
74
/// Registered handles have their completion key set to their handle number.
75
pub(crate) struct IoCompletionPort {
76
port: SafeDescriptor,
77
threads: Vec<WorkerThread<()>>,
78
completed: Arc<(Mutex<VecDeque<CompletionPacket>>, Condvar)>,
79
concurrency: u32,
80
}
81
82
/// Gets a completion packet from the completion port. If the underlying IO operation
83
/// encountered an error, it will be contained inside the completion packet. If this method
84
/// encountered an error getting a completion packet, the error will be returned directly.
85
/// Safety: caller needs to ensure that the `handle` is valid and is for io completion port.
86
#[deny(unsafe_op_in_unsafe_fn)]
87
unsafe fn get_completion_status(
88
handle: RawDescriptor,
89
timeout: DWORD,
90
) -> io::Result<CompletionPacket> {
91
let mut bytes_transferred = 0;
92
let mut completion_key = 0;
93
// SAFETY: trivially safe
94
let mut overlapped: *mut OVERLAPPED = unsafe { std::mem::zeroed() };
95
96
// SAFETY:
97
// Safe because:
98
// 1. Memory of pointers passed is stack allocated and lives as long as the syscall.
99
// 2. We check the error so we don't use invalid output values (e.g. overlapped).
100
let success = unsafe {
101
GetQueuedCompletionStatus(
102
handle,
103
&mut bytes_transferred,
104
&mut completion_key,
105
&mut overlapped as *mut *mut OVERLAPPED,
106
timeout,
107
)
108
} != 0;
109
110
if success {
111
return Ok(CompletionPacket {
112
result: Ok(bytes_transferred as usize),
113
completion_key,
114
overlapped_ptr: overlapped as usize,
115
});
116
}
117
118
// Did the IOCP operation fail, or did the overlapped operation fail?
119
if overlapped.is_null() {
120
// IOCP failed somehow.
121
Err(io::Error::last_os_error())
122
} else {
123
// Overlapped operation failed.
124
Ok(CompletionPacket {
125
result: Err(SysError::last()),
126
completion_key,
127
overlapped_ptr: overlapped as usize,
128
})
129
}
130
}
131
132
/// Waits for completion events to arrive & returns the completion keys.
133
/// Safety: caller needs to ensure that the `handle` is valid and is for io completion port.
134
#[deny(unsafe_op_in_unsafe_fn)]
135
unsafe fn poll(port: RawDescriptor) -> Result<Vec<CompletionPacket>> {
136
let mut completion_packets = vec![];
137
completion_packets.push(
138
// SAFETY: caller has ensured that the handle is valid and is for io completion port
139
unsafe {
140
get_completion_status(port, INFINITE)
141
.map_err(|e| Error::IocpOperationFailed(SysError::from(e)))?
142
},
143
);
144
145
// Drain any waiting completion packets.
146
//
147
// Wondering why we don't use GetQueuedCompletionStatusEx instead? Well, there's no way to
148
// get detailed error information for each of the returned overlapped IO operations without
149
// calling GetOverlappedResult. If we have to do that, then it's cheaper to just get each
150
// completion packet individually.
151
while completion_packets.len() < ENTRIES_PER_POLL {
152
// SAFETY:
153
// Safety: caller has ensured that the handle is valid and is for io completion port
154
match unsafe { get_completion_status(port, 0) } {
155
Ok(pkt) => {
156
completion_packets.push(pkt);
157
}
158
Err(e) if e.kind() == io::ErrorKind::TimedOut => break,
159
Err(e) => return Err(Error::IocpOperationFailed(SysError::from(e))),
160
}
161
}
162
163
Ok(completion_packets)
164
}
165
166
/// Safety: caller needs to ensure that the `handle` is valid and is for io completion port.
167
fn iocp_waiter_thread(
168
port: Arc<Mutex<Port>>,
169
kill_evt: Event,
170
completed: Arc<(Mutex<VecDeque<CompletionPacket>>, Condvar)>,
171
) -> Result<()> {
172
let port = port.lock();
173
loop {
174
// SAFETY: caller has ensured that the handle is valid and is for io completion port
175
let packets = unsafe { poll(port.inner)? };
176
if !packets.is_empty() {
177
{
178
let mut c = completed.0.lock();
179
for packet in packets {
180
c.push_back(packet);
181
}
182
completed.1.notify_one();
183
}
184
}
185
if kill_evt
186
.wait_timeout(Duration::from_nanos(0))
187
.map_err(Error::IocpOperationFailed)?
188
== EventWaitResult::Signaled
189
{
190
return Ok(());
191
}
192
}
193
}
194
195
impl Drop for IoCompletionPort {
196
fn drop(&mut self) {
197
if !self.threaded() {
198
return;
199
}
200
201
let mut threads = std::mem::take(&mut self.threads);
202
for thread in &mut threads {
203
// let the thread know that it should exit
204
if let Err(e) = thread.signal() {
205
error!("faild to signal iocp thread: {}", e);
206
}
207
}
208
209
// interrupt all poll/get status on ports.
210
// Single thread can consume more ENTRIES_PER_POLL number of completion statuses.
211
// We send enough post_status so that all threads have enough data to be woken up by the
212
// completion ports.
213
// This is slightly unpleasant way to interrupt all the threads.
214
for _ in 0..(threads.len() * ENTRIES_PER_POLL) {
215
if let Err(e) = self.wake() {
216
error!("post_status failed during thread exit:{}", e);
217
}
218
}
219
}
220
}
221
222
impl IoCompletionPort {
223
pub fn new(concurrency: u32) -> Result<Self> {
224
let completed = Arc::new((Mutex::new(VecDeque::new()), Condvar::new()));
225
// Unwrap is safe because we're creating a new IOCP and will receive the owned handle
226
// back.
227
let port = create_iocp(None, None, 0, concurrency)?.unwrap();
228
let mut threads = vec![];
229
if concurrency > 1 {
230
info!("creating iocp with concurrency: {}", concurrency);
231
for i in 0..concurrency {
232
let completed_clone = completed.clone();
233
let port_desc = Arc::new(Mutex::new(Port {
234
inner: port.as_raw_descriptor(),
235
}));
236
threads.push(WorkerThread::start(
237
format!("overlapped_io_{i}"),
238
move |kill_evt| {
239
iocp_waiter_thread(port_desc, kill_evt, completed_clone).unwrap();
240
},
241
));
242
}
243
}
244
Ok(Self {
245
port,
246
threads,
247
completed,
248
concurrency,
249
})
250
}
251
252
fn threaded(&self) -> bool {
253
self.concurrency > 1
254
}
255
256
/// Register the provided descriptor with this completion port. Registered descriptors cannot
257
/// be deregistered. To deregister, close the descriptor.
258
pub fn register_descriptor(&self, desc: &dyn AsRawDescriptor) -> Result<()> {
259
create_iocp(
260
Some(desc),
261
Some(&self.port),
262
desc.as_raw_descriptor() as usize,
263
self.concurrency,
264
)?;
265
Ok(())
266
}
267
268
/// Posts a completion packet to the IO completion port.
269
pub fn post_status(&self, bytes_transferred: u32, completion_key: usize) -> Result<()> {
270
// SAFETY:
271
// Safe because the IOCP handle is valid.
272
let res = unsafe {
273
PostQueuedCompletionStatus(
274
self.port.as_raw_descriptor(),
275
bytes_transferred,
276
completion_key,
277
null_mut(),
278
)
279
};
280
if res == 0 {
281
return Err(Error::IocpOperationFailed(SysError::last()));
282
}
283
Ok(())
284
}
285
286
/// Wake up thread waiting on this iocp.
287
/// If there are more than one thread waiting, then you may need to call this function
288
/// multiple times.
289
pub fn wake(&self) -> Result<()> {
290
self.post_status(0, INVALID_HANDLE_VALUE as usize)
291
}
292
293
/// Get up to ENTRIES_PER_POLL completion packets from the IOCP in one shot.
294
#[allow(dead_code)]
295
fn get_completion_status_ex(
296
&self,
297
timeout: DWORD,
298
) -> Result<SmallVec<[OVERLAPPED_ENTRY; ENTRIES_PER_POLL]>> {
299
let mut overlapped_entries: SmallVec<[OVERLAPPED_ENTRY; ENTRIES_PER_POLL]> =
300
smallvec!(OVERLAPPED_ENTRY::default(); ENTRIES_PER_POLL);
301
302
let mut entries_removed: ULONG = 0;
303
// SAFETY:
304
// Safe because:
305
// 1. IOCP is guaranteed to exist by self.
306
// 2. Memory of pointers passed is stack allocated and lives as long as the syscall.
307
// 3. We check the error so we don't use invalid output values (e.g. overlapped).
308
let success = unsafe {
309
GetQueuedCompletionStatusEx(
310
self.port.as_raw_descriptor(),
311
overlapped_entries.as_mut_ptr() as LPOVERLAPPED_ENTRY,
312
ENTRIES_PER_POLL as ULONG,
313
&mut entries_removed,
314
timeout,
315
// We are normally called from a polling loop. It's more efficient (loop latency
316
// wise) to hold the thread instead of performing an alertable wait.
317
/* fAlertable= */
318
false as BOOL,
319
)
320
} != 0;
321
322
if success {
323
overlapped_entries.truncate(entries_removed as usize);
324
return Ok(overlapped_entries);
325
}
326
327
// Overlapped operation failed.
328
Err(Error::IocpOperationFailed(SysError::last()))
329
}
330
331
/// Waits for completion events to arrive & returns the completion keys.
332
fn poll_threaded(&self) -> Result<SmallVec<[CompletionPacket; ENTRIES_PER_POLL]>> {
333
let mut completion_packets = SmallVec::with_capacity(ENTRIES_PER_POLL);
334
let mut packets = self.completed.0.lock();
335
loop {
336
let len = usize::min(ENTRIES_PER_POLL, packets.len());
337
for p in packets.drain(..len) {
338
completion_packets.push(p)
339
}
340
if !completion_packets.is_empty() {
341
return Ok(completion_packets);
342
}
343
packets = self.completed.1.wait(packets).unwrap();
344
}
345
}
346
347
/// Waits for completion events to arrive & returns the completion keys.
348
fn poll_unthreaded(&self) -> Result<SmallVec<[CompletionPacket; ENTRIES_PER_POLL]>> {
349
// SAFETY: safe because port is in scope for the duration of the call.
350
let packets = unsafe { poll(self.port.as_raw_descriptor())? };
351
let mut completion_packets = SmallVec::with_capacity(ENTRIES_PER_POLL);
352
for pkt in packets {
353
completion_packets.push(pkt);
354
}
355
Ok(completion_packets)
356
}
357
358
pub fn poll(&self) -> Result<SmallVec<[CompletionPacket; ENTRIES_PER_POLL]>> {
359
if self.threaded() {
360
self.poll_threaded()
361
} else {
362
self.poll_unthreaded()
363
}
364
}
365
366
/// Waits for completion events to arrive & returns the completion keys. Internally uses
367
/// GetCompletionStatusEx.
368
///
369
/// WARNING: do NOT use completion keys that are not IO handles except for INVALID_HANDLE_VALUE
370
/// or undefined behavior will result.
371
#[allow(dead_code)]
372
pub fn poll_ex(&self) -> Result<SmallVec<[CompletionPacket; ENTRIES_PER_POLL]>> {
373
if self.threaded() {
374
self.poll()
375
} else {
376
self.poll_ex_unthreaded()
377
}
378
}
379
380
pub fn poll_ex_unthreaded(&self) -> Result<SmallVec<[CompletionPacket; ENTRIES_PER_POLL]>> {
381
let mut completion_packets = SmallVec::with_capacity(ENTRIES_PER_POLL);
382
let overlapped_entries = self.get_completion_status_ex(INFINITE)?;
383
384
for entry in &overlapped_entries {
385
if entry.lpCompletionKey as RawDescriptor == INVALID_HANDLE_VALUE {
386
completion_packets.push(CompletionPacket {
387
result: Ok(0),
388
completion_key: entry.lpCompletionKey,
389
overlapped_ptr: entry.lpOverlapped as usize,
390
});
391
continue;
392
}
393
394
let mut bytes_transferred = 0;
395
// SAFETY: trivially safe with return value checked
396
let success = unsafe {
397
GetOverlappedResult(
398
entry.lpCompletionKey as RawDescriptor,
399
entry.lpOverlapped,
400
&mut bytes_transferred,
401
// We don't need to wait because IOCP told us the IO is complete.
402
/* bWait= */
403
false as BOOL,
404
)
405
} != 0;
406
if success {
407
completion_packets.push(CompletionPacket {
408
result: Ok(bytes_transferred as usize),
409
completion_key: entry.lpCompletionKey,
410
overlapped_ptr: entry.lpOverlapped as usize,
411
});
412
} else {
413
completion_packets.push(CompletionPacket {
414
result: Err(SysError::last()),
415
completion_key: entry.lpCompletionKey,
416
overlapped_ptr: entry.lpOverlapped as usize,
417
});
418
}
419
}
420
Ok(completion_packets)
421
}
422
}
423
424
/// If existing_iocp is None, will return the created IOCP.
425
fn create_iocp(
426
file: Option<&dyn AsRawDescriptor>,
427
existing_iocp: Option<&dyn AsRawDescriptor>,
428
completion_key: usize,
429
concurrency: u32,
430
) -> Result<Option<SafeDescriptor>> {
431
let raw_file = match file {
432
Some(file) => file.as_raw_descriptor(),
433
None => INVALID_HANDLE_VALUE,
434
};
435
let raw_existing_iocp = match existing_iocp {
436
Some(iocp) => iocp.as_raw_descriptor(),
437
None => null_mut(),
438
};
439
440
let port =
441
// SAFETY:
442
// Safe because:
443
// 1. The file handle is open because we have a reference to it.
444
// 2. The existing IOCP (if applicable) is valid.
445
unsafe { CreateIoCompletionPort(raw_file, raw_existing_iocp, completion_key, concurrency) };
446
447
if port.is_null() {
448
return Err(Error::IocpOperationFailed(SysError::last()));
449
}
450
451
if existing_iocp.is_some() {
452
Ok(None)
453
} else {
454
// SAFETY:
455
// Safe because:
456
// 1. We are creating a new IOCP.
457
// 2. We exclusively own the handle.
458
// 3. The handle is valid since CreateIoCompletionPort returned without errors.
459
Ok(Some(unsafe { SafeDescriptor::from_raw_descriptor(port) }))
460
}
461
}
462
463
#[cfg(test)]
464
mod tests {
465
use std::fs::File;
466
use std::fs::OpenOptions;
467
use std::os::windows::fs::OpenOptionsExt;
468
use std::path::PathBuf;
469
470
use tempfile::TempDir;
471
use winapi::um::winbase::FILE_FLAG_OVERLAPPED;
472
473
use super::*;
474
475
static TEST_IO_CONCURRENCY: u32 = 4;
476
477
fn tempfile_path() -> (PathBuf, TempDir) {
478
let dir = tempfile::TempDir::new().unwrap();
479
let mut file_path = PathBuf::from(dir.path());
480
file_path.push("test");
481
(file_path, dir)
482
}
483
484
fn open_overlapped(path: &PathBuf) -> File {
485
OpenOptions::new()
486
.create(true)
487
.read(true)
488
.write(true)
489
.custom_flags(FILE_FLAG_OVERLAPPED)
490
.open(path)
491
.unwrap()
492
}
493
494
fn basic_iocp_test_with(concurrency: u32) {
495
let iocp = IoCompletionPort::new(concurrency).unwrap();
496
let (file_path, _tmpdir) = tempfile_path();
497
let mut overlapped = OVERLAPPED::default();
498
let f = open_overlapped(&file_path);
499
500
iocp.register_descriptor(&f).unwrap();
501
let buf = [0u8; 16];
502
// SAFETY: Safe given file is valid, buffers are allocated and initialized and return value
503
// is checked.
504
unsafe {
505
base::windows::write_file(&f, buf.as_ptr(), buf.len(), Some(&mut overlapped)).unwrap()
506
};
507
assert_eq!(iocp.poll().unwrap().len(), 1);
508
}
509
510
#[test]
511
fn basic_iocp_test_unthreaded() {
512
basic_iocp_test_with(1)
513
}
514
515
#[test]
516
fn basic_iocp_test_threaded() {
517
basic_iocp_test_with(TEST_IO_CONCURRENCY)
518
}
519
520
fn basic_iocp_test_poll_ex(concurrency: u32) {
521
let iocp = IoCompletionPort::new(concurrency).unwrap();
522
let (file_path, _tmpdir) = tempfile_path();
523
let mut overlapped = OVERLAPPED::default();
524
let f = open_overlapped(&file_path);
525
526
iocp.register_descriptor(&f).unwrap();
527
let buf = [0u8; 16];
528
// SAFETY: Safe given file is valid, buffers are allocated and initialized and return value
529
// is checked.
530
unsafe {
531
base::windows::write_file(&f, buf.as_ptr(), buf.len(), Some(&mut overlapped)).unwrap()
532
};
533
assert_eq!(iocp.poll_ex().unwrap().len(), 1);
534
}
535
536
#[test]
537
fn basic_iocp_test_poll_ex_unthreaded() {
538
basic_iocp_test_poll_ex(1);
539
}
540
541
#[test]
542
fn basic_iocp_test_poll_ex_threaded() {
543
basic_iocp_test_poll_ex(TEST_IO_CONCURRENCY);
544
}
545
}
546
547