diff --git a/firewall.py b/firewall.py new file mode 100644 index 0000000..621362a --- /dev/null +++ b/firewall.py @@ -0,0 +1,386 @@ +#!/usr/bin/python3 + +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +# Copyright © 2017 David Sinquin + + +""" +Module for nftables set management. +""" + +# Dependencies: python3-netaddr, python3-requests, nftables, sudo (optionnal) +# +# For sudo configuration, create a file in /etc/sudoers.d/ with: +# " ALL = (root) NOPASSWD: /sbin/nftables" + +# netaddr : +# - https://pypi.python.org/pypi/netaddr/ +# - https://github.com/drkjam/netaddr/ +# - https://netaddr.readthedocs.io/en/latest/ + + +import logging +import subprocess + +import netaddr # MAC, IPv4, IPv6 +import requests + +from collections import Iterable + +from config import Config + + +class ExecError(Exception): + """Simple class to indicate an error in a process execution.""" + pass + +class CommandExec: + """Simple class to start a command, logging and returning errors if any.""" + @staticmethod + def run_check_output(command, allowed_return_codes=(0,), timeout=15): + """ + Run a command, logging output in case of an error. + + Actual timeout may be twice the given value in seconds.""" + logging.debug("Command to be run: '%s'", "' '".join(command)) + process = subprocess.Popen( + command, + shell=False, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=True) + try: + result = process.communicate(timeout=timeout) + return_code = process.wait(timeout=timeout) + except subprocess.TimeoutExpired as err: + process.kill() + raise ExecError from err + if return_code not in allowed_return_codes: + error_message = ('Error running command: "{}", return code: {}.\n' + 'Stderr:\n{}\nStdout:\n{}'.format( + '" "'.join(command), return_code, *result)) + logging.error(error_message) + raise ExecError(error_message) + return (return_code, *result) + + @classmethod + def run(cls, *args, **kwargs): + """Run a command without checking outputs.""" + returncode, _, _ = cls.run_check_output(*args, **kwargs) + return returncode + +class Parser: + """Parsers for commonly used formats.""" + @staticmethod + def MAC(mac): + """Check a MAC validity.""" + return netaddr.EUI(mac, dialect=netaddr.mac_unix_expanded) + @staticmethod + def IPv4(ip): + """Check an IPv4 validity.""" + return netaddr.IPAddress(ip, version=4) + @staticmethod + def IPv6(ip): + """Check a IPv6 validity.""" + return netaddr.IPAddress(ip, version=6) + @staticmethod + def protocol(protocol): + """Check a protocol validity.""" + if protocol in ('tcp', 'udp', 'icmp'): + return protocol + raise ValueError('Invalid protocol: "{}".'.format(protocol)) + @staticmethod + def port_number(port): + """Check a port validity.""" + port_number = int(port) + if 0 <= port_number < 65536: + return port_number + raise ValueError('Invalid port number: "{}".'.format(port)) + +class NetfilterSet: + """Manage a netfilter set using nftables.""" + + TYPES = {'IPv4': 'ipv4_addr', 'IPv6': 'ipv6_addr', 'MAC': 'ether_addr', + 'protocol': 'inet_proto', 'port': 'inet_service'} + + FILTERS = {'IPv4': Parser.IPv4, 'IPv6': Parser.IPv6, 'MAC': Parser.MAC, + 'protocol': Parser.protocol, 'port': Parser.port_number} + + ADDRESS_FAMILIES = {'ip', 'ip6', 'inet', 'arp', 'bridge', 'netdev'} + + def __init__(self, + name, + type_, # e.g.: ('MAC', 'IPv4') + target_content=None, + use_sudo=True, + address_family='inet', # Manage both IPv4 and IPv6. + table_name='filter'): + self.name = name + self.content = set() + # self.type + self.set_type(type_) + self.filters = tuple(self.FILTERS[i] for i in self.type) + # self.address_family + self.set_address_family(address_family) + self.table = table_name + sudo = ["/usr/bin/sudo"] * int(bool(use_sudo)) + self.nft = [*sudo, "/usr/sbin/nft"] + if target_content: + self._target_content = self.validate_set_data(target_content) + else: + self._target_content = set() + + @property + def target_content(self): + return self._target_content.copy() # Forbid in-place modification + + @target_content.setter + def target_content(self, target_content): + self._target_content = self.validate_set_data(target_content) + + def filter(self, elements): + return (self.filters[i](element) for i, element in enumerate(elements)) + + def set_type(self, type_): + """Check set type validity and store it along with a type checker.""" + for element_type in type_: + if element_type not in self.TYPES: + raise ValueError('Invalid type: "{}".'.format(element_type)) + self.type = type_ + + def set_address_family(self, address_family='ip'): + """Set set addres_family, defaulting to "ip" like nftables.""" + if address_family not in self.ADDRESS_FAMILIES: + raise ValueError( + 'Invalid address_family: "{}".'.format(address_family)) + self.address_family = address_family + + def create_in_kernel(self): + """Create the set, removing existing set if needed.""" + # Delete set if it exists with wrong type + current_set = self._get_raw_netfilter_set(parse_elements=False) + logging.info(current_set) + if current_set is None: + self._create_new_set_in_kernel() + elif not self.has_type(current_set['type']): + self._delete_in_kernel() + self._create_new_set_in_kernel() + + def _delete_in_kernel(self): + """Delete the set, table and set must exist.""" + CommandExec.run([ + *self.nft, + 'delete set {addr_family} {table} {set_}'.format( + addr_family=self.address_family, table=self.table, + set_=self.name) + ]) + + def _create_new_set_in_kernel(self): + """Create the non-existing set, creating table if needed.""" + create_set = [ + *self.nft, + 'add set {addr_family} {table} {set_} {{ type {type_} ; }}'.format( + addr_family=self.address_family, table=self.table, + set_=self.name, + type_=' . '.join(self.TYPES[i] for i in self.type)) + ] + return_code = CommandExec.run(create_set, allowed_return_codes=(0, 1)) + if return_code == 0: + return # Set creation successful. + # return_code was 1, one error was detected in the rules. + # Attempt to create the table first. + create_table = [*self.nft, 'add table {addr_family} {table}'.format( + addr_family=self.address_family, table=self.table)] + CommandExec.run(create_table) + CommandExec.run(create_set) + + def validate_set_data(self, set_data): + """ + Validate data, returning it or raising a ValueError. + + For MAC-IPv4 set, data must be an iterable of (MAC, IPv4) iterables. + """ + set_ = set() + errors = [] + for n_uplet in set_data: + try: + set_.add(tuple(self.filter(n_uplet))) + except Exception as err: + errors.append(err) + if errors: + raise ValueError( + 'Error parsing data, encountered the folowing {} errors.\n"{}"' + .format(len(errors), '",\n"'.join(map(str, errors)))) + return set_ + + def _apply_target_content(self): + """Change netfilter set content to target set.""" + current_set = self.get_netfilter_set_content() + if current_set is None: + raise ValueError('Cannot change "{}" netfilter set content: set ' + 'do not exist in "{}" "{}".'.format( + self.name, self.address_family, self.table)) + to_delete = current_set - self._target_content + to_add = self._target_content - current_set + self._change_set_content(delete=to_delete, add=to_add) + + def _change_set_content(self, delete=None, add=None): + todo = [tuple_ for tuple_ in (('add', add), ('delete', delete)) + if tuple_[1]] + for action, elements in todo: + content = ', '.join(' . '.join(str(element) for element in tuple_) + for tuple_ in elements) + command = [ + *self.nft, + '{action} element {addr_family} {table} {set_} {{{content}}}' \ + .format(action=action, addr_family=self.address_family, + table=self.table, set_=self.name, content=content) + ] + CommandExec.run(command) + + def _get_raw_netfilter_set(self, parse_elements=True): + """Return a dict describing the netfilter set matching self or None.""" + _, stdout, _ = CommandExec.run_check_output( + [*self.nft, '-nn', 'list set {addr_family} {table} {set_}'.format( + addr_family=self.address_family, table=self.table, + set_=self.name)], + allowed_return_codes=(0, 1) # In case table do not exist + ) + if not stdout: + return None + else: + netfilter_set = self._parse_netfilter_set_string(stdout) + if netfilter_set['name'] != self.name \ + or netfilter_set['address_family'] != self.address_family \ + or netfilter_set['table'] != self.table \ + or netfilter_set['type'] != [ + self.TYPES[type_] for type_ in self.type]: + raise ValueError('Did not get the right set, too wrong to fix.') + if parse_elements: + if netfilter_set['raw_content']: + netfilter_set['content'] = self.validate_set_data(( + (element.strip() for element in n_uplet.split(' . ')) + for n_uplet in netfilter_set['raw_content'].split(','))) + else: + netfilter_set['content'] = set() + return netfilter_set + + @staticmethod + def _parse_netfilter_set_string(set_string): + """ + Parse netfilter set definition and return set as dict. + + Do not validate content type against detected set type. + Return a dict with 'name', 'address_family', 'table', 'type', + 'raw_content' keys (all strings, 'raw_content' can be None). + Raise ValueError in case of unexpected syntax. + """ + # Fragile code since using lexer / parser would be quite heavy + lines = [line.lstrip('\t ') for line in set_string.strip().splitlines()] + errors = [] + # 5 lines when empty, 6 with elements = { … } + if len(lines) not in (5, 6): + errors.append('Error, expecting 5 or 6 lines for set definition, ' + 'got "{}".'.format(set_string)) + + line_iterator = iter(lines) + set_definition = {} + + line = next(line_iterator).split(' ') # line #1 + # 'table {' + if len(line) != 4 or line[0] != 'table' or line[3] != '{': + errors.append( + 'Cannot parse table definition, expecting "type ' + ' {{", got "{}".'.format(' '.join(line))) + else: + set_definition['address_family'] = line[1] + set_definition['table'] = line[2] + + line = next(line_iterator).split(' ') # line #2 + # 'set {' + if len(line) != 3 or line[0] != 'set' or line[2] != '{': + errors.append('Cannot parse set definition, expecting "set ' + '{{", got "{}".' .format(' '.join(line))) + else: + set_definition['name'] = line[1] + + line = next(line_iterator).split(' ') # line #3 + # 'type [. ]...' + if len(line) < 2 or len(line) % 2 != 0 or line[0] != 'type' \ + or any(element != '.' for element in line[2::2]): + errors.append( + 'Cannot parse type definition, expecting "type ' + '[. ]...", got "{}".'.format(' '.join(line))) + else: + set_definition['type'] = line[1::2] + + if len(lines) == 6: + # set is not empty, getting raw elements + line = next(line_iterator) # Unsplit line #4 + if line[:13] != 'elements = { ' or line[-1] != '}': + errors.append('Cannot parse set elements, expecting "elements ' + '= {{ <…>}}", got "{}".'.format(line)) + else: + set_definition['raw_content'] = line[13:-1].strip() + else: + set_definition['raw_content'] = None + + # last two lines + for i in range(2): + line = next(line_iterator).split(' ') + if line != ['}']: + errors.append( + 'No normal end to set definition, expecting "}}" on line ' + '{}, got "{}".'.format(i+5, ' '.join(line))) + if errors: + raise ValueError('The following error(s) were encountered while ' + 'parsing set.\n"{}"'.format('",\n"'.join(errors))) + return set_definition + + def get_netfilter_set_content(self): + """Return current set content from netfilter.""" + netfilter_set = self._get_raw_netfilter_set(parse_elements=True) + if netfilter_set is None: + return None + else: + return netfilter_set['content'] + + def has_type(self, type_): + """Check if some type match the set's one.""" + return tuple(self.TYPES[t] for t in self.type) == tuple(type_) + + def manage(self): + """Create set if needed and populate it with target content.""" + self.create_in_kernel() + self._apply_target_content() + + +class Firewall: + """Manages the firewall using nftables.""" + + @staticmethod + def manage_sets(sets, address_family=None, table=None, use_sudo=None): + CONFIG = Config() + address_family = address_family or CONFIG['address_family'] or 'inet' + table = table or CONFIG['table'] or 'filter' + sudo = use_sudo or (use_sudo is None and CONFIG['use_sudo']) + for set_ in sets: + NetfilterSet( + name=set_['name'], + type_=set_['type'], + target_content=set_['content'], + use_sudo=sudo, + address_family=address_family, + table_name=table).manage()