#!/usr/bin/env python3
#
# Copyright VyOS maintainers and contributors <maintainers@vyos.io>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 2 or later as
# published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
import json
import time
import logging
import os

from vyos.configdict import dict_merge
from vyos.configquery import ConfigTreeQuery
from vyos.firewall import fqdn_config_parse
from vyos.firewall import fqdn_resolve
from vyos.ifconfig import WireGuardIf
from vyos.remote import download
from vyos.utils.commit import commit_in_progress
from vyos.utils.dict import dict_search_args
from vyos.utils.kernel import WIREGUARD_REKEY_AFTER_TIME
from vyos.utils.file import makedir, chmod_775, write_file, read_file
from vyos.utils.network import is_valid_ipv4_address_or_range, is_valid_ipv6_address_or_range
from vyos.utils.process import cmd
from vyos.utils.process import run
from vyos.xml_ref import get_defaults

base = ['firewall']
timeout = 300
cache = False
base_firewall = ['firewall']
base_nat = ['nat']
base_interfaces = ['interfaces']

firewall_config_dir = "/config/firewall"

domain_state = {}

ipv4_tables = {
    'ip vyos_mangle',
    'ip vyos_filter',
    'ip vyos_nat',
    'ip vyos_wanloadbalance',
    'ip raw'
}

ipv6_tables = {
    'ip6 vyos_mangle',
    'ip6 vyos_filter',
    'ip6 raw'
}

logger = logging.getLogger(__name__)
logs_handler = logging.StreamHandler()
logger.addHandler(logs_handler)
logger.setLevel(logging.INFO)

def get_config(conf, node):
    node_config = conf.get_config_dict(node, key_mangling=('-', '_'), get_first_key=True,
                                    no_tag_node_value_mangle=True)

    default_values = get_defaults(node, get_first_key=True)

    node_config = dict_merge(default_values, node_config)

    if node == base_firewall and 'global_options' in node_config:
        global_config = node_config['global_options']
        global timeout, cache

        if 'resolver_interval' in global_config:
            timeout = int(global_config['resolver_interval'])

        if 'resolver_cache' in global_config:
            cache = True

    fqdn_config_parse(node_config, node[0])

    return node_config

def resolve(domains, ipv6=False):
    global domain_state

    ip_list = set()

    for domain in domains:
        resolved = fqdn_resolve(domain, ipv6=ipv6)

        cache_key = f'{domain}_ipv6' if ipv6 else domain

        if resolved and cache:
            domain_state[cache_key] = resolved
        elif not resolved:
            if cache_key not in domain_state:
                continue
            resolved = domain_state[cache_key]

        ip_list = ip_list | resolved
    return ip_list

def nft_output(table, set_name, ip_list):
    output = [f'flush set {table} {set_name}']
    if ip_list:
        ip_str = ','.join(ip_list)
        output.append(f'add element {table} {set_name} {{ {ip_str} }}')
    return output

def nft_valid_sets():
    try:
        valid_sets = []
        sets_json = cmd('nft --json list sets')
        sets_obj = json.loads(sets_json)

        for obj in sets_obj['nftables']:
            if 'set' in obj:
                family = obj['set']['family']
                table = obj['set']['table']
                name = obj['set']['name']
                valid_sets.append((f'{family} {table}', name))

        return valid_sets
    except:
        return []

def update_remote_group(config):
    conf_lines = []
    count = 0
    valid_sets = nft_valid_sets()

    remote_groups = dict_search_args(config, 'group', 'remote_group')
    if remote_groups:
        # Create directory for list files if necessary
        if not os.path.isdir(firewall_config_dir):
            makedir(firewall_config_dir, group='vyattacfg')
            chmod_775(firewall_config_dir)

        for set_name, remote_config in remote_groups.items():
            if 'url' not in remote_config:
                continue
            nft_ip_set_name = f'R_{set_name}'
            nft_ip6_set_name = f'R6_{set_name}'

            # Create list file if necessary
            list_file = os.path.join(firewall_config_dir, f"{nft_ip_set_name}.txt")
            if not os.path.exists(list_file):
                write_file(list_file, '', user="root", group="vyattacfg", mode=0o644)

            # Attempt to download file, use cached version if download fails
            try:
                download(list_file, remote_config['url'], raise_error=True)
            except:
                logger.error(f'Failed to download list-file for {set_name} remote group')
                logger.info(f'Using cached list-file for {set_name} remote group')

            # Read list file
            ip_list = []
            ip6_list = []
            invalid_list = []
            for line in read_file(list_file).splitlines():
                line_first_word = line.strip().partition(' ')[0]

                if is_valid_ipv4_address_or_range(line_first_word):
                    ip_list.append(line_first_word)
                elif is_valid_ipv6_address_or_range(line_first_word):
                    ip6_list.append(line_first_word)
                else:
                    if line_first_word[0].isalnum():
                        invalid_list.append(line_first_word)

            # Load ip tables
            for table in ipv4_tables:
                if (table, nft_ip_set_name) in valid_sets:
                    conf_lines += nft_output(table, nft_ip_set_name, ip_list)

            # Load ip6 tables
            for table in ipv6_tables:
                if (table, nft_ip6_set_name) in valid_sets:
                    conf_lines += nft_output(table, nft_ip6_set_name, ip6_list)

            invalid_str = ", ".join(invalid_list)
            if invalid_str:
                logger.info(f'Invalid address for set {set_name}: {invalid_str}')

            count += 1

    nft_conf_str = "\n".join(conf_lines) + "\n"
    code = run(f'nft --file -', input=nft_conf_str)

    logger.info(f'Updated {count} remote-groups in firewall - result: {code}')


