diff --git a/daemon/remote.c b/daemon/remote.c index 34c63648dc..d5ead81e01 100644 --- a/daemon/remote.c +++ b/daemon/remote.c @@ -2495,7 +2495,7 @@ remoteDispatchDomainEventSend(virNetServerClientPtr client, { virNetMessagePtr msg; - if (!(msg = virNetMessageNew())) + if (!(msg = virNetMessageNew(false))) goto cleanup; msg->header.prog = virNetServerProgramGetID(program); diff --git a/daemon/stream.c b/daemon/stream.c index ba3adc21c9..e3214c2f57 100644 --- a/daemon/stream.c +++ b/daemon/stream.c @@ -207,7 +207,7 @@ daemonStreamEvent(virStreamPtr st, int events, void *opaque) virNetError(VIR_ERR_RPC, "%s", _("stream had I/O failure")); - msg = virNetMessageNew(); + msg = virNetMessageNew(false); if (!msg) { ret = -1; } else { @@ -344,7 +344,7 @@ int daemonFreeClientStream(virNetServerClientPtr client, virNetMessagePtr tmp = msg->next; if (client) { /* Send a dummy reply to free up 'msg' & unblock client rx */ - memset(msg, 0, sizeof(*msg)); + virNetMessageClear(msg); msg->header.type = VIR_NET_REPLY; if (virNetServerClientSendMessage(client, msg) < 0) { virNetServerClientImmediateClose(client); @@ -653,7 +653,7 @@ daemonStreamHandleWrite(virNetServerClientPtr client, * its active request count / throttling */ if (msg->header.status == VIR_NET_CONTINUE) { - memset(msg, 0, sizeof(*msg)); + virNetMessageClear(msg); msg->header.type = VIR_NET_REPLY; if (virNetServerClientSendMessage(client, msg) < 0) { virNetMessageFree(msg); @@ -715,7 +715,7 @@ daemonStreamHandleRead(virNetServerClientPtr client, memset(&rerr, 0, sizeof(rerr)); - if (!(msg = virNetMessageNew())) + if (!(msg = virNetMessageNew(false))) ret = -1; else ret = virNetServerProgramSendStreamError(remoteProgram, @@ -729,7 +729,7 @@ daemonStreamHandleRead(virNetServerClientPtr client, stream->tx = 0; if (ret == 0) stream->recvEOF = 1; - if (!(msg = virNetMessageNew())) + if (!(msg = virNetMessageNew(false))) ret = -1; if (msg) { diff --git a/src/rpc/virnetclientprogram.c b/src/rpc/virnetclientprogram.c index c39520abfd..a07b744d3c 100644 --- a/src/rpc/virnetclientprogram.c +++ b/src/rpc/virnetclientprogram.c @@ -272,7 +272,7 @@ int virNetClientProgramCall(virNetClientProgramPtr prog, { virNetMessagePtr msg; - if (!(msg = virNetMessageNew())) + if (!(msg = virNetMessageNew(false))) return -1; msg->header.prog = prog->program; diff --git a/src/rpc/virnetclientstream.c b/src/rpc/virnetclientstream.c index fe15acdf61..2cc84d40dc 100644 --- a/src/rpc/virnetclientstream.c +++ b/src/rpc/virnetclientstream.c @@ -328,7 +328,7 @@ int virNetClientStreamSendPacket(virNetClientStreamPtr st, bool wantReply; VIR_DEBUG("st=%p status=%d data=%p nbytes=%zu", st, status, data, nbytes); - if (!(msg = virNetMessageNew())) + if (!(msg = virNetMessageNew(false))) return -1; virMutexLock(&st->lock); @@ -390,7 +390,7 @@ int virNetClientStreamRecvPacket(virNetClientStreamPtr st, goto cleanup; } - if (!(msg = virNetMessageNew())) { + if (!(msg = virNetMessageNew(false))) { virReportOOMError(); goto cleanup; } diff --git a/src/rpc/virnetmessage.c b/src/rpc/virnetmessage.c index 072549190b..a1ae9c4d7b 100644 --- a/src/rpc/virnetmessage.c +++ b/src/rpc/virnetmessage.c @@ -32,7 +32,7 @@ virReportErrorHelper(VIR_FROM_THIS, code, __FILE__, \ __FUNCTION__, __LINE__, __VA_ARGS__) -virNetMessagePtr virNetMessageNew(void) +virNetMessagePtr virNetMessageNew(bool tracked) { virNetMessagePtr msg; @@ -41,11 +41,21 @@ virNetMessagePtr virNetMessageNew(void) return NULL; } - VIR_DEBUG("msg=%p", msg); + msg->tracked = tracked; + VIR_DEBUG("msg=%p tracked=%d", msg, tracked); return msg; } + +void virNetMessageClear(virNetMessagePtr msg) +{ + bool tracked = msg->tracked; + memset(msg, 0, sizeof(*msg)); + msg->tracked = tracked; +} + + void virNetMessageFree(virNetMessagePtr msg) { if (!msg) diff --git a/src/rpc/virnetmessage.h b/src/rpc/virnetmessage.h index 2aae3f6499..307a0413ef 100644 --- a/src/rpc/virnetmessage.h +++ b/src/rpc/virnetmessage.h @@ -35,6 +35,8 @@ typedef void (*virNetMessageFreeCallback)(virNetMessagePtr msg, void *opaque); * use virNetMessageNew() to allocate on the heap */ struct _virNetMessage { + bool tracked; + char buffer[VIR_NET_MESSAGE_MAX + VIR_NET_MESSAGE_LEN_MAX]; size_t bufferLength; size_t bufferOffset; @@ -48,7 +50,9 @@ struct _virNetMessage { }; -virNetMessagePtr virNetMessageNew(void); +virNetMessagePtr virNetMessageNew(bool tracked); + +void virNetMessageClear(virNetMessagePtr); void virNetMessageFree(virNetMessagePtr msg); diff --git a/src/rpc/virnetserverclient.c b/src/rpc/virnetserverclient.c index a73b06d692..412814def9 100644 --- a/src/rpc/virnetserverclient.c +++ b/src/rpc/virnetserverclient.c @@ -277,7 +277,7 @@ virNetServerClientCheckAccess(virNetServerClientPtr client) return -1; } - if (!(confirm = virNetMessageNew())) + if (!(confirm = virNetMessageNew(false))) return -1; /* Checks have succeeded. Write a '\1' byte back to the client to @@ -323,7 +323,7 @@ virNetServerClientPtr virNetServerClientNew(virNetSocketPtr sock, virNetTLSContextRef(tls); /* Prepare one for packet receive */ - if (!(client->rx = virNetMessageNew())) + if (!(client->rx = virNetMessageNew(true))) goto error; client->rx->bufferLength = VIR_NET_MESSAGE_LEN_MAX; client->nrequests = 1; @@ -805,7 +805,7 @@ readmore: /* Possibly need to create another receive buffer */ if (client->nrequests < client->nrequests_max) { - if (!(client->rx = virNetMessageNew())) { + if (!(client->rx = virNetMessageNew(true))) { client->wantClose = true; } else { client->rx->bufferLength = VIR_NET_MESSAGE_LEN_MAX; @@ -885,16 +885,14 @@ virNetServerClientDispatchWrite(virNetServerClientPtr client) /* Get finished msg from head of tx queue */ msg = virNetMessageQueueServe(&client->tx); - if (msg->header.type == VIR_NET_REPLY || - (msg->header.type == VIR_NET_STREAM && - msg->header.status != VIR_NET_CONTINUE)) { + if (msg->tracked) { client->nrequests--; /* See if the recv queue is currently throttled */ if (!client->rx && client->nrequests < client->nrequests_max) { /* Ready to recv more messages */ + virNetMessageClear(msg); client->rx = msg; - memset(client->rx, 0, sizeof(*client->rx)); client->rx->bufferLength = VIR_NET_MESSAGE_LEN_MAX; msg = NULL; client->nrequests++; diff --git a/src/rpc/virnetserverprogram.c b/src/rpc/virnetserverprogram.c index 2e9e3f7a15..643a97dba0 100644 --- a/src/rpc/virnetserverprogram.c +++ b/src/rpc/virnetserverprogram.c @@ -284,7 +284,7 @@ int virNetServerProgramDispatch(virNetServerProgramPtr prog, VIR_INFO("Ignoring unexpected stream data serial=%d proc=%d status=%d", msg->header.serial, msg->header.proc, msg->header.status); /* Send a dummy reply to free up 'msg' & unblock client rx */ - memset(msg, 0, sizeof(*msg)); + virNetMessageClear(msg); msg->header.type = VIR_NET_REPLY; if (virNetServerClientSendMessage(client, msg) < 0) { ret = -1;