diff --git a/virtio-devices/src/vsock/unix/muxer.rs b/virtio-devices/src/vsock/unix/muxer.rs index 2eca46fc2..0768beae9 100644 --- a/virtio-devices/src/vsock/unix/muxer.rs +++ b/virtio-devices/src/vsock/unix/muxer.rs @@ -40,7 +40,7 @@ use std::collections::{HashMap, HashSet}; use std::fs::File; -use std::io::{self, Read}; +use std::io::{self, ErrorKind, Read}; use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; use std::os::unix::net::{UnixListener, UnixStream}; @@ -92,6 +92,15 @@ enum EpollListener { LocalStream(UnixStream), } +/// A partially read "CONNECT" command. +#[derive(Default)] +struct PartiallyReadCommand { + /// The bytes of the command that have been read so far. + buf: [u8; 32], + /// How much of `buf` has been used. + len: usize, +} + /// The vsock connection multiplexer. /// pub struct VsockMuxer { @@ -101,6 +110,8 @@ pub struct VsockMuxer { conn_map: HashMap, /// A hash map used to store epoll event listeners / handlers. listener_map: HashMap, + /// A hash map used to store partially read "connect" commands. + partial_command_map: HashMap, /// The RX queue. Items in this queue are consumed by `VsockMuxer::recv_pkt()`, and /// produced /// - by `VsockMuxer::send_pkt()` (e.g. RST in response to a connection request packet); @@ -358,6 +369,7 @@ impl VsockMuxer { rxq: MuxerRxQ::new(), conn_map: HashMap::with_capacity(defs::MAX_CONNECTIONS), listener_map: HashMap::with_capacity(defs::MAX_CONNECTIONS + 1), + partial_command_map: Default::default(), killq: MuxerKillQ::new(), local_port_last: (1u32 << 30) - 1, local_port_set: HashSet::with_capacity(defs::MAX_CONNECTIONS), @@ -424,27 +436,40 @@ impl VsockMuxer { // Data is ready to be read from a host-initiated connection. That would be the // "connect" command that we're expecting. Some(EpollListener::LocalStream(_)) => { - if let Some(EpollListener::LocalStream(mut stream)) = self.remove_listener(fd) { - Self::read_local_stream_port(&mut stream) - .map(|peer_port| (self.allocate_local_port(), peer_port)) - .and_then(|(local_port, peer_port)| { - self.add_connection( - ConnMapKey { - local_port, - peer_port, - }, - MuxerConnection::new_local_init( - stream, - uapi::VSOCK_HOST_CID, - self.cid, - local_port, - peer_port, - ), - ) - }) - .unwrap_or_else(|err| { - info!("vsock: error adding local-init connection: {:?}", err); - }) + if let Some(EpollListener::LocalStream(stream)) = self.listener_map.get_mut(&fd) { + let port = Self::read_local_stream_port(&mut self.partial_command_map, stream); + + if let Err(Error::UnixRead(ref e)) = port { + if e.kind() == ErrorKind::WouldBlock { + return; + } + } + + let stream = match self.remove_listener(fd) { + Some(EpollListener::LocalStream(s)) => s, + _ => unreachable!(), + }; + + port.and_then(|peer_port| { + let local_port = self.allocate_local_port(); + + self.add_connection( + ConnMapKey { + local_port, + peer_port, + }, + MuxerConnection::new_local_init( + stream, + uapi::VSOCK_HOST_CID, + self.cid, + local_port, + peer_port, + ), + ) + }) + .unwrap_or_else(|err| { + info!("vsock: error adding local-init connection: {:?}", err); + }) } } @@ -459,30 +484,36 @@ impl VsockMuxer { /// Parse a host "connect" command, and extract the destination vsock port. /// - fn read_local_stream_port(stream: &mut UnixStream) -> Result { - let mut buf = [0u8; 32]; + fn read_local_stream_port( + partial_command_map: &mut HashMap, + stream: &mut UnixStream, + ) -> Result { + let command = partial_command_map.entry(stream.as_raw_fd()).or_default(); // This is the minimum number of bytes that we should be able to read, when parsing a // valid connection request. I.e. `b"connect 0\n".len()`. - const MIN_READ_LEN: usize = 10; + const MIN_COMMAND_LEN: usize = 10; // Bring in the minimum number of bytes that we should be able to read. - stream - .read_exact(&mut buf[..MIN_READ_LEN]) - .map_err(Error::UnixRead)?; + if command.len < MIN_COMMAND_LEN { + command.len += stream + .read(&mut command.buf[command.len..MIN_COMMAND_LEN]) + .map_err(Error::UnixRead)?; + } // Now, finish reading the destination port number, by bringing in one byte at a time, // until we reach an EOL terminator (or our buffer space runs out). Yeah, not // particularly proud of this approach, but it will have to do for now. - let mut blen = MIN_READ_LEN; - while buf[blen - 1] != b'\n' && blen < buf.len() { - stream - .read_exact(&mut buf[blen..=blen]) + while command.buf[command.len - 1] != b'\n' && command.len < command.buf.len() { + command.len += stream + .read(&mut command.buf[command.len..=command.len]) .map_err(Error::UnixRead)?; - blen += 1; } - let mut word_iter = std::str::from_utf8(&buf[..blen]) + let _ = command; + let command = partial_command_map.remove(&stream.as_raw_fd()).unwrap(); + + let mut word_iter = std::str::from_utf8(&command.buf[..command.len]) .map_err(Error::ConvertFromUtf8)? .split_whitespace();