virtio-devices: fix reading vsock connect command

The socket is nonblocking, so it's not guaranteed that it will be
possible to read the whole connect command in a single iteration of
the event loop.  To reproduce:

	(echo -n 'CONNECT '; sleep 1; echo 1234; cat) | socat STDIO UNIX-CONNECT:vsock.sock

This would produce the error:

	cloud-hypervisor: 5.509209s: <_vsock4> INFO:virtio-devices/src/vsock/unix/muxer.rs:446 -- vsock: error adding local-init connection: UnixRead(Os { code: 11, kind: WouldBlock, message: "Resource temporarily unavailable" })

To fix this, if we only get a partial command, we need to save it for
future iterations of the event loop, and only proceed once we've read
a complete command.

Signed-off-by: Alyssa Ross <hi@alyssa.is>
This commit is contained in:
Alyssa Ross 2024-01-09 14:46:46 +01:00 committed by Rob Bradford
parent dc68a6e30f
commit 48de800756

View File

@ -40,7 +40,7 @@
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::fs::File; use std::fs::File;
use std::io::{self, Read}; use std::io::{self, ErrorKind, Read};
use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
use std::os::unix::net::{UnixListener, UnixStream}; use std::os::unix::net::{UnixListener, UnixStream};
@ -92,6 +92,15 @@ enum EpollListener {
LocalStream(UnixStream), LocalStream(UnixStream),
} }
/// A partially read "CONNECT" command.
#[derive(Default)]
struct PartiallyReadCommand {
/// The bytes of the command that have been read so far.
buf: [u8; 32],
/// How much of `buf` has been used.
len: usize,
}
/// The vsock connection multiplexer. /// The vsock connection multiplexer.
/// ///
pub struct VsockMuxer { pub struct VsockMuxer {
@ -101,6 +110,8 @@ pub struct VsockMuxer {
conn_map: HashMap<ConnMapKey, MuxerConnection>, conn_map: HashMap<ConnMapKey, MuxerConnection>,
/// A hash map used to store epoll event listeners / handlers. /// A hash map used to store epoll event listeners / handlers.
listener_map: HashMap<RawFd, EpollListener>, listener_map: HashMap<RawFd, EpollListener>,
/// A hash map used to store partially read "connect" commands.
partial_command_map: HashMap<RawFd, PartiallyReadCommand>,
/// The RX queue. Items in this queue are consumed by `VsockMuxer::recv_pkt()`, and /// The RX queue. Items in this queue are consumed by `VsockMuxer::recv_pkt()`, and
/// produced /// produced
/// - by `VsockMuxer::send_pkt()` (e.g. RST in response to a connection request packet); /// - by `VsockMuxer::send_pkt()` (e.g. RST in response to a connection request packet);
@ -358,6 +369,7 @@ impl VsockMuxer {
rxq: MuxerRxQ::new(), rxq: MuxerRxQ::new(),
conn_map: HashMap::with_capacity(defs::MAX_CONNECTIONS), conn_map: HashMap::with_capacity(defs::MAX_CONNECTIONS),
listener_map: HashMap::with_capacity(defs::MAX_CONNECTIONS + 1), listener_map: HashMap::with_capacity(defs::MAX_CONNECTIONS + 1),
partial_command_map: Default::default(),
killq: MuxerKillQ::new(), killq: MuxerKillQ::new(),
local_port_last: (1u32 << 30) - 1, local_port_last: (1u32 << 30) - 1,
local_port_set: HashSet::with_capacity(defs::MAX_CONNECTIONS), local_port_set: HashSet::with_capacity(defs::MAX_CONNECTIONS),
@ -424,10 +436,23 @@ impl VsockMuxer {
// Data is ready to be read from a host-initiated connection. That would be the // Data is ready to be read from a host-initiated connection. That would be the
// "connect" command that we're expecting. // "connect" command that we're expecting.
Some(EpollListener::LocalStream(_)) => { Some(EpollListener::LocalStream(_)) => {
if let Some(EpollListener::LocalStream(mut stream)) = self.remove_listener(fd) { if let Some(EpollListener::LocalStream(stream)) = self.listener_map.get_mut(&fd) {
Self::read_local_stream_port(&mut stream) let port = Self::read_local_stream_port(&mut self.partial_command_map, stream);
.map(|peer_port| (self.allocate_local_port(), peer_port))
.and_then(|(local_port, peer_port)| { if let Err(Error::UnixRead(ref e)) = port {
if e.kind() == ErrorKind::WouldBlock {
return;
}
}
let stream = match self.remove_listener(fd) {
Some(EpollListener::LocalStream(s)) => s,
_ => unreachable!(),
};
port.and_then(|peer_port| {
let local_port = self.allocate_local_port();
self.add_connection( self.add_connection(
ConnMapKey { ConnMapKey {
local_port, local_port,
@ -459,30 +484,36 @@ impl VsockMuxer {
/// Parse a host "connect" command, and extract the destination vsock port. /// Parse a host "connect" command, and extract the destination vsock port.
/// ///
fn read_local_stream_port(stream: &mut UnixStream) -> Result<u32> { fn read_local_stream_port(
let mut buf = [0u8; 32]; partial_command_map: &mut HashMap<RawFd, PartiallyReadCommand>,
stream: &mut UnixStream,
) -> Result<u32> {
let command = partial_command_map.entry(stream.as_raw_fd()).or_default();
// This is the minimum number of bytes that we should be able to read, when parsing a // This is the minimum number of bytes that we should be able to read, when parsing a
// valid connection request. I.e. `b"connect 0\n".len()`. // valid connection request. I.e. `b"connect 0\n".len()`.
const MIN_READ_LEN: usize = 10; const MIN_COMMAND_LEN: usize = 10;
// Bring in the minimum number of bytes that we should be able to read. // Bring in the minimum number of bytes that we should be able to read.
stream if command.len < MIN_COMMAND_LEN {
.read_exact(&mut buf[..MIN_READ_LEN]) command.len += stream
.read(&mut command.buf[command.len..MIN_COMMAND_LEN])
.map_err(Error::UnixRead)?; .map_err(Error::UnixRead)?;
}
// Now, finish reading the destination port number, by bringing in one byte at a time, // Now, finish reading the destination port number, by bringing in one byte at a time,
// until we reach an EOL terminator (or our buffer space runs out). Yeah, not // until we reach an EOL terminator (or our buffer space runs out). Yeah, not
// particularly proud of this approach, but it will have to do for now. // particularly proud of this approach, but it will have to do for now.
let mut blen = MIN_READ_LEN; while command.buf[command.len - 1] != b'\n' && command.len < command.buf.len() {
while buf[blen - 1] != b'\n' && blen < buf.len() { command.len += stream
stream .read(&mut command.buf[command.len..=command.len])
.read_exact(&mut buf[blen..=blen])
.map_err(Error::UnixRead)?; .map_err(Error::UnixRead)?;
blen += 1;
} }
let mut word_iter = std::str::from_utf8(&buf[..blen]) let _ = command;
let command = partial_command_map.remove(&stream.as_raw_fd()).unwrap();
let mut word_iter = std::str::from_utf8(&command.buf[..command.len])
.map_err(Error::ConvertFromUtf8)? .map_err(Error::ConvertFromUtf8)?
.split_whitespace(); .split_whitespace();