def update_fqdn(config, node):
    conf_lines = []
    count = 0
    valid_sets = nft_valid_sets()

    if node == 'firewall':
        domain_groups = dict_search_args(config, 'group', 'domain_group')
        if domain_groups:
            for set_name, domain_config in domain_groups.items():
                if 'address' not in domain_config:
                    continue
                nft_set_name = f'D_{set_name}'
                domains = domain_config['address']

                ip_list = resolve(domains, ipv6=False)
                for table in ipv4_tables:
                    if (table, nft_set_name) in valid_sets:
                        conf_lines += nft_output(table, nft_set_name, ip_list)
                ip6_list = resolve(domains, ipv6=True)
                for table in ipv6_tables:
                    if (table, nft_set_name) in valid_sets:
                        conf_lines += nft_output(table, nft_set_name, ip6_list)
                count += 1

        for set_name, domain in config['ip_fqdn'].items():
            table = 'ip vyos_filter'
            nft_set_name = f'FQDN_{set_name}'
            ip_list = resolve([domain], ipv6=False)
            if (table, nft_set_name) in valid_sets:
                conf_lines += nft_output(table, nft_set_name, ip_list)
            count += 1

        for set_name, domain in config['ip6_fqdn'].items():
            table = 'ip6 vyos_filter'
            nft_set_name = f'FQDN_{set_name}'
            ip_list = resolve([domain], ipv6=True)
            if (table, nft_set_name) in valid_sets:
                conf_lines += nft_output(table, nft_set_name, ip_list)
            count += 1

    else:
        # It's NAT
        for set_name, domain in config['ip_fqdn'].items():
            table = 'ip vyos_nat'
            nft_set_name = f'FQDN_nat_{set_name}'
            ip_list = resolve([domain], ipv6=False)
            if (table, nft_set_name) in valid_sets:
                conf_lines += nft_output(table, nft_set_name, ip_list)
            count += 1

    nft_conf_str = "\n".join(conf_lines) + "\n"
    code = run(f'nft --file -', input=nft_conf_str)

    logger.info(f'Updated {count} sets in {node} - result: {code}')

def update_interfaces(config, node):
    if node == 'interfaces':
        wg_interfaces = dict_search_args(config, 'wireguard')
        if wg_interfaces:

            peer_public_keys = {}
            # for each wireguard interfaces
            for interface, wireguard in wg_interfaces.items():
                peer_public_keys[interface] = []
                for peer, peer_config in wireguard['peer'].items():
                    # check peer if peer host-name or address is set
                    if 'host_name' in peer_config or 'address' in peer_config:
                        # check latest handshake
                        peer_public_keys[interface].append(
                            peer_config['public_key']
                        )

            now_time = time.time()
            for (interface, check_peer_public_keys) in peer_public_keys.items():
                if len(check_peer_public_keys) == 0:
                    continue

                intf = WireGuardIf(interface, create=False, debug=False)
                handshakes = intf.operational.get_latest_handshakes()

                # WireGuard performs a handshake every WIREGUARD_REKEY_AFTER_TIME
                # if data is being transmitted between the peers. If no data is
                # transmitted, the handshake will not be initiated unless new
                # data begins to flow. Each handshake generates a new session
                # key, and the key is rotated at least every 120 seconds or
                # upon data transmission after a prolonged silence.
                for public_key, handshake_time in handshakes.items():
                    if public_key in check_peer_public_keys and (
                        handshake_time == 0
                        or (now_time - handshake_time > 3*WIREGUARD_REKEY_AFTER_TIME)
                    ):
                        intf.operational.reset_peer(public_key=public_key)

if __name__ == '__main__':
    logger.info('VyOS domain resolver')

    count = 1
    while commit_in_progress():
        if ( count % 60 == 0 ):
            logger.info(f'Commit still in progress after {count}s - waiting')
        count += 1
        time.sleep(1)

    conf = ConfigTreeQuery()
    firewall = get_config(conf, base_firewall)
    nat = get_config(conf, base_nat)
    interfaces = get_config(conf, base_interfaces)

    logger.info(f'interval: {timeout}s - cache: {cache}')

    while True:
        update_fqdn(firewall, 'firewall')
        update_fqdn(nat, 'nat')
        update_remote_group(firewall)
        update_interfaces(interfaces, 'interfaces')
        time.sleep(timeout)
