#!/usr/bin/env python3

import argparse
import array
import fcntl
import json
import os
import re
import socket
import sys
import traceback

import libvirt

import libvirt_qemu

debug = False


def get_domain(uri, domstr):
    conn = libvirt.open(uri)

    dom = None
    saveex = None

    def try_lookup(cb):
        try:
            return cb()
        except libvirt.libvirtError as ex:
            nonlocal saveex
            if saveex is None:
                saveex = ex

    if re.match(r'^\d+$', domstr):
        dom = try_lookup(lambda: conn.lookupByID(int(domstr)))

    if dom is None and re.match(r'^[-a-f0-9]{36}|[a-f0-9]{32}$', domstr):
        dom = try_lookup(lambda: conn.lookupByUUIDString(domstr))

    if dom is None:
        dom = try_lookup(lambda: conn.lookupByName(domstr))

    if dom is None:
        raise saveex

    if not dom.isActive():
        raise Exception("Domain must be running to use QMP")

    return conn, dom


class QMPProxy(object):

    def __init__(self, conn, dom, serversock, verbose):
        self.conn = conn
        self.dom = dom
        self.verbose = verbose

        self.serversock = serversock
        self.serverwatch = 0

        self.clientsock = None
        self.clientwatch = 0

        self.api2sock = bytes([])
        self.api2sockfds = []

        self.sock2api = bytes([])
        self.sock2apifds = []

        self.serverwatch = libvirt.virEventAddHandle(
            self.serversock.fileno(), libvirt.VIR_EVENT_HANDLE_READABLE,
            self.handle_server_io, self)

        libvirt_qemu.qemuMonitorEventRegister(self.conn, self.dom,
                                              None,
                                              self.handle_qmp_event,
                                              self)

    @staticmethod
    def handle_qmp_event(conn, dom, event, secs, usecs, details, self):
        evdoc = {
            "event": event,
            "timestamp": {"seconds": secs, "microseconds": usecs}
        }
        if details is not None:
            evdoc["data"] = details

        ev = json.dumps(evdoc)
        if self.verbose:
            print(ev)
        ev += "\r\n"
        self.api2sock += ev.encode("utf-8")
        self.update_client_events()

    def recv_with_fds(self):
        # Match VIR_NET_MESSAGE_NUM_FDS_MAX in virnetprotocol.x
        maxfds = 32
        fds = array.array('i')
        cmsgdatalen = socket.CMSG_LEN(maxfds * fds.itemsize)

        data, cmsgdata, flags, addr = self.clientsock.recvmsg(1024,
                                                              cmsgdatalen)
        for cmsg_level, cmsg_type, cmsg_data in cmsgdata:
            scm_rights = (cmsg_level == socket.SOL_SOCKET and
                          cmsg_type == socket.SCM_RIGHTS)
            if scm_rights:
                fds.frombytes(cmsg_data[:len(cmsg_data) -
                                        (len(cmsg_data) % fds.itemsize)])
            else:
                raise Exception("Unexpected CMSGDATA level %d type %d" % (
                    cmsg_level, cmsg_type))

        return data, [self.make_file(fd) for fd in fds]

    def send_with_fds(self, data, fds):
        cfds = [fd.fileno() for fd in fds]

        cmsgdata = [(socket.SOL_SOCKET, socket.SCM_RIGHTS,
                     array.array("i", cfds))]

        return self.clientsock.sendmsg([data], cmsgdata)

    @staticmethod
    def make_file(fd):
        flags = fcntl.fcntl(fd, fcntl.F_GETFL)

        mask = os.O_RDONLY | os.O_WRONLY | os.O_RDWR | os.O_APPEND
        flags = flags & mask
        mode = ""
        if flags == os.O_RDONLY:
            mode = "rb"
        elif flags == os.O_WRONLY:
            mode = "wb"
        elif flags == os.O_RDWR:
            mode = "r+b"
        elif flags == (os.O_WRONLY | os.O_APPEND):
            mode = "ab"
        elif flags == (os.O_RDWR | os.O_APPEND):
            mode = "a+b"

        return os.fdopen(fd, mode)

    def add_client(self, sock):
        ver = self.conn.getVersion()
        major = int(ver / 1000000) % 1000
        minor = int(ver / 1000) % 1000
        micro = ver % 1000

        greetingobj = {
            "QMP": {
                "version": {
                    "qemu": {
                        "major": major,
                        "minor": minor,
                        "micro": micro,
                    },
                    "package": f"qemu-{major}.{minor}.{micro}",
                },
                "capabilities": [
                    "oob"
                ],
            }
        }
        greeting = json.dumps(greetingobj)
        if self.verbose:
            print(greeting)
        greeting += "\r\n"

        self.clientsock = sock
        self.clientwatch = libvirt.virEventAddHandle(
            self.clientsock.fileno(), libvirt.VIR_EVENT_HANDLE_WRITABLE,
            self.handle_client_io, self)
        self.api2sock += greeting.encode("utf-8")
        self.update_server_events()

    def remove_client(self):
        libvirt.virEventRemoveHandle(self.clientwatch)
        self.clientsock.close()
        self.clientsock = None
        self.clientwatch = 0
        self.update_server_events()

        self.api2sock = bytes([])
        self.api2sockfds = []

        self.sock2api = bytes([])
        self.sock2apifds = []

    def update_client_events(self):
        # For simplicity of tracking distinct QMP cmds and their passed FDs
        # we don't try to support "pipelining", only a single cmd may be
        # inflight
        if len(self.api2sock) > 0:
            events = libvirt.VIR_EVENT_HANDLE_WRITABLE
        else:
            events = libvirt.VIR_EVENT_HANDLE_READABLE

        libvirt.virEventUpdateHandle(self.clientwatch, events)

    def update_server_events(self):
        if self.clientsock is not None:
            libvirt.virEventUpdateHandle(self.serverwatch, 0)
        else:
            libvirt.virEventUpdateHandle(self.serverwatch,
                                         libvirt.VIR_EVENT_HANDLE_READABLE)

    def try_command(self):
        try:
            cmdstr = self.sock2api.decode("utf-8")
            cmd = json.loads(cmdstr)

            if self.verbose:
                cmdstr = cmdstr.strip()
                print(cmdstr)
        except Exception as ex:
            if debug:
                print("Incomplete %s: %s" % (self.sock2api, ex))
            return

        id = None
        if "id" in cmd:
            id = cmd["id"]
            del cmd["id"]

        if cmd.get("execute", "") == "qmp_capabilities":
            resobj = {
                "return": {},
            }
            resfds = []
        else:
            if hasattr(libvirt_qemu, "qemuMonitorCommandWithFiles"):
                res, resfds = libvirt_qemu.qemuMonitorCommandWithFiles(
                    self.dom, json.dumps(cmd), [f.fileno() for f in self.sock2apifds])
                resobj = json.loads(res)
            else:
                if len(self.sock2apifds) > 0:
                    raise Exception("FD passing not supported")
                res = libvirt_qemu.qemuMonitorCommand(
                    self.dom, json.dumps(cmd))
                resfds = []
                resobj = json.loads(res)

        if "id" in resobj:
            del resobj["id"]
        if id is not None:
            resobj["id"] = id

        res = json.dumps(resobj)
        if self.verbose:
            print(res)
        res += "\r\n"

        self.sock2api = bytes([])
        self.sock2apifds = []
        self.api2sock += res.encode("utf-8")
        self.api2sockfds = resfds

    @staticmethod
    def handle_client_io(watch, fd, events, self):
        error = False
        try:
            if events & libvirt.VIR_EVENT_HANDLE_WRITABLE:
                done = self.send_with_fds(self.api2sock, self.api2sockfds)
                if done > 0:
                    self.api2sock = self.api2sock[done:]
                    self.api2sockfds = []
            elif events & libvirt.VIR_EVENT_HANDLE_READABLE:
                data, fds = self.recv_with_fds()
                if len(data) == 0:
                    error = True
                else:
                    self.sock2api += data
                    if len(fds):
                        self.sock2apifds += fds

                    self.try_command()
            else:
                error = True
        except Exception as e:
            global debug
            if debug:
                print("%s: %s" % (sys.argv[0], str(e)))
                print(traceback.format_exc())
            error = True

        if error:
            self.remove_client()
        else:
            self.update_client_events()

    @staticmethod
    def handle_server_io(watch, fd, events, self):
        if self.clientsock is None:
            sock, addr = self.serversock.accept()
            self.add_client(sock)
        else:
            self.update_server_events()


