diff --git a/vmm/src/device_manager.rs b/vmm/src/device_manager.rs index 19ce15be3..8fa420083 100644 --- a/vmm/src/device_manager.rs +++ b/vmm/src/device_manager.rs @@ -777,7 +777,6 @@ struct DeviceManagerState { #[derive(Debug)] pub struct PtyPair { pub main: File, - pub sub: File, pub path: PathBuf, } @@ -785,7 +784,6 @@ impl Clone for PtyPair { fn clone(&self) -> Self { PtyPair { main: self.main.try_clone().unwrap(), - sub: self.sub.try_clone().unwrap(), path: self.path.clone(), } } @@ -1915,7 +1913,7 @@ impl DeviceManager { let file = main.try_clone().unwrap(); assert!(resize_pipe.is_none()); self.listen_for_sigwinch_on_tty(&sub).unwrap(); - self.console_pty = Some(Arc::new(Mutex::new(PtyPair { main, sub, path }))); + self.console_pty = Some(Arc::new(Mutex::new(PtyPair { main, path }))); Endpoint::FilePair(file.try_clone().unwrap(), file) } } @@ -2016,7 +2014,7 @@ impl DeviceManager { self.set_raw_mode(&mut sub) .map_err(DeviceManagerError::SetPtyRaw)?; self.config.lock().unwrap().serial.file = Some(path.clone()); - self.serial_pty = Some(Arc::new(Mutex::new(PtyPair { main, sub, path }))); + self.serial_pty = Some(Arc::new(Mutex::new(PtyPair { main, path }))); } None } diff --git a/vmm/src/serial_buffer.rs b/vmm/src/serial_buffer.rs index 94b9dc59d..53b4b2842 100644 --- a/vmm/src/serial_buffer.rs +++ b/vmm/src/serial_buffer.rs @@ -3,173 +3,95 @@ // SPDX-License-Identifier: Apache-2.0 // -use crate::serial_manager::EpollDispatch; - -use std::io::Write; -use std::os::unix::io::RawFd; +use std::{ + collections::VecDeque, + io::Write, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, +}; // Circular buffer implementation for serial output. // Read from head; push to tail pub(crate) struct SerialBuffer { - buffer: Vec, - head: usize, - tail: usize, + buffer: VecDeque, out: Box, - buffering: bool, - out_fd: Option, - epoll_fd: Option, + write_out: Arc, } -const MAX_BUFFER_SIZE: usize = 16 << 10; - impl SerialBuffer { - pub(crate) fn new(out: Box) -> Self { + pub(crate) fn new(out: Box, write_out: Arc) -> Self { Self { - buffer: vec![], - head: 0, - tail: 0, + buffer: VecDeque::new(), out, - buffering: false, - out_fd: None, - epoll_fd: None, + write_out, } } +} - pub(crate) fn add_out_fd(&mut self, out_fd: RawFd) { - self.out_fd = Some(out_fd); - } +impl Write for SerialBuffer { + fn write(&mut self, buf: &[u8]) -> Result { + // Simply fill the buffer if we're not allowed to write to the out + // device. + if !self.write_out.load(Ordering::Acquire) { + self.buffer.extend(buf); + return Ok(buf.len()); + } - pub(crate) fn add_epoll_fd(&mut self, epoll_fd: RawFd) { - self.epoll_fd = Some(epoll_fd); - } + // In case we're allowed to write to the out device, we flush the + // content of the buffer. + self.flush()?; - pub fn flush_buffer(&mut self) -> Result<(), std::io::Error> { - if self.tail <= self.head { - // The buffer to be written is in two parts - let buf = &self.buffer[self.head..]; - match self.out.write(buf) { - Ok(bytes_written) => { - if bytes_written == buf.len() { - self.head = 0; - // Can now proceed to write the other part of the buffer - } else { - self.head += bytes_written; - self.out.flush()?; - return Ok(()); + // If after flushing the buffer, it's still not empty, that means + // only a subset of the bytes was written and we should fill the buffer + // with what's coming from the serial. + if !self.buffer.is_empty() { + self.buffer.extend(buf); + return Ok(buf.len()); + } + + // We reach this point if we're allowed to write to the out device + // and we know there's nothing left in the buffer. + let mut offset = 0; + loop { + match self.out.write(&buf[offset..]) { + Ok(written_bytes) => { + if written_bytes < buf.len() - offset { + offset += written_bytes; + continue; } } Err(e) => { if !matches!(e.kind(), std::io::ErrorKind::WouldBlock) { return Err(e); } - self.add_out_poll()?; - return Ok(()); + self.buffer.extend(&buf[offset..]); } } + break; } - let buf = &self.buffer[self.head..self.tail]; - match self.out.write(buf) { - Ok(bytes_written) => { - if bytes_written == buf.len() { - self.buffer.clear(); - self.buffer.shrink_to_fit(); - self.head = 0; - self.tail = 0; - self.remove_out_poll()?; - } else { - self.head += bytes_written; - } - self.out.flush()?; - } - Err(e) => { - if !matches!(e.kind(), std::io::ErrorKind::WouldBlock) { - return Err(e); - } - self.add_out_poll()?; - } - } - - Ok(()) - } - - fn add_out_poll(&mut self) -> Result<(), std::io::Error> { - if self.out_fd.is_some() && self.epoll_fd.is_some() && !self.buffering { - self.buffering = true; - let out_fd = self.out_fd.as_ref().unwrap(); - let epoll_fd = self.epoll_fd.as_ref().unwrap(); - epoll::ctl( - *epoll_fd, - epoll::ControlOptions::EPOLL_CTL_MOD, - *out_fd, - epoll::Event::new( - epoll::Events::EPOLLIN | epoll::Events::EPOLLOUT, - EpollDispatch::File as u64, - ), - )?; - } - Ok(()) - } - - fn remove_out_poll(&mut self) -> Result<(), std::io::Error> { - if self.out_fd.is_some() && self.epoll_fd.is_some() && self.buffering { - self.buffering = false; - let out_fd = self.out_fd.as_ref().unwrap(); - let epoll_fd = self.epoll_fd.as_ref().unwrap(); - epoll::ctl( - *epoll_fd, - epoll::ControlOptions::EPOLL_CTL_MOD, - *out_fd, - epoll::Event::new(epoll::Events::EPOLLIN, EpollDispatch::File as u64), - )?; - } - Ok(()) - } -} -impl Write for SerialBuffer { - fn write(&mut self, buf: &[u8]) -> Result { - // The serial output only writes one byte at a time - for v in buf { - if self.buffer.is_empty() { - // This case exists to avoid allocating the buffer if it's not needed - if let Err(e) = self.out.write(&[*v]) { - if !matches!(e.kind(), std::io::ErrorKind::WouldBlock) { - return Err(e); - } - self.add_out_poll()?; - self.buffer.push(*v); - self.tail += 1; - } else { - self.out.flush()?; - } - } else { - // Buffer is completely full, lose the oldest byte by moving head forward - if self.head == self.tail { - self.head = self.tail + 1; - if self.head == MAX_BUFFER_SIZE { - self.head = 0; - } - } - - if self.buffer.len() < MAX_BUFFER_SIZE { - self.buffer.push(*v); - } else { - self.buffer[self.tail] = *v; - } - - self.tail += 1; - if self.tail == MAX_BUFFER_SIZE { - self.tail = 0; - } - - self.flush_buffer()?; - } - } + // Make sure we flush anything that might have been written to the + // out device. + self.out.flush()?; Ok(buf.len()) } + // This function flushes the content of the buffer to the out device if + // it is allowed to, otherwise this is a no-op. fn flush(&mut self) -> Result<(), std::io::Error> { - self.flush_buffer() + if !self.write_out.load(Ordering::Acquire) { + return Ok(()); + } + + while let Some(byte) = self.buffer.pop_front() { + if self.out.write_all(&[byte]).is_err() { + self.buffer.push_front(byte); + break; + } + } + self.out.flush() } } diff --git a/vmm/src/serial_manager.rs b/vmm/src/serial_manager.rs index d5ead8bff..9d31d98f3 100644 --- a/vmm/src/serial_manager.rs +++ b/vmm/src/serial_manager.rs @@ -15,6 +15,7 @@ use std::fs::File; use std::io::Read; use std::os::unix::io::{AsRawFd, FromRawFd}; use std::panic::AssertUnwindSafe; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{Arc, Mutex}; use std::{io, result, thread}; use thiserror::Error; @@ -84,6 +85,7 @@ pub struct SerialManager { in_file: File, kill_evt: EventFd, handle: Option>, + pty_write_out: Option>, } impl SerialManager { @@ -147,11 +149,12 @@ impl SerialManager { ) .map_err(Error::Epoll)?; + let mut pty_write_out = None; if mode == ConsoleOutputMode::Pty { + let write_out = Arc::new(AtomicBool::new(false)); + pty_write_out = Some(write_out.clone()); let writer = in_file.try_clone().map_err(Error::FileClone)?; - let mut buffer = SerialBuffer::new(Box::new(writer)); - buffer.add_out_fd(in_file.as_raw_fd()); - buffer.add_epoll_fd(epoll_fd); + let buffer = SerialBuffer::new(Box::new(writer), write_out); serial.as_ref().lock().unwrap().set_out(Box::new(buffer)); } @@ -164,9 +167,36 @@ impl SerialManager { in_file, kill_evt, handle: None, + pty_write_out, })) } + // This function should be called when the other end of the PTY is + // connected. It verifies if this is the first time it's been invoked + // after the connection happened, and if that's the case it flushes + // all output from the serial to the PTY. Otherwise, it's a no-op. + fn trigger_pty_flush( + #[cfg(target_arch = "x86_64")] serial: &Arc>, + #[cfg(target_arch = "aarch64")] serial: &Arc>, + pty_write_out: Option<&Arc>, + ) -> Result<()> { + if let Some(pty_write_out) = &pty_write_out { + if pty_write_out.load(Ordering::Acquire) { + return Ok(()); + } + + pty_write_out.store(true, Ordering::Release); + + serial + .lock() + .unwrap() + .flush_output() + .map_err(Error::FlushOutput)?; + } + + Ok(()) + } + pub fn start_thread(&mut self, exit_evt: EventFd) -> Result<()> { // Don't allow this to be run if the handle exists if self.handle.is_some() { @@ -177,6 +207,15 @@ impl SerialManager { let epoll_fd = self.epoll_file.as_raw_fd(); let mut in_file = self.in_file.try_clone().map_err(Error::FileClone)?; let serial = self.serial.clone(); + let pty_write_out = self.pty_write_out.clone(); + + // In case of PTY, we want to be able to detect a connection on the + // other end of the PTY. This is done by detecting there's no event + // triggered on the epoll, which is the reason why we want the + // epoll_wait() function to return after the timeout expired. + // In case of TTY, we don't expect to detect such behavior, which is + // why we can afford to block until an actual event is triggered. + let timeout = if pty_write_out.is_some() { 500 } else { -1 }; let thread = thread::Builder::new() .name("serial-manager".to_string()) @@ -189,7 +228,7 @@ impl SerialManager { vec![epoll::Event::new(epoll::Events::empty(), 0); EPOLL_EVENTS_LEN]; loop { - let num_events = match epoll::wait(epoll_fd, -1, &mut events[..]) { + let num_events = match epoll::wait(epoll_fd, timeout, &mut events[..]) { Ok(res) => res, Err(e) => { if e.kind() == io::ErrorKind::Interrupted { @@ -200,12 +239,22 @@ impl SerialManager { // returns an error of type EINTR, but this should not // be considered as a regular error. Instead it is more // appropriate to retry, by calling into epoll_wait(). - continue; + 0 + } else { + return Err(Error::Epoll(e)); } - return Err(Error::Epoll(e)); } }; + if num_events == 0 { + // This very specific case happens when the serial is connected + // to a PTY. We know EPOLLHUP is always present when there's nothing + // connected at the other end of the PTY. That's why getting no event + // means we can flush the output of the serial through the PTY. + Self::trigger_pty_flush(&serial, pty_write_out.as_ref())?; + continue; + } + for event in events.iter().take(num_events) { let dispatch_event: EpollDispatch = event.data.into(); match dispatch_event { @@ -214,14 +263,6 @@ impl SerialManager { warn!("Unknown serial manager loop event: {}", event); } EpollDispatch::File => { - if event.events & libc::EPOLLOUT as u32 != 0 { - serial - .as_ref() - .lock() - .unwrap() - .flush_output() - .map_err(Error::FlushOutput)?; - } if event.events & libc::EPOLLIN as u32 != 0 { let mut input = [0u8; 64]; let count = @@ -239,6 +280,20 @@ impl SerialManager { .queue_input_bytes(&input[..count]) .map_err(Error::QueueInput)?; } + if event.events & libc::EPOLLHUP as u32 != 0 { + if let Some(pty_write_out) = &pty_write_out { + pty_write_out.store(false, Ordering::Release); + } + // It's really important to sleep here as this will prevent + // the current thread from consuming 100% of the CPU cycles + // when waiting for someone to connect to the PTY. + std::thread::sleep(std::time::Duration::from_millis(500)); + } else { + // If the EPOLLHUP flag is not up on the associated event, we + // can assume the other end of the PTY is connected and therefore + // we can flush the output of the serial to it. + Self::trigger_pty_flush(&serial, pty_write_out.as_ref())?; + } } EpollDispatch::Kill => { info!("KILL event received, stopping epoll loop");