#!/usr/bin/env python3
import argparse
import atexit
import hashlib
import http.client
import io
import json
import os
import shutil
import socket
import subprocess
import sys
import tarfile
import tempfile
import uuid


class FriendlyParser(argparse.ArgumentParser):
    def error(self, message):
        sys.stderr.write("\nerror: %s\n" % message)
        self.print_help()
        sys.exit(2)


class UnixHTTPConnection(http.client.HTTPConnection):
    def __init__(self, path):
        http.client.HTTPConnection.__init__(self, "localhost")
        self.path = path

    def connect(self):
        sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
        sock.connect(self.path)
        self.sock = sock


class LXD:
    workdir = None

    def __init__(self, path, project="default"):
        self.lxd = UnixHTTPConnection(path)
        self.project = project

        # Create our workdir
        self.workdir = tempfile.mkdtemp()
        atexit.register(self.cleanup)

    def cleanup(self):
        if self.workdir:
            shutil.rmtree(self.workdir)

    def rest_call(self, path, data=None, method="GET", headers={}):
        if method == "GET" and data:
            data["project"] = self.project
            self.lxd.request(
                method,
                path + "?" + "&".join([f"{key}={value}" for key, value in data.items()]),
                headers,
            )
        else:
            path += "?project=%s" % self.project
            self.lxd.request(method, path, data, headers)

        r = self.lxd.getresponse()
        d = json.loads(r.read().decode("utf-8"))
        return r.status, d

    def aliases_create(self, name, target):
        data = json.dumps({"target": target, "name": name})
        headers = {}
        headers["Content-Type"] = "application/json"

        status, data = self.rest_call("/1.0/images/aliases", data, "POST", headers)

        if status not in (200, 201):
            raise Exception("Failed to create alias: %s" % name)

    def aliases_remove(self, name):
        status, data = self.rest_call("/1.0/images/aliases/%s" % name, method="DELETE")

        if status != 200:
            raise Exception("Failed to remove alias: %s" % name)

    def aliases_list(self):
        status, data = self.rest_call("/1.0/images/aliases")

        return [alias.split("/1.0/images/aliases/")[-1] for alias in data["metadata"]]

    def images_list(self, recursive=False):
        if recursive:
            status, data = self.rest_call("/1.0/images?recursion=1")
            return data["metadata"]
        else:
            status, data = self.rest_call("/1.0/images")
            return [image.split("/1.0/images/")[-1] for image in data["metadata"]]

    def images_upload(self, path, public, filename=None):
        headers = {}
        if public:
            headers["X-LXD-public"] = "1"

        if isinstance(path, str):
            headers["Content-Type"] = "application/octet-stream"

            with open(path, "rb") as f:
                status, data = self.rest_call(
                    "/1.0/images", f, "POST", headers
                )
        else:
            meta_path, rootfs_path = path
            boundary = str(uuid.uuid1())
            filename_entry = " filename=%s" % filename if filename else ""

            if not isinstance(self.workdir, str) or not self.workdir:
                raise RuntimeError("Temporary working directory is not available")
            upload_path = os.path.join(self.workdir, "upload")
            with open(upload_path, "wb+") as body:
                for name, path in [("metadata", meta_path), ("rootfs", rootfs_path)]:
                    body.write(bytes("--%s\r\n" % boundary, "utf-8"))
                    body.write(
                        bytes(
                            "Content-Disposition: form-data; "
                            "name=%s;%s\r\n" % (name, filename_entry),
                            "utf-8",
                        )
                    )
                    body.write(b"Content-Type: application/octet-stream\r\n")
                    body.write(b"\r\n")
                    with open(path, "rb") as fd:
                        shutil.copyfileobj(fd, body)
                    body.write(b"\r\n")

                body.write(bytes("--%s--\r\n" % boundary, "utf-8"))
                body.write(b"\r\n")

            headers["Content-Type"] = "multipart/form-data; boundary=%s" % boundary

            with open(upload_path, "rb") as f:
                status, data = self.rest_call(
                    "/1.0/images", f, "POST", headers
                )

        if status != 202:
            raise Exception("Failed to upload the image: %s %s" % (status, data.get("error", "")))

        status, data = self.rest_call(data["operation"] + "/wait", "", "GET", {})
        if status != 200:
            raise Exception("Failed to query the operation: %s" % status)

        if data["status_code"] != 200:
            raise Exception("Failed to import the image: %s" % data["metadata"])

        return data["metadata"]["metadata"]


