740 lines
28 KiB
Python
Executable file
740 lines
28 KiB
Python
Executable file
#!/usr/bin/python3
|
|
|
|
# This program is free software: you can redistribute it and/or modify
|
|
# it under the terms of the GNU General Public License as published by
|
|
# the Free Software Foundation, either version 3 of the License, or
|
|
# (at your option) any later version.
|
|
#
|
|
# This program is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU General Public License
|
|
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
|
|
# Copyright © 2017 David Sinquin <david.re2o@sinquin.eu>
|
|
# Copyright © 2018-2019 Hugo Levy-Falk <hugo@klafyvel.me>
|
|
|
|
|
|
"""
|
|
Module for nftables set management.
|
|
"""
|
|
|
|
# Dependencies: python3-netaddr, python3-requests, nftables, sudo (optionnal)
|
|
#
|
|
# For sudo configuration, create a file in /etc/sudoers.d/ with:
|
|
# "<python_user> ALL = (root) NOPASSWD: /sbin/nftables"
|
|
|
|
# netaddr :
|
|
# - https://pypi.python.org/pypi/netaddr/
|
|
# - https://github.com/drkjam/netaddr/
|
|
# - https://netaddr.readthedocs.io/en/latest/
|
|
|
|
|
|
import logging
|
|
import subprocess
|
|
import re
|
|
|
|
import netaddr # MAC, IPv4, IPv6
|
|
import requests
|
|
|
|
from collections import Iterable
|
|
|
|
from configparser import ConfigParser
|
|
|
|
|
|
class ExecError(Exception):
|
|
"""Simple class to indicate an error in a process execution."""
|
|
pass
|
|
|
|
class CommandExec:
|
|
"""Simple class to start a command, logging and returning errors if any."""
|
|
@staticmethod
|
|
def run_check_output(command, allowed_return_codes=(0,), timeout=15):
|
|
"""
|
|
Run a command, logging output in case of an error.
|
|
|
|
Actual timeout may be twice the given value in seconds."""
|
|
logging.debug("Command to be run: '%s'", "' '".join(command))
|
|
process = subprocess.Popen(
|
|
command,
|
|
shell=False,
|
|
stdout=subprocess.PIPE,
|
|
stderr=subprocess.PIPE,
|
|
universal_newlines=True)
|
|
try:
|
|
result = process.communicate(timeout=timeout)
|
|
return_code = process.wait(timeout=timeout)
|
|
except subprocess.TimeoutExpired as err:
|
|
process.kill()
|
|
raise ExecError from err
|
|
if return_code not in allowed_return_codes:
|
|
error_message = ('Error running command: "{}", return code: {}.\n'
|
|
'Stderr:\n{}\nStdout:\n{}'.format(
|
|
'" "'.join(command), return_code, *result))
|
|
logging.error(error_message)
|
|
raise ExecError(error_message)
|
|
return (return_code, *result)
|
|
|
|
@classmethod
|
|
def run(cls, *args, **kwargs):
|
|
"""Run a command without checking outputs."""
|
|
returncode, _, _ = cls.run_check_output(*args, **kwargs)
|
|
return returncode
|
|
|
|
class Parser:
|
|
"""Parsers for commonly used formats."""
|
|
@staticmethod
|
|
def MAC(mac):
|
|
"""Check a MAC validity."""
|
|
return netaddr.EUI(mac, dialect=netaddr.mac_unix_expanded)
|
|
@staticmethod
|
|
def IPv4(ip):
|
|
"""Check an IPv4 validity.
|
|
|
|
Args:
|
|
ip: can either be a tuple (in this case returns an IPRange), a
|
|
single IP address or a IP Network.
|
|
"""
|
|
if type(ip) in (netaddr.IPAddress, netaddr.IPNetwork, netaddr.IPRange, netaddr.IPGlob):
|
|
return ip
|
|
try:
|
|
return netaddr.IPAddress(ip, version=4)
|
|
except netaddr.core.AddrFormatError:
|
|
try:
|
|
return netaddr.IPNetwork(ip, version=4)
|
|
except netaddr.core.AddrFormatError:
|
|
begin, end = ip.split('-')
|
|
return netaddr.IPRange(begin, end)
|
|
@staticmethod
|
|
def IPv6(ip):
|
|
"""Check a IPv6 validity.
|
|
|
|
Args:
|
|
ip: can either be a tuple (in this case returns an IPRange), a
|
|
single IP address or a IP Network.
|
|
"""
|
|
if isinstance(ip, tuple):
|
|
begin, end = ip
|
|
return netaddr.IPRange(begin, end, version=6)
|
|
try:
|
|
return netaddr.IPAddress(ip, version=6)
|
|
except ValueError:
|
|
return netaddr.IPNetwork(ip, version=6)
|
|
|
|
@staticmethod
|
|
def protocol(protocol):
|
|
"""Check a protocol validity."""
|
|
if protocol in ('tcp', 'udp', 'icmp'):
|
|
return protocol
|
|
raise ValueError('Invalid protocol: "{}".'.format(protocol))
|
|
@staticmethod
|
|
def port_number(port):
|
|
"""Check a port validity."""
|
|
try:
|
|
port_number = int(port)
|
|
if 0 <= port_number < 65536:
|
|
return port_number
|
|
except ValueError:
|
|
begin, end = port.split('-')
|
|
begin, end = int(begin), int(end)
|
|
if 0 <= begin < end <= 65536:
|
|
return port
|
|
raise ValueError('Invalid port number: "{}".'.format(port))
|
|
|
|
class NetfilterSet:
|
|
"""Manage a netfilter set using nftables."""
|
|
|
|
TYPES = {'IPv4': 'ipv4_addr', 'IPv6': 'ipv6_addr', 'MAC': 'ether_addr',
|
|
'protocol': 'inet_proto', 'port': 'inet_service'}
|
|
|
|
FILTERS = {'IPv4': Parser.IPv4, 'IPv6': Parser.IPv6, 'MAC': Parser.MAC,
|
|
'protocol': Parser.protocol, 'port': Parser.port_number}
|
|
|
|
ADDRESS_FAMILIES = {'ip', 'ip6', 'inet', 'arp', 'bridge', 'netdev'}
|
|
|
|
FLAGS = {'constant', 'interval', 'timeout'}
|
|
|
|
NFT_TYPE = {'set', 'map'}
|
|
|
|
# A.K.A. Really, I don't hate you, so please don't hate me...
|
|
pattern = re.compile(
|
|
r"table (?P<address_family>\w+)+ (?P<table>\w+) \{\n"
|
|
r"\s*set (?P<name>\w+) \{\n"
|
|
r"\s*type (?P<type>(\w+( \. )?)+)\n"
|
|
r"(\s*flags (?P<flags>(\w+(, )?)+)\n)?"
|
|
r"(\s*elements = \{ "
|
|
r"(?P<elements>((\n?\s*)?([\w:\.-/]+( \. )?)+,?)*) "
|
|
r"\n?\s*\}\n)?"
|
|
r"\s*\}\n"
|
|
r"\s*\}"
|
|
)
|
|
|
|
def __init__(self,
|
|
name,
|
|
type_, # e.g.: ('MAC', 'IPv4')
|
|
target_content=None,
|
|
use_sudo=True,
|
|
address_family='inet', # Manage both IPv4 and IPv6.
|
|
table_name='filter',
|
|
flags = [],
|
|
):
|
|
self.name = name
|
|
self.content = set()
|
|
# self.type
|
|
self.set_type(type_)
|
|
self.filters = tuple(self.FILTERS[i] for i in self.type)
|
|
self.set_flags(flags)
|
|
# self.address_family
|
|
self.set_address_family(address_family)
|
|
self.table = table_name
|
|
sudo = ["/usr/bin/sudo"] * int(bool(use_sudo))
|
|
self.nft = [*sudo, "/usr/sbin/nft"]
|
|
if target_content:
|
|
self._target_content = self.validate_data(target_content)
|
|
else:
|
|
self._target_content = set()
|
|
|
|
@property
|
|
def target_content(self):
|
|
return self._target_content.copy() # Forbid in-place modification
|
|
|
|
@target_content.setter
|
|
def target_content(self, target_content):
|
|
self._target_content = self.validate_data(target_content)
|
|
|
|
def filter(self, elements):
|
|
return (self.filters[i](element) for i, element in enumerate(elements))
|
|
|
|
def set_type(self, type_):
|
|
"""Check set type validity and store it along with a type checker."""
|
|
for element_type in type_:
|
|
if element_type not in self.TYPES:
|
|
raise ValueError('Invalid type: "{}".'.format(element_type))
|
|
self.type = type_
|
|
|
|
def set_address_family(self, address_family='ip'):
|
|
"""Set set addres_family, defaulting to "ip" like nftables."""
|
|
if address_family not in self.ADDRESS_FAMILIES:
|
|
raise ValueError(
|
|
'Invalid address_family: "{}".'.format(address_family))
|
|
self.address_family = address_family
|
|
|
|
def set_flags(self, flags_):
|
|
"""Check set flags validity before saving them."""
|
|
for f in flags_:
|
|
if f not in self.FLAGS:
|
|
raise ValueError('Invalid flag: "{}".'.format(f))
|
|
self.flags = set(flags_) or None
|
|
|
|
def create_in_kernel(self):
|
|
"""Create the set, removing existing set if needed."""
|
|
# Delete set if it exists with wrong type
|
|
current_set = self._get_raw_netfilter(parse_elements=False)
|
|
if current_set is None:
|
|
self._create_new_set_in_kernel()
|
|
elif not self.has_type(current_set['type']):
|
|
self._delete_in_kernel()
|
|
self._create_new_set_in_kernel()
|
|
|
|
def _delete_in_kernel(self, nft_type='set'):
|
|
"""Delete the set, table and set must exist."""
|
|
CommandExec.run([
|
|
*self.nft,
|
|
'delete {nft_type} {addr_family} {table} {set_}'.format(
|
|
addr_family=self.address_family, table=self.table,
|
|
nft_type=nft_type,
|
|
set_=self.name)
|
|
])
|
|
|
|
def _create_new_set_in_kernel(self, nft_type='set'):
|
|
"""Create the non-existing set, creating table if needed."""
|
|
if self.flags:
|
|
nft_command = 'add {nft_type} {addr_family} {table} {set_} {{ type {type_} ; flags {flags};}}'.format(
|
|
nft_type=nft_type,
|
|
addr_family=self.address_family,
|
|
table=self.table,
|
|
set_=self.name,
|
|
type_=self.format_type(),
|
|
flags=', '.join(self.flags)
|
|
)
|
|
else:
|
|
nft_command = 'add {nft_type} {addr_family} {table} {set_} {{ type {type_} ;}}'.format(
|
|
nft_type=nft_type,
|
|
addr_family=self.address_family,
|
|
table=self.table,
|
|
set_=self.name,
|
|
type_=self.format_type(),
|
|
)
|
|
create_set = [
|
|
*self.nft,
|
|
nft_command
|
|
]
|
|
return_code = CommandExec.run(create_set, allowed_return_codes=(0, 1))
|
|
if return_code == 0:
|
|
return # Set creation successful.
|
|
# return_code was 1, one error was detected in the rules.
|
|
# Attempt to create the table first.
|
|
create_table = [*self.nft, 'add table {addr_family} {table}'.format(
|
|
addr_family=self.address_family, table=self.table)]
|
|
CommandExec.run(create_table)
|
|
CommandExec.run(create_set)
|
|
|
|
def validate_data(self, set_data):
|
|
"""
|
|
Validate data, returning it or raising a ValueError.
|
|
|
|
For MAC-IPv4 set, data must be an iterable of (MAC, IPv4) iterables.
|
|
"""
|
|
set_ = set()
|
|
errors = []
|
|
for n_uplet in set_data:
|
|
try:
|
|
set_.add(tuple(self.filter(n_uplet)))
|
|
except Exception as err:
|
|
errors.append(err)
|
|
if errors:
|
|
raise ValueError(
|
|
'Error parsing data, encountered the folowing {} errors.\n"{}"'
|
|
.format(len(errors), '",\n"'.join(map(str, errors))))
|
|
return set_
|
|
|
|
def _apply_target_content(self):
|
|
"""Change netfilter content to target set."""
|
|
current_set = self.get_netfilter_content()
|
|
if current_set is None:
|
|
raise ValueError('Cannot change "{}" netfilter set content: set '
|
|
'do not exist in "{}" "{}".'.format(
|
|
self.name, self.address_family, self.table))
|
|
to_delete = current_set - self._target_content
|
|
to_add = self._target_content - current_set
|
|
self._change_content(delete=to_delete, add=to_add)
|
|
|
|
def _change_content(self, delete=None, add=None):
|
|
todo = [tuple_ for tuple_ in (('add', add), ('delete', delete))
|
|
if tuple_[1]]
|
|
for action, elements in todo:
|
|
content = ', '.join(' . '.join(str(element) for element in tuple_)
|
|
for tuple_ in elements)
|
|
command = [
|
|
*self.nft,
|
|
'{action} element {addr_family} {table} {set_} {{{content}}}' \
|
|
.format(action=action, addr_family=self.address_family,
|
|
table=self.table, set_=self.name, content=content)
|
|
]
|
|
CommandExec.run(command)
|
|
|
|
def _get_raw_netfilter(self, parse_elements=True):
|
|
"""Return a dict describing the netfilter set matching self or None."""
|
|
_, stdout, _ = CommandExec.run_check_output(
|
|
[*self.nft, '-nn', 'list set {addr_family} {table} {set_}'.format(
|
|
addr_family=self.address_family, table=self.table,
|
|
set_=self.name)],
|
|
allowed_return_codes=(0, 1) # In case table do not exist
|
|
)
|
|
if not stdout:
|
|
return None
|
|
else:
|
|
netfilter_set = self._parse_netfilter_string(stdout)
|
|
if netfilter_set['name'] != self.name \
|
|
or netfilter_set['address_family'] != self.address_family \
|
|
or netfilter_set['table'] != self.table \
|
|
or not self.has_type(netfilter_set['type']) \
|
|
or netfilter_set.get('flags', set()) != self.flags:
|
|
raise ValueError(
|
|
'Did not get the right set, too wrong to fix. Got '
|
|
+ str(netfilter_set)
|
|
+ ("\nExpected : "
|
|
"\n\tname: \t{name} \t[{name_check}]"
|
|
"\n\taddress_family: \t{family} \t[{family_check}]"
|
|
"\n\ttable: \t{table} \t[{table_check}]"
|
|
"\n\tflags: \t{flags} \t[{flags_check}]"
|
|
"\n\ttypes: \t{types} \t[{types_check}]"
|
|
).format(
|
|
name=self.name,
|
|
family=self.address_family,
|
|
table=self.table,
|
|
flags=self.flags,
|
|
types=tuple(self.TYPES[t] for t in self.type),
|
|
name_check= 'v' if self.name == netfilter_set['name'] else 'x',
|
|
family_check= 'v' if self.address_family == netfilter_set['address_family'] else 'x',
|
|
table_check= 'v' if self.table == netfilter_set['table'] else 'x',
|
|
flags_check= 'v' if self.flags == netfilter_set.get('flags', set()) else 'x',
|
|
types_check= 'v' if self.has_type(netfilter_set['type']) else 'x',
|
|
)
|
|
)
|
|
if parse_elements:
|
|
if netfilter_set['raw_content']:
|
|
netfilter_set['content'] = self.validate_data((
|
|
(element.strip() for element in n_uplet.split(' . '))
|
|
for n_uplet in netfilter_set['raw_content'].split(',')))
|
|
else:
|
|
netfilter_set['content'] = set()
|
|
return netfilter_set
|
|
|
|
@classmethod
|
|
def _parse_netfilter_string(cls, set_string):
|
|
"""
|
|
Parse netfilter set definition and return set as dict.
|
|
|
|
Do not validate content type against detected set type.
|
|
Return a dict with 'name', 'address_family', 'table', 'type', 'flags',
|
|
'raw_content' keys (all strings, 'raw_content' can be None).
|
|
Raise ValueError in case of unexpected syntax.
|
|
"""
|
|
try:
|
|
values = cls.pattern.match(set_string).groupdict()
|
|
except Exception as e:
|
|
raise ValueError("Malformed expression :\n" + set_string)
|
|
return {
|
|
'address_family': values['address_family'],
|
|
'table': values['table'],
|
|
'name': values['name'],
|
|
'type': values['type'].split(' . '),
|
|
'raw_content': values['elements'],
|
|
'flags': set(values['flags'].split(', ')) if values['flags'] else None,
|
|
}
|
|
|
|
def get_netfilter_content(self):
|
|
"""Return current set content from netfilter."""
|
|
netfilter_set = self._get_raw_netfilter(parse_elements=True)
|
|
if netfilter_set is None:
|
|
return None
|
|
else:
|
|
return netfilter_set['content']
|
|
|
|
def has_type(self, type_):
|
|
"""Check if some type match the set's one."""
|
|
return tuple(self.TYPES[t] for t in self.type) == tuple(type_)
|
|
|
|
def manage(self):
|
|
"""Create set if needed and populate it with target content."""
|
|
self.create_in_kernel()
|
|
self._apply_target_content()
|
|
|
|
def format_type(self):
|
|
return ' . '.join(self.TYPES[i] for i in self.type)
|
|
|
|
|
|
class NetfilterMap(NetfilterSet):
|
|
# A.K.A. Again, I don't hate you, so please don't hate me...
|
|
pattern = re.compile(
|
|
r"table (?P<address_family>\w+)+ (?P<table>\w+) \{\n"
|
|
r"\s*map (?P<name>\w+) \{\n"
|
|
r"\s*type (?P<type_from>(\w+( \. )?)+) : (?P<type>\w+)\n"
|
|
r"(\s*flags (?P<flags>(\w+(, )?)+)\n)?"
|
|
r"(\s*elements = \{ "
|
|
r"(?P<elements>(\n?\s*([\w:\.-/]+( \. )?)+ : [\w:\.-/]+,?)*)"
|
|
r"\n?\s*\}\n)?"
|
|
r"\s*\}"
|
|
r"\n\s*\}"
|
|
)
|
|
def __init__(self,
|
|
name,
|
|
type_,
|
|
type_from,
|
|
target_content=None,
|
|
use_sudo=True,
|
|
address_family='inet',
|
|
table_name='filter',
|
|
flags=[]
|
|
):
|
|
super().__init__(name, type_, use_sudo=use_sudo,
|
|
address_family=address_family, table_name=table_name,
|
|
flags=flags)
|
|
self.set_type_from(type_from)
|
|
self.key_filters = tuple(self.FILTERS[i] for i in self.type_from)
|
|
if target_content:
|
|
self._target_content = self.validate_data(target_content)
|
|
else:
|
|
self._target_content = {}
|
|
|
|
def filter_key(self, elements):
|
|
return (self.key_filters[i](element) for i, element in enumerate(elements))
|
|
|
|
def set_type_from(self, type_):
|
|
"""Check set type validity and store it along with a type checker."""
|
|
for element_type in type_:
|
|
if element_type not in self.TYPES:
|
|
raise ValueError('Invalid type: "{}".'.format(element_type))
|
|
self.type_from = type_
|
|
|
|
def _delete_in_kernel(self):
|
|
"""Delete the map, table and map must exist."""
|
|
super()._delete_in_kernel(nft_type='map')
|
|
|
|
def _create_new_set_in_kernel(self):
|
|
"""Create the non-existing set, creating table if needed."""
|
|
super()._create_new_set_in_kernel(nft_type='map')
|
|
|
|
def validate_data(self, dict_data):
|
|
"""
|
|
Validate data, returning it or raising a ValueError.
|
|
|
|
For MAC-IPv4 set, data must be an iterable of (MAC, IPv4) iterables.
|
|
"""
|
|
set_ = {}
|
|
errors = []
|
|
for key in dict_data:
|
|
try:
|
|
set_[tuple(self.filter_key(key))] = tuple(self.filter(dict_data[key]))
|
|
except Exception as err:
|
|
errors.append(err)
|
|
if errors:
|
|
raise ValueError(
|
|
'Error parsing data, encountered the folowing {} errors.\n"{}"'
|
|
.format(len(errors), '",\n"'.join(map(str, errors))))
|
|
return set_
|
|
|
|
def _apply_target_content(self):
|
|
"""Change netfilter map content to target map."""
|
|
current_map = self.get_netfilter_content()
|
|
if current_map is None:
|
|
raise ValueError('Cannot change "{}" netfilter map content: map '
|
|
'do not exist in "{}" "{}".'.format(
|
|
self.name, self.address_family, self.table))
|
|
keys_to_delete = current_map.keys() - self._target_content.keys()
|
|
keys_to_add = self._target_content.keys() - current_map.keys()
|
|
keys_to_check = current_map.keys() & self._target_content.keys()
|
|
for k in keys_to_check:
|
|
if current_map[k] != self._target_content[k]:
|
|
keys_to_add.add(k)
|
|
keys_to_delete.add(k)
|
|
to_add = {k : self._target_content[k] for k in keys_to_add}
|
|
self._change_content(delete=keys_to_delete, add=to_add)
|
|
|
|
def _change_content(self, delete=None, add=None):
|
|
if delete:
|
|
content = ', '.join(' . '.join(str(element) for element in tuple_)
|
|
for tuple_ in delete)
|
|
command = [
|
|
*self.nft,
|
|
'delete element {addr_family} {table} {set_} {{{content}}}' \
|
|
.format(addr_family=self.address_family,
|
|
table=self.table, set_=self.name, content=content)
|
|
]
|
|
CommandExec.run(command)
|
|
if add:
|
|
content = ', '.join(
|
|
' . '.join(str(element) for element in tuple_)
|
|
+ ' : '
|
|
+ ' . '.join(str(element) for element in add[tuple_])
|
|
for tuple_ in add
|
|
)
|
|
command = [
|
|
*self.nft,
|
|
'add element {addr_family} {table} {set_} {{{content}}}' \
|
|
.format(addr_family=self.address_family,
|
|
table=self.table, set_=self.name, content=content)
|
|
]
|
|
CommandExec.run(command)
|
|
|
|
def _get_raw_netfilter(self, parse_elements=True):
|
|
"""Return a dict describing the netfilter map matching self or None."""
|
|
_, stdout, _ = CommandExec.run_check_output(
|
|
[*self.nft, '-nn', 'list map {addr_family} {table} {set_}'.format(
|
|
addr_family=self.address_family, table=self.table,
|
|
set_=self.name)],
|
|
allowed_return_codes=(0, 1) # In case table do not exist
|
|
)
|
|
if not stdout:
|
|
return None
|
|
else:
|
|
netfilter_set = self._parse_netfilter_string(stdout)
|
|
if netfilter_set['name'] != self.name \
|
|
or netfilter_set['address_family'] != self.address_family \
|
|
or netfilter_set['table'] != self.table \
|
|
or not self.has_type((netfilter_set['type_from'], netfilter_set['type'])):
|
|
raise ValueError('Did not get the right map, too wrong to fix.')
|
|
if parse_elements:
|
|
if netfilter_set['raw_content']:
|
|
netfilter_set['content'] = self.validate_data({
|
|
(element.strip() for element in n_uplet.split(' : ')[0].split(' . ')):
|
|
n_uplet.split(' : ')[1].strip()
|
|
for n_uplet in netfilter_set['raw_content'].split(',')
|
|
})
|
|
else:
|
|
netfilter_set['content'] = {}
|
|
return netfilter_set
|
|
|
|
@classmethod
|
|
def _parse_netfilter_string(cls, set_string):
|
|
"""
|
|
Parse netfilter map definition and return map as dict.
|
|
|
|
Do not validate content type against detected map type.
|
|
Return a dict with 'name', 'address_family', 'table', 'type', 'flags'
|
|
'raw_content' and 'type_from' keys (all strings, 'raw_content' and
|
|
'flags' can be None). Raise ValueError in case of unexpected syntax.
|
|
"""
|
|
try:
|
|
values = cls.pattern.match(set_string).groupdict()
|
|
except Exception as e:
|
|
raise ValueError("Malformed expression :\n" + set_string)
|
|
return {
|
|
'address_family': values['address_family'],
|
|
'table': values['table'],
|
|
'name': values['name'],
|
|
'type': values['type'],
|
|
'type_from': values['type_from'].split(' . '),
|
|
'raw_content': values['elements'],
|
|
'flags': values['flags'],
|
|
}
|
|
|
|
def has_type(self, type_):
|
|
"""Check if some type match the set's one."""
|
|
return tuple(self.TYPES[t] for t in self.type) == (type_[1],) and \
|
|
tuple(self.TYPES[t] for t in self.type_from) == tuple(type_[0])
|
|
|
|
def format_type(self):
|
|
return ' . '.join(self.TYPES[i] for i in self.type_from) + ' : ' + ' . '.join(self.TYPES[i] for i in self.type)
|
|
|
|
def filter(self, elements):
|
|
return (self.filters[0](elements),)
|
|
|
|
|
|
def get_ip_iterable_from_str(ip):
|
|
try:
|
|
ret = netaddr.IPGlob(ip)
|
|
except netaddr.core.AddrFormatError:
|
|
try:
|
|
ret = netaddr.IPNetwork(ip)
|
|
except netaddr.core.AddrFormatError:
|
|
begin,end = ip.split('-')
|
|
ret = netaddr.IPRange(begin,end)
|
|
return ret
|
|
|
|
|
|
class NAT:
|
|
|
|
PROTOCOLS = (
|
|
'tcp',
|
|
'udp',
|
|
'icmp'
|
|
)
|
|
|
|
def __init__(self,
|
|
name,
|
|
range_in,
|
|
range_out,
|
|
first_port,
|
|
last_port,
|
|
use_sudo=True
|
|
):
|
|
"""Creates a NAT object for the given range of IP-Addresses.
|
|
|
|
Args:
|
|
name: name of the sets
|
|
range_in: an IPRange with the private IP address
|
|
range_out: an IPRange with the public IP address
|
|
first_port: the first port used for the nat
|
|
last_port: the last port used for the nat
|
|
use_sudo: Should the nft commands be run in sudo ?
|
|
"""
|
|
|
|
assert 0 <= first_port < last_port < 65536, (name + ": Your first_port "
|
|
"is lower than your last_port")
|
|
self.name = name
|
|
self.range_in = get_ip_iterable_from_str(range_in)
|
|
self.range_out = get_ip_iterable_from_str(range_out)
|
|
self.first_port = first_port
|
|
self.last_port = last_port
|
|
|
|
self.nb_private_by_public = self.range_in.size // self.range_out.size + 1
|
|
|
|
sudo = ["/usr/bin/sudo"] * int(bool(use_sudo))
|
|
self.nft = [*sudo, "/usr/sbin/nft"]
|
|
|
|
def create_nat_rule(self, grp, ports):
|
|
"""Create a nat rules in the form :
|
|
ip saddr @<self.name>_nat_port_<grp> ip protocol tcp snat ip saddr map @<self.name>_nat_address : <ports>
|
|
ip saddr @<self.name>_nat_port_<grp> ip protocol udp snat ip saddr map @<self.name>_nat_address : <ports>
|
|
|
|
Args:
|
|
grp: The name of the group
|
|
ports: The port range (str)
|
|
"""
|
|
for protocol in self.PROTOCOLS:
|
|
CommandExec.run([
|
|
*self.nft,
|
|
"add rule ip nat {name}_nat ip saddr @{name}_nat_port_{grp} ip protocol {protocol} snat ip saddr map @{name}_nat_address : {ports}".format(
|
|
protocol=protocol,
|
|
name=self.name,
|
|
grp=grp,
|
|
ports=ports
|
|
)
|
|
])
|
|
|
|
def manage(self):
|
|
"""Creates the port sets, ip map and rules
|
|
"""
|
|
ips = {}
|
|
ports = [
|
|
set() for i in range(self.nb_private_by_public)
|
|
]
|
|
port_range = lambda i : '-'.join([
|
|
str(int(self.first_port + i/self.nb_private_by_public * (self.last_port - self.first_port))),
|
|
str(int(self.first_port + (i+1)/self.nb_private_by_public * (self.last_port - self.first_port)-1))
|
|
])
|
|
nat_log = ""
|
|
for ip_out, ip in zip(
|
|
self.range_out,
|
|
range(self.range_in.first, self.range_in.last, self.nb_private_by_public)
|
|
):
|
|
range_size = self.nb_private_by_public if int(ip + self.nb_private_by_public) <= self.range_in.last else (self.range_in.last - ip)
|
|
ips[(netaddr.IPRange(ip, ip+range_size-1),)] = ip_out
|
|
for i in range(range_size):
|
|
ip_in = netaddr.IPAddress(ip+i)
|
|
ports[i].add((ip_in,))
|
|
nat_log += '\t'.join((str(ip_out), port_range(i), str(ip_in), '\n'))
|
|
print(nat_log)
|
|
|
|
|
|
ip_map = NetfilterMap(
|
|
target_content=ips,
|
|
type_=('IPv4',),
|
|
name=self.name+'_nat_address',
|
|
table_name='nat',
|
|
flags=('interval',),
|
|
type_from=('IPv4',),
|
|
address_family='ip',
|
|
)
|
|
ip_map.manage()
|
|
|
|
|
|
for i, grp in enumerate(ports):
|
|
grp_set = NetfilterSet(
|
|
name=self.name+'_nat_port_'+str(i),
|
|
target_content=grp,
|
|
type_=('IPv4',),
|
|
table_name='nat',
|
|
address_family='ip',
|
|
)
|
|
grp_set.manage()
|
|
self.create_nat_rule(
|
|
str(i),
|
|
port_range(i)
|
|
)
|
|
|
|
return nat_log
|
|
|
|
|
|
class Firewall:
|
|
"""Manages the firewall using nftables."""
|
|
|
|
@staticmethod
|
|
def manage_sets(sets, address_family=None, table=None, use_sudo=None):
|
|
CONFIG = ConfigParser()
|
|
CONFIG.read('config.ini')
|
|
address_family = address_family or CONFIG['address_family'] or 'inet'
|
|
table = table or CONFIG['table'] or 'filter'
|
|
sudo = use_sudo or (use_sudo is None and CONFIG['use_sudo'])
|
|
for set_ in sets:
|
|
NetfilterSet(
|
|
name=set_['name'],
|
|
type_=set_['type'],
|
|
target_content=set_['content'],
|
|
use_sudo=sudo,
|
|
address_family=address_family,
|
|
table_name=table).manage()
|