#!/usr/bin/python3 # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . # Copyright © 2017 David Sinquin """ Module for nftables set management. """ # Dependencies: python3-netaddr, python3-requests, nftables, sudo (optionnal) # # For sudo configuration, create a file in /etc/sudoers.d/ with: # " ALL = (root) NOPASSWD: /sbin/nftables" # netaddr : # - https://pypi.python.org/pypi/netaddr/ # - https://github.com/drkjam/netaddr/ # - https://netaddr.readthedocs.io/en/latest/ import logging import subprocess import re import netaddr # MAC, IPv4, IPv6 import requests from collections import Iterable from configparser import ConfigParser class ExecError(Exception): """Simple class to indicate an error in a process execution.""" pass class CommandExec: """Simple class to start a command, logging and returning errors if any.""" @staticmethod def run_check_output(command, allowed_return_codes=(0,), timeout=15): """ Run a command, logging output in case of an error. Actual timeout may be twice the given value in seconds.""" logging.debug("Command to be run: '%s'", "' '".join(command)) process = subprocess.Popen( command, shell=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True) try: result = process.communicate(timeout=timeout) return_code = process.wait(timeout=timeout) except subprocess.TimeoutExpired as err: process.kill() raise ExecError from err if return_code not in allowed_return_codes: error_message = ('Error running command: "{}", return code: {}.\n' 'Stderr:\n{}\nStdout:\n{}'.format( '" "'.join(command), return_code, *result)) logging.error(error_message) raise ExecError(error_message) return (return_code, *result) @classmethod def run(cls, *args, **kwargs): """Run a command without checking outputs.""" returncode, _, _ = cls.run_check_output(*args, **kwargs) return returncode class Parser: """Parsers for commonly used formats.""" @staticmethod def MAC(mac): """Check a MAC validity.""" return netaddr.EUI(mac, dialect=netaddr.mac_unix_expanded) @staticmethod def IPv4(ip): """Check an IPv4 validity. Args: ip: can either be a tuple (in this case returns an IPRange), a single IP address or a IP Network. """ try: return netaddr.IPAddress(ip, version=4) except netaddr.core.AddrFormatError: try: return netaddr.IPNetwork(ip, version=4) except netaddr.core.AddrFormatError: begin, end = ip.split('-') return netaddr.IPRange(begin, end) @staticmethod def IPv6(ip): """Check a IPv6 validity. Args: ip: can either be a tuple (in this case returns an IPRange), a single IP address or a IP Network. """ if isinstance(ip, tuple): begin, end = ip return netaddr.IPRange(begin, end, version=6) try: return netaddr.IPAddress(ip, version=6) except ValueError: return netaddr.IPNetwork(ip, version=6) @staticmethod def protocol(protocol): """Check a protocol validity.""" if protocol in ('tcp', 'udp', 'icmp'): return protocol raise ValueError('Invalid protocol: "{}".'.format(protocol)) @staticmethod def port_number(port): """Check a port validity.""" try: port_number = int(port) if 0 <= port_number < 65536: return port_number except ValueError: begin, end = port.split('-') begin, end = int(begin), int(end) if 0 <= begin < end <= 65536: return port raise ValueError('Invalid port number: "{}".'.format(port)) class NetfilterSet: """Manage a netfilter set using nftables.""" TYPES = {'IPv4': 'ipv4_addr', 'IPv6': 'ipv6_addr', 'MAC': 'ether_addr', 'protocol': 'inet_proto', 'port': 'inet_service'} FILTERS = {'IPv4': Parser.IPv4, 'IPv6': Parser.IPv6, 'MAC': Parser.MAC, 'protocol': Parser.protocol, 'port': Parser.port_number} ADDRESS_FAMILIES = {'ip', 'ip6', 'inet', 'arp', 'bridge', 'netdev'} FLAGS = {'constant', 'interval', 'timeout'} NFT_TYPE = {'set', 'map'} def __init__(self, name, type_, # e.g.: ('MAC', 'IPv4') target_content=None, use_sudo=True, address_family='inet', # Manage both IPv4 and IPv6. table_name='filter', flags = [], type_from=None ): self.name = name self.content = set() # self.type self.set_type(type_) if type_from: self.set_type_from(type_from) self.nft_type = 'map' self.key_filters = tuple(self.FILTERS[i] for i in self.type_from) else: self.nft_type = 'set' self.filters = tuple(self.FILTERS[i] for i in self.type) self.set_flags(flags) # self.address_family self.set_address_family(address_family) self.table = table_name sudo = ["/usr/bin/sudo"] * int(bool(use_sudo)) self.nft = [*sudo, "/usr/sbin/nft"] if target_content and self.nft_type == 'set': self._target_content = self.validate_set_data(target_content) elif target_content and self.nft_type == 'map': self._target_content = self.validate_map_data(target_content) else: self._target_content = set() @property def target_content(self): return self._target_content.copy() # Forbid in-place modification @target_content.setter def target_content(self, target_content): self._target_content = self.validate_set_data(target_content) def filter(self, elements): return (self.filters[i](element) for i, element in enumerate(elements)) def filter_key(self, elements): return (self.key_filters[i](element) for i, element in enumerate(elements)) def set_type(self, type_): """Check set type validity and store it along with a type checker.""" for element_type in type_: if element_type not in self.TYPES: raise ValueError('Invalid type: "{}".'.format(element_type)) self.type = type_ def set_type_from(self, type_): """Check set type validity and store it along with a type checker.""" for element_type in type_: if element_type not in self.TYPES: raise ValueError('Invalid type: "{}".'.format(element_type)) self.type_from = type_ def set_address_family(self, address_family='ip'): """Set set addres_family, defaulting to "ip" like nftables.""" if address_family not in self.ADDRESS_FAMILIES: raise ValueError( 'Invalid address_family: "{}".'.format(address_family)) self.address_family = address_family def set_flags(self, flags_): """Check set flags validity before saving them.""" for f in flags_: if f not in self.FLAGS: raise ValueError('Invalid flag: "{}".'.format(f)) self.flags = set(flags_) def create_in_kernel(self): """Create the set, removing existing set if needed.""" # Delete set if it exists with wrong type current_set = self._get_raw_netfilter_set(parse_elements=False) logging.info(current_set) if current_set is None: self._create_new_set_in_kernel() elif not self.has_type(current_set['type']): self._delete_in_kernel() self._create_new_set_in_kernel() def _delete_in_kernel(self): """Delete the set, table and set must exist.""" CommandExec.run([ *self.nft, 'delete {nft_type} {addr_family} {table} {set_}'.format( nft_type=self.nft_type, addr_family=self.address_family, table=self.table, set_=self.name) ]) def _create_new_set_in_kernel(self): """Create the non-existing set, creating table if needed.""" if self.flags: nft_command = 'add {nft_type} {addr_family} {table} {set_} {{ type {type_} ; flags {flags};}}'.format( nft_type=self.nft_type, addr_family=self.address_family, table=self.table, set_=self.name, type_=self.format_type(), flags=', '.join(self.flags) ) else: nft_command = 'add {nft_type} {addr_family} {table} {set_} {{ type {type_} ;}}'.format( nft_type=self.nft_type, addr_family=self.address_family, table=self.table, set_=self.name, type_=self.format_type(), ) create_set = [ *self.nft, nft_command ] return_code = CommandExec.run(create_set, allowed_return_codes=(0, 1)) if return_code == 0: return # Set creation successful. # return_code was 1, one error was detected in the rules. # Attempt to create the table first. create_table = [*self.nft, 'add table {addr_family} {table}'.format( addr_family=self.address_family, table=self.table)] CommandExec.run(create_table) CommandExec.run(create_set) def validate_set_data(self, set_data): """ Validate data, returning it or raising a ValueError. For MAC-IPv4 set, data must be an iterable of (MAC, IPv4) iterables. """ set_ = set() errors = [] for n_uplet in set_data: try: set_.add(tuple(self.filter(n_uplet))) except Exception as err: errors.append(err) if errors: raise ValueError( 'Error parsing data, encountered the folowing {} errors.\n"{}"' .format(len(errors), '",\n"'.join(map(str, errors)))) return set_ def validate_map_data(self, dict_data): """ Validate data, returning it or raising a ValueError. For MAC-IPv4 set, data must be an iterable of (MAC, IPv4) iterables. """ set_ = {} errors = [] for key in dict_data: try: set_[tuple(self.filter_key(key))] = tuple(self.filter(dict_data[key])) except Exception as err: errors.append(err) if errors: raise ValueError( 'Error parsing data, encountered the folowing {} errors.\n"{}"' .format(len(errors), '",\n"'.join(map(str, errors)))) return set_ def _apply_target_content(self): """Change netfilter content to target set.""" if self.nft_type == 'set': self._apply_target_content_set() else: self._apply_target_content_map() def _apply_target_content_set(self): """Change netfilter set content to target set.""" current_set = self.get_netfilter_set_content() if current_set is None: raise ValueError('Cannot change "{}" netfilter set content: set ' 'do not exist in "{}" "{}".'.format( self.name, self.address_family, self.table)) to_delete = current_set - self._target_content to_add = self._target_content - current_set self._change_set_content(delete=to_delete, add=to_add) def _apply_target_content_map(self): """Change netfilter set content to target set.""" current_map = self.get_netfilter_map_content() if current_map is None: raise ValueError('Cannot change "{}" netfilter map content: map ' 'do not exist in "{}" "{}".'.format( self.name, self.address_family, self.table)) keys_to_delete = current_map.keys() - self._target_content.keys() keys_to_add = self._target_content.keys() - current_map.keys() keys_to_check = current_map.keys() & self._target_content.keys() for k in keys_to_check: if current_map[k] != self._target_content[k]: keys_to_add.add(k) keys_to_delete.add(k) to_add = {k : self._target_content[k] for k in keys_to_add} self._change_map_content(delete=keys_to_delete, add=to_add) def _change_set_content(self, delete=None, add=None): todo = [tuple_ for tuple_ in (('add', add), ('delete', delete)) if tuple_[1]] for action, elements in todo: content = ', '.join(' . '.join(str(element) for element in tuple_) for tuple_ in elements) command = [ *self.nft, '{action} element {addr_family} {table} {set_} {{{content}}}' \ .format(action=action, addr_family=self.address_family, table=self.table, set_=self.name, content=content) ] CommandExec.run(command) def _change_map_content(self, delete=None, add=None): if delete: content = ', '.join(' . '.join(str(element) for element in tuple_) for tuple_ in delete) command = [ *self.nft, 'delete element {addr_family} {table} {set_} {{{content}}}' \ .format(addr_family=self.address_family, table=self.table, set_=self.name, content=content) ] CommandExec.run(command) if add: content = ', '.join( ' . '.join(str(element) for element in tuple_) + ' : ' + ' . '.join(str(element) for element in add[tuple_]) for tuple_ in add ) command = [ *self.nft, 'add element {addr_family} {table} {set_} {{{content}}}' \ .format(addr_family=self.address_family, table=self.table, set_=self.name, content=content) ] CommandExec.run(command) def _get_raw_netfilter_set(self, parse_elements=True): """Return a dict describing the netfilter set matching self or None.""" _, stdout, _ = CommandExec.run_check_output( [*self.nft, '-nn', 'list set {addr_family} {table} {set_}'.format( addr_family=self.address_family, table=self.table, set_=self.name)], allowed_return_codes=(0, 1) # In case table do not exist ) if not stdout: return None else: netfilter_set = self._parse_netfilter_set_string(stdout) if netfilter_set['name'] != self.name \ or netfilter_set['address_family'] != self.address_family \ or netfilter_set['table'] != self.table \ or not self.has_type(netfilter_set['type']) \ or netfilter_set.get('flags', set()) != self.flags: raise ValueError( 'Did not get the right set, too wrong to fix. Got ' + str(netfilter_set) + ("\nExpected : " "\n\tname: {name}" "\n\taddress_family: {family}" "\n\ttable: {table}" "\n\tflags: {flags}" "\n\ttypes: {types}" ).format( name=self.name, family=self.address_family, table=self.table, flags=self.flags, types=tuple(self.TYPES[t] for t in self.type) ) ) if parse_elements: if netfilter_set['raw_content']: netfilter_set['content'] = self.validate_set_data(( (element.strip() for element in n_uplet.split(' . ')) for n_uplet in netfilter_set['raw_content'].split(','))) else: netfilter_set['content'] = set() return netfilter_set def _get_raw_netfilter_map(self, parse_elements=True): """Return a dict describing the netfilter map matching self or None.""" _, stdout, _ = CommandExec.run_check_output( [*self.nft, '-nn', 'list map {addr_family} {table} {set_}'.format( addr_family=self.address_family, table=self.table, set_=self.name)], allowed_return_codes=(0, 1) # In case table do not exist ) if not stdout: return None else: netfilter_set = self._parse_netfilter_map_string(stdout) if netfilter_set['name'] != self.name \ or netfilter_set['address_family'] != self.address_family \ or netfilter_set['table'] != self.table \ or not self.has_type(netfilter_set['type']): raise ValueError('Did not get the right map, too wrong to fix.') if parse_elements: if netfilter_set['raw_content']: netfilter_set['content'] = self.validate_map_data({ (element.strip() for element in n_uplet.split(' : ')[0].split(' . ')) : (element.strip() for element in n_uplet.split(' : ')[1].split(' . ')) for n_uplet in netfilter_set['raw_content'].split(',') }) else: netfilter_set['content'] = {} return netfilter_set @staticmethod def _parse_netfilter_set_string(set_string): """ Parse netfilter set definition and return set as dict. Do not validate content type against detected set type. Return a dict with 'name', 'address_family', 'table', 'type', 'flags', 'raw_content' keys (all strings, 'raw_content' can be None). Raise ValueError in case of unexpected syntax. """ # A.K.A. Really, I don't hate you, so please don't hate me... regexp = ( "table (?P\w+)+ (?P\w+) \{\n" "\s*set (?P\w+) \{\n" "\s*type (?P(\w+( \. )?)+)\n" "(\s*elements = \{ " "(?P((\n\s*)?([\w:\.]+( \. )?)+,?)*) " "\}\n)?" "\s*\}\n" "\s*\}" ) values = re.match(regexp, set_string).groupdict() return { 'address_family': values['address_family'], 'table': values['table'], 'name': values['name'], 'type': values['type'].split(' . '), 'raw_content': values['elements'], } @staticmethod def _parse_netfilter_map_string(set_string): """ Parse netfilter map definition and return map as dict. Do not validate content type against detected map type. Return a dict with 'name', 'address_family', 'table', 'type', 'flags' 'raw_content' keys (all strings, 'raw_content' can be None). Raise ValueError in case of unexpected syntax. """ # Fragile code since using lexer / parser would be quite heavy lines = [line.lstrip('\t ') for line in set_string.strip().splitlines()] errors = [] # 5 lines when empty, 6 with elements = { … } + one for flags if len(lines) not in (5, 6, 7): errors.append('Error, expecting 5 or 6 lines for set definition, ' 'got "{}".'.format(set_string)) line_iterator = iter(lines) set_definition = {} line = next(line_iterator).split(' ') # line #1 # 'table {' if len(line) != 4 or line[0] != 'table' or line[3] != '{': errors.append( 'Cannot parse table definition, expecting "type ' '
{{", got "{}".'.format(' '.join(line))) else: set_definition['address_family'] = line[1] set_definition['table'] = line[2] line = next(line_iterator).split(' ') # line #2 # 'set {' if len(line) != 3 or line[0] != 'map' or line[2] != '{': errors.append('Cannot parse set definition, expecting "set ' '{{", got "{}".' .format(' '.join(line))) else: set_definition['name'] = line[1] line, elements_type = next(line_iterator).split(' : ') # line #3 # 'type [. ]... : [. ]...' line = line.split(' ') if len(line) < 2: errors.append( 'Cannot parse type definition, left side of \':\' is too short : %s' % line ) type_, keys_type = line[0], line[1:] elements_type = elements_type.split(' ') if type_ != 'type': errors.append( 'Cannot parse type definition, expected first word \'type\', got %s' % type_ ) elif len(elements_type) % 2 != 1 or len(keys_type) % 2 != 1 \ or any(e != '.' for e in elements_type[1::2]) \ or any(e != '.' for e in keys_type[1::2]): errors.append( 'Cannot parse type definition, expecting "type ' '[. ]... : [. ]...", got "{}".'.format(' '.join(line))) else: set_definition['type'] = (keys_type[::2], elements_type[::2]) # here we can have the flags, if there are any # flags , , ... if len(lines) >= 6: line = next(line_iterator) if line[:5] == 'flags': # If there are actually flags set_definition['flags'] = {f.strip() for f in line[:5].strip().split(',')} if len(lines) >= 6: # set is not empty, getting raw elements if 'flags' in set_definition and len(lines) == 7: # the line unsplitted previously has been used. line = next(line_iterator) # Unsplit line #4 print(line) if ('flags' in set_definition and len(lines)==7) or ('flags' not in set_definition and len(lines)==6) : if line[:13] != 'elements = { ' or line[-1] != '}': errors.append('Cannot parse set elements, expecting "elements ' '= {{ <…>}}", got "{}".'.format(line)) else: set_definition['raw_content'] = line[13:-1].strip() else: set_definition['raw_content'] = None else: set_definition['raw_content'] = None # last two lines for i in range(2): line = next(line_iterator).split(' ') if line != ['}']: errors.append( 'No normal end to set definition, expecting "}}" on line ' '{}, got "{}".'.format(i+5, ' '.join(line))) if errors: raise ValueError('The following error(s) were encountered while ' 'parsing set.\n"{}"'.format('",\n"'.join(errors))) return set_definition def get_netfilter_set_content(self): """Return current set content from netfilter.""" netfilter_set = self._get_raw_netfilter_set(parse_elements=True) if netfilter_set is None: return None else: return netfilter_set['content'] def get_netfilter_map_content(self): """Return current set content from netfilter.""" netfilter_set = self._get_raw_netfilter_map(parse_elements=True) if netfilter_set is None: return None else: return netfilter_set['content'] def has_type(self, type_): """Check if some type match the set's one.""" if self.nft_type == 'set': return tuple(self.TYPES[t] for t in self.type) == tuple(type_) else: return tuple(self.TYPES[t] for t in self.type) == tuple(type_[1]) and \ tuple(self.TYPES[t] for t in self.type_from) == tuple(type_[0]) def manage(self): """Create set if needed and populate it with target content.""" self.create_in_kernel() self._apply_target_content() def format_type(self): if self.nft_type == 'set': return ' . '.join(self.TYPES[i] for i in self.type) else: return ' . '.join(self.TYPES[i] for i in self.type_from) + ' : ' + ' . '.join(self.TYPES[i] for i in self.type) class Firewall: """Manages the firewall using nftables.""" @staticmethod def manage_sets(sets, address_family=None, table=None, use_sudo=None): CONFIG = ConfigParser() CONFIG.read('config.ini') address_family = address_family or CONFIG['address_family'] or 'inet' table = table or CONFIG['table'] or 'filter' sudo = use_sudo or (use_sudo is None and CONFIG['use_sudo']) for set_ in sets: NetfilterSet( name=set_['name'], type_=set_['type'], target_content=set_['content'], use_sudo=sudo, address_family=address_family, table_name=table).manage()