diff --git a/vm-virtio/src/vsock/csm/connection.rs b/vm-virtio/src/vsock/csm/connection.rs index a226fd1a9..a42794c6a 100644 --- a/vm-virtio/src/vsock/csm/connection.rs +++ b/vm-virtio/src/vsock/csm/connection.rs @@ -631,3 +631,556 @@ where .set_fwd_cnt(self.fwd_cnt.0) } } + +#[cfg(test)] +mod tests { + use libc::EFD_NONBLOCK; + + use std::io::{Error as IoError, ErrorKind, Read, Result as IoResult, Write}; + use std::os::unix::io::RawFd; + use std::time::{Duration, Instant}; + use vmm_sys_util::eventfd::EventFd; + + use super::super::super::defs::uapi; + use super::super::super::tests::TestContext; + use super::super::defs as csm_defs; + use super::*; + + const LOCAL_CID: u64 = 2; + const PEER_CID: u64 = 3; + const LOCAL_PORT: u32 = 1002; + const PEER_PORT: u32 = 1003; + const PEER_BUF_ALLOC: u32 = 64 * 1024; + + enum StreamState { + Closed, + Error(ErrorKind), + Ready, + WouldBlock, + } + + struct TestStream { + fd: EventFd, + read_buf: Vec, + read_state: StreamState, + write_buf: Vec, + write_state: StreamState, + } + impl TestStream { + fn new() -> Self { + Self { + fd: EventFd::new(EFD_NONBLOCK).unwrap(), + read_state: StreamState::Ready, + write_state: StreamState::Ready, + read_buf: Vec::new(), + write_buf: Vec::new(), + } + } + fn new_with_read_buf(buf: &[u8]) -> Self { + let mut stream = Self::new(); + stream.read_buf = buf.to_vec(); + stream + } + } + + impl AsRawFd for TestStream { + fn as_raw_fd(&self) -> RawFd { + self.fd.as_raw_fd() + } + } + + impl Read for TestStream { + fn read(&mut self, data: &mut [u8]) -> IoResult { + match self.read_state { + StreamState::Closed => Ok(0), + StreamState::Error(kind) => Err(IoError::new(kind, "whatevs")), + StreamState::Ready => { + if self.read_buf.is_empty() { + return Err(IoError::new(ErrorKind::WouldBlock, "EAGAIN")); + } + let len = std::cmp::min(data.len(), self.read_buf.len()); + assert_ne!(len, 0); + data[..len].copy_from_slice(&self.read_buf[..len]); + self.read_buf = self.read_buf.split_off(len); + Ok(len) + } + StreamState::WouldBlock => Err(IoError::new(ErrorKind::WouldBlock, "EAGAIN")), + } + } + } + + impl Write for TestStream { + fn write(&mut self, data: &[u8]) -> IoResult { + match self.write_state { + StreamState::Closed => Err(IoError::new(ErrorKind::BrokenPipe, "EPIPE")), + StreamState::Error(kind) => Err(IoError::new(kind, "whatevs")), + StreamState::Ready => { + self.write_buf.extend_from_slice(data); + Ok(data.len()) + } + StreamState::WouldBlock => Err(IoError::new(ErrorKind::WouldBlock, "EAGAIN")), + } + } + fn flush(&mut self) -> IoResult<()> { + Ok(()) + } + } + + fn init_pkt(pkt: &mut VsockPacket, op: u16, len: u32) -> &mut VsockPacket { + for b in pkt.hdr_mut() { + *b = 0; + } + pkt.set_src_cid(PEER_CID) + .set_dst_cid(LOCAL_CID) + .set_src_port(PEER_PORT) + .set_dst_port(LOCAL_PORT) + .set_type(uapi::VSOCK_TYPE_STREAM) + .set_buf_alloc(PEER_BUF_ALLOC) + .set_op(op) + .set_len(len) + } + + // This is the connection state machine test context: a helper struct to provide CSM testing + // primitives. A single `VsockPacket` object will be enough for our testing needs. We'll be + // using it for simulating both packet sends and packet receives. We need to keep the vsock + // testing context alive, since `VsockPacket` is just a pointer-wrapper over some data that + // resides in guest memory. The vsock test context owns the `GuestMemory` object, so we'll make + // it a member here, in order to make sure that guest memory outlives our testing packet. A + // single `VsockConnection` object will also suffice for our testing needs. We'll be using a + // specially crafted `Read + Write + AsRawFd` object as a backing stream, so that we can + // control the various error conditions that might arise. + struct CsmTestContext { + _vsock_test_ctx: TestContext, + pkt: VsockPacket, + conn: VsockConnection, + } + + impl CsmTestContext { + fn new_established() -> Self { + Self::new(ConnState::Established) + } + + fn new(conn_state: ConnState) -> Self { + let vsock_test_ctx = TestContext::new(); + let mut handler_ctx = vsock_test_ctx.create_epoll_handler_context(); + let stream = TestStream::new(); + let mut pkt = VsockPacket::from_rx_virtq_head( + &handler_ctx.handler.queues[0] + .iter(&vsock_test_ctx.mem) + .next() + .unwrap(), + ) + .unwrap(); + let conn = match conn_state { + ConnState::PeerInit => VsockConnection::::new_peer_init( + stream, + LOCAL_CID, + PEER_CID, + LOCAL_PORT, + PEER_PORT, + PEER_BUF_ALLOC, + ), + ConnState::LocalInit => VsockConnection::::new_local_init( + stream, LOCAL_CID, PEER_CID, LOCAL_PORT, PEER_PORT, + ), + ConnState::Established => { + let mut conn = VsockConnection::::new_peer_init( + stream, + LOCAL_CID, + PEER_CID, + LOCAL_PORT, + PEER_PORT, + PEER_BUF_ALLOC, + ); + assert!(conn.has_pending_rx()); + conn.recv_pkt(&mut pkt).unwrap(); + assert_eq!(pkt.op(), uapi::VSOCK_OP_RESPONSE); + conn + } + other => panic!("invalid ctx state: {:?}", other), + }; + assert_eq!(conn.state, conn_state); + Self { + _vsock_test_ctx: vsock_test_ctx, + pkt, + conn, + } + } + + fn set_stream(&mut self, stream: TestStream) { + self.conn.stream = stream; + } + + fn set_peer_credit(&mut self, credit: u32) { + assert!(credit < self.conn.peer_buf_alloc); + self.conn.peer_fwd_cnt = Wrapping(0); + self.conn.rx_cnt = Wrapping(self.conn.peer_buf_alloc - credit); + assert_eq!(self.conn.peer_avail_credit(), credit as usize); + } + + fn send(&mut self) { + self.conn.send_pkt(&self.pkt).unwrap(); + } + + fn recv(&mut self) { + self.conn.recv_pkt(&mut self.pkt).unwrap(); + } + + fn notify_epollin(&mut self) { + self.conn.notify(epoll::Events::EPOLLIN); + assert!(self.conn.has_pending_rx()); + } + + fn notify_epollout(&mut self) { + self.conn.notify(epoll::Events::EPOLLOUT); + } + + fn init_pkt(&mut self, op: u16, len: u32) -> &mut VsockPacket { + init_pkt(&mut self.pkt, op, len) + } + + fn init_data_pkt(&mut self, data: &[u8]) -> &VsockPacket { + assert!(data.len() <= self.pkt.buf().unwrap().len()); + self.init_pkt(uapi::VSOCK_OP_RW, data.len() as u32); + self.pkt.buf_mut().unwrap()[..data.len()].copy_from_slice(data); + &self.pkt + } + } + + #[test] + fn test_peer_request() { + let mut ctx = CsmTestContext::new(ConnState::PeerInit); + assert!(ctx.conn.has_pending_rx()); + ctx.recv(); + // For peer-initiated requests, our connection should always yield a vsock reponse packet, + // in order to establish the connection. + assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RESPONSE); + assert_eq!(ctx.pkt.src_cid(), LOCAL_CID); + assert_eq!(ctx.pkt.dst_cid(), PEER_CID); + assert_eq!(ctx.pkt.src_port(), LOCAL_PORT); + assert_eq!(ctx.pkt.dst_port(), PEER_PORT); + assert_eq!(ctx.pkt.type_(), uapi::VSOCK_TYPE_STREAM); + assert_eq!(ctx.pkt.len(), 0); + // After yielding the response packet, the connection should have transitioned to the + // established state. + assert_eq!(ctx.conn.state, ConnState::Established); + } + + #[test] + fn test_local_request() { + let mut ctx = CsmTestContext::new(ConnState::LocalInit); + // Host-initiated connections should first yield a connection request packet. + assert!(ctx.conn.has_pending_rx()); + // Before yielding the connection request packet, the timeout kill timer shouldn't be + // armed. + assert!(!ctx.conn.will_expire()); + ctx.recv(); + assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_REQUEST); + // Since the request might time-out, the kill timer should now be armed. + assert!(ctx.conn.will_expire()); + assert!(!ctx.conn.has_expired()); + ctx.init_pkt(uapi::VSOCK_OP_RESPONSE, 0); + ctx.send(); + // Upon receiving a connection response, the connection should have transitioned to the + // established state, and the kill timer should've been disarmed. + assert_eq!(ctx.conn.state, ConnState::Established); + assert!(!ctx.conn.will_expire()); + } + + #[test] + fn test_local_request_timeout() { + let mut ctx = CsmTestContext::new(ConnState::LocalInit); + ctx.recv(); + assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_REQUEST); + assert!(ctx.conn.will_expire()); + assert!(!ctx.conn.has_expired()); + std::thread::sleep(std::time::Duration::from_millis( + defs::CONN_REQUEST_TIMEOUT_MS, + )); + assert!(ctx.conn.has_expired()); + } + + #[test] + fn test_rx_data() { + let mut ctx = CsmTestContext::new_established(); + let data = &[1, 2, 3, 4]; + ctx.set_stream(TestStream::new_with_read_buf(data)); + assert_eq!(ctx.conn.get_polled_fd(), ctx.conn.stream.as_raw_fd()); + ctx.notify_epollin(); + ctx.recv(); + assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RW); + assert_eq!(ctx.pkt.len() as usize, data.len()); + assert_eq!(ctx.pkt.buf().unwrap()[..ctx.pkt.len() as usize], *data); + + // There's no more data in the stream, so `recv_pkt` should yield `VsockError::NoData`. + match ctx.conn.recv_pkt(&mut ctx.pkt) { + Err(VsockError::NoData) => (), + other => panic!("{:?}", other), + } + + // A recv attempt in an invalid state should yield an instant reset packet. + ctx.conn.state = ConnState::LocalClosed; + ctx.notify_epollin(); + ctx.recv(); + assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RST); + } + + #[test] + fn test_local_close() { + let mut ctx = CsmTestContext::new_established(); + let mut stream = TestStream::new(); + stream.read_state = StreamState::Closed; + ctx.set_stream(stream); + ctx.notify_epollin(); + ctx.recv(); + // When the host-side stream is closed, we can neither send not receive any more data. + // Therefore, the vsock shutdown packet that we'll deliver to the guest must contain both + // the no-more-send and the no-more-recv indications. + assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_SHUTDOWN); + assert_ne!(ctx.pkt.flags() & uapi::VSOCK_FLAGS_SHUTDOWN_SEND, 0); + assert_ne!(ctx.pkt.flags() & uapi::VSOCK_FLAGS_SHUTDOWN_RCV, 0); + + // The kill timer should now be armed. + assert!(ctx.conn.will_expire()); + assert!( + ctx.conn.expiry().unwrap() + < Instant::now() + Duration::from_millis(defs::CONN_SHUTDOWN_TIMEOUT_MS) + ); + } + + #[test] + fn test_peer_close() { + // Test that send/recv shutdown indications are handled correctly. + // I.e. once set, an indication cannot be reset. + { + let mut ctx = CsmTestContext::new_established(); + + ctx.init_pkt(uapi::VSOCK_OP_SHUTDOWN, 0) + .set_flags(uapi::VSOCK_FLAGS_SHUTDOWN_RCV); + ctx.send(); + assert_eq!(ctx.conn.state, ConnState::PeerClosed(true, false)); + + // Attempting to reset the no-more-recv indication should not work + // (we are only setting the no-more-send indication here). + ctx.pkt.set_flags(uapi::VSOCK_FLAGS_SHUTDOWN_SEND); + ctx.send(); + assert_eq!(ctx.conn.state, ConnState::PeerClosed(true, true)); + } + + // Test case: + // - reading data from a no-more-send connection should work; and + // - writing data should have no effect. + { + let data = &[1, 2, 3, 4]; + let mut ctx = CsmTestContext::new_established(); + ctx.set_stream(TestStream::new_with_read_buf(data)); + ctx.init_pkt(uapi::VSOCK_OP_SHUTDOWN, 0) + .set_flags(uapi::VSOCK_FLAGS_SHUTDOWN_SEND); + ctx.send(); + ctx.notify_epollin(); + ctx.recv(); + assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RW); + assert_eq!(&ctx.pkt.buf().unwrap()[..ctx.pkt.len() as usize], data); + + ctx.init_data_pkt(data); + ctx.send(); + assert_eq!(ctx.conn.stream.write_buf.len(), 0); + assert!(ctx.conn.tx_buf.is_empty()); + } + + // Test case: + // - writing data to a no-more-recv connection should work; and + // - attempting to read data from it should yield an RST packet. + { + let mut ctx = CsmTestContext::new_established(); + ctx.init_pkt(uapi::VSOCK_OP_SHUTDOWN, 0) + .set_flags(uapi::VSOCK_FLAGS_SHUTDOWN_RCV); + ctx.send(); + let data = &[1, 2, 3, 4]; + ctx.init_data_pkt(data); + ctx.send(); + assert_eq!(ctx.conn.stream.write_buf, data.to_vec()); + + ctx.notify_epollin(); + ctx.recv(); + assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RST); + } + + // Test case: setting both no-more-send and no-more-recv indications should have the + // connection confirm termination (i.e. yield an RST). + { + let mut ctx = CsmTestContext::new_established(); + ctx.init_pkt(uapi::VSOCK_OP_SHUTDOWN, 0) + .set_flags(uapi::VSOCK_FLAGS_SHUTDOWN_RCV | uapi::VSOCK_FLAGS_SHUTDOWN_SEND); + ctx.send(); + assert!(ctx.conn.has_pending_rx()); + ctx.recv(); + assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RST); + } + } + + #[test] + fn test_local_read_error() { + let mut ctx = CsmTestContext::new_established(); + let mut stream = TestStream::new(); + stream.read_state = StreamState::Error(ErrorKind::PermissionDenied); + ctx.set_stream(stream); + ctx.notify_epollin(); + ctx.recv(); + assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RST); + } + + #[test] + fn test_credit_request_to_peer() { + let mut ctx = CsmTestContext::new_established(); + ctx.set_peer_credit(0); + ctx.notify_epollin(); + ctx.recv(); + assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_CREDIT_REQUEST); + } + + #[test] + fn test_credit_request_from_peer() { + let mut ctx = CsmTestContext::new_established(); + ctx.init_pkt(uapi::VSOCK_OP_CREDIT_REQUEST, 0); + ctx.send(); + assert!(ctx.conn.has_pending_rx()); + ctx.recv(); + assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_CREDIT_UPDATE); + assert_eq!(ctx.pkt.buf_alloc(), csm_defs::CONN_TX_BUF_SIZE as u32); + assert_eq!(ctx.pkt.fwd_cnt(), ctx.conn.fwd_cnt.0); + } + + #[test] + fn test_credit_update_to_peer() { + let mut ctx = CsmTestContext::new_established(); + + // Force a stale state, where the peer hasn't been updated on our credit situation. + ctx.conn.last_fwd_cnt_to_peer = Wrapping(0); + ctx.conn.fwd_cnt = Wrapping(csm_defs::CONN_CREDIT_UPDATE_THRESHOLD as u32); + + // Fake a data send from the peer, to bring us over the credit update threshold. + let data = &[1, 2, 3, 4]; + ctx.init_data_pkt(data); + ctx.send(); + + // The CSM should now have a credit update available for the peer. + assert!(ctx.conn.has_pending_rx()); + ctx.recv(); + assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_CREDIT_UPDATE); + assert_eq!( + ctx.pkt.fwd_cnt() as usize, + csm_defs::CONN_CREDIT_UPDATE_THRESHOLD + data.len() + ); + assert_eq!(ctx.conn.fwd_cnt, ctx.conn.last_fwd_cnt_to_peer); + } + + #[test] + fn test_tx_buffering() { + // Test case: + // - when writing to the backing stream would block, TX data should end up in the TX buf + // - when the CSM is notified that it can write to the backing stream, it should flush + // the TX buf. + { + let mut ctx = CsmTestContext::new_established(); + + let mut stream = TestStream::new(); + stream.write_state = StreamState::WouldBlock; + ctx.set_stream(stream); + + // Send some data through the connection. The backing stream is set to reject writes, + // so the data should end up in the TX buffer. + let data = &[1, 2, 3, 4]; + ctx.init_data_pkt(data); + ctx.send(); + + // When there's data in the TX buffer, the connection should ask to be notified when it + // can write to its backing stream. + assert!(ctx + .conn + .get_polled_evset() + .contains(epoll::Events::EPOLLOUT)); + assert_eq!(ctx.conn.tx_buf.len(), data.len()); + + // Unlock the write stream and notify the connection it can now write its bufferred + // data. + ctx.set_stream(TestStream::new()); + ctx.conn.notify(epoll::Events::EPOLLOUT); + assert!(ctx.conn.tx_buf.is_empty()); + assert_eq!(ctx.conn.stream.write_buf, data); + } + } + + #[test] + fn test_stream_write_error() { + // Test case: sending a data packet to a broken / closed backing stream should kill it. + { + let mut ctx = CsmTestContext::new_established(); + let mut stream = TestStream::new(); + stream.write_state = StreamState::Closed; + ctx.set_stream(stream); + + let data = &[1, 2, 3, 4]; + ctx.init_data_pkt(data); + ctx.send(); + + assert_eq!(ctx.conn.state, ConnState::Killed); + assert!(ctx.conn.has_pending_rx()); + ctx.recv(); + assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RST); + } + + // Test case: notifying a connection that it can flush its TX buffer to a broken stream + // should kill the connection. + { + let mut ctx = CsmTestContext::new_established(); + + let mut stream = TestStream::new(); + stream.write_state = StreamState::WouldBlock; + ctx.set_stream(stream); + + // Send some data through the connection. The backing stream is set to reject writes, + // so the data should end up in the TX buffer. + let data = &[1, 2, 3, 4]; + ctx.init_data_pkt(data); + ctx.send(); + + // Set the backing stream to error out on write. + let mut stream = TestStream::new(); + stream.write_state = StreamState::Closed; + ctx.set_stream(stream); + + assert!(ctx + .conn + .get_polled_evset() + .contains(epoll::Events::EPOLLOUT)); + ctx.notify_epollout(); + assert_eq!(ctx.conn.state, ConnState::Killed); + } + } + + #[test] + fn test_peer_credit_misbehavior() { + let mut ctx = CsmTestContext::new_established(); + + let mut stream = TestStream::new(); + stream.write_state = StreamState::WouldBlock; + ctx.set_stream(stream); + + // Fill up the TX buffer. + let data = vec![0u8; ctx.pkt.buf().unwrap().len()]; + ctx.init_data_pkt(data.as_slice()); + for _i in 0..(csm_defs::CONN_TX_BUF_SIZE / data.len()) { + ctx.send(); + } + + // Then try to send more data. + ctx.send(); + + // The connection should've committed suicide. + assert_eq!(ctx.conn.state, ConnState::Killed); + assert!(ctx.conn.has_pending_rx()); + ctx.recv(); + assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RST); + } +} diff --git a/vm-virtio/src/vsock/device.rs b/vm-virtio/src/vsock/device.rs index 7d76e24ad..a17e02d8d 100644 --- a/vm-virtio/src/vsock/device.rs +++ b/vm-virtio/src/vsock/device.rs @@ -502,3 +502,331 @@ where Ok(()) } } + +#[cfg(test)] +mod tests { + use super::super::tests::TestContext; + use super::super::*; + use super::*; + use crate::vsock::device::{BACKEND_EVENT, EVT_QUEUE_EVENT, RX_QUEUE_EVENT, TX_QUEUE_EVENT}; + + #[test] + fn test_virtio_device() { + let mut ctx = TestContext::new(); + let avail_features = 1u64 << VIRTIO_F_VERSION_1 | 1u64 << VIRTIO_F_IN_ORDER; + let device_features = avail_features; + let driver_features: u64 = avail_features | 1 | (1 << 32); + let device_pages = [ + (device_features & 0xffff_ffff) as u32, + (device_features >> 32) as u32, + ]; + let driver_pages = [ + (driver_features & 0xffff_ffff) as u32, + (driver_features >> 32) as u32, + ]; + assert_eq!( + ctx.device.device_type(), + VirtioDeviceType::TYPE_VSOCK as u32 + ); + assert_eq!(ctx.device.queue_max_sizes(), QUEUE_SIZES); + assert_eq!(ctx.device.features(0), device_pages[0]); + assert_eq!(ctx.device.features(1), device_pages[1]); + assert_eq!(ctx.device.features(2), 0); + + // Ack device features, page 0. + ctx.device.ack_features(0, driver_pages[0]); + // Ack device features, page 1. + ctx.device.ack_features(1, driver_pages[1]); + // Ack some bogus page (i.e. 2). This should have no side effect. + ctx.device.ack_features(2, 0); + // Attempt to un-ack the first feature page. This should have no side effect. + ctx.device.ack_features(0, !driver_pages[0]); + // Check that no side effect are present, and that the acked features are exactly the same + // as the device features. + assert_eq!(ctx.device.acked_features, device_features & driver_features); + + // Test reading 32-bit chunks. + let mut data = [0u8; 8]; + ctx.device.read_config(0, &mut data[..4]); + assert_eq!( + u64::from(LittleEndian::read_u32(&data)), + ctx.cid & 0xffff_ffff + ); + ctx.device.read_config(4, &mut data[4..]); + assert_eq!( + u64::from(LittleEndian::read_u32(&data[4..])), + (ctx.cid >> 32) & 0xffff_ffff + ); + + // Test reading 64-bit. + let mut data = [0u8; 8]; + ctx.device.read_config(0, &mut data); + assert_eq!(LittleEndian::read_u64(&data), ctx.cid); + + // Check that out-of-bounds reading doesn't mutate the destination buffer. + let mut data = [0u8, 1, 2, 3, 4, 5, 6, 7]; + ctx.device.read_config(2, &mut data); + assert_eq!(data, [0u8, 1, 2, 3, 4, 5, 6, 7]); + + // Just covering lines here, since the vsock device has no writable config. + // A warning is, however, logged, if the guest driver attempts to write any config data. + ctx.device.write_config(0, &data[..4]); + + // Test a bad activation. + let bad_activate = ctx.device.activate( + Arc::new(RwLock::new(ctx.mem.clone())), + Arc::new( + Box::new(move |_: &VirtioInterruptType, _: Option<&Queue>| Ok(())) + as VirtioInterrupt, + ), + Vec::new(), + Vec::new(), + ); + match bad_activate { + Err(ActivateError::BadActivate) => (), + other => panic!("{:?}", other), + } + + // Test a correct activation. + ctx.device + .activate( + Arc::new(RwLock::new(ctx.mem.clone())), + Arc::new( + Box::new(move |_: &VirtioInterruptType, _: Option<&Queue>| Ok(())) + as VirtioInterrupt, + ), + vec![Queue::new(256), Queue::new(256), Queue::new(256)], + vec![ + EventFd::new(EFD_NONBLOCK).unwrap(), + EventFd::new(EFD_NONBLOCK).unwrap(), + EventFd::new(EFD_NONBLOCK).unwrap(), + ], + ) + .unwrap(); + } + + #[test] + fn test_irq() { + // Test case: successful IRQ signaling. + { + let test_ctx = TestContext::new(); + let ctx = test_ctx.create_epoll_handler_context(); + + let queue = Queue::new(256); + assert!(ctx.handler.signal_used_queue(&queue).is_ok()); + } + } + + #[test] + fn test_txq_event() { + // Test case: + // - the driver has something to send (there's data in the TX queue); and + // - the backend has no pending RX data. + { + let test_ctx = TestContext::new(); + let mut ctx = test_ctx.create_epoll_handler_context(); + + ctx.handler.backend.set_pending_rx(false); + ctx.signal_txq_event(); + + // The available TX descriptor should have been used. + assert_eq!(ctx.guest_txvq.used.idx.get(), 1); + // The available RX descriptor should be untouched. + assert_eq!(ctx.guest_rxvq.used.idx.get(), 0); + } + + // Test case: + // - the driver has something to send (there's data in the TX queue); and + // - the backend also has some pending RX data. + { + let test_ctx = TestContext::new(); + let mut ctx = test_ctx.create_epoll_handler_context(); + + ctx.handler.backend.set_pending_rx(true); + ctx.signal_txq_event(); + + // Both available RX and TX descriptors should have been used. + assert_eq!(ctx.guest_txvq.used.idx.get(), 1); + assert_eq!(ctx.guest_rxvq.used.idx.get(), 1); + } + + // Test case: + // - the driver has something to send (there's data in the TX queue); and + // - the backend errors out and cannot process the TX queue. + { + let test_ctx = TestContext::new(); + let mut ctx = test_ctx.create_epoll_handler_context(); + + ctx.handler.backend.set_pending_rx(false); + ctx.handler.backend.set_tx_err(Some(VsockError::NoData)); + ctx.signal_txq_event(); + + // Both RX and TX queues should be untouched. + assert_eq!(ctx.guest_txvq.used.idx.get(), 0); + assert_eq!(ctx.guest_rxvq.used.idx.get(), 0); + } + + // Test case: + // - the driver supplied a malformed TX buffer. + { + let test_ctx = TestContext::new(); + let mut ctx = test_ctx.create_epoll_handler_context(); + + // Invalidate the packet header descriptor, by setting its length to 0. + ctx.guest_txvq.dtable[0].len.set(0); + ctx.signal_txq_event(); + + // The available descriptor should have been consumed, but no packet should have + // reached the backend. + assert_eq!(ctx.guest_txvq.used.idx.get(), 1); + assert_eq!(ctx.handler.backend.tx_ok_cnt, 0); + } + + // Test case: spurious TXQ_EVENT. + { + let test_ctx = TestContext::new(); + let mut ctx = test_ctx.create_epoll_handler_context(); + + match ctx + .handler + .handle_event(TX_QUEUE_EVENT, epoll::Events::EPOLLIN) + { + Err(DeviceError::FailedReadingQueue { .. }) => (), + other => panic!("{:?}", other), + } + } + } + + #[test] + fn test_rxq_event() { + // Test case: + // - there is pending RX data in the backend; and + // - the driver makes RX buffers available; and + // - the backend successfully places its RX data into the queue. + { + let test_ctx = TestContext::new(); + let mut ctx = test_ctx.create_epoll_handler_context(); + + ctx.handler.backend.set_pending_rx(true); + ctx.handler.backend.set_rx_err(Some(VsockError::NoData)); + ctx.signal_rxq_event(); + + // The available RX buffer should've been left untouched. + assert_eq!(ctx.guest_rxvq.used.idx.get(), 0); + } + + // Test case: + // - there is pending RX data in the backend; and + // - the driver makes RX buffers available; and + // - the backend errors out, when attempting to receive data. + { + let test_ctx = TestContext::new(); + let mut ctx = test_ctx.create_epoll_handler_context(); + + ctx.handler.backend.set_pending_rx(true); + ctx.signal_rxq_event(); + + // The available RX buffer should have been used. + assert_eq!(ctx.guest_rxvq.used.idx.get(), 1); + } + + // Test case: the driver provided a malformed RX descriptor chain. + { + let test_ctx = TestContext::new(); + let mut ctx = test_ctx.create_epoll_handler_context(); + + // Invalidate the packet header descriptor, by setting its length to 0. + ctx.guest_rxvq.dtable[0].len.set(0); + + // The chain should've been processed, without employing the backend. + assert!(ctx.handler.process_rx().is_ok()); + assert_eq!(ctx.guest_rxvq.used.idx.get(), 1); + assert_eq!(ctx.handler.backend.rx_ok_cnt, 0); + } + + // Test case: spurious RXQ_EVENT. + { + let test_ctx = TestContext::new(); + let mut ctx = test_ctx.create_epoll_handler_context(); + ctx.handler.backend.set_pending_rx(false); + match ctx + .handler + .handle_event(RX_QUEUE_EVENT, epoll::Events::EPOLLIN) + { + Err(DeviceError::FailedReadingQueue { .. }) => (), + other => panic!("{:?}", other), + } + } + } + + #[test] + fn test_evq_event() { + // Test case: spurious EVQ_EVENT. + { + let test_ctx = TestContext::new(); + let mut ctx = test_ctx.create_epoll_handler_context(); + ctx.handler.backend.set_pending_rx(false); + match ctx + .handler + .handle_event(EVT_QUEUE_EVENT, epoll::Events::EPOLLIN) + { + Err(DeviceError::FailedReadingQueue { .. }) => (), + other => panic!("{:?}", other), + } + } + } + + #[test] + fn test_backend_event() { + // Test case: + // - a backend event is received; and + // - the backend has pending RX data. + { + let test_ctx = TestContext::new(); + let mut ctx = test_ctx.create_epoll_handler_context(); + + ctx.handler.backend.set_pending_rx(true); + ctx.handler + .handle_event(BACKEND_EVENT, epoll::Events::EPOLLIN) + .unwrap(); + + // The backend should've received this event. + assert_eq!(ctx.handler.backend.evset, Some(epoll::Events::EPOLLIN)); + // TX queue processing should've been triggered. + assert_eq!(ctx.guest_txvq.used.idx.get(), 1); + // RX queue processing should've been triggered. + assert_eq!(ctx.guest_rxvq.used.idx.get(), 1); + } + + // Test case: + // - a backend event is received; and + // - the backend doesn't have any pending RX data. + { + let test_ctx = TestContext::new(); + let mut ctx = test_ctx.create_epoll_handler_context(); + + ctx.handler.backend.set_pending_rx(false); + ctx.handler + .handle_event(BACKEND_EVENT, epoll::Events::EPOLLIN) + .unwrap(); + + // The backend should've received this event. + assert_eq!(ctx.handler.backend.evset, Some(epoll::Events::EPOLLIN)); + // TX queue processing should've been triggered. + assert_eq!(ctx.guest_txvq.used.idx.get(), 1); + // The RX queue should've been left untouched. + assert_eq!(ctx.guest_rxvq.used.idx.get(), 0); + } + } + + #[test] + fn test_unknown_event() { + let test_ctx = TestContext::new(); + let mut ctx = test_ctx.create_epoll_handler_context(); + + match ctx.handler.handle_event(0xff, epoll::Events::EPOLLIN) { + Err(DeviceError::UnknownEvent { .. }) => (), + other => panic!("{:?}", other), + } + } +} diff --git a/vm-virtio/src/vsock/mod.rs b/vm-virtio/src/vsock/mod.rs index 24740695b..0bf8ab637 100644 --- a/vm-virtio/src/vsock/mod.rs +++ b/vm-virtio/src/vsock/mod.rs @@ -89,6 +89,30 @@ pub enum VsockError { } type Result = std::result::Result; +#[derive(Debug)] +pub enum VsockEpollHandlerError { + /// The vsock data/buffer virtio descriptor length is smaller than expected. + BufDescTooSmall, + /// The vsock data/buffer virtio descriptor is expected, but missing. + BufDescMissing, + /// Chained GuestMemory error. + GuestMemory, + /// Bounds check failed on guest memory pointer. + GuestMemoryBounds, + /// The vsock header descriptor length is too small. + HdrDescTooSmall(u32), + /// The vsock header `len` field holds an invalid value. + InvalidPktLen(u32), + /// A data fetch was attempted when no data was available. + NoData, + /// A data buffer was expected for the provided packet, but it is missing. + PktBufMissing, + /// Encountered an unexpected write-only virtio descriptor. + UnreadableDescriptor, + /// Encountered an unexpected read-only virtio descriptor. + UnwritableDescriptor, +} + /// A passive, event-driven object, that needs to be notified whenever an epoll-able event occurs. /// An event-polling control loop will use `get_polled_fd()` and `get_polled_evset()` to query /// the listener for the file descriptor and the set of events it's interested in. When such an @@ -131,3 +155,183 @@ pub trait VsockChannel { /// Currently, the only implementation we have is `crate::virtio::unix::muxer::VsockMuxer`, which /// translates guest-side vsock connections to host-side Unix domain socket connections. pub trait VsockBackend: VsockChannel + VsockEpollListener + Send {} + +#[cfg(test)] +mod tests { + use libc::EFD_NONBLOCK; + + use super::device::{VsockEpollHandler, RX_QUEUE_EVENT, TX_QUEUE_EVENT}; + use super::packet::VSOCK_PKT_HDR_SIZE; + use super::*; + + use std::os::unix::io::AsRawFd; + use std::sync::{Arc, RwLock}; + use vmm_sys_util::eventfd::EventFd; + + use crate::device::{VirtioInterrupt, VirtioInterruptType}; + use crate::queue::tests::VirtQueue as GuestQ; + use crate::queue::Queue; + use crate::{VIRTQ_DESC_F_NEXT, VIRTQ_DESC_F_WRITE}; + use vm_memory::{GuestAddress, GuestMemoryMmap}; + + pub struct TestBackend { + pub evfd: EventFd, + pub rx_err: Option, + pub tx_err: Option, + pub pending_rx: bool, + pub rx_ok_cnt: usize, + pub tx_ok_cnt: usize, + pub evset: Option, + } + impl TestBackend { + pub fn new() -> Self { + Self { + evfd: EventFd::new(EFD_NONBLOCK).unwrap(), + rx_err: None, + tx_err: None, + pending_rx: false, + rx_ok_cnt: 0, + tx_ok_cnt: 0, + evset: None, + } + } + pub fn set_rx_err(&mut self, err: Option) { + self.rx_err = err; + } + pub fn set_tx_err(&mut self, err: Option) { + self.tx_err = err; + } + pub fn set_pending_rx(&mut self, prx: bool) { + self.pending_rx = prx; + } + } + impl VsockChannel for TestBackend { + fn recv_pkt(&mut self, _pkt: &mut VsockPacket) -> Result<()> { + match self.rx_err.take() { + None => { + self.rx_ok_cnt += 1; + Ok(()) + } + Some(e) => Err(e), + } + } + fn send_pkt(&mut self, _pkt: &VsockPacket) -> Result<()> { + match self.tx_err.take() { + None => { + self.tx_ok_cnt += 1; + Ok(()) + } + Some(e) => Err(e), + } + } + fn has_pending_rx(&self) -> bool { + self.pending_rx + } + } + impl VsockEpollListener for TestBackend { + fn get_polled_fd(&self) -> RawFd { + self.evfd.as_raw_fd() + } + fn get_polled_evset(&self) -> epoll::Events { + epoll::Events::EPOLLIN + } + fn notify(&mut self, evset: epoll::Events) { + self.evset = Some(evset); + } + } + impl VsockBackend for TestBackend {} + + pub struct TestContext { + pub cid: u64, + pub mem: GuestMemoryMmap, + pub mem_size: usize, + pub device: Vsock, + } + + impl TestContext { + pub fn new() -> Self { + const CID: u64 = 52; + const MEM_SIZE: usize = 1024 * 1024 * 128; + Self { + cid: CID, + mem: GuestMemoryMmap::new(&[(GuestAddress(0), MEM_SIZE)]).unwrap(), + mem_size: MEM_SIZE, + device: Vsock::new(CID, TestBackend::new()).unwrap(), + } + } + + pub fn create_epoll_handler_context(&self) -> EpollHandlerContext { + const QSIZE: u16 = 2; + + let guest_rxvq = GuestQ::new(GuestAddress(0x0010_0000), &self.mem, QSIZE as u16); + let guest_txvq = GuestQ::new(GuestAddress(0x0020_0000), &self.mem, QSIZE as u16); + let guest_evvq = GuestQ::new(GuestAddress(0x0030_0000), &self.mem, QSIZE as u16); + let rxvq = guest_rxvq.create_queue(); + let txvq = guest_txvq.create_queue(); + let evvq = guest_evvq.create_queue(); + + // Set up one available descriptor in the RX queue. + guest_rxvq.dtable[0].set( + 0x0040_0000, + VSOCK_PKT_HDR_SIZE as u32, + VIRTQ_DESC_F_WRITE | VIRTQ_DESC_F_NEXT, + 1, + ); + guest_rxvq.dtable[1].set(0x0040_1000, 4096, VIRTQ_DESC_F_WRITE, 0); + guest_rxvq.avail.ring[0].set(0); + guest_rxvq.avail.idx.set(1); + + // Set up one available descriptor in the TX queue. + guest_txvq.dtable[0].set(0x0050_0000, VSOCK_PKT_HDR_SIZE as u32, VIRTQ_DESC_F_NEXT, 1); + guest_txvq.dtable[1].set(0x0050_1000, 4096, 0, 0); + guest_txvq.avail.ring[0].set(0); + guest_txvq.avail.idx.set(1); + + let queues = vec![rxvq, txvq, evvq]; + let queue_evts = vec![ + EventFd::new(EFD_NONBLOCK).unwrap(), + EventFd::new(EFD_NONBLOCK).unwrap(), + EventFd::new(EFD_NONBLOCK).unwrap(), + ]; + let interrupt_cb = Arc::new(Box::new( + move |_: &VirtioInterruptType, _: Option<&Queue>| Ok(()), + ) as VirtioInterrupt); + + EpollHandlerContext { + guest_rxvq, + guest_txvq, + guest_evvq, + handler: VsockEpollHandler { + mem: Arc::new(RwLock::new(self.mem.clone())), + queues, + queue_evts, + kill_evt: EventFd::new(EFD_NONBLOCK).unwrap(), + interrupt_cb, + backend: TestBackend::new(), + }, + } + } + } + + pub struct EpollHandlerContext<'a> { + pub handler: VsockEpollHandler, + pub guest_rxvq: GuestQ<'a>, + pub guest_txvq: GuestQ<'a>, + pub guest_evvq: GuestQ<'a>, + } + + impl<'a> EpollHandlerContext<'a> { + pub fn signal_txq_event(&mut self) { + self.handler.queue_evts[1].write(1).unwrap(); + self.handler + .handle_event(TX_QUEUE_EVENT, epoll::Events::EPOLLIN) + .unwrap(); + } + pub fn signal_rxq_event(&mut self) { + self.handler.queue_evts[0].write(1).unwrap(); + self.handler + .handle_event(RX_QUEUE_EVENT, epoll::Events::EPOLLIN) + .unwrap(); + } + } +} diff --git a/vm-virtio/src/vsock/packet.rs b/vm-virtio/src/vsock/packet.rs index cd28e3533..7d25ac269 100644 --- a/vm-virtio/src/vsock/packet.rs +++ b/vm-virtio/src/vsock/packet.rs @@ -339,3 +339,319 @@ impl VsockPacket { self } } + +#[cfg(test)] +mod tests { + + use vm_memory::{GuestAddress, GuestMemoryMmap}; + + use super::super::tests::TestContext; + use super::*; + use crate::queue::tests::VirtqDesc as GuestQDesc; + use crate::vsock::defs::MAX_PKT_BUF_SIZE; + use crate::VIRTQ_DESC_F_WRITE; + + macro_rules! create_context { + ($test_ctx:ident, $handler_ctx:ident) => { + let $test_ctx = TestContext::new(); + let mut $handler_ctx = $test_ctx.create_epoll_handler_context(); + // For TX packets, hdr.len should be set to a valid value. + set_pkt_len(1024, &$handler_ctx.guest_txvq.dtable[0], &$test_ctx.mem); + }; + } + + macro_rules! expect_asm_error { + (tx, $test_ctx:expr, $handler_ctx:expr, $err:pat) => { + expect_asm_error!($test_ctx, $handler_ctx, $err, from_tx_virtq_head, 1); + }; + (rx, $test_ctx:expr, $handler_ctx:expr, $err:pat) => { + expect_asm_error!($test_ctx, $handler_ctx, $err, from_rx_virtq_head, 0); + }; + ($test_ctx:expr, $handler_ctx:expr, $err:pat, $ctor:ident, $vq:expr) => { + match VsockPacket::$ctor( + &$handler_ctx.handler.queues[$vq] + .iter(&$test_ctx.mem) + .next() + .unwrap(), + ) { + Err($err) => (), + Ok(_) => panic!("Packet assembly should've failed!"), + Err(other) => panic!("Packet assembly failed with: {:?}", other), + } + }; + } + + fn set_pkt_len(len: u32, guest_desc: &GuestQDesc, mem: &GuestMemoryMmap) { + let hdr_gpa = guest_desc.addr.get(); + let hdr_ptr = mem.get_host_address(GuestAddress(hdr_gpa)).unwrap() as *mut u8; + let len_ptr = unsafe { hdr_ptr.add(HDROFF_LEN) }; + + LittleEndian::write_u32(unsafe { std::slice::from_raw_parts_mut(len_ptr, 4) }, len); + } + + #[test] + #[allow(clippy::cognitive_complexity)] + fn test_tx_packet_assembly() { + // Test case: successful TX packet assembly. + { + create_context!(test_ctx, handler_ctx); + + let pkt = VsockPacket::from_tx_virtq_head( + &handler_ctx.handler.queues[1] + .iter(&test_ctx.mem) + .next() + .unwrap(), + ) + .unwrap(); + assert_eq!(pkt.hdr().len(), VSOCK_PKT_HDR_SIZE); + assert_eq!( + pkt.buf().unwrap().len(), + handler_ctx.guest_txvq.dtable[1].len.get() as usize + ); + } + + // Test case: error on write-only hdr descriptor. + { + create_context!(test_ctx, handler_ctx); + handler_ctx.guest_txvq.dtable[0] + .flags + .set(VIRTQ_DESC_F_WRITE); + expect_asm_error!(tx, test_ctx, handler_ctx, VsockError::UnreadableDescriptor); + } + + // Test case: header descriptor has insufficient space to hold the packet header. + { + create_context!(test_ctx, handler_ctx); + handler_ctx.guest_txvq.dtable[0] + .len + .set(VSOCK_PKT_HDR_SIZE as u32 - 1); + expect_asm_error!(tx, test_ctx, handler_ctx, VsockError::HdrDescTooSmall(_)); + } + + // Test case: zero-length TX packet. + { + create_context!(test_ctx, handler_ctx); + set_pkt_len(0, &handler_ctx.guest_txvq.dtable[0], &test_ctx.mem); + let mut pkt = VsockPacket::from_tx_virtq_head( + &handler_ctx.handler.queues[1] + .iter(&test_ctx.mem) + .next() + .unwrap(), + ) + .unwrap(); + assert!(pkt.buf().is_none()); + assert!(pkt.buf_mut().is_none()); + } + + // Test case: TX packet has more data than we can handle. + { + create_context!(test_ctx, handler_ctx); + set_pkt_len( + MAX_PKT_BUF_SIZE as u32 + 1, + &handler_ctx.guest_txvq.dtable[0], + &test_ctx.mem, + ); + expect_asm_error!(tx, test_ctx, handler_ctx, VsockError::InvalidPktLen(_)); + } + + // Test case: + // - packet header advertises some data length; and + // - the data descriptor is missing. + { + create_context!(test_ctx, handler_ctx); + set_pkt_len(1024, &handler_ctx.guest_txvq.dtable[0], &test_ctx.mem); + handler_ctx.guest_txvq.dtable[0].flags.set(0); + expect_asm_error!(tx, test_ctx, handler_ctx, VsockError::BufDescMissing); + } + + // Test case: error on write-only buf descriptor. + { + create_context!(test_ctx, handler_ctx); + handler_ctx.guest_txvq.dtable[1] + .flags + .set(VIRTQ_DESC_F_WRITE); + expect_asm_error!(tx, test_ctx, handler_ctx, VsockError::UnreadableDescriptor); + } + + // Test case: the buffer descriptor cannot fit all the data advertised by the the + // packet header `len` field. + { + create_context!(test_ctx, handler_ctx); + set_pkt_len(8 * 1024, &handler_ctx.guest_txvq.dtable[0], &test_ctx.mem); + handler_ctx.guest_txvq.dtable[1].len.set(4 * 1024); + expect_asm_error!(tx, test_ctx, handler_ctx, VsockError::BufDescTooSmall); + } + } + + #[test] + fn test_rx_packet_assembly() { + // Test case: successful RX packet assembly. + { + create_context!(test_ctx, handler_ctx); + let pkt = VsockPacket::from_rx_virtq_head( + &handler_ctx.handler.queues[0] + .iter(&test_ctx.mem) + .next() + .unwrap(), + ) + .unwrap(); + assert_eq!(pkt.hdr().len(), VSOCK_PKT_HDR_SIZE); + assert_eq!( + pkt.buf().unwrap().len(), + handler_ctx.guest_rxvq.dtable[1].len.get() as usize + ); + } + + // Test case: read-only RX packet header. + { + create_context!(test_ctx, handler_ctx); + handler_ctx.guest_rxvq.dtable[0].flags.set(0); + expect_asm_error!(rx, test_ctx, handler_ctx, VsockError::UnwritableDescriptor); + } + + // Test case: RX descriptor head cannot fit the entire packet header. + { + create_context!(test_ctx, handler_ctx); + handler_ctx.guest_rxvq.dtable[0] + .len + .set(VSOCK_PKT_HDR_SIZE as u32 - 1); + expect_asm_error!(rx, test_ctx, handler_ctx, VsockError::HdrDescTooSmall(_)); + } + + // Test case: RX descriptor chain is missing the packet buffer descriptor. + { + create_context!(test_ctx, handler_ctx); + handler_ctx.guest_rxvq.dtable[0] + .flags + .set(VIRTQ_DESC_F_WRITE); + expect_asm_error!(rx, test_ctx, handler_ctx, VsockError::BufDescMissing); + } + } + + #[test] + #[allow(clippy::cognitive_complexity)] + fn test_packet_hdr_accessors() { + const SRC_CID: u64 = 1; + const DST_CID: u64 = 2; + const SRC_PORT: u32 = 3; + const DST_PORT: u32 = 4; + const LEN: u32 = 5; + const TYPE: u16 = 6; + const OP: u16 = 7; + const FLAGS: u32 = 8; + const BUF_ALLOC: u32 = 9; + const FWD_CNT: u32 = 10; + + create_context!(test_ctx, handler_ctx); + let mut pkt = VsockPacket::from_rx_virtq_head( + &handler_ctx.handler.queues[0] + .iter(&test_ctx.mem) + .next() + .unwrap(), + ) + .unwrap(); + + // Test field accessors. + pkt.set_src_cid(SRC_CID) + .set_dst_cid(DST_CID) + .set_src_port(SRC_PORT) + .set_dst_port(DST_PORT) + .set_len(LEN) + .set_type(TYPE) + .set_op(OP) + .set_flags(FLAGS) + .set_buf_alloc(BUF_ALLOC) + .set_fwd_cnt(FWD_CNT); + + assert_eq!(pkt.src_cid(), SRC_CID); + assert_eq!(pkt.dst_cid(), DST_CID); + assert_eq!(pkt.src_port(), SRC_PORT); + assert_eq!(pkt.dst_port(), DST_PORT); + assert_eq!(pkt.len(), LEN); + assert_eq!(pkt.type_(), TYPE); + assert_eq!(pkt.op(), OP); + assert_eq!(pkt.flags(), FLAGS); + assert_eq!(pkt.buf_alloc(), BUF_ALLOC); + assert_eq!(pkt.fwd_cnt(), FWD_CNT); + + // Test individual flag setting. + let flags = pkt.flags() | 0b1000; + pkt.set_flag(0b1000); + assert_eq!(pkt.flags(), flags); + + // Test packet header as-slice access. + // + + assert_eq!(pkt.hdr().len(), VSOCK_PKT_HDR_SIZE); + + assert_eq!( + SRC_CID, + LittleEndian::read_u64(&pkt.hdr()[HDROFF_SRC_CID..]) + ); + assert_eq!( + DST_CID, + LittleEndian::read_u64(&pkt.hdr()[HDROFF_DST_CID..]) + ); + assert_eq!( + SRC_PORT, + LittleEndian::read_u32(&pkt.hdr()[HDROFF_SRC_PORT..]) + ); + assert_eq!( + DST_PORT, + LittleEndian::read_u32(&pkt.hdr()[HDROFF_DST_PORT..]) + ); + assert_eq!(LEN, LittleEndian::read_u32(&pkt.hdr()[HDROFF_LEN..])); + assert_eq!(TYPE, LittleEndian::read_u16(&pkt.hdr()[HDROFF_TYPE..])); + assert_eq!(OP, LittleEndian::read_u16(&pkt.hdr()[HDROFF_OP..])); + assert_eq!(FLAGS, LittleEndian::read_u32(&pkt.hdr()[HDROFF_FLAGS..])); + assert_eq!( + BUF_ALLOC, + LittleEndian::read_u32(&pkt.hdr()[HDROFF_BUF_ALLOC..]) + ); + assert_eq!( + FWD_CNT, + LittleEndian::read_u32(&pkt.hdr()[HDROFF_FWD_CNT..]) + ); + + assert_eq!(pkt.hdr_mut().len(), VSOCK_PKT_HDR_SIZE); + for b in pkt.hdr_mut() { + *b = 0; + } + assert_eq!(pkt.src_cid(), 0); + assert_eq!(pkt.dst_cid(), 0); + assert_eq!(pkt.src_port(), 0); + assert_eq!(pkt.dst_port(), 0); + assert_eq!(pkt.len(), 0); + assert_eq!(pkt.type_(), 0); + assert_eq!(pkt.op(), 0); + assert_eq!(pkt.flags(), 0); + assert_eq!(pkt.buf_alloc(), 0); + assert_eq!(pkt.fwd_cnt(), 0); + } + + #[test] + fn test_packet_buf() { + create_context!(test_ctx, handler_ctx); + let mut pkt = VsockPacket::from_rx_virtq_head( + &handler_ctx.handler.queues[0] + .iter(&test_ctx.mem) + .next() + .unwrap(), + ) + .unwrap(); + + assert_eq!( + pkt.buf().unwrap().len(), + handler_ctx.guest_rxvq.dtable[1].len.get() as usize + ); + assert_eq!( + pkt.buf_mut().unwrap().len(), + handler_ctx.guest_rxvq.dtable[1].len.get() as usize + ); + + for i in 0..pkt.buf().unwrap().len() { + pkt.buf_mut().unwrap()[i] = (i % 0x100) as u8; + assert_eq!(pkt.buf().unwrap()[i], (i % 0x100) as u8); + } + } +} diff --git a/vm-virtio/src/vsock/unix/muxer.rs b/vm-virtio/src/vsock/unix/muxer.rs index a094d2b72..2e6f8943b 100644 --- a/vm-virtio/src/vsock/unix/muxer.rs +++ b/vm-virtio/src/vsock/unix/muxer.rs @@ -767,3 +767,546 @@ impl VsockMuxer { } } } + +#[cfg(test)] +mod tests { + use std::io::{Read, Write}; + use std::ops::Drop; + use std::os::unix::net::{UnixListener, UnixStream}; + use std::path::{Path, PathBuf}; + + use super::super::super::csm::defs as csm_defs; + use super::super::super::tests::TestContext as VsockTestContext; + use super::*; + + const PEER_CID: u64 = 3; + const PEER_BUF_ALLOC: u32 = 64 * 1024; + + struct MuxerTestContext { + _vsock_test_ctx: VsockTestContext, + pkt: VsockPacket, + muxer: VsockMuxer, + } + + impl Drop for MuxerTestContext { + fn drop(&mut self) { + std::fs::remove_file(self.muxer.host_sock_path.as_str()).unwrap(); + } + } + + impl MuxerTestContext { + fn new(name: &str) -> Self { + let vsock_test_ctx = VsockTestContext::new(); + let mut handler_ctx = vsock_test_ctx.create_epoll_handler_context(); + let pkt = VsockPacket::from_rx_virtq_head( + &handler_ctx.handler.queues[0] + .iter(&vsock_test_ctx.mem) + .next() + .unwrap(), + ) + .unwrap(); + let uds_path = format!("test_vsock_{}.sock", name); + let muxer = VsockMuxer::new(PEER_CID, uds_path).unwrap(); + + Self { + _vsock_test_ctx: vsock_test_ctx, + pkt, + muxer, + } + } + + fn init_pkt(&mut self, local_port: u32, peer_port: u32, op: u16) -> &mut VsockPacket { + for b in self.pkt.hdr_mut() { + *b = 0; + } + self.pkt + .set_type(uapi::VSOCK_TYPE_STREAM) + .set_src_cid(PEER_CID) + .set_dst_cid(uapi::VSOCK_HOST_CID) + .set_src_port(peer_port) + .set_dst_port(local_port) + .set_op(op) + .set_buf_alloc(PEER_BUF_ALLOC) + } + + fn init_data_pkt( + &mut self, + local_port: u32, + peer_port: u32, + data: &[u8], + ) -> &mut VsockPacket { + assert!(data.len() <= self.pkt.buf().unwrap().len() as usize); + self.init_pkt(local_port, peer_port, uapi::VSOCK_OP_RW) + .set_len(data.len() as u32); + self.pkt.buf_mut().unwrap()[..data.len()].copy_from_slice(data); + &mut self.pkt + } + + fn send(&mut self) { + self.muxer.send_pkt(&self.pkt).unwrap(); + } + + fn recv(&mut self) { + self.muxer.recv_pkt(&mut self.pkt).unwrap(); + } + + fn notify_muxer(&mut self) { + self.muxer.notify(epoll::Events::EPOLLIN); + } + + fn count_epoll_listeners(&self) -> (usize, usize) { + let mut local_lsn_count = 0usize; + let mut conn_lsn_count = 0usize; + for key in self.muxer.listener_map.values() { + match key { + EpollListener::LocalStream(_) => local_lsn_count += 1, + EpollListener::Connection { .. } => conn_lsn_count += 1, + _ => (), + }; + } + (local_lsn_count, conn_lsn_count) + } + + fn create_local_listener(&self, port: u32) -> LocalListener { + LocalListener::new(format!("{}_{}", self.muxer.host_sock_path, port)) + } + + fn local_connect(&mut self, peer_port: u32) -> (UnixStream, u32) { + let (init_local_lsn_count, init_conn_lsn_count) = self.count_epoll_listeners(); + + let mut stream = UnixStream::connect(self.muxer.host_sock_path.clone()).unwrap(); + stream.set_nonblocking(true).unwrap(); + // The muxer would now get notified of a new connection having arrived at its Unix + // socket, so it can accept it. + self.notify_muxer(); + + // Just after having accepted a new local connection, the muxer should've added a new + // `LocalStream` listener to its `listener_map`. + let (local_lsn_count, _) = self.count_epoll_listeners(); + assert_eq!(local_lsn_count, init_local_lsn_count + 1); + + let buf = format!("CONNECT {}\n", peer_port); + stream.write_all(buf.as_bytes()).unwrap(); + // The muxer would now get notified that data is available for reading from the locally + // initiated connection. + self.notify_muxer(); + + // Successfully reading and parsing the connection request should have removed the + // LocalStream epoll listener and added a Connection epoll listener. + let (local_lsn_count, conn_lsn_count) = self.count_epoll_listeners(); + assert_eq!(local_lsn_count, init_local_lsn_count); + assert_eq!(conn_lsn_count, init_conn_lsn_count + 1); + + // A LocalInit connection should've been added to the muxer connection map. A new + // local port should also have been allocated for the new LocalInit connection. + let local_port = self.muxer.local_port_last; + let key = ConnMapKey { + local_port, + peer_port, + }; + assert!(self.muxer.conn_map.contains_key(&key)); + assert!(self.muxer.local_port_set.contains(&local_port)); + + // A connection request for the peer should now be available from the muxer. + assert!(self.muxer.has_pending_rx()); + self.recv(); + assert_eq!(self.pkt.op(), uapi::VSOCK_OP_REQUEST); + assert_eq!(self.pkt.dst_port(), peer_port); + assert_eq!(self.pkt.src_port(), local_port); + + self.init_pkt(local_port, peer_port, uapi::VSOCK_OP_RESPONSE); + self.send(); + + (stream, local_port) + } + } + + struct LocalListener { + path: PathBuf, + sock: UnixListener, + } + impl LocalListener { + fn new + Clone>(path: P) -> Self { + let path_buf = path.clone().as_ref().to_path_buf(); + let sock = UnixListener::bind(path).unwrap(); + sock.set_nonblocking(true).unwrap(); + Self { + path: path_buf, + sock, + } + } + fn accept(&mut self) -> UnixStream { + let (stream, _) = self.sock.accept().unwrap(); + stream.set_nonblocking(true).unwrap(); + stream + } + } + impl Drop for LocalListener { + fn drop(&mut self) { + std::fs::remove_file(&self.path).unwrap(); + } + } + + #[test] + fn test_muxer_epoll_listener() { + let ctx = MuxerTestContext::new("muxer_epoll_listener"); + assert_eq!(ctx.muxer.get_polled_fd(), ctx.muxer.epoll_fd); + assert_eq!(ctx.muxer.get_polled_evset(), epoll::Events::EPOLLIN); + } + + #[test] + fn test_bad_peer_pkt() { + const LOCAL_PORT: u32 = 1026; + const PEER_PORT: u32 = 1025; + const SOCK_DGRAM: u16 = 2; + + let mut ctx = MuxerTestContext::new("bad_peer_pkt"); + ctx.init_pkt(LOCAL_PORT, PEER_PORT, uapi::VSOCK_OP_REQUEST) + .set_type(SOCK_DGRAM); + ctx.send(); + + // The guest sent a SOCK_DGRAM packet. Per the vsock spec, we need to reply with an RST + // packet, since vsock only supports stream sockets. + assert!(ctx.muxer.has_pending_rx()); + ctx.recv(); + assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RST); + assert_eq!(ctx.pkt.src_cid(), uapi::VSOCK_HOST_CID); + assert_eq!(ctx.pkt.dst_cid(), PEER_CID); + assert_eq!(ctx.pkt.src_port(), LOCAL_PORT); + assert_eq!(ctx.pkt.dst_port(), PEER_PORT); + + // Any orphan (i.e. without a connection), non-RST packet, should be replied to with an + // RST. + let bad_ops = [ + uapi::VSOCK_OP_RESPONSE, + uapi::VSOCK_OP_CREDIT_REQUEST, + uapi::VSOCK_OP_CREDIT_UPDATE, + uapi::VSOCK_OP_SHUTDOWN, + uapi::VSOCK_OP_RW, + ]; + for op in bad_ops.iter() { + ctx.init_pkt(LOCAL_PORT, PEER_PORT, *op); + ctx.send(); + assert!(ctx.muxer.has_pending_rx()); + ctx.recv(); + assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RST); + assert_eq!(ctx.pkt.src_port(), LOCAL_PORT); + assert_eq!(ctx.pkt.dst_port(), PEER_PORT); + } + + // Any packet addressed to anything other than VSOCK_VHOST_CID should get dropped. + assert!(!ctx.muxer.has_pending_rx()); + ctx.init_pkt(LOCAL_PORT, PEER_PORT, uapi::VSOCK_OP_REQUEST) + .set_dst_cid(uapi::VSOCK_HOST_CID + 1); + ctx.send(); + assert!(!ctx.muxer.has_pending_rx()); + } + + #[test] + fn test_peer_connection() { + const LOCAL_PORT: u32 = 1026; + const PEER_PORT: u32 = 1025; + + let mut ctx = MuxerTestContext::new("peer_connection"); + + // Test peer connection refused. + ctx.init_pkt(LOCAL_PORT, PEER_PORT, uapi::VSOCK_OP_REQUEST); + ctx.send(); + ctx.recv(); + assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RST); + assert_eq!(ctx.pkt.len(), 0); + assert_eq!(ctx.pkt.src_cid(), uapi::VSOCK_HOST_CID); + assert_eq!(ctx.pkt.dst_cid(), PEER_CID); + assert_eq!(ctx.pkt.src_port(), LOCAL_PORT); + assert_eq!(ctx.pkt.dst_port(), PEER_PORT); + + // Test peer connection accepted. + let mut listener = ctx.create_local_listener(LOCAL_PORT); + ctx.init_pkt(LOCAL_PORT, PEER_PORT, uapi::VSOCK_OP_REQUEST); + ctx.send(); + assert_eq!(ctx.muxer.conn_map.len(), 1); + let mut stream = listener.accept(); + ctx.recv(); + assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RESPONSE); + assert_eq!(ctx.pkt.len(), 0); + assert_eq!(ctx.pkt.src_cid(), uapi::VSOCK_HOST_CID); + assert_eq!(ctx.pkt.dst_cid(), PEER_CID); + assert_eq!(ctx.pkt.src_port(), LOCAL_PORT); + assert_eq!(ctx.pkt.dst_port(), PEER_PORT); + let key = ConnMapKey { + local_port: LOCAL_PORT, + peer_port: PEER_PORT, + }; + assert!(ctx.muxer.conn_map.contains_key(&key)); + + // Test guest -> host data flow. + let data = [1, 2, 3, 4]; + ctx.init_data_pkt(LOCAL_PORT, PEER_PORT, &data); + ctx.send(); + let mut buf = vec![0; data.len()]; + stream.read_exact(buf.as_mut_slice()).unwrap(); + assert_eq!(buf.as_slice(), data); + + // Test host -> guest data flow. + let data = [5u8, 6, 7, 8]; + stream.write_all(&data).unwrap(); + + // When data is available on the local stream, an EPOLLIN event would normally be delivered + // to the muxer's nested epoll FD. For testing only, we can fake that event notification + // here. + ctx.notify_muxer(); + // After being notified, the muxer should've figured out that RX data was available for one + // of its connections, so it should now be reporting that it can fill in an RX packet. + assert!(ctx.muxer.has_pending_rx()); + ctx.recv(); + assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RW); + assert_eq!(ctx.pkt.buf().unwrap()[..data.len()], data); + assert_eq!(ctx.pkt.src_port(), LOCAL_PORT); + assert_eq!(ctx.pkt.dst_port(), PEER_PORT); + + assert!(!ctx.muxer.has_pending_rx()); + } + + #[test] + fn test_local_connection() { + let mut ctx = MuxerTestContext::new("local_connection"); + let peer_port = 1025; + let (mut stream, local_port) = ctx.local_connect(peer_port); + + // Test guest -> host data flow. + let data = [1, 2, 3, 4]; + ctx.init_data_pkt(local_port, peer_port, &data); + ctx.send(); + + let mut buf = vec![0u8; data.len()]; + stream.read_exact(buf.as_mut_slice()).unwrap(); + assert_eq!(buf.as_slice(), &data); + + // Test host -> guest data flow. + let data = [5, 6, 7, 8]; + stream.write_all(&data).unwrap(); + ctx.notify_muxer(); + + assert!(ctx.muxer.has_pending_rx()); + ctx.recv(); + assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RW); + assert_eq!(ctx.pkt.src_port(), local_port); + assert_eq!(ctx.pkt.dst_port(), peer_port); + assert_eq!(ctx.pkt.buf().unwrap()[..data.len()], data); + } + + #[test] + fn test_local_close() { + let peer_port = 1025; + let mut ctx = MuxerTestContext::new("local_close"); + let local_port; + { + let (_stream, local_port_) = ctx.local_connect(peer_port); + local_port = local_port_; + } + // Local var `_stream` was now dropped, thus closing the local stream. After the muxer gets + // notified via EPOLLIN, it should attempt to gracefully shutdown the connection, issuing a + // VSOCK_OP_SHUTDOWN with both no-more-send and no-more-recv indications set. + ctx.notify_muxer(); + assert!(ctx.muxer.has_pending_rx()); + ctx.recv(); + assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_SHUTDOWN); + assert_ne!(ctx.pkt.flags() & uapi::VSOCK_FLAGS_SHUTDOWN_SEND, 0); + assert_ne!(ctx.pkt.flags() & uapi::VSOCK_FLAGS_SHUTDOWN_RCV, 0); + assert_eq!(ctx.pkt.src_port(), local_port); + assert_eq!(ctx.pkt.dst_port(), peer_port); + + // The connection should get removed (and its local port freed), after the peer replies + // with an RST. + ctx.init_pkt(local_port, peer_port, uapi::VSOCK_OP_RST); + ctx.send(); + let key = ConnMapKey { + local_port, + peer_port, + }; + assert!(!ctx.muxer.conn_map.contains_key(&key)); + assert!(!ctx.muxer.local_port_set.contains(&local_port)); + } + + #[test] + fn test_peer_close() { + let peer_port = 1025; + let local_port = 1026; + let mut ctx = MuxerTestContext::new("peer_close"); + + let mut sock = ctx.create_local_listener(local_port); + ctx.init_pkt(local_port, peer_port, uapi::VSOCK_OP_REQUEST); + ctx.send(); + let mut stream = sock.accept(); + + assert!(ctx.muxer.has_pending_rx()); + ctx.recv(); + assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RESPONSE); + assert_eq!(ctx.pkt.src_port(), local_port); + assert_eq!(ctx.pkt.dst_port(), peer_port); + let key = ConnMapKey { + local_port, + peer_port, + }; + assert!(ctx.muxer.conn_map.contains_key(&key)); + + // Emulate a full shutdown from the peer (no-more-send + no-more-recv). + ctx.init_pkt(local_port, peer_port, uapi::VSOCK_OP_SHUTDOWN) + .set_flag(uapi::VSOCK_FLAGS_SHUTDOWN_SEND) + .set_flag(uapi::VSOCK_FLAGS_SHUTDOWN_RCV); + ctx.send(); + + // Now, the muxer should remove the connection from its map, and reply with an RST. + assert!(ctx.muxer.has_pending_rx()); + ctx.recv(); + assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RST); + assert_eq!(ctx.pkt.src_port(), local_port); + assert_eq!(ctx.pkt.dst_port(), peer_port); + let key = ConnMapKey { + local_port, + peer_port, + }; + assert!(!ctx.muxer.conn_map.contains_key(&key)); + + // The muxer should also drop / close the local Unix socket for this connection. + let mut buf = vec![0u8; 16]; + assert_eq!(stream.read(buf.as_mut_slice()).unwrap(), 0); + } + + #[test] + fn test_muxer_rxq() { + let mut ctx = MuxerTestContext::new("muxer_rxq"); + let local_port = 1026; + let peer_port_first = 1025; + let mut listener = ctx.create_local_listener(local_port); + let mut streams: Vec = Vec::new(); + + for peer_port in peer_port_first..peer_port_first + defs::MUXER_RXQ_SIZE { + ctx.init_pkt(local_port, peer_port as u32, uapi::VSOCK_OP_REQUEST); + ctx.send(); + streams.push(listener.accept()); + } + + // The muxer RX queue should now be full (with connection reponses), but still + // synchronized. + assert!(ctx.muxer.rxq.is_synced()); + + // One more queued reply should desync the RX queue. + ctx.init_pkt( + local_port, + (peer_port_first + defs::MUXER_RXQ_SIZE) as u32, + uapi::VSOCK_OP_REQUEST, + ); + ctx.send(); + assert!(!ctx.muxer.rxq.is_synced()); + + // With an out-of-sync queue, an RST should evict any non-RST packet from the queue, and + // take its place. We'll check that by making sure that the last packet popped from the + // queue is an RST. + ctx.init_pkt( + local_port + 1, + peer_port_first as u32, + uapi::VSOCK_OP_REQUEST, + ); + ctx.send(); + + for peer_port in peer_port_first..peer_port_first + defs::MUXER_RXQ_SIZE - 1 { + ctx.recv(); + assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RESPONSE); + // The response order should hold. The evicted response should have been the last + // enqueued. + assert_eq!(ctx.pkt.dst_port(), peer_port as u32); + } + // There should be one more packet in the queue: the RST. + assert_eq!(ctx.muxer.rxq.len(), 1); + ctx.recv(); + assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RST); + + // The queue should now be empty, but out-of-sync, so the muxer should report it has some + // pending RX. + assert!(ctx.muxer.rxq.is_empty()); + assert!(!ctx.muxer.rxq.is_synced()); + assert!(ctx.muxer.has_pending_rx()); + + // The next recv should sync the queue back up. It should also yield one of the two + // responses that are still left: + // - the one that desynchronized the queue; and + // - the one that got evicted by the RST. + ctx.recv(); + assert!(ctx.muxer.rxq.is_synced()); + assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RESPONSE); + + assert!(ctx.muxer.has_pending_rx()); + ctx.recv(); + assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RESPONSE); + } + + #[test] + fn test_muxer_killq() { + let mut ctx = MuxerTestContext::new("muxer_killq"); + let local_port = 1026; + let peer_port_first = 1025; + let peer_port_last = peer_port_first + defs::MUXER_KILLQ_SIZE; + let mut listener = ctx.create_local_listener(local_port); + + for peer_port in peer_port_first..=peer_port_last { + ctx.init_pkt(local_port, peer_port as u32, uapi::VSOCK_OP_REQUEST); + ctx.send(); + ctx.notify_muxer(); + ctx.recv(); + assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RESPONSE); + assert_eq!(ctx.pkt.src_port(), local_port); + assert_eq!(ctx.pkt.dst_port(), peer_port as u32); + { + let _stream = listener.accept(); + } + ctx.notify_muxer(); + ctx.recv(); + assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_SHUTDOWN); + assert_eq!(ctx.pkt.src_port(), local_port); + assert_eq!(ctx.pkt.dst_port(), peer_port as u32); + // The kill queue should be synchronized, up until the `defs::MUXER_KILLQ_SIZE`th + // connection we schedule for termination. + assert_eq!( + ctx.muxer.killq.is_synced(), + peer_port < peer_port_first + defs::MUXER_KILLQ_SIZE + ); + } + + assert!(!ctx.muxer.killq.is_synced()); + assert!(!ctx.muxer.has_pending_rx()); + + // Wait for the kill timers to expire. + std::thread::sleep(std::time::Duration::from_millis( + csm_defs::CONN_SHUTDOWN_TIMEOUT_MS, + )); + + // Trigger a kill queue sweep, by requesting a new connection. + ctx.init_pkt( + local_port, + peer_port_last as u32 + 1, + uapi::VSOCK_OP_REQUEST, + ); + ctx.send(); + + // After sweeping the kill queue, it should now be synced (assuming the RX queue is larger + // than the kill queue, since an RST packet will be queued for each killed connection). + assert!(ctx.muxer.killq.is_synced()); + assert!(ctx.muxer.has_pending_rx()); + // There should be `defs::MUXER_KILLQ_SIZE` RSTs in the RX queue, from terminating the + // dying connections in the recent killq sweep. + for _p in peer_port_first..peer_port_last { + ctx.recv(); + assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RST); + assert_eq!(ctx.pkt.src_port(), local_port); + } + + // There should be one more packet in the RX queue: the connection response our request + // that triggered the kill queue sweep. + ctx.recv(); + assert_eq!(ctx.pkt.op(), uapi::VSOCK_OP_RESPONSE); + assert_eq!(ctx.pkt.dst_port(), peer_port_last as u32 + 1); + + assert!(!ctx.muxer.has_pending_rx()); + } +}