diff --git a/src/main.rs b/src/main.rs index f4c7745df..148f990cf 100644 --- a/src/main.rs +++ b/src/main.rs @@ -521,7 +521,13 @@ fn start_vmm(cmd_arguments: ArgMatches) -> Result, Error> { // Before we start any threads, mask the signals we'll be // installing handlers for, to make sure they only ever run on the // dedicated signal handling thread we'll start in a bit. - for sig in &vmm::vm::HANDLED_SIGNALS { + for sig in &vmm::vm::Vm::HANDLED_SIGNALS { + if let Err(e) = block_signal(*sig) { + eprintln!("Error blocking signals: {}", e); + } + } + + for sig in &vmm::Vmm::HANDLED_SIGNALS { if let Err(e) = block_signal(*sig) { eprintln!("Error blocking signals: {}", e); } diff --git a/vmm/src/lib.rs b/vmm/src/lib.rs index b531d553b..17435016d 100644 --- a/vmm/src/lib.rs +++ b/vmm/src/lib.rs @@ -26,12 +26,13 @@ use crate::migration::{recv_vm_config, recv_vm_state}; use crate::seccomp_filters::{get_seccomp_filter, Thread}; use crate::vm::{Error as VmError, Vm, VmState}; use anyhow::anyhow; -use libc::EFD_NONBLOCK; +use libc::{EFD_NONBLOCK, SIGINT, SIGTERM}; use memory_manager::MemoryManagerSnapshotData; use pci::PciBdf; use seccompiler::{apply_filter, SeccompAction}; use serde::ser::{SerializeStruct, Serializer}; use serde::{Deserialize, Serialize}; +use signal_hook::iterator::{Handle, Signals}; use std::collections::HashMap; use std::fs::File; use std::io; @@ -39,6 +40,7 @@ use std::io::{Read, Write}; use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; use std::os::unix::net::UnixListener; use std::os::unix::net::UnixStream; +use std::panic::AssertUnwindSafe; use std::path::PathBuf; use std::sync::mpsc::{Receiver, RecvError, SendError, Sender}; use std::sync::{Arc, Mutex}; @@ -48,7 +50,9 @@ use vm_memory::bitmap::AtomicBitmap; use vm_migration::{protocol::*, Migratable}; use vm_migration::{MigratableError, Pausable, Snapshot, Snapshottable, Transportable}; use vmm_sys_util::eventfd::EventFd; +use vmm_sys_util::signal::unblock_signal; use vmm_sys_util::sock_ctrl_msg::ScmSocket; +use vmm_sys_util::terminal::Terminal; mod acpi; pub mod api; @@ -163,6 +167,12 @@ pub enum Error { #[cfg(feature = "gdb")] #[error("Error sending GDB request: {0}")] GdbResponseSend(#[source] SendError), + + #[error("Cannot spawn a signal handler thread: {0}")] + SignalHandlerSpawn(#[source] io::Error), + + #[error("Failed to join on threads: {0:?}")] + ThreadCleanup(std::boxed::Box), } pub type Result = result::Result; @@ -299,6 +309,8 @@ pub fn start_vmm_thread( exit_evt, )?; + vmm.setup_signal_handler()?; + vmm.control_loop( Arc::new(api_receiver), #[cfg(feature = "gdb")] @@ -362,9 +374,78 @@ pub struct Vmm { seccomp_action: SeccompAction, hypervisor: Arc, activate_evt: EventFd, + signals: Option, + threads: Vec>, } impl Vmm { + pub const HANDLED_SIGNALS: [i32; 2] = [SIGTERM, SIGINT]; + + fn signal_handler(mut signals: Signals, on_tty: bool, exit_evt: &EventFd) { + for sig in &Self::HANDLED_SIGNALS { + unblock_signal(*sig).unwrap(); + } + + for signal in signals.forever() { + match signal { + SIGTERM | SIGINT => { + if exit_evt.write(1).is_err() { + // Resetting the terminal is usually done as the VMM exits + if on_tty { + io::stdin() + .lock() + .set_canon_mode() + .expect("failed to restore terminal mode"); + } + std::process::exit(1); + } + } + _ => (), + } + } + } + + fn setup_signal_handler(&mut self) -> Result<()> { + let signals = Signals::new(&Self::HANDLED_SIGNALS); + match signals { + Ok(signals) => { + self.signals = Some(signals.handle()); + let exit_evt = self.exit_evt.try_clone().map_err(Error::EventFdClone)?; + let on_tty = unsafe { libc::isatty(libc::STDIN_FILENO as i32) } != 0; + + let signal_handler_seccomp_filter = + get_seccomp_filter(&self.seccomp_action, Thread::SignalHandler) + .map_err(Error::CreateSeccompFilter)?; + self.threads.push( + thread::Builder::new() + .name("vmm_signal_handler".to_string()) + .spawn(move || { + if !signal_handler_seccomp_filter.is_empty() { + if let Err(e) = apply_filter(&signal_handler_seccomp_filter) + .map_err(Error::ApplySeccompFilter) + { + error!("Error applying seccomp filter: {:?}", e); + exit_evt.write(1).ok(); + return; + } + } + std::panic::catch_unwind(AssertUnwindSafe(|| { + Vmm::signal_handler(signals, on_tty, &exit_evt); + })) + .map_err(|_| { + error!("signal_handler thead panicked"); + exit_evt.write(1).ok() + }) + .ok(); + }) + .map_err(Error::SignalHandlerSpawn)?, + ); + } + Err(e) => error!("Signal not found {}", e), + } + Ok(()) + } + fn new( vmm_version: String, api_evt: EventFd, @@ -414,6 +495,8 @@ impl Vmm { seccomp_action, hypervisor, activate_evt, + signals: None, + threads: vec![], }) } @@ -1855,6 +1938,16 @@ impl Vmm { } } + // Trigger the termination of the signal_handler thread + if let Some(signals) = self.signals.take() { + signals.close(); + } + + // Wait for all the threads to finish + for thread in self.threads.drain(..) { + thread.join().map_err(Error::ThreadCleanup)? + } + Ok(()) } } diff --git a/vmm/src/vm.rs b/vmm/src/vm.rs index 40138fd33..3b9f4a93a 100644 --- a/vmm/src/vm.rs +++ b/vmm/src/vm.rs @@ -64,11 +64,7 @@ use linux_loader::loader::pe::Error::InvalidImageMagicNumber; use linux_loader::loader::KernelLoader; use seccompiler::{apply_filter, SeccompAction}; use serde::{Deserialize, Serialize}; -use signal_hook::{ - consts::{SIGINT, SIGTERM, SIGWINCH}, - iterator::backend::Handle, - iterator::Signals, -}; +use signal_hook::{consts::SIGWINCH, iterator::backend::Handle, iterator::Signals}; use std::cmp; use std::collections::BTreeMap; use std::collections::HashMap; @@ -455,8 +451,6 @@ pub fn physical_bits(max_phys_bits: u8) -> u8 { cmp::min(host_phys_bits, max_phys_bits) } -pub const HANDLED_SIGNALS: [i32; 3] = [SIGWINCH, SIGTERM, SIGINT]; - pub struct Vm { #[cfg(any(target_arch = "aarch64", feature = "tdx"))] kernel: Option, @@ -485,6 +479,8 @@ pub struct Vm { } impl Vm { + pub const HANDLED_SIGNALS: [i32; 1] = [SIGWINCH]; + #[allow(clippy::too_many_arguments)] fn new_from_memory_manager( config: Arc>, @@ -1678,34 +1674,14 @@ impl Vm { Ok(self.device_manager.lock().unwrap().counters()) } - fn os_signal_handler( - mut signals: Signals, - console_input_clone: Arc, - on_tty: bool, - exit_evt: &EventFd, - ) { - for sig in &HANDLED_SIGNALS { + fn signal_handler(mut signals: Signals, console_input_clone: Arc) { + for sig in &Vm::HANDLED_SIGNALS { unblock_signal(*sig).unwrap(); } for signal in signals.forever() { - match signal { - SIGWINCH => { - console_input_clone.update_console_size(); - } - SIGTERM | SIGINT => { - if exit_evt.write(1).is_err() { - // Resetting the terminal is usually done as the VMM exits - if on_tty { - io::stdin() - .lock() - .set_canon_mode() - .expect("failed to restore terminal mode"); - } - std::process::exit(1); - } - } - _ => (), + if signal == SIGWINCH { + console_input_clone.update_console_size(); } } } @@ -1994,18 +1970,17 @@ impl Vm { fn setup_signal_handler(&mut self) -> Result<()> { let console = self.device_manager.lock().unwrap().console().clone(); - let signals = Signals::new(&HANDLED_SIGNALS); + let signals = Signals::new(&Vm::HANDLED_SIGNALS); match signals { Ok(signals) => { self.signals = Some(signals.handle()); let exit_evt = self.exit_evt.try_clone().map_err(Error::EventFdClone)?; - let on_tty = self.on_tty; let signal_handler_seccomp_filter = get_seccomp_filter(&self.seccomp_action, Thread::SignalHandler) .map_err(Error::CreateSeccompFilter)?; self.threads.push( thread::Builder::new() - .name("signal_handler".to_string()) + .name("vm_signal_handler".to_string()) .spawn(move || { if !signal_handler_seccomp_filter.is_empty() { if let Err(e) = apply_filter(&signal_handler_seccomp_filter) @@ -2017,7 +1992,7 @@ impl Vm { } } std::panic::catch_unwind(AssertUnwindSafe(|| { - Vm::os_signal_handler(signals, console, on_tty, &exit_evt); + Vm::signal_handler(signals, console); })) .map_err(|_| { error!("signal_handler thead panicked");