virtio-devices: iommu: Report request error back to guest

Improve the request parsing/handling code by allowing an error status to
be returned back to the guest driver before we return an error
internally.

Signed-off-by: Sebastien Boeuf <sebastien.boeuf@intel.com>
This commit is contained in:
Sebastien Boeuf 2022-04-26 11:52:19 +02:00
parent bbd0667b98
commit 6df8f0bbf3

View File

@ -403,221 +403,230 @@ impl Request {
// Create the reply // Create the reply
let mut reply: Vec<u8> = Vec::new(); let mut reply: Vec<u8> = Vec::new();
let mut status = VIRTIO_IOMMU_S_OK;
let mut hdr_len = 0;
let hdr_len = match req_head.type_ { let result = (|| {
VIRTIO_IOMMU_T_ATTACH => { match req_head.type_ {
if desc_size_left != size_of::<VirtioIommuReqAttach>() { VIRTIO_IOMMU_T_ATTACH => {
return Err(Error::InvalidAttachRequest); if desc_size_left != size_of::<VirtioIommuReqAttach>() {
} status = VIRTIO_IOMMU_S_INVAL;
return Err(Error::InvalidAttachRequest);
let req: VirtioIommuReqAttach = desc_chain
.memory()
.read_obj(req_addr as GuestAddress)
.map_err(Error::GuestMemory)?;
debug!("Attach request {:?}", req);
// Copy the value to use it as a proper reference.
let domain_id = req.domain;
let endpoint = req.endpoint;
let bypass =
(req.flags & VIRTIO_IOMMU_ATTACH_F_BYPASS) == VIRTIO_IOMMU_ATTACH_F_BYPASS;
// Add endpoint associated with specific domain
mapping
.endpoints
.write()
.unwrap()
.insert(endpoint, domain_id);
// Add new domain with no mapping if the entry didn't exist yet
let mut domains = mapping.domains.write().unwrap();
let domain = Domain {
mappings: BTreeMap::new(),
bypass,
};
domains.entry(domain_id).or_insert_with(|| domain);
0
}
VIRTIO_IOMMU_T_DETACH => {
if desc_size_left != size_of::<VirtioIommuReqDetach>() {
return Err(Error::InvalidDetachRequest);
}
let req: VirtioIommuReqDetach = desc_chain
.memory()
.read_obj(req_addr as GuestAddress)
.map_err(Error::GuestMemory)?;
debug!("Detach request {:?}", req);
// Copy the value to use it as a proper reference.
let domain_id = req.domain;
let endpoint = req.endpoint;
// Remove endpoint associated with specific domain
mapping.endpoints.write().unwrap().remove(&endpoint);
// After all endpoints have been successfully detached from a
// domain, the domain can be removed. This means we must remove
// the mappings associated with this domain.
if mapping
.endpoints
.write()
.unwrap()
.iter()
.filter(|(_, &d)| d == domain_id)
.count()
== 0
{
mapping.domains.write().unwrap().remove(&domain_id);
}
0
}
VIRTIO_IOMMU_T_MAP => {
if desc_size_left != size_of::<VirtioIommuReqMap>() {
return Err(Error::InvalidMapRequest);
}
let req: VirtioIommuReqMap = desc_chain
.memory()
.read_obj(req_addr as GuestAddress)
.map_err(Error::GuestMemory)?;
debug!("Map request {:?}", req);
// Copy the value to use it as a proper reference.
let domain_id = req.domain;
if let Some(domain) = mapping.domains.read().unwrap().get(&domain_id) {
if domain.bypass {
return Err(Error::InvalidMapRequestBypassDomain);
} }
} else {
return Err(Error::InvalidMapRequestMissingDomain); let req: VirtioIommuReqAttach = desc_chain
.memory()
.read_obj(req_addr as GuestAddress)
.map_err(Error::GuestMemory)?;
debug!("Attach request {:?}", req);
// Copy the value to use it as a proper reference.
let domain_id = req.domain;
let endpoint = req.endpoint;
let bypass =
(req.flags & VIRTIO_IOMMU_ATTACH_F_BYPASS) == VIRTIO_IOMMU_ATTACH_F_BYPASS;
// Add endpoint associated with specific domain
mapping
.endpoints
.write()
.unwrap()
.insert(endpoint, domain_id);
// Add new domain with no mapping if the entry didn't exist yet
let mut domains = mapping.domains.write().unwrap();
let domain = Domain {
mappings: BTreeMap::new(),
bypass,
};
domains.entry(domain_id).or_insert_with(|| domain);
} }
VIRTIO_IOMMU_T_DETACH => {
if desc_size_left != size_of::<VirtioIommuReqDetach>() {
status = VIRTIO_IOMMU_S_INVAL;
return Err(Error::InvalidDetachRequest);
}
// Find the list of endpoints attached to the given domain. let req: VirtioIommuReqDetach = desc_chain
let endpoints: Vec<u32> = mapping .memory()
.endpoints .read_obj(req_addr as GuestAddress)
.write() .map_err(Error::GuestMemory)?;
.unwrap() debug!("Detach request {:?}", req);
.iter()
.filter(|(_, &d)| d == domain_id)
.map(|(&e, _)| e)
.collect();
// Trigger external mapping if necessary. // Copy the value to use it as a proper reference.
for endpoint in endpoints { let domain_id = req.domain;
if let Some(ext_map) = ext_mapping.get(&endpoint) { let endpoint = req.endpoint;
let size = req.virt_end - req.virt_start + 1;
ext_map // Remove endpoint associated with specific domain
.map(req.virt_start, req.phys_start, size) mapping.endpoints.write().unwrap().remove(&endpoint);
.map_err(Error::ExternalMapping)?;
// After all endpoints have been successfully detached from a
// domain, the domain can be removed. This means we must remove
// the mappings associated with this domain.
if mapping
.endpoints
.write()
.unwrap()
.iter()
.filter(|(_, &d)| d == domain_id)
.count()
== 0
{
mapping.domains.write().unwrap().remove(&domain_id);
} }
} }
VIRTIO_IOMMU_T_MAP => {
// Add new mapping associated with the domain if desc_size_left != size_of::<VirtioIommuReqMap>() {
mapping status = VIRTIO_IOMMU_S_INVAL;
.domains return Err(Error::InvalidMapRequest);
.write()
.unwrap()
.get_mut(&domain_id)
.unwrap()
.mappings
.insert(
req.virt_start,
Mapping {
gpa: req.phys_start,
size: req.virt_end - req.virt_start + 1,
},
);
0
}
VIRTIO_IOMMU_T_UNMAP => {
if desc_size_left != size_of::<VirtioIommuReqUnmap>() {
return Err(Error::InvalidUnmapRequest);
}
let req: VirtioIommuReqUnmap = desc_chain
.memory()
.read_obj(req_addr as GuestAddress)
.map_err(Error::GuestMemory)?;
debug!("Unmap request {:?}", req);
// Copy the value to use it as a proper reference.
let domain_id = req.domain;
let virt_start = req.virt_start;
if let Some(domain) = mapping.domains.read().unwrap().get(&domain_id) {
if domain.bypass {
return Err(Error::InvalidUnmapRequestBypassDomain);
} }
} else {
return Err(Error::InvalidUnmapRequestMissingDomain);
}
// Find the list of endpoints attached to the given domain. let req: VirtioIommuReqMap = desc_chain
let endpoints: Vec<u32> = mapping .memory()
.endpoints .read_obj(req_addr as GuestAddress)
.write() .map_err(Error::GuestMemory)?;
.unwrap() debug!("Map request {:?}", req);
.iter()
.filter(|(_, &d)| d == domain_id)
.map(|(&e, _)| e)
.collect();
// Trigger external unmapping if necessary. // Copy the value to use it as a proper reference.
for endpoint in endpoints { let domain_id = req.domain;
if let Some(ext_map) = ext_mapping.get(&endpoint) {
let size = req.virt_end - virt_start + 1; if let Some(domain) = mapping.domains.read().unwrap().get(&domain_id) {
ext_map if domain.bypass {
.unmap(virt_start, size) status = VIRTIO_IOMMU_S_INVAL;
.map_err(Error::ExternalUnmapping)?; return Err(Error::InvalidMapRequestBypassDomain);
}
} else {
status = VIRTIO_IOMMU_S_INVAL;
return Err(Error::InvalidMapRequestMissingDomain);
} }
// Find the list of endpoints attached to the given domain.
let endpoints: Vec<u32> = mapping
.endpoints
.write()
.unwrap()
.iter()
.filter(|(_, &d)| d == domain_id)
.map(|(&e, _)| e)
.collect();
// Trigger external mapping if necessary.
for endpoint in endpoints {
if let Some(ext_map) = ext_mapping.get(&endpoint) {
let size = req.virt_end - req.virt_start + 1;
ext_map
.map(req.virt_start, req.phys_start, size)
.map_err(Error::ExternalMapping)?;
}
}
// Add new mapping associated with the domain
mapping
.domains
.write()
.unwrap()
.get_mut(&domain_id)
.unwrap()
.mappings
.insert(
req.virt_start,
Mapping {
gpa: req.phys_start,
size: req.virt_end - req.virt_start + 1,
},
);
} }
VIRTIO_IOMMU_T_UNMAP => {
if desc_size_left != size_of::<VirtioIommuReqUnmap>() {
status = VIRTIO_IOMMU_S_INVAL;
return Err(Error::InvalidUnmapRequest);
}
// Remove mapping associated with the domain let req: VirtioIommuReqUnmap = desc_chain
mapping .memory()
.domains .read_obj(req_addr as GuestAddress)
.write() .map_err(Error::GuestMemory)?;
.unwrap() debug!("Unmap request {:?}", req);
.get_mut(&domain_id)
.unwrap()
.mappings
.remove(&virt_start);
0 // Copy the value to use it as a proper reference.
} let domain_id = req.domain;
VIRTIO_IOMMU_T_PROBE => { let virt_start = req.virt_start;
if desc_size_left != size_of::<VirtioIommuReqProbe>() {
return Err(Error::InvalidProbeRequest); if let Some(domain) = mapping.domains.read().unwrap().get(&domain_id) {
if domain.bypass {
status = VIRTIO_IOMMU_S_INVAL;
return Err(Error::InvalidUnmapRequestBypassDomain);
}
} else {
status = VIRTIO_IOMMU_S_INVAL;
return Err(Error::InvalidUnmapRequestMissingDomain);
}
// Find the list of endpoints attached to the given domain.
let endpoints: Vec<u32> = mapping
.endpoints
.write()
.unwrap()
.iter()
.filter(|(_, &d)| d == domain_id)
.map(|(&e, _)| e)
.collect();
// Trigger external unmapping if necessary.
for endpoint in endpoints {
if let Some(ext_map) = ext_mapping.get(&endpoint) {
let size = req.virt_end - virt_start + 1;
ext_map
.unmap(virt_start, size)
.map_err(Error::ExternalUnmapping)?;
}
}
// Remove mapping associated with the domain
mapping
.domains
.write()
.unwrap()
.get_mut(&domain_id)
.unwrap()
.mappings
.remove(&virt_start);
} }
VIRTIO_IOMMU_T_PROBE => {
if desc_size_left != size_of::<VirtioIommuReqProbe>() {
status = VIRTIO_IOMMU_S_INVAL;
return Err(Error::InvalidProbeRequest);
}
let req: VirtioIommuReqProbe = desc_chain let req: VirtioIommuReqProbe = desc_chain
.memory() .memory()
.read_obj(req_addr as GuestAddress) .read_obj(req_addr as GuestAddress)
.map_err(Error::GuestMemory)?; .map_err(Error::GuestMemory)?;
debug!("Probe request {:?}", req); debug!("Probe request {:?}", req);
let probe_prop = VirtioIommuProbeProperty { let probe_prop = VirtioIommuProbeProperty {
type_: VIRTIO_IOMMU_PROBE_T_RESV_MEM, type_: VIRTIO_IOMMU_PROBE_T_RESV_MEM,
length: size_of::<VirtioIommuProbeResvMem>() as u16, length: size_of::<VirtioIommuProbeResvMem>() as u16,
}; };
reply.extend_from_slice(probe_prop.as_slice()); reply.extend_from_slice(probe_prop.as_slice());
let resv_mem = VirtioIommuProbeResvMem { let resv_mem = VirtioIommuProbeResvMem {
subtype: VIRTIO_IOMMU_RESV_MEM_T_MSI, subtype: VIRTIO_IOMMU_RESV_MEM_T_MSI,
start: msi_iova_start, start: msi_iova_start,
end: msi_iova_end, end: msi_iova_end,
..Default::default() ..Default::default()
}; };
reply.extend_from_slice(resv_mem.as_slice()); reply.extend_from_slice(resv_mem.as_slice());
PROBE_PROP_SIZE hdr_len = PROBE_PROP_SIZE;
}
_ => {
status = VIRTIO_IOMMU_S_INVAL;
return Err(Error::InvalidRequest);
}
} }
_ => return Err(Error::InvalidRequest), Ok(())
}; })();
let status_desc = desc_chain.next().ok_or(Error::DescriptorChainTooShort)?; let status_desc = desc_chain.next().ok_or(Error::DescriptorChainTooShort)?;
@ -631,16 +640,21 @@ impl Request {
} }
let tail = VirtioIommuReqTail { let tail = VirtioIommuReqTail {
status: VIRTIO_IOMMU_S_OK, status,
..Default::default() ..Default::default()
}; };
reply.extend_from_slice(tail.as_slice()); reply.extend_from_slice(tail.as_slice());
// Make sure we return the result of the request to the guest before
// we return a potential error internally.
desc_chain desc_chain
.memory() .memory()
.write_slice(reply.as_slice(), status_desc.addr()) .write_slice(reply.as_slice(), status_desc.addr())
.map_err(Error::GuestMemory)?; .map_err(Error::GuestMemory)?;
// Return the error if the result was not Ok().
result?;
Ok((hdr_len as usize) + size_of::<VirtioIommuReqTail>()) Ok((hdr_len as usize) + size_of::<VirtioIommuReqTail>())
} }
} }