vhost_rs: remove unused crate

As we switched to external vhost_rs crate, delete this internal one.
All vhost changes should go to cloud-hypervisor/vhost.

Signed-off-by: Eryu Guan <eguan@linux.alibaba.com>
This commit is contained in:
Eryu Guan 2020-02-25 11:49:14 +08:00 committed by Rob Bradford
parent 5200bf3c59
commit 81c2294c11
21 changed files with 0 additions and 5692 deletions

View File

@ -1,25 +0,0 @@
[package]
name = "vhost_rs"
version = "0.1.0"
authors = ["Liu Jiang <gerry@linux.alibaba.com>"]
repository = "https://github.com/rust-vmm/vhost"
license = "Apache-2.0 or BSD-3-Clause"
[features]
default = []
vhost-vsock = []
vhost-kern = ["vm-memory"]
vhost-user-master = []
vhost-user-slave = []
[dependencies]
bitflags = "1.1.0"
libc = "0.2.67"
vmm-sys-util = ">=0.3.1"
[dependencies.vm-memory]
git = "https://github.com/rust-vmm/vm-memory"
optional = true
[dev-dependencies]
tempfile = "3.1.0"

View File

@ -1,202 +0,0 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@ -1,27 +0,0 @@
// Copyright 2017 The Chromium OS Authors. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@ -1,24 +0,0 @@
Permission is hereby granted, free of charge, to any
person obtaining a copy of this software and associated
documentation files (the "Software"), to deal in the
Software without restriction, including without
limitation the rights to use, copy, modify, merge,
publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software
is furnished to do so, subject to the following
conditions:
The above copyright notice and this permission notice
shall be included in all copies or substantial portions
of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.

View File

@ -1,10 +0,0 @@
# vHost
A crate to support vhost backend drivers for virtio devices.
## Kernel-based vHost Backend Drivers
The vhost drivers in Linux provide in-kernel virtio device emulation. Normally the hypervisor userspace process emulates I/O accesses from the guest. Vhost puts virtio emulation code into the kernel, taking hypervisor userspace out of the picture. This allows device emulation code to directly call into kernel subsystems instead of performing system calls from userspace. The hypervisor relies on ioctl based interfaces to control those in-kernel vhost drivers, such as vhost-net, vhost-scsi and vhost-vsock etc.
## vHost-user Backend Drivers
The vhost-user protocol is aiming to implement vhost backend drivers in userspace, which complements the ioctl interface used to control the vhost implementation in the Linux kernel. It implements the control plane needed to establish virtqueue sharing with a user space process on the same host. It uses communication over a Unix domain socket to share file descriptors in the ancillary data of the message.
The protocol defines two sides of the communication, master and slave. Master is the application that shares its virtqueues, slave is the consumer of the virtqueues. Master and slave can be either a client (i.e. connecting) or server (listening) in the socket communication.

View File

@ -1,130 +0,0 @@
// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause
//
// Portions Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
//
// Portions Copyright 2017 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE-BSD file.
//! Common traits and structs for vhost-kern and vhost-user backend drivers.
use super::Result;
use std::os::unix::io::RawFd;
use vmm_sys_util::eventfd::EventFd;
/// Maximum number of memory regions supported.
pub const VHOST_MAX_MEMORY_REGIONS: usize = 255;
/// Vring/virtque configuration data.
pub struct VringConfigData {
/// Maximum queue size supported by the driver.
pub queue_max_size: u16,
/// Actual queue size negotiated by the driver.
pub queue_size: u16,
/// Bitmask of vring flags.
pub flags: u32,
/// Descriptor table address.
pub desc_table_addr: u64,
/// Used ring buffer address.
pub used_ring_addr: u64,
/// Available ring buffer address.
pub avail_ring_addr: u64,
/// Optional address for logging.
pub log_addr: Option<u64>,
}
/// Memory region configuration data.
#[derive(Default, Clone, Copy)]
pub struct VhostUserMemoryRegionInfo {
/// Guest physical address of the memory region.
pub guest_phys_addr: u64,
/// Size of the memory region.
pub memory_size: u64,
/// Virtual address in the current process.
pub userspace_addr: u64,
/// Optional offset where region starts in the mapped memory.
pub mmap_offset: u64,
/// Optional file diescriptor for mmap
pub mmap_handle: RawFd,
}
/// An interface for setting up vhost-based backend drivers.
///
/// Vhost-based virtio devices are different from regular virtio devices because the the vhost
/// backend takes care of handling all the data transfer. The device itself only needs to deal with
/// setting up the the backend driver and managing the control channel.
pub trait VhostBackend: std::marker::Sized {
/// Get a bitmask of supported virtio/vhost features.
fn get_features(&mut self) -> Result<u64>;
/// Inform the vhost subsystem which features to enable.
/// This should be a subset of supported features from get_features().
///
/// # Arguments
/// * `features` - Bitmask of features to set.
fn set_features(&mut self, features: u64) -> Result<()>;
/// Set the current process as the owner of the vhost backend.
/// This must be run before any other vhost commands.
fn set_owner(&mut self) -> Result<()>;
/// Used to be sent to request disabling all rings
/// This is no longer used.
fn reset_owner(&mut self) -> Result<()>;
/// Set the guest memory mappings for vhost to use.
fn set_mem_table(&mut self, regions: &[VhostUserMemoryRegionInfo]) -> Result<()>;
/// Set base address for page modification logging.
fn set_log_base(&mut self, base: u64, fd: Option<RawFd>) -> Result<()>;
/// Specify an eventfd file descriptor to signal on log write.
fn set_log_fd(&mut self, fd: RawFd) -> Result<()>;
/// Set the number of descriptors in the vring.
///
/// # Arguments
/// * `queue_index` - Index of the queue to set descriptor count for.
/// * `num` - Number of descriptors in the queue.
fn set_vring_num(&mut self, queue_index: usize, num: u16) -> Result<()>;
/// Set the addresses for a given vring.
///
/// # Arguments
/// * `queue_index` - Index of the queue to set addresses for.
/// * `config_data` - Configuration data for a vring.
fn set_vring_addr(&mut self, queue_index: usize, config_data: &VringConfigData) -> Result<()>;
/// Set the first index to look for available descriptors.
///
/// # Arguments
/// * `queue_index` - Index of the queue to modify.
/// * `num` - Index where available descriptors start.
fn set_vring_base(&mut self, queue_index: usize, base: u16) -> Result<()>;
/// Get the available vring base offset.
fn get_vring_base(&mut self, queue_index: usize) -> Result<u32>;
/// Set the eventfd to trigger when buffers have been used by the host.
///
/// # Arguments
/// * `queue_index` - Index of the queue to modify.
/// * `fd` - EventFd to trigger.
fn set_vring_call(&mut self, queue_index: usize, fd: &EventFd) -> Result<()>;
/// Set the eventfd that will be signaled by the guest when buffers are
/// available for the host to process.
///
/// # Arguments
/// * `queue_index` - Index of the queue to modify.
/// * `fd` - EventFd that will be signaled from guest.
fn set_vring_kick(&mut self, queue_index: usize, fd: &EventFd) -> Result<()>;
/// Set the eventfd that will be signaled by the guest when error happens.
///
/// # Arguments
/// * `queue_index` - Index of the queue to modify.
/// * `fd` - EventFd that will be signaled from guest.
fn set_vring_err(&mut self, queue_index: usize, fd: &EventFd) -> Result<()>;
}

View File

@ -1,120 +0,0 @@
// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause
//
// Portions Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
//
// Portions Copyright 2017 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE-BSD file.
//! Virtio Vhost Backend Drivers
//!
//! Virtio devices use virtqueues to transport data efficiently. Virtqueue is a set of three
//! different single-producer, single-consumer ring structures designed to store generic
//! scatter-gather I/O.
//!
//! Vhost is a mechanism to improve performance of Virtio devices by delegate data plane operations
//! to dedicated IO service processes. Only the configuration, I/O submission notification, and I/O
//! completion interruption are piped through the hypervisor.
//! It uses the same virtqueue layout as Virtio to allow Vhost devices to be mapped directly to
//! Virtio devices. This allows a Vhost device to be accessed directly by a guest OS inside a
//! hypervisor process with an existing Virtio (PCI) driver.
//!
//! The initial vhost implementation is a part of the Linux kernel and uses ioctl interface to
//! communicate with userspace applications. Dedicated kernel worker threads are created to handle
//! IO requests from the guest.
//!
//! Later Vhost-user protocol is introduced to complement the ioctl interface used to control the
//! vhost implementation in the Linux kernel. It implements the control plane needed to establish
//! virtqueues sharing with a user space process on the same host. It uses communication over a
//! Unix domain socket to share file descriptors in the ancillary data of the message.
//! The protocol defines 2 sides of the communication, master and slave. Master is the application
//! that shares its virtqueues. Slave is the consumer of the virtqueues. Master and slave can be
//! either a client (i.e. connecting) or server (listening) in the socket communication.
#![deny(missing_docs)]
#[cfg_attr(
any(feature = "vhost-user-master", feature = "vhost-user-slave"),
macro_use
)]
extern crate bitflags;
extern crate libc;
#[cfg(feature = "vhost-kern")]
extern crate vm_memory;
#[cfg_attr(feature = "vhost-kern", macro_use)]
extern crate vmm_sys_util;
mod backend;
pub use backend::*;
#[cfg(feature = "vhost-kern")]
pub mod vhost_kern;
#[cfg(any(feature = "vhost-user-master", feature = "vhost-user-slave"))]
pub mod vhost_user;
#[cfg(feature = "vhost-vsock")]
pub mod vsock;
/// Error codes for vhost operations
#[derive(Debug)]
pub enum Error {
/// Invalid operations.
InvalidOperation,
/// Invalid guest memory.
InvalidGuestMemory,
/// Invalid guest memory region.
InvalidGuestMemoryRegion,
/// Invalid queue.
InvalidQueue,
/// Invalid descriptor table address.
DescriptorTableAddress,
/// Invalid used address.
UsedAddress,
/// Invalid available address.
AvailAddress,
/// Invalid log address.
LogAddress,
#[cfg(feature = "vhost-kern")]
/// Error opening the vhost backend driver.
VhostOpen(std::io::Error),
#[cfg(feature = "vhost-kern")]
/// Error while running ioctl.
IoctlError(std::io::Error),
/// Error from IO subsystem.
IOError(std::io::Error),
#[cfg(any(feature = "vhost-user-master", feature = "vhost-user-slave"))]
/// Error from the vhost-user subsystem.
VhostUserProtocol(vhost_user::Error),
}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Error::InvalidOperation => write!(f, "invalid vhost operations"),
Error::InvalidGuestMemory => write!(f, "invalid guest memory object"),
Error::InvalidGuestMemoryRegion => write!(f, "invalid guest memory region"),
Error::InvalidQueue => write!(f, "invalid virtque"),
Error::DescriptorTableAddress => write!(f, "invalid virtque descriptor talbe address"),
Error::UsedAddress => write!(f, "invalid virtque used talbe address"),
Error::AvailAddress => write!(f, "invalid virtque available talbe address"),
Error::LogAddress => write!(f, "invalid virtque log address"),
Error::IOError(e) => write!(f, "IO error: {}", e),
#[cfg(feature = "vhost-kern")]
Error::VhostOpen(e) => write!(f, "failure in opening vhost file: {}", e),
#[cfg(feature = "vhost-kern")]
Error::IoctlError(e) => write!(f, "failure in vhost ioctl: {}", e),
#[cfg(any(feature = "vhost-user-master", feature = "vhost-user-slave"))]
Error::VhostUserProtocol(e) => write!(f, "vhost-user: {}", e),
}
}
}
#[cfg(any(feature = "vhost-user-master", feature = "vhost-user-slave"))]
impl std::convert::From<vhost_user::Error> for Error {
fn from(err: vhost_user::Error) -> Self {
Error::VhostUserProtocol(err)
}
}
/// Result of vhost operations
pub type Result<T> = std::result::Result<T, Error>;

View File

@ -1,320 +0,0 @@
// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause
//
// Portions Copyright 2017 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE-BSD file.
//! Traits and structs to control Linux in-kernel vhost drivers.
//!
//! The initial vhost implementation is a part of the Linux kernel and uses ioctl interface to
//! communicate with userspace applications. This sub module provides ioctl based interfaces to
//! control the in-kernel net, scsi, vsock vhost drivers.
use std::os::unix::io::{AsRawFd, RawFd};
use std::ptr::null;
use vm_memory::{Address, GuestAddress, GuestMemory, GuestUsize};
use vmm_sys_util::eventfd::EventFd;
use vmm_sys_util::ioctl::{ioctl, ioctl_with_mut_ref, ioctl_with_ptr, ioctl_with_ref};
use super::{
Error, Result, VhostBackend, VhostUserMemoryRegionInfo, VringConfigData,
VHOST_MAX_MEMORY_REGIONS,
};
pub mod vhost_binding;
use self::vhost_binding::*;
#[cfg(feature = "vhost-vsock")]
pub mod vsock;
#[inline]
fn ioctl_result<T>(rc: i32, res: T) -> Result<T> {
if rc < 0 {
Err(Error::IoctlError(std::io::Error::last_os_error()))
} else {
Ok(res)
}
}
fn guest_addr(addr: u64) -> GuestAddress {
GuestAddress::new(addr)
}
/// Represent an in-kernel vhost device backend.
pub trait VhostKernBackend<'a>: AsRawFd {
/// Associated type to access the guest's memory.
type M: GuestMemory<'a>;
/// Get the object to access the guest's memory.
fn mem(&self) -> &Self::M;
/// Check whether the ring configuration is valid.
#[allow(clippy::if_same_then_else)]
#[allow(clippy::needless_bool)]
fn is_valid(
&self,
queue_max_size: u16,
queue_size: u16,
desc_addr: GuestAddress,
avail_addr: GuestAddress,
used_addr: GuestAddress,
) -> bool {
let desc_table_size = 16 * u64::from(queue_size) as GuestUsize;
let avail_ring_size = 6 + 2 * u64::from(queue_size) as GuestUsize;
let used_ring_size = 6 + 8 * u64::from(queue_size) as GuestUsize;
if queue_size > queue_max_size || queue_size == 0 || (queue_size & (queue_size - 1)) != 0 {
false
} else if desc_addr
.checked_add(desc_table_size)
.map_or(true, |v| !self.mem().address_in_range(v))
{
false
} else if avail_addr
.checked_add(avail_ring_size)
.map_or(true, |v| !self.mem().address_in_range(v))
{
false
} else if used_addr
.checked_add(used_ring_size)
.map_or(true, |v| !self.mem().address_in_range(v))
{
false
} else {
true
}
}
}
impl<'a, T: VhostKernBackend<'a>> VhostBackend for T {
/// Set the current process as the owner of this file descriptor.
/// This must be run before any other vhost ioctls.
fn set_owner(&mut self) -> Result<()> {
// This ioctl is called on a valid vhost fd and has its return value checked.
let ret = unsafe { ioctl(self, VHOST_SET_OWNER()) };
ioctl_result(ret, ())
}
fn reset_owner(&mut self) -> Result<()> {
// This ioctl is called on a valid vhost fd and has its return value checked.
let ret = unsafe { ioctl(self, VHOST_RESET_OWNER()) };
ioctl_result(ret, ())
}
/// Get a bitmask of supported virtio/vhost features.
fn get_features(&mut self) -> Result<u64> {
let mut avail_features: u64 = 0;
// This ioctl is called on a valid vhost fd and has its return value checked.
let ret = unsafe { ioctl_with_mut_ref(self, VHOST_GET_FEATURES(), &mut avail_features) };
ioctl_result(ret, avail_features)
}
/// Inform the vhost subsystem which features to enable. This should be a subset of
/// supported features from VHOST_GET_FEATURES.
///
/// # Arguments
/// * `features` - Bitmask of features to set.
fn set_features(&mut self, features: u64) -> Result<()> {
// This ioctl is called on a valid vhost fd and has its return value checked.
let ret = unsafe { ioctl_with_ref(self, VHOST_SET_FEATURES(), &features) };
ioctl_result(ret, ())
}
/// Set the guest memory mappings for vhost to use.
fn set_mem_table(&mut self, regions: &[VhostUserMemoryRegionInfo]) -> Result<()> {
if regions.is_empty() || regions.len() > VHOST_MAX_MEMORY_REGIONS {
return Err(Error::InvalidGuestMemory);
}
let mut vhost_memory = VhostMemory::new(regions.len() as u16);
for (index, region) in regions.iter().enumerate() {
vhost_memory.set_region(
index as u32,
&vhost_memory_region {
guest_phys_addr: region.guest_phys_addr,
memory_size: region.memory_size,
userspace_addr: region.userspace_addr,
flags_padding: 0u64,
},
)?;
}
// This ioctl is called with a pointer that is valid for the lifetime
// of this function. The kernel will make its own copy of the memory
// tables. As always, check the return value.
let ret = unsafe { ioctl_with_ptr(self, VHOST_SET_MEM_TABLE(), vhost_memory.as_ptr()) };
ioctl_result(ret, ())
}
/// Set base address for page modification logging.
///
/// # Arguments
/// * `base` - Base address for page modification logging.
fn set_log_base(&mut self, base: u64, fd: Option<RawFd>) -> Result<()> {
if fd.is_some() {
return Err(Error::LogAddress);
}
// This ioctl is called on a valid vhost fd and has its return value checked.
let ret = unsafe { ioctl_with_ref(self, VHOST_SET_LOG_BASE(), &base) };
ioctl_result(ret, ())
}
/// Specify an eventfd file descriptor to signal on log write.
fn set_log_fd(&mut self, fd: RawFd) -> Result<()> {
// This ioctl is called on a valid vhost fd and has its return value checked.
let val: i32 = fd;
let ret = unsafe { ioctl_with_ref(self, VHOST_SET_LOG_FD(), &val) };
ioctl_result(ret, ())
}
/// Set the number of descriptors in the vring.
///
/// # Arguments
/// * `queue_index` - Index of the queue to set descriptor count for.
/// * `num` - Number of descriptors in the queue.
fn set_vring_num(&mut self, queue_index: usize, num: u16) -> Result<()> {
let vring_state = vhost_vring_state {
index: queue_index as u32,
num: u32::from(num),
};
// This ioctl is called on a valid vhost fd and has its return value checked.
let ret = unsafe { ioctl_with_ref(self, VHOST_SET_VRING_NUM(), &vring_state) };
ioctl_result(ret, ())
}
/// Set the addresses for a given vring.
///
/// # Arguments
/// * `queue_max_size` - Maximum queue size supported by the device.
/// * `queue_size` - Actual queue size negotiated by the driver.
/// * `queue_index` - Index of the queue to set addresses for.
/// * `flags` - Bitmask of vring flags.
/// * `desc_table_addr` - Descriptor table address.
/// * `used_ring_addr` - Used ring buffer address.
/// * `avail_ring_addr` - Available ring buffer address.
/// * `log_addr` - Optional address for logging.
fn set_vring_addr(&mut self, queue_index: usize, config_data: &VringConfigData) -> Result<()> {
if !self.is_valid(
config_data.queue_max_size,
config_data.queue_size,
guest_addr(config_data.desc_table_addr),
guest_addr(config_data.used_ring_addr),
guest_addr(config_data.avail_ring_addr),
) {
return Err(Error::InvalidQueue);
}
let desc_addr = self
.mem()
.get_host_address(guest_addr(config_data.desc_table_addr))
.ok_or(Error::DescriptorTableAddress)?;
let used_addr = self
.mem()
.get_host_address(guest_addr(config_data.used_ring_addr))
.ok_or(Error::UsedAddress)?;
let avail_addr = self
.mem()
.get_host_address(guest_addr(config_data.avail_ring_addr))
.ok_or(Error::AvailAddress)?;
let log_addr = match config_data.log_addr {
None => null(),
Some(a) => self
.mem()
.get_host_address(guest_addr(a))
.ok_or(Error::LogAddress)?,
};
let vring_addr = vhost_vring_addr {
index: queue_index as u32,
flags: config_data.flags,
desc_user_addr: desc_addr as u64,
used_user_addr: used_addr as u64,
avail_user_addr: avail_addr as u64,
log_guest_addr: log_addr as u64,
};
// This ioctl is called on a valid vhost fd and has its
// return value checked.
let ret = unsafe { ioctl_with_ref(self, VHOST_SET_VRING_ADDR(), &vring_addr) };
ioctl_result(ret, ())
}
/// Set the first index to look for available descriptors.
///
/// # Arguments
/// * `queue_index` - Index of the queue to modify.
/// * `num` - Index where available descriptors start.
fn set_vring_base(&mut self, queue_index: usize, base: u16) -> Result<()> {
let vring_state = vhost_vring_state {
index: queue_index as u32,
num: u32::from(base),
};
// This ioctl is called on a valid vhost fd and has its return value checked.
let ret = unsafe { ioctl_with_ref(self, VHOST_SET_VRING_BASE(), &vring_state) };
ioctl_result(ret, ())
}
/// Get a bitmask of supported virtio/vhost features.
fn get_vring_base(&mut self, queue_index: usize) -> Result<u32> {
let vring_state = vhost_vring_state {
index: queue_index as u32,
num: 0,
};
// This ioctl is called on a valid vhost fd and has its return value checked.
let ret = unsafe { ioctl_with_ref(self, VHOST_GET_VRING_BASE(), &vring_state) };
ioctl_result(ret, vring_state.num)
}
/// Set the eventfd to trigger when buffers have been used by the host.
///
/// # Arguments
/// * `queue_index` - Index of the queue to modify.
/// * `fd` - EventFd to trigger.
fn set_vring_call(&mut self, queue_index: usize, fd: &EventFd) -> Result<()> {
let vring_file = vhost_vring_file {
index: queue_index as u32,
fd: fd.as_raw_fd(),
};
// This ioctl is called on a valid vhost fd and has its return value checked.
let ret = unsafe { ioctl_with_ref(self, VHOST_SET_VRING_CALL(), &vring_file) };
ioctl_result(ret, ())
}
/// Set the eventfd that will be signaled by the guest when buffers are
/// available for the host to process.
///
/// # Arguments
/// * `queue_index` - Index of the queue to modify.
/// * `fd` - EventFd that will be signaled from guest.
fn set_vring_kick(&mut self, queue_index: usize, fd: &EventFd) -> Result<()> {
let vring_file = vhost_vring_file {
index: queue_index as u32,
fd: fd.as_raw_fd(),
};
// This ioctl is called on a valid vhost fd and has its return value checked.
let ret = unsafe { ioctl_with_ref(self, VHOST_SET_VRING_KICK(), &vring_file) };
ioctl_result(ret, ())
}
/// Set the eventfd to signal an error from the vhost backend.
///
/// # Arguments
/// * `queue_index` - Index of the queue to modify.
/// * `fd` - EventFd that will be signaled from the backend.
fn set_vring_err(&mut self, queue_index: usize, fd: &EventFd) -> Result<()> {
let vring_file = vhost_vring_file {
index: queue_index as u32,
fd: fd.as_raw_fd(),
};
// This ioctl is called on a valid vhost fd and has its return value checked.
let ret = unsafe { ioctl_with_ref(self, VHOST_SET_VRING_ERR(), &vring_file) };
ioctl_result(ret, ())
}
}

View File

