Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
google
GitHub Repository: google/crosvm
Path: blob/main/cros_async/src/sys/windows/overlapped_source.rs
5394 views
1
// Copyright 2023 The ChromiumOS Authors
2
// Use of this source code is governed by a BSD-style license that can be
3
// found in the LICENSE file.
4
5
//! EXPERIMENTAL overlapped IO based async IO wrapper. Do not use in production.
6
7
use std::fs::File;
8
use std::io;
9
use std::io::Write;
10
use std::mem::ManuallyDrop;
11
use std::sync::Arc;
12
13
use base::error;
14
use base::AsRawDescriptor;
15
use base::Descriptor;
16
use base::FromRawDescriptor;
17
use base::PunchHole;
18
use base::RawDescriptor;
19
use base::WriteZeroesAt;
20
use thiserror::Error as ThisError;
21
use winapi::um::minwinbase::OVERLAPPED;
22
23
use crate::common_executor::RawExecutor;
24
use crate::mem::BackingMemory;
25
use crate::mem::MemRegion;
26
use crate::sys::windows::handle_executor::HandleReactor;
27
use crate::sys::windows::handle_executor::RegisteredOverlappedSource;
28
use crate::AsyncError;
29
use crate::AsyncResult;
30
use crate::BlockingPool;
31
32
#[derive(ThisError, Debug)]
33
pub enum Error {
34
#[error("An error occurred trying to get a VolatileSlice into BackingMemory: {0}.")]
35
BackingMemoryVolatileSliceFetchFailed(crate::mem::Error),
36
#[error("An error occurred trying to seek: {0}.")]
37
IoSeekError(io::Error),
38
#[error("An error occurred trying to read: {0}.")]
39
IoReadError(base::Error),
40
#[error("An error occurred trying to write: {0}.")]
41
IoWriteError(base::Error),
42
#[error("An error occurred trying to flush: {0}.")]
43
IoFlushError(io::Error),
44
#[error("An error occurred trying to punch hole: {0}.")]
45
IoPunchHoleError(io::Error),
46
#[error("An error occurred trying to write zeroes: {0}.")]
47
IoWriteZeroesError(io::Error),
48
#[error("An error occurred trying to duplicate source handles: {0}.")]
49
HandleDuplicationFailed(io::Error),
50
#[error("A IO error occurred trying to read: {0}.")]
51
StdIoReadError(io::Error),
52
#[error("An IO error occurred trying to write: {0}.")]
53
StdIoWriteError(io::Error),
54
}
55
56
impl From<Error> for io::Error {
57
fn from(e: Error) -> Self {
58
use Error::*;
59
match e {
60
BackingMemoryVolatileSliceFetchFailed(e) => io::Error::other(e),
61
IoSeekError(e) => e,
62
IoReadError(e) => e.into(),
63
IoWriteError(e) => e.into(),
64
IoFlushError(e) => e,
65
IoPunchHoleError(e) => e,
66
IoWriteZeroesError(e) => e,
67
HandleDuplicationFailed(e) => e,
68
StdIoReadError(e) => e,
69
StdIoWriteError(e) => e,
70
}
71
}
72
}
73
74
impl From<Error> for AsyncError {
75
fn from(e: Error) -> AsyncError {
76
AsyncError::SysVariants(e.into())
77
}
78
}
79
80
pub type Result<T> = std::result::Result<T, Error>;
81
82
/// Async IO source for Windows that uses a multi-threaded, multi-handle approach to provide fast IO
83
/// operations. It demuxes IO requests across a set of handles that refer to the same underlying IO
84
/// source, such as a file, and executes those requests across multiple threads. Benchmarks show
85
/// that this is the fastest method to perform IO on Windows, especially for file reads.
86
pub struct OverlappedSource<F: AsRawDescriptor> {
87
blocking_pool: BlockingPool,
88
reg_source: RegisteredOverlappedSource,
89
source: F,
90
seek_forbidden: bool,
91
}
92
93
impl<F: AsRawDescriptor> OverlappedSource<F> {
94
/// Create a new `OverlappedSource` from the given IO source. The source MUST be opened in
95
/// overlapped mode or undefined behavior will result.
96
///
97
/// seek_forbidden should be set for non seekable types like named pipes.
98
pub fn new(
99
source: F,
100
ex: &Arc<RawExecutor<HandleReactor>>,
101
seek_forbidden: bool,
102
) -> AsyncResult<Self> {
103
Ok(Self {
104
blocking_pool: BlockingPool::default(),
105
reg_source: ex.reactor.register_overlapped_source(ex, &source)?,
106
source,
107
seek_forbidden,
108
})
109
}
110
}
111
112
/// SAFETY:
113
/// Safety requirements:
114
/// Same as base::windows::read_file.
115
unsafe fn read(
116
file: RawDescriptor,
117
buf: *mut u8,
118
buf_len: usize,
119
overlapped: &mut OVERLAPPED,
120
) -> AsyncResult<()> {
121
Ok(
122
base::windows::read_file(&Descriptor(file), buf, buf_len, Some(overlapped))
123
.map(|_len| ())
124
.map_err(Error::StdIoReadError)?,
125
)
126
}
127
128
/// SAFETY:
129
/// Safety requirements:
130
/// Same as base::windows::write_file.
131
unsafe fn write(
132
file: RawDescriptor,
133
buf: *const u8,
134
buf_len: usize,
135
overlapped: &mut OVERLAPPED,
136
) -> AsyncResult<()> {
137
Ok(
138
base::windows::write_file(&Descriptor(file), buf, buf_len, Some(overlapped))
139
.map(|_len| ())
140
.map_err(Error::StdIoWriteError)?,
141
)
142
}
143
144
impl<F: AsRawDescriptor> OverlappedSource<F> {
145
/// Reads from the iosource at `file_offset` and fill the given `vec`.
146
pub async fn read_to_vec(
147
&self,
148
file_offset: Option<u64>,
149
mut vec: Vec<u8>,
150
) -> AsyncResult<(usize, Vec<u8>)> {
151
if self.seek_forbidden && file_offset.is_some() {
152
return Err(Error::IoSeekError(io::Error::new(
153
io::ErrorKind::InvalidInput,
154
"seek on non-seekable handle",
155
))
156
.into());
157
}
158
let mut overlapped_op = self.reg_source.register_overlapped_operation(file_offset)?;
159
160
// SAFETY:
161
// Safe because we pass a pointer to a valid vec and that same vector's length.
162
unsafe {
163
read(
164
self.source.as_raw_descriptor(),
165
vec.as_mut_ptr(),
166
vec.len(),
167
overlapped_op.get_overlapped(),
168
)?
169
};
170
let overlapped_result = overlapped_op.await?;
171
let bytes_read = overlapped_result.result.map_err(Error::IoReadError)?;
172
Ok((bytes_read, vec))
173
}
174
175
/// Reads to the given `mem` at the given offsets from the file starting at `file_offset`.
176
pub async fn read_to_mem(
177
&self,
178
file_offset: Option<u64>,
179
mem: Arc<dyn BackingMemory + Send + Sync>,
180
mem_offsets: impl IntoIterator<Item = MemRegion>,
181
) -> AsyncResult<usize> {
182
let mut total_bytes_read = 0;
183
let mut offset = match file_offset {
184
Some(offset) if !self.seek_forbidden => Some(offset),
185
None if self.seek_forbidden => None,
186
// For devices that are seekable (files), we have to track the offset otherwise
187
// subsequent read calls will just read the same bytes into each of the memory regions.
188
None => Some(0),
189
_ => {
190
return Err(Error::IoSeekError(io::Error::new(
191
io::ErrorKind::InvalidInput,
192
"seek on non-seekable handle",
193
))
194
.into())
195
}
196
};
197
198
for region in mem_offsets.into_iter() {
199
let mut overlapped_op = self.reg_source.register_overlapped_operation(offset)?;
200
201
let slice = mem
202
.get_volatile_slice(region)
203
.map_err(Error::BackingMemoryVolatileSliceFetchFailed)?;
204
205
// SAFETY:
206
// Safe because we're passing a volatile slice (valid ptr), and the size of the memory
207
// region it refers to.
208
unsafe {
209
read(
210
self.source.as_raw_descriptor(),
211
slice.as_mut_ptr(),
212
slice.size(),
213
overlapped_op.get_overlapped(),
214
)?
215
};
216
let overlapped_result = overlapped_op.await?;
217
let bytes_read = overlapped_result.result.map_err(Error::IoReadError)?;
218
offset = offset.map(|offset| offset + bytes_read as u64);
219
total_bytes_read += bytes_read;
220
}
221
Ok(total_bytes_read)
222
}
223
224
/// Wait for the handle of `self` to be readable.
225
pub async fn wait_readable(&self) -> AsyncResult<()> {
226
unimplemented!()
227
}
228
229
/// Reads a single u64 from the current offset.
230
pub async fn read_u64(&self) -> AsyncResult<u64> {
231
unimplemented!()
232
}
233
234
/// Writes from the given `vec` to the file starting at `file_offset`.
235
pub async fn write_from_vec(
236
&self,
237
file_offset: Option<u64>,
238
vec: Vec<u8>,
239
) -> AsyncResult<(usize, Vec<u8>)> {
240
if self.seek_forbidden && file_offset.is_some() {
241
return Err(Error::IoSeekError(io::Error::new(
242
io::ErrorKind::InvalidInput,
243
"seek on non-seekable handle",
244
))
245
.into());
246
}
247
let mut overlapped_op = self.reg_source.register_overlapped_operation(file_offset)?;
248
249
// SAFETY:
250
// Safe because we pass a pointer to a valid vec and that same vector's length.
251
unsafe {
252
write(
253
self.source.as_raw_descriptor(),
254
vec.as_ptr(),
255
vec.len(),
256
overlapped_op.get_overlapped(),
257
)?
258
};
259
260
let bytes_written = overlapped_op.await?.result.map_err(Error::IoWriteError)?;
261
Ok((bytes_written, vec))
262
}
263
264
/// Writes from the given `mem` from the given offsets to the file starting at `file_offset`.
265
pub async fn write_from_mem(
266
&self,
267
file_offset: Option<u64>,
268
mem: Arc<dyn BackingMemory + Send + Sync>,
269
mem_offsets: impl IntoIterator<Item = MemRegion>,
270
) -> AsyncResult<usize> {
271
let mut total_bytes_written = 0;
272
let mut offset = match file_offset {
273
Some(offset) if !self.seek_forbidden => Some(offset),
274
None if self.seek_forbidden => None,
275
// For devices that are seekable (files), we have to track the offset otherwise
276
// subsequent read calls will just read the same bytes into each of the memory regions.
277
None => Some(0),
278
_ => {
279
return Err(Error::IoSeekError(io::Error::new(
280
io::ErrorKind::InvalidInput,
281
"seek on non-seekable handle",
282
))
283
.into())
284
}
285
};
286
287
for region in mem_offsets.into_iter() {
288
let mut overlapped_op = self.reg_source.register_overlapped_operation(offset)?;
289
290
let slice = mem
291
.get_volatile_slice(region)
292
.map_err(Error::BackingMemoryVolatileSliceFetchFailed)?;
293
294
// SAFETY:
295
// Safe because we're passing a volatile slice (valid ptr), and the size of the memory
296
// region it refers to.
297
unsafe {
298
write(
299
self.source.as_raw_descriptor(),
300
slice.as_ptr(),
301
slice.size(),
302
overlapped_op.get_overlapped(),
303
)?
304
};
305
let bytes_written = overlapped_op.await?.result.map_err(Error::IoReadError)?;
306
offset = offset.map(|offset| offset + bytes_written as u64);
307
total_bytes_written += bytes_written;
308
}
309
Ok(total_bytes_written)
310
}
311
312
/// Deallocates the given range of a file.
313
///
314
/// TODO(nkgold): currently this is sync on the executor, which is bad / very hacky. With a
315
/// little wrapper work, we can make overlapped DeviceIoControl calls instead.
316
pub async fn punch_hole(&self, file_offset: u64, len: u64) -> AsyncResult<()> {
317
if self.seek_forbidden {
318
return Err(Error::IoSeekError(io::Error::new(
319
io::ErrorKind::InvalidInput,
320
"fallocate cannot be called on a non-seekable handle",
321
))
322
.into());
323
}
324
// SAFETY:
325
// Safe because self.source lives as long as file.
326
let file = ManuallyDrop::new(unsafe {
327
File::from_raw_descriptor(self.source.as_raw_descriptor())
328
});
329
file.punch_hole(file_offset, len)
330
.map_err(Error::IoPunchHoleError)?;
331
Ok(())
332
}
333
334
/// Fills the given range with zeroes.
335
///
336
/// TODO(nkgold): currently this is sync on the executor, which is bad / very hacky. With a
337
/// little wrapper work, we can make overlapped DeviceIoControl calls instead.
338
pub async fn write_zeroes_at(&self, file_offset: u64, len: u64) -> AsyncResult<()> {
339
if self.seek_forbidden {
340
return Err(Error::IoSeekError(io::Error::new(
341
io::ErrorKind::InvalidInput,
342
"write_zeroes_at cannot be called on a non-seekable handle",
343
))
344
.into());
345
}
346
// SAFETY:
347
// Safe because self.source lives as long as file.
348
let file = ManuallyDrop::new(unsafe {
349
File::from_raw_descriptor(self.source.as_raw_descriptor())
350
});
351
// ZeroRange calls `punch_hole` which doesn't extend the File size if it needs to.
352
// Will fix if it becomes a problem.
353
file.write_zeroes_at(file_offset, len as usize)
354
.map_err(Error::IoWriteZeroesError)?;
355
Ok(())
356
}
357
358
/// Sync all completed write operations to the backing storage.
359
pub async fn fsync(&self) -> AsyncResult<()> {
360
// SAFETY:
361
// Safe because self.source lives at least as long as the blocking pool thread. Note that
362
// if the blocking pool stalls and shutdown fails, the thread could outlive the file;
363
// however, this would mean things are already badly broken and we have a similar risk in
364
// HandleSource.
365
let mut file = unsafe {
366
ManuallyDrop::new(File::from_raw_descriptor(self.source.as_raw_descriptor()))
367
.try_clone()
368
.map_err(Error::HandleDuplicationFailed)?
369
};
370
371
Ok(self
372
.blocking_pool
373
.spawn(move || file.flush().map_err(Error::IoFlushError))
374
.await?)
375
}
376
377
/// Sync all data of completed write operations to the backing storage. Currently, the
378
/// implementation is equivalent to fsync.
379
pub async fn fdatasync(&self) -> AsyncResult<()> {
380
// TODO(b/282003931): Fall back to regular fsync.
381
self.fsync().await
382
}
383
384
/// Yields the underlying IO source.
385
pub fn into_source(self) -> F {
386
self.source
387
}
388
389
/// Provides a mutable ref to the underlying IO source.
390
pub fn as_source_mut(&mut self) -> &mut F {
391
&mut self.source
392
}
393
394
/// Provides a ref to the underlying IO source.
395
///
396
/// In the multi-source case, the 0th source will be returned. If sources are not
397
/// interchangeable, behavior is undefined.
398
pub fn as_source(&self) -> &F {
399
&self.source
400
}
401
402
pub async fn wait_for_handle(&self) -> AsyncResult<()> {
403
base::sys::windows::async_wait_for_single_object(&self.source)
404
.await
405
.map_err(super::handle_source::Error::HandleWaitFailed)?;
406
Ok(())
407
}
408
}
409
410
// NOTE: Prefer adding tests to io_source.rs if not backend specific.
411
#[cfg(test)]
412
mod tests {
413
use std::fs::OpenOptions;
414
use std::io::Read;
415
use std::os::windows::fs::OpenOptionsExt;
416
use std::path::PathBuf;
417
418
use tempfile::TempDir;
419
use winapi::um::winbase::FILE_FLAG_OVERLAPPED;
420
421
use super::*;
422
use crate::mem::VecIoWrapper;
423
use crate::ExecutorTrait;
424
425
fn tempfile_path() -> (PathBuf, TempDir) {
426
let dir = tempfile::TempDir::new().unwrap();
427
let mut file_path = PathBuf::from(dir.path());
428
file_path.push("test");
429
(file_path, dir)
430
}
431
432
fn open_overlapped(path: &PathBuf) -> File {
433
OpenOptions::new()
434
.read(true)
435
.write(true)
436
.custom_flags(FILE_FLAG_OVERLAPPED)
437
.open(path)
438
.unwrap()
439
}
440
441
fn create_overlapped(path: &PathBuf) -> File {
442
OpenOptions::new()
443
.create_new(true)
444
.read(true)
445
.write(true)
446
.custom_flags(FILE_FLAG_OVERLAPPED)
447
.open(path)
448
.unwrap()
449
}
450
451
#[test]
452
fn test_read_vec() {
453
let (file_path, _tmpdir) = tempfile_path();
454
std::fs::write(&file_path, "data").unwrap();
455
456
async fn read_data(src: &OverlappedSource<File>) {
457
let buf: Vec<u8> = vec![0; 4];
458
let (bytes_read, buf) = src.read_to_vec(Some(0), buf).await.unwrap();
459
assert_eq!(bytes_read, 4);
460
assert_eq!(std::str::from_utf8(buf.as_slice()).unwrap(), "data");
461
}
462
463
let ex = RawExecutor::<HandleReactor>::new().unwrap();
464
let src = OverlappedSource::new(open_overlapped(&file_path), &ex, false).unwrap();
465
ex.run_until(read_data(&src)).unwrap();
466
}
467
468
#[test]
469
fn test_read_mem() {
470
let (file_path, _tmpdir) = tempfile_path();
471
std::fs::write(&file_path, "data").unwrap();
472
473
async fn read_data(src: &OverlappedSource<File>) {
474
let mem = Arc::new(VecIoWrapper::from(vec![0; 4]));
475
let bytes_read = src
476
.read_to_mem(
477
Some(0),
478
Arc::<VecIoWrapper>::clone(&mem),
479
[
480
MemRegion { offset: 0, len: 2 },
481
MemRegion { offset: 2, len: 2 },
482
],
483
)
484
.await
485
.unwrap();
486
assert_eq!(bytes_read, 4);
487
let vec: Vec<u8> = match Arc::try_unwrap(mem) {
488
Ok(v) => v.into(),
489
Err(_) => panic!("Too many vec refs"),
490
};
491
assert_eq!(std::str::from_utf8(vec.as_slice()).unwrap(), "data");
492
}
493
494
let ex = RawExecutor::<HandleReactor>::new().unwrap();
495
let src = OverlappedSource::new(open_overlapped(&file_path), &ex, false).unwrap();
496
ex.run_until(read_data(&src)).unwrap();
497
}
498
499
#[test]
500
fn test_write_vec() {
501
let (file_path, _tmpdir) = tempfile_path();
502
503
async fn write_data(src: &OverlappedSource<File>) {
504
let mut buf: Vec<u8> = Vec::new();
505
buf.extend_from_slice("data".as_bytes());
506
507
let (bytes_written, _) = src.write_from_vec(Some(0), buf).await.unwrap();
508
assert_eq!(bytes_written, 4);
509
}
510
511
let ex = RawExecutor::<HandleReactor>::new().unwrap();
512
let f = create_overlapped(&file_path);
513
let src = OverlappedSource::new(f, &ex, false).unwrap();
514
ex.run_until(write_data(&src)).unwrap();
515
drop(src);
516
517
let buf = std::fs::read(&file_path).unwrap();
518
assert_eq!(buf, b"data");
519
}
520
521
#[test]
522
fn test_write_mem() {
523
let (file_path, _tmpdir) = tempfile_path();
524
525
async fn write_data(src: &OverlappedSource<File>) {
526
let mut buf: Vec<u8> = Vec::new();
527
buf.extend_from_slice("data".as_bytes());
528
let mem = Arc::new(VecIoWrapper::from(buf));
529
let bytes_written = src
530
.write_from_mem(
531
Some(0),
532
Arc::<VecIoWrapper>::clone(&mem),
533
[
534
MemRegion { offset: 0, len: 2 },
535
MemRegion { offset: 2, len: 2 },
536
],
537
)
538
.await
539
.unwrap();
540
assert_eq!(bytes_written, 4);
541
match Arc::try_unwrap(mem) {
542
Ok(_) => (),
543
Err(_) => panic!("Too many vec refs"),
544
};
545
}
546
547
let ex = RawExecutor::<HandleReactor>::new().unwrap();
548
let f = create_overlapped(&file_path);
549
let src = OverlappedSource::new(f, &ex, false).unwrap();
550
ex.run_until(write_data(&src)).unwrap();
551
drop(src);
552
553
let buf = std::fs::read(&file_path).unwrap();
554
assert_eq!(buf, b"data");
555
}
556
557
#[cfg_attr(all(target_os = "windows", target_env = "gnu"), ignore)]
558
#[test]
559
fn test_punch_holes() {
560
let (file_path, _tmpdir) = tempfile_path();
561
std::fs::write(&file_path, "abcdefghijk").unwrap();
562
563
async fn punch_hole(src: &OverlappedSource<File>) {
564
let offset = 1;
565
let len = 3;
566
src.punch_hole(offset, len).await.unwrap();
567
}
568
569
let ex = RawExecutor::<HandleReactor>::new().unwrap();
570
let f = open_overlapped(&file_path);
571
let src = OverlappedSource::new(f, &ex, false).unwrap();
572
ex.run_until(punch_hole(&src)).unwrap();
573
drop(src);
574
575
let buf = std::fs::read(&file_path).unwrap();
576
assert_eq!(buf, b"a\0\0\0efghijk");
577
}
578
579
/// Test should fail because punch hole should not be allowed to allocate more memory
580
#[cfg_attr(all(target_os = "windows", target_env = "gnu"), ignore)]
581
#[test]
582
fn test_punch_holes_fail_out_of_bounds() {
583
let (file_path, _tmpdir) = tempfile_path();
584
std::fs::write(&file_path, "abcdefghijk").unwrap();
585
586
async fn punch_hole(src: &OverlappedSource<File>) {
587
let offset = 9;
588
let len = 4;
589
src.punch_hole(offset, len).await.unwrap();
590
}
591
592
let ex = RawExecutor::<HandleReactor>::new().unwrap();
593
let f = open_overlapped(&file_path);
594
let src = OverlappedSource::new(f, &ex, false).unwrap();
595
ex.run_until(punch_hole(&src)).unwrap();
596
drop(src);
597
598
let mut buf = vec![0; 13];
599
let mut f = OpenOptions::new()
600
.read(true)
601
.write(true)
602
.open(&file_path)
603
.unwrap();
604
assert!(f.read_exact(&mut buf).is_err());
605
}
606
607
// TODO(b/194338842): "ZeroRange" is supposed to allocate more memory if it goes out of the
608
// bounds of the file. Determine if we need to support this, since Windows doesn't do this yet.
609
// use tempfile::NamedTempFile;
610
// #[test]
611
// fn test_write_zeroes() {
612
// let mut temp_file = NamedTempFile::new().unwrap();
613
// temp_file.write("abcdefghijk".as_bytes()).unwrap();
614
// temp_file.flush().unwrap();
615
// temp_file.seek(SeekFrom::Start(0)).unwrap();
616
617
// async fn punch_hole(src: &OverlappedSource<File>) {
618
// let offset = 9;
619
// let len = 4;
620
// src
621
// .fallocate(offset, len, AllocateMode::ZeroRange)
622
// .await
623
// .unwrap();
624
// }
625
626
// let ex = RawExecutor::<HandleReactor>::new();
627
// let f = fs::OpenOptions::new()
628
// .write(true)
629
// .open(temp_file.path())
630
// .unwrap();
631
// let src = OverlappedSource::new(vec![f].into_boxed_slice()).unwrap();
632
// ex.run_until(punch_hole(&src)).unwrap();
633
634
// let mut buf = vec![0; 13];
635
// temp_file.read_exact(&mut buf).unwrap();
636
// assert_eq!(
637
// std::str::from_utf8(buf.as_slice()).unwrap(),
638
// "abcdefghi\0\0\0\0"
639
// );
640
// }
641
}
642
643