# Copyright VyOS maintainers and contributors <maintainers@vyos.io>
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 2.1 of the License, or (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this library.  If not, see <http://www.gnu.org/licenses/>.

# Migrate Wireguard to store keys in CLI
# Migrate EAPoL to PKI configuration

import os

from vyos.configtree import ConfigTree
from vyos.pki import CERT_BEGIN
from vyos.pki import load_certificate
from vyos.pki import load_crl
from vyos.pki import load_dh_parameters
from vyos.pki import load_private_key
from vyos.pki import encode_certificate
from vyos.pki import encode_dh_parameters
from vyos.pki import encode_private_key
from vyos.pki import verify_crl
from vyos.utils.process import run
from vyos.utils.file import read_file

def wrapped_pem_to_config_value(pem):
    out = []
    for line in pem.strip().split("\n"):
        if not line or line.startswith("-----") or line[0] == '#':
            continue
        out.append(line)
    return "".join(out)

def read_auth_file(config_auth_path):
    full_path = os.path.normpath(os.path.join(AUTH_DIR, config_auth_path))

    # If the file is not found under `/config/auth`, it may be because the `/config`
    # partition has not been bind-mounted yet during early boot migration execution.
    # Fall back to the equivalent path under `/opt/vyatta/etc/config/auth` which
    # is accessible at all boot stages.
    if not os.path.isfile(full_path) and full_path.startswith(f'{AUTH_DIR}/'):
        full_path = AUTH_DIR_FALLBACK + full_path[len(AUTH_DIR): ]

    if os.path.isfile(full_path):
        if not os.access(full_path, os.R_OK):
            run(['sudo', 'chmod', '644', full_path])

        return read_file(full_path)

    return None


AUTH_DIR = '/config/auth'
AUTH_DIR_FALLBACK = '/opt/vyatta/etc/config/auth'

pki_base = ['pki']

def migrate(config: ConfigTree) -> None:
    # OpenVPN
    base = ['interfaces', 'openvpn']

    if config.exists(base):
        for interface in config.list_nodes(base):
            x509_base = base + [interface, 'tls']
            pki_name = f'openvpn_{interface}'

            if config.exists(base + [interface, 'shared-secret-key-file']):
                if not config.exists(pki_base + ['openvpn', 'shared-secret']):
                    config.set(pki_base + ['openvpn', 'shared-secret'])
                    config.set_tag(pki_base + ['openvpn', 'shared-secret'])

                key_file = config.return_value(base + [interface, 'shared-secret-key-file'])
                key = read_auth_file(key_file)
                key_pki_name = f'{pki_name}_shared'

                if key:
                    # Check if OpenVPN shared-secret already exists - no need to check node
                    # existence as it is always created above when entering this context
                    secret_exists = None
                    for secret_name in config.list_nodes(pki_base + ['openvpn', 'shared-secret']):
                        secret_path = pki_base + ['openvpn', 'shared-secret', secret_name, 'key']
                        if not config.exists(secret_path):
                            continue

                        secret = config.return_value(secret_path)
                        # Check for duplicate cert/key - and re-use if possible
                        if secret == wrapped_pem_to_config_value(key):
                            secret_exists = secret_name
                            break

                    if secret_exists:
                        config.set(base + [interface, 'shared-secret-key'], value=secret_exists)
                    else:
                        config.set(pki_base + ['openvpn', 'shared-secret', key_pki_name, 'key'], value=wrapped_pem_to_config_value(key))
                        config.set(pki_base + ['openvpn', 'shared-secret', key_pki_name, 'version'], value='1')
                        config.set(base + [interface, 'shared-secret-key'], value=key_pki_name)
                else:
                    print(f'Failed to migrate shared-secret-key on openvpn interface {interface}')

                config.delete(base + [interface, 'shared-secret-key-file'])

            if not config.exists(base + [interface, 'tls']):
                continue

            if config.exists(base + [interface, 'tls', 'auth-file']):
                if not config.exists(pki_base + ['openvpn', 'shared-secret']):
                    config.set(pki_base + ['openvpn', 'shared-secret'])
                    config.set_tag(pki_base + ['openvpn', 'shared-secret'])

                key_file = config.return_value(base + [interface, 'tls', 'auth-file'])
                key = read_auth_file(key_file)
                key_pki_name = f'{pki_name}_auth'

                if key:
                    # Check if OpenVPN auth key already exists - no need to check node
                    # existence as it is always created above when entering this context
                    secret_exists = None
                    for secret_name in config.list_nodes(pki_base + ['openvpn', 'shared-secret']):
                        secret_path = pki_base + ['openvpn', 'shared-secret', secret_name, 'key']
                        if not config.exists(secret_path):
                            continue

                        secret = config.return_value(secret_path)
                        # Check for duplicate cert/key - and re-use if possible
                        if secret == wrapped_pem_to_config_value(key):
                            secret_exists = secret_name
                            break

                    if secret_exists:
                        config.set(base + [interface, 'tls', 'auth-key'], value=secret_exists)
                    else:
                        config.set(pki_base + ['openvpn', 'shared-secret', key_pki_name, 'key'], value=wrapped_pem_to_config_value(key))
                        config.set(pki_base + ['openvpn', 'shared-secret', key_pki_name, 'version'], value='1')
                        config.set(base + [interface, 'tls', 'auth-key'], value=key_pki_name)
                else:
                    print(f'Failed to migrate auth-key on openvpn interface {interface}')

                config.delete(base + [interface, 'tls', 'auth-file'])

            if config.exists(base + [interface, 'tls', 'crypt-file']):
                if not config.exists(pki_base + ['openvpn', 'shared-secret']):
                    config.set(pki_base + ['openvpn', 'shared-secret'])
                    config.set_tag(pki_base + ['openvpn', 'shared-secret'])

                key_file = config.return_value(base + [interface, 'tls', 'crypt-file'])
                key = read_auth_file(key_file)
                key_pki_name = f'{pki_name}_crypt'

                if key:
                    # Check if OpenVPN auth key already exists - no need to check node
                    # existence as it is always created above when entering this context
                    secret_exists = None
                    for secret_name in config.list_nodes(pki_base + ['openvpn', 'shared-secret']):
                        secret_path = pki_base + ['openvpn', 'shared-secret', secret_name, 'key']
                        if not config.exists(secret_path):
                            continue

                        secret = config.return_value(secret_path)
                        # Check for duplicate cert/key - and re-use if possible
                        if secret == wrapped_pem_to_config_value(key):
                            secret_exists = secret_name
                            break

                    if secret_exists:
                        config.set(base + [interface, 'tls', 'crypt-key'], value=secret_exists)
                    else:
                        config.set(pki_base + ['openvpn', 'shared-secret', key_pki_name, 'key'], value=wrapped_pem_to_config_value(key))
                        config.set(pki_base + ['openvpn', 'shared-secret', key_pki_name, 'version'], value='1')
                        config.set(base + [interface, 'tls', 'crypt-key'], value=key_pki_name)
                else:
                    print(f'Failed to migrate crypt-key on openvpn interface {interface}')

                config.delete(base + [interface, 'tls', 'crypt-file'])

            ca_certs = {}

            if config.exists(x509_base + ['ca-cert-file']):
                if not config.exists(pki_base + ['ca']):
                    config.set(pki_base + ['ca'])
                    config.set_tag(pki_base + ['ca'])

                cert_file = config.return_value(x509_base + ['ca-cert-file'])
                certs_str = read_auth_file(cert_file)

                if certs_str:
                    certs_data = certs_str.split(CERT_BEGIN)
                    index = 1
                    for cert_data in certs_data[1:]:
                        cert = load_certificate(CERT_BEGIN + cert_data, wrap_tags=False)

                        if cert:
                            ca_certs[f'{pki_name}_{index}'] = cert
                            cert_pem = encode_certificate(cert)

                            # Check if CA already exists - no need to check node existence as
                            # it is always created above when entering this context
                            ca_exists = None
                            for ca_name in config.list_nodes(pki_base + ['ca']):
                                ca_cert_path = pki_base + ['ca', ca_name, 'certificate']
                                if not config.exists(ca_cert_path):
                                    continue

                                ca_base64 = config.return_value(ca_cert_path)
                                # Check for duplicate cert/key - and re-use if possible
                                if ca_base64 == wrapped_pem_to_config_value(cert_pem):
                                    ca_exists = ca_name
                                    break

                            if ca_exists:
                                config.set(x509_base + ['ca-certificate'], value=ca_exists, replace=False)
                            else:
                                config.set(pki_base + ['ca', f'{pki_name}_{index}', 'certificate'], value=wrapped_pem_to_config_value(cert_pem))
                                config.set(x509_base + ['ca-certificate'], value=f'{pki_name}_{index}', replace=False)
                        else:
                            print(f'Failed to migrate CA certificate on openvpn interface {interface}')

                        index += 1
                else:
                    print(f'Failed to migrate CA certificate on openvpn interface {interface}')

                config.delete(x509_base + ['ca-cert-file'])

            if config.exists(x509_base + ['crl-file']):
                if not config.exists(pki_base + ['ca']):
                    config.set(pki_base + ['ca'])
                    config.set_tag(pki_base + ['ca'])

                crl_file = config.return_value(x509_base + ['crl-file'])
                crl_data = read_auth_file(crl_file)

                crl = load_crl(crl_data, wrap_tags=False) if crl_data else None
                crl_ca_name = None

                if crl:
                    for ca_name, ca_cert in ca_certs.items():
                        if verify_crl(crl, ca_cert):
                            crl_ca_name = ca_name
                            break

                if crl_ca_name:
                    crl_pem = encode_certificate(crl)

                    # Check if CRL already exists - no need to check node
                    # existence as it is always created above when entering this context
                    crl_exists = None
                    for ca_name in config.list_nodes(pki_base + ['ca']):
                        crl_path = pki_base + ['ca', ca_name, 'crl']
                        if not config.exists(crl_path):
                            continue

                        crl_base64 = config.return_value(crl_path)
                        # Check if CRL is a duplicate and we have already imported it
                        if crl_base64 == wrapped_pem_to_config_value(crl_pem):
                            crl_exists = ca_name
                            break

                    if not crl_exists:
                        config.set(pki_base + ['ca', crl_ca_name, 'crl'], value=wrapped_pem_to_config_value(crl_pem))
                else:
                    print(f'Failed to migrate CRL on openvpn interface {interface}')

                config.delete(x509_base + ['crl-file'])

            if config.exists(x509_base + ['cert-file']):
                if not config.exists(pki_base + ['certificate']):
                    config.set(pki_base + ['certificate'])
                    config.set_tag(pki_base + ['certificate'])

                cert_file = config.return_value(x509_base + ['cert-file'])
                cert_data = read_auth_file(cert_file)

                cert = load_certificate(cert_data, wrap_tags=False) if cert_data else None

                if cert:
                    cert_pem = encode_certificate(cert)
                    # Check if certificate public key already exists - no need to check node
                    # existence as it is always created above when entering this context
                    cert_exists = None
                    for cert_name in config.list_nodes(pki_base + ['certificate']):
                        cert_path = pki_base + ['certificate', cert_name, 'certificate']
                        if not config.exists(cert_path):
                            continue

                        cert_base64 = config.return_value(cert_path)
                        # Check for duplicate cert/key - and re-use if possible
                        if cert_base64 == wrapped_pem_to_config_value(cert_pem):
                            cert_exists = cert_name
                            break

                    if cert_exists:
                        config.set(x509_base + ['certificate'], value=cert_exists)
                    else:
                        config.set(pki_base + ['certificate', pki_name, 'certificate'], value=wrapped_pem_to_config_value(cert_pem))
                        config.set(x509_base + ['certificate'], value=pki_name)
                else:
                    print(f'Failed to migrate certificate on openvpn interface {interface}')

                config.delete(x509_base + ['cert-file'])

            if config.exists(x509_base + ['key-file']):
                key_file = config.return_value(x509_base + ['key-file'])
                key_data = read_auth_file(key_file)

                key = load_private_key(key_data, passphrase=None, wrap_tags=False) if key_data else None

                if key:
                    key_pem = encode_private_key(key, passphrase=None)
                    # Check if certificate public key already exists - no need to check node
                    # existence as it is always created above when entering this context
                    key_exists = None
                    for key_name in config.list_nodes(pki_base + ['certificate']):
                        key_path = pki_base + ['certificate', key_name, 'private', 'key']
                        if not config.exists(key_path):
                            continue

                        key_base64 = config.return_value(key_path)
                        # Check for duplicate cert/key - and re-use if possible
                        if key_base64 == wrapped_pem_to_config_value(key_pem):
                            key_exists = key_name
                            break

                    if not key_exists:
                        config.set(pki_base + ['certificate', pki_name, 'private', 'key'], value=wrapped_pem_to_config_value(key_pem))
                else:
                    print(f'Failed to migrate private key on openvpn interface {interface}')

                config.delete(x509_base + ['key-file'])

            if config.exists(x509_base + ['dh-file']):
                if not config.exists(pki_base + ['dh']):
                    config.set(pki_base + ['dh'])
                    config.set_tag(pki_base + ['dh'])

                dh_file = config.return_value(x509_base + ['dh-file'])
                dh_data = read_auth_file(dh_file)

                dh = load_dh_parameters(dh_data, wrap_tags=False) if dh_data else None

                if dh:
                    dh_pem = encode_dh_parameters(dh)

                    # Check if DH parameters already exists - no need to check node existence
                    # as it is always created above when entering this context
                    dh_exists = None
                    for dh_name in config.list_nodes(pki_base + ['dh']):
                        dh_param_path = pki_base + ['dh', dh_name, 'parameters']
                        if not config.exists(dh_param_path):
                            continue

                        dh_base64 = config.return_value(dh_param_path)
                        # Check for duplicate cert/key - and re-use if possible
                        if dh_base64 == wrapped_pem_to_config_value(dh_pem):
                            dh_exists = dh_name
                            break

                    if dh_exists:
                        config.set(x509_base + ['dh-params'], value=dh_exists)
                    else:
                        config.set(pki_base + ['dh', pki_name, 'parameters'], value=wrapped_pem_to_config_value(dh_pem))
                        config.set(x509_base + ['dh-params'], value=pki_name)
                else:
                    print(f'Failed to migrate DH parameters on openvpn interface {interface}')

                config.delete(x509_base + ['dh-file'])

    # Wireguard
    base = ['interfaces', 'wireguard']

    if config.exists(base):
        for interface in config.list_nodes(base):
            private_key_path = base + [interface, 'private-key']

            key_file = 'default'
            if config.exists(private_key_path):
                key_file = config.return_value(private_key_path)

            full_key_path = f'{AUTH_DIR}/wireguard/{key_file}/private.key'
            key_data = read_auth_file(full_key_path)

            if not key_data:
                print(f'Could not find wireguard private key for migration on interface "{interface}"')
                continue

            key_data = key_data.strip()
            config.set(private_key_path, value=key_data)

            for peer in config.list_nodes(base + [interface, 'peer']):
                config.rename(base + [interface, 'peer', peer, 'pubkey'], 'public-key')

    # Ethernet EAPoL
    base = ['interfaces', 'ethernet']

    if config.exists(base):
        for interface in config.list_nodes(base):
            if not config.exists(base + [interface, 'eapol']):
                continue

            x509_base = base + [interface, 'eapol']
            pki_name = f'eapol_{interface}'

            if config.exists(x509_base + ['ca-cert-file']):
                if not config.exists(pki_base + ['ca']):
                    config.set(pki_base + ['ca'])
                    config.set_tag(pki_base + ['ca'])

                cert_file = config.return_value(x509_base + ['ca-cert-file'])
                cert_data = read_auth_file(cert_file)

                cert = load_certificate(cert_data, wrap_tags=False) if cert_data else None

                if cert:
                    cert_pem = encode_certificate(cert)
                    config.set(pki_base + ['ca', pki_name, 'certificate'], value=wrapped_pem_to_config_value(cert_pem))
                    config.set(x509_base + ['ca-certificate'], value=pki_name)
                else:
                    print(f'Failed to migrate CA certificate on eapol config for interface {interface}')

                config.delete(x509_base + ['ca-cert-file'])

            if config.exists(x509_base + ['cert-file']):
                if not config.exists(pki_base + ['certificate']):
                    config.set(pki_base + ['certificate'])
                    config.set_tag(pki_base + ['certificate'])

                cert_file = config.return_value(x509_base + ['cert-file'])
                cert_data = read_auth_file(cert_file)

                cert = load_certificate(cert_data, wrap_tags=False) if cert_data else None

                if cert:
                    cert_pem = encode_certificate(cert)
                    config.set(pki_base + ['certificate', pki_name, 'certificate'], value=wrapped_pem_to_config_value(cert_pem))
                    config.set(x509_base + ['certificate'], value=pki_name)
                else:
                    print(f'Failed to migrate certificate on eapol config for interface {interface}')

                config.delete(x509_base + ['cert-file'])

            if config.exists(x509_base + ['key-file']):
                key_file = config.return_value(x509_base + ['key-file'])
                key_data = read_auth_file(key_file)

                key = load_private_key(key_data, passphrase=None, wrap_tags=False) if key_data else None

                if key:
                    key_pem = encode_private_key(key, passphrase=None)
                    config.set(pki_base + ['certificate', pki_name, 'private', 'key'], value=wrapped_pem_to_config_value(key_pem))
                else:
                    print(f'Failed to migrate private key on eapol config for interface {interface}')

                config.delete(x509_base + ['key-file'])