@ -1,405 +0,0 @@
// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause
//
// Portions Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
//
// Portions Copyright 2017 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE-BSD file.
/* Auto-generated by bindgen then manually edited for simplicity */
#![allow(non_upper_case_globals)]
#![allow(non_camel_case_types)]
#![allow(non_snake_case)]
#![allow(missing_docs)]
use std::os::raw;
use {Error, Result};
pub const VHOST: raw::c_uint = 0xaf;
pub const VHOST_VRING_F_LOG: raw::c_uint = 0;
pub const VHOST_ACCESS_RO: raw::c_uint = 1;
pub const VHOST_ACCESS_WO: raw::c_uint = 2;
pub const VHOST_ACCESS_RW: raw::c_uint = 3;
pub const VHOST_IOTLB_MISS: raw::c_uint = 1;
pub const VHOST_IOTLB_UPDATE: raw::c_uint = 2;
pub const VHOST_IOTLB_INVALIDATE: raw::c_uint = 3;
pub const VHOST_IOTLB_ACCESS_FAIL: raw::c_uint = 4;
pub const VHOST_IOTLB_MSG: raw::c_uint = 1;
pub const VHOST_PAGE_SIZE: raw::c_uint = 4096;
pub const VHOST_VIRTIO: raw::c_uint = 175;
pub const VHOST_VRING_LITTLE_ENDIAN: raw::c_uint = 0;
pub const VHOST_VRING_BIG_ENDIAN: raw::c_uint = 1;
pub const VHOST_F_LOG_ALL: raw::c_uint = 26;
pub const VHOST_NET_F_VIRTIO_NET_HDR: raw::c_uint = 27;
pub const VHOST_SCSI_ABI_VERSION: raw::c_uint = 1;
ioctl_ior_nr!(VHOST_GET_FEATURES, VHOST, 0x00, raw::c_ulonglong);
ioctl_iow_nr!(VHOST_SET_FEATURES, VHOST, 0x00, raw::c_ulonglong);
ioctl_io_nr!(VHOST_SET_OWNER, VHOST, 0x01);
ioctl_io_nr!(VHOST_RESET_OWNER, VHOST, 0x02);
ioctl_iow_nr!(VHOST_SET_MEM_TABLE, VHOST, 0x03, vhost_memory);
ioctl_iow_nr!(VHOST_SET_LOG_BASE, VHOST, 0x04, raw::c_ulonglong);
ioctl_iow_nr!(VHOST_SET_LOG_FD, VHOST, 0x07, raw::c_int);
ioctl_iow_nr!(VHOST_SET_VRING_NUM, VHOST, 0x10, vhost_vring_state);
ioctl_iow_nr!(VHOST_SET_VRING_ADDR, VHOST, 0x11, vhost_vring_addr);
ioctl_iow_nr!(VHOST_SET_VRING_BASE, VHOST, 0x12, vhost_vring_state);
ioctl_iowr_nr!(VHOST_GET_VRING_BASE, VHOST, 0x12, vhost_vring_state);
ioctl_iow_nr!(VHOST_SET_VRING_KICK, VHOST, 0x20, vhost_vring_file);
ioctl_iow_nr!(VHOST_SET_VRING_CALL, VHOST, 0x21, vhost_vring_file);
ioctl_iow_nr!(VHOST_SET_VRING_ERR, VHOST, 0x22, vhost_vring_file);
ioctl_iow_nr!(vhost_SET_BACKEND, VHOST, 0x30, vhost_vring_file);
ioctl_iow_nr!(VHOST_SCSI_SET_ENDPOINT, VHOST, 0x40, vhost_scsi_target);
ioctl_iow_nr!(VHOST_SCSI_CLEAR_ENDPOINT, VHOST, 0x41, vhost_scsi_target);
ioctl_iow_nr!(VHOST_SCSI_GET_ABI_VERSION, VHOST, 0x42, raw::c_int);
ioctl_iow_nr!(VHOST_SCSI_SET_EVENTS_MISSED, VHOST, 0x43, raw::c_uint);
ioctl_iow_nr!(VHOST_SCSI_GET_EVENTS_MISSED, VHOST, 0x44, raw::c_uint);
ioctl_iow_nr!(VHOST_VSOCK_SET_GUEST_CID, VHOST, 0x60, raw::c_ulonglong);
ioctl_iow_nr!(VHOST_VSOCK_SET_RUNNING, VHOST, 0x61, raw::c_int);
#[repr(C)]
#[derive(Default)]
pub struct __IncompleteArrayField<T>(::std::marker::PhantomData<T>);
impl<T> __IncompleteArrayField<T> {
#[inline]
pub fn new() -> Self {
__IncompleteArrayField(::std::marker::PhantomData)
}
#[inline]
#[allow(clippy::trivially_copy_pass_by_ref)]
#[allow(clippy::useless_transmute)]
pub unsafe fn as_ptr(&self) -> *const T {
::std::mem::transmute(self)
}
#[inline]
#[allow(clippy::useless_transmute)]
pub unsafe fn as_mut_ptr(&mut self) -> *mut T {
::std::mem::transmute(self)
}
#[inline]
pub unsafe fn as_slice(&self, len: usize) -> &[T] {
::std::slice::from_raw_parts(self.as_ptr(), len)
}
#[inline]
pub unsafe fn as_mut_slice(&mut self, len: usize) -> &mut [T] {
::std::slice::from_raw_parts_mut(self.as_mut_ptr(), len)
}
}
impl<T> ::std::fmt::Debug for __IncompleteArrayField<T> {
fn fmt(&self, fmt: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
fmt.write_str("__IncompleteArrayField")
}
}
impl<T> ::std::clone::Clone for __IncompleteArrayField<T> {
#[inline]
fn clone(&self) -> Self {
Self::new()
}
}
impl<T> ::std::marker::Copy for __IncompleteArrayField<T> {}
#[repr(C)]
#[derive(Debug, Default, Copy, Clone)]
pub struct vhost_vring_state {
pub index: raw::c_uint,
pub num: raw::c_uint,
}
#[repr(C)]
#[derive(Debug, Default, Copy, Clone)]
pub struct vhost_vring_file {
pub index: raw::c_uint,
pub fd: raw::c_int,
}
#[repr(C)]
#[derive(Debug, Default, Copy, Clone)]
pub struct vhost_vring_addr {
pub index: raw::c_uint,
pub flags: raw::c_uint,
pub desc_user_addr: raw::c_ulonglong,
pub used_user_addr: raw::c_ulonglong,
pub avail_user_addr: raw::c_ulonglong,
pub log_guest_addr: raw::c_ulonglong,
}
#[repr(C)]
#[derive(Debug, Default, Copy, Clone)]
pub struct vhost_iotlb_msg {
pub iova: raw::c_ulonglong,
pub size: raw::c_ulonglong,
pub uaddr: raw::c_ulonglong,
pub perm: raw::c_uchar,
pub type_: raw::c_uchar,
}
#[repr(C)]
#[derive(Copy, Clone)]
pub struct vhost_msg {
pub type_: raw::c_int,
pub __bindgen_anon_1: vhost_msg__bindgen_ty_1,
}
impl Default for vhost_msg {
fn default() -> Self {
unsafe { ::std::mem::zeroed() }
}
}
#[repr(C)]
#[derive(Copy, Clone)]
pub union vhost_msg__bindgen_ty_1 {
pub iotlb: vhost_iotlb_msg,
pub padding: [raw::c_uchar; 64usize],
_bindgen_union_align: [u64; 8usize],
}
impl Default for vhost_msg__bindgen_ty_1 {
fn default() -> Self {
unsafe { ::std::mem::zeroed() }
}
}
#[repr(C)]
#[derive(Debug, Default, Copy, Clone)]
pub struct vhost_memory_region {
pub guest_phys_addr: raw::c_ulonglong,
pub memory_size: raw::c_ulonglong,
pub userspace_addr: raw::c_ulonglong,
pub flags_padding: raw::c_ulonglong,
}
#[repr(C)]
#[derive(Debug, Default, Clone)]
pub struct vhost_memory {
pub nregions: raw::c_uint,
pub padding: raw::c_uint,
pub regions: __IncompleteArrayField<vhost_memory_region>,
__force_alignment: [u64; 0],
}
#[repr(C)]
#[derive(Copy, Clone)]
pub struct vhost_scsi_target {
pub abi_version: raw::c_int,
pub vhost_wwpn: [raw::c_char; 224usize],
pub vhost_tpgt: raw::c_ushort,
pub reserved: raw::c_ushort,
}
impl Default for vhost_scsi_target {
fn default() -> Self {
unsafe { ::std::mem::zeroed() }
}
}
/// Helper to support vhost::set_mem_table()
pub struct VhostMemory {
buf: Vec<vhost_memory>,
}
impl VhostMemory {
// Limit number of regions to u16 to simplify error handling
pub fn new(entries: u16) -> Self {
let size = std::mem::size_of::<vhost_memory_region>() * entries as usize;
let count = (size + 2 * std::mem::size_of::<vhost_memory>() - 1)
/ std::mem::size_of::<vhost_memory>();
let mut buf: Vec<vhost_memory> = vec![Default::default(); count];
buf[0].nregions = u32::from(entries);
VhostMemory { buf }
}
pub fn as_ptr(&self) -> *const char {
&self.buf[0] as *const vhost_memory as *const char
}
pub fn get_header(&self) -> &vhost_memory {
&self.buf[0]
}
pub fn get_region(&self, index: u32) -> Option<&vhost_memory_region> {
if index >= self.buf[0].nregions {
return None;
}
// Safe because we have allocated enough space nregions
let regions = unsafe { self.buf[0].regions.as_slice(self.buf[0].nregions as usize) };
Some(&regions[index as usize])
}
pub fn set_region(&mut self, index: u32, region: &vhost_memory_region) -> Result<()> {
if index >= self.buf[0].nregions {
return Err(Error::InvalidGuestMemory);
}
// Safe because we have allocated enough space nregions and checked the index.
let regions = unsafe { self.buf[0].regions.as_mut_slice(index as usize + 1) };
regions[index as usize] = *region;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn bindgen_test_layout_vhost_vring_state() {
assert_eq!(
::std::mem::size_of::<vhost_vring_state>(),
8usize,
concat!("Size of: ", stringify!(vhost_vring_state))
);
assert_eq!(
::std::mem::align_of::<vhost_vring_state>(),
4usize,
concat!("Alignment of ", stringify!(vhost_vring_state))
);
}
#[test]
fn bindgen_test_layout_vhost_vring_file() {
assert_eq!(
::std::mem::size_of::<vhost_vring_file>(),
8usize,
concat!("Size of: ", stringify!(vhost_vring_file))
);
assert_eq!(
::std::mem::align_of::<vhost_vring_file>(),
4usize,
concat!("Alignment of ", stringify!(vhost_vring_file))
);
}
#[test]
fn bindgen_test_layout_vhost_vring_addr() {
assert_eq!(
::std::mem::size_of::<vhost_vring_addr>(),
40usize,
concat!("Size of: ", stringify!(vhost_vring_addr))
);
assert_eq!(
::std::mem::align_of::<vhost_vring_addr>(),
8usize,
concat!("Alignment of ", stringify!(vhost_vring_addr))
);
}
#[test]
fn bindgen_test_layout_vhost_msg__bindgen_ty_1() {
assert_eq!(
::std::mem::size_of::<vhost_msg__bindgen_ty_1>(),
64usize,
concat!("Size of: ", stringify!(vhost_msg__bindgen_ty_1))
);
assert_eq!(
::std::mem::align_of::<vhost_msg__bindgen_ty_1>(),
8usize,
concat!("Alignment of ", stringify!(vhost_msg__bindgen_ty_1))
);
}
#[test]
fn bindgen_test_layout_vhost_msg() {
assert_eq!(
::std::mem::size_of::<vhost_msg>(),
72usize,
concat!("Size of: ", stringify!(vhost_msg))
);
assert_eq!(
::std::mem::align_of::<vhost_msg>(),
8usize,
concat!("Alignment of ", stringify!(vhost_msg))
);
}
#[test]
fn bindgen_test_layout_vhost_memory_region() {
assert_eq!(
::std::mem::size_of::<vhost_memory_region>(),
32usize,
concat!("Size of: ", stringify!(vhost_memory_region))
);
assert_eq!(
::std::mem::align_of::<vhost_memory_region>(),
8usize,
concat!("Alignment of ", stringify!(vhost_memory_region))
);
}
#[test]
fn bindgen_test_layout_vhost_memory() {
assert_eq!(
::std::mem::size_of::<vhost_memory>(),
8usize,
concat!("Size of: ", stringify!(vhost_memory))
);
assert_eq!(
::std::mem::align_of::<vhost_memory>(),
8usize,
concat!("Alignment of ", stringify!(vhost_memory))
);
}
#[test]
fn bindgen_test_layout_vhost_iotlb_msg() {
assert_eq!(
::std::mem::size_of::<vhost_iotlb_msg>(),
32usize,
concat!("Size of: ", stringify!(vhost_iotlb_msg))
);
assert_eq!(
::std::mem::align_of::<vhost_iotlb_msg>(),
8usize,
concat!("Alignment of ", stringify!(vhost_iotlb_msg))
);
}
#[test]
fn bindgen_test_layout_vhost_scsi_target() {
assert_eq!(
::std::mem::size_of::<vhost_scsi_target>(),
232usize,
concat!("Size of: ", stringify!(vhost_scsi_target))
);
assert_eq!(
::std::mem::align_of::<vhost_scsi_target>(),
4usize,
concat!("Alignment of ", stringify!(vhost_scsi_target))
);
}
#[test]
fn test_vhostmemory() {
let mut obj = VhostMemory::new(2);
let region = vhost_memory_region {
guest_phys_addr: 0x1000u64,
memory_size: 0x2000u64,
userspace_addr: 0x300000u64,
flags_padding: 0u64,
};
assert!(obj.get_region(2).is_none());
{
let header = obj.get_header();
assert_eq!(header.nregions, 2u32);
}
{
assert!(obj.set_region(0, &region).is_ok());
assert!(obj.set_region(1, &region).is_ok());
assert!(obj.set_region(2, &region).is_err());
}
let region1 = obj.get_region(1).unwrap();
assert_eq!(region1.guest_phys_addr, 0x1000u64);
assert_eq!(region1.memory_size, 0x2000u64);
assert_eq!(region1.userspace_addr, 0x300000u64);
}
}

View File

@ -1,84 +0,0 @@
// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 or MIT
//
// Copyright 2017 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the THIRD-PARTY file.
//! Kernel-based vsock vhost backend.
use std::fs::{File, OpenOptions};
use std::marker::PhantomData;
use std::os::unix::fs::OpenOptionsExt;
use std::os::unix::io::{AsRawFd, RawFd};
use super::vhost_binding::{VHOST_VSOCK_SET_GUEST_CID, VHOST_VSOCK_SET_RUNNING};
use super::{ioctl_result, Error, Result, VhostKernBackend};
use libc;
use vm_memory::GuestMemory;
use vmm_sys_util::ioctl::ioctl_with_ref;
const VHOST_PATH: &str = "/dev/vhost-vsock";
/// Handle for running VHOST_VSOCK ioctls.
pub struct Vsock<'a, M: GuestMemory<'a>> {
fd: File,
mem: M,
_phatomdata: PhantomData<&'a M>, // Get rid of unused type parameter `a
}
impl<'a, M: GuestMemory<'a>> Vsock<'a, M> {
/// Open a handle to a new VHOST-VSOCK instance.
pub fn new(mem: &M) -> Result<Self> {
Ok(Vsock {
fd: OpenOptions::new()
.read(true)
.write(true)
.custom_flags(libc::O_CLOEXEC | libc::O_NONBLOCK)
.open(VHOST_PATH)
.map_err(Error::VhostOpen)?,
mem: mem.clone(),
_phatomdata: PhantomData,
})
}
/// Set the CID for the guest. This number is used for routing all data destined for
/// running in the guest. Each guest on a hypervisor must have an unique CID
///
/// # Arguments
/// * `cid` - CID to assign to the guest
pub fn set_guest_cid(&self, cid: u64) -> Result<()> {
let ret = unsafe { ioctl_with_ref(&self.fd, VHOST_VSOCK_SET_GUEST_CID(), &cid) };
ioctl_result(ret, ())
}
/// Tell the VHOST driver to start performing data transfer.
pub fn start(&self) -> Result<()> {
self.set_running(true)
}
/// Tell the VHOST driver to stop performing data transfer.
pub fn stop(&self) -> Result<()> {
self.set_running(false)
}
fn set_running(&self, running: bool) -> Result<()> {
let on: ::std::os::raw::c_int = if running { 1 } else { 0 };
let ret = unsafe { ioctl_with_ref(&self.fd, VHOST_VSOCK_SET_RUNNING(), &on) };
ioctl_result(ret, ())
}
}
impl<'a, M: GuestMemory<'a>> VhostKernBackend<'a> for Vsock<'a, M> {
type M = M;
fn mem(&self) -> &Self::M {
&self.mem
}
}
impl<'a, M: GuestMemory<'a>> AsRawFd for Vsock<'a, M> {
fn as_raw_fd(&self) -> RawFd {
self.fd.as_raw_fd()
}
}

View File

