use std::cmp::min;
use std::mem::size_of;
use std::ptr::copy;
use std::ptr::read_volatile;
use std::ptr::write_bytes;
use std::ptr::write_volatile;
use std::result;
use std::slice;
use remain::sorted;
use thiserror::Error;
use zerocopy::FromBytes;
use zerocopy::IntoBytes;
use crate::IoBufMut;
#[sorted]
#[derive(Error, Eq, PartialEq, Debug)]
pub enum VolatileMemoryError {
#[error("address 0x{addr:x} is out of bounds")]
OutOfBounds { addr: usize },
#[error("address 0x{base:x} offset by 0x{offset:x} would overflow")]
Overflow { base: usize, offset: usize },
}
pub type VolatileMemoryResult<T> = result::Result<T, VolatileMemoryError>;
use crate::VolatileMemoryError as Error;
type Result<T> = VolatileMemoryResult<T>;
pub trait VolatileMemory {
fn get_slice(&self, offset: usize, count: usize) -> Result<VolatileSlice>;
}
#[derive(Copy, Clone, Debug)]
#[repr(transparent)]
pub struct VolatileSlice<'a>(IoBufMut<'a>);
impl<'a> VolatileSlice<'a> {
pub fn new(buf: &mut [u8]) -> VolatileSlice {
VolatileSlice(IoBufMut::new(buf))
}
pub unsafe fn from_raw_parts(addr: *mut u8, len: usize) -> VolatileSlice<'a> {
VolatileSlice(IoBufMut::from_raw_parts(addr, len))
}
pub fn as_ptr(&self) -> *const u8 {
self.0.as_ptr()
}
pub fn as_mut_ptr(&self) -> *mut u8 {
self.0.as_mut_ptr()
}
pub fn size(&self) -> usize {
self.0.len()
}
pub fn advance(&mut self, count: usize) {
self.0.advance(count)
}
pub fn truncate(&mut self, len: usize) {
self.0.truncate(len)
}
pub fn as_iobuf(&self) -> &IoBufMut {
&self.0
}
#[allow(clippy::wrong_self_convention)]
pub fn as_iobufs<'mem, 'slice>(
iovs: &'slice [VolatileSlice<'mem>],
) -> &'slice [IoBufMut<'mem>] {
unsafe { slice::from_raw_parts(iovs.as_ptr() as *const IoBufMut, iovs.len()) }
}
#[inline]
pub fn as_iobufs_mut<'mem, 'slice>(
iovs: &'slice mut [VolatileSlice<'mem>],
) -> &'slice mut [IoBufMut<'mem>] {
unsafe { slice::from_raw_parts_mut(iovs.as_mut_ptr() as *mut IoBufMut, iovs.len()) }
}
pub fn offset(self, count: usize) -> Result<VolatileSlice<'a>> {
let new_addr = (self.as_mut_ptr() as usize).checked_add(count).ok_or(
VolatileMemoryError::Overflow {
base: self.as_mut_ptr() as usize,
offset: count,
},
)?;
let new_size = self
.size()
.checked_sub(count)
.ok_or(VolatileMemoryError::OutOfBounds { addr: new_addr })?;
unsafe { Ok(VolatileSlice::from_raw_parts(new_addr as *mut u8, new_size)) }
}
pub fn sub_slice(self, offset: usize, count: usize) -> Result<VolatileSlice<'a>> {
let mem_end = offset
.checked_add(count)
.ok_or(VolatileMemoryError::Overflow {
base: offset,
offset: count,
})?;
if mem_end > self.size() {
return Err(Error::OutOfBounds { addr: mem_end });
}
let new_addr = (self.as_mut_ptr() as usize).checked_add(offset).ok_or(
VolatileMemoryError::Overflow {
base: self.as_mut_ptr() as usize,
offset,
},
)?;
Ok(unsafe { VolatileSlice::from_raw_parts(new_addr as *mut u8, count) })
}
pub fn write_bytes(&self, value: u8) {
unsafe {
write_bytes(self.as_mut_ptr(), value, self.size());
}
}
pub fn copy_to<T>(&self, buf: &mut [T])
where
T: FromBytes + IntoBytes + Copy,
{
let mut addr = self.as_mut_ptr() as *const u8;
for v in buf.iter_mut().take(self.size() / size_of::<T>()) {
unsafe {
*v = read_volatile(addr as *const T);
addr = addr.add(size_of::<T>());
}
}
}
pub fn copy_to_volatile_slice(&self, slice: VolatileSlice) {
unsafe {
copy(
self.as_mut_ptr() as *const u8,
slice.as_mut_ptr(),
min(self.size(), slice.size()),
);
}
}
pub fn copy_from<T>(&self, buf: &[T])
where
T: IntoBytes + Copy,
{
let mut addr = self.as_mut_ptr();
for v in buf.iter().take(self.size() / size_of::<T>()) {
unsafe {
write_volatile(addr as *mut T, *v);
addr = addr.add(size_of::<T>());
}
}
}
pub fn is_all_zero(&self) -> bool {
const MASK_4BIT: usize = 0x0f;
let head_addr = self.as_ptr() as usize;
let aligned_head_addr = (head_addr + MASK_4BIT) & !MASK_4BIT;
let tail_addr = head_addr + self.size();
let aligned_tail_addr = tail_addr & !MASK_4BIT;
if (aligned_head_addr..aligned_tail_addr).step_by(16).any(
|aligned_addr|
unsafe { *(aligned_addr as *const u128) } != 0,
) {
return false;
}
if head_addr == aligned_head_addr && tail_addr == aligned_tail_addr {
true
} else {
unsafe {
is_all_zero_naive(head_addr, aligned_head_addr)
&& is_all_zero_naive(aligned_tail_addr, tail_addr)
}
}
}
}
unsafe fn is_all_zero_naive(head_addr: usize, tail_addr: usize) -> bool {
(head_addr..tail_addr).all(|addr| *(addr as *const u8) == 0)
}
impl VolatileMemory for VolatileSlice<'_> {
fn get_slice(&self, offset: usize, count: usize) -> Result<VolatileSlice> {
self.sub_slice(offset, count)
}
}
impl PartialEq<VolatileSlice<'_>> for VolatileSlice<'_> {
fn eq(&self, other: &VolatileSlice) -> bool {
let size = self.size();
if size != other.size() {
return false;
}
let cmp = unsafe { libc::memcmp(self.as_ptr() as _, other.as_ptr() as _, size) };
cmp == 0
}
}
impl Eq for VolatileSlice<'_> {}
impl std::io::Write for VolatileSlice<'_> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
let len = buf.len().min(self.size());
self.copy_from(&buf[..len]);
self.advance(len);
Ok(len)
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
#[cfg(test)]
mod tests {
use std::io::Write;
use std::sync::Arc;
use std::sync::Barrier;
use std::thread::spawn;
use super::*;
#[derive(Clone)]
struct VecMem {
mem: Arc<Vec<u8>>,
}
impl VecMem {
fn new(size: usize) -> VecMem {
VecMem {
mem: Arc::new(vec![0u8; size]),
}
}
}
impl VolatileMemory for VecMem {
fn get_slice(&self, offset: usize, count: usize) -> Result<VolatileSlice> {
let mem_end = offset
.checked_add(count)
.ok_or(VolatileMemoryError::Overflow {
base: offset,
offset: count,
})?;
if mem_end > self.mem.len() {
return Err(Error::OutOfBounds { addr: mem_end });
}
let new_addr = (self.mem.as_ptr() as usize).checked_add(offset).ok_or(
VolatileMemoryError::Overflow {
base: self.mem.as_ptr() as usize,
offset,
},
)?;
Ok(
unsafe { VolatileSlice::from_raw_parts(new_addr as *mut u8, count) },
)
}
}
#[test]
fn observe_mutate() {
let a = VecMem::new(1);
let a_clone = a.clone();
a.get_slice(0, 1).unwrap().write_bytes(99);
let start_barrier = Arc::new(Barrier::new(2));
let thread_start_barrier = start_barrier.clone();
let end_barrier = Arc::new(Barrier::new(2));
let thread_end_barrier = end_barrier.clone();
spawn(move || {
thread_start_barrier.wait();
a_clone.get_slice(0, 1).unwrap().write_bytes(0);
thread_end_barrier.wait();
});
let mut byte = [0u8; 1];
a.get_slice(0, 1).unwrap().copy_to(&mut byte);
assert_eq!(byte[0], 99);
start_barrier.wait();
end_barrier.wait();
a.get_slice(0, 1).unwrap().copy_to(&mut byte);
assert_eq!(byte[0], 0);
}
#[test]
fn slice_size() {
let a = VecMem::new(100);
let s = a.get_slice(0, 27).unwrap();
assert_eq!(s.size(), 27);
let s = a.get_slice(34, 27).unwrap();
assert_eq!(s.size(), 27);
let s = s.get_slice(20, 5).unwrap();
assert_eq!(s.size(), 5);
}
#[test]
fn slice_overflow_error() {
let a = VecMem::new(1);
let res = a.get_slice(usize::MAX, 1).unwrap_err();
assert_eq!(
res,
Error::Overflow {
base: usize::MAX,
offset: 1,
}
);
}
#[test]
fn slice_oob_error() {
let a = VecMem::new(100);
a.get_slice(50, 50).unwrap();
let res = a.get_slice(55, 50).unwrap_err();
assert_eq!(res, Error::OutOfBounds { addr: 105 });
}
#[test]
fn is_all_zero_16bytes_aligned() {
let a = VecMem::new(1024);
let slice = a.get_slice(0, 1024).unwrap();
assert!(slice.is_all_zero());
a.get_slice(129, 1).unwrap().write_bytes(1);
assert!(!slice.is_all_zero());
}
#[test]
fn is_all_zero_head_not_aligned() {
let a = VecMem::new(1024);
let slice = a.get_slice(1, 1023).unwrap();
assert!(slice.is_all_zero());
a.get_slice(0, 1).unwrap().write_bytes(1);
assert!(slice.is_all_zero());
a.get_slice(1, 1).unwrap().write_bytes(1);
assert!(!slice.is_all_zero());
a.get_slice(1, 1).unwrap().write_bytes(0);
a.get_slice(129, 1).unwrap().write_bytes(1);
assert!(!slice.is_all_zero());
}
#[test]
fn is_all_zero_tail_not_aligned() {
let a = VecMem::new(1024);
let slice = a.get_slice(0, 1023).unwrap();
assert!(slice.is_all_zero());
a.get_slice(1023, 1).unwrap().write_bytes(1);
assert!(slice.is_all_zero());
a.get_slice(1022, 1).unwrap().write_bytes(1);
assert!(!slice.is_all_zero());
a.get_slice(1022, 1).unwrap().write_bytes(0);
a.get_slice(0, 1).unwrap().write_bytes(1);
assert!(!slice.is_all_zero());
}
#[test]
fn is_all_zero_no_aligned_16bytes() {
let a = VecMem::new(1024);
let slice = a.get_slice(1, 16).unwrap();
assert!(slice.is_all_zero());
a.get_slice(0, 1).unwrap().write_bytes(1);
assert!(slice.is_all_zero());
for i in 1..17 {
a.get_slice(i, 1).unwrap().write_bytes(1);
assert!(!slice.is_all_zero());
a.get_slice(i, 1).unwrap().write_bytes(0);
}
a.get_slice(17, 1).unwrap().write_bytes(1);
assert!(slice.is_all_zero());
}
#[test]
fn write_partial() {
let mem = VecMem::new(1024);
let mut slice = mem.get_slice(1, 16).unwrap();
slice.write_bytes(0xCC);
let write_len = slice.write(&[1, 2, 3, 4]).unwrap();
assert_eq!(write_len, 4);
assert_eq!(slice.size(), 16 - 4);
assert_eq!(mem.mem[1..=4], [1, 2, 3, 4]);
assert_eq!(mem.mem[5], 0xCC);
}
#[test]
fn write_multiple() {
let mem = VecMem::new(1024);
let mut slice = mem.get_slice(1, 16).unwrap();
slice.write_bytes(0xCC);
let write_len = slice.write(&[1, 2, 3, 4]).unwrap();
assert_eq!(write_len, 4);
assert_eq!(slice.size(), 16 - 4);
assert_eq!(mem.mem[5], 0xCC);
let write2_len = slice.write(&[5, 6, 7, 8]).unwrap();
assert_eq!(write2_len, 4);
assert_eq!(slice.size(), 16 - 4 - 4);
assert_eq!(mem.mem[1..=8], [1, 2, 3, 4, 5, 6, 7, 8]);
assert_eq!(mem.mem[9], 0xCC);
}
#[test]
fn write_exact_slice_size() {
let mem = VecMem::new(1024);
let mut slice = mem.get_slice(1, 4).unwrap();
slice.write_bytes(0xCC);
let write_len = slice.write(&[1, 2, 3, 4]).unwrap();
assert_eq!(write_len, 4);
assert_eq!(slice.size(), 0);
assert_eq!(mem.mem[1..=4], [1, 2, 3, 4]);
assert_eq!(mem.mem[5], 0);
}
#[test]
fn write_more_than_slice_size() {
let mem = VecMem::new(1024);
let mut slice = mem.get_slice(1, 4).unwrap();
slice.write_bytes(0xCC);
let write_len = slice.write(&[1, 2, 3, 4, 5]).unwrap();
assert_eq!(write_len, 4);
assert_eq!(slice.size(), 0);
assert_eq!(mem.mem[1..=4], [1, 2, 3, 4]);
assert_eq!(mem.mem[5], 0);
}
#[test]
fn write_empty_slice() {
let mem = VecMem::new(1024);
let mut slice = mem.get_slice(1, 0).unwrap();
assert_eq!(slice.write(&[1, 2, 3, 4]).unwrap(), 0);
assert_eq!(slice.write(&[5, 6, 7, 8]).unwrap(), 0);
assert_eq!(slice.write(&[]).unwrap(), 0);
}
}