diff --git a/firewall.py b/firewall.py
new file mode 100644
index 0000000..621362a
--- /dev/null
+++ b/firewall.py
@@ -0,0 +1,386 @@
+#!/usr/bin/python3
+
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+# Copyright © 2017 David Sinquin
+
+
+"""
+Module for nftables set management.
+"""
+
+# Dependencies: python3-netaddr, python3-requests, nftables, sudo (optionnal)
+#
+# For sudo configuration, create a file in /etc/sudoers.d/ with:
+# " ALL = (root) NOPASSWD: /sbin/nftables"
+
+# netaddr :
+# - https://pypi.python.org/pypi/netaddr/
+# - https://github.com/drkjam/netaddr/
+# - https://netaddr.readthedocs.io/en/latest/
+
+
+import logging
+import subprocess
+
+import netaddr # MAC, IPv4, IPv6
+import requests
+
+from collections import Iterable
+
+from config import Config
+
+
+class ExecError(Exception):
+ """Simple class to indicate an error in a process execution."""
+ pass
+
+class CommandExec:
+ """Simple class to start a command, logging and returning errors if any."""
+ @staticmethod
+ def run_check_output(command, allowed_return_codes=(0,), timeout=15):
+ """
+ Run a command, logging output in case of an error.
+
+ Actual timeout may be twice the given value in seconds."""
+ logging.debug("Command to be run: '%s'", "' '".join(command))
+ process = subprocess.Popen(
+ command,
+ shell=False,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ universal_newlines=True)
+ try:
+ result = process.communicate(timeout=timeout)
+ return_code = process.wait(timeout=timeout)
+ except subprocess.TimeoutExpired as err:
+ process.kill()
+ raise ExecError from err
+ if return_code not in allowed_return_codes:
+ error_message = ('Error running command: "{}", return code: {}.\n'
+ 'Stderr:\n{}\nStdout:\n{}'.format(
+ '" "'.join(command), return_code, *result))
+ logging.error(error_message)
+ raise ExecError(error_message)
+ return (return_code, *result)
+
+ @classmethod
+ def run(cls, *args, **kwargs):
+ """Run a command without checking outputs."""
+ returncode, _, _ = cls.run_check_output(*args, **kwargs)
+ return returncode
+
+class Parser:
+ """Parsers for commonly used formats."""
+ @staticmethod
+ def MAC(mac):
+ """Check a MAC validity."""
+ return netaddr.EUI(mac, dialect=netaddr.mac_unix_expanded)
+ @staticmethod
+ def IPv4(ip):
+ """Check an IPv4 validity."""
+ return netaddr.IPAddress(ip, version=4)
+ @staticmethod
+ def IPv6(ip):
+ """Check a IPv6 validity."""
+ return netaddr.IPAddress(ip, version=6)
+ @staticmethod
+ def protocol(protocol):
+ """Check a protocol validity."""
+ if protocol in ('tcp', 'udp', 'icmp'):
+ return protocol
+ raise ValueError('Invalid protocol: "{}".'.format(protocol))
+ @staticmethod
+ def port_number(port):
+ """Check a port validity."""
+ port_number = int(port)
+ if 0 <= port_number < 65536:
+ return port_number
+ raise ValueError('Invalid port number: "{}".'.format(port))
+
+class NetfilterSet:
+ """Manage a netfilter set using nftables."""
+
+ TYPES = {'IPv4': 'ipv4_addr', 'IPv6': 'ipv6_addr', 'MAC': 'ether_addr',
+ 'protocol': 'inet_proto', 'port': 'inet_service'}
+
+ FILTERS = {'IPv4': Parser.IPv4, 'IPv6': Parser.IPv6, 'MAC': Parser.MAC,
+ 'protocol': Parser.protocol, 'port': Parser.port_number}
+
+ ADDRESS_FAMILIES = {'ip', 'ip6', 'inet', 'arp', 'bridge', 'netdev'}
+
+ def __init__(self,
+ name,
+ type_, # e.g.: ('MAC', 'IPv4')
+ target_content=None,
+ use_sudo=True,
+ address_family='inet', # Manage both IPv4 and IPv6.
+ table_name='filter'):
+ self.name = name
+ self.content = set()
+ # self.type
+ self.set_type(type_)
+ self.filters = tuple(self.FILTERS[i] for i in self.type)
+ # self.address_family
+ self.set_address_family(address_family)
+ self.table = table_name
+ sudo = ["/usr/bin/sudo"] * int(bool(use_sudo))
+ self.nft = [*sudo, "/usr/sbin/nft"]
+ if target_content:
+ self._target_content = self.validate_set_data(target_content)
+ else:
+ self._target_content = set()
+
+ @property
+ def target_content(self):
+ return self._target_content.copy() # Forbid in-place modification
+
+ @target_content.setter
+ def target_content(self, target_content):
+ self._target_content = self.validate_set_data(target_content)
+
+ def filter(self, elements):
+ return (self.filters[i](element) for i, element in enumerate(elements))
+
+ def set_type(self, type_):
+ """Check set type validity and store it along with a type checker."""
+ for element_type in type_:
+ if element_type not in self.TYPES:
+ raise ValueError('Invalid type: "{}".'.format(element_type))
+ self.type = type_
+
+ def set_address_family(self, address_family='ip'):
+ """Set set addres_family, defaulting to "ip" like nftables."""
+ if address_family not in self.ADDRESS_FAMILIES:
+ raise ValueError(
+ 'Invalid address_family: "{}".'.format(address_family))
+ self.address_family = address_family
+
+ def create_in_kernel(self):
+ """Create the set, removing existing set if needed."""
+ # Delete set if it exists with wrong type
+ current_set = self._get_raw_netfilter_set(parse_elements=False)
+ logging.info(current_set)
+ if current_set is None:
+ self._create_new_set_in_kernel()
+ elif not self.has_type(current_set['type']):
+ self._delete_in_kernel()
+ self._create_new_set_in_kernel()
+
+ def _delete_in_kernel(self):
+ """Delete the set, table and set must exist."""
+ CommandExec.run([
+ *self.nft,
+ 'delete set {addr_family} {table} {set_}'.format(
+ addr_family=self.address_family, table=self.table,
+ set_=self.name)
+ ])
+
+ def _create_new_set_in_kernel(self):
+ """Create the non-existing set, creating table if needed."""
+ create_set = [
+ *self.nft,
+ 'add set {addr_family} {table} {set_} {{ type {type_} ; }}'.format(
+ addr_family=self.address_family, table=self.table,
+ set_=self.name,
+ type_=' . '.join(self.TYPES[i] for i in self.type))
+ ]
+ return_code = CommandExec.run(create_set, allowed_return_codes=(0, 1))
+ if return_code == 0:
+ return # Set creation successful.
+ # return_code was 1, one error was detected in the rules.
+ # Attempt to create the table first.
+ create_table = [*self.nft, 'add table {addr_family} {table}'.format(
+ addr_family=self.address_family, table=self.table)]
+ CommandExec.run(create_table)
+ CommandExec.run(create_set)
+
+ def validate_set_data(self, set_data):
+ """
+ Validate data, returning it or raising a ValueError.
+
+ For MAC-IPv4 set, data must be an iterable of (MAC, IPv4) iterables.
+ """
+ set_ = set()
+ errors = []
+ for n_uplet in set_data:
+ try:
+ set_.add(tuple(self.filter(n_uplet)))
+ except Exception as err:
+ errors.append(err)
+ if errors:
+ raise ValueError(
+ 'Error parsing data, encountered the folowing {} errors.\n"{}"'
+ .format(len(errors), '",\n"'.join(map(str, errors))))
+ return set_
+
+ def _apply_target_content(self):
+ """Change netfilter set content to target set."""
+ current_set = self.get_netfilter_set_content()
+ if current_set is None:
+ raise ValueError('Cannot change "{}" netfilter set content: set '
+ 'do not exist in "{}" "{}".'.format(
+ self.name, self.address_family, self.table))
+ to_delete = current_set - self._target_content
+ to_add = self._target_content - current_set
+ self._change_set_content(delete=to_delete, add=to_add)
+
+ def _change_set_content(self, delete=None, add=None):
+ todo = [tuple_ for tuple_ in (('add', add), ('delete', delete))
+ if tuple_[1]]
+ for action, elements in todo:
+ content = ', '.join(' . '.join(str(element) for element in tuple_)
+ for tuple_ in elements)
+ command = [
+ *self.nft,
+ '{action} element {addr_family} {table} {set_} {{{content}}}' \
+ .format(action=action, addr_family=self.address_family,
+ table=self.table, set_=self.name, content=content)
+ ]
+ CommandExec.run(command)
+
+ def _get_raw_netfilter_set(self, parse_elements=True):
+ """Return a dict describing the netfilter set matching self or None."""
+ _, stdout, _ = CommandExec.run_check_output(
+ [*self.nft, '-nn', 'list set {addr_family} {table} {set_}'.format(
+ addr_family=self.address_family, table=self.table,
+ set_=self.name)],
+ allowed_return_codes=(0, 1) # In case table do not exist
+ )
+ if not stdout:
+ return None
+ else:
+ netfilter_set = self._parse_netfilter_set_string(stdout)
+ if netfilter_set['name'] != self.name \
+ or netfilter_set['address_family'] != self.address_family \
+ or netfilter_set['table'] != self.table \
+ or netfilter_set['type'] != [
+ self.TYPES[type_] for type_ in self.type]:
+ raise ValueError('Did not get the right set, too wrong to fix.')
+ if parse_elements:
+ if netfilter_set['raw_content']:
+ netfilter_set['content'] = self.validate_set_data((
+ (element.strip() for element in n_uplet.split(' . '))
+ for n_uplet in netfilter_set['raw_content'].split(',')))
+ else:
+ netfilter_set['content'] = set()
+ return netfilter_set
+
+ @staticmethod
+ def _parse_netfilter_set_string(set_string):
+ """
+ Parse netfilter set definition and return set as dict.
+
+ Do not validate content type against detected set type.
+ Return a dict with 'name', 'address_family', 'table', 'type',
+ 'raw_content' keys (all strings, 'raw_content' can be None).
+ Raise ValueError in case of unexpected syntax.
+ """
+ # Fragile code since using lexer / parser would be quite heavy
+ lines = [line.lstrip('\t ') for line in set_string.strip().splitlines()]
+ errors = []
+ # 5 lines when empty, 6 with elements = { … }
+ if len(lines) not in (5, 6):
+ errors.append('Error, expecting 5 or 6 lines for set definition, '
+ 'got "{}".'.format(set_string))
+
+ line_iterator = iter(lines)
+ set_definition = {}
+
+ line = next(line_iterator).split(' ') # line #1
+ # 'table {'
+ if len(line) != 4 or line[0] != 'table' or line[3] != '{':
+ errors.append(
+ 'Cannot parse table definition, expecting "type '
+ ' {{", got "{}".'.format(' '.join(line)))
+ else:
+ set_definition['address_family'] = line[1]
+ set_definition['table'] = line[2]
+
+ line = next(line_iterator).split(' ') # line #2
+ # 'set {'
+ if len(line) != 3 or line[0] != 'set' or line[2] != '{':
+ errors.append('Cannot parse set definition, expecting "set '
+ '{{", got "{}".' .format(' '.join(line)))
+ else:
+ set_definition['name'] = line[1]
+
+ line = next(line_iterator).split(' ') # line #3
+ # 'type [. ]...'
+ if len(line) < 2 or len(line) % 2 != 0 or line[0] != 'type' \
+ or any(element != '.' for element in line[2::2]):
+ errors.append(
+ 'Cannot parse type definition, expecting "type '
+ '[. ]...", got "{}".'.format(' '.join(line)))
+ else:
+ set_definition['type'] = line[1::2]
+
+ if len(lines) == 6:
+ # set is not empty, getting raw elements
+ line = next(line_iterator) # Unsplit line #4
+ if line[:13] != 'elements = { ' or line[-1] != '}':
+ errors.append('Cannot parse set elements, expecting "elements '
+ '= {{ <…>}}", got "{}".'.format(line))
+ else:
+ set_definition['raw_content'] = line[13:-1].strip()
+ else:
+ set_definition['raw_content'] = None
+
+ # last two lines
+ for i in range(2):
+ line = next(line_iterator).split(' ')
+ if line != ['}']:
+ errors.append(
+ 'No normal end to set definition, expecting "}}" on line '
+ '{}, got "{}".'.format(i+5, ' '.join(line)))
+ if errors:
+ raise ValueError('The following error(s) were encountered while '
+ 'parsing set.\n"{}"'.format('",\n"'.join(errors)))
+ return set_definition
+
+ def get_netfilter_set_content(self):
+ """Return current set content from netfilter."""
+ netfilter_set = self._get_raw_netfilter_set(parse_elements=True)
+ if netfilter_set is None:
+ return None
+ else:
+ return netfilter_set['content']
+
+ def has_type(self, type_):
+ """Check if some type match the set's one."""
+ return tuple(self.TYPES[t] for t in self.type) == tuple(type_)
+
+ def manage(self):
+ """Create set if needed and populate it with target content."""
+ self.create_in_kernel()
+ self._apply_target_content()
+
+
+class Firewall:
+ """Manages the firewall using nftables."""
+
+ @staticmethod
+ def manage_sets(sets, address_family=None, table=None, use_sudo=None):
+ CONFIG = Config()
+ address_family = address_family or CONFIG['address_family'] or 'inet'
+ table = table or CONFIG['table'] or 'filter'
+ sudo = use_sudo or (use_sudo is None and CONFIG['use_sudo'])
+ for set_ in sets:
+ NetfilterSet(
+ name=set_['name'],
+ type_=set_['type'],
+ target_content=set_['content'],
+ use_sudo=sudo,
+ address_family=address_family,
+ table_name=table).manage()