firewall/firewall.py
Hugo Levy-Falk ec80954927 MAC-IP table
2019-03-12 22:06:28 +01:00

651 lines
26 KiB
Python
Executable file

#!/usr/bin/python3
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
# Copyright © 2017 David Sinquin <david.re2o@sinquin.eu>
"""
Module for nftables set management.
"""
# Dependencies: python3-netaddr, python3-requests, nftables, sudo (optionnal)
#
# For sudo configuration, create a file in /etc/sudoers.d/ with:
# "<python_user> ALL = (root) NOPASSWD: /sbin/nftables"
# netaddr :
# - https://pypi.python.org/pypi/netaddr/
# - https://github.com/drkjam/netaddr/
# - https://netaddr.readthedocs.io/en/latest/
import logging
import subprocess
import re
import netaddr # MAC, IPv4, IPv6
import requests
from collections import Iterable
from configparser import ConfigParser
class ExecError(Exception):
"""Simple class to indicate an error in a process execution."""
pass
class CommandExec:
"""Simple class to start a command, logging and returning errors if any."""
@staticmethod
def run_check_output(command, allowed_return_codes=(0,), timeout=15):
"""
Run a command, logging output in case of an error.
Actual timeout may be twice the given value in seconds."""
logging.debug("Command to be run: '%s'", "' '".join(command))
process = subprocess.Popen(
command,
shell=False,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True)
try:
result = process.communicate(timeout=timeout)
return_code = process.wait(timeout=timeout)
except subprocess.TimeoutExpired as err:
process.kill()
raise ExecError from err
if return_code not in allowed_return_codes:
error_message = ('Error running command: "{}", return code: {}.\n'
'Stderr:\n{}\nStdout:\n{}'.format(
'" "'.join(command), return_code, *result))
logging.error(error_message)
raise ExecError(error_message)
return (return_code, *result)
@classmethod
def run(cls, *args, **kwargs):
"""Run a command without checking outputs."""
returncode, _, _ = cls.run_check_output(*args, **kwargs)
return returncode
class Parser:
"""Parsers for commonly used formats."""
@staticmethod
def MAC(mac):
"""Check a MAC validity."""
return netaddr.EUI(mac, dialect=netaddr.mac_unix_expanded)
@staticmethod
def IPv4(ip):
"""Check an IPv4 validity.
Args:
ip: can either be a tuple (in this case returns an IPRange), a
single IP address or a IP Network.
"""
try:
return netaddr.IPAddress(ip, version=4)
except netaddr.core.AddrFormatError:
try:
return netaddr.IPNetwork(ip, version=4)
except netaddr.core.AddrFormatError:
begin, end = ip.split('-')
return netaddr.IPRange(begin, end)
@staticmethod
def IPv6(ip):
"""Check a IPv6 validity.
Args:
ip: can either be a tuple (in this case returns an IPRange), a
single IP address or a IP Network.
"""
if isinstance(ip, tuple):
begin, end = ip
return netaddr.IPRange(begin, end, version=6)
try:
return netaddr.IPAddress(ip, version=6)
except ValueError:
return netaddr.IPNetwork(ip, version=6)
@staticmethod
def protocol(protocol):
"""Check a protocol validity."""
if protocol in ('tcp', 'udp', 'icmp'):
return protocol
raise ValueError('Invalid protocol: "{}".'.format(protocol))
@staticmethod
def port_number(port):
"""Check a port validity."""
try:
port_number = int(port)
if 0 <= port_number < 65536:
return port_number
except ValueError:
begin, end = port.split('-')
begin, end = int(begin), int(end)
if 0 <= begin < end <= 65536:
return port
raise ValueError('Invalid port number: "{}".'.format(port))
class NetfilterSet:
"""Manage a netfilter set using nftables."""
TYPES = {'IPv4': 'ipv4_addr', 'IPv6': 'ipv6_addr', 'MAC': 'ether_addr',
'protocol': 'inet_proto', 'port': 'inet_service'}
FILTERS = {'IPv4': Parser.IPv4, 'IPv6': Parser.IPv6, 'MAC': Parser.MAC,
'protocol': Parser.protocol, 'port': Parser.port_number}
ADDRESS_FAMILIES = {'ip', 'ip6', 'inet', 'arp', 'bridge', 'netdev'}
FLAGS = {'constant', 'interval', 'timeout'}
NFT_TYPE = {'set', 'map'}
def __init__(self,
name,
type_, # e.g.: ('MAC', 'IPv4')
target_content=None,
use_sudo=True,
address_family='inet', # Manage both IPv4 and IPv6.
table_name='filter',
flags = [],
type_from=None
):
self.name = name
self.content = set()
# self.type
self.set_type(type_)
if type_from:
self.set_type_from(type_from)
self.nft_type = 'map'
self.key_filters = tuple(self.FILTERS[i] for i in self.type_from)
else:
self.nft_type = 'set'
self.filters = tuple(self.FILTERS[i] for i in self.type)
self.set_flags(flags)
# self.address_family
self.set_address_family(address_family)
self.table = table_name
sudo = ["/usr/bin/sudo"] * int(bool(use_sudo))
self.nft = [*sudo, "/usr/sbin/nft"]
if target_content and self.nft_type == 'set':
self._target_content = self.validate_set_data(target_content)
elif target_content and self.nft_type == 'map':
self._target_content = self.validate_map_data(target_content)
else:
self._target_content = set()
@property
def target_content(self):
return self._target_content.copy() # Forbid in-place modification
@target_content.setter
def target_content(self, target_content):
self._target_content = self.validate_set_data(target_content)
def filter(self, elements):
return (self.filters[i](element) for i, element in enumerate(elements))
def filter_key(self, elements):
return (self.key_filters[i](element) for i, element in enumerate(elements))
def set_type(self, type_):
"""Check set type validity and store it along with a type checker."""
for element_type in type_:
if element_type not in self.TYPES:
raise ValueError('Invalid type: "{}".'.format(element_type))
self.type = type_
def set_type_from(self, type_):
"""Check set type validity and store it along with a type checker."""
for element_type in type_:
if element_type not in self.TYPES:
raise ValueError('Invalid type: "{}".'.format(element_type))
self.type_from = type_
def set_address_family(self, address_family='ip'):
"""Set set addres_family, defaulting to "ip" like nftables."""
if address_family not in self.ADDRESS_FAMILIES:
raise ValueError(
'Invalid address_family: "{}".'.format(address_family))
self.address_family = address_family
def set_flags(self, flags_):
"""Check set flags validity before saving them."""
for f in flags_:
if f not in self.FLAGS:
raise ValueError('Invalid flag: "{}".'.format(f))
self.flags = set(flags_)
def create_in_kernel(self):
"""Create the set, removing existing set if needed."""
# Delete set if it exists with wrong type
current_set = self._get_raw_netfilter_set(parse_elements=False)
logging.info(current_set)
if current_set is None:
self._create_new_set_in_kernel()
elif not self.has_type(current_set['type']):
self._delete_in_kernel()
self._create_new_set_in_kernel()
def _delete_in_kernel(self):
"""Delete the set, table and set must exist."""
CommandExec.run([
*self.nft,
'delete {nft_type} {addr_family} {table} {set_}'.format(
nft_type=self.nft_type,
addr_family=self.address_family, table=self.table,
set_=self.name)
])
def _create_new_set_in_kernel(self):
"""Create the non-existing set, creating table if needed."""
if self.flags:
nft_command = 'add {nft_type} {addr_family} {table} {set_} {{ type {type_} ; flags {flags};}}'.format(
nft_type=self.nft_type,
addr_family=self.address_family,
table=self.table,
set_=self.name,
type_=self.format_type(),
flags=', '.join(self.flags)
)
else:
nft_command = 'add {nft_type} {addr_family} {table} {set_} {{ type {type_} ;}}'.format(
nft_type=self.nft_type,
addr_family=self.address_family,
table=self.table,
set_=self.name,
type_=self.format_type(),
)
create_set = [
*self.nft,
nft_command
]
return_code = CommandExec.run(create_set, allowed_return_codes=(0, 1))
if return_code == 0:
return # Set creation successful.
# return_code was 1, one error was detected in the rules.
# Attempt to create the table first.
create_table = [*self.nft, 'add table {addr_family} {table}'.format(
addr_family=self.address_family, table=self.table)]
CommandExec.run(create_table)
CommandExec.run(create_set)
def validate_set_data(self, set_data):
"""
Validate data, returning it or raising a ValueError.
For MAC-IPv4 set, data must be an iterable of (MAC, IPv4) iterables.
"""
set_ = set()
errors = []
for n_uplet in set_data:
try:
set_.add(tuple(self.filter(n_uplet)))
except Exception as err:
errors.append(err)
if errors:
raise ValueError(
'Error parsing data, encountered the folowing {} errors.\n"{}"'
.format(len(errors), '",\n"'.join(map(str, errors))))
return set_
def validate_map_data(self, dict_data):
"""
Validate data, returning it or raising a ValueError.
For MAC-IPv4 set, data must be an iterable of (MAC, IPv4) iterables.
"""
set_ = {}
errors = []
for key in dict_data:
try:
set_[tuple(self.filter_key(key))] = tuple(self.filter(dict_data[key]))
except Exception as err:
errors.append(err)
if errors:
raise ValueError(
'Error parsing data, encountered the folowing {} errors.\n"{}"'
.format(len(errors), '",\n"'.join(map(str, errors))))
return set_
def _apply_target_content(self):
"""Change netfilter content to target set."""
if self.nft_type == 'set':
self._apply_target_content_set()
else:
self._apply_target_content_map()
def _apply_target_content_set(self):
"""Change netfilter set content to target set."""
current_set = self.get_netfilter_set_content()
if current_set is None:
raise ValueError('Cannot change "{}" netfilter set content: set '
'do not exist in "{}" "{}".'.format(
self.name, self.address_family, self.table))
to_delete = current_set - self._target_content
to_add = self._target_content - current_set
self._change_set_content(delete=to_delete, add=to_add)
def _apply_target_content_map(self):
"""Change netfilter set content to target set."""
current_map = self.get_netfilter_map_content()
if current_map is None:
raise ValueError('Cannot change "{}" netfilter map content: map '
'do not exist in "{}" "{}".'.format(
self.name, self.address_family, self.table))
keys_to_delete = current_map.keys() - self._target_content.keys()
keys_to_add = self._target_content.keys() - current_map.keys()
keys_to_check = current_map.keys() & self._target_content.keys()
for k in keys_to_check:
if current_map[k] != self._target_content[k]:
keys_to_add.add(k)
keys_to_delete.add(k)
to_add = {k : self._target_content[k] for k in keys_to_add}
self._change_map_content(delete=keys_to_delete, add=to_add)
def _change_set_content(self, delete=None, add=None):
todo = [tuple_ for tuple_ in (('add', add), ('delete', delete))
if tuple_[1]]
for action, elements in todo:
content = ', '.join(' . '.join(str(element) for element in tuple_)
for tuple_ in elements)
command = [
*self.nft,
'{action} element {addr_family} {table} {set_} {{{content}}}' \
.format(action=action, addr_family=self.address_family,
table=self.table, set_=self.name, content=content)
]
CommandExec.run(command)
def _change_map_content(self, delete=None, add=None):
if delete:
content = ', '.join(' . '.join(str(element) for element in tuple_)
for tuple_ in delete)
command = [
*self.nft,
'delete element {addr_family} {table} {set_} {{{content}}}' \
.format(addr_family=self.address_family,
table=self.table, set_=self.name, content=content)
]
CommandExec.run(command)
if add:
content = ', '.join(
' . '.join(str(element) for element in tuple_)
+ ' : '
+ ' . '.join(str(element) for element in add[tuple_])
for tuple_ in add
)
command = [
*self.nft,
'add element {addr_family} {table} {set_} {{{content}}}' \
.format(addr_family=self.address_family,
table=self.table, set_=self.name, content=content)
]
CommandExec.run(command)
def _get_raw_netfilter_set(self, parse_elements=True):
"""Return a dict describing the netfilter set matching self or None."""
_, stdout, _ = CommandExec.run_check_output(
[*self.nft, '-nn', 'list set {addr_family} {table} {set_}'.format(
addr_family=self.address_family, table=self.table,
set_=self.name)],
allowed_return_codes=(0, 1) # In case table do not exist
)
if not stdout:
return None
else:
netfilter_set = self._parse_netfilter_set_string(stdout)
if netfilter_set['name'] != self.name \
or netfilter_set['address_family'] != self.address_family \
or netfilter_set['table'] != self.table \
or not self.has_type(netfilter_set['type']) \
or netfilter_set.get('flags', set()) != self.flags:
raise ValueError(
'Did not get the right set, too wrong to fix. Got '
+ str(netfilter_set)
+ ("\nExpected : "
"\n\tname: {name}"
"\n\taddress_family: {family}"
"\n\ttable: {table}"
"\n\tflags: {flags}"
"\n\ttypes: {types}"
).format(
name=self.name,
family=self.address_family,
table=self.table,
flags=self.flags,
types=tuple(self.TYPES[t] for t in self.type)
)
)
if parse_elements:
if netfilter_set['raw_content']:
netfilter_set['content'] = self.validate_set_data((
(element.strip() for element in n_uplet.split(' . '))
for n_uplet in netfilter_set['raw_content'].split(',')))
else:
netfilter_set['content'] = set()
return netfilter_set
def _get_raw_netfilter_map(self, parse_elements=True):
"""Return a dict describing the netfilter map matching self or None."""
_, stdout, _ = CommandExec.run_check_output(
[*self.nft, '-nn', 'list map {addr_family} {table} {set_}'.format(
addr_family=self.address_family, table=self.table,
set_=self.name)],
allowed_return_codes=(0, 1) # In case table do not exist
)
if not stdout:
return None
else:
netfilter_set = self._parse_netfilter_map_string(stdout)
if netfilter_set['name'] != self.name \
or netfilter_set['address_family'] != self.address_family \
or netfilter_set['table'] != self.table \
or not self.has_type(netfilter_set['type']):
raise ValueError('Did not get the right map, too wrong to fix.')
if parse_elements:
if netfilter_set['raw_content']:
netfilter_set['content'] = self.validate_map_data({
(element.strip() for element in n_uplet.split(' : ')[0].split(' . ')) :
(element.strip() for element in n_uplet.split(' : ')[1].split(' . '))
for n_uplet in netfilter_set['raw_content'].split(',')
})
else:
netfilter_set['content'] = {}
return netfilter_set
@staticmethod
def _parse_netfilter_set_string(set_string):
"""
Parse netfilter set definition and return set as dict.
Do not validate content type against detected set type.
Return a dict with 'name', 'address_family', 'table', 'type', 'flags',
'raw_content' keys (all strings, 'raw_content' can be None).
Raise ValueError in case of unexpected syntax.
"""
# A.K.A. Really, I don't hate you, so please don't hate me...
regexp = (
"table (?P<address_family>\w+)+ (?P<table>\w+) \{\n"
"\s*set (?P<name>\w+) \{\n"
"\s*type (?P<type>(\w+( \. )?)+)\n"
"(\s*elements = \{ "
"(?P<elements>((\n\s*)?([\w:\.]+( \. )?)+,?)*) "
"\}\n)?"
"\s*\}\n"
"\s*\}"
)
values = re.match(regexp, set_string).groupdict()
return {
'address_family': values['address_family'],
'table': values['table'],
'name': values['name'],
'type': values['type'].split(' . '),
'raw_content': values['elements'],
}
@staticmethod
def _parse_netfilter_map_string(set_string):
"""
Parse netfilter map definition and return map as dict.
Do not validate content type against detected map type.
Return a dict with 'name', 'address_family', 'table', 'type', 'flags'
'raw_content' keys (all strings, 'raw_content' can be None).
Raise ValueError in case of unexpected syntax.
"""
# Fragile code since using lexer / parser would be quite heavy
lines = [line.lstrip('\t ') for line in set_string.strip().splitlines()]
errors = []
# 5 lines when empty, 6 with elements = { … } + one for flags
if len(lines) not in (5, 6, 7):
errors.append('Error, expecting 5 or 6 lines for set definition, '
'got "{}".'.format(set_string))
line_iterator = iter(lines)
set_definition = {}
line = next(line_iterator).split(' ') # line #1
# 'table <address_family> <chain> {'
if len(line) != 4 or line[0] != 'table' or line[3] != '{':
errors.append(
'Cannot parse table definition, expecting "type <addr_family> '
'<table> {{", got "{}".'.format(' '.join(line)))
else:
set_definition['address_family'] = line[1]
set_definition['table'] = line[2]
line = next(line_iterator).split(' ') # line #2
# 'set <name> {'
if len(line) != 3 or line[0] != 'map' or line[2] != '{':
errors.append('Cannot parse set definition, expecting "set <name> '
'{{", got "{}".' .format(' '.join(line)))
else:
set_definition['name'] = line[1]
line, elements_type = next(line_iterator).split(' : ') # line #3
# 'type <type> [. <type>]... : <type> [. <type>]...'
line = line.split(' ')
if len(line) < 2:
errors.append(
'Cannot parse type definition, left side of \':\' is too short : %s' % line
)
type_, keys_type = line[0], line[1:]
elements_type = elements_type.split(' ')
if type_ != 'type':
errors.append(
'Cannot parse type definition, expected first word \'type\', got %s' % type_
)
elif len(elements_type) % 2 != 1 or len(keys_type) % 2 != 1 \
or any(e != '.' for e in elements_type[1::2]) \
or any(e != '.' for e in keys_type[1::2]):
errors.append(
'Cannot parse type definition, expecting "type <type> '
'[. <type>]... : <type> [. <type>]...", got "{}".'.format(' '.join(line)))
else:
set_definition['type'] = (keys_type[::2], elements_type[::2])
# here we can have the flags, if there are any
# flags <flag_1>, <flag_2>, ...
if len(lines) >= 6:
line = next(line_iterator)
if line[:5] == 'flags': # If there are actually flags
set_definition['flags'] = {f.strip() for f in line[:5].strip().split(',')}
if len(lines) >= 6:
# set is not empty, getting raw elements
if 'flags' in set_definition and len(lines) == 7: # the line unsplitted previously has been used.
line = next(line_iterator) # Unsplit line #4
print(line)
if ('flags' in set_definition and len(lines)==7) or ('flags' not in set_definition and len(lines)==6) :
if line[:13] != 'elements = { ' or line[-1] != '}':
errors.append('Cannot parse set elements, expecting "elements '
'= {{ <…>}}", got "{}".'.format(line))
else:
set_definition['raw_content'] = line[13:-1].strip()
else:
set_definition['raw_content'] = None
else:
set_definition['raw_content'] = None
# last two lines
for i in range(2):
line = next(line_iterator).split(' ')
if line != ['}']:
errors.append(
'No normal end to set definition, expecting "}}" on line '
'{}, got "{}".'.format(i+5, ' '.join(line)))
if errors:
raise ValueError('The following error(s) were encountered while '
'parsing set.\n"{}"'.format('",\n"'.join(errors)))
return set_definition
def get_netfilter_set_content(self):
"""Return current set content from netfilter."""
netfilter_set = self._get_raw_netfilter_set(parse_elements=True)
if netfilter_set is None:
return None
else:
return netfilter_set['content']
def get_netfilter_map_content(self):
"""Return current set content from netfilter."""
netfilter_set = self._get_raw_netfilter_map(parse_elements=True)
if netfilter_set is None:
return None
else:
return netfilter_set['content']
def has_type(self, type_):
"""Check if some type match the set's one."""
if self.nft_type == 'set':
return tuple(self.TYPES[t] for t in self.type) == tuple(type_)
else:
return tuple(self.TYPES[t] for t in self.type) == tuple(type_[1]) and \
tuple(self.TYPES[t] for t in self.type_from) == tuple(type_[0])
def manage(self):
"""Create set if needed and populate it with target content."""
self.create_in_kernel()
self._apply_target_content()
def format_type(self):
if self.nft_type == 'set':
return ' . '.join(self.TYPES[i] for i in self.type)
else:
return ' . '.join(self.TYPES[i] for i in self.type_from) + ' : ' + ' . '.join(self.TYPES[i] for i in self.type)
class Firewall:
"""Manages the firewall using nftables."""
@staticmethod
def manage_sets(sets, address_family=None, table=None, use_sudo=None):
CONFIG = ConfigParser()
CONFIG.read('config.ini')
address_family = address_family or CONFIG['address_family'] or 'inet'
table = table or CONFIG['table'] or 'filter'
sudo = use_sudo or (use_sudo is None and CONFIG['use_sudo'])
for set_ in sets:
NetfilterSet(
name=set_['name'],
type_=set_['type'],
target_content=set_['content'],
use_sudo=sudo,
address_family=address_family,
table_name=table).manage()