class BusyBox:
    workdir = None
    binary_path = None

    def __init__(self):
        # Create our workdir
        self.workdir = tempfile.mkdtemp()
        self.binary_path = shutil.which("busybox")
        if not self.binary_path:
            raise RuntimeError("Unable to locate busybox binary in PATH")
        atexit.register(self.cleanup)

    def cleanup(self):
        if self.workdir:
            shutil.rmtree(self.workdir)

    def create_tarball(self, split=False, template=[]):
        if self.workdir is None:
            raise RuntimeError("Temporary working directory is not available")
        workdir = self.workdir

        if self.binary_path is None:
            raise RuntimeError("BusyBox binary path is not set")
        binary_path = self.binary_path

        destination_tar = os.path.join(workdir, "busybox.tar")
        target_tarball = tarfile.open(destination_tar, "w:")

        destination_tar_rootfs = None
        target_tarball_rootfs = None
        if split:
            destination_tar_rootfs = os.path.join(workdir, "busybox.rootfs.tar")
            target_tarball_rootfs = tarfile.open(destination_tar_rootfs, "w:")

        source_date_epoch = int(os.environ.get("SOURCE_DATE_EPOCH", "0"))
        busybox_stat = os.stat(binary_path)
        creation_date = (
            source_date_epoch if source_date_epoch else int(busybox_stat.st_mtime)
        )

        def normalized_tarinfo(name, *, mode=0o644, type_=tarfile.REGTYPE, linkname=None, size=0):
            info = tarfile.TarInfo(name)
            info.mtime = creation_date
            info.uid = 0
            info.gid = 0
            info.uname = ""
            info.gname = ""
            info.mode = mode
            info.type = type_
            info.size = size
            if linkname is not None:
                info.linkname = linkname
            return info

        metadata = {
            "architecture": os.uname()[4],
            "creation_date": creation_date,
            "properties": {
                "os": "BusyBox",
                "architecture": os.uname()[4],
                "description": "BusyBox %s" % os.uname()[4],
                "name": "busybox-%s" % os.uname()[4],
            },
        }

        # Add busybox
        with open(binary_path, "rb") as fd:
            busybox_file = normalized_tarinfo(
                "bin/busybox" if split else "rootfs/bin/busybox",
                mode=0o755,
                size=busybox_stat.st_size,
            )
            if split:
                assert target_tarball_rootfs is not None
                target_tarball_rootfs.addfile(busybox_file, fd)
            else:
                target_tarball.addfile(busybox_file, fd)

        # Add symlinks
        busybox = subprocess.Popen(
            [binary_path, "--list-full"],
            stdout=subprocess.PIPE,
            universal_newlines=True,
        )
        output, _ = busybox.communicate()

        symlinks = sorted(
            {
                path.strip()
                for path in output.splitlines()
                if path.strip() and path.strip() != "bin/busybox"
            }
        )

        for path in symlinks:
            symlink_name = path if split else f"rootfs/{path}"
            symlink_file = normalized_tarinfo(
                symlink_name,
                mode=0o777,
                type_=tarfile.SYMTYPE,
                linkname="/bin/busybox",
            )
            if split:
                assert target_tarball_rootfs is not None
                target_tarball_rootfs.addfile(symlink_file)
            else:
                target_tarball.addfile(symlink_file)

        # Add directories
        for path in sorted(("dev", "mnt", "proc", "root", "sys", "tmp")):
            directory_name = path if split else f"rootfs/{path}"
            directory_file = normalized_tarinfo(
                directory_name, mode=0o755, type_=tarfile.DIRTYPE
            )
            if split:
                assert target_tarball_rootfs is not None
                target_tarball_rootfs.addfile(directory_file)
            else:
                target_tarball.addfile(directory_file)

        # Deal with templating
        if template:
            metadata["templates"] = {
                "/template": {"when": template, "template": "template.tpl"}
            }

            directory_file = normalized_tarinfo(
                "templates", mode=0o755, type_=tarfile.DIRTYPE
            )
            target_tarball.addfile(directory_file)

            template_content = """name: {{ container.name }}
architecture: {{ container.architecture }}
privileged: {{ container.privileged }}
ephemeral: {{ container.ephemeral }}
trigger: {{ trigger }}
path: {{ path }}
user.foo: {{ config_get("user.foo", "_unset_") }}
"""

            template_file = normalized_tarinfo(
                "templates/template.tpl",
                mode=0o644,
                size=len(template_content),
            )
            target_tarball.addfile(
                template_file, io.BytesIO(template_content.encode())
            )

        # Add the metadata file
        metadata_yaml = (
            json.dumps(
                metadata,
                sort_keys=True,
                indent=4,
                separators=(",", ": "),
                ensure_ascii=False,
            ).encode("utf-8")
            + b"\n"
        )

        metadata_file = normalized_tarinfo(
            "metadata.yaml", mode=0o644, size=len(metadata_yaml)
        )
        target_tarball.addfile(metadata_file, io.BytesIO(metadata_yaml))

        target_tarball.close()
        os.utime(destination_tar, (creation_date, creation_date))
        if split:
            assert target_tarball_rootfs is not None
            assert destination_tar_rootfs is not None
            target_tarball_rootfs.close()
            os.utime(destination_tar_rootfs, (creation_date, creation_date))

        if split:
            assert destination_tar_rootfs is not None
            return destination_tar, destination_tar_rootfs
        else:
            return destination_tar


