diff --git a/virtio-devices/Cargo.toml b/virtio-devices/Cargo.toml index ec07906fb..8a3c321c2 100644 --- a/virtio-devices/Cargo.toml +++ b/virtio-devices/Cargo.toml @@ -24,6 +24,7 @@ rate_limiter = { path = "../rate_limiter" } seccompiler = "0.2.0" serde = { version = "1.0.144", features = ["derive"] } serde_json = "1.0.85" +serial_buffer = { path = "../serial_buffer" } thiserror = "1.0.32" versionize = "0.1.6" versionize_derive = "0.1.4" diff --git a/virtio-devices/src/console.rs b/virtio-devices/src/console.rs index ca0376139..ae2a7df6e 100644 --- a/virtio-devices/src/console.rs +++ b/virtio-devices/src/console.rs @@ -14,6 +14,7 @@ use crate::VirtioInterrupt; use anyhow::anyhow; use libc::{EFD_NONBLOCK, TIOCGWINSZ}; use seccompiler::SeccompAction; +use serial_buffer::SerialBuffer; use std::cmp; use std::collections::VecDeque; use std::fs::File; @@ -89,11 +90,15 @@ struct ConsoleEpollHandler { kill_evt: EventFd, pause_evt: EventFd, access_platform: Option>, + out: Option>, + write_out: Option>, + file_event_registered: bool, } pub enum Endpoint { File(File), FilePair(File, File), + PtyPair(File, File), Null, } @@ -102,6 +107,7 @@ impl Endpoint { match self { Self::File(f) => Some(f), Self::FilePair(f, _) => Some(f), + Self::PtyPair(f, _) => Some(f), Self::Null => None, } } @@ -110,9 +116,14 @@ impl Endpoint { match self { Self::File(_) => None, Self::FilePair(_, f) => Some(f), + Self::PtyPair(_, f) => Some(f), Self::Null => None, } } + + fn is_pty(&self) -> bool { + matches!(self, Self::PtyPair(_, _)) + } } impl Clone for Endpoint { @@ -122,12 +133,68 @@ impl Clone for Endpoint { Self::FilePair(f_out, f_in) => { Self::FilePair(f_out.try_clone().unwrap(), f_in.try_clone().unwrap()) } + Self::PtyPair(f_out, f_in) => { + Self::PtyPair(f_out.try_clone().unwrap(), f_in.try_clone().unwrap()) + } Self::Null => Self::Null, } } } impl ConsoleEpollHandler { + #[allow(clippy::too_many_arguments)] + fn new( + mem: GuestMemoryAtomic, + queues: Vec, + interrupt_cb: Arc, + in_buffer: Arc>>, + resizer: Arc, + endpoint: Endpoint, + input_queue_evt: EventFd, + output_queue_evt: EventFd, + input_evt: EventFd, + config_evt: EventFd, + resize_pipe: Option, + kill_evt: EventFd, + pause_evt: EventFd, + access_platform: Option>, + ) -> Self { + let out_file = endpoint.out_file(); + let (out, write_out) = if let Some(out_file) = out_file { + let writer = out_file.try_clone().unwrap(); + if endpoint.is_pty() { + let pty_write_out = Arc::new(AtomicBool::new(false)); + let write_out = Some(pty_write_out.clone()); + let buffer = SerialBuffer::new(Box::new(writer), pty_write_out); + (Some(Box::new(buffer) as Box), write_out) + } else { + (Some(Box::new(writer) as Box), None) + } + } else { + (None, None) + }; + + ConsoleEpollHandler { + mem, + queues, + interrupt_cb, + in_buffer, + resizer, + endpoint, + input_queue_evt, + output_queue_evt, + input_evt, + config_evt, + resize_pipe, + kill_evt, + pause_evt, + access_platform, + out, + write_out, + file_event_registered: false, + } + } + /* * Each port of virtio console device has one receive * queue. One or more empty buffers are placed by the @@ -184,7 +251,7 @@ impl ConsoleEpollHandler { while let Some(mut desc_chain) = trans_queue.pop_descriptor_chain(self.mem.memory()) { let desc = desc_chain.next().unwrap(); - if let Some(ref mut out) = self.endpoint.out_file() { + if let Some(out) = &mut self.out { let _ = desc_chain.memory().write_to( desc.addr() .translate_gva(self.access_platform.as_ref(), desc.len() as usize), @@ -225,9 +292,62 @@ impl ConsoleEpollHandler { helper.add_event(resize_pipe.as_raw_fd(), RESIZE_EVENT)?; } if let Some(in_file) = self.endpoint.in_file() { - helper.add_event(in_file.as_raw_fd(), FILE_EVENT)?; + let mut events = epoll::Events::EPOLLIN; + if self.endpoint.is_pty() { + events |= epoll::Events::EPOLLONESHOT; + } + helper.add_event_custom(in_file.as_raw_fd(), FILE_EVENT, events)?; + self.file_event_registered = true; } - helper.run(paused, paused_sync, self)?; + + // In case of PTY, we want to be able to detect a connection on the + // other end of the PTY. This is done by detecting there's no event + // triggered on the epoll, which is the reason why we want the + // epoll_wait() function to return after the timeout expired. + // In case of TTY, we don't expect to detect such behavior, which is + // why we can afford to block until an actual event is triggered. + let (timeout, enable_event_list) = if self.endpoint.is_pty() { + (500, true) + } else { + (-1, false) + }; + helper.run_with_timeout(paused, paused_sync, self, timeout, enable_event_list)?; + + Ok(()) + } + + // This function should be called when the other end of the PTY is + // connected. It verifies if this is the first time it's been invoked + // after the connection happened, and if that's the case it flushes + // all output from the console to the PTY. Otherwise, it's a no-op. + fn trigger_pty_flush(&mut self) -> result::Result<(), anyhow::Error> { + if let (Some(pty_write_out), Some(out)) = (&self.write_out, &mut self.out) { + if pty_write_out.load(Ordering::Acquire) { + return Ok(()); + } + pty_write_out.store(true, Ordering::Release); + out.flush() + .map_err(|e| anyhow!("Failed to flush PTY: {:?}", e)) + } else { + Ok(()) + } + } + + fn register_file_event( + &mut self, + helper: &mut EpollHelper, + ) -> result::Result<(), EpollHelperError> { + if self.file_event_registered { + return Ok(()); + } + + // Re-arm the file event. + helper.mod_event_custom( + self.endpoint.in_file().unwrap().as_raw_fd(), + FILE_EVENT, + epoll::Events::EPOLLIN | epoll::Events::EPOLLONESHOT, + )?; + self.file_event_registered = true; Ok(()) } @@ -236,10 +356,11 @@ impl ConsoleEpollHandler { impl EpollHelperHandler for ConsoleEpollHandler { fn handle_event( &mut self, - _helper: &mut EpollHelper, + helper: &mut EpollHelper, event: &epoll::Event, ) -> result::Result<(), EpollHelperError> { let ev_type = event.data as u16; + match ev_type { INPUT_QUEUE_EVENT => { self.input_queue_evt.read().map_err(|e| { @@ -307,20 +428,40 @@ impl EpollHelperHandler for ConsoleEpollHandler { self.resizer.update_console_size(); } FILE_EVENT => { - let mut input = [0u8; 64]; - if let Some(ref mut in_file) = self.endpoint.in_file() { - if let Ok(count) = in_file.read(&mut input) { - let mut in_buffer = self.in_buffer.lock().unwrap(); - in_buffer.extend(&input[..count]); - } + if event.events & libc::EPOLLIN as u32 != 0 { + let mut input = [0u8; 64]; + if let Some(ref mut in_file) = self.endpoint.in_file() { + if let Ok(count) = in_file.read(&mut input) { + let mut in_buffer = self.in_buffer.lock().unwrap(); + in_buffer.extend(&input[..count]); + } - if self.process_input_queue() { - self.signal_used_queue(0).map_err(|e| { - EpollHelperError::HandleEvent(anyhow!( - "Failed to signal used queue: {:?}", - e - )) - })?; + if self.process_input_queue() { + self.signal_used_queue(0).map_err(|e| { + EpollHelperError::HandleEvent(anyhow!( + "Failed to signal used queue: {:?}", + e + )) + })?; + } + } + } + if self.endpoint.is_pty() { + self.file_event_registered = false; + if event.events & libc::EPOLLHUP as u32 != 0 { + if let Some(pty_write_out) = &self.write_out { + if pty_write_out.load(Ordering::Acquire) { + pty_write_out.store(false, Ordering::Release); + } + } + } else { + // If the EPOLLHUP flag is not up on the associated event, we + // can assume the other end of the PTY is connected and therefore + // we can flush the output of the serial to it. + self.trigger_pty_flush() + .map_err(EpollHelperError::HandleTimeout)?; + + self.register_file_event(helper)?; } } } @@ -332,6 +473,55 @@ impl EpollHelperHandler for ConsoleEpollHandler { } Ok(()) } + + // This function will be invoked whenever the timeout is reached before + // any other event was triggered while waiting for the epoll. + fn handle_timeout(&mut self, helper: &mut EpollHelper) -> Result<(), EpollHelperError> { + if !self.endpoint.is_pty() { + return Ok(()); + } + + if self.file_event_registered { + // This very specific case happens when the console is connected + // to a PTY. We know EPOLLHUP is always present when there's nothing + // connected at the other end of the PTY. That's why getting no event + // means we can flush the output of the console through the PTY. + self.trigger_pty_flush() + .map_err(EpollHelperError::HandleTimeout)?; + } + + // Every time we hit the timeout, let's register the FILE_EVENT to give + // us a chance to catch a possible event that might have been triggered. + self.register_file_event(helper) + } + + // This function returns the full list of events found on the epoll before + // iterating through it calling handle_event(). It allows the detection of + // the PTY connection even when the timeout is not being triggered, which + // happens when there are other events preventing the timeout from being + // reached. This is an additional way of detecting a PTY connection. + fn event_list( + &mut self, + helper: &mut EpollHelper, + events: &[epoll::Event], + ) -> Result<(), EpollHelperError> { + if self.file_event_registered { + for event in events { + if event.data as u16 == FILE_EVENT && (event.events & libc::EPOLLHUP as u32) != 0 { + return Ok(()); + } + } + + // This very specific case happens when the console is connected + // to a PTY. We know EPOLLHUP is always present when there's nothing + // connected at the other end of the PTY. That's why getting no event + // means we can flush the output of the console through the PTY. + self.trigger_pty_flush() + .map_err(EpollHelperError::HandleTimeout)?; + } + + self.register_file_event(helper) + } } /// Resize handler @@ -532,22 +722,22 @@ impl VirtioDevice for Console { virtqueues.push(queue); let output_queue_evt = queue_evt; - let mut handler = ConsoleEpollHandler { + let mut handler = ConsoleEpollHandler::new( mem, - queues: virtqueues, + virtqueues, interrupt_cb, - in_buffer: self.in_buffer.clone(), - endpoint: self.endpoint.clone(), + self.in_buffer.clone(), + Arc::clone(&self.resizer), + self.endpoint.clone(), input_queue_evt, output_queue_evt, input_evt, - config_evt: self.resizer.config_evt.try_clone().unwrap(), - resize_pipe: self.resize_pipe.as_ref().map(|p| p.try_clone().unwrap()), - resizer: Arc::clone(&self.resizer), + self.resizer.config_evt.try_clone().unwrap(), + self.resize_pipe.as_ref().map(|p| p.try_clone().unwrap()), kill_evt, pause_evt, - access_platform: self.common.access_platform.clone(), - }; + self.common.access_platform.clone(), + ); let paused = self.common.paused.clone(); let paused_sync = self.common.paused_sync.clone();