# Copyright VyOS maintainers and contributors <maintainers@vyos.io>
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 2.1 of the License, or (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this library.  If not, see <http://www.gnu.org/licenses/>.

# T2199: Remove unavailable nodes due to XML/Python implementation using nftables
#        monthdays: nftables does not have a monthdays equivalent
#        utc: nftables userspace uses localtime and calculates the UTC offset automatically
#        icmp/v6: migrate previously available `type-name` to valid type/code
# T4178: Update tcp flags to use multi value node
# T6071: CLI description limit of 256 characters

import re

from vyos.configtree import ConfigTree

max_len_description = 255

base = ['firewall']

icmp_remove = ['any']
icmp_translations = {
    'ping': 'echo-request',
    'pong': 'echo-reply',
    'ttl-exceeded': 'time-exceeded',
    # Network Unreachable
    'network-unreachable': [3, 0],
    'host-unreachable': [3, 1],
    'protocol-unreachable': [3, 2],
    'port-unreachable': [3, 3],
    'fragmentation-needed': [3, 4],
    'source-route-failed': [3, 5],
    'network-unknown': [3, 6],
    'host-unknown': [3, 7],
    'network-prohibited': [3, 9],
    'host-prohibited': [3, 10],
    'TOS-network-unreachable': [3, 11],
    'TOS-host-unreachable': [3, 12],
    'communication-prohibited': [3, 13],
    'host-precedence-violation': [3, 14],
    'precedence-cutoff': [3, 15],
    # Redirect
    'network-redirect': [5, 0],
    'host-redirect': [5, 1],
    'TOS-network-redirect': [5, 2],
    'TOS host-redirect': [5, 3],
    #  Time Exceeded
    'ttl-zero-during-transit': [11, 0],
    'ttl-zero-during-reassembly': [11, 1],
    'ttl-exceeded': 'time-exceeded',
    # Parameter Problem
    'ip-header-bad': [12, 0],
    'required-option-missing': [12, 1]
}

icmpv6_remove = []
icmpv6_translations = {
    'ping': 'echo-request',
    'pong': 'echo-reply',
    # Destination Unreachable
    'no-route': [1, 0],
    'communication-prohibited': [1, 1],
    'address-unreachble': [1, 3],
    'port-unreachable': [1, 4],
    # nd
    'redirect': 'nd-redirect',
    'router-solicitation': 'nd-router-solicit',
    'router-advertisement': 'nd-router-advert',
    'neighbour-solicitation': 'nd-neighbor-solicit',
    'neighbor-solicitation': 'nd-neighbor-solicit',
    'neighbour-advertisement': 'nd-neighbor-advert',
    'neighbor-advertisement': 'nd-neighbor-advert',
    #  Time Exceeded
    'ttl-zero-during-transit': [3, 0],
    'ttl-zero-during-reassembly': [3, 1],
    # Parameter Problem
    'bad-header': [4, 0],
    'unknown-header-type': [4, 1],
    'unknown-option': [4, 2]
}

v4_groups = ["address-group", "network-group", "port-group"]
v6_groups = ["ipv6-address-group", "ipv6-network-group", "port-group"]

valid_group_name_pattern = re.compile(r'^[A-Za-z0-9_.-]+$')
invalid_group_name_pattern = re.compile(r'[^A-Za-z0-9_.-]')

def is_valid_group_name(name: str) -> bool:
    return bool(valid_group_name_pattern.match(name))

def sanitize_group_name(name: str) -> str:
    """Function to sanitize group name by replacing invalid characters"""
    return invalid_group_name_pattern.sub('_', name)

def get_unique_group_name(config: ConfigTree, path: list[str], name: str) -> str:
    """Function to generate a unique group name if the proposed name already exists"""

    new_name = sanitize_group_name(name)
    while config.exists(path + [new_name]):
        new_name = f'{new_name}_'
    return new_name

def migrate(config: ConfigTree) -> None:
    if not config.exists(base):
        # Nothing to do
        return

    v4_found = False
    v6_found = False
    translated_dict = {}

    if config.exists(base + ['group']):
        for group_type in config.list_nodes(base + ['group']):
            for group_name in config.list_nodes(base + ['group', group_type]):
                name_description = base + ['group', group_type, group_name, 'description']
                if config.exists(name_description):
                    tmp = config.return_value(name_description)
                    config.set(name_description, value=tmp[:max_len_description])

                if not is_valid_group_name(group_name):
                    if group_type in v4_groups and not v4_found:
                        v4_found = True
                    if group_type in v6_groups and not v6_found:
                        v6_found = True
                    new_group_name = get_unique_group_name(config, base + ['group', group_type], group_name)
                    translated_dict[group_name] = new_group_name
                    config.copy(base + ['group', group_type, group_name], base + ['group', group_type, new_group_name])
                    config.delete(base + ['group', group_type, group_name])

    if config.exists(base + ['name']):
        for name in config.list_nodes(base + ['name']):
            name_description = base + ['name', name, 'description']
            if config.exists(name_description):
                tmp = config.return_value(name_description)
                config.set(name_description, value=tmp[:max_len_description])

            if not config.exists(base + ['name', name, 'rule']):
                continue

            for rule in config.list_nodes(base + ['name', name, 'rule']):
                rule_description = base + ['name', name, 'rule', rule, 'description']
                if config.exists(rule_description):
                    tmp = config.return_value(rule_description)
                    config.set(rule_description, value=tmp[:max_len_description])

                rule_recent = base + ['name', name, 'rule', rule, 'recent']
                rule_time = base + ['name', name, 'rule', rule, 'time']
                rule_tcp_flags = base + ['name', name, 'rule', rule, 'tcp', 'flags']
                rule_icmp = base + ['name', name, 'rule', rule, 'icmp']

                if config.exists(rule_time + ['monthdays']):
                    config.delete(rule_time + ['monthdays'])

                if config.exists(rule_time + ['utc']):
                    config.delete(rule_time + ['utc'])

                if config.exists(rule_recent + ['time']):
                    tmp = int(config.return_value(rule_recent + ['time']))
                    unit = 'minute'
                    if tmp > 600:
                        unit = 'hour'
                    elif tmp < 10:
                        unit = 'second'
                    config.set(rule_recent + ['time'], value=unit)

                if config.exists(rule_tcp_flags):
                    tmp = config.return_value(rule_tcp_flags)
                    config.delete(rule_tcp_flags)
                    for flag in tmp.split(","):
                        if flag[0] == '!':
                            config.set(rule_tcp_flags + ['not', flag[1:].lower()])
                        else:
                            config.set(rule_tcp_flags + [flag.lower()])

                if config.exists(rule_icmp + ['type-name']):
                    tmp = config.return_value(rule_icmp + ['type-name'])
                    if tmp in icmp_remove:
                        config.delete(rule_icmp + ['type-name'])
                    elif tmp in icmp_translations:
                        translate = icmp_translations[tmp]
                        if isinstance(translate, str):
                            config.set(rule_icmp + ['type-name'], value=translate)
                        elif isinstance(translate, list):
                            config.delete(rule_icmp + ['type-name'])
                            config.set(rule_icmp + ['type'], value=translate[0])
                            config.set(rule_icmp + ['code'], value=translate[1])

                for direction in ['destination', 'source']:
                    if config.exists(base + ['name', name, 'rule', rule, direction]):
                        if config.exists(base + ['name', name, 'rule', rule, direction, 'group']) and v4_found:
                            for group_type in config.list_nodes(base + ['name', name, 'rule', rule, direction, 'group']):
                                group_name = config.return_value(base + ['name', name, 'rule', rule, direction, 'group', group_type])

                                translated_group_name = group_name[1:] if group_name.startswith('!') else group_name
                                if translated_group_name in translated_dict:
                                    if group_name.startswith('!'):
                                        new_group_name = '!' + translated_dict[translated_group_name]
                                    else:
                                        new_group_name = translated_dict[translated_group_name]
                                    config.set(base + ['name', name, 'rule', rule, direction, 'group', group_type], value=new_group_name)

                        pg_base = base + ['name', name, 'rule', rule, direction, 'group', 'port-group']
                        proto_base = base + ['name', name, 'rule', rule, 'protocol']
                        if config.exists(pg_base):
                            if config.exists(proto_base):
                                proto_name = config.return_value(proto_base)
                                set_as_tcp_udp = proto_name == 'all'
                            else:
                                set_as_tcp_udp = False
                            if set_as_tcp_udp:
                                config.set(proto_base, value='tcp_udp')

            if '+' in name:
                replacement_string = "_"
                new_name = name.replace('+', replacement_string)
                while config.exists(base + ['name', new_name]):
                    replacement_string = replacement_string + "_"
                    new_name = name.replace('+', replacement_string)
                config.copy(base + ['name', name], base + ['name', new_name])
                config.delete(base + ['name', name])

    if config.exists(base + ['ipv6-name']):
        for name in config.list_nodes(base + ['ipv6-name']):
            name_description = base + ['ipv6-name', name, 'description']
            if config.exists(name_description):
                tmp = config.return_value(name_description)
                config.set(name_description, value=tmp[:max_len_description])

            if not config.exists(base + ['ipv6-name', name, 'rule']):
                continue

            for rule in config.list_nodes(base + ['ipv6-name', name, 'rule']):
                rule_description = base + ['ipv6-name', name, 'rule', rule, 'description']
                if config.exists(rule_description):
                    tmp = config.return_value(rule_description)
                    config.set(rule_description, value=tmp[:max_len_description])

                rule_recent = base + ['ipv6-name', name, 'rule', rule, 'recent']
                rule_time = base + ['ipv6-name', name, 'rule', rule, 'time']
                rule_tcp_flags = base + ['ipv6-name', name, 'rule', rule, 'tcp', 'flags']
                rule_icmp = base + ['ipv6-name', name, 'rule', rule, 'icmpv6']

                if config.exists(rule_time + ['monthdays']):
                    config.delete(rule_time + ['monthdays'])

                if config.exists(rule_time + ['utc']):
                    config.delete(rule_time + ['utc'])

                if config.exists(rule_recent + ['time']):
                    tmp = int(config.return_value(rule_recent + ['time']))
                    unit = 'minute'
                    if tmp > 600:
                        unit = 'hour'
                    elif tmp < 10:
                        unit = 'second'
                    config.set(rule_recent + ['time'], value=unit)

                if config.exists(rule_tcp_flags):
                    tmp = config.return_value(rule_tcp_flags)
                    config.delete(rule_tcp_flags)
                    for flag in tmp.split(","):
                        if flag[0] == '!':
                            config.set(rule_tcp_flags + ['not', flag[1:].lower()])
                        else:
                            config.set(rule_tcp_flags + [flag.lower()])

                if config.exists(base + ['ipv6-name', name, 'rule', rule, 'protocol']):
                    tmp = config.return_value(base + ['ipv6-name', name, 'rule', rule, 'protocol'])
                    if tmp == 'icmpv6':
                        config.set(base + ['ipv6-name', name, 'rule', rule, 'protocol'], value='ipv6-icmp')

                if config.exists(rule_icmp + ['type']):
                    tmp = config.return_value(rule_icmp + ['type'])
                    type_code_match = re.match(r'^(\d+)(?:/(\d+))?$', tmp)

                    if type_code_match:
                        config.set(rule_icmp + ['type'], value=type_code_match[1])
                        if type_code_match[2]:
                            config.set(rule_icmp + ['code'], value=type_code_match[2])
                    elif tmp in icmpv6_remove:
                        config.delete(rule_icmp + ['type'])
                    elif tmp in icmpv6_translations:
                        translate = icmpv6_translations[tmp]
                        if isinstance(translate, str):
                            config.delete(rule_icmp + ['type'])
                            config.set(rule_icmp + ['type-name'], value=translate)
                        elif isinstance(translate, list):
                            config.set(rule_icmp + ['type'], value=translate[0])
                            config.set(rule_icmp + ['code'], value=translate[1])
                    else:
                        config.rename(rule_icmp + ['type'], 'type-name')

                for direction in ['destination', 'source']:
                    if config.exists(base + ['ipv6-name', name, 'rule', rule, direction]):
                        if config.exists(base + ['ipv6-name', name, 'rule', rule, direction, 'group']) and v6_found:
                            for group_type in config.list_nodes(base + ['ipv6-name', name, 'rule', rule, direction, 'group']):
                                group_name = config.return_value(base + ['ipv6-name', name, 'rule', rule, direction, 'group', group_type])

                                translated_group_name = group_name[1:] if group_name.startswith('!') else group_name
                                if translated_group_name in translated_dict:
                                    if group_name.startswith('!'):
                                        new_group_name = '!' + translated_dict[translated_group_name]
                                    else:
                                        new_group_name = translated_dict[translated_group_name]
                                    config.set(base + ['ipv6-name', name, 'rule', rule, direction, 'group', group_type], value=new_group_name)

                        pg_base = base + ['ipv6-name', name, 'rule', rule, direction, 'group', 'port-group']
                        proto_base = base + ['ipv6-name', name, 'rule', rule, 'protocol']
                        if config.exists(pg_base):
                            if config.exists(proto_base):
                                proto_name = config.return_value(proto_base)
                                set_as_tcp_udp = proto_name == 'all'
                            else:
                                set_as_tcp_udp = False
                            if set_as_tcp_udp:
                                config.set(proto_base, value='tcp_udp')

            if '+' in name:
                replacement_string = "_"
                new_name = name.replace('+', replacement_string)
                while config.exists(base + ['ipv6-name', new_name]):
                    replacement_string = replacement_string + "_"
                    new_name = name.replace('+', replacement_string)
                config.copy(base + ['ipv6-name', name], base + ['ipv6-name', new_name])
                config.delete(base + ['ipv6-name', name])
