diff --git a/vmm/src/device_manager.rs b/vmm/src/device_manager.rs index cf1c6700b..fb9fcaa8b 100644 --- a/vmm/src/device_manager.rs +++ b/vmm/src/device_manager.rs @@ -1869,7 +1869,7 @@ impl DeviceManager { self.modify_mode(f.as_raw_fd(), |t| unsafe { cfmakeraw(t) }) } - fn listen_for_sigwinch_on_tty(&mut self, pty_main: File, pty_sub: File) -> std::io::Result<()> { + fn listen_for_sigwinch_on_tty(&mut self, pty_sub: File) -> std::io::Result<()> { let seccomp_filter = get_seccomp_filter( &self.seccomp_action, Thread::PtyForeground, @@ -1877,7 +1877,7 @@ impl DeviceManager { ) .unwrap(); - match start_sigwinch_listener(seccomp_filter, pty_main, pty_sub) { + match start_sigwinch_listener(seccomp_filter, pty_sub) { Ok(pipe) => { self.console_resize_pipe = Some(Arc::new(pipe)); } @@ -1917,8 +1917,7 @@ impl DeviceManager { self.config.lock().unwrap().console.file = Some(path.clone()); let file = main.try_clone().unwrap(); assert!(resize_pipe.is_none()); - self.listen_for_sigwinch_on_tty(main.try_clone().unwrap(), sub) - .unwrap(); + self.listen_for_sigwinch_on_tty(sub).unwrap(); self.console_pty = Some(Arc::new(Mutex::new(PtyPair { main, path }))); Endpoint::PtyPair(file.try_clone().unwrap(), file) } diff --git a/vmm/src/seccomp_filters.rs b/vmm/src/seccomp_filters.rs index b782319d4..8177d9faf 100644 --- a/vmm/src/seccomp_filters.rs +++ b/vmm/src/seccomp_filters.rs @@ -498,6 +498,7 @@ fn vmm_thread_rules( (libc::SYS_clone, vec![]), (libc::SYS_clone3, vec![]), (libc::SYS_close, vec![]), + (libc::SYS_close_range, vec![]), (libc::SYS_connect, vec![]), (libc::SYS_dup, vec![]), (libc::SYS_epoll_create1, vec![]), diff --git a/vmm/src/sigwinch_listener.rs b/vmm/src/sigwinch_listener.rs index 60188ef57..d789bc445 100644 --- a/vmm/src/sigwinch_listener.rs +++ b/vmm/src/sigwinch_listener.rs @@ -1,16 +1,18 @@ -// Copyright 2021 Alyssa Ross +// Copyright 2021, 2023 Alyssa Ross // SPDX-License-Identifier: Apache-2.0 use crate::clone3::{clone3, clone_args, CLONE_CLEAR_SIGHAND}; use libc::{ c_int, c_void, close, getpgrp, ioctl, pipe2, poll, pollfd, setsid, sigemptyset, siginfo_t, - sigprocmask, tcsetpgrp, O_CLOEXEC, POLLERR, SIGWINCH, SIG_SETMASK, STDIN_FILENO, STDOUT_FILENO, - TIOCSCTTY, + sigprocmask, syscall, tcsetpgrp, SYS_close_range, ENOSYS, O_CLOEXEC, POLLERR, SIGWINCH, + SIG_SETMASK, STDERR_FILENO, TIOCSCTTY, }; use seccompiler::{apply_filter, BpfProgram}; use std::cell::RefCell; -use std::fs::File; +use std::collections::BTreeSet; +use std::fs::{read_dir, File}; use std::io::{self, ErrorKind, Read, Write}; +use std::iter::once; use std::mem::size_of; use std::mem::MaybeUninit; use std::os::unix::prelude::*; @@ -60,17 +62,72 @@ fn unblock_all_signals() -> io::Result<()> { Ok(()) } -fn sigwinch_listener_main(seccomp_filter: BpfProgram, tx: File, pty: File) -> ! { - TX.with(|opt| opt.replace(Some(tx))); +/// # Safety +/// +/// Caller is responsible for ensuring all file descriptors not listed +/// in `keep_fds` are not accessed after this point, and that no other +/// thread is opening file descriptors while this function is +/// running. +unsafe fn close_fds_fallback(keep_fds: &BTreeSet) { + // We collect these instead of iterating through them, because we + // don't want to close the descriptor for /proc/self/fd while + // we're iterating through it. + let open_fds: BTreeSet = read_dir("/proc/self/fd") + .unwrap() + .map(Result::unwrap) + .filter_map(|s| s.file_name().into_string().ok()?.parse().ok()) + .collect(); + for fd in open_fds.difference(keep_fds) { + close(*fd); + } +} + +/// # Safety +/// +/// Caller is responsible for ensuring all file descriptors not listed +/// in `keep_fds` are not accessed after this point, and that no other +/// thread is opening file descriptors while this function is +/// running. +unsafe fn close_unused_fds(keep_fds: &mut [RawFd]) { + keep_fds.sort(); + + // Iterate over the gaps between descriptors we want to keep. + let firsts = keep_fds.iter().map(|fd| fd + 1); + for (i, first) in once(0).chain(firsts).enumerate() { + // The next fd is the one at i, because the indexes in the + // iterator are offset by one due to the initial 0. + let next_keep_fd = keep_fds.get(i); + let last = next_keep_fd.map(|fd| fd - 1).unwrap_or(RawFd::MAX); + + if first > last { + continue; + } + + if syscall(SYS_close_range, first, last, 0) == -1 { + // The kernel might be too old to have close_range, in + // which case we need to fall back to an uglier method. + let e = io::Error::last_os_error(); + if e.raw_os_error() == Some(ENOSYS) { + return close_fds_fallback(&keep_fds.iter().copied().collect()); + } + + panic!("close_range: {e}"); + } + } +} + +fn sigwinch_listener_main(seccomp_filter: BpfProgram, tx: File, pty: File) -> ! { let pty_fd = pty.into_raw_fd(); - // SAFETY: FFI calls + // SAFETY: any references to these file descriptors are + // unreachable, because this function never returns. unsafe { - close(STDIN_FILENO); - close(STDOUT_FILENO); + close_unused_fds(&mut [STDERR_FILENO, tx.as_raw_fd(), pty_fd]); } + TX.with(|opt| opt.replace(Some(tx))); + unblock_all_signals().unwrap(); if !seccomp_filter.is_empty() { @@ -119,11 +176,7 @@ fn sigwinch_listener_main(seccomp_filter: BpfProgram, tx: File, pty: File) -> ! exit(0); } -pub fn start_sigwinch_listener( - seccomp_filter: BpfProgram, - pty_main: File, - pty_sub: File, -) -> io::Result { +pub fn start_sigwinch_listener(seccomp_filter: BpfProgram, tty_sub: File) -> io::Result { let mut pipe = [-1; 2]; // SAFETY: FFI call with valid arguments if unsafe { pipe2(pipe.as_mut_ptr(), O_CLOEXEC) } == -1 { @@ -142,9 +195,7 @@ pub fn start_sigwinch_listener( match unsafe { clone3(&mut args, size_of::()) } { -1 => return Err(io::Error::last_os_error()), 0 => { - drop(rx); - drop(pty_main); - sigwinch_listener_main(seccomp_filter, tx, pty_sub); + sigwinch_listener_main(seccomp_filter, tx, tty_sub); } _ => (), }