diff --git a/virtio-devices/src/iommu.rs b/virtio-devices/src/iommu.rs index d72817ab7..8f86e0cce 100644 --- a/virtio-devices/src/iommu.rs +++ b/virtio-devices/src/iommu.rs @@ -403,221 +403,230 @@ impl Request { // Create the reply let mut reply: Vec = Vec::new(); + let mut status = VIRTIO_IOMMU_S_OK; + let mut hdr_len = 0; - let hdr_len = match req_head.type_ { - VIRTIO_IOMMU_T_ATTACH => { - if desc_size_left != size_of::() { - return Err(Error::InvalidAttachRequest); - } - - let req: VirtioIommuReqAttach = desc_chain - .memory() - .read_obj(req_addr as GuestAddress) - .map_err(Error::GuestMemory)?; - debug!("Attach request {:?}", req); - - // Copy the value to use it as a proper reference. - let domain_id = req.domain; - let endpoint = req.endpoint; - let bypass = - (req.flags & VIRTIO_IOMMU_ATTACH_F_BYPASS) == VIRTIO_IOMMU_ATTACH_F_BYPASS; - - // Add endpoint associated with specific domain - mapping - .endpoints - .write() - .unwrap() - .insert(endpoint, domain_id); - - // Add new domain with no mapping if the entry didn't exist yet - let mut domains = mapping.domains.write().unwrap(); - let domain = Domain { - mappings: BTreeMap::new(), - bypass, - }; - domains.entry(domain_id).or_insert_with(|| domain); - - 0 - } - VIRTIO_IOMMU_T_DETACH => { - if desc_size_left != size_of::() { - return Err(Error::InvalidDetachRequest); - } - - let req: VirtioIommuReqDetach = desc_chain - .memory() - .read_obj(req_addr as GuestAddress) - .map_err(Error::GuestMemory)?; - debug!("Detach request {:?}", req); - - // Copy the value to use it as a proper reference. - let domain_id = req.domain; - let endpoint = req.endpoint; - - // Remove endpoint associated with specific domain - mapping.endpoints.write().unwrap().remove(&endpoint); - - // After all endpoints have been successfully detached from a - // domain, the domain can be removed. This means we must remove - // the mappings associated with this domain. - if mapping - .endpoints - .write() - .unwrap() - .iter() - .filter(|(_, &d)| d == domain_id) - .count() - == 0 - { - mapping.domains.write().unwrap().remove(&domain_id); - } - - 0 - } - VIRTIO_IOMMU_T_MAP => { - if desc_size_left != size_of::() { - return Err(Error::InvalidMapRequest); - } - - let req: VirtioIommuReqMap = desc_chain - .memory() - .read_obj(req_addr as GuestAddress) - .map_err(Error::GuestMemory)?; - debug!("Map request {:?}", req); - - // Copy the value to use it as a proper reference. - let domain_id = req.domain; - - if let Some(domain) = mapping.domains.read().unwrap().get(&domain_id) { - if domain.bypass { - return Err(Error::InvalidMapRequestBypassDomain); + let result = (|| { + match req_head.type_ { + VIRTIO_IOMMU_T_ATTACH => { + if desc_size_left != size_of::() { + status = VIRTIO_IOMMU_S_INVAL; + return Err(Error::InvalidAttachRequest); } - } else { - return Err(Error::InvalidMapRequestMissingDomain); + + let req: VirtioIommuReqAttach = desc_chain + .memory() + .read_obj(req_addr as GuestAddress) + .map_err(Error::GuestMemory)?; + debug!("Attach request {:?}", req); + + // Copy the value to use it as a proper reference. + let domain_id = req.domain; + let endpoint = req.endpoint; + let bypass = + (req.flags & VIRTIO_IOMMU_ATTACH_F_BYPASS) == VIRTIO_IOMMU_ATTACH_F_BYPASS; + + // Add endpoint associated with specific domain + mapping + .endpoints + .write() + .unwrap() + .insert(endpoint, domain_id); + + // Add new domain with no mapping if the entry didn't exist yet + let mut domains = mapping.domains.write().unwrap(); + let domain = Domain { + mappings: BTreeMap::new(), + bypass, + }; + domains.entry(domain_id).or_insert_with(|| domain); } + VIRTIO_IOMMU_T_DETACH => { + if desc_size_left != size_of::() { + status = VIRTIO_IOMMU_S_INVAL; + return Err(Error::InvalidDetachRequest); + } - // Find the list of endpoints attached to the given domain. - let endpoints: Vec = mapping - .endpoints - .write() - .unwrap() - .iter() - .filter(|(_, &d)| d == domain_id) - .map(|(&e, _)| e) - .collect(); + let req: VirtioIommuReqDetach = desc_chain + .memory() + .read_obj(req_addr as GuestAddress) + .map_err(Error::GuestMemory)?; + debug!("Detach request {:?}", req); - // Trigger external mapping if necessary. - for endpoint in endpoints { - if let Some(ext_map) = ext_mapping.get(&endpoint) { - let size = req.virt_end - req.virt_start + 1; - ext_map - .map(req.virt_start, req.phys_start, size) - .map_err(Error::ExternalMapping)?; + // Copy the value to use it as a proper reference. + let domain_id = req.domain; + let endpoint = req.endpoint; + + // Remove endpoint associated with specific domain + mapping.endpoints.write().unwrap().remove(&endpoint); + + // After all endpoints have been successfully detached from a + // domain, the domain can be removed. This means we must remove + // the mappings associated with this domain. + if mapping + .endpoints + .write() + .unwrap() + .iter() + .filter(|(_, &d)| d == domain_id) + .count() + == 0 + { + mapping.domains.write().unwrap().remove(&domain_id); } } - - // Add new mapping associated with the domain - mapping - .domains - .write() - .unwrap() - .get_mut(&domain_id) - .unwrap() - .mappings - .insert( - req.virt_start, - Mapping { - gpa: req.phys_start, - size: req.virt_end - req.virt_start + 1, - }, - ); - - 0 - } - VIRTIO_IOMMU_T_UNMAP => { - if desc_size_left != size_of::() { - return Err(Error::InvalidUnmapRequest); - } - - let req: VirtioIommuReqUnmap = desc_chain - .memory() - .read_obj(req_addr as GuestAddress) - .map_err(Error::GuestMemory)?; - debug!("Unmap request {:?}", req); - - // Copy the value to use it as a proper reference. - let domain_id = req.domain; - let virt_start = req.virt_start; - - if let Some(domain) = mapping.domains.read().unwrap().get(&domain_id) { - if domain.bypass { - return Err(Error::InvalidUnmapRequestBypassDomain); + VIRTIO_IOMMU_T_MAP => { + if desc_size_left != size_of::() { + status = VIRTIO_IOMMU_S_INVAL; + return Err(Error::InvalidMapRequest); } - } else { - return Err(Error::InvalidUnmapRequestMissingDomain); - } - // Find the list of endpoints attached to the given domain. - let endpoints: Vec = mapping - .endpoints - .write() - .unwrap() - .iter() - .filter(|(_, &d)| d == domain_id) - .map(|(&e, _)| e) - .collect(); + let req: VirtioIommuReqMap = desc_chain + .memory() + .read_obj(req_addr as GuestAddress) + .map_err(Error::GuestMemory)?; + debug!("Map request {:?}", req); - // Trigger external unmapping if necessary. - for endpoint in endpoints { - if let Some(ext_map) = ext_mapping.get(&endpoint) { - let size = req.virt_end - virt_start + 1; - ext_map - .unmap(virt_start, size) - .map_err(Error::ExternalUnmapping)?; + // Copy the value to use it as a proper reference. + let domain_id = req.domain; + + if let Some(domain) = mapping.domains.read().unwrap().get(&domain_id) { + if domain.bypass { + status = VIRTIO_IOMMU_S_INVAL; + return Err(Error::InvalidMapRequestBypassDomain); + } + } else { + status = VIRTIO_IOMMU_S_INVAL; + return Err(Error::InvalidMapRequestMissingDomain); } + + // Find the list of endpoints attached to the given domain. + let endpoints: Vec = mapping + .endpoints + .write() + .unwrap() + .iter() + .filter(|(_, &d)| d == domain_id) + .map(|(&e, _)| e) + .collect(); + + // Trigger external mapping if necessary. + for endpoint in endpoints { + if let Some(ext_map) = ext_mapping.get(&endpoint) { + let size = req.virt_end - req.virt_start + 1; + ext_map + .map(req.virt_start, req.phys_start, size) + .map_err(Error::ExternalMapping)?; + } + } + + // Add new mapping associated with the domain + mapping + .domains + .write() + .unwrap() + .get_mut(&domain_id) + .unwrap() + .mappings + .insert( + req.virt_start, + Mapping { + gpa: req.phys_start, + size: req.virt_end - req.virt_start + 1, + }, + ); } + VIRTIO_IOMMU_T_UNMAP => { + if desc_size_left != size_of::() { + status = VIRTIO_IOMMU_S_INVAL; + return Err(Error::InvalidUnmapRequest); + } - // Remove mapping associated with the domain - mapping - .domains - .write() - .unwrap() - .get_mut(&domain_id) - .unwrap() - .mappings - .remove(&virt_start); + let req: VirtioIommuReqUnmap = desc_chain + .memory() + .read_obj(req_addr as GuestAddress) + .map_err(Error::GuestMemory)?; + debug!("Unmap request {:?}", req); - 0 - } - VIRTIO_IOMMU_T_PROBE => { - if desc_size_left != size_of::() { - return Err(Error::InvalidProbeRequest); + // Copy the value to use it as a proper reference. + let domain_id = req.domain; + let virt_start = req.virt_start; + + if let Some(domain) = mapping.domains.read().unwrap().get(&domain_id) { + if domain.bypass { + status = VIRTIO_IOMMU_S_INVAL; + return Err(Error::InvalidUnmapRequestBypassDomain); + } + } else { + status = VIRTIO_IOMMU_S_INVAL; + return Err(Error::InvalidUnmapRequestMissingDomain); + } + + // Find the list of endpoints attached to the given domain. + let endpoints: Vec = mapping + .endpoints + .write() + .unwrap() + .iter() + .filter(|(_, &d)| d == domain_id) + .map(|(&e, _)| e) + .collect(); + + // Trigger external unmapping if necessary. + for endpoint in endpoints { + if let Some(ext_map) = ext_mapping.get(&endpoint) { + let size = req.virt_end - virt_start + 1; + ext_map + .unmap(virt_start, size) + .map_err(Error::ExternalUnmapping)?; + } + } + + // Remove mapping associated with the domain + mapping + .domains + .write() + .unwrap() + .get_mut(&domain_id) + .unwrap() + .mappings + .remove(&virt_start); } + VIRTIO_IOMMU_T_PROBE => { + if desc_size_left != size_of::() { + status = VIRTIO_IOMMU_S_INVAL; + return Err(Error::InvalidProbeRequest); + } - let req: VirtioIommuReqProbe = desc_chain - .memory() - .read_obj(req_addr as GuestAddress) - .map_err(Error::GuestMemory)?; - debug!("Probe request {:?}", req); + let req: VirtioIommuReqProbe = desc_chain + .memory() + .read_obj(req_addr as GuestAddress) + .map_err(Error::GuestMemory)?; + debug!("Probe request {:?}", req); - let probe_prop = VirtioIommuProbeProperty { - type_: VIRTIO_IOMMU_PROBE_T_RESV_MEM, - length: size_of::() as u16, - }; - reply.extend_from_slice(probe_prop.as_slice()); + let probe_prop = VirtioIommuProbeProperty { + type_: VIRTIO_IOMMU_PROBE_T_RESV_MEM, + length: size_of::() as u16, + }; + reply.extend_from_slice(probe_prop.as_slice()); - let resv_mem = VirtioIommuProbeResvMem { - subtype: VIRTIO_IOMMU_RESV_MEM_T_MSI, - start: msi_iova_start, - end: msi_iova_end, - ..Default::default() - }; - reply.extend_from_slice(resv_mem.as_slice()); + let resv_mem = VirtioIommuProbeResvMem { + subtype: VIRTIO_IOMMU_RESV_MEM_T_MSI, + start: msi_iova_start, + end: msi_iova_end, + ..Default::default() + }; + reply.extend_from_slice(resv_mem.as_slice()); - PROBE_PROP_SIZE + hdr_len = PROBE_PROP_SIZE; + } + _ => { + status = VIRTIO_IOMMU_S_INVAL; + return Err(Error::InvalidRequest); + } } - _ => return Err(Error::InvalidRequest), - }; + Ok(()) + })(); let status_desc = desc_chain.next().ok_or(Error::DescriptorChainTooShort)?; @@ -631,16 +640,21 @@ impl Request { } let tail = VirtioIommuReqTail { - status: VIRTIO_IOMMU_S_OK, + status, ..Default::default() }; reply.extend_from_slice(tail.as_slice()); + // Make sure we return the result of the request to the guest before + // we return a potential error internally. desc_chain .memory() .write_slice(reply.as_slice(), status_desc.addr()) .map_err(Error::GuestMemory)?; + // Return the error if the result was not Ok(). + result?; + Ok((hdr_len as usize) + size_of::()) } }