@ -1,730 +0,0 @@
// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Structs for Unix Domain Socket listener and endpoint.
#![allow(dead_code)]
use libc::{c_void, iovec};
use std::io::ErrorKind;
use std::marker::PhantomData;
use std::os::unix::io::{AsRawFd, RawFd};
use std::os::unix::net::{UnixListener, UnixStream};
use std::{mem, slice};
use super::message::*;
use super::sock_ctrl_msg::ScmSocket;
use super::{Error, Result};
/// Unix domain socket listener for accepting incoming connections.
pub struct Listener {
fd: UnixListener,
path: String,
}
impl Listener {
/// Create a unix domain socket listener.
///
/// # Return:
/// * - the new Listener object on success.
/// * - SocketError: failed to create listener socket.
pub fn new(path: &str, unlink: bool) -> Result<Self> {
if unlink {
let _ = std::fs::remove_file(path);
}
let fd = UnixListener::bind(path).map_err(Error::SocketError)?;
Ok(Listener {
fd,
path: path.to_string(),
})
}
/// Accept an incoming connection.
///
/// # Return:
/// * - Some(UnixStream): new UnixStream object if new incoming connection is available.
/// * - None: no incoming connection available.
/// * - SocketError: errors from accept().
pub fn accept(&self) -> Result<Option<UnixStream>> {
loop {
match self.fd.accept() {
Ok((socket, _addr)) => return Ok(Some(socket)),
Err(e) => {
match e.kind() {
// No incoming connection available.
ErrorKind::WouldBlock => return Ok(None),
// New connection closed by peer.
ErrorKind::ConnectionAborted => return Ok(None),
// Interrupted by signals, retry
ErrorKind::Interrupted => continue,
_ => return Err(Error::SocketError(e)),
}
}
}
}
}
/// Change blocking status on the listener.
///
/// # Return:
/// * - () on success.
/// * - SocketError: failure from set_nonblocking().
pub fn set_nonblocking(&self, block: bool) -> Result<()> {
self.fd.set_nonblocking(block).map_err(Error::SocketError)
}
}
impl AsRawFd for Listener {
fn as_raw_fd(&self) -> RawFd {
self.fd.as_raw_fd()
}
}
impl Drop for Listener {
fn drop(&mut self) {
let _ = std::fs::remove_file(self.path.clone());
}
}
/// Unix domain socket endpoint for vhost-user connection.
pub(super) struct Endpoint<R: Req> {
sock: UnixStream,
_r: PhantomData<R>,
}
impl<R: Req> Endpoint<R> {
/// Create a new stream by connecting to server at `str`.
///
/// # Return:
/// * - the new Endpoint object on success.
/// * - SocketConnect: failed to connect to peer.
pub fn connect(path: &str) -> Result<Self> {
let sock = UnixStream::connect(path).map_err(Error::SocketConnect)?;
Ok(Self::from_stream(sock))
}
/// Create an endpoint from a stream object.
pub fn from_stream(sock: UnixStream) -> Self {
Endpoint {
sock,
_r: PhantomData,
}
}
/// Sends bytes from scatter-gather vectors over the socket with optional attached file
/// descriptors.
///
/// # Return:
/// * - number of bytes sent on success
/// * - SocketRetry: temporary error caused by signals or short of resources.
/// * - SocketBroken: the underline socket is broken.
/// * - SocketError: other socket related errors.
pub fn send_iovec(&mut self, iovs: &[&[u8]], fds: Option<&[RawFd]>) -> Result<usize> {
let rfds = match fds {
Some(rfds) => rfds,
_ => &[],
};
self.sock.send_with_fds(iovs, rfds).map_err(Into::into)
}
/// Sends bytes from a slice over the socket with optional attached file descriptors.
///
/// # Return:
/// * - number of bytes sent on success
/// * - SocketRetry: temporary error caused by signals or short of resources.
/// * - SocketBroken: the underline socket is broken.
/// * - SocketError: other socket related errors.
pub fn send_slice(&mut self, data: &[u8], fds: Option<&[RawFd]>) -> Result<usize> {
self.send_iovec(&[data], fds)
}
/// Sends a header-only message with optional attached file descriptors.
///
/// # Return:
/// * - number of bytes sent on success
/// * - SocketRetry: temporary error caused by signals or short of resources.
/// * - SocketBroken: the underline socket is broken.
/// * - SocketError: other socket related errors.
/// * - PartialMessage: received a partial message.
pub fn send_header(
&mut self,
hdr: &VhostUserMsgHeader<R>,
fds: Option<&[RawFd]>,
) -> Result<()> {
// Safe because there can't be other mutable referance to hdr.
let iovs = unsafe {
[slice::from_raw_parts(
hdr as *const VhostUserMsgHeader<R> as *const u8,
mem::size_of::<VhostUserMsgHeader<R>>(),
)]
};
let bytes = self.send_iovec(&iovs[..], fds)?;
if bytes != mem::size_of::<VhostUserMsgHeader<R>>() {
return Err(Error::PartialMessage);
}
Ok(())
}
/// Send a message with header and body. Optional file descriptors may be attached to
/// the message.
///
/// # Return:
/// * - number of bytes sent on success
/// * - SocketRetry: temporary error caused by signals or short of resources.
/// * - SocketBroken: the underline socket is broken.
/// * - SocketError: other socket related errors.
/// * - PartialMessage: received a partial message.
pub fn send_message<T: Sized>(
&mut self,
hdr: &VhostUserMsgHeader<R>,
body: &T,
fds: Option<&[RawFd]>,
) -> Result<()> {
// Safe because there can't be other mutable referance to hdr and body.
let iovs = unsafe {
[
slice::from_raw_parts(
hdr as *const VhostUserMsgHeader<R> as *const u8,
mem::size_of::<VhostUserMsgHeader<R>>(),
),
slice::from_raw_parts(body as *const T as *const u8, mem::size_of::<T>()),
]
};
let bytes = self.send_iovec(&iovs[..], fds)?;
if bytes != mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>() {
return Err(Error::PartialMessage);
}
Ok(())
}
/// Send a message with header, body and payload. Optional file descriptors
/// may also be attached to the message.
///
/// # Return:
/// * - number of bytes sent on success
/// * - SocketRetry: temporary error caused by signals or short of resources.
/// * - SocketBroken: the underline socket is broken.
/// * - SocketError: other socket related errors.
/// * - OversizedMsg: message size is too big.
/// * - PartialMessage: received a partial message.
/// * - IncorrectFds: wrong number of attached fds.
pub fn send_message_with_payload<T: Sized, P: Sized>(
&mut self,
hdr: &VhostUserMsgHeader<R>,
body: &T,
payload: &[P],
fds: Option<&[RawFd]>,
) -> Result<()> {
let len = payload.len() * mem::size_of::<P>();
if len > MAX_MSG_SIZE - mem::size_of::<T>() {
return Err(Error::OversizedMsg);
}
if let Some(fd_arr) = fds {
if fd_arr.len() > MAX_ATTACHED_FD_ENTRIES {
return Err(Error::IncorrectFds);
}
}
// Safe because there can't be other mutable reference to hdr, body and payload.
let iovs = unsafe {
[
slice::from_raw_parts(
hdr as *const VhostUserMsgHeader<R> as *const u8,
mem::size_of::<VhostUserMsgHeader<R>>(),
),
slice::from_raw_parts(body as *const T as *const u8, mem::size_of::<T>()),
slice::from_raw_parts(payload.as_ptr() as *const u8, len),
]
};
let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>() + len;
let len = self.send_iovec(&iovs, fds)?;
if len != total {
return Err(Error::PartialMessage);
}
Ok(())
}
/// Reads bytes from the socket into the given scatter/gather vectors.
///
/// # Return:
/// * - (number of bytes received, buf) on success
/// * - SocketRetry: temporary error caused by signals or short of resources.
/// * - SocketBroken: the underline socket is broken.
/// * - SocketError: other socket related errors.
pub fn recv_data(&mut self, len: usize) -> Result<(usize, Vec<u8>)> {
let mut rbuf = vec![0u8; len];
let mut iovs = [iovec {
iov_base: rbuf.as_mut_ptr() as *mut c_void,
iov_len: len,
}];
let (bytes, _) = self.sock.recv_with_fds(&mut iovs, &mut [])?;
Ok((bytes, rbuf))
}
/// Reads bytes from the socket into the given scatter/gather vectors with optional attached
/// file descriptors.
///
/// The underlying communication channel is a Unix domain socket in STREAM mode. It's a little
/// tricky to pass file descriptors through such a communication channel. Let's assume that a
/// sender sending a message with some file descriptors attached. To successfully receive those
/// attached file descriptors, the receiver must obey following rules:
/// 1) file descriptors are attached to a message.
/// 2) message(packet) boundaries must be respected on the receive side.
/// In other words, recvmsg() operations must not cross the packet boundary, otherwise the
/// attached file descriptors will get lost.
///
/// # Return:
/// * - (number of bytes received, [received fds]) on success
/// * - SocketRetry: temporary error caused by signals or short of resources.
/// * - SocketBroken: the underline socket is broken.
/// * - SocketError: other socket related errors.
pub fn recv_into_iovec(&mut self, iovs: &mut [iovec]) -> Result<(usize, Option<Vec<RawFd>>)> {
let mut fd_array = vec![0; MAX_ATTACHED_FD_ENTRIES];
let (bytes, fds) = self.sock.recv_with_fds(iovs, &mut fd_array)?;
let rfds = match fds {
0 => None,
n => {
let mut fds = Vec::with_capacity(n);
fds.extend_from_slice(&fd_array[0..n]);
Some(fds)
}
};
Ok((bytes, rfds))
}
/// Reads bytes from the socket into a new buffer with optional attached
/// file descriptors. Received file descriptors are set close-on-exec.
///
/// # Return:
/// * - (number of bytes received, buf, [received fds]) on success.
/// * - SocketRetry: temporary error caused by signals or short of resources.
/// * - SocketBroken: the underline socket is broken.
/// * - SocketError: other socket related errors.
pub fn recv_into_buf(
&mut self,
buf_size: usize,
) -> Result<(usize, Vec<u8>, Option<Vec<RawFd>>)> {
let mut buf = vec![0u8; buf_size];
let (bytes, rfds) = {
let mut iovs = [iovec {
iov_base: buf.as_mut_ptr() as *mut c_void,
iov_len: buf_size,
}];
self.recv_into_iovec(&mut iovs)?
};
Ok((bytes, buf, rfds))
}
/// Receive a header-only message with optional attached file descriptors.
/// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be
/// accepted and all other file descriptor will be discard silently.
///
/// # Return:
/// * - (message header, [received fds]) on success.
/// * - SocketRetry: temporary error caused by signals or short of resources.
/// * - SocketBroken: the underline socket is broken.
/// * - SocketError: other socket related errors.
/// * - PartialMessage: received a partial message.
/// * - InvalidMessage: received a invalid message.
pub fn recv_header(&mut self) -> Result<(VhostUserMsgHeader<R>, Option<Vec<RawFd>>)> {
let mut hdr = VhostUserMsgHeader::default();
let mut iovs = [iovec {
iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void,
iov_len: mem::size_of::<VhostUserMsgHeader<R>>(),
}];
let (bytes, rfds) = self.recv_into_iovec(&mut iovs[..])?;
if bytes != mem::size_of::<VhostUserMsgHeader<R>>() {
return Err(Error::PartialMessage);
} else if !hdr.is_valid() {
return Err(Error::InvalidMessage);
}
Ok((hdr, rfds))
}
/// Receive a message with optional attached file descriptors.
/// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be
/// accepted and all other file descriptor will be discard silently.
///
/// # Return:
/// * - (message header, message body, [received fds]) on success.
/// * - SocketRetry: temporary error caused by signals or short of resources.
/// * - SocketBroken: the underline socket is broken.
/// * - SocketError: other socket related errors.
/// * - PartialMessage: received a partial message.
/// * - InvalidMessage: received a invalid message.
pub fn recv_body<T: Sized + Default + VhostUserMsgValidator>(
&mut self,
) -> Result<(VhostUserMsgHeader<R>, T, Option<Vec<RawFd>>)> {
let mut hdr = VhostUserMsgHeader::default();
let mut body: T = Default::default();
let mut iovs = [
iovec {
iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void,
iov_len: mem::size_of::<VhostUserMsgHeader<R>>(),
},
iovec {
iov_base: (&mut body as *mut T) as *mut c_void,
iov_len: mem::size_of::<T>(),
},
];
let (bytes, rfds) = self.recv_into_iovec(&mut iovs[..])?;
let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>();
if bytes != total {
return Err(Error::PartialMessage);
} else if !hdr.is_valid() || !body.is_valid() {
return Err(Error::InvalidMessage);
}
Ok((hdr, body, rfds))
}
/// Receive a message with header and optional content. Callers need to
/// pre-allocate a big enough buffer to receive the message body and
/// optional payload. If there are attached file descriptor associated
/// with the message, the first MAX_ATTACHED_FD_ENTRIES file descriptors
/// will be accepted and all other file descriptor will be discard
/// silently.
///
/// # Return:
/// * - (message header, message size, [received fds]) on success.
/// * - SocketRetry: temporary error caused by signals or short of resources.
/// * - SocketBroken: the underline socket is broken.
/// * - SocketError: other socket related errors.
/// * - PartialMessage: received a partial message.
/// * - InvalidMessage: received a invalid message.
pub fn recv_body_into_buf(
&mut self,
buf: &mut [u8],
) -> Result<(VhostUserMsgHeader<R>, usize, Option<Vec<RawFd>>)> {
let mut hdr = VhostUserMsgHeader::default();
let mut iovs = [
iovec {
iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void,
iov_len: mem::size_of::<VhostUserMsgHeader<R>>(),
},
iovec {
iov_base: buf.as_mut_ptr() as *mut c_void,
iov_len: buf.len(),
},
];
let (bytes, rfds) = self.recv_into_iovec(&mut iovs[..])?;
if bytes < mem::size_of::<VhostUserMsgHeader<R>>() {
return Err(Error::PartialMessage);
} else if !hdr.is_valid() {
return Err(Error::InvalidMessage);
}
Ok((hdr, bytes - mem::size_of::<VhostUserMsgHeader<R>>(), rfds))
}
/// Receive a message with optional payload and attached file descriptors.
/// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be
/// accepted and all other file descriptor will be discard silently.
///
/// # Return:
/// * - (message header, message body, size of payload, [received fds]) on success.
/// * - SocketRetry: temporary error caused by signals or short of resources.
/// * - SocketBroken: the underline socket is broken.
/// * - SocketError: other socket related errors.
/// * - PartialMessage: received a partial message.
/// * - InvalidMessage: received a invalid message.
#[cfg_attr(feature = "cargo-clippy", allow(clippy::type_complexity))]
pub fn recv_payload_into_buf<T: Sized + Default + VhostUserMsgValidator>(
&mut self,
buf: &mut [u8],
) -> Result<(VhostUserMsgHeader<R>, T, usize, Option<Vec<RawFd>>)> {
let mut hdr = VhostUserMsgHeader::default();
let mut body: T = Default::default();
let mut iovs = [
iovec {
iov_base: (&mut hdr as *mut VhostUserMsgHeader<R>) as *mut c_void,
iov_len: mem::size_of::<VhostUserMsgHeader<R>>(),
},
iovec {
iov_base: (&mut body as *mut T) as *mut c_void,
iov_len: mem::size_of::<T>(),
},
iovec {
iov_base: buf.as_mut_ptr() as *mut c_void,
iov_len: buf.len(),
},
];
let (bytes, rfds) = self.recv_into_iovec(&mut iovs[..])?;
let total = mem::size_of::<VhostUserMsgHeader<R>>() + mem::size_of::<T>();
if bytes < total {
return Err(Error::PartialMessage);
} else if !hdr.is_valid() || !body.is_valid() {
return Err(Error::InvalidMessage);
}
Ok((hdr, body, bytes - total, rfds))
}
/// Close all raw file descriptors.
pub fn close_rfds(rfds: Option<Vec<RawFd>>) {
if let Some(fds) = rfds {
for fd in fds {
// safe because the rawfds are valid and we don't care about the result.
let _ = unsafe { libc::close(fd) };
}
}
}
}
impl<T: Req> AsRawFd for Endpoint<T> {
fn as_raw_fd(&self) -> RawFd {
self.sock.as_raw_fd()
}
}
#[cfg(test)]
mod tests {
extern crate tempfile;
use self::tempfile::tempfile;
use super::*;
use std::fs::File;
use std::io::{Read, Seek, SeekFrom, Write};
use std::os::unix::io::FromRawFd;
const UNIX_SOCKET_LISTENER: &'static str = "/tmp/vhost_user_test_rust_listener";
const UNIX_SOCKET_CONNECTION: &'static str = "/tmp/vhost_user_test_rust_connection";
const UNIX_SOCKET_DATA: &'static str = "/tmp/vhost_user_test_rust_data";
const UNIX_SOCKET_FD: &'static str = "/tmp/vhost_user_test_rust_fd";
const UNIX_SOCKET_SEND: &'static str = "/tmp/vhost_user_test_rust_send";
#[test]
fn create_listener() {
let _ = Listener::new(UNIX_SOCKET_LISTENER, true).unwrap();
}
#[test]
fn accept_connection() {
let listener = Listener::new(UNIX_SOCKET_CONNECTION, true).unwrap();
listener.set_nonblocking(true).unwrap();
// accept on a fd without incoming connection
let conn = listener.accept().unwrap();
assert!(conn.is_none());
}
#[test]
#[ignore]
fn send_data() {
let listener = Listener::new(UNIX_SOCKET_DATA, true).unwrap();
listener.set_nonblocking(true).unwrap();
let mut master = Endpoint::<MasterReq>::connect(UNIX_SOCKET_DATA).unwrap();
let sock = listener.accept().unwrap().unwrap();
let mut slave = Endpoint::<MasterReq>::from_stream(sock);
let buf1 = vec![0x1, 0x2, 0x3, 0x4];
let mut len = master.send_slice(&buf1[..], None).unwrap();
assert_eq!(len, 4);
let (bytes, buf2, _) = slave.recv_into_buf(0x1000).unwrap();
assert_eq!(bytes, 4);
assert_eq!(&buf1[..], &buf2[..bytes]);
len = master.send_slice(&buf1[..], None).unwrap();
assert_eq!(len, 4);
let (bytes, buf2, _) = slave.recv_into_buf(0x2).unwrap();
assert_eq!(bytes, 2);
assert_eq!(&buf1[..2], &buf2[..]);
let (bytes, buf2, _) = slave.recv_into_buf(0x2).unwrap();
assert_eq!(bytes, 2);
assert_eq!(&buf1[2..], &buf2[..]);
}
#[test]
#[ignore]
fn send_fd() {
let listener = Listener::new(UNIX_SOCKET_FD, true).unwrap();
listener.set_nonblocking(true).unwrap();
let mut master = Endpoint::<MasterReq>::connect(UNIX_SOCKET_FD).unwrap();
let sock = listener.accept().unwrap().unwrap();
let mut slave = Endpoint::<MasterReq>::from_stream(sock);
let mut fd = tempfile().unwrap();
write!(fd, "test").unwrap();
// Normal case for sending/receiving file descriptors
let buf1 = vec![0x1, 0x2, 0x3, 0x4];
let len = master
.send_slice(&buf1[..], Some(&[fd.as_raw_fd()]))
.unwrap();
assert_eq!(len, 4);
let (bytes, buf2, rfds) = slave.recv_into_buf(4).unwrap();
assert_eq!(bytes, 4);
assert_eq!(&buf1[..], &buf2[..]);
assert!(rfds.is_some());
let fds = rfds.unwrap();
{
assert_eq!(fds.len(), 1);
let mut file = unsafe { File::from_raw_fd(fds[0]) };
let mut content = String::new();
file.seek(SeekFrom::Start(0)).unwrap();
file.read_to_string(&mut content).unwrap();
assert_eq!(content, "test");
}
// Following communication pattern should work:
// Sending side: data(header, body) with fds
// Receiving side: data(header) with fds, data(body)
let len = master
.send_slice(
&buf1[..],
Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]),
)
.unwrap();
assert_eq!(len, 4);
let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap();
assert_eq!(bytes, 2);
assert_eq!(&buf1[..2], &buf2[..]);
assert!(rfds.is_some());
let fds = rfds.unwrap();
{
assert_eq!(fds.len(), 3);
let mut file = unsafe { File::from_raw_fd(fds[1]) };
let mut content = String::new();
file.seek(SeekFrom::Start(0)).unwrap();
file.read_to_string(&mut content).unwrap();
assert_eq!(content, "test");
}
let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap();
assert_eq!(bytes, 2);
assert_eq!(&buf1[2..], &buf2[..]);
assert!(rfds.is_none());
// Following communication pattern should not work:
// Sending side: data(header, body) with fds
// Receiving side: data(header), data(body) with fds
let len = master
.send_slice(
&buf1[..],
Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]),
)
.unwrap();
assert_eq!(len, 4);
let (bytes, buf4) = slave.recv_data(2).unwrap();
assert_eq!(bytes, 2);
assert_eq!(&buf1[..2], &buf4[..]);
let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap();
assert_eq!(bytes, 2);
assert_eq!(&buf1[2..], &buf2[..]);
assert!(rfds.is_none());
// Following communication pattern should work:
// Sending side: data, data with fds
// Receiving side: data, data with fds
let len = master.send_slice(&buf1[..], None).unwrap();
assert_eq!(len, 4);
let len = master
.send_slice(
&buf1[..],
Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]),
)
.unwrap();
assert_eq!(len, 4);
let (bytes, buf2, rfds) = slave.recv_into_buf(0x4).unwrap();
assert_eq!(bytes, 4);
assert_eq!(&buf1[..], &buf2[..]);
assert!(rfds.is_none());
let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap();
assert_eq!(bytes, 2);
assert_eq!(&buf1[..2], &buf2[..]);
assert!(rfds.is_some());
let fds = rfds.unwrap();
{
assert_eq!(fds.len(), 3);
let mut file = unsafe { File::from_raw_fd(fds[1]) };
let mut content = String::new();
file.seek(SeekFrom::Start(0)).unwrap();
file.read_to_string(&mut content).unwrap();
assert_eq!(content, "test");
}
let (bytes, buf2, rfds) = slave.recv_into_buf(0x2).unwrap();
assert_eq!(bytes, 2);
assert_eq!(&buf1[2..], &buf2[..]);
assert!(rfds.is_none());
// Following communication pattern should not work:
// Sending side: data1, data2 with fds
// Receiving side: data + partial of data2, left of data2 with fds
let len = master.send_slice(&buf1[..], None).unwrap();
assert_eq!(len, 4);
let len = master
.send_slice(
&buf1[..],
Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]),
)
.unwrap();
assert_eq!(len, 4);
let (bytes, _) = slave.recv_data(5).unwrap();
assert_eq!(bytes, 5);
let (bytes, _, rfds) = slave.recv_into_buf(0x4).unwrap();
assert_eq!(bytes, 3);
assert!(rfds.is_none());
// If the target fd array is too small, extra file descriptors will get lost.
let len = master
.send_slice(
&buf1[..],
Some(&[fd.as_raw_fd(), fd.as_raw_fd(), fd.as_raw_fd()]),
)
.unwrap();
assert_eq!(len, 4);
let (bytes, _, rfds) = slave.recv_into_buf(0x4).unwrap();
assert_eq!(bytes, 4);
assert!(rfds.is_some());
Endpoint::<MasterReq>::close_rfds(rfds);
Endpoint::<MasterReq>::close_rfds(None);
}
#[test]
#[ignore]
fn send_recv() {
let listener = Listener::new(UNIX_SOCKET_SEND, true).unwrap();
listener.set_nonblocking(true).unwrap();
let mut master = Endpoint::<MasterReq>::connect(UNIX_SOCKET_SEND).unwrap();
let sock = listener.accept().unwrap().unwrap();
let mut slave = Endpoint::<MasterReq>::from_stream(sock);
let mut hdr1 =
VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0, mem::size_of::<u64>() as u32);
hdr1.set_need_reply(true);
let features1 = 0x1u64;
master.send_message(&hdr1, &features1, None).unwrap();
let mut features2 = 0u64;
let slice = unsafe {
slice::from_raw_parts_mut(
(&mut features2 as *mut u64) as *mut u8,
mem::size_of::<u64>(),
)
};
let (hdr2, bytes, rfds) = slave.recv_body_into_buf(slice).unwrap();
assert_eq!(hdr1, hdr2);
assert_eq!(bytes, 8);
assert_eq!(features1, features2);
assert!(rfds.is_none());
master.send_header(&hdr1, None).unwrap();
let (hdr2, rfds) = slave.recv_header().unwrap();
assert_eq!(hdr1, hdr2);
assert!(rfds.is_none());
}
}

View File

@ -1,250 +0,0 @@
// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use super::message::*;
use super::*;
use std::os::unix::io::RawFd;
pub const MAX_QUEUE_NUM: usize = 2;
pub const MAX_VRING_NUM: usize = 256;
pub const VIRTIO_FEATURES: u64 = 0x40000003;
#[derive(Default)]
pub struct DummySlaveReqHandler {
pub owned: bool,
pub features_acked: bool,
pub acked_features: u64,
pub acked_protocol_features: u64,
pub queue_num: usize,
pub vring_num: [u32; MAX_QUEUE_NUM],
pub vring_base: [u32; MAX_QUEUE_NUM],
pub call_fd: [Option<RawFd>; MAX_QUEUE_NUM],
pub kick_fd: [Option<RawFd>; MAX_QUEUE_NUM],
pub err_fd: [Option<RawFd>; MAX_QUEUE_NUM],
pub vring_started: [bool; MAX_QUEUE_NUM],
pub vring_enabled: [bool; MAX_QUEUE_NUM],
}
impl DummySlaveReqHandler {
pub fn new() -> Self {
DummySlaveReqHandler {
queue_num: MAX_QUEUE_NUM,
..Default::default()
}
}
}
impl VhostUserSlaveReqHandler for DummySlaveReqHandler {
fn set_owner(&mut self) -> Result<()> {
if self.owned {
return Err(Error::InvalidOperation);
}
self.owned = true;
Ok(())
}
fn reset_owner(&mut self) -> Result<()> {
self.owned = false;
self.features_acked = false;
self.acked_features = 0;
self.acked_protocol_features = 0;
Ok(())
}
fn get_features(&mut self) -> Result<u64> {
Ok(VIRTIO_FEATURES)
}
fn set_features(&mut self, features: u64) -> Result<()> {
if !self.owned {
return Err(Error::InvalidOperation);
} else if self.features_acked {
return Err(Error::InvalidOperation);
} else if (features & !VIRTIO_FEATURES) != 0 {
return Err(Error::InvalidParam);
}
self.acked_features = features;
self.features_acked = true;
// If VHOST_USER_F_PROTOCOL_FEATURES has not been negotiated,
// the ring is initialized in an enabled state.
// If VHOST_USER_F_PROTOCOL_FEATURES has been negotiated,
// the ring is initialized in a disabled state. Client must not
// pass data to/from the backend until ring is enabled by
// VHOST_USER_SET_VRING_ENABLE with parameter 1, or after it has
// been disabled by VHOST_USER_SET_VRING_ENABLE with parameter 0.
let vring_enabled =
self.acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() == 0;
for enabled in &mut self.vring_enabled {
*enabled = vring_enabled;
}
Ok(())
}
fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures> {
Ok(VhostUserProtocolFeatures::all())
}
fn set_protocol_features(&mut self, features: u64) -> Result<()> {
// Note: slave that reported VHOST_USER_F_PROTOCOL_FEATURES must
// support this message even before VHOST_USER_SET_FEATURES was
// called.
// What happens if the master calls set_features() with
// VHOST_USER_F_PROTOCOL_FEATURES cleared after calling this
// interface?
self.acked_protocol_features = features;
Ok(())
}
fn set_mem_table(&mut self, _ctx: &[VhostUserMemoryRegion], _fds: &[RawFd]) -> Result<()> {
// TODO
Ok(())
}
fn get_queue_num(&mut self) -> Result<u64> {
Ok(MAX_QUEUE_NUM as u64)
}
fn set_vring_num(&mut self, index: u32, num: u32) -> Result<()> {
if index as usize >= self.queue_num || num == 0 || num as usize > MAX_VRING_NUM {
return Err(Error::InvalidParam);
}
self.vring_num[index as usize] = num;
Ok(())
}
fn set_vring_addr(
&mut self,
index: u32,
_flags: VhostUserVringAddrFlags,
_descriptor: u64,
_used: u64,
_available: u64,
_log: u64,
) -> Result<()> {
if index as usize >= self.queue_num {
return Err(Error::InvalidParam);
}
Ok(())
}
fn set_vring_base(&mut self, index: u32, base: u32) -> Result<()> {
if index as usize >= self.queue_num || base as usize >= MAX_VRING_NUM {
return Err(Error::InvalidParam);
}
self.vring_base[index as usize] = base;
Ok(())
}
fn get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState> {
if index as usize >= self.queue_num {
return Err(Error::InvalidParam);
}
// Quotation from vhost-user spec:
// Client must start ring upon receiving a kick (that is, detecting
// that file descriptor is readable) on the descriptor specified by
// VHOST_USER_SET_VRING_KICK, and stop ring upon receiving
// VHOST_USER_GET_VRING_BASE.
self.vring_started[index as usize] = false;
Ok(VhostUserVringState::new(
index,
self.vring_base[index as usize],
))
}
fn set_vring_kick(&mut self, index: u8, fd: Option<RawFd>) -> Result<()> {
if index as usize >= self.queue_num || index as usize > self.queue_num {
return Err(Error::InvalidParam);
}
if self.kick_fd[index as usize].is_some() {
// Close file descriptor set by previous operations.
let _ = unsafe { libc::close(self.kick_fd[index as usize].unwrap()) };
}
self.kick_fd[index as usize] = fd;
// Quotation from vhost-user spec:
// Client must start ring upon receiving a kick (that is, detecting
// that file descriptor is readable) on the descriptor specified by
// VHOST_USER_SET_VRING_KICK, and stop ring upon receiving
// VHOST_USER_GET_VRING_BASE.
//
// So we should add fd to event monitor(select, poll, epoll) here.
self.vring_started[index as usize] = true;
Ok(())
}
fn set_vring_call(&mut self, index: u8, fd: Option<RawFd>) -> Result<()> {
if index as usize >= self.queue_num || index as usize > self.queue_num {
return Err(Error::InvalidParam);
}
if self.call_fd[index as usize].is_some() {
// Close file descriptor set by previous operations.
let _ = unsafe { libc::close(self.call_fd[index as usize].unwrap()) };
}
self.call_fd[index as usize] = fd;
Ok(())
}
fn set_vring_err(&mut self, index: u8, fd: Option<RawFd>) -> Result<()> {
if index as usize >= self.queue_num || index as usize > self.queue_num {
return Err(Error::InvalidParam);
}
if self.err_fd[index as usize].is_some() {
// Close file descriptor set by previous operations.
let _ = unsafe { libc::close(self.err_fd[index as usize].unwrap()) };
}
self.err_fd[index as usize] = fd;
Ok(())
}
fn set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()> {
// This request should be handled only when VHOST_USER_F_PROTOCOL_FEATURES
// has been negotiated.
if self.acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() == 0 {
return Err(Error::InvalidOperation);
} else if index as usize >= self.queue_num || index as usize > self.queue_num {
return Err(Error::InvalidParam);
}
// Slave must not pass data to/from the backend until ring is
// enabled by VHOST_USER_SET_VRING_ENABLE with parameter 1,
// or after it has been disabled by VHOST_USER_SET_VRING_ENABLE
// with parameter 0.
self.vring_enabled[index as usize] = enable;
Ok(())
}
fn get_config(
&mut self,
offset: u32,
size: u32,
_flags: VhostUserConfigFlags,
) -> Result<Vec<u8>> {
if self.acked_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
return Err(Error::InvalidOperation);
} else if offset < VHOST_USER_CONFIG_OFFSET
|| offset >= VHOST_USER_CONFIG_SIZE
|| size > VHOST_USER_CONFIG_SIZE - VHOST_USER_CONFIG_OFFSET
|| size + offset > VHOST_USER_CONFIG_SIZE
{
return Err(Error::InvalidParam);
}
Ok(vec![0xa5; size as usize])
}
fn set_config(&mut self, offset: u32, buf: &[u8], _flags: VhostUserConfigFlags) -> Result<()> {
let size = buf.len() as u32;
if self.acked_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
return Err(Error::InvalidOperation);
} else if offset < VHOST_USER_CONFIG_OFFSET
|| offset >= VHOST_USER_CONFIG_SIZE
|| size > VHOST_USER_CONFIG_SIZE - VHOST_USER_CONFIG_OFFSET
|| size + offset > VHOST_USER_CONFIG_SIZE
{
return Err(Error::InvalidParam);
}
Ok(())
}
}

