Script de management de nftables
This commit is contained in:
parent
adb973b5af
commit
f3ac887ad9
1 changed files with 386 additions and 0 deletions
386
firewall.py
Normal file
386
firewall.py
Normal file
|
@ -0,0 +1,386 @@
|
|||
#!/usr/bin/python3
|
||||
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
# Copyright © 2017 David Sinquin <david.re2o@sinquin.eu>
|
||||
|
||||
|
||||
"""
|
||||
Module for nftables set management.
|
||||
"""
|
||||
|
||||
# Dependencies: python3-netaddr, python3-requests, nftables, sudo (optionnal)
|
||||
#
|
||||
# For sudo configuration, create a file in /etc/sudoers.d/ with:
|
||||
# "<python_user> ALL = (root) NOPASSWD: /sbin/nftables"
|
||||
|
||||
# netaddr :
|
||||
# - https://pypi.python.org/pypi/netaddr/
|
||||
# - https://github.com/drkjam/netaddr/
|
||||
# - https://netaddr.readthedocs.io/en/latest/
|
||||
|
||||
|
||||
import logging
|
||||
import subprocess
|
||||
|
||||
import netaddr # MAC, IPv4, IPv6
|
||||
import requests
|
||||
|
||||
from collections import Iterable
|
||||
|
||||
from config import Config
|
||||
|
||||
|
||||
class ExecError(Exception):
|
||||
"""Simple class to indicate an error in a process execution."""
|
||||
pass
|
||||
|
||||
class CommandExec:
|
||||
"""Simple class to start a command, logging and returning errors if any."""
|
||||
@staticmethod
|
||||
def run_check_output(command, allowed_return_codes=(0,), timeout=15):
|
||||
"""
|
||||
Run a command, logging output in case of an error.
|
||||
|
||||
Actual timeout may be twice the given value in seconds."""
|
||||
logging.debug("Command to be run: '%s'", "' '".join(command))
|
||||
process = subprocess.Popen(
|
||||
command,
|
||||
shell=False,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
universal_newlines=True)
|
||||
try:
|
||||
result = process.communicate(timeout=timeout)
|
||||
return_code = process.wait(timeout=timeout)
|
||||
except subprocess.TimeoutExpired as err:
|
||||
process.kill()
|
||||
raise ExecError from err
|
||||
if return_code not in allowed_return_codes:
|
||||
error_message = ('Error running command: "{}", return code: {}.\n'
|
||||
'Stderr:\n{}\nStdout:\n{}'.format(
|
||||
'" "'.join(command), return_code, *result))
|
||||
logging.error(error_message)
|
||||
raise ExecError(error_message)
|
||||
return (return_code, *result)
|
||||
|
||||
@classmethod
|
||||
def run(cls, *args, **kwargs):
|
||||
"""Run a command without checking outputs."""
|
||||
returncode, _, _ = cls.run_check_output(*args, **kwargs)
|
||||
return returncode
|
||||
|
||||
class Parser:
|
||||
"""Parsers for commonly used formats."""
|
||||
@staticmethod
|
||||
def MAC(mac):
|
||||
"""Check a MAC validity."""
|
||||
return netaddr.EUI(mac, dialect=netaddr.mac_unix_expanded)
|
||||
@staticmethod
|
||||
def IPv4(ip):
|
||||
"""Check an IPv4 validity."""
|
||||
return netaddr.IPAddress(ip, version=4)
|
||||
@staticmethod
|
||||
def IPv6(ip):
|
||||
"""Check a IPv6 validity."""
|
||||
return netaddr.IPAddress(ip, version=6)
|
||||
@staticmethod
|
||||
def protocol(protocol):
|
||||
"""Check a protocol validity."""
|
||||
if protocol in ('tcp', 'udp', 'icmp'):
|
||||
return protocol
|
||||
raise ValueError('Invalid protocol: "{}".'.format(protocol))
|
||||
@staticmethod
|
||||
def port_number(port):
|
||||
"""Check a port validity."""
|
||||
port_number = int(port)
|
||||
if 0 <= port_number < 65536:
|
||||
return port_number
|
||||
raise ValueError('Invalid port number: "{}".'.format(port))
|
||||
|
||||
class NetfilterSet:
|
||||
"""Manage a netfilter set using nftables."""
|
||||
|
||||
TYPES = {'IPv4': 'ipv4_addr', 'IPv6': 'ipv6_addr', 'MAC': 'ether_addr',
|
||||
'protocol': 'inet_proto', 'port': 'inet_service'}
|
||||
|
||||
FILTERS = {'IPv4': Parser.IPv4, 'IPv6': Parser.IPv6, 'MAC': Parser.MAC,
|
||||
'protocol': Parser.protocol, 'port': Parser.port_number}
|
||||
|
||||
ADDRESS_FAMILIES = {'ip', 'ip6', 'inet', 'arp', 'bridge', 'netdev'}
|
||||
|
||||
def __init__(self,
|
||||
name,
|
||||
type_, # e.g.: ('MAC', 'IPv4')
|
||||
target_content=None,
|
||||
use_sudo=True,
|
||||
address_family='inet', # Manage both IPv4 and IPv6.
|
||||
table_name='filter'):
|
||||
self.name = name
|
||||
self.content = set()
|
||||
# self.type
|
||||
self.set_type(type_)
|
||||
self.filters = tuple(self.FILTERS[i] for i in self.type)
|
||||
# self.address_family
|
||||
self.set_address_family(address_family)
|
||||
self.table = table_name
|
||||
sudo = ["/usr/bin/sudo"] * int(bool(use_sudo))
|
||||
self.nft = [*sudo, "/usr/sbin/nft"]
|
||||
if target_content:
|
||||
self._target_content = self.validate_set_data(target_content)
|
||||
else:
|
||||
self._target_content = set()
|
||||
|
||||
@property
|
||||
def target_content(self):
|
||||
return self._target_content.copy() # Forbid in-place modification
|
||||
|
||||
@target_content.setter
|
||||
def target_content(self, target_content):
|
||||
self._target_content = self.validate_set_data(target_content)
|
||||
|
||||
def filter(self, elements):
|
||||
return (self.filters[i](element) for i, element in enumerate(elements))
|
||||
|
||||
def set_type(self, type_):
|
||||
"""Check set type validity and store it along with a type checker."""
|
||||
for element_type in type_:
|
||||
if element_type not in self.TYPES:
|
||||
raise ValueError('Invalid type: "{}".'.format(element_type))
|
||||
self.type = type_
|
||||
|
||||
def set_address_family(self, address_family='ip'):
|
||||
"""Set set addres_family, defaulting to "ip" like nftables."""
|
||||
if address_family not in self.ADDRESS_FAMILIES:
|
||||
raise ValueError(
|
||||
'Invalid address_family: "{}".'.format(address_family))
|
||||
self.address_family = address_family
|
||||
|
||||
def create_in_kernel(self):
|
||||
"""Create the set, removing existing set if needed."""
|
||||
# Delete set if it exists with wrong type
|
||||
current_set = self._get_raw_netfilter_set(parse_elements=False)
|
||||
logging.info(current_set)
|
||||
if current_set is None:
|
||||
self._create_new_set_in_kernel()
|
||||
elif not self.has_type(current_set['type']):
|
||||
self._delete_in_kernel()
|
||||
self._create_new_set_in_kernel()
|
||||
|
||||
def _delete_in_kernel(self):
|
||||
"""Delete the set, table and set must exist."""
|
||||
CommandExec.run([
|
||||
*self.nft,
|
||||
'delete set {addr_family} {table} {set_}'.format(
|
||||
addr_family=self.address_family, table=self.table,
|
||||
set_=self.name)
|
||||
])
|
||||
|
||||
def _create_new_set_in_kernel(self):
|
||||
"""Create the non-existing set, creating table if needed."""
|
||||
create_set = [
|
||||
*self.nft,
|
||||
'add set {addr_family} {table} {set_} {{ type {type_} ; }}'.format(
|
||||
addr_family=self.address_family, table=self.table,
|
||||
set_=self.name,
|
||||
type_=' . '.join(self.TYPES[i] for i in self.type))
|
||||
]
|
||||
return_code = CommandExec.run(create_set, allowed_return_codes=(0, 1))
|
||||
if return_code == 0:
|
||||
return # Set creation successful.
|
||||
# return_code was 1, one error was detected in the rules.
|
||||
# Attempt to create the table first.
|
||||
create_table = [*self.nft, 'add table {addr_family} {table}'.format(
|
||||
addr_family=self.address_family, table=self.table)]
|
||||
CommandExec.run(create_table)
|
||||
CommandExec.run(create_set)
|
||||
|
||||
def validate_set_data(self, set_data):
|
||||
"""
|
||||
Validate data, returning it or raising a ValueError.
|
||||
|
||||
For MAC-IPv4 set, data must be an iterable of (MAC, IPv4) iterables.
|
||||
"""
|
||||
set_ = set()
|
||||
errors = []
|
||||
for n_uplet in set_data:
|
||||
try:
|
||||
set_.add(tuple(self.filter(n_uplet)))
|
||||
except Exception as err:
|
||||
errors.append(err)
|
||||
if errors:
|
||||
raise ValueError(
|
||||
'Error parsing data, encountered the folowing {} errors.\n"{}"'
|
||||
.format(len(errors), '",\n"'.join(map(str, errors))))
|
||||
return set_
|
||||
|
||||
def _apply_target_content(self):
|
||||
"""Change netfilter set content to target set."""
|
||||
current_set = self.get_netfilter_set_content()
|
||||
if current_set is None:
|
||||
raise ValueError('Cannot change "{}" netfilter set content: set '
|
||||
'do not exist in "{}" "{}".'.format(
|
||||
self.name, self.address_family, self.table))
|
||||
to_delete = current_set - self._target_content
|
||||
to_add = self._target_content - current_set
|
||||
self._change_set_content(delete=to_delete, add=to_add)
|
||||
|
||||
def _change_set_content(self, delete=None, add=None):
|
||||
todo = [tuple_ for tuple_ in (('add', add), ('delete', delete))
|
||||
if tuple_[1]]
|
||||
for action, elements in todo:
|
||||
content = ', '.join(' . '.join(str(element) for element in tuple_)
|
||||
for tuple_ in elements)
|
||||
command = [
|
||||
*self.nft,
|
||||
'{action} element {addr_family} {table} {set_} {{{content}}}' \
|
||||
.format(action=action, addr_family=self.address_family,
|
||||
table=self.table, set_=self.name, content=content)
|
||||
]
|
||||
CommandExec.run(command)
|
||||
|
||||
def _get_raw_netfilter_set(self, parse_elements=True):
|
||||
"""Return a dict describing the netfilter set matching self or None."""
|
||||
_, stdout, _ = CommandExec.run_check_output(
|
||||
[*self.nft, '-nn', 'list set {addr_family} {table} {set_}'.format(
|
||||
addr_family=self.address_family, table=self.table,
|
||||
set_=self.name)],
|
||||
allowed_return_codes=(0, 1) # In case table do not exist
|
||||
)
|
||||
if not stdout:
|
||||
return None
|
||||
else:
|
||||
netfilter_set = self._parse_netfilter_set_string(stdout)
|
||||
if netfilter_set['name'] != self.name \
|
||||
or netfilter_set['address_family'] != self.address_family \
|
||||
or netfilter_set['table'] != self.table \
|
||||
or netfilter_set['type'] != [
|
||||
self.TYPES[type_] for type_ in self.type]:
|
||||
raise ValueError('Did not get the right set, too wrong to fix.')
|
||||
if parse_elements:
|
||||
if netfilter_set['raw_content']:
|
||||
netfilter_set['content'] = self.validate_set_data((
|
||||
(element.strip() for element in n_uplet.split(' . '))
|
||||
for n_uplet in netfilter_set['raw_content'].split(',')))
|
||||
else:
|
||||
netfilter_set['content'] = set()
|
||||
return netfilter_set
|
||||
|
||||
@staticmethod
|
||||
def _parse_netfilter_set_string(set_string):
|
||||
"""
|
||||
Parse netfilter set definition and return set as dict.
|
||||
|
||||
Do not validate content type against detected set type.
|
||||
Return a dict with 'name', 'address_family', 'table', 'type',
|
||||
'raw_content' keys (all strings, 'raw_content' can be None).
|
||||
Raise ValueError in case of unexpected syntax.
|
||||
"""
|
||||
# Fragile code since using lexer / parser would be quite heavy
|
||||
lines = [line.lstrip('\t ') for line in set_string.strip().splitlines()]
|
||||
errors = []
|
||||
# 5 lines when empty, 6 with elements = { … }
|
||||
if len(lines) not in (5, 6):
|
||||
errors.append('Error, expecting 5 or 6 lines for set definition, '
|
||||
'got "{}".'.format(set_string))
|
||||
|
||||
line_iterator = iter(lines)
|
||||
set_definition = {}
|
||||
|
||||
line = next(line_iterator).split(' ') # line #1
|
||||
# 'table <address_family> <chain> {'
|
||||
if len(line) != 4 or line[0] != 'table' or line[3] != '{':
|
||||
errors.append(
|
||||
'Cannot parse table definition, expecting "type <addr_family> '
|
||||
'<table> {{", got "{}".'.format(' '.join(line)))
|
||||
else:
|
||||
set_definition['address_family'] = line[1]
|
||||
set_definition['table'] = line[2]
|
||||
|
||||
line = next(line_iterator).split(' ') # line #2
|
||||
# 'set <name> {'
|
||||
if len(line) != 3 or line[0] != 'set' or line[2] != '{':
|
||||
errors.append('Cannot parse set definition, expecting "set <name> '
|
||||
'{{", got "{}".' .format(' '.join(line)))
|
||||
else:
|
||||
set_definition['name'] = line[1]
|
||||
|
||||
line = next(line_iterator).split(' ') # line #3
|
||||
# 'type <type> [. <type>]...'
|
||||
if len(line) < 2 or len(line) % 2 != 0 or line[0] != 'type' \
|
||||
or any(element != '.' for element in line[2::2]):
|
||||
errors.append(
|
||||
'Cannot parse type definition, expecting "type <type> '
|
||||
'[. <type>]...", got "{}".'.format(' '.join(line)))
|
||||
else:
|
||||
set_definition['type'] = line[1::2]
|
||||
|
||||
if len(lines) == 6:
|
||||
# set is not empty, getting raw elements
|
||||
line = next(line_iterator) # Unsplit line #4
|
||||
if line[:13] != 'elements = { ' or line[-1] != '}':
|
||||
errors.append('Cannot parse set elements, expecting "elements '
|
||||
'= {{ <…>}}", got "{}".'.format(line))
|
||||
else:
|
||||
set_definition['raw_content'] = line[13:-1].strip()
|
||||
else:
|
||||
set_definition['raw_content'] = None
|
||||
|
||||
# last two lines
|
||||
for i in range(2):
|
||||
line = next(line_iterator).split(' ')
|
||||
if line != ['}']:
|
||||
errors.append(
|
||||
'No normal end to set definition, expecting "}}" on line '
|
||||
'{}, got "{}".'.format(i+5, ' '.join(line)))
|
||||
if errors:
|
||||
raise ValueError('The following error(s) were encountered while '
|
||||
'parsing set.\n"{}"'.format('",\n"'.join(errors)))
|
||||
return set_definition
|
||||
|
||||
def get_netfilter_set_content(self):
|
||||
"""Return current set content from netfilter."""
|
||||
netfilter_set = self._get_raw_netfilter_set(parse_elements=True)
|
||||
if netfilter_set is None:
|
||||
return None
|
||||
else:
|
||||
return netfilter_set['content']
|
||||
|
||||
def has_type(self, type_):
|
||||
"""Check if some type match the set's one."""
|
||||
return tuple(self.TYPES[t] for t in self.type) == tuple(type_)
|
||||
|
||||
def manage(self):
|
||||
"""Create set if needed and populate it with target content."""
|
||||
self.create_in_kernel()
|
||||
self._apply_target_content()
|
||||
|
||||
|
||||
class Firewall:
|
||||
"""Manages the firewall using nftables."""
|
||||
|
||||
@staticmethod
|
||||
def manage_sets(sets, address_family=None, table=None, use_sudo=None):
|
||||
CONFIG = Config()
|
||||
address_family = address_family or CONFIG['address_family'] or 'inet'
|
||||
table = table or CONFIG['table'] or 'filter'
|
||||
sudo = use_sudo or (use_sudo is None and CONFIG['use_sudo'])
|
||||
for set_ in sets:
|
||||
NetfilterSet(
|
||||
name=set_['name'],
|
||||
type_=set_['type'],
|
||||
target_content=set_['content'],
|
||||
use_sudo=sudo,
|
||||
address_family=address_family,
|
||||
table_name=table).manage()
|
Loading…
Reference in a new issue