diff --git a/virtio-devices/src/vhost_user/net.rs b/virtio-devices/src/vhost_user/net.rs index 8955f91c3..a4738a22b 100644 --- a/virtio-devices/src/vhost_user/net.rs +++ b/virtio-devices/src/vhost_user/net.rs @@ -110,19 +110,22 @@ impl Net { let avail_protocol_features = VhostUserProtocolFeatures::MQ | VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS | VhostUserProtocolFeatures::REPLY_ACK; - let backend_protocol_features = + let acked_protocol_features = if acked_features & VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits() != 0 { - vhost_user_net + let backend_protocol_features = vhost_user_net .get_protocol_features() - .map_err(Error::VhostUserGetProtocolFeatures)? - } else { - return Err(Error::VhostUserProtocolNotSupport); - }; - let acked_protocol_features = avail_protocol_features & backend_protocol_features; + .map_err(Error::VhostUserGetProtocolFeatures)?; - vhost_user_net - .set_protocol_features(acked_protocol_features) - .map_err(Error::VhostUserSetProtocolFeatures)?; + let acked_protocol_features = avail_protocol_features & backend_protocol_features; + + vhost_user_net + .set_protocol_features(acked_protocol_features) + .map_err(Error::VhostUserSetProtocolFeatures)?; + + acked_protocol_features.bits() + } else { + 0 + }; // If the control queue feature has not been negotiated, let's decrease // the number of queues. @@ -131,7 +134,7 @@ impl Net { } let backend_num_queues = - if acked_protocol_features.bits() & VhostUserProtocolFeatures::MQ.bits() != 0 { + if acked_protocol_features & VhostUserProtocolFeatures::MQ.bits() != 0 { vhost_user_net .get_queue_num() .map_err(Error::VhostUserGetQueueMaxNum)? as usize @@ -160,8 +163,8 @@ impl Net { common: VirtioCommon { device_type: VirtioDeviceType::Net as u32, queue_sizes: vec![vu_cfg.queue_size; num_queues], - avail_features, - acked_features, + avail_features: acked_features, + acked_features: 0, paused_sync: Some(Arc::new(Barrier::new(1))), min_queues: DEFAULT_QUEUE_NUMBER as u16, ..Default::default() @@ -169,7 +172,7 @@ impl Net { vhost_user_net, config, guest_memory: None, - acked_protocol_features: acked_protocol_features.bits(), + acked_protocol_features, socket_path, }) }