View File

@ -1,784 +0,0 @@
// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Traits and Struct for vhost-user master.
use std::mem;
use std::os::unix::io::{AsRawFd, RawFd};
use std::os::unix::net::UnixStream;
use std::sync::{Arc, Mutex};
use vmm_sys_util::eventfd::EventFd;
use super::connection::Endpoint;
use super::message::*;
use super::{Error as VhostUserError, Result as VhostUserResult};
use crate::backend::{VhostBackend, VhostUserMemoryRegionInfo, VringConfigData};
use crate::{Error, Result};
/// Trait for vhost-user master to provide extra methods not covered by the VhostBackend yet.
pub trait VhostUserMaster: VhostBackend {
/// Get the protocol feature bitmask from the underlying vhost implementation.
fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures>;
/// Enable protocol features in the underlying vhost implementation.
fn set_protocol_features(&mut self, features: VhostUserProtocolFeatures) -> Result<()>;
/// Query how many queues the backend supports.
fn get_queue_num(&mut self) -> Result<u64>;
/// Signal slave to enable or disable corresponding vring.
///
/// Slave must not pass data to/from the backend until ring is enabled by
/// VHOST_USER_SET_VRING_ENABLE with parameter 1, or after it has been
/// disabled by VHOST_USER_SET_VRING_ENABLE with parameter 0.
fn set_vring_enable(&mut self, queue_index: usize, enable: bool) -> Result<()>;
/// Fetch the contents of the virtio device configuration space.
fn get_config(
&mut self,
offset: u32,
size: u32,
flags: VhostUserConfigFlags,
buf: &[u8],
) -> Result<(VhostUserConfig, VhostUserConfigPayload)>;
/// Change the virtio device configuration space. It also can be used for live migration on the
/// destination host to set readonly configuration space fields.
fn set_config(&mut self, offset: u32, flags: VhostUserConfigFlags, buf: &[u8]) -> Result<()>;
/// Setup slave communication channel.
fn set_slave_request_fd(&mut self, fd: RawFd) -> Result<()>;
}
fn error_code<T>(err: VhostUserError) -> Result<T> {
Err(Error::VhostUserProtocol(err))
}
/// Struct for the vhost-user master endpoint.
#[derive(Clone)]
pub struct Master {
node: Arc<Mutex<MasterInternal>>,
}
impl Master {
/// Create a new instance.
fn new(ep: Endpoint<MasterReq>, max_queue_num: u64) -> Self {
Master {
node: Arc::new(Mutex::new(MasterInternal {
main_sock: ep,
virtio_features: 0,
acked_virtio_features: 0,
protocol_features: 0,
acked_protocol_features: 0,
protocol_features_ready: false,
max_queue_num,
error: None,
})),
}
}
/// Create a new instance from a Unix stream socket.
pub fn from_stream(sock: UnixStream, max_queue_num: u64) -> Self {
Self::new(Endpoint::<MasterReq>::from_stream(sock), max_queue_num)
}
/// Create a new vhost-user master endpoint.
///
/// Will retry as the backend may not be ready to accept the connection.
///
/// # Arguments
/// * `path` - path of Unix domain socket listener to connect to
pub fn connect(path: &str, max_queue_num: u64) -> Result<Self> {
let mut retry_count = 5;
let endpoint = loop {
match Endpoint::<MasterReq>::connect(path) {
Ok(endpoint) => break Ok(endpoint),
Err(e) => match &e {
VhostUserError::SocketConnect(why) => {
if why.kind() == std::io::ErrorKind::ConnectionRefused && retry_count > 0 {
std::thread::sleep(std::time::Duration::from_millis(100));
retry_count -= 1;
continue;
} else {
break Err(e);
}
}
_ => break Err(e),
},
}
}?;
Ok(Self::new(endpoint, max_queue_num))
}
}
impl VhostBackend for Master {
/// Get from the underlying vhost implementation the feature bitmask.
fn get_features(&mut self) -> Result<u64> {
let mut node = self.node.lock().unwrap();
let hdr = node.send_request_header(MasterReq::GET_FEATURES, None)?;
let val = node.recv_reply::<VhostUserU64>(&hdr)?;
node.virtio_features = val.value;
Ok(node.virtio_features)
}
/// Enable features in the underlying vhost implementation using a bitmask.
fn set_features(&mut self, features: u64) -> Result<()> {
let mut node = self.node.lock().unwrap();
let val = VhostUserU64::new(features);
let _ = node.send_request_with_body(MasterReq::SET_FEATURES, &val, None)?;
// Don't wait for ACK here because the protocol feature negotiation process hasn't been
// completed yet.
node.acked_virtio_features = features & node.virtio_features;
Ok(())
}
/// Set the current Master as an owner of the session.
fn set_owner(&mut self) -> Result<()> {
// We unwrap() the return value to assert that we are not expecting threads to ever fail
// while holding the lock.
let mut node = self.node.lock().unwrap();
let _ = node.send_request_header(MasterReq::SET_OWNER, None)?;
// Don't wait for ACK here because the protocol feature negotiation process hasn't been
// completed yet.
Ok(())
}
fn reset_owner(&mut self) -> Result<()> {
let mut node = self.node.lock().unwrap();
let _ = node.send_request_header(MasterReq::RESET_OWNER, None)?;
// Don't wait for ACK here because the protocol feature negotiation process hasn't been
// completed yet.
Ok(())
}
/// Set the memory map regions on the slave so it can translate the vring
/// addresses. In the ancillary data there is an array of file descriptors
fn set_mem_table(&mut self, regions: &[VhostUserMemoryRegionInfo]) -> Result<()> {
if regions.is_empty() || regions.len() > MAX_ATTACHED_FD_ENTRIES {
return error_code(VhostUserError::InvalidParam);
}
let mut ctx = VhostUserMemoryContext::new();
for region in regions.iter() {
if region.memory_size == 0 || region.mmap_handle < 0 {
return error_code(VhostUserError::InvalidParam);
}
let reg = VhostUserMemoryRegion {
guest_phys_addr: region.guest_phys_addr,
memory_size: region.memory_size,
user_addr: region.userspace_addr,
mmap_offset: region.mmap_offset,
};
ctx.append(&reg, region.mmap_handle);
}
let mut node = self.node.lock().unwrap();
let body = VhostUserMemory::new(ctx.regions.len() as u32);
let hdr = node.send_request_with_payload(
MasterReq::SET_MEM_TABLE,
&body,
ctx.regions.as_slice(),
Some(ctx.fds.as_slice()),
)?;
node.wait_for_ack(&hdr).map_err(|e| e.into())
}
// Clippy doesn't seem to know that if let with && is still experimental
#[allow(clippy::unnecessary_unwrap)]
fn set_log_base(&mut self, base: u64, fd: Option<RawFd>) -> Result<()> {
let mut node = self.node.lock().unwrap();
let val = VhostUserU64::new(base);
if node.acked_protocol_features & VhostUserProtocolFeatures::LOG_SHMFD.bits() != 0
&& fd.is_some()
{
let fds = [fd.unwrap()];
let _ = node.send_request_with_body(MasterReq::SET_LOG_BASE, &val, Some(&fds))?;
} else {
let _ = node.send_request_with_body(MasterReq::SET_LOG_BASE, &val, None)?;
}
Ok(())
}
fn set_log_fd(&mut self, fd: RawFd) -> Result<()> {
let mut node = self.node.lock().unwrap();
let fds = [fd];
node.send_request_header(MasterReq::SET_LOG_FD, Some(&fds))?;
Ok(())
}
/// Set the size of the queue.
fn set_vring_num(&mut self, queue_index: usize, num: u16) -> Result<()> {
let mut node = self.node.lock().unwrap();
if queue_index as u64 >= node.max_queue_num {
return error_code(VhostUserError::InvalidParam);
}
let val = VhostUserVringState::new(queue_index as u32, num.into());
let hdr = node.send_request_with_body(MasterReq::SET_VRING_NUM, &val, None)?;
node.wait_for_ack(&hdr).map_err(|e| e.into())
}
/// Sets the addresses of the different aspects of the vring.
fn set_vring_addr(&mut self, queue_index: usize, config_data: &VringConfigData) -> Result<()> {
let mut node = self.node.lock().unwrap();
if queue_index as u64 >= node.max_queue_num
|| config_data.flags & !(VhostUserVringAddrFlags::all().bits()) != 0
{
return error_code(VhostUserError::InvalidParam);
}
let val = VhostUserVringAddr::from_config_data(queue_index as u32, config_data);
let hdr = node.send_request_with_body(MasterReq::SET_VRING_ADDR, &val, None)?;
node.wait_for_ack(&hdr).map_err(|e| e.into())
}
/// Sets the base offset in the available vring.
fn set_vring_base(&mut self, queue_index: usize, base: u16) -> Result<()> {
let mut node = self.node.lock().unwrap();
if queue_index as u64 >= node.max_queue_num {
return error_code(VhostUserError::InvalidParam);
}
let val = VhostUserVringState::new(queue_index as u32, base.into());
let hdr = node.send_request_with_body(MasterReq::SET_VRING_BASE, &val, None)?;
node.wait_for_ack(&hdr).map_err(|e| e.into())
}
fn get_vring_base(&mut self, queue_index: usize) -> Result<u32> {
let mut node = self.node.lock().unwrap();
if queue_index as u64 >= node.max_queue_num {
return error_code(VhostUserError::InvalidParam);
}
let req = VhostUserVringState::new(queue_index as u32, 0);
let hdr = node.send_request_with_body(MasterReq::GET_VRING_BASE, &req, None)?;
let reply = node.recv_reply::<VhostUserVringState>(&hdr)?;
Ok(reply.num)
}
/// Set the event file descriptor to signal when buffers are used.
/// Bits (0-7) of the payload contain the vring index. Bit 8 is the invalid FD flag. This flag
/// is set when there is no file descriptor in the ancillary data. This signals that polling
/// will be used instead of waiting for the call.
fn set_vring_call(&mut self, queue_index: usize, fd: &EventFd) -> Result<()> {
let mut node = self.node.lock().unwrap();
if queue_index as u64 >= node.max_queue_num {
return error_code(VhostUserError::InvalidParam);
}
node.send_fd_for_vring(MasterReq::SET_VRING_CALL, queue_index, fd.as_raw_fd())?;
Ok(())
}
/// Set the event file descriptor for adding buffers to the vring.
/// Bits (0-7) of the payload contain the vring index. Bit 8 is the invalid FD flag. This flag
/// is set when there is no file descriptor in the ancillary data. This signals that polling
/// should be used instead of waiting for a kick.
fn set_vring_kick(&mut self, queue_index: usize, fd: &EventFd) -> Result<()> {
let mut node = self.node.lock().unwrap();
if queue_index as u64 >= node.max_queue_num {
return error_code(VhostUserError::InvalidParam);
}
node.send_fd_for_vring(MasterReq::SET_VRING_KICK, queue_index, fd.as_raw_fd())?;
Ok(())
}
/// Set the event file descriptor to signal when error occurs.
/// Bits (0-7) of the payload contain the vring index. Bit 8 is the invalid FD flag. This flag
/// is set when there is no file descriptor in the ancillary data.
fn set_vring_err(&mut self, queue_index: usize, fd: &EventFd) -> Result<()> {
let mut node = self.node.lock().unwrap();
if queue_index as u64 >= node.max_queue_num {
return error_code(VhostUserError::InvalidParam);
}
node.send_fd_for_vring(MasterReq::SET_VRING_ERR, queue_index, fd.as_raw_fd())?;
Ok(())
}
}
impl VhostUserMaster for Master {
fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures> {
let mut node = self.node.lock().unwrap();
let flag = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits();
if node.virtio_features & flag == 0 || node.acked_virtio_features & flag == 0 {
return error_code(VhostUserError::InvalidOperation);
}
let hdr = node.send_request_header(MasterReq::GET_PROTOCOL_FEATURES, None)?;
let val = node.recv_reply::<VhostUserU64>(&hdr)?;
node.protocol_features = val.value;
// Should we support forward compatibility?
// If so just mask out unrecognized flags instead of return errors.
match VhostUserProtocolFeatures::from_bits(node.protocol_features) {
Some(val) => Ok(val),
None => error_code(VhostUserError::InvalidMessage),
}
}
fn set_protocol_features(&mut self, features: VhostUserProtocolFeatures) -> Result<()> {
let mut node = self.node.lock().unwrap();
let flag = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits();
if node.virtio_features & flag == 0 || node.acked_virtio_features & flag == 0 {
return error_code(VhostUserError::InvalidOperation);
}
let val = VhostUserU64::new(features.bits());
let _ = node.send_request_with_body(MasterReq::SET_PROTOCOL_FEATURES, &val, None)?;
// Don't wait for ACK here because the protocol feature negotiation process hasn't been
// completed yet.
node.acked_protocol_features = features.bits();
node.protocol_features_ready = true;
Ok(())
}
fn get_queue_num(&mut self) -> Result<u64> {
let mut node = self.node.lock().unwrap();
if !node.is_feature_mq_available() {
return error_code(VhostUserError::InvalidOperation);
}
let hdr = node.send_request_header(MasterReq::GET_QUEUE_NUM, None)?;
let val = node.recv_reply::<VhostUserU64>(&hdr)?;
if val.value > VHOST_USER_MAX_VRINGS {
return error_code(VhostUserError::InvalidMessage);
}
node.max_queue_num = val.value;
Ok(node.max_queue_num)
}
fn set_vring_enable(&mut self, queue_index: usize, enable: bool) -> Result<()> {
let mut node = self.node.lock().unwrap();
// set_vring_enable() is supported only when PROTOCOL_FEATURES has been enabled.
if node.acked_virtio_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() == 0 {
return error_code(VhostUserError::InvalidOperation);
} else if queue_index as u64 >= node.max_queue_num {
return error_code(VhostUserError::InvalidParam);
}
let flag = if enable { 1 } else { 0 };
let val = VhostUserVringState::new(queue_index as u32, flag);
let hdr = node.send_request_with_body(MasterReq::SET_VRING_ENABLE, &val, None)?;
node.wait_for_ack(&hdr).map_err(|e| e.into())
}
fn get_config(
&mut self,
offset: u32,
size: u32,
flags: VhostUserConfigFlags,
buf: &[u8],
) -> Result<(VhostUserConfig, VhostUserConfigPayload)> {
let body = VhostUserConfig::new(offset, size, flags);
if !body.is_valid() {
return error_code(VhostUserError::InvalidParam);
}
let mut node = self.node.lock().unwrap();
// depends on VhostUserProtocolFeatures::CONFIG
if node.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
return error_code(VhostUserError::InvalidOperation);
}
// vhost-user spec states that:
// "Master payload: virtio device config space"
// "Slave payload: virtio device config space"
let hdr = node.send_request_with_payload(MasterReq::GET_CONFIG, &body, buf, None)?;
let (body_reply, buf_reply, rfds) =
node.recv_reply_with_payload::<VhostUserConfig>(&hdr)?;
if rfds.is_some() {
Endpoint::<MasterReq>::close_rfds(rfds);
return error_code(VhostUserError::InvalidMessage);
} else if body_reply.size == 0 {
return error_code(VhostUserError::SlaveInternalError);
} else if body_reply.size != body.size || body_reply.size as usize != buf.len() {
return error_code(VhostUserError::InvalidMessage);
}
Ok((body_reply, buf_reply))
}
fn set_config(&mut self, offset: u32, flags: VhostUserConfigFlags, buf: &[u8]) -> Result<()> {
if buf.len() > MAX_MSG_SIZE {
return error_code(VhostUserError::InvalidParam);
}
let body = VhostUserConfig::new(offset, buf.len() as u32, flags);
if !body.is_valid() {
return error_code(VhostUserError::InvalidParam);
}
let mut node = self.node.lock().unwrap();
// depends on VhostUserProtocolFeatures::CONFIG
if node.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
return error_code(VhostUserError::InvalidOperation);
}
let hdr = node.send_request_with_payload(MasterReq::SET_CONFIG, &body, buf, None)?;
node.wait_for_ack(&hdr).map_err(|e| e.into())
}
fn set_slave_request_fd(&mut self, fd: RawFd) -> Result<()> {
let mut node = self.node.lock().unwrap();
if node.acked_protocol_features & VhostUserProtocolFeatures::SLAVE_REQ.bits() == 0 {
return error_code(VhostUserError::InvalidOperation);
}
let fds = [fd];
node.send_request_header(MasterReq::SET_SLAVE_REQ_FD, Some(&fds))?;
Ok(())
}
}
impl AsRawFd for Master {
fn as_raw_fd(&self) -> RawFd {
let node = self.node.lock().unwrap();
node.main_sock.as_raw_fd()
}
}
/// Context object to pass guest memory configuration to VhostUserMaster::set_mem_table().
struct VhostUserMemoryContext {
regions: VhostUserMemoryPayload,
fds: Vec<RawFd>,
}
impl VhostUserMemoryContext {
/// Create a context object.
pub fn new() -> Self {
VhostUserMemoryContext {
regions: VhostUserMemoryPayload::new(),
fds: Vec::new(),
}
}
/// Append a user memory region and corresponding RawFd into the context object.
pub fn append(&mut self, region: &VhostUserMemoryRegion, fd: RawFd) {
self.regions.push(*region);
self.fds.push(fd);
}
}
struct MasterInternal {
// Used to send requests to the slave.
main_sock: Endpoint<MasterReq>,
// Cached virtio features from the slave.
virtio_features: u64,
// Cached acked virtio features from the driver.
acked_virtio_features: u64,
// Cached vhost-user protocol features from the slave.
protocol_features: u64,
// Cached vhost-user protocol features.
acked_protocol_features: u64,
// Cached vhost-user protocol features are ready to use.
protocol_features_ready: bool,
// Cached maxinum number of queues supported from the slave.
max_queue_num: u64,
// Internal flag to mark failure state.
error: Option<i32>,
}
impl MasterInternal {
fn send_request_header(
&mut self,
code: MasterReq,
fds: Option<&[RawFd]>,
) -> VhostUserResult<VhostUserMsgHeader<MasterReq>> {
self.check_state()?;
let hdr = Self::new_request_header(code, 0);
self.main_sock.send_header(&hdr, fds)?;
Ok(hdr)
}
fn send_request_with_body<T: Sized>(
&mut self,
code: MasterReq,
msg: &T,
fds: Option<&[RawFd]>,
) -> VhostUserResult<VhostUserMsgHeader<MasterReq>> {
if mem::size_of::<T>() > MAX_MSG_SIZE {
return Err(VhostUserError::InvalidParam);
}
self.check_state()?;
let hdr = Self::new_request_header(code, mem::size_of::<T>() as u32);
self.main_sock.send_message(&hdr, msg, fds)?;
Ok(hdr)
}
fn send_request_with_payload<T: Sized, P: Sized>(
&mut self,
code: MasterReq,
msg: &T,
payload: &[P],
fds: Option<&[RawFd]>,
) -> VhostUserResult<VhostUserMsgHeader<MasterReq>> {
let len = mem::size_of::<T>() + payload.len() * mem::size_of::<P>();
if len > MAX_MSG_SIZE {
return Err(VhostUserError::InvalidParam);
}
if let Some(ref fd_arr) = fds {
if fd_arr.len() > MAX_ATTACHED_FD_ENTRIES {
return Err(VhostUserError::InvalidParam);
}
}
self.check_state()?;
let hdr = Self::new_request_header(code, len as u32);
self.main_sock
.send_message_with_payload(&hdr, msg, payload, fds)?;
Ok(hdr)
}
fn send_fd_for_vring(
&mut self,
code: MasterReq,
queue_index: usize,
fd: RawFd,
) -> VhostUserResult<VhostUserMsgHeader<MasterReq>> {
if queue_index as u64 >= self.max_queue_num {
return Err(VhostUserError::InvalidParam);
}
self.check_state()?;
// Bits (0-7) of the payload contain the vring index. Bit 8 is the invalid FD flag.
// This flag is set when there is no file descriptor in the ancillary data. This signals
// that polling will be used instead of waiting for the call.
let msg = VhostUserU64::new(queue_index as u64);
let hdr = Self::new_request_header(code, mem::size_of::<VhostUserU64>() as u32);
self.main_sock.send_message(&hdr, &msg, Some(&[fd]))?;
Ok(hdr)
}
fn recv_reply<T: Sized + Default + VhostUserMsgValidator>(
&mut self,
hdr: &VhostUserMsgHeader<MasterReq>,
) -> VhostUserResult<T> {
if mem::size_of::<T>() > MAX_MSG_SIZE || hdr.is_reply() {
return Err(VhostUserError::InvalidParam);
}
self.check_state()?;
let (reply, body, rfds) = self.main_sock.recv_body::<T>()?;
if !reply.is_reply_for(&hdr) || rfds.is_some() || !body.is_valid() {
Endpoint::<MasterReq>::close_rfds(rfds);
return Err(VhostUserError::InvalidMessage);
}
Ok(body)
}
fn recv_reply_with_payload<T: Sized + Default + VhostUserMsgValidator>(
&mut self,
hdr: &VhostUserMsgHeader<MasterReq>,
) -> VhostUserResult<(T, Vec<u8>, Option<Vec<RawFd>>)> {
if mem::size_of::<T>() > MAX_MSG_SIZE || hdr.is_reply() {
return Err(VhostUserError::InvalidParam);
}
self.check_state()?;
let mut buf: Vec<u8> = vec![0; hdr.get_size() as usize - mem::size_of::<T>()];
let (reply, body, bytes, rfds) = self.main_sock.recv_payload_into_buf::<T>(&mut buf)?;
if !reply.is_reply_for(hdr)
|| reply.get_size() as usize != mem::size_of::<T>() + bytes
|| rfds.is_some()
|| !body.is_valid()
{
Endpoint::<MasterReq>::close_rfds(rfds);
return Err(VhostUserError::InvalidMessage);
} else if bytes > MAX_MSG_SIZE - mem::size_of::<T>() {
return Err(VhostUserError::InvalidMessage);
} else if bytes < buf.len() {
// It's safe because we have checked the buffer size
unsafe { buf.set_len(bytes) };
}
Ok((body, buf, rfds))
}
fn wait_for_ack(&mut self, hdr: &VhostUserMsgHeader<MasterReq>) -> VhostUserResult<()> {
if self.acked_protocol_features & VhostUserProtocolFeatures::REPLY_ACK.bits() == 0
|| !hdr.is_need_reply()
{
return Ok(());
}
self.check_state()?;
let (reply, body, rfds) = self.main_sock.recv_body::<VhostUserU64>()?;
if !reply.is_reply_for(&hdr) || rfds.is_some() || !body.is_valid() {
Endpoint::<MasterReq>::close_rfds(rfds);
return Err(VhostUserError::InvalidMessage);
}
if body.value != 0 {
return Err(VhostUserError::SlaveInternalError);
}
Ok(())
}
fn is_feature_mq_available(&self) -> bool {
self.acked_protocol_features & VhostUserProtocolFeatures::MQ.bits() != 0
}
fn check_state(&self) -> VhostUserResult<()> {
match self.error {
Some(e) => Err(VhostUserError::SocketBroken(
std::io::Error::from_raw_os_error(e),
)),
None => Ok(()),
}
}
#[inline]
fn new_request_header(request: MasterReq, size: u32) -> VhostUserMsgHeader<MasterReq> {
// TODO: handle NEED_REPLY flag
VhostUserMsgHeader::new(request, 0x1, size)
}
}
#[cfg(test)]
mod tests {
use super::super::connection::Listener;
use super::*;
const UNIX_SOCKET_MASTER: &'static str = "/tmp/vhost_user_test_rust_master";
const UNIX_SOCKET_MASTER2: &'static str = "/tmp/vhost_user_test_rust_master2";
const UNIX_SOCKET_MASTER3: &'static str = "/tmp/vhost_user_test_rust_master3";
const UNIX_SOCKET_MASTER4: &'static str = "/tmp/vhost_user_test_rust_master4";
fn create_pair(path: &str) -> (Master, Endpoint<MasterReq>) {
let listener = Listener::new(path, true).unwrap();
listener.set_nonblocking(true).unwrap();
let master = Master::connect(path, 2).unwrap();
let slave = listener.accept().unwrap().unwrap();
(master, Endpoint::from_stream(slave))
}
#[test]
#[ignore]
fn create_master() {
let listener = Listener::new(UNIX_SOCKET_MASTER, true).unwrap();
listener.set_nonblocking(true).unwrap();
let mut master = Master::connect(UNIX_SOCKET_MASTER, 2).unwrap();
let mut slave = Endpoint::<MasterReq>::from_stream(listener.accept().unwrap().unwrap());
// Send two messages continuously
master.set_owner().unwrap();
master.reset_owner().unwrap();
let (hdr, rfds) = slave.recv_header().unwrap();
assert_eq!(hdr.get_code(), MasterReq::SET_OWNER);
assert_eq!(hdr.get_size(), 0);
assert_eq!(hdr.get_version(), 0x1);
assert!(rfds.is_none());
let (hdr, rfds) = slave.recv_header().unwrap();
assert_eq!(hdr.get_code(), MasterReq::RESET_OWNER);
assert_eq!(hdr.get_size(), 0);
assert_eq!(hdr.get_version(), 0x1);
assert!(rfds.is_none());
}
#[test]
#[ignore]
fn test_create_failure() {
let _ = Listener::new(UNIX_SOCKET_MASTER2, true).unwrap();
let _ = Listener::new(UNIX_SOCKET_MASTER2, false).is_err();
assert!(Master::connect(UNIX_SOCKET_MASTER2, 2).is_err());
let listener = Listener::new(UNIX_SOCKET_MASTER2, true).unwrap();
assert!(Listener::new(UNIX_SOCKET_MASTER2, false).is_err());
listener.set_nonblocking(true).unwrap();
let _master = Master::connect(UNIX_SOCKET_MASTER2, 2).unwrap();
let _slave = listener.accept().unwrap().unwrap();
}
#[test]
#[ignore]
fn test_features() {
let (mut master, mut peer) = create_pair(UNIX_SOCKET_MASTER3);
master.set_owner().unwrap();
let (hdr, rfds) = peer.recv_header().unwrap();
assert_eq!(hdr.get_code(), MasterReq::SET_OWNER);
assert_eq!(hdr.get_size(), 0);
assert_eq!(hdr.get_version(), 0x1);
assert!(rfds.is_none());
let hdr = VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0x4, 8);
let msg = VhostUserU64::new(0x15);
peer.send_message(&hdr, &msg, None).unwrap();
let features = master.get_features().unwrap();
assert_eq!(features, 0x15u64);
let (_hdr, rfds) = peer.recv_header().unwrap();
assert!(rfds.is_none());
master.set_features(0x15).unwrap();
let (_hdr, msg, rfds) = peer.recv_body::<VhostUserU64>().unwrap();
assert!(rfds.is_none());
let val = msg.value;
assert_eq!(val, 0x15);
let hdr = VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0x4, 8);
let msg = 0x15u32;
peer.send_message(&hdr, &msg, None).unwrap();
assert!(master.get_features().is_err());
}
#[test]
#[ignore]
fn test_protocol_features() {
let (mut master, mut peer) = create_pair(UNIX_SOCKET_MASTER4);
master.set_owner().unwrap();
let (hdr, rfds) = peer.recv_header().unwrap();
assert_eq!(hdr.get_code(), MasterReq::SET_OWNER);
assert!(rfds.is_none());
assert!(master.get_protocol_features().is_err());
assert!(master
.set_protocol_features(VhostUserProtocolFeatures::all())
.is_err());
let vfeatures = 0x15 | VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits();
let hdr = VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0x4, 8);
let msg = VhostUserU64::new(vfeatures);
peer.send_message(&hdr, &msg, None).unwrap();
let features = master.get_features().unwrap();
assert_eq!(features, vfeatures);
let (_hdr, rfds) = peer.recv_header().unwrap();
assert!(rfds.is_none());
master.set_features(vfeatures).unwrap();
let (_hdr, msg, rfds) = peer.recv_body::<VhostUserU64>().unwrap();
assert!(rfds.is_none());
let val = msg.value;
assert_eq!(val, vfeatures);
let pfeatures = VhostUserProtocolFeatures::all();
let hdr = VhostUserMsgHeader::new(MasterReq::GET_PROTOCOL_FEATURES, 0x4, 8);
let msg = VhostUserU64::new(pfeatures.bits());
peer.send_message(&hdr, &msg, None).unwrap();
let features = master.get_protocol_features().unwrap();
assert_eq!(features, pfeatures);
let (_hdr, rfds) = peer.recv_header().unwrap();
assert!(rfds.is_none());
master.set_protocol_features(pfeatures).unwrap();
let (_hdr, msg, rfds) = peer.recv_body::<VhostUserU64>().unwrap();
assert!(rfds.is_none());
let val = msg.value;
assert_eq!(val, pfeatures.bits());
let hdr = VhostUserMsgHeader::new(MasterReq::SET_PROTOCOL_FEATURES, 0x4, 8);
let msg = VhostUserU64::new(pfeatures.bits());
peer.send_message(&hdr, &msg, None).unwrap();
assert!(master.get_protocol_features().is_err());
}
#[test]
fn test_set_mem_table() {
// TODO
}
#[test]
fn test_get_ring_num() {
// TODO
}
}

