#!/usr/bin/python3 # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . # Copyright © 2017 David Sinquin """ Module for nftables set management. """ # Dependencies: python3-netaddr, python3-requests, nftables, sudo (optionnal) # # For sudo configuration, create a file in /etc/sudoers.d/ with: # " ALL = (root) NOPASSWD: /sbin/nftables" # netaddr : # - https://pypi.python.org/pypi/netaddr/ # - https://github.com/drkjam/netaddr/ # - https://netaddr.readthedocs.io/en/latest/ import logging import subprocess import netaddr # MAC, IPv4, IPv6 import requests from collections import Iterable from config import Config class ExecError(Exception): """Simple class to indicate an error in a process execution.""" pass class CommandExec: """Simple class to start a command, logging and returning errors if any.""" @staticmethod def run_check_output(command, allowed_return_codes=(0,), timeout=15): """ Run a command, logging output in case of an error. Actual timeout may be twice the given value in seconds.""" logging.debug("Command to be run: '%s'", "' '".join(command)) process = subprocess.Popen( command, shell=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True) try: result = process.communicate(timeout=timeout) return_code = process.wait(timeout=timeout) except subprocess.TimeoutExpired as err: process.kill() raise ExecError from err if return_code not in allowed_return_codes: error_message = ('Error running command: "{}", return code: {}.\n' 'Stderr:\n{}\nStdout:\n{}'.format( '" "'.join(command), return_code, *result)) logging.error(error_message) raise ExecError(error_message) return (return_code, *result) @classmethod def run(cls, *args, **kwargs): """Run a command without checking outputs.""" returncode, _, _ = cls.run_check_output(*args, **kwargs) return returncode class Parser: """Parsers for commonly used formats.""" @staticmethod def MAC(mac): """Check a MAC validity.""" return netaddr.EUI(mac, dialect=netaddr.mac_unix_expanded) @staticmethod def IPv4(ip): """Check an IPv4 validity.""" return netaddr.IPAddress(ip, version=4) @staticmethod def IPv6(ip): """Check a IPv6 validity.""" return netaddr.IPAddress(ip, version=6) @staticmethod def protocol(protocol): """Check a protocol validity.""" if protocol in ('tcp', 'udp', 'icmp'): return protocol raise ValueError('Invalid protocol: "{}".'.format(protocol)) @staticmethod def port_number(port): """Check a port validity.""" port_number = int(port) if 0 <= port_number < 65536: return port_number raise ValueError('Invalid port number: "{}".'.format(port)) class NetfilterSet: """Manage a netfilter set using nftables.""" TYPES = {'IPv4': 'ipv4_addr', 'IPv6': 'ipv6_addr', 'MAC': 'ether_addr', 'protocol': 'inet_proto', 'port': 'inet_service'} FILTERS = {'IPv4': Parser.IPv4, 'IPv6': Parser.IPv6, 'MAC': Parser.MAC, 'protocol': Parser.protocol, 'port': Parser.port_number} ADDRESS_FAMILIES = {'ip', 'ip6', 'inet', 'arp', 'bridge', 'netdev'} def __init__(self, name, type_, # e.g.: ('MAC', 'IPv4') target_content=None, use_sudo=True, address_family='inet', # Manage both IPv4 and IPv6. table_name='filter'): self.name = name self.content = set() # self.type self.set_type(type_) self.filters = tuple(self.FILTERS[i] for i in self.type) # self.address_family self.set_address_family(address_family) self.table = table_name sudo = ["/usr/bin/sudo"] * int(bool(use_sudo)) self.nft = [*sudo, "/usr/sbin/nft"] if target_content: self._target_content = self.validate_set_data(target_content) else: self._target_content = set() @property def target_content(self): return self._target_content.copy() # Forbid in-place modification @target_content.setter def target_content(self, target_content): self._target_content = self.validate_set_data(target_content) def filter(self, elements): return (self.filters[i](element) for i, element in enumerate(elements)) def set_type(self, type_): """Check set type validity and store it along with a type checker.""" for element_type in type_: if element_type not in self.TYPES: raise ValueError('Invalid type: "{}".'.format(element_type)) self.type = type_ def set_address_family(self, address_family='ip'): """Set set addres_family, defaulting to "ip" like nftables.""" if address_family not in self.ADDRESS_FAMILIES: raise ValueError( 'Invalid address_family: "{}".'.format(address_family)) self.address_family = address_family def create_in_kernel(self): """Create the set, removing existing set if needed.""" # Delete set if it exists with wrong type current_set = self._get_raw_netfilter_set(parse_elements=False) logging.info(current_set) if current_set is None: self._create_new_set_in_kernel() elif not self.has_type(current_set['type']): self._delete_in_kernel() self._create_new_set_in_kernel() def _delete_in_kernel(self): """Delete the set, table and set must exist.""" CommandExec.run([ *self.nft, 'delete set {addr_family} {table} {set_}'.format( addr_family=self.address_family, table=self.table, set_=self.name) ]) def _create_new_set_in_kernel(self): """Create the non-existing set, creating table if needed.""" create_set = [ *self.nft, 'add set {addr_family} {table} {set_} {{ type {type_} ; }}'.format( addr_family=self.address_family, table=self.table, set_=self.name, type_=' . '.join(self.TYPES[i] for i in self.type)) ] return_code = CommandExec.run(create_set, allowed_return_codes=(0, 1)) if return_code == 0: return # Set creation successful. # return_code was 1, one error was detected in the rules. # Attempt to create the table first. create_table = [*self.nft, 'add table {addr_family} {table}'.format( addr_family=self.address_family, table=self.table)] CommandExec.run(create_table) CommandExec.run(create_set) def validate_set_data(self, set_data): """ Validate data, returning it or raising a ValueError. For MAC-IPv4 set, data must be an iterable of (MAC, IPv4) iterables. """ set_ = set() errors = [] for n_uplet in set_data: try: set_.add(tuple(self.filter(n_uplet))) except Exception as err: errors.append(err) if errors: raise ValueError( 'Error parsing data, encountered the folowing {} errors.\n"{}"' .format(len(errors), '",\n"'.join(map(str, errors)))) return set_ def _apply_target_content(self): """Change netfilter set content to target set.""" current_set = self.get_netfilter_set_content() if current_set is None: raise ValueError('Cannot change "{}" netfilter set content: set ' 'do not exist in "{}" "{}".'.format( self.name, self.address_family, self.table)) to_delete = current_set - self._target_content to_add = self._target_content - current_set self._change_set_content(delete=to_delete, add=to_add) def _change_set_content(self, delete=None, add=None): todo = [tuple_ for tuple_ in (('add', add), ('delete', delete)) if tuple_[1]] for action, elements in todo: content = ', '.join(' . '.join(str(element) for element in tuple_) for tuple_ in elements) command = [ *self.nft, '{action} element {addr_family} {table} {set_} {{{content}}}' \ .format(action=action, addr_family=self.address_family, table=self.table, set_=self.name, content=content) ] CommandExec.run(command) def _get_raw_netfilter_set(self, parse_elements=True): """Return a dict describing the netfilter set matching self or None.""" _, stdout, _ = CommandExec.run_check_output( [*self.nft, '-nn', 'list set {addr_family} {table} {set_}'.format( addr_family=self.address_family, table=self.table, set_=self.name)], allowed_return_codes=(0, 1) # In case table do not exist ) if not stdout: return None else: netfilter_set = self._parse_netfilter_set_string(stdout) if netfilter_set['name'] != self.name \ or netfilter_set['address_family'] != self.address_family \ or netfilter_set['table'] != self.table \ or netfilter_set['type'] != [ self.TYPES[type_] for type_ in self.type]: raise ValueError('Did not get the right set, too wrong to fix.') if parse_elements: if netfilter_set['raw_content']: netfilter_set['content'] = self.validate_set_data(( (element.strip() for element in n_uplet.split(' . ')) for n_uplet in netfilter_set['raw_content'].split(','))) else: netfilter_set['content'] = set() return netfilter_set @staticmethod def _parse_netfilter_set_string(set_string): """ Parse netfilter set definition and return set as dict. Do not validate content type against detected set type. Return a dict with 'name', 'address_family', 'table', 'type', 'raw_content' keys (all strings, 'raw_content' can be None). Raise ValueError in case of unexpected syntax. """ # Fragile code since using lexer / parser would be quite heavy lines = [line.lstrip('\t ') for line in set_string.strip().splitlines()] errors = [] # 5 lines when empty, 6 with elements = { … } if len(lines) not in (5, 6): errors.append('Error, expecting 5 or 6 lines for set definition, ' 'got "{}".'.format(set_string)) line_iterator = iter(lines) set_definition = {} line = next(line_iterator).split(' ') # line #1 # 'table {' if len(line) != 4 or line[0] != 'table' or line[3] != '{': errors.append( 'Cannot parse table definition, expecting "type ' ' {{", got "{}".'.format(' '.join(line))) else: set_definition['address_family'] = line[1] set_definition['table'] = line[2] line = next(line_iterator).split(' ') # line #2 # 'set {' if len(line) != 3 or line[0] != 'set' or line[2] != '{': errors.append('Cannot parse set definition, expecting "set ' '{{", got "{}".' .format(' '.join(line))) else: set_definition['name'] = line[1] line = next(line_iterator).split(' ') # line #3 # 'type [. ]...' if len(line) < 2 or len(line) % 2 != 0 or line[0] != 'type' \ or any(element != '.' for element in line[2::2]): errors.append( 'Cannot parse type definition, expecting "type ' '[. ]...", got "{}".'.format(' '.join(line))) else: set_definition['type'] = line[1::2] if len(lines) == 6: # set is not empty, getting raw elements line = next(line_iterator) # Unsplit line #4 if line[:13] != 'elements = { ' or line[-1] != '}': errors.append('Cannot parse set elements, expecting "elements ' '= {{ <…>}}", got "{}".'.format(line)) else: set_definition['raw_content'] = line[13:-1].strip() else: set_definition['raw_content'] = None # last two lines for i in range(2): line = next(line_iterator).split(' ') if line != ['}']: errors.append( 'No normal end to set definition, expecting "}}" on line ' '{}, got "{}".'.format(i+5, ' '.join(line))) if errors: raise ValueError('The following error(s) were encountered while ' 'parsing set.\n"{}"'.format('",\n"'.join(errors))) return set_definition def get_netfilter_set_content(self): """Return current set content from netfilter.""" netfilter_set = self._get_raw_netfilter_set(parse_elements=True) if netfilter_set is None: return None else: return netfilter_set['content'] def has_type(self, type_): """Check if some type match the set's one.""" return tuple(self.TYPES[t] for t in self.type) == tuple(type_) def manage(self): """Create set if needed and populate it with target content.""" self.create_in_kernel() self._apply_target_content() class Firewall: """Manages the firewall using nftables.""" @staticmethod def manage_sets(sets, address_family=None, table=None, use_sudo=None): CONFIG = Config() address_family = address_family or CONFIG['address_family'] or 'inet' table = table or CONFIG['table'] or 'filter' sudo = use_sudo or (use_sudo is None and CONFIG['use_sudo']) for set_ in sets: NetfilterSet( name=set_['name'], type_=set_['type'], target_content=set_['content'], use_sudo=sudo, address_family=address_family, table_name=table).manage()