if __name__ == "__main__":

    def setup_alias(aliases, fingerprint):
        if not aliases:
            return

        existing = lxd.aliases_list()

        for alias in aliases:
            if alias in existing:
                lxd.aliases_remove(alias)
            lxd.aliases_create(alias, fingerprint)
            print("Setup alias: %s" % alias)

    def import_busybox(parser, args):
        busybox = BusyBox()

        if args.split:
            meta_path, rootfs_path = busybox.create_tarball(
                split=True, template=args.template.split(",")
            )

            if args.save_image:
                shutil.copy(meta_path, "busybox.meta.tar")
                shutil.copy(rootfs_path, "busybox.rootfs.tar")
                return

            with open(meta_path, "rb") as meta_fd:
                with open(rootfs_path, "rb") as rootfs_fd:
                    fingerprint = hashlib.sha256(
                        meta_fd.read() + rootfs_fd.read()
                    ).hexdigest()

            if fingerprint in lxd.images_list():
                parser.exit(1, "This image is already in the store.\n")

            if args.filename:
                r = lxd.images_upload(
                    (meta_path, rootfs_path), args.public, meta_path.split("/")[-1]
                )
            else:
                r = lxd.images_upload((meta_path, rootfs_path), args.public)
            print("Image imported as: %s" % r["fingerprint"])
        else:
            path = busybox.create_tarball(template=args.template.split(","))

            if args.save_image:
                if not isinstance(path, str):
                    raise RuntimeError("Unexpected tarball path type")
                shutil.copy(path, "busybox.tar")
                return

            if not isinstance(path, str):
                raise RuntimeError("Unexpected tarball path type")
            with open(path, "rb") as fd:
                fingerprint = hashlib.sha256(fd.read()).hexdigest()

            if fingerprint in lxd.images_list():
                parser.exit(1, "This image is already in the store.\n")

            r = lxd.images_upload(path, args.public)
            print("Image imported as: %s" % r["fingerprint"])

        setup_alias(args.alias, fingerprint)

    parser = FriendlyParser(description="Import a busybox image")
    parser.add_argument(
        "--alias", action="append", default=[], help="Aliases for the image"
    )
    parser.add_argument(
        "--public", action="store_true", default=False, help="Make the image public"
    )
    parser.add_argument(
        "--split",
        action="store_true",
        default=False,
        help="Whether to create a split image",
    )
    parser.add_argument(
        "--filename",
        action="store_true",
        default=False,
        help="Set the split image's filename",
    )
    parser.add_argument(
        "--template", type=str, default="", help="Trigger test template"
    )
    parser.add_argument("--project", type=str, default="default", help="Project to use")

    parser.add_argument(
        "--save-image", action="store_true", default=False, help="Save the image tarball"
    )
    parser.set_defaults(func=import_busybox)

    # Call the function
    args = parser.parse_args()

    if not args.save_image:
        if "LXD_DIR" in os.environ:
            lxd_socket = os.path.join(os.environ["LXD_DIR"], "unix.socket")
        elif os.path.exists("/var/snap/lxd/common/lxd/unix.socket"):
            lxd_socket = "/var/snap/lxd/common/lxd/unix.socket"
        else:
            lxd_socket = "/var/lib/lxd/unix.socket"

        if not os.path.exists(lxd_socket):
            print("LXD isn't running.")
            sys.exit(1)

        lxd = LXD(lxd_socket, project=args.project)

    try:
        args.func(parser, args)
    except Exception as e:
        sys.stderr.write("\nerror: %s\n" % e)
        sys.exit(1)
