main: Enable the api-socket to be passed as an fd

To avoid race issues where the api-socket may not be created by the
time a cloud-hypervisor caller is ready to look for it, enable the
caller to pass the api-socket fd directly.

Avoid breaking current callers by allowing the --api-socket path to be
passed as it is now in addition to through the path argument.

Signed-off-by: William Douglas <william.r.douglas@gmail.com>
This commit is contained in:
William Douglas 2021-04-09 23:46:55 +00:00
parent 375382cb08
commit 767b4f0e59
4 changed files with 75 additions and 22 deletions

View File

@ -25,7 +25,7 @@ use signal_hook::{
}; };
use std::env; use std::env;
use std::fs::File; use std::fs::File;
use std::os::unix::io::FromRawFd; use std::os::unix::io::{FromRawFd, RawFd};
use std::sync::mpsc::channel; use std::sync::mpsc::channel;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use std::thread; use std::thread;
@ -62,6 +62,8 @@ enum Error {
ThreadJoin(std::boxed::Box<dyn std::any::Any + std::marker::Send>), ThreadJoin(std::boxed::Box<dyn std::any::Any + std::marker::Send>),
#[error("VMM thread exited with error: {0}")] #[error("VMM thread exited with error: {0}")]
VmmThread(#[source] vmm::Error), VmmThread(#[source] vmm::Error),
#[error("Error parsing --api-socket: {0}")]
ParsingApiSocket(std::num::ParseIntError),
#[error("Error parsing --event-monitor: {0}")] #[error("Error parsing --event-monitor: {0}")]
ParsingEventMonitor(option_parser::OptionParserError), ParsingEventMonitor(option_parser::OptionParserError),
#[error("Error parsing --event-monitor: path or fd required")] #[error("Error parsing --event-monitor: path or fd required")]
@ -324,7 +326,7 @@ fn create_app<'a, 'b>(
.arg( .arg(
Arg::with_name("api-socket") Arg::with_name("api-socket")
.long("api-socket") .long("api-socket")
.help("HTTP API socket path (UNIX domain socket).") .help("HTTP API socket (UNIX domain socket): path=</path/to/a/file> or fd=<fd>.")
.takes_value(true) .takes_value(true)
.min_values(1) .min_values(1)
.group("vmm-config"), .group("vmm-config"),
@ -379,7 +381,7 @@ fn create_app<'a, 'b>(
app app
} }
fn start_vmm(cmd_arguments: ArgMatches, api_socket_path: &Option<String>) -> Result<(), Error> { fn start_vmm(cmd_arguments: ArgMatches) -> Result<Option<String>, Error> {
let log_level = match cmd_arguments.occurrences_of("v") { let log_level = match cmd_arguments.occurrences_of("v") {
0 => LevelFilter::Warn, 0 => LevelFilter::Warn,
1 => LevelFilter::Info, 1 => LevelFilter::Info,
@ -402,6 +404,29 @@ fn start_vmm(cmd_arguments: ArgMatches, api_socket_path: &Option<String>) -> Res
.map(|()| log::set_max_level(log_level)) .map(|()| log::set_max_level(log_level))
.map_err(Error::LoggerSetup)?; .map_err(Error::LoggerSetup)?;
let (api_socket_path, api_socket_fd) =
if let Some(socket_config) = cmd_arguments.value_of("api-socket") {
let mut parser = OptionParser::new();
parser.add("path").add("fd");
parser.parse(socket_config).unwrap_or_default();
if let Some(fd) = parser.get("fd") {
(
None,
Some(fd.parse::<RawFd>().map_err(Error::ParsingApiSocket)?),
)
} else if let Some(path) = parser.get("path") {
(Some(path), None)
} else {
(
cmd_arguments.value_of("api-socket").map(|s| s.to_string()),
None,
)
}
} else {
(None, None)
};
if let Some(monitor_config) = cmd_arguments.value_of("event-monitor") { if let Some(monitor_config) = cmd_arguments.value_of("event-monitor") {
let mut parser = OptionParser::new(); let mut parser = OptionParser::new();
parser.add("path").add("fd"); parser.add("path").add("fd");
@ -474,7 +499,8 @@ fn start_vmm(cmd_arguments: ArgMatches, api_socket_path: &Option<String>) -> Res
let hypervisor = hypervisor::new().map_err(Error::CreateHypervisor)?; let hypervisor = hypervisor::new().map_err(Error::CreateHypervisor)?;
let vmm_thread = vmm::start_vmm_thread( let vmm_thread = vmm::start_vmm_thread(
env!("CARGO_PKG_VERSION").to_string(), env!("CARGO_PKG_VERSION").to_string(),
api_socket_path, &api_socket_path,
api_socket_fd,
api_evt.try_clone().unwrap(), api_evt.try_clone().unwrap(),
http_sender, http_sender,
api_request_receiver, api_request_receiver,
@ -510,7 +536,9 @@ fn start_vmm(cmd_arguments: ArgMatches, api_socket_path: &Option<String>) -> Res
vmm_thread vmm_thread
.join() .join()
.map_err(Error::ThreadJoin)? .map_err(Error::ThreadJoin)?
.map_err(Error::VmmThread) .map_err(Error::VmmThread)?;
Ok(api_socket_path)
} }
fn main() { fn main() {
@ -519,16 +547,17 @@ fn main() {
let (default_vcpus, default_memory, default_rng) = prepare_default_values(); let (default_vcpus, default_memory, default_rng) = prepare_default_values();
let cmd_arguments = create_app(&default_vcpus, &default_memory, &default_rng).get_matches(); let cmd_arguments = create_app(&default_vcpus, &default_memory, &default_rng).get_matches();
let api_socket_path = cmd_arguments.value_of("api-socket").map(|s| s.to_string()); let exit_code = match start_vmm(cmd_arguments) {
Ok(path) => {
let exit_code = if let Err(e) = start_vmm(cmd_arguments, &api_socket_path) { path.map(|s| std::fs::remove_file(s).ok());
0
}
Err(e) => {
eprintln!("{}", e); eprintln!("{}", e);
1 1
} else { }
0
}; };
api_socket_path.map(|s| std::fs::remove_file(s).ok());
std::process::exit(exit_code); std::process::exit(exit_code);
} }

View File

@ -11,6 +11,7 @@ use micro_http::{Body, HttpServer, MediaType, Method, Request, Response, StatusC
use seccomp::{SeccompAction, SeccompFilter}; use seccomp::{SeccompAction, SeccompFilter};
use serde_json::Error as SerdeError; use serde_json::Error as SerdeError;
use std::collections::HashMap; use std::collections::HashMap;
use std::os::unix::io::RawFd;
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::mpsc::Sender; use std::sync::mpsc::Sender;
use std::sync::Arc; use std::sync::Arc;
@ -253,16 +254,12 @@ fn handle_http_request(
response response
} }
pub fn start_http_thread( fn start_http_thread(
path: &str, mut server: HttpServer,
api_notifier: EventFd, api_notifier: EventFd,
api_sender: Sender<ApiRequest>, api_sender: Sender<ApiRequest>,
seccomp_action: &SeccompAction, seccomp_action: &SeccompAction,
) -> Result<thread::JoinHandle<Result<()>>> { ) -> Result<thread::JoinHandle<Result<()>>> {
std::fs::remove_file(path).unwrap_or_default();
let socket_path = PathBuf::from(path);
let mut server = HttpServer::new(socket_path).map_err(Error::CreateApiServer)?;
// Retrieve seccomp filter for API thread // Retrieve seccomp filter for API thread
let api_seccomp_filter = let api_seccomp_filter =
get_seccomp_filter(seccomp_action, Thread::Api).map_err(Error::CreateSeccompFilter)?; get_seccomp_filter(seccomp_action, Thread::Api).map_err(Error::CreateSeccompFilter)?;
@ -299,3 +296,25 @@ pub fn start_http_thread(
}) })
.map_err(Error::HttpThreadSpawn) .map_err(Error::HttpThreadSpawn)
} }
pub fn start_http_path_thread(
path: &str,
api_notifier: EventFd,
api_sender: Sender<ApiRequest>,
seccomp_action: &SeccompAction,
) -> Result<thread::JoinHandle<Result<()>>> {
std::fs::remove_file(path).unwrap_or_default();
let socket_path = PathBuf::from(path);
let server = HttpServer::new(socket_path).map_err(Error::CreateApiServer)?;
start_http_thread(server, api_notifier, api_sender, seccomp_action)
}
pub fn start_http_fd_thread(
fd: RawFd,
api_notifier: EventFd,
api_sender: Sender<ApiRequest>,
seccomp_action: &SeccompAction,
) -> Result<thread::JoinHandle<Result<()>>> {
let server = HttpServer::new_from_fd(fd).map_err(Error::CreateApiServer)?;
start_http_thread(server, api_notifier, api_sender, seccomp_action)
}

View File

@ -31,7 +31,8 @@
extern crate vm_device; extern crate vm_device;
extern crate vmm_sys_util; extern crate vmm_sys_util;
pub use self::http::start_http_thread; pub use self::http::start_http_fd_thread;
pub use self::http::start_http_path_thread;
pub mod http; pub mod http;
pub mod http_endpoint; pub mod http_endpoint;

View File

@ -247,9 +247,11 @@ impl Serialize for PciDeviceInfo {
} }
} }
#[allow(clippy::too_many_arguments)]
pub fn start_vmm_thread( pub fn start_vmm_thread(
vmm_version: String, vmm_version: String,
http_path: &Option<String>, http_path: &Option<String>,
http_fd: Option<RawFd>,
api_event: EventFd, api_event: EventFd,
api_sender: Sender<ApiRequest>, api_sender: Sender<ApiRequest>,
api_receiver: Receiver<ApiRequest>, api_receiver: Receiver<ApiRequest>,
@ -280,9 +282,11 @@ pub fn start_vmm_thread(
}) })
.map_err(Error::VmmThreadSpawn)?; .map_err(Error::VmmThreadSpawn)?;
if let Some(http_path) = http_path {
// The VMM thread is started, we can start serving HTTP requests // The VMM thread is started, we can start serving HTTP requests
api::start_http_thread(http_path, http_api_event, api_sender, seccomp_action)?; if let Some(http_path) = http_path {
api::start_http_path_thread(http_path, http_api_event, api_sender, seccomp_action)?;
} else if let Some(http_fd) = http_fd {
api::start_http_fd_thread(http_fd, http_api_event, api_sender, seccomp_action)?;
} }
Ok(thread) Ok(thread)
} }