vfio_user: Always generate an error for commands

Any error from the backend or from the protocol handling code will now
result in an error reply being sent. This is cleanly achieved by
splitting the command handling out into its own method and using the
Rust Result<> based error handling to trigger the generation of the
error reply.

Signed-off-by: Rob Bradford <robert.bradford@intel.com>
This commit is contained in:
Rob Bradford 2023-02-14 12:01:17 +00:00 committed by Bo Chen
parent b072eb454e
commit 10531f052b

View File

@ -272,6 +272,8 @@ pub enum Error {
SocketAccept(#[source] std::io::Error), SocketAccept(#[source] std::io::Error),
#[error("Unsupported command: {0:?}")] #[error("Unsupported command: {0:?}")]
UnsupportedCommand(Command), UnsupportedCommand(Command),
#[error("Error from backend: {0:?}")]
Backend(#[source] std::io::Error),
} }
impl Client { impl Client {
@ -854,40 +856,13 @@ impl Server {
}) })
} }
pub fn run(&self, backend: &mut dyn ServerBackend) -> Result<(), Error> { fn handle_command(
let (mut stream, _) = self.listener.accept().map_err(Error::SocketAccept)?; &self,
backend: &mut dyn ServerBackend,
loop { stream: &mut UnixStream,
let mut header = Header::default(); header: Header,
fds: Vec<File>,
// The maximum number of FDs that can be sent is 16 so that is ) -> Result<(), Error> {
// also the maximum that can be received.
let mut fds = vec![0; 16];
let mut iovecs = vec![iovec {
iov_base: header.as_mut_slice().as_mut_ptr() as *mut c_void,
iov_len: header.as_mut_slice().len(),
}];
// SAFETY: Safe as the iovect is correctly initialised and fds is big enough
let (bytes, fds_received) = unsafe {
stream
.recv_with_fds(&mut iovecs, &mut fds)
.map_err(Error::ReceiveWithFd)?
};
// Other end closed connection
if bytes == 0 {
info!("Connection closed");
break;
}
fds.resize(fds_received, 0);
let fds: Vec<File> = fds
.iter()
// SAFETY: Safe as we have only valid FDs in the vector now
.map(|fd| unsafe { File::from_raw_fd(*fd) })
.collect();
match header.command { match header.command {
Command::Unknown Command::Unknown
| Command::GetRegionIoFds | Command::GetRegionIoFds
@ -906,8 +881,7 @@ impl Server {
.map_err(Error::StreamRead)?; .map_err(Error::StreamRead)?;
let mut raw_version_data = Vec::new(); let mut raw_version_data = Vec::new();
raw_version_data raw_version_data.resize(header.message_size as usize - size_of::<Version>(), 0u8);
.resize(header.message_size as usize - size_of::<Version>(), 0u8);
stream stream
.read_exact(&mut raw_version_data) .read_exact(&mut raw_version_data)
.map_err(Error::StreamRead)?; .map_err(Error::StreamRead)?;
@ -915,8 +889,8 @@ impl Server {
.unwrap() .unwrap()
.to_string_lossy() .to_string_lossy()
.into_owned(); .into_owned();
let client_capabilities: Capabilities = serde_json::from_str(&version_data) let client_capabilities: Capabilities =
.map_err(Error::DeserializeCapabilites)?; serde_json::from_str(&version_data).map_err(Error::DeserializeCapabilites)?;
info!( info!(
"Received client version: major = {} minor = {} capabilities = {:?}", "Received client version: major = {} minor = {} capabilities = {:?}",
@ -962,29 +936,23 @@ impl Server {
.read_exact(&mut cmd.as_mut_slice()[size_of::<Header>()..]) .read_exact(&mut cmd.as_mut_slice()[size_of::<Header>()..])
.map_err(Error::StreamRead)?; .map_err(Error::StreamRead)?;
let reply = match backend.dma_map( backend
.dma_map(
DmaMapFlags::from_bits_truncate(cmd.flags), DmaMapFlags::from_bits_truncate(cmd.flags),
cmd.offset, cmd.offset,
cmd.address, cmd.address,
cmd.size, cmd.size,
Some(&fds[0]), Some(&fds[0]),
) { )
Ok(()) => Header { .map_err(Error::Backend)?;
let reply = Header {
message_id: cmd.header.message_id, message_id: cmd.header.message_id,
command: Command::DmaMap, command: Command::DmaMap,
flags: HeaderFlags::Reply as u32, flags: HeaderFlags::Reply as u32,
message_size: size_of::<Header>() as u32, message_size: size_of::<Header>() as u32,
..Default::default() ..Default::default()
},
Err(e) => Header {
message_id: cmd.header.message_id,
command: Command::DmaMap,
flags: HeaderFlags::Error as u32,
message_size: size_of::<Header>() as u32,
error: e.raw_os_error().unwrap_or_default() as u32,
},
}; };
stream stream
.write_all(reply.as_slice()) .write_all(reply.as_slice())
.map_err(Error::StreamWrite)?; .map_err(Error::StreamWrite)?;
@ -998,12 +966,14 @@ impl Server {
.read_exact(&mut cmd.as_mut_slice()[size_of::<Header>()..]) .read_exact(&mut cmd.as_mut_slice()[size_of::<Header>()..])
.map_err(Error::StreamRead)?; .map_err(Error::StreamRead)?;
match backend.dma_unmap( backend
.dma_unmap(
DmaUnmapFlags::from_bits_truncate(cmd.flags), DmaUnmapFlags::from_bits_truncate(cmd.flags),
cmd.address, cmd.address,
cmd.size, cmd.size,
) { )
Ok(()) => { .map_err(Error::Backend)?;
let reply = DmaUnmap { let reply = DmaUnmap {
header: Header { header: Header {
message_id: cmd.header.message_id, message_id: cmd.header.message_id,
@ -1021,20 +991,6 @@ impl Server {
.write_all(reply.as_slice()) .write_all(reply.as_slice())
.map_err(Error::StreamWrite)?; .map_err(Error::StreamWrite)?;
} }
Err(e) => {
let reply = Header {
message_id: cmd.header.message_id,
command: Command::DmaUnmap,
flags: HeaderFlags::Error as u32,
message_size: size_of::<Header>() as u32,
error: e.raw_os_error().unwrap_or_default() as u32,
};
stream
.write_all(reply.as_slice())
.map_err(Error::StreamWrite)?;
}
};
}
Command::DeviceGetInfo => { Command::DeviceGetInfo => {
let mut cmd = DeviceGetInfo { let mut cmd = DeviceGetInfo {
header, header,
@ -1126,22 +1082,16 @@ impl Server {
.read_exact(&mut cmd.as_mut_slice()[size_of::<Header>()..]) .read_exact(&mut cmd.as_mut_slice()[size_of::<Header>()..])
.map_err(Error::StreamRead)?; .map_err(Error::StreamRead)?;
let reply = backend
match backend.set_irqs(cmd.index, cmd.flags, cmd.start, cmd.count, fds) { .set_irqs(cmd.index, cmd.flags, cmd.start, cmd.count, fds)
Ok(()) => Header { .map_err(Error::Backend)?;
let reply = Header {
message_id: cmd.header.message_id, message_id: cmd.header.message_id,
command: Command::SetIrqs, command: Command::SetIrqs,
flags: HeaderFlags::Reply as u32, flags: HeaderFlags::Reply as u32,
message_size: size_of::<Header>() as u32, message_size: size_of::<Header>() as u32,
..Default::default() ..Default::default()
},
Err(e) => Header {
message_id: cmd.header.message_id,
command: Command::SetIrqs,
flags: HeaderFlags::Error as u32,
message_size: size_of::<Header>() as u32,
error: e.raw_os_error().unwrap_or_default() as u32,
},
}; };
stream stream
.write_all(reply.as_slice()) .write_all(reply.as_slice())
@ -1159,8 +1109,10 @@ impl Server {
let (region, offset, count) = (cmd.region, cmd.offset, cmd.count); let (region, offset, count) = (cmd.region, cmd.offset, cmd.count);
let mut data = vec![0u8; count as usize]; let mut data = vec![0u8; count as usize];
match backend.region_read(region, offset, &mut data) { backend
Ok(()) => { .region_read(region, offset, &mut data)
.map_err(Error::Backend)?;
let reply = RegionAccess { let reply = RegionAccess {
header: Header { header: Header {
message_id: cmd.header.message_id, message_id: cmd.header.message_id,
@ -1178,20 +1130,6 @@ impl Server {
.map_err(Error::StreamWrite)?; .map_err(Error::StreamWrite)?;
stream.write_all(&data).map_err(Error::StreamWrite)?; stream.write_all(&data).map_err(Error::StreamWrite)?;
} }
Err(e) => {
let reply = Header {
message_id: cmd.header.message_id,
command: Command::RegionRead,
flags: HeaderFlags::Error as u32,
message_size: size_of::<Header>() as u32,
error: e.raw_os_error().unwrap_or_default() as u32,
};
stream
.write_all(reply.as_slice())
.map_err(Error::StreamWrite)?;
}
}
}
Command::RegionWrite => { Command::RegionWrite => {
let mut cmd = RegionAccess { let mut cmd = RegionAccess {
header, header,
@ -1205,8 +1143,10 @@ impl Server {
let mut data = vec![0u8; count as usize]; let mut data = vec![0u8; count as usize];
stream.read_exact(&mut data).map_err(Error::StreamRead)?; stream.read_exact(&mut data).map_err(Error::StreamRead)?;
match backend.region_write(region, offset, &data) { backend
Ok(()) => { .region_write(region, offset, &data)
.map_err(Error::Backend)?;
let reply = RegionAccess { let reply = RegionAccess {
header: Header { header: Header {
message_id: cmd.header.message_id, message_id: cmd.header.message_id,
@ -1223,42 +1163,71 @@ impl Server {
.write_all(reply.as_slice()) .write_all(reply.as_slice())
.map_err(Error::StreamWrite)?; .map_err(Error::StreamWrite)?;
} }
Err(e) => {
let reply = Header {
message_id: cmd.header.message_id,
command: Command::RegionWrite,
flags: HeaderFlags::Error as u32,
message_size: size_of::<Header>() as u32,
error: e.raw_os_error().unwrap_or_default() as u32,
};
stream
.write_all(reply.as_slice())
.map_err(Error::StreamWrite)?;
}
}
}
Command::DeviceReset => { Command::DeviceReset => {
let reply = match backend.reset() { backend.reset().map_err(Error::Backend)?;
Ok(()) => Header { let reply = Header {
message_id: header.message_id, message_id: header.message_id,
command: Command::DeviceReset, command: Command::DeviceReset,
flags: HeaderFlags::Reply as u32, flags: HeaderFlags::Reply as u32,
message_size: size_of::<Header>() as u32, message_size: size_of::<Header>() as u32,
..Default::default() ..Default::default()
},
Err(e) => Header {
message_id: header.message_id,
command: Command::DeviceReset,
flags: HeaderFlags::Error as u32,
message_size: size_of::<Header>() as u32,
error: e.raw_os_error().unwrap_or_default() as u32,
},
}; };
stream stream
.write_all(reply.as_slice()) .write_all(reply.as_slice())
.map_err(Error::StreamWrite)?; .map_err(Error::StreamWrite)?;
} }
} }
Ok(())
}
pub fn run(&self, backend: &mut dyn ServerBackend) -> Result<(), Error> {
let (mut stream, _) = self.listener.accept().map_err(Error::SocketAccept)?;
loop {
let mut header = Header::default();
// The maximum number of FDs that can be sent is 16 so that is
// also the maximum that can be received.
let mut fds = vec![0; 16];
let mut iovecs = vec![iovec {
iov_base: header.as_mut_slice().as_mut_ptr() as *mut c_void,
iov_len: header.as_mut_slice().len(),
}];
// SAFETY: Safe as the iovect is correctly initialised and fds is big enough
let (bytes, fds_received) = unsafe {
stream
.recv_with_fds(&mut iovecs, &mut fds)
.map_err(Error::ReceiveWithFd)?
};
// Other end closed connection
if bytes == 0 {
info!("Connection closed");
break;
}
fds.resize(fds_received, 0);
let fds: Vec<File> = fds
.iter()
// SAFETY: Safe as we have only valid FDs in the vector now
.map(|fd| unsafe { File::from_raw_fd(*fd) })
.collect();
if let Err(e) = self.handle_command(backend, &mut stream, header, fds) {
error!("Error handling command: {:?}: {e}", header.command);
let reply = Header {
message_id: header.message_id,
command: header.command,
flags: HeaderFlags::Error as u32,
message_size: size_of::<Header>() as u32,
error: 0,
};
stream
.write_all(reply.as_slice())
.map_err(Error::StreamWrite)?;
}
} }
Ok(()) Ok(())