diff --git a/vmm/src/lib.rs b/vmm/src/lib.rs index eef29e98d..8a197b732 100644 --- a/vmm/src/lib.rs +++ b/vmm/src/lib.rs @@ -11,6 +11,7 @@ extern crate log; use std::collections::HashMap; use std::fs::File; use std::io::{stdout, Read, Write}; +use std::net::{TcpListener, TcpStream}; use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; use std::os::unix::net::{UnixListener, UnixStream}; use std::panic::AssertUnwindSafe; @@ -36,8 +37,8 @@ use serde::{Deserialize, Serialize}; use signal_hook::iterator::{Handle, Signals}; use thiserror::Error; use tracer::trace_scoped; -use vm_memory::bitmap::AtomicBitmap; -use vm_memory::{ReadVolatile, WriteVolatile}; +use vm_memory::bitmap::{AtomicBitmap, BitmapSlice}; +use vm_memory::{ReadVolatile, VolatileMemoryError, VolatileSlice, WriteVolatile}; use vm_migration::protocol::*; use vm_migration::{Migratable, MigratableError, Pausable, Snapshot, Snapshottable, Transportable}; use vmm_sys_util::eventfd::EventFd; @@ -235,6 +236,89 @@ impl From for EpollDispatch { } } +enum SocketStream { + Unix(UnixStream), + Tcp(TcpStream), +} + +impl Read for SocketStream { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + match self { + SocketStream::Unix(stream) => stream.read(buf), + SocketStream::Tcp(stream) => stream.read(buf), + } + } +} + +impl Write for SocketStream { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + match self { + SocketStream::Unix(stream) => stream.write(buf), + SocketStream::Tcp(stream) => stream.write(buf), + } + } + + fn flush(&mut self) -> std::io::Result<()> { + match self { + SocketStream::Unix(stream) => stream.flush(), + SocketStream::Tcp(stream) => stream.flush(), + } + } +} + +impl AsRawFd for SocketStream { + fn as_raw_fd(&self) -> RawFd { + match self { + SocketStream::Unix(s) => s.as_raw_fd(), + SocketStream::Tcp(s) => s.as_raw_fd(), + } + } +} + +impl ReadVolatile for SocketStream { + fn read_volatile( + &mut self, + buf: &mut VolatileSlice, + ) -> std::result::Result { + match self { + SocketStream::Unix(s) => s.read_volatile(buf), + SocketStream::Tcp(s) => s.read_volatile(buf), + } + } + + fn read_exact_volatile( + &mut self, + buf: &mut VolatileSlice, + ) -> std::result::Result<(), VolatileMemoryError> { + match self { + SocketStream::Unix(s) => s.read_exact_volatile(buf), + SocketStream::Tcp(s) => s.read_exact_volatile(buf), + } + } +} + +impl WriteVolatile for SocketStream { + fn write_volatile( + &mut self, + buf: &VolatileSlice, + ) -> std::result::Result { + match self { + SocketStream::Unix(s) => s.write_volatile(buf), + SocketStream::Tcp(s) => s.write_volatile(buf), + } + } + + fn write_all_volatile( + &mut self, + buf: &VolatileSlice, + ) -> std::result::Result<(), VolatileMemoryError> { + match self { + SocketStream::Unix(s) => s.write_all_volatile(buf), + SocketStream::Tcp(s) => s.write_all_volatile(buf), + } + } +} + pub struct EpollContext { epoll_file: File, } @@ -921,14 +1005,72 @@ impl Vmm { .map(|s| s.into()) } + fn send_migration_socket( + destination_url: &str, + ) -> std::result::Result { + if let Some(address) = destination_url.strip_prefix("tcp:") { + info!("Connecting to TCP socket at {}", address); + + let socket = TcpStream::connect(address).map_err(|e| { + MigratableError::MigrateSend(anyhow!("Error connecting to TCP socket: {}", e)) + })?; + + Ok(SocketStream::Tcp(socket)) + } else { + let path = Vmm::socket_url_to_path(destination_url)?; + info!("Connecting to UNIX socket at {:?}", path); + + let socket = UnixStream::connect(&path).map_err(|e| { + MigratableError::MigrateSend(anyhow!("Error connecting to UNIX socket: {}", e)) + })?; + + Ok(SocketStream::Unix(socket)) + } + } + + fn receive_migration_socket( + receiver_url: &str, + ) -> std::result::Result { + if let Some(address) = receiver_url.strip_prefix("tcp:") { + let listener = TcpListener::bind(address).map_err(|e| { + MigratableError::MigrateReceive(anyhow!("Error binding to TCP socket: {}", e)) + })?; + + let (socket, _addr) = listener.accept().map_err(|e| { + MigratableError::MigrateReceive(anyhow!( + "Error accepting connection on TCP socket: {}", + e + )) + })?; + + Ok(SocketStream::Tcp(socket)) + } else { + let path = Vmm::socket_url_to_path(receiver_url)?; + let listener = UnixListener::bind(&path).map_err(|e| { + MigratableError::MigrateReceive(anyhow!("Error binding to UNIX socket: {}", e)) + })?; + + let (socket, _addr) = listener.accept().map_err(|e| { + MigratableError::MigrateReceive(anyhow!( + "Error accepting connection on UNIX socket: {}", + e + )) + })?; + + // Remove the UNIX socket file after accepting the connection + std::fs::remove_file(&path).map_err(|e| { + MigratableError::MigrateReceive(anyhow!("Error removing UNIX socket file: {}", e)) + })?; + + Ok(SocketStream::Unix(socket)) + } + } + // Returns true if there were dirty pages to send - fn vm_maybe_send_dirty_pages( + fn vm_maybe_send_dirty_pages( vm: &mut Vm, - socket: &mut T, - ) -> result::Result - where - T: Read + Write + WriteVolatile, - { + socket: &mut SocketStream, + ) -> result::Result { // Send (dirty) memory table let table = vm.dirty_log()?; @@ -956,10 +1098,8 @@ impl Vmm { >, send_data_migration: VmSendMigrationData, ) -> result::Result<(), MigratableError> { - let path = Self::socket_url_to_path(&send_data_migration.destination_url)?; - let mut socket = UnixStream::connect(path).map_err(|e| { - MigratableError::MigrateSend(anyhow!("Error connecting to UNIX socket: {}", e)) - })?; + // Set up the socket connection + let mut socket = Self::send_migration_socket(&send_data_migration.destination_url)?; // Start the migration Request::start().write_to(&mut socket)?; @@ -999,7 +1139,17 @@ impl Vmm { }; if send_data_migration.local { - vm.send_memory_fds(&mut socket)?; + match &mut socket { + SocketStream::Unix(unix_socket) => { + // Proceed with sending memory file descriptors over UNIX socket + vm.send_memory_fds(unix_socket)?; + } + SocketStream::Tcp(_tcp_socket) => { + return Err(MigratableError::MigrateSend(anyhow!( + "--local option is not supported with TCP sockets", + ))); + } + } } let vm_migration_config = VmMigrationConfig { @@ -1960,16 +2110,8 @@ impl RequestHandler for Vmm { receive_data_migration.receiver_url ); - let path = Self::socket_url_to_path(&receive_data_migration.receiver_url)?; - let listener = UnixListener::bind(&path).map_err(|e| { - MigratableError::MigrateReceive(anyhow!("Error binding to UNIX socket: {}", e)) - })?; - let (mut socket, _addr) = listener.accept().map_err(|e| { - MigratableError::MigrateReceive(anyhow!("Error accepting on UNIX socket: {}", e)) - })?; - std::fs::remove_file(&path).map_err(|e| { - MigratableError::MigrateReceive(anyhow!("Error unlinking UNIX socket: {}", e)) - })?; + // Accept the connection and get the socket + let mut socket = Vmm::receive_migration_socket(&receive_data_migration.receiver_url)?; let mut started = false; let mut memory_manager: Option>> = None; @@ -2037,24 +2179,35 @@ impl RequestHandler for Vmm { continue; } - let mut buf = [0u8; 4]; - let (_, file) = socket.recv_with_fd(&mut buf).map_err(|e| { - MigratableError::MigrateReceive(anyhow!( - "Error receiving slot from socket: {}", - e - )) - })?; + match &mut socket { + SocketStream::Unix(unix_socket) => { + let mut buf = [0u8; 4]; + let (_, file) = unix_socket.recv_with_fd(&mut buf).map_err(|e| { + MigratableError::MigrateReceive(anyhow!( + "Error receiving slot from socket: {}", + e + )) + })?; - if existing_memory_files.is_none() { - existing_memory_files = Some(HashMap::default()) + if existing_memory_files.is_none() { + existing_memory_files = Some(HashMap::default()) + } + + if let Some(ref mut existing_memory_files) = existing_memory_files { + let slot = u32::from_le_bytes(buf); + existing_memory_files.insert(slot, file.unwrap()); + } + + Response::ok().write_to(&mut socket)?; + } + SocketStream::Tcp(_tcp_socket) => { + // For TCP sockets, we cannot transfer file descriptors + warn!( + "MemoryFd command received over TCP socket, which is not supported" + ); + Response::error().write_to(&mut socket)?; + } } - - if let Some(ref mut existing_memory_files) = existing_memory_files { - let slot = u32::from_le_bytes(buf); - existing_memory_files.insert(slot, file.unwrap()); - } - - Response::ok().write_to(&mut socket)?; } Command::Complete => { info!("Complete Command Received");