def parse_commandline():
    parser = argparse.ArgumentParser(description="Libvirt QMP proxy")
    parser.add_argument("--connect", "-c",
                        help="Libvirt QEMU driver connection URI")
    parser.add_argument("--debug", "-d", action='store_true',
                        help="Display debugging information")
    parser.add_argument("--verbose", "-v", action='store_true',
                        help="Display QMP traffic")
    parser.add_argument("domain", metavar="DOMAIN",
                        help="Libvirt guest domain ID/UUID/Name")
    parser.add_argument("sockpath", metavar="QMP-SOCK-PATH",
                        help="UNIX socket path for QMP server")

    return parser.parse_args()


def main():
    args = parse_commandline()
    global debug
    debug = args.debug

    if not debug:
        libvirt.registerErrorHandler(lambda opaque, error: None, None)

    libvirt.virEventRegisterDefaultImpl()

    conn, dom = get_domain(args.connect, args.domain)

    if conn.getType() != "QEMU":
        raise Exception("QMP proxy requires a QEMU driver connection not %s" %
                        conn.getType())

    sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
    if os.path.exists(args.sockpath):
        os.unlink(args.sockpath)
    sock.bind(args.sockpath)
    sock.listen(1)

    _ = QMPProxy(conn, dom, sock, args.verbose)

    while True:
        libvirt.virEventRunDefaultImpl()


try:
    main()
    sys.exit(0)
except Exception as e:
    print("%s: %s" % (sys.argv[0], str(e)))
    if debug:
        print(traceback.format_exc())
    sys.exit(1)