View File

@ -1,258 +0,0 @@
// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Traits and Structs to handle vhost-user requests from the slave to the master.
use libc;
use std::mem;
use std::os::unix::io::{AsRawFd, RawFd};
use std::os::unix::net::UnixStream;
use std::sync::{Arc, Mutex};
use super::connection::Endpoint;
use super::message::*;
use super::{Error, HandlerResult, Result};
/// Trait to handle vhost-user requests from the slave to the master.
pub trait VhostUserMasterReqHandler {
// fn handle_iotlb_msg(&mut self, iotlb: VhostUserIotlb);
// fn handle_vring_host_notifier(&mut self, area: VhostUserVringArea, fd: RawFd);
/// Handle device configuration change notifications from the slave.
fn handle_config_change(&mut self) -> HandlerResult<()> {
Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
}
/// Handle virtio-fs map file requests from the slave.
fn fs_slave_map(&mut self, _fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<()> {
// Safe because we have just received the rawfd from kernel.
unsafe { libc::close(fd) };
Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
}
/// Handle virtio-fs unmap file requests from the slave.
fn fs_slave_unmap(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<()> {
Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
}
/// Handle virtio-fs sync file requests from the slave.
fn fs_slave_sync(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<()> {
Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
}
}
/// A vhost-user master request endpoint which relays all received requests from the slave to the
/// provided request handler.
pub struct MasterReqHandler<S: VhostUserMasterReqHandler> {
// underlying Unix domain socket for communication
sub_sock: Endpoint<SlaveReq>,
tx_sock: UnixStream,
// the VirtIO backend device object
backend: Arc<Mutex<S>>,
// whether the endpoint has encountered any failure
error: Option<i32>,
}
impl<S: VhostUserMasterReqHandler> MasterReqHandler<S> {
/// Create a vhost-user slave request handler.
/// This opens a pair of connected anonymous sockets.
/// Returns Self and the socket that must be sent to the slave via SET_SLAVE_REQ_FD.
pub fn new(backend: Arc<Mutex<S>>) -> Result<Self> {
let (tx, rx) = UnixStream::pair().map_err(Error::SocketError)?;
Ok(MasterReqHandler {
sub_sock: Endpoint::<SlaveReq>::from_stream(rx),
tx_sock: tx,
backend,
error: None,
})
}
/// Get the raw fd to send to the slave as slave communication channel.
pub fn get_tx_raw_fd(&self) -> RawFd {
self.tx_sock.as_raw_fd()
}
/// Mark endpoint as failed or normal state.
pub fn set_failed(&mut self, error: i32) {
self.error = Some(error);
}
/// Receive and handle one incoming request message from the slave.
/// The caller needs to:
/// . serialize calls to this function
/// . decide what to do when errer happens
/// . optional recover from failure
pub fn handle_request(&mut self) -> Result<()> {
// Return error if the endpoint is already in failed state.
self.check_state()?;
// The underlying communication channel is a Unix domain socket in
// stream mode, and recvmsg() is a little tricky here. To successfully
// receive attached file descriptors, we need to receive messages and
// corresponding attached file descriptors in this way:
// . recv messsage header and optional attached file
// . validate message header
// . recv optional message body and payload according size field in
// message header
// . validate message body and optional payload
let (hdr, rfds) = self.sub_sock.recv_header()?;
let rfds = self.check_attached_rfds(&hdr, rfds)?;
let (size, buf) = match hdr.get_size() {
0 => (0, vec![0u8; 0]),
len => {
let (size2, rbuf) = self.sub_sock.recv_data(len as usize)?;
if size2 != len as usize {
return Err(Error::InvalidMessage);
}
(size2, rbuf)
}
};
let res = match hdr.get_code() {
SlaveReq::CONFIG_CHANGE_MSG => {
self.check_msg_size(&hdr, size, 0)?;
self.backend
.lock()
.unwrap()
.handle_config_change()
.map_err(Error::ReqHandlerError)
}
SlaveReq::FS_MAP => {
let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?;
self.backend
.lock()
.unwrap()
.fs_slave_map(msg, rfds.unwrap()[0])
.map_err(Error::ReqHandlerError)
}
SlaveReq::FS_UNMAP => {
let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?;
self.backend
.lock()
.unwrap()
.fs_slave_unmap(msg)
.map_err(Error::ReqHandlerError)
}
SlaveReq::FS_SYNC => {
let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?;
self.backend
.lock()
.unwrap()
.fs_slave_sync(msg)
.map_err(Error::ReqHandlerError)
}
_ => Err(Error::InvalidMessage),
};
self.send_ack_message(&hdr, &res)?;
res
}
fn check_state(&self) -> Result<()> {
match self.error {
Some(e) => Err(Error::SocketBroken(std::io::Error::from_raw_os_error(e))),
None => Ok(()),
}
}
fn check_msg_size(
&self,
hdr: &VhostUserMsgHeader<SlaveReq>,
size: usize,
expected: usize,
) -> Result<()> {
if hdr.get_size() as usize != expected
|| hdr.is_reply()
|| hdr.get_version() != 0x1
|| size != expected
{
return Err(Error::InvalidMessage);
}
Ok(())
}
fn check_attached_rfds(
&self,
hdr: &VhostUserMsgHeader<SlaveReq>,
rfds: Option<Vec<RawFd>>,
) -> Result<Option<Vec<RawFd>>> {
match hdr.get_code() {
SlaveReq::FS_MAP => {
// Expect an fd set with a single fd.
match rfds {
None => Err(Error::InvalidMessage),
Some(fds) => {
if fds.len() != 1 {
Endpoint::<SlaveReq>::close_rfds(Some(fds));
Err(Error::InvalidMessage)
} else {
Ok(Some(fds))
}
}
}
}
_ => {
if rfds.is_some() {
Endpoint::<SlaveReq>::close_rfds(rfds);
Err(Error::InvalidMessage)
} else {
Ok(rfds)
}
}
}
}
fn extract_msg_body<'a, T: Sized + VhostUserMsgValidator>(
&self,
hdr: &VhostUserMsgHeader<SlaveReq>,
size: usize,
buf: &'a [u8],
) -> Result<&'a T> {
self.check_msg_size(hdr, size, mem::size_of::<T>())?;
let msg = unsafe { &*(buf.as_ptr() as *const T) };
if !msg.is_valid() {
return Err(Error::InvalidMessage);
}
Ok(msg)
}
fn new_reply_header<T: Sized>(
&self,
req: &VhostUserMsgHeader<SlaveReq>,
) -> Result<VhostUserMsgHeader<SlaveReq>> {
if mem::size_of::<T>() > MAX_MSG_SIZE {
return Err(Error::InvalidParam);
}
self.check_state()?;
Ok(VhostUserMsgHeader::new(
req.get_code(),
VhostUserHeaderFlag::REPLY.bits(),
mem::size_of::<T>() as u32,
))
}
fn send_ack_message(
&mut self,
req: &VhostUserMsgHeader<SlaveReq>,
res: &Result<()>,
) -> Result<()> {
if req.is_need_reply() {
let hdr = self.new_reply_header::<VhostUserU64>(req)?;
let val = match res {
Ok(_) => 0,
Err(_) => 1,
};
let msg = VhostUserU64::new(val);
self.sub_sock.send_message(&hdr, &msg, None)?;
}
Ok(())
}
}
impl<S: VhostUserMasterReqHandler> AsRawFd for MasterReqHandler<S> {
fn as_raw_fd(&self) -> RawFd {
self.sub_sock.as_raw_fd()
}
}

View File

@ -1,812 +0,0 @@
// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Define communication messages for the vhost-user protocol.
//!
//! For message definition, please refer to the [vhost-user spec](https://github.com/qemu/qemu/blob/f7526eece29cd2e36a63b6703508b24453095eb8/docs/interop/vhost-user.txt).
#![allow(dead_code)]
#![allow(non_camel_case_types)]
use std::fmt::Debug;
use std::marker::PhantomData;
use VringConfigData;
/// The vhost-user specification uses a field of u32 to store message length.
/// On the other hand, preallocated buffers are needed to receive messages from the Unix domain
/// socket. To preallocating a 4GB buffer for each vhost-user message is really just an overhead.
/// Among all defined vhost-user messages, only the VhostUserConfig and VhostUserMemory has variable
/// message size. For the VhostUserConfig, a maximum size of 4K is enough because the user
/// configuration space for virtio devices is (4K - 0x100) bytes at most. For the VhostUserMemory,
/// 4K should be enough too because it can support 255 memory regions at most.
pub const MAX_MSG_SIZE: usize = 0x1000;
/// The VhostUserMemory message has variable message size and variable number of attached file
/// descriptors. Each user memory region entry in the message payload occupies 32 bytes,
/// so setting maximum number of attached file descriptors based on the maximum message size.
/// But rust only implements Default and AsMut traits for arrays with 0 - 32 entries, so further
/// reduce the maximum number...
// pub const MAX_ATTACHED_FD_ENTRIES: usize = (MAX_MSG_SIZE - 8) / 32;
pub const MAX_ATTACHED_FD_ENTRIES: usize = 32;
/// Starting position (inclusion) of the device configuration space in virtio devices.
pub const VHOST_USER_CONFIG_OFFSET: u32 = 0x100;
/// Ending position (exclusion) of the device configuration space in virtio devices.
pub const VHOST_USER_CONFIG_SIZE: u32 = 0x1000;
/// Maximum number of vrings supported.
pub const VHOST_USER_MAX_VRINGS: u64 = 0xFFu64;
pub(super) trait Req:
Clone + Copy + Debug + PartialEq + Eq + PartialOrd + Ord + Into<u32>
{
fn is_valid(&self) -> bool;
}
/// Type of requests sending from masters to slaves.
#[repr(u32)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub enum MasterReq {
/// Null operation.
NOOP = 0,
/// Get from the underlying vhost implementation the features bit mask.
GET_FEATURES = 1,
/// Enable features in the underlying vhost implementation using a bit mask.
SET_FEATURES = 2,
/// Set the current Master as an owner of the session.
SET_OWNER = 3,
/// No longer used.
RESET_OWNER = 4,
/// Set the memory map regions on the slave so it can translate the vring addresses.
SET_MEM_TABLE = 5,
/// Set logging shared memory space.
SET_LOG_BASE = 6,
/// Set the logging file descriptor, which is passed as ancillary data.
SET_LOG_FD = 7,
/// Set the size of the queue.
SET_VRING_NUM = 8,
/// Set the addresses of the different aspects of the vring.
SET_VRING_ADDR = 9,
/// Set the base offset in the available vring.
SET_VRING_BASE = 10,
/// Get the available vring base offset.
GET_VRING_BASE = 11,
/// Set the event file descriptor for adding buffers to the vring.
SET_VRING_KICK = 12,
/// Set the event file descriptor to signal when buffers are used.
SET_VRING_CALL = 13,
/// Set the event file descriptor to signal when error occurs.
SET_VRING_ERR = 14,
/// Get the protocol feature bit mask from the underlying vhost implementation.
GET_PROTOCOL_FEATURES = 15,
/// Enable protocol features in the underlying vhost implementation.
SET_PROTOCOL_FEATURES = 16,
/// Query how many queues the backend supports.
GET_QUEUE_NUM = 17,
/// Signal slave to enable or disable corresponding vring.
SET_VRING_ENABLE = 18,
/// Ask vhost user backend to broadcast a fake RARP to notify the migration is terminated
/// for guest that does not support GUEST_ANNOUNCE.
SEND_RARP = 19,
/// Set host MTU value exposed to the guest.
NET_SET_MTU = 20,
/// Set the socket file descriptor for slave initiated requests.
SET_SLAVE_REQ_FD = 21,
/// Send IOTLB messages with struct vhost_iotlb_msg as payload.
IOTLB_MSG = 22,
/// Set the endianness of a VQ for legacy devices.
SET_VRING_ENDIAN = 23,
/// Fetch the contents of the virtio device configuration space.
GET_CONFIG = 24,
/// Change the contents of the virtio device configuration space.
SET_CONFIG = 25,
/// Create a session for crypto operation.
CREATE_CRYPTO_SESSION = 26,
/// Close a session for crypto operation.
CLOSE_CRYPTO_SESSION = 27,
/// Advise slave that a migration with postcopy enabled is underway.
POSTCOPY_ADVISE = 28,
/// Advise slave that a transition to postcopy mode has happened.
POSTCOPY_LISTEN = 29,
/// Advise that postcopy migration has now completed.
POSTCOPY_END = 30,
/// Get a shared buffer from slave.
GET_INFLIGHT_FD = 31,
/// Send the shared inflight buffer back to slave
SET_INFLIGHT_FD = 32,
/// Upper bound of valid commands.
MAX_CMD = 33,
}
impl Into<u32> for MasterReq {
fn into(self) -> u32 {
self as u32
}
}
impl Req for MasterReq {
fn is_valid(&self) -> bool {
(*self > MasterReq::NOOP) && (*self < MasterReq::MAX_CMD)
}
}
/// Type of requests sending from slaves to masters.
#[repr(u32)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub enum SlaveReq {
/// Null operation.
NOOP = 0,
/// Send IOTLB messages with struct vhost_iotlb_msg as payload.
IOTLB_MSG = 1,
/// Notify that the virtio device's configuration space has changed.
CONFIG_CHANGE_MSG = 2,
/// Set host notifier for a specified queue.
VRING_HOST_NOTIFIER_MSG = 3,
/// Virtio-fs draft: map file content into the window.
FS_MAP = 4,
/// Virtio-fs draft: unmap file content from the window.
FS_UNMAP = 5,
/// Virtio-fs draft: sync file content.
FS_SYNC = 6,
/// Upper bound of valid commands.
MAX_CMD = 7,
}
impl Into<u32> for SlaveReq {
fn into(self) -> u32 {
self as u32
}
}
impl Req for SlaveReq {
fn is_valid(&self) -> bool {
(*self > SlaveReq::NOOP) && (*self < SlaveReq::MAX_CMD)
}
}
/// Vhost message Validator.
pub trait VhostUserMsgValidator {
/// Validate message syntax only.
/// It doesn't validate message semantics such as protocol version number and dependency
/// on feature flags etc.
fn is_valid(&self) -> bool {
true
}
}
bitflags! {
/// Common message flags for vhost-user requests and replies.
pub struct VhostUserHeaderFlag: u32 {
/// Bits[0..2] is message version number.
const VERSION = 0x3;
/// Mark message as reply.
const REPLY = 0x4;
/// Sender anticipates a reply message from the peer.
const NEED_REPLY = 0x8;
/// All valid bits.
const ALL_FLAGS = 0xc;
/// All reserved bits.
const RESERVED_BITS = !0xf;
}
}
/// Common message header for vhost-user requests and replies.
/// A vhost-user message consists of 3 header fields and an optional payload. All numbers are in the
/// machine native byte order.
#[allow(safe_packed_borrows)]
#[repr(packed)]
#[derive(Debug, Clone, Copy, PartialEq)]
pub(super) struct VhostUserMsgHeader<R: Req> {
request: u32,
flags: u32,
size: u32,
_r: PhantomData<R>,
}
impl<R: Req> VhostUserMsgHeader<R> {
/// Create a new instance of `VhostUserMsgHeader`.
pub fn new(request: R, flags: u32, size: u32) -> Self {
// Default to protocol version 1
let fl = (flags & VhostUserHeaderFlag::ALL_FLAGS.bits()) | 0x1;
VhostUserMsgHeader {
request: request.into(),
flags: fl,
size,
_r: PhantomData,
}
}
/// Get message type.
pub fn get_code(&self) -> R {
// It's safe because R is marked as repr(u32).
unsafe { std::mem::transmute_copy::<u32, R>(&self.request) }
}
/// Set message type.
pub fn set_code(&mut self, request: R) {
self.request = request.into();
}
/// Get message version number.
pub fn get_version(&self) -> u32 {
self.flags & 0x3
}
/// Set message version number.
pub fn set_version(&mut self, ver: u32) {
self.flags &= !0x3;
self.flags |= ver & 0x3;
}
/// Check whether it's a reply message.
pub fn is_reply(&self) -> bool {
(self.flags & VhostUserHeaderFlag::REPLY.bits()) != 0
}
/// Mark message as reply.
pub fn set_reply(&mut self, is_reply: bool) {
if is_reply {
self.flags |= VhostUserHeaderFlag::REPLY.bits();
} else {
self.flags &= !VhostUserHeaderFlag::REPLY.bits();
}
}
/// Check whether reply for this message is requested.
pub fn is_need_reply(&self) -> bool {
(self.flags & VhostUserHeaderFlag::NEED_REPLY.bits()) != 0
}
/// Mark that reply for this message is needed.
pub fn set_need_reply(&mut self, need_reply: bool) {
if need_reply {
self.flags |= VhostUserHeaderFlag::NEED_REPLY.bits();
} else {
self.flags &= !VhostUserHeaderFlag::NEED_REPLY.bits();
}
}
/// Check whether it's the reply message for the request `req`.
pub fn is_reply_for(&self, req: &VhostUserMsgHeader<R>) -> bool {
self.is_reply() && !req.is_reply() && self.get_code() == req.get_code()
}
/// Get message size.
pub fn get_size(&self) -> u32 {
self.size
}
/// Set message size.
pub fn set_size(&mut self, size: u32) {
self.size = size;
}
}
impl<R: Req> Default for VhostUserMsgHeader<R> {
fn default() -> Self {
VhostUserMsgHeader {
request: 0,
flags: 0x1,
size: 0,
_r: PhantomData,
}
}
}
impl<T: Req> VhostUserMsgValidator for VhostUserMsgHeader<T> {
#[allow(clippy::if_same_then_else)]
fn is_valid(&self) -> bool {
if !self.get_code().is_valid() {
return false;
} else if self.size as usize > MAX_MSG_SIZE {
return false;
} else if self.get_version() != 0x1 {
return false;
} else if (self.flags & VhostUserHeaderFlag::RESERVED_BITS.bits()) != 0 {
return false;
}
true
}
}
bitflags! {
/// Transport specific flags in VirtIO feature set defined by vhost-user.
pub struct VhostUserVirtioFeatures: u64 {
/// Feature flag for the protocol feature.
const PROTOCOL_FEATURES = 0x4000_0000;
}
}
bitflags! {
/// Vhost-user protocol feature flags.
pub struct VhostUserProtocolFeatures: u64 {
/// Support multiple queues.
const MQ = 0x0000_0001;
/// Support logging through shared memory fd.
const LOG_SHMFD = 0x0000_0002;
/// Support broadcasting fake RARP packet.
const RARP = 0x0000_0004;
/// Support sending reply messages for requests with NEED_REPLY flag set.
const REPLY_ACK = 0x0000_0008;
/// Support setting MTU for virtio-net devices.
const MTU = 0x0000_0010;
/// Allow the slave to send requests to the master by an optional communication channel.
const SLAVE_REQ = 0x0000_0020;
/// Support setting slave endian by SET_VRING_ENDIAN.
const CROSS_ENDIAN = 0x0000_0040;
/// Support crypto operations.
const CRYPTO_SESSION = 0x0000_0080;
/// Support sending userfault_fd from slaves to masters.
const PAGEFAULT = 0x0000_0100;
/// Support Virtio device configuration.
const CONFIG = 0x0000_0200;
/// Allow the slave to send fds (at most 8 descriptors in each message) to the master.
const SLAVE_SEND_FD = 0x0000_0400;
/// Allow the slave to register a host notifier.
const HOST_NOTIFIER = 0x0000_0800;
/// Support inflight shmfd.
const INFLIGHT_SHMFD = 0x0000_1000;
}
}
/// A generic message to encapsulate a 64-bit value.
#[repr(packed)]
#[derive(Default)]
pub struct VhostUserU64 {
/// The encapsulated 64-bit common value.
pub value: u64,
}
impl VhostUserU64 {
/// Create a new instance.
pub fn new(value: u64) -> Self {
VhostUserU64 { value }
}
}
impl VhostUserMsgValidator for VhostUserU64 {}
/// Memory region descriptor for the SET_MEM_TABLE request.
#[repr(packed)]
#[derive(Default)]
pub struct VhostUserMemory {
/// Number of memory regions in the payload.
pub num_regions: u32,
/// Padding for alignment.
pub padding1: u32,
}
impl VhostUserMemory {
/// Create a new instance.
pub fn new(cnt: u32) -> Self {
VhostUserMemory {
num_regions: cnt,
padding1: 0,
}
}
}
impl VhostUserMsgValidator for VhostUserMemory {
#[allow(clippy::if_same_then_else)]
fn is_valid(&self) -> bool {
if self.padding1 != 0 {
return false;
} else if self.num_regions == 0 || self.num_regions > MAX_ATTACHED_FD_ENTRIES as u32 {
return false;
}
true
}
}
/// Memory region descriptors as payload for the SET_MEM_TABLE request.
#[repr(packed)]
#[derive(Default, Clone, Copy)]
pub struct VhostUserMemoryRegion {
/// Guest physical address of the memory region.
pub guest_phys_addr: u64,
/// Size of the memory region.
pub memory_size: u64,
/// Virtual address in the current process.
pub user_addr: u64,
/// Offset where region starts in the mapped memory.
pub mmap_offset: u64,
}
impl VhostUserMemoryRegion {
/// Create a new instance.
pub fn new(guest_phys_addr: u64, memory_size: u64, user_addr: u64, mmap_offset: u64) -> Self {
VhostUserMemoryRegion {
guest_phys_addr,
memory_size,
user_addr,
mmap_offset,
}
}
}
impl VhostUserMsgValidator for VhostUserMemoryRegion {
fn is_valid(&self) -> bool {
if self.memory_size == 0
|| self.guest_phys_addr.checked_add(self.memory_size).is_none()
|| self.user_addr.checked_add(self.memory_size).is_none()
|| self.mmap_offset.checked_add(self.memory_size).is_none()
{
return false;
}
true
}
}
/// Payload of the VhostUserMemory message.
pub type VhostUserMemoryPayload = Vec<VhostUserMemoryRegion>;
/// Vring state descriptor.
#[repr(packed)]
#[derive(Default)]
pub struct VhostUserVringState {
/// Vring index.
pub index: u32,
/// A common 32bit value to encapsulate vring state etc.
pub num: u32,
}
impl VhostUserVringState {
/// Create a new instance.
pub fn new(index: u32, num: u32) -> Self {
VhostUserVringState { index, num }
}
}
impl VhostUserMsgValidator for VhostUserVringState {}
bitflags! {
/// Flags for vring address.
pub struct VhostUserVringAddrFlags: u32 {
/// Support log of vring operations.
/// Modifications to "used" vring should be logged.
const VHOST_VRING_F_LOG = 0x1;
}
}
/// Vring address descriptor.
#[repr(packed)]
#[derive(Default)]
pub struct VhostUserVringAddr {
/// Vring index.
pub index: u32,
/// Vring flags defined by VhostUserVringAddrFlags.
pub flags: u32,
/// Ring address of the vring descriptor table.
pub descriptor: u64,
/// Ring address of the vring used ring.
pub used: u64,
/// Ring address of the vring available ring.
pub available: u64,
/// Guest address for logging.
pub log: u64,
}
impl VhostUserVringAddr {
/// Create a new instance.
pub fn new(
index: u32,
flags: VhostUserVringAddrFlags,
descriptor: u64,
used: u64,
available: u64,
log: u64,
) -> Self {
VhostUserVringAddr {
index,
flags: flags.bits(),
descriptor,
used,
available,
log,
}
}
/// Create a new instance from `VringConfigData`.
#[cfg_attr(feature = "cargo-clippy", allow(clippy::identity_conversion))]
pub fn from_config_data(index: u32, config_data: &VringConfigData) -> Self {
let log_addr = config_data.log_addr.unwrap_or(0);
VhostUserVringAddr {
index,
flags: config_data.flags,
descriptor: config_data.desc_table_addr,
used: config_data.used_ring_addr,
available: config_data.avail_ring_addr,
log: log_addr,
}
}
}
impl VhostUserMsgValidator for VhostUserVringAddr {
#[allow(clippy::if_same_then_else)]
fn is_valid(&self) -> bool {
if (self.flags & !VhostUserVringAddrFlags::all().bits()) != 0 {
return false;
} else if self.descriptor & 0xf != 0 {
return false;
} else if self.available & 0x1 != 0 {
return false;
} else if self.used & 0x3 != 0 {
return false;
}
true
}
}
bitflags! {
/// Flags for the device configuration message.
pub struct VhostUserConfigFlags: u32 {
/// Vhost master messages used for writeable fields.
const WRITABLE = 0x0;
/// Vhost master messages used for live migration.
const LIVE_MIGRATION = 0x1;
}
}
/// Message to read/write device configuration space.
#[repr(packed)]
#[derive(Default)]
pub struct VhostUserConfig {
/// Offset of virtio device's configuration space.
pub offset: u32,
/// Configuration space access size in bytes.
pub size: u32,
/// Flags for the device configuration operation.
pub flags: u32,
}
impl VhostUserConfig {
/// Create a new instance.
pub fn new(offset: u32, size: u32, flags: VhostUserConfigFlags) -> Self {
VhostUserConfig {
offset,
size,
flags: flags.bits(),
}
}
}
impl VhostUserMsgValidator for VhostUserConfig {
#[allow(clippy::if_same_then_else)]
fn is_valid(&self) -> bool {
if (self.flags & !VhostUserConfigFlags::all().bits()) != 0 {
return false;
} else if self.offset >= VHOST_USER_CONFIG_SIZE
|| self.size == 0
|| self.size > VHOST_USER_CONFIG_SIZE
|| self.size + self.offset > VHOST_USER_CONFIG_SIZE
{
return false;
}
true
}
}
/// Payload for the VhostUserConfig message.
pub type VhostUserConfigPayload = Vec<u8>;
/*
* TODO: support dirty log, live migration and IOTLB operations.
#[repr(packed)]
pub struct VhostUserVringArea {
pub index: u32,
pub flags: u32,
pub size: u64,
pub offset: u64,
}
#[repr(packed)]
pub struct VhostUserLog {
pub size: u64,
pub offset: u64,
}
#[repr(packed)]
pub struct VhostUserIotlb {
pub iova: u64,
pub size: u64,
pub user_addr: u64,
pub permission: u8,
pub optype: u8,
}
*/
bitflags! {
#[derive(Default)]
/// Flags for virtio-fs slave messages.
pub struct VhostUserFSSlaveMsgFlags: u64 {
/// Empty permission.
const EMPTY = 0x0;
/// Read permission.
const MAP_R = 0x1;
/// Write permission.
const MAP_W = 0x2;
}
}
/// Max entries in one virtio-fs slave request.
pub const VHOST_USER_FS_SLAVE_ENTRIES: usize = 8;
/// Slave request message to update the MMIO window.
#[repr(packed)]
#[derive(Default)]
pub struct VhostUserFSSlaveMsg {
/// TODO:
pub fd_offset: [u64; VHOST_USER_FS_SLAVE_ENTRIES],
/// TODO:
pub cache_offset: [u64; VHOST_USER_FS_SLAVE_ENTRIES],
/// Size of region to map.
pub len: [u64; VHOST_USER_FS_SLAVE_ENTRIES],
/// Flags for the mmap operation
pub flags: [VhostUserFSSlaveMsgFlags; VHOST_USER_FS_SLAVE_ENTRIES],
}
impl VhostUserMsgValidator for VhostUserFSSlaveMsg {
fn is_valid(&self) -> bool {
for i in 0..VHOST_USER_FS_SLAVE_ENTRIES {
if ({ self.flags[i] }.bits() & !VhostUserFSSlaveMsgFlags::all().bits()) != 0
|| self.fd_offset[i].checked_add(self.len[i]).is_none()
|| self.cache_offset[i].checked_add(self.len[i]).is_none()
{
return false;
}
}
true
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::mem;
#[test]
fn check_request_code() {
let code = MasterReq::NOOP;
assert!(!code.is_valid());
let code = MasterReq::MAX_CMD;
assert!(!code.is_valid());
let code = MasterReq::GET_FEATURES;
assert!(code.is_valid());
}
#[test]
fn msg_header_ops() {
let mut hdr = VhostUserMsgHeader::new(MasterReq::GET_FEATURES, 0, 0x100);
assert_eq!(hdr.get_code(), MasterReq::GET_FEATURES);
hdr.set_code(MasterReq::SET_FEATURES);
assert_eq!(hdr.get_code(), MasterReq::SET_FEATURES);
assert_eq!(hdr.get_version(), 0x1);
assert_eq!(hdr.is_reply(), false);
hdr.set_reply(true);
assert_eq!(hdr.is_reply(), true);
hdr.set_reply(false);
assert_eq!(hdr.is_need_reply(), false);
hdr.set_need_reply(true);
assert_eq!(hdr.is_need_reply(), true);
hdr.set_need_reply(false);
assert_eq!(hdr.get_size(), 0x100);
hdr.set_size(0x200);
assert_eq!(hdr.get_size(), 0x200);
assert_eq!(hdr.is_need_reply(), false);
assert_eq!(hdr.is_reply(), false);
assert_eq!(hdr.get_version(), 0x1);
// Check message length
assert!(hdr.is_valid());
hdr.set_size(0x2000);
assert!(!hdr.is_valid());
hdr.set_size(0x100);
assert_eq!(hdr.get_size(), 0x100);
assert!(hdr.is_valid());
hdr.set_size((MAX_MSG_SIZE - mem::size_of::<VhostUserMsgHeader<MasterReq>>()) as u32);
assert!(hdr.is_valid());
hdr.set_size(0x0);
assert!(hdr.is_valid());
// Check version
hdr.set_version(0x0);
assert!(!hdr.is_valid());
hdr.set_version(0x2);
assert!(!hdr.is_valid());
hdr.set_version(0x1);
assert!(hdr.is_valid());
}
#[test]
fn check_user_memory() {
let mut msg = VhostUserMemory::new(1);
assert!(msg.is_valid());
msg.num_regions = MAX_ATTACHED_FD_ENTRIES as u32;
assert!(msg.is_valid());
msg.num_regions += 1;
assert!(!msg.is_valid());
msg.num_regions = 0xFFFFFFFF;
assert!(!msg.is_valid());
msg.num_regions = MAX_ATTACHED_FD_ENTRIES as u32;
msg.padding1 = 1;
assert!(!msg.is_valid());
}
#[test]
fn check_user_memory_region() {
let mut msg = VhostUserMemoryRegion {
guest_phys_addr: 0,
memory_size: 0x1000,
user_addr: 0,
mmap_offset: 0,
};
assert!(msg.is_valid());
msg.guest_phys_addr = 0xFFFFFFFFFFFFEFFF;
assert!(msg.is_valid());
msg.guest_phys_addr = 0xFFFFFFFFFFFFF000;
assert!(!msg.is_valid());
msg.guest_phys_addr = 0xFFFFFFFFFFFF0000;
msg.memory_size = 0;
assert!(!msg.is_valid());
}
#[test]
fn check_user_vring_addr() {
let mut msg =
VhostUserVringAddr::new(0, VhostUserVringAddrFlags::all(), 0x0, 0x0, 0x0, 0x0);
assert!(msg.is_valid());
msg.descriptor = 1;
assert!(!msg.is_valid());
msg.descriptor = 0;
msg.available = 1;
assert!(!msg.is_valid());
msg.available = 0;
msg.used = 1;
assert!(!msg.is_valid());
msg.used = 0;
msg.flags |= 0x80000000;
assert!(!msg.is_valid());
msg.flags &= !0x80000000;
}
#[test]
#[ignore]
fn check_user_config_msg() {
let mut msg = VhostUserConfig::new(
VHOST_USER_CONFIG_OFFSET,
VHOST_USER_CONFIG_SIZE - VHOST_USER_CONFIG_OFFSET,
VhostUserConfigFlags::WRITABLE,
);
assert!(msg.is_valid());
msg.size = 0;
assert!(!msg.is_valid());
msg.size = 1;
assert!(msg.is_valid());
msg.offset = 0;
assert!(!msg.is_valid());
msg.offset = VHOST_USER_CONFIG_SIZE;
assert!(!msg.is_valid());
msg.offset = VHOST_USER_CONFIG_SIZE - 1;
assert!(msg.is_valid());
msg.size = 2;
assert!(!msg.is_valid());
msg.size = 1;
msg.flags |= VhostUserConfigFlags::LIVE_MIGRATION.bits();
assert!(msg.is_valid());
msg.flags |= 0x4;
assert!(!msg.is_valid());
}
}

View File

@ -1,260 +0,0 @@
// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! The protocol for vhost-user is based on the existing implementation of vhost for the Linux
//! Kernel. The protocol defines two sides of the communication, master and slave. Master is
//! the application that shares its virtqueues. Slave is the consumer of the virtqueues.
//!
//! The communication channel between the master and the slave includes two sub channels. One is
//! used to send requests from the master to the slave and optional replies from the slave to the
//! master. This sub channel is created on master startup by connecting to the slave service
//! endpoint. The other is used to send requests from the slave to the master and optional replies
//! from the master to the slave. This sub channel is created by the master issuing a
//! VHOST_USER_SET_SLAVE_REQ_FD request to the slave with an auxiliary file descriptor.
//!
//! Unix domain socket is used as the underlying communication channel because the master needs to
//! send file descriptors to the slave.
//!
//! Most messages that can be sent via the Unix domain socket implementing vhost-user have an
//! equivalent ioctl to the kernel implementation.
use libc;
use std::io::Error as IOError;
mod connection;
pub mod message;
pub use self::connection::Listener;
#[cfg(feature = "vhost-user-master")]
mod master;
#[cfg(feature = "vhost-user-master")]
pub use self::master::{Master, VhostUserMaster};
#[cfg(feature = "vhost-user-master")]
mod master_req_handler;
#[cfg(feature = "vhost-user-master")]
pub use self::master_req_handler::{MasterReqHandler, VhostUserMasterReqHandler};
#[cfg(feature = "vhost-user-slave")]
mod slave;
#[cfg(feature = "vhost-user-slave")]
pub use self::slave::SlaveListener;
#[cfg(feature = "vhost-user-slave")]
mod slave_req_handler;
#[cfg(feature = "vhost-user-slave")]
pub use self::slave_req_handler::{SlaveReqHandler, VhostUserSlaveReqHandler};
#[cfg(feature = "vhost-user-slave")]
mod slave_fs_cache;
#[cfg(feature = "vhost-user-slave")]
pub use self::slave_fs_cache::SlaveFsCacheReq;
pub mod sock_ctrl_msg;
/// Errors for vhost-user operations
#[derive(Debug)]
pub enum Error {
/// Invalid parameters.
InvalidParam,
/// Unsupported operations due to that the protocol feature hasn't been negotiated.
InvalidOperation,
/// Invalid message format, flag or content.
InvalidMessage,
/// Only part of a message have been sent or received successfully
PartialMessage,
/// Message is too large
OversizedMsg,
/// Fd array in question is too big or too small
IncorrectFds,
/// Can't connect to peer.
SocketConnect(std::io::Error),
/// Generic socket errors.
SocketError(std::io::Error),
/// The socket is broken or has been closed.
SocketBroken(std::io::Error),
/// Should retry the socket operation again.
SocketRetry(std::io::Error),
/// Failure from the slave side.
SlaveInternalError,
/// Failure from the master side.
MasterInternalError,
/// Virtio/protocol features mismatch.
FeatureMismatch,
/// Error from request handler
ReqHandlerError(IOError),
}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Error::InvalidParam => write!(f, "invalid parameters"),
Error::InvalidOperation => write!(f, "invalid operation"),
Error::InvalidMessage => write!(f, "invalid message"),
Error::PartialMessage => write!(f, "partial message"),
Error::OversizedMsg => write!(f, "oversized message"),
Error::IncorrectFds => write!(f, "wrong number of attached fds"),
Error::SocketError(e) => write!(f, "socket error: {}", e),
Error::SocketConnect(e) => write!(f, "can't connect to peer: {}", e),
Error::SocketBroken(e) => write!(f, "socket is broken: {}", e),
Error::SocketRetry(e) => write!(f, "temporary socket error: {}", e),
Error::SlaveInternalError => write!(f, "slave internal error"),
Error::MasterInternalError => write!(f, "Master internal error"),
Error::FeatureMismatch => write!(f, "virtio/protocol features mismatch"),
Error::ReqHandlerError(e) => write!(f, "handler failed to handle request: {}", e),
}
}
}
impl Error {
/// Determine whether to rebuild the underline communication channel.
pub fn should_reconnect(&self) -> bool {
match *self {
// Should reconnect because it may be caused by temporary network errors.
Error::PartialMessage => true,
// Should reconnect because the underline socket is broken.
Error::SocketBroken(_) => true,
// Slave internal error, hope it recovers on reconnect.
Error::SlaveInternalError => true,
// Master internal error, hope it recovers on reconnect.
Error::MasterInternalError => true,
// Should just retry the IO operation instead of rebuilding the underline connection.
Error::SocketRetry(_) => false,
Error::InvalidParam | Error::InvalidOperation => false,
Error::InvalidMessage | Error::IncorrectFds | Error::OversizedMsg => false,
Error::SocketError(_) | Error::SocketConnect(_) => false,
Error::FeatureMismatch => false,
Error::ReqHandlerError(_) => false,
}
}
}
impl std::convert::From<vmm_sys_util::errno::Error> for Error {
/// Convert raw socket errors into meaningful vhost-user errors.
///
/// The vmm_sys_util::errno::Error is a simple wrapper over the raw errno, which doesn't means much
/// to the vhost-user connection manager. So convert it into meaningful errors to simplify
/// the connection manager logic.
///
/// # Return:
/// * - Error::SocketRetry: temporary error caused by signals or short of resources.
/// * - Error::SocketBroken: the underline socket is broken.
/// * - Error::SocketError: other socket related errors.
#[allow(unreachable_patterns)] // EWOULDBLOCK equals to EGAIN on linux
fn from(err: vmm_sys_util::errno::Error) -> Self {
match err.errno() {
// The socket is marked nonblocking and the requested operation would block.
libc::EAGAIN => Error::SocketRetry(IOError::from_raw_os_error(libc::EAGAIN)),
// The socket is marked nonblocking and the requested operation would block.
libc::EWOULDBLOCK => Error::SocketRetry(IOError::from_raw_os_error(libc::EWOULDBLOCK)),
// A signal occurred before any data was transmitted
libc::EINTR => Error::SocketRetry(IOError::from_raw_os_error(libc::EINTR)),
// The output queue for a network interface was full. This generally indicates
// that the interface has stopped sending, but may be caused by transient congestion.
libc::ENOBUFS => Error::SocketRetry(IOError::from_raw_os_error(libc::ENOBUFS)),
// No memory available.
libc::ENOMEM => Error::SocketRetry(IOError::from_raw_os_error(libc::ENOMEM)),
// Connection reset by peer.
libc::ECONNRESET => Error::SocketBroken(IOError::from_raw_os_error(libc::ECONNRESET)),
// The local end has been shut down on a connection oriented socket. In this case the
// process will also receive a SIGPIPE unless MSG_NOSIGNAL is set.
libc::EPIPE => Error::SocketBroken(IOError::from_raw_os_error(libc::EPIPE)),
// Write permission is denied on the destination socket file, or search permission is
// denied for one of the directories the path prefix.
libc::EACCES => Error::SocketConnect(IOError::from_raw_os_error(libc::EACCES)),
// Catch all other errors
e => Error::SocketError(IOError::from_raw_os_error(e)),
}
}
}
/// Result of vhost-user operations
pub type Result<T> = std::result::Result<T, Error>;
/// Result of request handler.
pub type HandlerResult<T> = std::result::Result<T, IOError>;
#[cfg(all(test, feature = "vhost-user-master", feature = "vhost-user-slave"))]
mod dummy_slave;
#[cfg(all(test, feature = "vhost-user-master", feature = "vhost-user-slave"))]
mod tests {
use super::dummy_slave::{DummySlaveReqHandler, VIRTIO_FEATURES};
use super::message::*;
use super::*;
use crate::backend::VhostBackend;
use std::sync::{Arc, Barrier, Mutex};
use std::thread;
fn create_slave<S: VhostUserSlaveReqHandler>(
path: &str,
backend: Arc<Mutex<S>>,
) -> (Master, SlaveReqHandler<S>) {
let mut slave_listener = SlaveListener::new(path, true, backend).unwrap();
let master = Master::connect(path, 1).unwrap();
(master, slave_listener.accept().unwrap().unwrap())
}
#[test]
fn create_dummy_slave() {
let mut slave = DummySlaveReqHandler::new();
slave.set_owner().unwrap();
assert!(slave.set_owner().is_err());
}
#[test]
fn test_set_owner() {
let slave_be = Arc::new(Mutex::new(DummySlaveReqHandler::new()));
let (mut master, mut slave) =
create_slave("/tmp/vhost_user_lib_unit_test_owner", slave_be.clone());
assert_eq!(slave_be.lock().unwrap().owned, false);
master.set_owner().unwrap();
slave.handle_request().unwrap();
assert_eq!(slave_be.lock().unwrap().owned, true);
master.set_owner().unwrap();
assert!(slave.handle_request().is_err());
assert_eq!(slave_be.lock().unwrap().owned, true);
}
#[test]
fn test_set_features() {
let mbar = Arc::new(Barrier::new(2));
let sbar = mbar.clone();
let slave_be = Arc::new(Mutex::new(DummySlaveReqHandler::new()));
let (mut master, mut slave) =
create_slave("/tmp/vhost_user_lib_unit_test_feature", slave_be.clone());
thread::spawn(move || {
slave.handle_request().unwrap();
assert_eq!(slave_be.lock().unwrap().owned, true);
slave.handle_request().unwrap();
slave.handle_request().unwrap();
assert_eq!(
slave_be.lock().unwrap().acked_features,
VIRTIO_FEATURES & !0x1
);
slave.handle_request().unwrap();
slave.handle_request().unwrap();
assert_eq!(
slave_be.lock().unwrap().acked_protocol_features,
VhostUserProtocolFeatures::all().bits()
);
sbar.wait();
});
master.set_owner().unwrap();
// set virtio features
let features = master.get_features().unwrap();
assert_eq!(features, VIRTIO_FEATURES);
master.set_features(VIRTIO_FEATURES & !0x1).unwrap();
// set vhost protocol features
let features = master.get_protocol_features().unwrap();
assert_eq!(features.bits(), VhostUserProtocolFeatures::all().bits());
master.set_protocol_features(features).unwrap();
mbar.wait();
}
}

