diff --git a/vfio_user/src/lib.rs b/vfio_user/src/lib.rs index de927adb7..4f27f83eb 100644 --- a/vfio_user/src/lib.rs +++ b/vfio_user/src/lib.rs @@ -272,6 +272,8 @@ pub enum Error { SocketAccept(#[source] std::io::Error), #[error("Unsupported command: {0:?}")] UnsupportedCommand(Command), + #[error("Error from backend: {0:?}")] + Backend(#[source] std::io::Error), } impl Client { @@ -854,6 +856,331 @@ impl Server { }) } + fn handle_command( + &self, + backend: &mut dyn ServerBackend, + stream: &mut UnixStream, + header: Header, + fds: Vec, + ) -> Result<(), Error> { + match header.command { + Command::Unknown + | Command::GetRegionIoFds + | Command::DmaRead + | Command::DmaWrite + | Command::UserDirtyPages => { + return Err(Error::UnsupportedCommand(header.command)); + } + Command::Version => { + let mut client_version = Version { + header, + ..Default::default() + }; + stream + .read_exact(&mut client_version.as_mut_slice()[size_of::
()..]) + .map_err(Error::StreamRead)?; + + let mut raw_version_data = Vec::new(); + raw_version_data.resize(header.message_size as usize - size_of::(), 0u8); + stream + .read_exact(&mut raw_version_data) + .map_err(Error::StreamRead)?; + let version_data = CString::from_vec_with_nul(raw_version_data) + .unwrap() + .to_string_lossy() + .into_owned(); + let client_capabilities: Capabilities = + serde_json::from_str(&version_data).map_err(Error::DeserializeCapabilites)?; + + info!( + "Received client version: major = {} minor = {} capabilities = {:?}", + client_version.major, client_version.minor, client_capabilities + ); + + let version = Version { + header: Header { + message_id: client_version.header.message_id, + command: Command::Version, + flags: HeaderFlags::Reply as u32, + message_size: (size_of::() + version_data.len() + 1) as u32, + ..Default::default() + }, + major: 0, + minor: 1, + }; + + let server_capabilities = Capabilities::default(); + let version_data = serde_json::to_string(&server_capabilities) + .map_err(Error::SerializeCapabilites)?; + let version_data = CString::new(version_data.as_bytes()).unwrap(); + + let bufs = vec![ + IoSlice::new(version.as_slice()), + IoSlice::new(version_data.as_bytes_with_nul()), + ]; + + // TODO: Use write_all_vectored() when ready + let _ = stream.write_vectored(&bufs).map_err(Error::StreamWrite)?; + + info!( + "Sent server version: major = {} minor = {} capabilities = {:?}", + version.major, version.minor, server_capabilities + ); + } + Command::DmaMap => { + let mut cmd = DmaMap { + header, + ..Default::default() + }; + stream + .read_exact(&mut cmd.as_mut_slice()[size_of::
()..]) + .map_err(Error::StreamRead)?; + + backend + .dma_map( + DmaMapFlags::from_bits_truncate(cmd.flags), + cmd.offset, + cmd.address, + cmd.size, + Some(&fds[0]), + ) + .map_err(Error::Backend)?; + + let reply = Header { + message_id: cmd.header.message_id, + command: Command::DmaMap, + flags: HeaderFlags::Reply as u32, + message_size: size_of::
() as u32, + ..Default::default() + }; + stream + .write_all(reply.as_slice()) + .map_err(Error::StreamWrite)?; + } + Command::DmaUnmap => { + let mut cmd = DmaUnmap { + header, + ..Default::default() + }; + stream + .read_exact(&mut cmd.as_mut_slice()[size_of::
()..]) + .map_err(Error::StreamRead)?; + + backend + .dma_unmap( + DmaUnmapFlags::from_bits_truncate(cmd.flags), + cmd.address, + cmd.size, + ) + .map_err(Error::Backend)?; + + let reply = DmaUnmap { + header: Header { + message_id: cmd.header.message_id, + command: Command::DmaUnmap, + flags: HeaderFlags::Reply as u32, + message_size: size_of::
() as u32, + ..Default::default() + }, + argsz: cmd.argsz, + flags: cmd.flags, + address: cmd.address, + size: cmd.size, + }; + stream + .write_all(reply.as_slice()) + .map_err(Error::StreamWrite)?; + } + Command::DeviceGetInfo => { + let mut cmd = DeviceGetInfo { + header, + ..Default::default() + }; + stream + .read_exact(&mut cmd.as_mut_slice()[size_of::
()..]) + .map_err(Error::StreamRead)?; + + let reply = DeviceGetInfo { + header: Header { + message_id: cmd.header.message_id, + command: Command::DeviceGetInfo, + flags: HeaderFlags::Reply as u32, + message_size: size_of::() as u32, + ..Default::default() + }, + argsz: size_of::() as u32, + flags: VFIO_DEVICE_FLAGS_PCI + | if self.resettable { + VFIO_DEVICE_FLAGS_RESET + } else { + 0 + }, + num_regions: self.regions.len() as u32, + num_irqs: self.irqs.len() as u32, + }; + stream + .write_all(reply.as_slice()) + .map_err(Error::StreamWrite)?; + } + Command::DeviceGetRegionInfo => { + let mut cmd = DeviceGetRegionInfo { + header, + ..Default::default() + }; + stream + .read_exact(&mut cmd.as_mut_slice()[size_of::
()..]) + .map_err(Error::StreamRead)?; + + let reply = DeviceGetRegionInfo { + header: Header { + message_id: cmd.header.message_id, + command: Command::DeviceGetRegionInfo, + flags: HeaderFlags::Reply as u32, + message_size: size_of::() as u32, + ..Default::default() + }, + region_info: self.regions[cmd.region_info.index as usize], + }; + stream + .write_all(reply.as_slice()) + .map_err(Error::StreamWrite)?; + } + Command::GetIrqInfo => { + let mut cmd = GetIrqInfo { + header, + ..Default::default() + }; + stream + .read_exact(&mut cmd.as_mut_slice()[size_of::
()..]) + .map_err(Error::StreamRead)?; + + let irq = &self.irqs[cmd.index as usize]; + + let reply = GetIrqInfo { + header: Header { + message_id: cmd.header.message_id, + command: Command::GetIrqInfo, + flags: HeaderFlags::Reply as u32, + message_size: size_of::() as u32, + ..Default::default() + }, + argsz: (size_of::() - size_of::
()) as u32, + index: irq.index, + flags: irq.flags, + count: irq.count, + }; + stream + .write_all(reply.as_slice()) + .map_err(Error::StreamWrite)?; + } + Command::SetIrqs => { + let mut cmd = SetIrqs { + header, + ..Default::default() + }; + stream + .read_exact(&mut cmd.as_mut_slice()[size_of::
()..]) + .map_err(Error::StreamRead)?; + + backend + .set_irqs(cmd.index, cmd.flags, cmd.start, cmd.count, fds) + .map_err(Error::Backend)?; + + let reply = Header { + message_id: cmd.header.message_id, + command: Command::SetIrqs, + flags: HeaderFlags::Reply as u32, + message_size: size_of::
() as u32, + ..Default::default() + }; + stream + .write_all(reply.as_slice()) + .map_err(Error::StreamWrite)?; + } + Command::RegionRead => { + let mut cmd = RegionAccess { + header, + ..Default::default() + }; + stream + .read_exact(&mut cmd.as_mut_slice()[size_of::
()..]) + .map_err(Error::StreamRead)?; + + let (region, offset, count) = (cmd.region, cmd.offset, cmd.count); + + let mut data = vec![0u8; count as usize]; + backend + .region_read(region, offset, &mut data) + .map_err(Error::Backend)?; + + let reply = RegionAccess { + header: Header { + message_id: cmd.header.message_id, + command: Command::RegionRead, + flags: HeaderFlags::Reply as u32, + message_size: size_of::() as u32, + ..Default::default() + }, + region, + offset, + count, + }; + stream + .write_all(reply.as_slice()) + .map_err(Error::StreamWrite)?; + stream.write_all(&data).map_err(Error::StreamWrite)?; + } + Command::RegionWrite => { + let mut cmd = RegionAccess { + header, + ..Default::default() + }; + stream + .read_exact(&mut cmd.as_mut_slice()[size_of::
()..]) + .map_err(Error::StreamRead)?; + + let (region, offset, count) = (cmd.region, cmd.offset, cmd.count); + + let mut data = vec![0u8; count as usize]; + stream.read_exact(&mut data).map_err(Error::StreamRead)?; + backend + .region_write(region, offset, &data) + .map_err(Error::Backend)?; + + let reply = RegionAccess { + header: Header { + message_id: cmd.header.message_id, + command: Command::RegionWrite, + flags: HeaderFlags::Reply as u32, + message_size: size_of::() as u32, + ..Default::default() + }, + region, + offset, + count, + }; + stream + .write_all(reply.as_slice()) + .map_err(Error::StreamWrite)?; + } + Command::DeviceReset => { + backend.reset().map_err(Error::Backend)?; + let reply = Header { + message_id: header.message_id, + command: Command::DeviceReset, + flags: HeaderFlags::Reply as u32, + message_size: size_of::
() as u32, + ..Default::default() + }; + stream + .write_all(reply.as_slice()) + .map_err(Error::StreamWrite)?; + } + } + + Ok(()) + } + pub fn run(&self, backend: &mut dyn ServerBackend) -> Result<(), Error> { let (mut stream, _) = self.listener.accept().map_err(Error::SocketAccept)?; @@ -888,376 +1215,18 @@ impl Server { .map(|fd| unsafe { File::from_raw_fd(*fd) }) .collect(); - match header.command { - Command::Unknown - | Command::GetRegionIoFds - | Command::DmaRead - | Command::DmaWrite - | Command::UserDirtyPages => { - return Err(Error::UnsupportedCommand(header.command)); - } - Command::Version => { - let mut client_version = Version { - header, - ..Default::default() - }; - stream - .read_exact(&mut client_version.as_mut_slice()[size_of::
()..]) - .map_err(Error::StreamRead)?; - - let mut raw_version_data = Vec::new(); - raw_version_data - .resize(header.message_size as usize - size_of::(), 0u8); - stream - .read_exact(&mut raw_version_data) - .map_err(Error::StreamRead)?; - let version_data = CString::from_vec_with_nul(raw_version_data) - .unwrap() - .to_string_lossy() - .into_owned(); - let client_capabilities: Capabilities = serde_json::from_str(&version_data) - .map_err(Error::DeserializeCapabilites)?; - - info!( - "Received client version: major = {} minor = {} capabilities = {:?}", - client_version.major, client_version.minor, client_capabilities - ); - - let version = Version { - header: Header { - message_id: client_version.header.message_id, - command: Command::Version, - flags: HeaderFlags::Reply as u32, - message_size: (size_of::() + version_data.len() + 1) as u32, - ..Default::default() - }, - major: 0, - minor: 1, - }; - - let server_capabilities = Capabilities::default(); - let version_data = serde_json::to_string(&server_capabilities) - .map_err(Error::SerializeCapabilites)?; - let version_data = CString::new(version_data.as_bytes()).unwrap(); - - let bufs = vec![ - IoSlice::new(version.as_slice()), - IoSlice::new(version_data.as_bytes_with_nul()), - ]; - - // TODO: Use write_all_vectored() when ready - let _ = stream.write_vectored(&bufs).map_err(Error::StreamWrite)?; - - info!( - "Sent server version: major = {} minor = {} capabilities = {:?}", - version.major, version.minor, server_capabilities - ); - } - Command::DmaMap => { - let mut cmd = DmaMap { - header, - ..Default::default() - }; - stream - .read_exact(&mut cmd.as_mut_slice()[size_of::
()..]) - .map_err(Error::StreamRead)?; - - let reply = match backend.dma_map( - DmaMapFlags::from_bits_truncate(cmd.flags), - cmd.offset, - cmd.address, - cmd.size, - Some(&fds[0]), - ) { - Ok(()) => Header { - message_id: cmd.header.message_id, - command: Command::DmaMap, - flags: HeaderFlags::Reply as u32, - message_size: size_of::
() as u32, - ..Default::default() - }, - Err(e) => Header { - message_id: cmd.header.message_id, - command: Command::DmaMap, - flags: HeaderFlags::Error as u32, - message_size: size_of::
() as u32, - error: e.raw_os_error().unwrap_or_default() as u32, - }, - }; - - stream - .write_all(reply.as_slice()) - .map_err(Error::StreamWrite)?; - } - Command::DmaUnmap => { - let mut cmd = DmaUnmap { - header, - ..Default::default() - }; - stream - .read_exact(&mut cmd.as_mut_slice()[size_of::
()..]) - .map_err(Error::StreamRead)?; - - match backend.dma_unmap( - DmaUnmapFlags::from_bits_truncate(cmd.flags), - cmd.address, - cmd.size, - ) { - Ok(()) => { - let reply = DmaUnmap { - header: Header { - message_id: cmd.header.message_id, - command: Command::DmaUnmap, - flags: HeaderFlags::Reply as u32, - message_size: size_of::
() as u32, - ..Default::default() - }, - argsz: cmd.argsz, - flags: cmd.flags, - address: cmd.address, - size: cmd.size, - }; - stream - .write_all(reply.as_slice()) - .map_err(Error::StreamWrite)?; - } - Err(e) => { - let reply = Header { - message_id: cmd.header.message_id, - command: Command::DmaUnmap, - flags: HeaderFlags::Error as u32, - message_size: size_of::
() as u32, - error: e.raw_os_error().unwrap_or_default() as u32, - }; - stream - .write_all(reply.as_slice()) - .map_err(Error::StreamWrite)?; - } - }; - } - Command::DeviceGetInfo => { - let mut cmd = DeviceGetInfo { - header, - ..Default::default() - }; - stream - .read_exact(&mut cmd.as_mut_slice()[size_of::
()..]) - .map_err(Error::StreamRead)?; - - let reply = DeviceGetInfo { - header: Header { - message_id: cmd.header.message_id, - command: Command::DeviceGetInfo, - flags: HeaderFlags::Reply as u32, - message_size: size_of::() as u32, - ..Default::default() - }, - argsz: size_of::() as u32, - flags: VFIO_DEVICE_FLAGS_PCI - | if self.resettable { - VFIO_DEVICE_FLAGS_RESET - } else { - 0 - }, - num_regions: self.regions.len() as u32, - num_irqs: self.irqs.len() as u32, - }; - stream - .write_all(reply.as_slice()) - .map_err(Error::StreamWrite)?; - } - Command::DeviceGetRegionInfo => { - let mut cmd = DeviceGetRegionInfo { - header, - ..Default::default() - }; - stream - .read_exact(&mut cmd.as_mut_slice()[size_of::
()..]) - .map_err(Error::StreamRead)?; - - let reply = DeviceGetRegionInfo { - header: Header { - message_id: cmd.header.message_id, - command: Command::DeviceGetRegionInfo, - flags: HeaderFlags::Reply as u32, - message_size: size_of::() as u32, - ..Default::default() - }, - region_info: self.regions[cmd.region_info.index as usize], - }; - stream - .write_all(reply.as_slice()) - .map_err(Error::StreamWrite)?; - } - Command::GetIrqInfo => { - let mut cmd = GetIrqInfo { - header, - ..Default::default() - }; - stream - .read_exact(&mut cmd.as_mut_slice()[size_of::
()..]) - .map_err(Error::StreamRead)?; - - let irq = &self.irqs[cmd.index as usize]; - - let reply = GetIrqInfo { - header: Header { - message_id: cmd.header.message_id, - command: Command::GetIrqInfo, - flags: HeaderFlags::Reply as u32, - message_size: size_of::() as u32, - ..Default::default() - }, - argsz: (size_of::() - size_of::
()) as u32, - index: irq.index, - flags: irq.flags, - count: irq.count, - }; - stream - .write_all(reply.as_slice()) - .map_err(Error::StreamWrite)?; - } - Command::SetIrqs => { - let mut cmd = SetIrqs { - header, - ..Default::default() - }; - stream - .read_exact(&mut cmd.as_mut_slice()[size_of::
()..]) - .map_err(Error::StreamRead)?; - - let reply = - match backend.set_irqs(cmd.index, cmd.flags, cmd.start, cmd.count, fds) { - Ok(()) => Header { - message_id: cmd.header.message_id, - command: Command::SetIrqs, - flags: HeaderFlags::Reply as u32, - message_size: size_of::
() as u32, - ..Default::default() - }, - Err(e) => Header { - message_id: cmd.header.message_id, - command: Command::SetIrqs, - flags: HeaderFlags::Error as u32, - message_size: size_of::
() as u32, - error: e.raw_os_error().unwrap_or_default() as u32, - }, - }; - stream - .write_all(reply.as_slice()) - .map_err(Error::StreamWrite)?; - } - Command::RegionRead => { - let mut cmd = RegionAccess { - header, - ..Default::default() - }; - stream - .read_exact(&mut cmd.as_mut_slice()[size_of::
()..]) - .map_err(Error::StreamRead)?; - - let (region, offset, count) = (cmd.region, cmd.offset, cmd.count); - - let mut data = vec![0u8; count as usize]; - match backend.region_read(region, offset, &mut data) { - Ok(()) => { - let reply = RegionAccess { - header: Header { - message_id: cmd.header.message_id, - command: Command::RegionRead, - flags: HeaderFlags::Reply as u32, - message_size: size_of::() as u32, - ..Default::default() - }, - region, - offset, - count, - }; - stream - .write_all(reply.as_slice()) - .map_err(Error::StreamWrite)?; - stream.write_all(&data).map_err(Error::StreamWrite)?; - } - Err(e) => { - let reply = Header { - message_id: cmd.header.message_id, - command: Command::RegionRead, - flags: HeaderFlags::Error as u32, - message_size: size_of::
() as u32, - error: e.raw_os_error().unwrap_or_default() as u32, - }; - stream - .write_all(reply.as_slice()) - .map_err(Error::StreamWrite)?; - } - } - } - Command::RegionWrite => { - let mut cmd = RegionAccess { - header, - ..Default::default() - }; - stream - .read_exact(&mut cmd.as_mut_slice()[size_of::
()..]) - .map_err(Error::StreamRead)?; - - let (region, offset, count) = (cmd.region, cmd.offset, cmd.count); - - let mut data = vec![0u8; count as usize]; - stream.read_exact(&mut data).map_err(Error::StreamRead)?; - match backend.region_write(region, offset, &data) { - Ok(()) => { - let reply = RegionAccess { - header: Header { - message_id: cmd.header.message_id, - command: Command::RegionWrite, - flags: HeaderFlags::Reply as u32, - message_size: size_of::() as u32, - ..Default::default() - }, - region, - offset, - count, - }; - stream - .write_all(reply.as_slice()) - .map_err(Error::StreamWrite)?; - } - Err(e) => { - let reply = Header { - message_id: cmd.header.message_id, - command: Command::RegionWrite, - flags: HeaderFlags::Error as u32, - message_size: size_of::
() as u32, - error: e.raw_os_error().unwrap_or_default() as u32, - }; - stream - .write_all(reply.as_slice()) - .map_err(Error::StreamWrite)?; - } - } - } - Command::DeviceReset => { - let reply = match backend.reset() { - Ok(()) => Header { - message_id: header.message_id, - command: Command::DeviceReset, - flags: HeaderFlags::Reply as u32, - message_size: size_of::
() as u32, - ..Default::default() - }, - Err(e) => Header { - message_id: header.message_id, - command: Command::DeviceReset, - flags: HeaderFlags::Error as u32, - message_size: size_of::
() as u32, - error: e.raw_os_error().unwrap_or_default() as u32, - }, - }; - stream - .write_all(reply.as_slice()) - .map_err(Error::StreamWrite)?; - } + if let Err(e) = self.handle_command(backend, &mut stream, header, fds) { + error!("Error handling command: {:?}: {e}", header.command); + let reply = Header { + message_id: header.message_id, + command: header.command, + flags: HeaderFlags::Error as u32, + message_size: size_of::
() as u32, + error: 0, + }; + stream + .write_all(reply.as_slice()) + .map_err(Error::StreamWrite)?; } }