#!/usr/bin/env python3
# dlitz 2025

import os
import shlex
import subprocess
import tempfile
from argparse import ArgumentParser
from pathlib import Path

from pkcs12_export import PKCS12Exporter


def generate_random_passphrase():
    return os.urandom(64).hex()


def make_arg_parser():
    parser = ArgumentParser(
        description="push TLS privkey & certificate to MikroTik RouterOS router"
    )
    parser.add_argument(
        "-k", "--privkey", type=Path, required=True, help="private key file"
    )
    parser.add_argument("--cert", type=Path, required=True, help="certificate file")
    parser.add_argument(
        "--chain", type=Path, help="separate certificate chain file (optional)"
    )
    parser.add_argument("--ssh-host", required=True, help="target ssh host")
    return parser


def parse_args():
    parser = make_arg_parser()
    args = parser.parse_args()
    assert ":" not in args.ssh_host
    return args, parser


def main():
    args, parser = parse_args()

    # TODO: Check certificate serial number before attempting to copy cert, and at end.

    privkey_data = args.privkey.read_text()
    cert_data = args.cert.read_text()
    chain_data = args.chain.read_text() if args.chain is not None else None

    key_passphrase = generate_random_passphrase()

    pkcs12_data = PKCS12Exporter().export_pkcs12(
        privkey_data=privkey_data,
        cert_data=cert_data,
        chain_data=chain_data,
        passphrase=key_passphrase,
    )

    with tempfile.NamedTemporaryFile(dir="/dev/shm") as tf:
        tf.write(pkcs12_data)
        tf.flush()

        ssh_options = [
            "-oBatchMode=yes",
            "-oControlMaster=no",
        ]

        cmd = ["scp", *ssh_options, "-q", tf.name, f"{args.ssh_host}:/cert-pusher-data.p12"]
        # print("executing:", shlex.join(cmd))
        subprocess.run(cmd, check=True)

        # ros_command = f'/certificate import name=www_ssl_cert file-name=cert-pusher-data.p12 no-key-export=yes passphrase="{key_passphrase}"'
        ros_command = f'/certificate import name=www_ssl_cert file-name=cert-pusher-data.p12 no-key-export=yes passphrase="{key_passphrase}"; /file remove [/file find name=cert-pusher-data.p12]'
        cmd = [
            "ssh",
            *ssh_options,
            args.ssh_host,
            ros_command,
        ]
        result = subprocess.check_output(cmd, text=True)
        assert " files-imported: 1" in result
        # print(result)



if __name__ == "__main__":
    main()