View File

@ -1,48 +0,0 @@
// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Traits and Structs for vhost-user slave.
use std::sync::{Arc, Mutex};
use super::connection::{Endpoint, Listener};
use super::message::*;
use super::{Result, SlaveReqHandler, VhostUserSlaveReqHandler};
/// Vhost-user slave side connection listener.
pub struct SlaveListener<S: VhostUserSlaveReqHandler> {
listener: Listener,
backend: Option<Arc<Mutex<S>>>,
}
/// Sets up a listener for incoming master connections, and handles construction
/// of a Slave on success.
impl<S: VhostUserSlaveReqHandler> SlaveListener<S> {
/// Create a unix domain socket for incoming master connections.
///
/// Be careful, the file at `path` will be unlinked if unlink is true
pub fn new(path: &str, unlink: bool, backend: Arc<Mutex<S>>) -> Result<Self> {
Ok(SlaveListener {
listener: Listener::new(path, unlink)?,
backend: Some(backend),
})
}
/// Accept an incoming connection from the master, returning Some(Slave) on
/// success, or None if the socket is nonblocking and no incoming connection
/// was detected
pub fn accept(&mut self) -> Result<Option<SlaveReqHandler<S>>> {
if let Some(fd) = self.listener.accept()? {
return Ok(Some(SlaveReqHandler::new(
Endpoint::<MasterReq>::from_stream(fd),
self.backend.take().unwrap(),
)));
}
Ok(None)
}
/// Change blocking status on the listener.
pub fn set_nonblocking(&self, block: bool) -> Result<()> {
self.listener.set_nonblocking(block)
}
}

View File

@ -1,94 +0,0 @@
// Copyright (C) 2020 Alibaba Cloud Computing. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use super::connection::Endpoint;
use super::message::*;
use super::{Error, HandlerResult, Result, VhostUserMasterReqHandler};
use std::io;
use std::mem;
use std::os::unix::io::RawFd;
use std::os::unix::net::UnixStream;
use std::sync::{Arc, Mutex};
struct SlaveFsCacheReqInternal {
sock: Endpoint<SlaveReq>,
}
/// A vhost-user slave endpoint which sends fs cache requests to the master
#[derive(Clone)]
pub struct SlaveFsCacheReq {
// underlying Unix domain socket for communication
node: Arc<Mutex<SlaveFsCacheReqInternal>>,
// whether the endpoint has encountered any failure
error: Option<i32>,
}
impl SlaveFsCacheReq {
fn new(ep: Endpoint<SlaveReq>) -> Self {
SlaveFsCacheReq {
node: Arc::new(Mutex::new(SlaveFsCacheReqInternal { sock: ep })),
error: None,
}
}
/// Create a new instance.
pub fn from_stream(sock: UnixStream) -> Self {
Self::new(Endpoint::<SlaveReq>::from_stream(sock))
}
fn send_message(
&mut self,
flags: SlaveReq,
fs: &VhostUserFSSlaveMsg,
fds: Option<&[RawFd]>,
) -> Result<()> {
self.check_state()?;
let len = mem::size_of::<VhostUserFSSlaveMsg>();
let mut hdr = VhostUserMsgHeader::new(flags, 0, len as u32);
hdr.set_need_reply(true);
self.node.lock().unwrap().sock.send_message(&hdr, fs, fds)?;
self.wait_for_ack(&hdr)
}
fn wait_for_ack(&mut self, hdr: &VhostUserMsgHeader<SlaveReq>) -> Result<()> {
self.check_state()?;
let (reply, body, rfds) = self.node.lock().unwrap().sock.recv_body::<VhostUserU64>()?;
if !reply.is_reply_for(&hdr) || rfds.is_some() || !body.is_valid() {
Endpoint::<SlaveReq>::close_rfds(rfds);
return Err(Error::InvalidMessage);
}
if body.value != 0 {
return Err(Error::MasterInternalError);
}
Ok(())
}
fn check_state(&self) -> Result<()> {
match self.error {
Some(e) => Err(Error::SocketBroken(std::io::Error::from_raw_os_error(e))),
None => Ok(()),
}
}
/// Mark endpoint as failed with specified error code.
pub fn set_failed(&mut self, error: i32) {
self.error = Some(error);
}
}
impl VhostUserMasterReqHandler for SlaveFsCacheReq {
/// Handle virtio-fs map file requests from the slave.
fn fs_slave_map(&mut self, fs: &VhostUserFSSlaveMsg, fd: RawFd) -> HandlerResult<()> {
self.send_message(SlaveReq::FS_MAP, fs, Some(&[fd]))
.or_else(|e| Err(io::Error::new(io::ErrorKind::Other, format!("{}", e))))
}
/// Handle virtio-fs unmap file requests from the slave.
fn fs_slave_unmap(&mut self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<()> {
self.send_message(SlaveReq::FS_UNMAP, fs, None)
.or_else(|e| Err(io::Error::new(io::ErrorKind::Other, format!("{}", e))))
}
}

View File

@ -1,614 +0,0 @@
// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Traits and Structs to handle vhost-user requests from the master to the slave.
use std::mem;
use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
use std::os::unix::net::UnixStream;
use std::slice;
use std::sync::{Arc, Mutex};
use super::connection::Endpoint;
use super::message::*;
use super::slave_fs_cache::SlaveFsCacheReq;
use super::{Error, Result};
/// Trait to handle vhost-user requests from the master to the slave.
#[allow(missing_docs)]
pub trait VhostUserSlaveReqHandler {
fn set_owner(&mut self) -> Result<()>;
fn reset_owner(&mut self) -> Result<()>;
fn get_features(&mut self) -> Result<u64>;
fn set_features(&mut self, features: u64) -> Result<()>;
fn set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], fds: &[RawFd]) -> Result<()>;
fn set_vring_num(&mut self, index: u32, num: u32) -> Result<()>;
fn set_vring_addr(
&mut self,
index: u32,
flags: VhostUserVringAddrFlags,
descriptor: u64,
used: u64,
available: u64,
log: u64,
) -> Result<()>;
fn set_vring_base(&mut self, index: u32, base: u32) -> Result<()>;
fn get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState>;
fn set_vring_kick(&mut self, index: u8, fd: Option<RawFd>) -> Result<()>;
fn set_vring_call(&mut self, index: u8, fd: Option<RawFd>) -> Result<()>;
fn set_vring_err(&mut self, index: u8, fd: Option<RawFd>) -> Result<()>;
fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures>;
fn set_protocol_features(&mut self, features: u64) -> Result<()>;
fn get_queue_num(&mut self) -> Result<u64>;
fn set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()>;
fn get_config(
&mut self,
offset: u32,
size: u32,
flags: VhostUserConfigFlags,
) -> Result<Vec<u8>>;
fn set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>;
fn set_slave_req_fd(&mut self, _vu_req: SlaveFsCacheReq) {}
}
/// A vhost-user slave endpoint which relays all received requests from the
/// master to the virtio backend device object.
///
/// The lifetime of the SlaveReqHandler object should be the same as the underline Unix Domain
/// Socket, so it gets simpler to recover from disconnect.
pub struct SlaveReqHandler<S: VhostUserSlaveReqHandler> {
// underlying Unix domain socket for communication
main_sock: Endpoint<MasterReq>,
// the vhost-user backend device object
backend: Arc<Mutex<S>>,
virtio_features: u64,
acked_virtio_features: u64,
protocol_features: VhostUserProtocolFeatures,
acked_protocol_features: u64,
// sending ack for messages without payload
reply_ack_enabled: bool,
// whether the endpoint has encountered any failure
error: Option<i32>,
}
impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
/// Create a vhost-user slave endpoint.
pub(super) fn new(main_sock: Endpoint<MasterReq>, backend: Arc<Mutex<S>>) -> Self {
SlaveReqHandler {
main_sock,
backend,
virtio_features: 0,
acked_virtio_features: 0,
protocol_features: VhostUserProtocolFeatures::empty(),
acked_protocol_features: 0,
reply_ack_enabled: false,
error: None,
}
}
/// Create a new vhost-user slave endpoint.
///
/// # Arguments
/// * - `path` - path of Unix domain socket listener to connect to
/// * - `backend` - handler for requests from the master to the slave
pub fn connect(path: &str, backend: Arc<Mutex<S>>) -> Result<Self> {
Ok(Self::new(Endpoint::<MasterReq>::connect(path)?, backend))
}
/// Mark endpoint as failed with specified error code.
pub fn set_failed(&mut self, error: i32) {
self.error = Some(error);
}
/// Receive and handle one incoming request message from the master.
/// The caller needs to:
/// . serialize calls to this function
/// . decide what to do when error happens
/// . optional recover from failure
pub fn handle_request(&mut self) -> Result<()> {
// Return error if the endpoint is already in failed state.
self.check_state()?;
// The underlying communication channel is a Unix domain socket in
// stream mode, and recvmsg() is a little tricky here. To successfully
// receive attached file descriptors, we need to receive messages and
// corresponding attached file descriptors in this way:
// . recv messsage header and optional attached file
// . validate message header
// . recv optional message body and payload according size field in
// message header
// . validate message body and optional payload
let (hdr, rfds) = self.main_sock.recv_header()?;
let rfds = self.check_attached_rfds(&hdr, rfds)?;
let (size, buf) = match hdr.get_size() {
0 => (0, vec![0u8; 0]),
len => {
let (size2, rbuf) = self.main_sock.recv_data(len as usize)?;
if size2 != len as usize {
return Err(Error::InvalidMessage);
}
(size2, rbuf)
}
};
match hdr.get_code() {
MasterReq::SET_OWNER => {
self.check_request_size(&hdr, size, 0)?;
self.backend.lock().unwrap().set_owner()?;
}
MasterReq::RESET_OWNER => {
self.check_request_size(&hdr, size, 0)?;
self.backend.lock().unwrap().reset_owner()?;
}
MasterReq::GET_FEATURES => {
self.check_request_size(&hdr, size, 0)?;
let features = self.backend.lock().unwrap().get_features()?;
let msg = VhostUserU64::new(features);
self.send_reply_message(&hdr, &msg)?;
self.virtio_features = features;
self.update_reply_ack_flag();
}
MasterReq::SET_FEATURES => {
let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?;
self.backend.lock().unwrap().set_features(msg.value)?;
self.acked_virtio_features = msg.value;
self.update_reply_ack_flag();
}
MasterReq::SET_MEM_TABLE => {
let res = self.set_mem_table(&hdr, size, &buf, rfds);
self.send_ack_message(&hdr, res)?;
}
MasterReq::SET_VRING_NUM => {
let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
let res = self
.backend
.lock()
.unwrap()
.set_vring_num(msg.index, msg.num);
self.send_ack_message(&hdr, res)?;
}
MasterReq::SET_VRING_ADDR => {
let msg = self.extract_request_body::<VhostUserVringAddr>(&hdr, size, &buf)?;
let flags = match VhostUserVringAddrFlags::from_bits(msg.flags) {
Some(val) => val,
None => return Err(Error::InvalidMessage),
};
let res = self.backend.lock().unwrap().set_vring_addr(
msg.index,
flags,
msg.descriptor,
msg.used,
msg.available,
msg.log,
);
self.send_ack_message(&hdr, res)?;
}
MasterReq::SET_VRING_BASE => {
let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
let res = self
.backend
.lock()
.unwrap()
.set_vring_base(msg.index, msg.num);
self.send_ack_message(&hdr, res)?;
}
MasterReq::GET_VRING_BASE => {
let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
let reply = self.backend.lock().unwrap().get_vring_base(msg.index)?;
self.send_reply_message(&hdr, &reply)?;
}
MasterReq::SET_VRING_CALL => {
self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
let (index, rfds) = self.handle_vring_fd_request(&buf, rfds)?;
let res = self.backend.lock().unwrap().set_vring_call(index, rfds);
self.send_ack_message(&hdr, res)?;
}
MasterReq::SET_VRING_KICK => {
self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
let (index, rfds) = self.handle_vring_fd_request(&buf, rfds)?;
let res = self.backend.lock().unwrap().set_vring_kick(index, rfds);
self.send_ack_message(&hdr, res)?;
}
MasterReq::SET_VRING_ERR => {
self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
let (index, rfds) = self.handle_vring_fd_request(&buf, rfds)?;
let res = self.backend.lock().unwrap().set_vring_err(index, rfds);
self.send_ack_message(&hdr, res)?;
}
MasterReq::GET_PROTOCOL_FEATURES => {
self.check_request_size(&hdr, size, 0)?;
let features = self.backend.lock().unwrap().get_protocol_features()?;
let msg = VhostUserU64::new(features.bits());
self.send_reply_message(&hdr, &msg)?;
self.protocol_features = features;
self.update_reply_ack_flag();
}
MasterReq::SET_PROTOCOL_FEATURES => {
let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?;
self.backend
.lock()
.unwrap()
.set_protocol_features(msg.value)?;
self.acked_protocol_features = msg.value;
self.update_reply_ack_flag();
}
MasterReq::GET_QUEUE_NUM => {
if self.acked_protocol_features & VhostUserProtocolFeatures::MQ.bits() == 0 {
return Err(Error::InvalidOperation);
}
self.check_request_size(&hdr, size, 0)?;
let num = self.backend.lock().unwrap().get_queue_num()?;
let msg = VhostUserU64::new(num);
self.send_reply_message(&hdr, &msg)?;
}
MasterReq::SET_VRING_ENABLE => {
let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
if self.acked_protocol_features & VhostUserProtocolFeatures::MQ.bits() == 0
&& msg.index > 0
{
return Err(Error::InvalidOperation);
}
let enable = match msg.num {
1 => true,
0 => false,
_ => return Err(Error::InvalidParam),
};
let res = self
.backend
.lock()
.unwrap()
.set_vring_enable(msg.index, enable);
self.send_ack_message(&hdr, res)?;
}
MasterReq::GET_CONFIG => {
if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
return Err(Error::InvalidOperation);
}
self.get_config(&hdr, &buf)?;
}
MasterReq::SET_CONFIG => {
if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
return Err(Error::InvalidOperation);
}
self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
self.set_config(&hdr, size, &buf)?;
}
MasterReq::SET_SLAVE_REQ_FD => {
if self.acked_protocol_features & VhostUserProtocolFeatures::SLAVE_REQ.bits() == 0 {
return Err(Error::InvalidOperation);
}
self.set_slave_req_fd(&hdr, rfds)?;
}
_ => {
return Err(Error::InvalidMessage);
}
}
Ok(())
}
fn set_mem_table(
&mut self,
hdr: &VhostUserMsgHeader<MasterReq>,
size: usize,
buf: &[u8],
rfds: Option<Vec<RawFd>>,
) -> Result<()> {
self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
// check message size is consistent
let hdrsize = mem::size_of::<VhostUserMemory>();
if size < hdrsize {
Endpoint::<MasterReq>::close_rfds(rfds);
return Err(Error::InvalidMessage);
}
let msg = unsafe { &*(buf.as_ptr() as *const VhostUserMemory) };
if !msg.is_valid() {
Endpoint::<MasterReq>::close_rfds(rfds);
return Err(Error::InvalidMessage);
}
if size != hdrsize + msg.num_regions as usize * mem::size_of::<VhostUserMemoryRegion>() {
Endpoint::<MasterReq>::close_rfds(rfds);
return Err(Error::InvalidMessage);
}
// validate number of fds matching number of memory regions
let fds = match rfds {
None => return Err(Error::InvalidMessage),
Some(fds) => {
if fds.len() != msg.num_regions as usize {
Endpoint::<MasterReq>::close_rfds(Some(fds));
return Err(Error::InvalidMessage);
}
fds
}
};
// Validate memory regions
let regions = unsafe {
slice::from_raw_parts(
buf.as_ptr().add(hdrsize) as *const VhostUserMemoryRegion,
msg.num_regions as usize,
)
};
for region in regions.iter() {
if !region.is_valid() {
Endpoint::<MasterReq>::close_rfds(Some(fds));
return Err(Error::InvalidMessage);
}
}
self.backend.lock().unwrap().set_mem_table(&regions, &fds)
}
fn get_config(&mut self, hdr: &VhostUserMsgHeader<MasterReq>, buf: &[u8]) -> Result<()> {
let msg = unsafe { &*(buf.as_ptr() as *const VhostUserConfig) };
if !msg.is_valid() {
return Err(Error::InvalidMessage);
}
let payload_offset = mem::size_of::<VhostUserConfig>();
if buf.len() - payload_offset != msg.size as usize {
return Err(Error::InvalidMessage);
}
let flags = match VhostUserConfigFlags::from_bits(msg.flags) {
Some(val) => val,
None => return Err(Error::InvalidMessage),
};
let res = self
.backend
.lock()
.unwrap()
.get_config(msg.offset, msg.size, flags);
// vhost-user slave's payload size MUST match master's request
// on success, uses zero length of payload to indicate an error
// to vhost-user master.
match res {
Ok(ref buf) if buf.len() == msg.size as usize => {
let reply = VhostUserConfig::new(msg.offset, buf.len() as u32, flags);
self.send_reply_with_payload(&hdr, &reply, buf.as_slice())?;
}
Ok(_) => {
let reply = VhostUserConfig::new(msg.offset, 0, flags);
self.send_reply_message(&hdr, &reply)?;
}
Err(_) => {
let reply = VhostUserConfig::new(msg.offset, 0, flags);
self.send_reply_message(&hdr, &reply)?;
}
}
Ok(())
}
fn set_config(
&mut self,
hdr: &VhostUserMsgHeader<MasterReq>,
size: usize,
buf: &[u8],
) -> Result<()> {
if size < mem::size_of::<VhostUserConfig>() {
return Err(Error::InvalidMessage);
}
let msg = unsafe { &*(buf.as_ptr() as *const VhostUserConfig) };
if !msg.is_valid() {
return Err(Error::InvalidMessage);
}
if size - mem::size_of::<VhostUserConfig>() != msg.size as usize {
return Err(Error::InvalidMessage);
}
let flags: VhostUserConfigFlags;
match VhostUserConfigFlags::from_bits(msg.flags) {
Some(val) => flags = val,
None => return Err(Error::InvalidMessage),
}
let res = self
.backend
.lock()
.unwrap()
.set_config(msg.offset, buf, flags);
self.send_ack_message(&hdr, res)?;
Ok(())
}
fn set_slave_req_fd(
&mut self,
hdr: &VhostUserMsgHeader<MasterReq>,
rfds: Option<Vec<RawFd>>,
) -> Result<()> {
if let Some(fds) = rfds {
if fds.len() == 1 {
let sock = unsafe { UnixStream::from_raw_fd(fds[0]) };
let vu_req = SlaveFsCacheReq::from_stream(sock);
self.backend.lock().unwrap().set_slave_req_fd(vu_req);
self.send_ack_message(&hdr, Ok(()))
} else {
Err(Error::InvalidMessage)
}
} else {
Err(Error::InvalidMessage)
}
}
fn handle_vring_fd_request(
&mut self,
buf: &[u8],
rfds: Option<Vec<RawFd>>,
) -> Result<(u8, Option<RawFd>)> {
let msg = unsafe { &*(buf.as_ptr() as *const VhostUserU64) };
if !msg.is_valid() {
return Err(Error::InvalidMessage);
}
// Bits (0-7) of the payload contain the vring index. Bit 8 is the
// invalid FD flag. This flag is set when there is no file descriptor
// in the ancillary data. This signals that polling will be used
// instead of waiting for the call.
let nofd = match msg.value & 0x100u64 {
0x100u64 => true,
_ => false,
};
let mut rfd = None;
match rfds {
Some(fds) => {
if !nofd && fds.len() == 1 {
rfd = Some(fds[0]);
} else if (nofd && !fds.is_empty()) || (!nofd && fds.len() != 1) {
Endpoint::<MasterReq>::close_rfds(Some(fds));
return Err(Error::InvalidMessage);
}
}
None => {
if !nofd {
return Err(Error::InvalidMessage);
}
}
}
Ok((msg.value as u8, rfd))
}
fn check_state(&self) -> Result<()> {
match self.error {
Some(e) => Err(Error::SocketBroken(std::io::Error::from_raw_os_error(e))),
None => Ok(()),
}
}
fn check_request_size(
&self,
hdr: &VhostUserMsgHeader<MasterReq>,
size: usize,
expected: usize,
) -> Result<()> {
if hdr.get_size() as usize != expected
|| hdr.is_reply()
|| hdr.get_version() != 0x1
|| size != expected
{
return Err(Error::InvalidMessage);
}
Ok(())
}
fn check_attached_rfds(
&self,
hdr: &VhostUserMsgHeader<MasterReq>,
rfds: Option<Vec<RawFd>>,
) -> Result<Option<Vec<RawFd>>> {
match hdr.get_code() {
MasterReq::SET_MEM_TABLE => Ok(rfds),
MasterReq::SET_VRING_CALL => Ok(rfds),
MasterReq::SET_VRING_KICK => Ok(rfds),
MasterReq::SET_VRING_ERR => Ok(rfds),
MasterReq::SET_LOG_BASE => Ok(rfds),
MasterReq::SET_LOG_FD => Ok(rfds),
MasterReq::SET_SLAVE_REQ_FD => Ok(rfds),
MasterReq::SET_INFLIGHT_FD => Ok(rfds),
_ => {
if rfds.is_some() {
Endpoint::<MasterReq>::close_rfds(rfds);
Err(Error::InvalidMessage)
} else {
Ok(rfds)
}
}
}
}
fn extract_request_body<'a, T: Sized + VhostUserMsgValidator>(
&self,
hdr: &VhostUserMsgHeader<MasterReq>,
size: usize,
buf: &'a [u8],
) -> Result<&'a T> {
self.check_request_size(hdr, size, mem::size_of::<T>())?;
let msg = unsafe { &*(buf.as_ptr() as *const T) };
if !msg.is_valid() {
return Err(Error::InvalidMessage);
}
Ok(msg)
}
fn update_reply_ack_flag(&mut self) {
let vflag = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits();
let pflag = VhostUserProtocolFeatures::REPLY_ACK;
if (self.virtio_features & vflag) != 0
&& (self.acked_virtio_features & vflag) != 0
&& self.protocol_features.contains(pflag)
&& (self.acked_protocol_features & pflag.bits()) != 0
{
self.reply_ack_enabled = true;
} else {
self.reply_ack_enabled = false;
}
}
fn new_reply_header<T: Sized>(
&self,
req: &VhostUserMsgHeader<MasterReq>,
payload_size: usize,
) -> Result<VhostUserMsgHeader<MasterReq>> {
if mem::size_of::<T>() > MAX_MSG_SIZE {
return Err(Error::InvalidParam);
}
self.check_state()?;
Ok(VhostUserMsgHeader::new(
req.get_code(),
VhostUserHeaderFlag::REPLY.bits(),
(mem::size_of::<T>() + payload_size) as u32,
))
}
fn send_ack_message(
&mut self,
req: &VhostUserMsgHeader<MasterReq>,
res: Result<()>,
) -> Result<()> {
if self.reply_ack_enabled {
let hdr = self.new_reply_header::<VhostUserU64>(req, 0)?;
let val = match res {
Ok(_) => 0,
Err(_) => 1,
};
let msg = VhostUserU64::new(val);
self.main_sock.send_message(&hdr, &msg, None)?;
}
Ok(())
}
fn send_reply_message<T>(
&mut self,
req: &VhostUserMsgHeader<MasterReq>,
msg: &T,
) -> Result<()> {
let hdr = self.new_reply_header::<T>(req, 0)?;
self.main_sock.send_message(&hdr, msg, None)?;
Ok(())
}
fn send_reply_with_payload<T, P>(
&mut self,
req: &VhostUserMsgHeader<MasterReq>,
msg: &T,
payload: &[P],
) -> Result<()>
where
T: Sized,
P: Sized,
{
let hdr = self.new_reply_header::<T>(req, payload.len())?;
self.main_sock
.send_message_with_payload(&hdr, msg, payload, None)?;
Ok(())
}
}
impl<S: VhostUserSlaveReqHandler> AsRawFd for SlaveReqHandler<S> {
fn as_raw_fd(&self) -> RawFd {
self.main_sock.as_raw_fd()
}
}

View File

@ -1,465 +0,0 @@
// Copyright 2017 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//! Used to send and receive messages with file descriptors on sockets that accept control messages
//! (e.g. Unix domain sockets).
// TODO: move this file into the vmm-sys-util crate
use std::fs::File;
use std::mem::size_of;
use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
use std::os::unix::net::{UnixDatagram, UnixStream};
use std::ptr::{copy_nonoverlapping, null_mut, write_unaligned};
use libc::{
c_long, c_void, cmsghdr, iovec, msghdr, recvmsg, sendmsg, MSG_NOSIGNAL, SCM_RIGHTS, SOL_SOCKET,
};
use vmm_sys_util::errno::{Error, Result};
// Each of the following macros performs the same function as their C counterparts. They are each
// macros because they are used to size statically allocated arrays.
macro_rules! CMSG_ALIGN {
($len:expr) => {
(($len) + size_of::<c_long>() - 1) & !(size_of::<c_long>() - 1)
};
}
macro_rules! CMSG_SPACE {
($len:expr) => {
size_of::<cmsghdr>() + CMSG_ALIGN!($len)
};
}
macro_rules! CMSG_LEN {
($len:expr) => {
size_of::<cmsghdr>() + ($len)
};
}
// This function (macro in the C version) is not used in any compile time constant slots, so is just
// an ordinary function. The returned pointer is hard coded to be RawFd because that's all that this
// module supports.
#[allow(non_snake_case)]
#[inline(always)]
fn CMSG_DATA(cmsg_buffer: *mut cmsghdr) -> *mut RawFd {
// Essentially returns a pointer to just past the header.
cmsg_buffer.wrapping_offset(1) as *mut RawFd
}
// This function is like CMSG_NEXT, but safer because it reads only from references, although it
// does some pointer arithmetic on cmsg_ptr.
#[cfg_attr(feature = "cargo-clippy", allow(clippy::cast_ptr_alignment))]
fn get_next_cmsg(msghdr: &msghdr, cmsg: &cmsghdr, cmsg_ptr: *mut cmsghdr) -> *mut cmsghdr {
let next_cmsg = (cmsg_ptr as *mut u8).wrapping_add(CMSG_ALIGN!(cmsg.cmsg_len)) as *mut cmsghdr;
if next_cmsg
.wrapping_offset(1)
.wrapping_sub(msghdr.msg_control as usize) as usize
> msghdr.msg_controllen
{
null_mut()
} else {
next_cmsg
}
}
const CMSG_BUFFER_INLINE_CAPACITY: usize = CMSG_SPACE!(size_of::<RawFd>() * 32);
enum CmsgBuffer {
Inline([u64; (CMSG_BUFFER_INLINE_CAPACITY + 7) / 8]),
Heap(Box<[cmsghdr]>),
}
impl CmsgBuffer {
fn with_capacity(capacity: usize) -> CmsgBuffer {
let cap_in_cmsghdr_units =
(capacity.checked_add(size_of::<cmsghdr>()).unwrap() - 1) / size_of::<cmsghdr>();
if capacity <= CMSG_BUFFER_INLINE_CAPACITY {
CmsgBuffer::Inline([0u64; (CMSG_BUFFER_INLINE_CAPACITY + 7) / 8])
} else {
CmsgBuffer::Heap(
vec![
cmsghdr {
cmsg_len: 0,
cmsg_level: 0,
cmsg_type: 0,
};
cap_in_cmsghdr_units
]
.into_boxed_slice(),
)
}
}
fn as_mut_ptr(&mut self) -> *mut cmsghdr {
match self {
CmsgBuffer::Inline(a) => a.as_mut_ptr() as *mut cmsghdr,
CmsgBuffer::Heap(a) => a.as_mut_ptr(),
}
}
}
fn raw_sendmsg<D: IntoIovec>(fd: RawFd, out_data: &[D], out_fds: &[RawFd]) -> Result<usize> {
let cmsg_capacity = CMSG_SPACE!(size_of::<RawFd>() * out_fds.len());
let mut cmsg_buffer = CmsgBuffer::with_capacity(cmsg_capacity);
let mut iovecs = Vec::with_capacity(out_data.len());
for data in out_data {
iovecs.push(iovec {
iov_base: data.as_ptr() as *mut c_void,
iov_len: data.size(),
});
}
let mut msg = msghdr {
msg_name: null_mut(),
msg_namelen: 0,
msg_iov: iovecs.as_mut_ptr(),
msg_iovlen: iovecs.len(),
msg_control: null_mut(),
msg_controllen: 0,
msg_flags: 0,
};
if !out_fds.is_empty() {
let cmsg = cmsghdr {
cmsg_len: CMSG_LEN!(size_of::<RawFd>() * out_fds.len()),
cmsg_level: SOL_SOCKET,
cmsg_type: SCM_RIGHTS,
};
unsafe {
// Safe because cmsg_buffer was allocated to be large enough to contain cmsghdr.
write_unaligned(cmsg_buffer.as_mut_ptr() as *mut cmsghdr, cmsg);
// Safe because the cmsg_buffer was allocated to be large enough to hold out_fds.len()
// file descriptors.
copy_nonoverlapping(
out_fds.as_ptr(),
CMSG_DATA(cmsg_buffer.as_mut_ptr()),
out_fds.len(),
);
}
msg.msg_control = cmsg_buffer.as_mut_ptr() as *mut c_void;
msg.msg_controllen = cmsg_capacity;
}
// Safe because the msghdr was properly constructed from valid (or null) pointers of the
// indicated length and we check the return value.
let write_count = unsafe { sendmsg(fd, &msg, MSG_NOSIGNAL) };
if write_count == -1 {
Err(Error::last())
} else {
Ok(write_count as usize)
}
}
fn raw_recvmsg(fd: RawFd, iovecs: &mut [iovec], in_fds: &mut [RawFd]) -> Result<(usize, usize)> {
let cmsg_capacity = CMSG_SPACE!(size_of::<RawFd>() * in_fds.len());
let mut cmsg_buffer = CmsgBuffer::with_capacity(cmsg_capacity);
let mut msg = msghdr {
msg_name: null_mut(),
msg_namelen: 0,
msg_iov: iovecs.as_mut_ptr(),
msg_iovlen: iovecs.len(),
msg_control: null_mut(),
msg_controllen: 0,
msg_flags: 0,
};
if !in_fds.is_empty() {
msg.msg_control = cmsg_buffer.as_mut_ptr() as *mut c_void;
msg.msg_controllen = cmsg_capacity;
}
// Safe because the msghdr was properly constructed from valid (or null) pointers of the
// indicated length and we check the return value.
let total_read = unsafe { recvmsg(fd, &mut msg, libc::MSG_WAITALL) };
if total_read == -1 {
return Err(Error::last());
}
// When the connection is closed recvmsg() doesn't give an explicit error
if total_read == 0 && msg.msg_controllen < size_of::<cmsghdr>() {
return Err(Error::new(libc::ECONNRESET));
}
let mut cmsg_ptr = msg.msg_control as *mut cmsghdr;
let mut in_fds_count = 0;
while !cmsg_ptr.is_null() {
// Safe because we checked that cmsg_ptr was non-null, and the loop is constructed such that
// that only happens when there is at least sizeof(cmsghdr) space after the pointer to read.
let cmsg = unsafe { (cmsg_ptr as *mut cmsghdr).read_unaligned() };
if cmsg.cmsg_level == SOL_SOCKET && cmsg.cmsg_type == SCM_RIGHTS {
let fd_count = (cmsg.cmsg_len - CMSG_LEN!(0)) / size_of::<RawFd>();
unsafe {
copy_nonoverlapping(
CMSG_DATA(cmsg_ptr),
in_fds[in_fds_count..(in_fds_count + fd_count)].as_mut_ptr(),
fd_count,
);
}
in_fds_count += fd_count;
}
cmsg_ptr = get_next_cmsg(&msg, &cmsg, cmsg_ptr);
}
Ok((total_read as usize, in_fds_count))
}
/// Trait for file descriptors can send and receive socket control messages via `sendmsg` and
/// `recvmsg`.
pub trait ScmSocket {
/// Gets the file descriptor of this socket.
fn socket_fd(&self) -> RawFd;
/// Sends the given data and file descriptor over the socket.
///
/// On success, returns the number of bytes sent.
///
/// # Arguments
///
/// * `buf` - A buffer of data to send on the `socket`.
/// * `fd` - A file descriptors to be sent.
fn send_with_fd<D: IntoIovec>(&self, buf: D, fd: RawFd) -> Result<usize> {
self.send_with_fds(&[buf], &[fd])
}
/// Sends the given data and file descriptors over the socket.
///
/// On success, returns the number of bytes sent.
///
/// # Arguments
///
/// * `bufs` - A list of data buffer to send on the `socket`.
/// * `fds` - A list of file descriptors to be sent.
fn send_with_fds<D: IntoIovec>(&self, bufs: &[D], fds: &[RawFd]) -> Result<usize> {
raw_sendmsg(self.socket_fd(), bufs, fds)
}
/// Receives data and potentially a file descriptor from the socket.
///
/// On success, returns the number of bytes and an optional file descriptor.
///
/// # Arguments
///
/// * `buf` - A buffer to receive data from the socket.
fn recv_with_fd(&self, buf: &mut [u8]) -> Result<(usize, Option<File>)> {
let mut fd = [0];
let mut iovecs = [iovec {
iov_base: buf.as_mut_ptr() as *mut c_void,
iov_len: buf.len(),
}];
let (read_count, fd_count) = self.recv_with_fds(&mut iovecs[..], &mut fd)?;
let file = if fd_count == 0 {
None
} else {
// Safe because the first fd from recv_with_fds is owned by us and valid because this
// branch was taken.
Some(unsafe { File::from_raw_fd(fd[0]) })
};
Ok((read_count, file))
}
/// Receives data and file descriptors from the socket.
///
/// On success, returns the number of bytes and file descriptors received as a tuple
/// `(bytes count, files count)`.
///
/// # Arguments
///
/// * `iovecs` - A list of iovec to receive data from the socket.
/// * `fds` - A slice of `RawFd`s to put the received file descriptors into. On success, the
/// number of valid file descriptors is indicated by the second element of the
/// returned tuple. The caller owns these file descriptors, but they will not be
/// closed on drop like a `File`-like type would be. It is recommended that each valid
/// file descriptor gets wrapped in a drop type that closes it after this returns.
fn recv_with_fds(&self, iovecs: &mut [iovec], fds: &mut [RawFd]) -> Result<(usize, usize)> {
raw_recvmsg(self.socket_fd(), iovecs, fds)
}
}
impl ScmSocket for UnixDatagram {
fn socket_fd(&self) -> RawFd {
self.as_raw_fd()
}
}
impl ScmSocket for UnixStream {
fn socket_fd(&self) -> RawFd {
self.as_raw_fd()
}
}
/// Trait for types that can be converted into an `iovec` that can be referenced by a syscall for
/// the lifetime of this object.
///
/// This trait is unsafe because interfaces that use this trait depend on the base pointer and size
/// being accurate.
pub unsafe trait IntoIovec {
/// Gets the base pointer of this `iovec`.
fn as_ptr(&self) -> *const c_void;
/// Gets the size in bytes of this `iovec`.
fn size(&self) -> usize;
}
// Safe because this slice can not have another mutable reference and it's pointer and size are
// guaranteed to be valid.
unsafe impl<'a> IntoIovec for &'a [u8] {
// Clippy false positive: https://github.com/rust-lang/rust-clippy/issues/3480
#[cfg_attr(feature = "cargo-clippy", allow(clippy::useless_asref))]
fn as_ptr(&self) -> *const c_void {
self.as_ref().as_ptr() as *const c_void
}
fn size(&self) -> usize {
self.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use std::mem::size_of;
use std::os::raw::c_long;
use std::os::unix::net::UnixDatagram;
use std::slice::from_raw_parts;
use libc::cmsghdr;
use vmm_sys_util::eventfd::EventFd;
#[test]
fn buffer_len() {
assert_eq!(CMSG_SPACE!(0 * size_of::<RawFd>()), size_of::<cmsghdr>());
assert_eq!(
CMSG_SPACE!(1 * size_of::<RawFd>()),
size_of::<cmsghdr>() + size_of::<c_long>()
);
if size_of::<RawFd>() == 4 {
assert_eq!(
CMSG_SPACE!(2 * size_of::<RawFd>()),
size_of::<cmsghdr>() + size_of::<c_long>()
);
assert_eq!(
CMSG_SPACE!(3 * size_of::<RawFd>()),
size_of::<cmsghdr>() + size_of::<c_long>() * 2
);
assert_eq!(
CMSG_SPACE!(4 * size_of::<RawFd>()),
size_of::<cmsghdr>() + size_of::<c_long>() * 2
);
} else if size_of::<RawFd>() == 8 {
assert_eq!(
CMSG_SPACE!(2 * size_of::<RawFd>()),
size_of::<cmsghdr>() + size_of::<c_long>() * 2
);
assert_eq!(
CMSG_SPACE!(3 * size_of::<RawFd>()),
size_of::<cmsghdr>() + size_of::<c_long>() * 3
);
assert_eq!(
CMSG_SPACE!(4 * size_of::<RawFd>()),
size_of::<cmsghdr>() + size_of::<c_long>() * 4
);
}
}
#[test]
fn send_recv_no_fd() {
let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
let write_count = s1
.send_with_fds(&[[1u8, 1, 2].as_ref(), [21u8, 34, 55].as_ref()], &[])
.expect("failed to send data");
assert_eq!(write_count, 6);
let mut buf = [0u8; 6];
let mut files = [0; 1];
let mut iovecs = [iovec {
iov_base: buf.as_mut_ptr() as *mut c_void,
iov_len: buf.len(),
}];
let (read_count, file_count) = s2
.recv_with_fds(&mut iovecs[..], &mut files)
.expect("failed to recv data");
assert_eq!(read_count, 6);
assert_eq!(file_count, 0);
assert_eq!(buf, [1, 1, 2, 21, 34, 55]);
}
#[test]
fn send_recv_only_fd() {
let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
let evt = EventFd::new(0).expect("failed to create eventfd");
let write_count = s1
.send_with_fd([].as_ref(), evt.as_raw_fd())
.expect("failed to send fd");
assert_eq!(write_count, 0);
let (read_count, file_opt) = s2.recv_with_fd(&mut []).expect("failed to recv fd");
let mut file = file_opt.unwrap();
assert_eq!(read_count, 0);
assert!(file.as_raw_fd() >= 0);
assert_ne!(file.as_raw_fd(), s1.as_raw_fd());
assert_ne!(file.as_raw_fd(), s2.as_raw_fd());
assert_ne!(file.as_raw_fd(), evt.as_raw_fd());
file.write(unsafe { from_raw_parts(&1203u64 as *const u64 as *const u8, 8) })
.expect("failed to write to sent fd");
assert_eq!(evt.read().expect("failed to read from eventfd"), 1203);
}
#[test]
fn send_recv_with_fd() {
let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
let evt = EventFd::new(0).expect("failed to create eventfd");
let write_count = s1
.send_with_fds(&[[237].as_ref()], &[evt.as_raw_fd()])
.expect("failed to send fd");
assert_eq!(write_count, 1);
let mut files = [0; 2];
let mut buf = [0u8];
let mut iovecs = [iovec {
iov_base: buf.as_mut_ptr() as *mut c_void,
iov_len: buf.len(),
}];
let (read_count, file_count) = s2
.recv_with_fds(&mut iovecs[..], &mut files)
.expect("failed to recv fd");
assert_eq!(read_count, 1);
assert_eq!(buf[0], 237);
assert_eq!(file_count, 1);
assert!(files[0] >= 0);
assert_ne!(files[0], s1.as_raw_fd());
assert_ne!(files[0], s2.as_raw_fd());
assert_ne!(files[0], evt.as_raw_fd());
let mut file = unsafe { File::from_raw_fd(files[0]) };
file.write(unsafe { from_raw_parts(&1203u64 as *const u64 as *const u8, 8) })
.expect("failed to write to sent fd");
assert_eq!(evt.read().expect("failed to read from eventfd"), 1203);
}
}

View File

@ -1,30 +0,0 @@
// Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause
//
// Portions Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
//
// Portions Copyright 2017 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE-BSD file.
//! Trait to control vhost-vsock backend drivers.
use crate::backend::VhostBackend;
use crate::Result;
/// Trait to control vhost-vsock backend drivers.
pub trait VhostVsock: VhostBackend {
/// Set the CID for the guest.
/// This number is used for routing all data destined for running in the guest.
/// Each guest on a hypervisor must have an unique CID.
///
/// # Arguments
/// * `cid` - CID to assign to the guest
fn set_guest_cid(&mut self, cid: u64) -> Result<()>;
/// Tell the VHOST driver to start performing data transfer.
fn start(&mut self) -> Result<()>;
/// Tell the VHOST driver to stop performing data transfer.
fn stop(&mut self) -> Result<()>;
}