diff options
author | Johan Lundberg <lundberg@nordu.net> | 2015-04-02 10:43:33 +0200 |
---|---|---|
committer | Johan Lundberg <lundberg@nordu.net> | 2015-04-02 10:43:33 +0200 |
commit | bd611ac59f7c4db885a2f8631ef0bcdcd1901ca0 (patch) | |
tree | e60f5333a7699cd021b33c7f5292af55b774001b /lib/policy.py |
Diffstat (limited to 'lib/policy.py')
-rw-r--r-- | lib/policy.py | 1821 |
1 files changed, 1821 insertions, 0 deletions
diff --git a/lib/policy.py b/lib/policy.py new file mode 100644 index 0000000..a6b8ad6 --- /dev/null +++ b/lib/policy.py @@ -0,0 +1,1821 @@ +#!/usr/bin/python +# +# Copyright 2011 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Parses the generic policy files and return a policy object for acl rendering. +""" + +import datetime +import os +import sys + +import logging +import nacaddr +import naming + +from third_party.ply import lex +from third_party.ply import yacc + + +DEFINITIONS = None +DEFAULT_DEFINITIONS = './def' +_ACTIONS = set(('accept', 'deny', 'reject', 'next', 'reject-with-tcp-rst')) +_LOGGING = set(('true', 'True', 'syslog', 'local', 'disable')) +_OPTIMIZE = True +_SHADE_CHECK = False + + +class Error(Exception): + """Generic error class.""" + + +class FileNotFoundError(Error): + """Policy file unable to be read.""" + + +class FileReadError(Error): + """Policy file unable to be read.""" + + +class RecursionTooDeepError(Error): + """Included files exceed maximum recursion depth.""" + + +class ParseError(Error): + """ParseError in the input.""" + + +class TermAddressExclusionError(Error): + """Excluded address block is not contained in the accepted address block.""" + + +class TermObjectTypeError(Error): + """Error with an object passed to Term.""" + + +class TermPortProtocolError(Error): + """Error when a requested protocol doesn't have any of the requested ports.""" + + +class TermProtocolEtherTypeError(Error): + """Error when both ether-type & upper-layer protocol matches are requested.""" + + +class TermNoActionError(Error): + """Error when a term hasn't defined an action.""" + + +class TermInvalidIcmpType(Error): + """Error when a term has invalid icmp-types specified.""" + + +class InvalidTermActionError(Error): + """Error when an action is invalid.""" + + +class InvalidTermLoggingError(Error): + """Error when a option is set for logging.""" + + +class UndefinedAddressError(Error): + """Error when an undefined address is referenced.""" + + +class NoTermsError(Error): + """Error when no terms were found.""" + + +class ShadingError(Error): + """Error when a term is shaded by a prior term.""" + + +def TranslatePorts(ports, protocols, term_name): + """Return all ports of all protocols requested. + + Args: + ports: list of ports, eg ['SMTP', 'DNS', 'HIGH_PORTS'] + protocols: list of protocols, eg ['tcp', 'udp'] + term_name: name of current term, used for warning messages + + Returns: + ret_array: list of ports tuples such as [(25,25), (53,53), (1024,65535)] + + Note: + Duplication will be taken care of in Term.CollapsePortList + """ + ret_array = [] + for proto in protocols: + for port in ports: + service_by_proto = DEFINITIONS.GetServiceByProto(port, proto) + if not service_by_proto: + logging.warn('%s %s %s %s %s %s%s %s', 'Term', term_name, + 'has service', port, 'which is not defined with protocol', + proto, + ', but will be permitted. Unless intended, you should', + 'consider splitting the protocols into separate terms!') + + for p in [x.split('-') for x in service_by_proto]: + if len(p) == 1: + ret_array.append((int(p[0]), int(p[0]))) + else: + ret_array.append((int(p[0]), int(p[1]))) + return ret_array + + +# classes for storing the object types in the policy files. +class Policy(object): + """The policy object contains everything found in a given policy file.""" + + def __init__(self, header, terms): + """Initiator for the Policy object. + + Args: + header: __main__.Header object. contains comments which should be passed + on to the rendered acls as well as the type of acls this policy file + should render to. + + terms: list __main__.Term. an array of Term objects which must be rendered + in each of the rendered acls. + + Attributes: + filters: list of tuples containing (header, terms). + """ + self.filters = [] + self.AddFilter(header, terms) + + def AddFilter(self, header, terms): + """Add another header & filter.""" + self.filters.append((header, terms)) + self._TranslateTerms(terms) + if _SHADE_CHECK: + self._DetectShading(terms) + + def _TranslateTerms(self, terms): + """.""" + if not terms: + raise NoTermsError('no terms found') + for term in terms: + # TODO(pmoody): this probably belongs in Term.SanityCheck(), + # or at the very least, in some method under class Term() + if term.translated: + continue + if term.port: + term.port = TranslatePorts(term.port, term.protocol, term.name) + if not term.port: + raise TermPortProtocolError( + 'no ports of the correct protocol for term %s' % ( + term.name)) + if term.source_port: + term.source_port = TranslatePorts(term.source_port, term.protocol, + term.name) + if not term.source_port: + raise TermPortProtocolError( + 'no source ports of the correct protocol for term %s' % ( + term.name)) + if term.destination_port: + term.destination_port = TranslatePorts(term.destination_port, + term.protocol, term.name) + if not term.destination_port: + raise TermPortProtocolError( + 'no destination ports of the correct protocol for term %s' % ( + term.name)) + + # If argument is true, we optimize, otherwise just sort addresses + term.AddressCleanup(_OPTIMIZE) + # Reset _OPTIMIZE global to default value + globals()['_OPTIMIZE'] = True + term.SanityCheck() + term.translated = True + + @property + def headers(self): + """Returns the headers from each of the configured filters. + + Returns: + headers + """ + return [x[0] for x in self.filters] + + def _DetectShading(self, terms): + """Finds terms which are shaded (impossible to reach). + + Iterate through each term, looking at each prior term. If a prior term + contains every component of the current term then the current term would + never be hit and is thus shaded. This can be a mistake. + + Args: + terms: list of Term objects. + + Raises: + ShadingError: When a term is impossible to reach. + """ + # Reset _OPTIMIZE global to default value + globals()['_SHADE_CHECK'] = False + shading_errors = [] + for index, term in enumerate(terms): + for prior_index in xrange(index): + # Check each term that came before for shading. Terms with next as an + # action do not terminate evaluation, so cannot shade. + if (term in terms[prior_index] + and 'next' not in terms[prior_index].action): + shading_errors.append( + ' %s is shaded by %s.' % ( + term.name, terms[prior_index].name)) + if shading_errors: + raise ShadingError('\n'.join(shading_errors)) + + +class Term(object): + """The Term object is used to store each of the terms. + + Args: + obj: an object of type VarType or a list of objects of type VarType + + members: + address/source_address/destination_address/: list of + VarType.(S|D)?ADDRESS's + address_exclude/source_address_exclude/destination_address_exclude: list of + VarType.(S|D)?ADDEXCLUDE's + port/source_port/destination_port: list of VarType.(S|D)?PORT's + options: list of VarType.OPTION's. + protocol: list of VarType.PROTOCOL's. + counter: VarType.COUNTER + action: list of VarType.ACTION's + comments: VarType.COMMENT + expiration: VarType.EXPIRATION + verbatim: VarType.VERBATIM + logging: VarType.LOGGING + qos: VarType.QOS + policer: VarType.POLICER + """ + ICMP_TYPE = {4: {'echo-reply': 0, + 'unreachable': 3, + 'source-quench': 4, + 'redirect': 5, + 'alternate-address': 6, + 'echo-request': 8, + 'router-advertisement': 9, + 'router-solicitation': 10, + 'time-exceeded': 11, + 'parameter-problem': 12, + 'timestamp-request': 13, + 'timestamp-reply': 14, + 'information-request': 15, + 'information-reply': 16, + 'mask-request': 17, + 'mask-reply': 18, + 'conversion-error': 31, + 'mobile-redirect': 32, + }, + 6: {'destination-unreachable': 1, + 'packet-too-big': 2, + 'time-exceeded': 3, + 'parameter-problem': 4, + 'echo-request': 128, + 'echo-reply': 129, + 'multicast-listener-query': 130, + 'multicast-listener-report': 131, + 'multicast-listener-done': 132, + 'router-solicit': 133, + 'router-advertisement': 134, + 'neighbor-solicit': 135, + 'neighbor-advertisement': 136, + 'redirect-message': 137, + 'router-renumbering': 138, + 'icmp-node-information-query': 139, + 'icmp-node-information-response': 140, + 'inverse-neighbor-discovery-solicitation': 141, + 'inverse-neighbor-discovery-advertisement': 142, + 'version-2-multicast-listener-report': 143, + 'home-agent-address-discovery-request': 144, + 'home-agent-address-discovery-reply': 145, + 'mobile-prefix-solicitation': 146, + 'mobile-prefix-advertisement': 147, + 'certification-path-solicitation': 148, + 'certification-path-advertisement': 149, + 'multicast-router-advertisement': 151, + 'multicast-router-solicitation': 152, + 'multicast-router-termination': 153, + }, + } + + def __init__(self, obj): + self.name = None + + self.action = [] + self.address = [] + self.address_exclude = [] + self.comment = [] + self.counter = None + self.expiration = None + self.destination_address = [] + self.destination_address_exclude = [] + self.destination_port = [] + self.destination_prefix = [] + self.logging = [] + self.loss_priority = None + self.option = [] + self.owner = None + self.policer = None + self.port = [] + self.precedence = [] + self.principals = [] + self.protocol = [] + self.protocol_except = [] + self.qos = None + self.routing_instance = None + self.source_address = [] + self.source_address_exclude = [] + self.source_port = [] + self.source_prefix = [] + self.verbatim = [] + # juniper specific. + self.packet_length = None + self.fragment_offset = None + self.icmp_type = [] + self.ether_type = [] + self.traffic_type = [] + self.translated = False + # iptables specific + self.source_interface = None + self.destination_interface = None + self.platform = [] + self.platform_exclude = [] + self.timeout = None + self.AddObject(obj) + self.flattened = False + self.flattened_addr = None + self.flattened_saddr = None + self.flattened_daddr = None + + def __contains__(self, other): + """Determine if other term is contained in this term.""" + if self.verbatim or other.verbatim: + # short circuit these + if sorted(self.verbatim) != sorted(other.verbatim): + return False + + # check prototols + # either protocol or protocol-except may be used, not both at the same time. + if self.protocol: + if other.protocol: + if not self.CheckProtocolIsContained(other.protocol, self.protocol): + return False + # this term has protocol, other has protocol_except. + elif other.protocol_except: + return False + else: + # other does not have protocol or protocol_except. since we do other + # cannot be contained in self. + return False + elif self.protocol_except: + if other.protocol_except: + if self.CheckProtocolIsContained( + self.protocol_except, other.protocol_except): + return False + elif other.protocol: + for proto in other.protocol: + if proto in self.protocol_except: + return False + else: + return False + + # combine addresses with exclusions for proper contains comparisons. + if not self.flattened: + self.FlattenAll() + if not other.flattened: + other.FlattenAll() + + # flat 'address' is compared against other flat (saddr|daddr). + # if NONE of these evaluate to True other is not contained. + if not ( + self.CheckAddressIsContained( + self.flattened_addr, other.flattened_addr) + or self.CheckAddressIsContained( + self.flattened_addr, other.flattened_saddr) + or self.CheckAddressIsContained( + self.flattened_addr, other.flattened_daddr)): + return False + + # compare flat address from other to flattened self (saddr|daddr). + if not ( + # other's flat address needs both self saddr & daddr to contain in order + # for the term to be contained. We already compared the flattened_addr + # attributes of both above, which was not contained. + self.CheckAddressIsContained( + other.flattened_addr, self.flattened_saddr) + and self.CheckAddressIsContained( + other.flattened_addr, self.flattened_daddr)): + return False + + # basic saddr/daddr check. + if not ( + self.CheckAddressIsContained( + self.flattened_saddr, other.flattened_saddr)): + return False + if not ( + self.CheckAddressIsContained( + self.flattened_daddr, other.flattened_daddr)): + return False + + if not ( + self.CheckPrincipalsContained( + self.principals, other.principals)): + return False + + # check ports + # like the address directive, the port directive is special in that it can + # be either source or destination. + if self.port: + if not (self.CheckPortIsContained(self.port, other.port) or + self.CheckPortIsContained(self.port, other.sport) or + self.CheckPortIsContained(self.port, other.dport)): + return False + if not self.CheckPortIsContained(self.source_port, other.source_port): + return False + if not self.CheckPortIsContained(self.destination_port, + other.destination_port): + return False + + # prefix lists + if self.source_prefix: + if sorted(self.source_prefix) != sorted(other.source_prefix): + return False + if self.destination_prefix: + if sorted(self.destination_prefix) != sorted( + other.destination_prefix): + return False + + # check precedence + if self.precedence: + if not other.precedence: + return False + for precedence in other.precedence: + if precedence not in self.precedence: + return False + # check various options + if self.option: + if not other.option: + return False + for opt in other.option: + if opt not in self.option: + return False + if self.fragment_offset: + # fragment_offset looks like 'integer-integer' or just, 'integer' + sfo = [int(x) for x in self.fragment_offset.split('-')] + if other.fragment_offset: + ofo = [int(x) for x in other.fragment_offset.split('-')] + if sfo[0] < ofo[0] or sorted(sfo[1:]) > sorted(ofo[1:]): + return False + else: + return False + if self.packet_length: + # packet_length looks like 'integer-integer' or just, 'integer' + spl = [int(x) for x in self.packet_length.split('-')] + if other.packet_length: + opl = [int(x) for x in other.packet_length.split('-')] + if spl[0] < opl[0] or sorted(spl[1:]) > sorted(opl[1:]): + return False + else: + return False + if self.icmp_type: + if sorted(self.icmp_type) is not sorted(other.icmp_type): + return False + + # check platform + if self.platform: + if sorted(self.platform) is not sorted(other.platform): + return False + if self.platform_exclude: + if sorted(self.platform_exclude) is not sorted(other.platform_exclude): + return False + + # we have containment + return True + + def __str__(self): + ret_str = [] + ret_str.append(' name: %s' % self.name) + if self.address: + ret_str.append(' address: %s' % self.address) + if self.address_exclude: + ret_str.append(' address_exclude: %s' % self.address_exclude) + if self.source_address: + ret_str.append(' source_address: %s' % self.source_address) + if self.source_address_exclude: + ret_str.append(' source_address_exclude: %s' % + self.source_address_exclude) + if self.destination_address: + ret_str.append(' destination_address: %s' % self.destination_address) + if self.destination_address_exclude: + ret_str.append(' destination_address_exclude: %s' % + self.destination_address_exclude) + if self.source_prefix: + ret_str.append(' source_prefix: %s' % self.source_prefix) + if self.destination_prefix: + ret_str.append(' destination_prefix: %s' % self.destination_prefix) + if self.protocol: + ret_str.append(' protocol: %s' % self.protocol) + if self.protocol_except: + ret_str.append(' protocol-except: %s' % self.protocol_except) + if self.owner: + ret_str.append(' owner: %s' % self.owner) + if self.port: + ret_str.append(' port: %s' % self.port) + if self.source_port: + ret_str.append(' source_port: %s' % self.source_port) + if self.destination_port: + ret_str.append(' destination_port: %s' % self.destination_port) + if self.action: + ret_str.append(' action: %s' % self.action) + if self.option: + ret_str.append(' option: %s' % self.option) + if self.qos: + ret_str.append(' qos: %s' % self.qos) + if self.logging: + ret_str.append(' logging: %s' % self.logging) + if self.counter: + ret_str.append(' counter: %s' % self.counter) + if self.source_interface: + ret_str.append(' source_interface: %s' % self.source_interface) + if self.destination_interface: + ret_str.append(' destination_interface: %s' % self.destination_interface) + if self.expiration: + ret_str.append(' expiration: %s' % self.expiration) + if self.platform: + ret_str.append(' platform: %s' % self.platform) + if self.platform_exclude: + ret_str.append(' platform_exclude: %s' % self.platform_exclude) + if self.timeout: + ret_str.append(' timeout: %s' % self.timeout) + return '\n'.join(ret_str) + + def __eq__(self, other): + # action + if sorted(self.action) != sorted(other.action): + return False + + # addresses. + if not (sorted(self.address) == sorted(other.address) and + sorted(self.source_address) == sorted(other.source_address) and + sorted(self.source_address_exclude) == + sorted(other.source_address_exclude) and + sorted(self.destination_address) == + sorted(other.destination_address) and + sorted(self.destination_address_exclude) == + sorted(other.destination_address_exclude)): + return False + + # prefix lists + if not (sorted(self.source_prefix) == sorted(other.source_prefix) and + sorted(self.destination_prefix) == + sorted(other.destination_prefix)): + return False + + # ports + if not (sorted(self.port) == sorted(other.port) and + sorted(self.source_port) == sorted(other.source_port) and + sorted(self.destination_port) == sorted(other.destination_port)): + return False + + # protocol + if not (sorted(self.protocol) == sorted(other.protocol) and + sorted(self.protocol_except) == sorted(other.protocol_except)): + return False + + # option + if sorted(self.option) != sorted(other.option): + return False + + # qos + if self.qos != other.qos: + return False + + # verbatim + if self.verbatim != other.verbatim: + return False + + # policer + if self.policer != other.policer: + return False + + # interface + if self.source_interface != other.source_interface: + return False + + if self.destination_interface != other.destination_interface: + return False + + if sorted(self.logging) != sorted(other.logging): + return False + if self.qos != other.qos: + return False + if self.packet_length != other.packet_length: + return False + if self.fragment_offset != other.fragment_offset: + return False + if sorted(self.icmp_type) != sorted(other.icmp_type): + return False + if sorted(self.ether_type) != sorted(other.ether_type): + return False + if sorted(self.traffic_type) != sorted(other.traffic_type): + return False + + # platform + if not (sorted(self.platform) == sorted(other.platform) and + sorted(self.platform_exclude) == sorted(other.platform_exclude)): + return False + + # timeout + if self.timeout != other.timeout: + return False + + return True + + def __ne__(self, other): + return not self.__eq__(other) + + def FlattenAll(self): + """Reduce source, dest, and address fields to their post-exclude state. + + Populates the self.flattened_addr, self.flattened_saddr, + self.flattened_daddr by removing excludes from includes. + """ + # No excludes, set flattened attributes and move along. + self.flattened = True + if not (self.source_address_exclude or self.destination_address_exclude or + self.address_exclude): + self.flattened_saddr = self.source_address + self.flattened_daddr = self.destination_address + self.flattened_addr = self.address + return + + if self.source_address_exclude: + self.flattened_saddr = self._FlattenAddresses( + self.source_address, self.source_address_exclude) + if self.destination_address_exclude: + self.flattened_daddr = self._FlattenAddresses( + self.destination_address, self.destination_address_exclude) + if self.address_exclude: + self.flattened_addr = self._FlattenAddresses( + self.address, self.address_exclude) + + + @staticmethod + def _FlattenAddresses(include, exclude): + """Reduce an include and exclude list to a single include list. + + Using recursion, whittle away exclude addresses from address include + addresses which contain the exclusion. + + Args: + include: list of include addresses. + exclude: list of exclude addresses. + Returns: + a single flattened list of nacaddr objects. + """ + if not exclude: + return include + + for index, in_addr in enumerate(include): + for ex_addr in exclude: + if ex_addr in in_addr: + reduced_list = in_addr.address_exclude(ex_addr) + include.pop(index) + include.extend( + Term._FlattenAddresses(reduced_list, exclude[1:])) + return include + + def GetAddressOfVersion(self, addr_type, af=None): + """Returns addresses of the appropriate Address Family. + + Args: + addr_type: string, this will be either + 'source_address', 'source_address_exclude', + 'destination_address' or 'destination_address_exclude' + af: int or None, either Term.INET4 or Term.INET6 + + Returns: + list of addresses of the correct family. + """ + if not af: + return eval('self.' + addr_type) + + return filter(lambda x: x.version == af, eval('self.' + addr_type)) + + def AddObject(self, obj): + """Add an object of unknown type to this term. + + Args: + obj: single or list of either + [Address, Port, Option, Protocol, Counter, Action, Comment, Expiration] + + Raises: + InvalidTermActionError: if the action defined isn't an accepted action. + eg, action:: godofoobar + TermObjectTypeError: if AddObject is called with an object it doesn't + understand. + InvalidTermLoggingError: when a option is set for logging not known. + """ + if type(obj) is list: + for x in obj: + # do we have a list of addresses? + # expanded address fields consolidate naked address fields with + # saddr/daddr. + if x.var_type is VarType.SADDRESS: + saddr = DEFINITIONS.GetNetAddr(x.value) + self.source_address.extend(saddr) + elif x.var_type is VarType.DADDRESS: + daddr = DEFINITIONS.GetNetAddr(x.value) + self.destination_address.extend(daddr) + elif x.var_type is VarType.ADDRESS: + addr = DEFINITIONS.GetNetAddr(x.value) + self.address.extend(addr) + # do we have address excludes? + elif x.var_type is VarType.SADDREXCLUDE: + saddr_exclude = DEFINITIONS.GetNetAddr(x.value) + self.source_address_exclude.extend(saddr_exclude) + elif x.var_type is VarType.DADDREXCLUDE: + daddr_exclude = DEFINITIONS.GetNetAddr(x.value) + self.destination_address_exclude.extend(daddr_exclude) + elif x.var_type is VarType.ADDREXCLUDE: + addr_exclude = DEFINITIONS.GetNetAddr(x.value) + self.address_exclude.extend(addr_exclude) + # do we have a list of ports? + elif x.var_type is VarType.PORT: + self.port.append(x.value) + elif x.var_type is VarType.SPORT: + self.source_port.append(x.value) + elif x.var_type is VarType.DPORT: + self.destination_port.append(x.value) + # do we have a list of protocols? + elif x.var_type is VarType.PROTOCOL: + self.protocol.append(x.value) + # do we have a list of protocol-exceptions? + elif x.var_type is VarType.PROTOCOL_EXCEPT: + self.protocol_except.append(x.value) + # do we have a list of options? + elif x.var_type is VarType.OPTION: + self.option.append(x.value) + elif x.var_type is VarType.PRINCIPALS: + self.principals.append(x.value) + elif x.var_type is VarType.SPFX: + self.source_prefix.append(x.value) + elif x.var_type is VarType.DPFX: + self.destination_prefix.append(x.value) + elif x.var_type is VarType.ETHER_TYPE: + self.ether_type.append(x.value) + elif x.var_type is VarType.TRAFFIC_TYPE: + self.traffic_type.append(x.value) + elif x.var_type is VarType.PRECEDENCE: + self.precedence.append(x.value) + elif x.var_type is VarType.PLATFORM: + self.platform.append(x.value) + elif x.var_type is VarType.PLATFORMEXCLUDE: + self.platform_exclude.append(x.value) + else: + raise TermObjectTypeError( + '%s isn\'t a type I know how to deal with (contains \'%s\')' % ( + type(x), x.value)) + else: + # stupid no switch statement in python + if obj.var_type is VarType.COMMENT: + self.comment.append(str(obj)) + elif obj.var_type is VarType.OWNER: + self.owner = obj.value + elif obj.var_type is VarType.EXPIRATION: + self.expiration = obj.value + elif obj.var_type is VarType.LOSS_PRIORITY: + self.loss_priority = obj.value + elif obj.var_type is VarType.ROUTING_INSTANCE: + self.routing_instance = obj.value + elif obj.var_type is VarType.PRECEDENCE: + self.precedence = obj.value + elif obj.var_type is VarType.VERBATIM: + self.verbatim.append(obj) + elif obj.var_type is VarType.ACTION: + if str(obj) not in _ACTIONS: + raise InvalidTermActionError('%s is not a valid action' % obj) + self.action.append(obj.value) + elif obj.var_type is VarType.COUNTER: + self.counter = obj + elif obj.var_type is VarType.ICMP_TYPE: + self.icmp_type.extend(obj.value) + elif obj.var_type is VarType.LOGGING: + if str(obj) not in _LOGGING: + raise InvalidTermLoggingError('%s is not a valid logging option' % + obj) + self.logging.append(obj) + # police man, tryin'a take you jail + elif obj.var_type is VarType.POLICER: + self.policer = obj.value + # qos? + elif obj.var_type is VarType.QOS: + self.qos = obj.value + elif obj.var_type is VarType.PACKET_LEN: + self.packet_length = obj.value + elif obj.var_type is VarType.FRAGMENT_OFFSET: + self.fragment_offset = obj.value + elif obj.var_type is VarType.SINTERFACE: + self.source_interface = obj.value + elif obj.var_type is VarType.DINTERFACE: + self.destination_interface = obj.value + elif obj.var_type is VarType.TIMEOUT: + self.timeout = obj.value + else: + raise TermObjectTypeError( + '%s isn\'t a type I know how to deal with' % (type(obj))) + + def SanityCheck(self): + """Sanity check the definition of the term. + + Raises: + ParseError: if term has both verbatim and non-verbatim tokens + TermInvalidIcmpType: if term has invalid icmp-types specified + TermNoActionError: if the term doesn't have an action defined. + TermPortProtocolError: if the term has a service/protocol definition pair + which don't match up, eg. SNMP and tcp + TermAddressExclusionError: if one of the *-exclude directives is defined, + but that address isn't contained in the non *-exclude directive. eg: + source-address::CORP_INTERNAL source-exclude:: LOCALHOST + TermProtocolEtherTypeError: if the term has both ether-type and + upper-layer protocol restrictions + InvalidTermActionError: action and routing-instance both defined + + This should be called when the term is fully formed, and + all of the options are set. + + """ + if self.verbatim: + if (self.action or self.source_port or self.destination_port or + self.port or self.protocol or self.option): + raise ParseError( + 'term "%s" has both verbatim and non-verbatim tokens.' % self.name) + else: + if not self.action and not self.routing_instance: + raise TermNoActionError('no action specified for term %s' % self.name) + # have we specified a port with a protocol that doesn't support ports? + if self.source_port or self.destination_port or self.port: + if 'tcp' not in self.protocol and 'udp' not in self.protocol: + raise TermPortProtocolError( + 'ports specified with a protocol that doesn\'t support ports. ' + 'Term: %s ' % self.name) + # TODO(pmoody): do we have mutually exclusive options? + # eg. tcp-established + tcp-initial? + + if self.ether_type and ( + self.protocol or + self.address or + self.destination_address or + self.destination_address_exclude or + self.destination_port or + self.destination_prefix or + self.source_address or + self.source_address_exclude or + self.source_port or + self.source_prefix): + raise TermProtocolEtherTypeError( + 'ether-type not supported when used with upper-layer protocol ' + 'restrictions. Term: %s' % self.name) + # validate icmp-types if specified, but addr_family will have to be checked + # in the generators as policy module doesn't know about that at this point. + if self.icmp_type: + for icmptype in self.icmp_type: + if (icmptype not in self.ICMP_TYPE[4] and icmptype not in + self.ICMP_TYPE[6]): + raise TermInvalidIcmpType('Term %s contains an invalid icmp-type:' + '%s' % (self.name, icmptype)) + + def AddressCleanup(self, optimize=True): + """Do Address and Port collapsing. + + Notes: + Collapses both the address definitions and the port definitions + to their smallest possible length. + + Args: + optimize: boolean value indicating whether to optimize addresses + """ + if optimize: + cleanup = nacaddr.CollapseAddrList + else: + cleanup = nacaddr.SortAddrList + + # address collapsing. + if self.address: + self.address = cleanup(self.address) + if self.source_address: + self.source_address = cleanup(self.source_address) + if self.source_address_exclude: + self.source_address_exclude = cleanup(self.source_address_exclude) + if self.destination_address: + self.destination_address = cleanup(self.destination_address) + if self.destination_address_exclude: + self.destination_address_exclude = cleanup( + self.destination_address_exclude) + + # port collapsing. + if self.port: + self.port = self.CollapsePortList(self.port) + if self.source_port: + self.source_port = self.CollapsePortList(self.source_port) + if self.destination_port: + self.destination_port = self.CollapsePortList(self.destination_port) + + def CollapsePortListRecursive(self, ports): + """Given a sorted list of ports, collapse to the smallest required list. + + Args: + ports: sorted list of port tuples + + Returns: + ret_ports: collapsed list of ports + """ + optimized = False + ret_ports = [] + for port in ports: + if not ret_ports: + ret_ports.append(port) + # we should be able to count on ret_ports[-1][0] <= port[0] + elif ret_ports[-1][1] >= port[1]: + # (10, 20) and (12, 13) -> (10, 20) + optimized = True + elif port[0] < ret_ports[-1][1] < port[1]: + # (10, 20) and (15, 30) -> (10, 30) + ret_ports[-1] = (ret_ports[-1][0], port[1]) + optimized = True + elif ret_ports[-1][1] + 1 == port[0]: + # (10, 20) and (21, 30) -> (10, 30) + ret_ports[-1] = (ret_ports[-1][0], port[1]) + optimized = True + else: + # (10, 20) and (22, 30) -> (10, 20), (22, 30) + ret_ports.append(port) + + if optimized: + return self.CollapsePortListRecursive(ret_ports) + return ret_ports + + def CollapsePortList(self, ports): + """Given a list of ports, Collapse to the smallest required. + + Args: + ports: a list of port strings eg: [(80,80), (53,53) (2000, 2009), + (1024,65535)] + + Returns: + ret_array: the collapsed sorted list of ports, eg: [(53,53), (80,80), + (1024,65535)] + """ + return self.CollapsePortListRecursive(sorted(ports)) + + def CheckPrincipalsContained(self, superset, subset): + """Check to if the given list of principals is wholly contained. + + Args: + superset: list of principals + subset: list of principals + + Returns: + bool: True if subset is contained in superset. false otherwise. + """ + # Skip set comparison if neither term has principals. + if not superset and not subset: + return True + + # Convert these lists to sets to use set comparison. + sup = set(superset) + sub = set(subset) + return sub.issubset(sup) + + def CheckProtocolIsContained(self, superset, subset): + """Check if the given list of protocols is wholly contained. + + Args: + superset: list of protocols + subset: list of protocols + + Returns: + bool: True if subset is contained in superset. false otherwise. + """ + if not superset: + return True + if not subset: + return False + + # Convert these lists to sets to use set comparison. + sup = set(superset) + sub = set(subset) + return sub.issubset(sup) + + def CheckPortIsContained(self, superset, subset): + """Check if the given list of ports is wholly contained. + + Args: + superset: list of port tuples + subset: list of port tuples + + Returns: + bool: True if subset is contained in superset, false otherwise + """ + if not superset: + return True + if not subset: + return False + + for sub_port in subset: + not_contains = True + for sup_port in superset: + if (int(sub_port[0]) >= int(sup_port[0]) + and int(sub_port[1]) <= int(sup_port[1])): + not_contains = False + break + if not_contains: + return False + return True + + def CheckAddressIsContained(self, superset, subset): + """Check if subset is wholey contained by superset. + + Args: + superset: list of the superset addresses + subset: list of the subset addresses + + Returns: + True or False. + """ + if not superset: + return True + if not subset: + return False + + for sub_addr in subset: + sub_contained = False + for sup_addr in superset: + # ipaddr ensures that version numbers match for inclusion. + if sub_addr in sup_addr: + sub_contained = True + break + if not sub_contained: + return False + return True + + +class VarType(object): + """Generic object meant to store lots of basic policy types.""" + + COMMENT = 0 + COUNTER = 1 + ACTION = 2 + SADDRESS = 3 + DADDRESS = 4 + ADDRESS = 5 + SPORT = 6 + DPORT = 7 + PROTOCOL_EXCEPT = 8 + OPTION = 9 + PROTOCOL = 10 + SADDREXCLUDE = 11 + DADDREXCLUDE = 12 + LOGGING = 13 + QOS = 14 + POLICER = 15 + PACKET_LEN = 16 + FRAGMENT_OFFSET = 17 + ICMP_TYPE = 18 + SPFX = 19 + DPFX = 20 + ETHER_TYPE = 21 + TRAFFIC_TYPE = 22 + VERBATIM = 23 + LOSS_PRIORITY = 24 + ROUTING_INSTANCE = 25 + PRECEDENCE = 26 + SINTERFACE = 27 + EXPIRATION = 28 + DINTERFACE = 29 + PLATFORM = 30 + PLATFORMEXCLUDE = 31 + PORT = 32 + TIMEOUT = 33 + OWNER = 34 + PRINCIPALS = 35 + ADDREXCLUDE = 36 + + def __init__(self, var_type, value): + self.var_type = var_type + if self.var_type == self.COMMENT: + # remove the double quotes + comment = value.strip('"') + # make all of the lines start w/o leading whitespace. + self.value = '\n'.join([x.lstrip() for x in comment.splitlines()]) + else: + self.value = value + + def __str__(self): + return self.value + + def __eq__(self, other): + return self.var_type == other.var_type and self.value == other.value + + +class Header(object): + """The header of the policy file contains the targets and a global comment.""" + + def __init__(self): + self.target = [] + self.comment = [] + + def AddObject(self, obj): + """Add and object to the Header. + + Args: + obj: of type VarType.COMMENT or Target + """ + if type(obj) == Target: + self.target.append(obj) + elif obj.var_type == VarType.COMMENT: + self.comment.append(str(obj)) + + @property + def platforms(self): + """The platform targets of this particular header.""" + return map(lambda x: x.platform, self.target) + + def FilterOptions(self, platform): + """Given a platform return the options. + + Args: + platform: string + + Returns: + list or None + """ + for target in self.target: + if target.platform == platform: + return target.options + return [] + + def FilterName(self, platform): + """Given a filter_type, return the filter name. + + Args: + platform: string + + Returns: + filter_name: string or None + + Notes: + !! Deprecated in favor of Header.FilterOptions(platform) !! + """ + for target in self.target: + if target.platform == platform: + if target.options: + return target.options[0] + return None + + +# This could be a VarType object, but I'm keeping it as it's class +# b/c we're almost certainly going to have to do something more exotic with +# it shortly to account for various rendering options like default iptables +# policies or output file names, etc. etc. +class Target(object): + """The type of acl to be rendered from this policy file.""" + + def __init__(self, target): + self.platform = target[0] + if len(target) > 1: + self.options = target[1:] + else: + self.options = None + + def __str__(self): + return self.platform + + def __eq__(self, other): + return self.platform == other.platform and self.options == other.options + + def __ne__(self, other): + return not self.__eq__(other) + + +# Lexing/Parsing starts here +tokens = ( + 'ACTION', + 'ADDR', + 'ADDREXCLUDE', + 'COMMENT', + 'COUNTER', + 'DADDR', + 'DADDREXCLUDE', + 'DPFX', + 'DPORT', + 'DINTERFACE', + 'DQUOTEDSTRING', + 'ETHER_TYPE', + 'EXPIRATION', + 'FRAGMENT_OFFSET', + 'HEADER', + 'ICMP_TYPE', + 'INTEGER', + 'LOGGING', + 'LOSS_PRIORITY', + 'OPTION', + 'OWNER', + 'PACKET_LEN', + 'PLATFORM', + 'PLATFORMEXCLUDE', + 'POLICER', + 'PORT', + 'PRECEDENCE', + 'PRINCIPALS', + 'PROTOCOL', + 'PROTOCOL_EXCEPT', + 'QOS', + 'ROUTING_INSTANCE', + 'SADDR', + 'SADDREXCLUDE', + 'SINTERFACE', + 'SPFX', + 'SPORT', + 'STRING', + 'TARGET', + 'TERM', + 'TIMEOUT', + 'TRAFFIC_TYPE', + 'VERBATIM', +) + +literals = r':{},-' +t_ignore = ' \t' + +reserved = { + 'action': 'ACTION', + 'address': 'ADDR', + 'address-exclude': 'ADDREXCLUDE', + 'comment': 'COMMENT', + 'counter': 'COUNTER', + 'destination-address': 'DADDR', + 'destination-exclude': 'DADDREXCLUDE', + 'destination-interface': 'DINTERFACE', + 'destination-prefix': 'DPFX', + 'destination-port': 'DPORT', + 'ether-type': 'ETHER_TYPE', + 'expiration': 'EXPIRATION', + 'fragment-offset': 'FRAGMENT_OFFSET', + 'header': 'HEADER', + 'icmp-type': 'ICMP_TYPE', + 'logging': 'LOGGING', + 'loss-priority': 'LOSS_PRIORITY', + 'option': 'OPTION', + 'owner': 'OWNER', + 'packet-length': 'PACKET_LEN', + 'platform': 'PLATFORM', + 'platform-exclude': 'PLATFORMEXCLUDE', + 'policer': 'POLICER', + 'port': 'PORT', + 'precedence': 'PRECEDENCE', + 'principals': 'PRINCIPALS', + 'protocol': 'PROTOCOL', + 'protocol-except': 'PROTOCOL_EXCEPT', + 'qos': 'QOS', + 'routing-instance': 'ROUTING_INSTANCE', + 'source-address': 'SADDR', + 'source-exclude': 'SADDREXCLUDE', + 'source-interface': 'SINTERFACE', + 'source-prefix': 'SPFX', + 'source-port': 'SPORT', + 'target': 'TARGET', + 'term': 'TERM', + 'timeout': 'TIMEOUT', + 'traffic-type': 'TRAFFIC_TYPE', + 'verbatim': 'VERBATIM', +} + + +# disable linting warnings for lexx/yacc code +# pylint: disable-msg=W0613,C6102,C6104,C6105,C6108,C6409 + + +def t_IGNORE_COMMENT(t): + r'\#.*' + pass + + +def t_DQUOTEDSTRING(t): + r'"[^"]*?"' + t.lexer.lineno += str(t.value).count('\n') + return t + + +def t_newline(t): + r'\n+' + t.lexer.lineno += len(t.value) + + +def t_error(t): + print "Illegal character '%s' on line %s" % (t.value[0], t.lineno) + t.lexer.skip(1) + + +def t_INTEGER(t): + r'\d+' + return t + + +def t_STRING(t): + r'\w+([-_+.@]\w*)*' + # we have an identifier; let's check if it's a keyword or just a string. + t.type = reserved.get(t.value, 'STRING') + return t + + +### +## parser starts here +### +def p_target(p): + """ target : target header terms + | """ + if len(p) > 1: + if type(p[1]) is Policy: + p[1].AddFilter(p[2], p[3]) + p[0] = p[1] + else: + p[0] = Policy(p[2], p[3]) + + +def p_header(p): + """ header : HEADER '{' header_spec '}' """ + p[0] = p[3] + + +def p_header_spec(p): + """ header_spec : header_spec target_spec + | header_spec comment_spec + | """ + if len(p) > 1: + if type(p[1]) == Header: + p[1].AddObject(p[2]) + p[0] = p[1] + else: + p[0] = Header() + p[0].AddObject(p[2]) + + +# we may want to change this at some point if we want to be clever with things +# like being able to set a default input/output policy for iptables policies. +def p_target_spec(p): + """ target_spec : TARGET ':' ':' strings_or_ints """ + p[0] = Target(p[4]) + + +def p_terms(p): + """ terms : terms TERM STRING '{' term_spec '}' + | """ + if len(p) > 1: + p[5].name = p[3] + if type(p[1]) == list: + p[1].append(p[5]) + p[0] = p[1] + else: + p[0] = [p[5]] + + +def p_term_spec(p): + """ term_spec : term_spec action_spec + | term_spec addr_spec + | term_spec comment_spec + | term_spec counter_spec + | term_spec ether_type_spec + | term_spec exclude_spec + | term_spec expiration_spec + | term_spec fragment_offset_spec + | term_spec icmp_type_spec + | term_spec interface_spec + | term_spec logging_spec + | term_spec losspriority_spec + | term_spec option_spec + | term_spec owner_spec + | term_spec packet_length_spec + | term_spec platform_spec + | term_spec policer_spec + | term_spec port_spec + | term_spec precedence_spec + | term_spec principals_spec + | term_spec prefix_list_spec + | term_spec protocol_spec + | term_spec qos_spec + | term_spec routinginstance_spec + | term_spec timeout_spec + | term_spec traffic_type_spec + | term_spec verbatim_spec + | """ + if len(p) > 1: + if type(p[1]) == Term: + p[1].AddObject(p[2]) + p[0] = p[1] + else: + p[0] = Term(p[2]) + + +def p_routinginstance_spec(p): + """ routinginstance_spec : ROUTING_INSTANCE ':' ':' STRING """ + p[0] = VarType(VarType.ROUTING_INSTANCE, p[4]) + + +def p_losspriority_spec(p): + """ losspriority_spec : LOSS_PRIORITY ':' ':' STRING """ + p[0] = VarType(VarType.LOSS_PRIORITY, p[4]) + + +def p_precedence_spec(p): + """ precedence_spec : PRECEDENCE ':' ':' one_or_more_ints """ + p[0] = VarType(VarType.PRECEDENCE, p[4]) + + +def p_icmp_type_spec(p): + """ icmp_type_spec : ICMP_TYPE ':' ':' one_or_more_strings """ + p[0] = VarType(VarType.ICMP_TYPE, p[4]) + + +def p_packet_length_spec(p): + """ packet_length_spec : PACKET_LEN ':' ':' INTEGER + | PACKET_LEN ':' ':' INTEGER '-' INTEGER """ + if len(p) == 4: + p[0] = VarType(VarType.PACKET_LEN, str(p[4])) + else: + p[0] = VarType(VarType.PACKET_LEN, str(p[4]) + '-' + str(p[6])) + + +def p_fragment_offset_spec(p): + """ fragment_offset_spec : FRAGMENT_OFFSET ':' ':' INTEGER + | FRAGMENT_OFFSET ':' ':' INTEGER '-' INTEGER """ + if len(p) == 4: + p[0] = VarType(VarType.FRAGMENT_OFFSET, str(p[4])) + else: + p[0] = VarType(VarType.FRAGMENT_OFFSET, str(p[4]) + '-' + str(p[6])) + + +def p_exclude_spec(p): + """ exclude_spec : SADDREXCLUDE ':' ':' one_or_more_strings + | DADDREXCLUDE ':' ':' one_or_more_strings + | ADDREXCLUDE ':' ':' one_or_more_strings + | PROTOCOL_EXCEPT ':' ':' one_or_more_strings """ + + p[0] = [] + for ex in p[4]: + if p[1].find('source-exclude') >= 0: + p[0].append(VarType(VarType.SADDREXCLUDE, ex)) + elif p[1].find('destination-exclude') >= 0: + p[0].append(VarType(VarType.DADDREXCLUDE, ex)) + elif p[1].find('address-exclude') >= 0: + p[0].append(VarType(VarType.ADDREXCLUDE, ex)) + elif p[1].find('protocol-except') >= 0: + p[0].append(VarType(VarType.PROTOCOL_EXCEPT, ex)) + + +def p_prefix_list_spec(p): + """ prefix_list_spec : DPFX ':' ':' one_or_more_strings + | SPFX ':' ':' one_or_more_strings """ + p[0] = [] + for pfx in p[4]: + if p[1].find('source-prefix') >= 0: + p[0].append(VarType(VarType.SPFX, pfx)) + elif p[1].find('destination-prefix') >= 0: + p[0].append(VarType(VarType.DPFX, pfx)) + + +def p_addr_spec(p): + """ addr_spec : SADDR ':' ':' one_or_more_strings + | DADDR ':' ':' one_or_more_strings + | ADDR ':' ':' one_or_more_strings """ + p[0] = [] + for addr in p[4]: + if p[1].find('source-address') >= 0: + p[0].append(VarType(VarType.SADDRESS, addr)) + elif p[1].find('destination-address') >= 0: + p[0].append(VarType(VarType.DADDRESS, addr)) + else: + p[0].append(VarType(VarType.ADDRESS, addr)) + + +def p_port_spec(p): + """ port_spec : SPORT ':' ':' one_or_more_strings + | DPORT ':' ':' one_or_more_strings + | PORT ':' ':' one_or_more_strings """ + p[0] = [] + for port in p[4]: + if p[1].find('source-port') >= 0: + p[0].append(VarType(VarType.SPORT, port)) + elif p[1].find('destination-port') >= 0: + p[0].append(VarType(VarType.DPORT, port)) + else: + p[0].append(VarType(VarType.PORT, port)) + + +def p_protocol_spec(p): + """ protocol_spec : PROTOCOL ':' ':' strings_or_ints """ + p[0] = [] + for proto in p[4]: + p[0].append(VarType(VarType.PROTOCOL, proto)) + + +def p_ether_type_spec(p): + """ ether_type_spec : ETHER_TYPE ':' ':' one_or_more_strings """ + p[0] = [] + for proto in p[4]: + p[0].append(VarType(VarType.ETHER_TYPE, proto)) + + +def p_traffic_type_spec(p): + """ traffic_type_spec : TRAFFIC_TYPE ':' ':' one_or_more_strings """ + p[0] = [] + for proto in p[4]: + p[0].append(VarType(VarType.TRAFFIC_TYPE, proto)) + + +def p_policer_spec(p): + """ policer_spec : POLICER ':' ':' STRING """ + p[0] = VarType(VarType.POLICER, p[4]) + + +def p_logging_spec(p): + """ logging_spec : LOGGING ':' ':' STRING """ + p[0] = VarType(VarType.LOGGING, p[4]) + + +def p_option_spec(p): + """ option_spec : OPTION ':' ':' one_or_more_strings """ + p[0] = [] + for opt in p[4]: + p[0].append(VarType(VarType.OPTION, opt)) + +def p_principals_spec(p): + """ principals_spec : PRINCIPALS ':' ':' one_or_more_strings """ + p[0] = [] + for opt in p[4]: + p[0].append(VarType(VarType.PRINCIPALS, opt)) + +def p_action_spec(p): + """ action_spec : ACTION ':' ':' STRING """ + p[0] = VarType(VarType.ACTION, p[4]) + + +def p_counter_spec(p): + """ counter_spec : COUNTER ':' ':' STRING """ + p[0] = VarType(VarType.COUNTER, p[4]) + + +def p_expiration_spec(p): + """ expiration_spec : EXPIRATION ':' ':' INTEGER '-' INTEGER '-' INTEGER """ + p[0] = VarType(VarType.EXPIRATION, datetime.date(int(p[4]), + int(p[6]), + int(p[8]))) + + +def p_comment_spec(p): + """ comment_spec : COMMENT ':' ':' DQUOTEDSTRING """ + p[0] = VarType(VarType.COMMENT, p[4]) + + +def p_owner_spec(p): + """ owner_spec : OWNER ':' ':' STRING """ + p[0] = VarType(VarType.OWNER, p[4]) + + +def p_verbatim_spec(p): + """ verbatim_spec : VERBATIM ':' ':' STRING DQUOTEDSTRING """ + p[0] = VarType(VarType.VERBATIM, [p[4], p[5].strip('"')]) + + +def p_qos_spec(p): + """ qos_spec : QOS ':' ':' STRING """ + p[0] = VarType(VarType.QOS, p[4]) + + +def p_interface_spec(p): + """ interface_spec : SINTERFACE ':' ':' STRING + | DINTERFACE ':' ':' STRING """ + if p[1].find('source-interface') >= 0: + p[0] = VarType(VarType.SINTERFACE, p[4]) + elif p[1].find('destination-interface') >= 0: + p[0] = VarType(VarType.DINTERFACE, p[4]) + + +def p_platform_spec(p): + """ platform_spec : PLATFORM ':' ':' one_or_more_strings + | PLATFORMEXCLUDE ':' ':' one_or_more_strings """ + p[0] = [] + for platform in p[4]: + if p[1].find('platform-exclude') >= 0: + p[0].append(VarType(VarType.PLATFORMEXCLUDE, platform)) + elif p[1].find('platform') >= 0: + p[0].append(VarType(VarType.PLATFORM, platform)) + + +def p_timeout_spec(p): + """ timeout_spec : TIMEOUT ':' ':' INTEGER """ + p[0] = VarType(VarType.TIMEOUT, p[4]) + + +def p_one_or_more_strings(p): + """ one_or_more_strings : one_or_more_strings STRING + | STRING + | """ + if len(p) > 1: + if type(p[1]) == type([]): + p[1].append(p[2]) + p[0] = p[1] + else: + p[0] = [p[1]] + + +def p_one_or_more_ints(p): + """ one_or_more_ints : one_or_more_ints INTEGER + | INTEGER + | """ + if len(p) > 1: + if type(p[1]) == type([]): + p[1].append(p[2]) + p[0] = p[1] + else: + p[0] = [p[1]] + + +def p_strings_or_ints(p): + """ strings_or_ints : strings_or_ints STRING + | strings_or_ints INTEGER + | STRING + | INTEGER + | """ + if len(p) > 1: + if type(p[1]) is list: + p[1].append(p[2]) + p[0] = p[1] + else: + p[0] = [p[1]] + + +def p_error(p): + """.""" + next_token = yacc.token() + if next_token is None: + use_token = 'EOF' + else: + use_token = repr(next_token.value) + + if p: + raise ParseError(' ERROR on "%s" (type %s, line %d, Next %s)' + % (p.value, p.type, p.lineno, use_token)) + else: + raise ParseError(' ERROR you likely have unablanaced "{"\'s') + +# pylint: enable-msg=W0613,C6102,C6104,C6105,C6108,C6409 + + +def _ReadFile(filename): + """Read data from a file if it exists. + + Args: + filename: str - Filename + + Returns: + data: str contents of file. + + Raises: + FileNotFoundError: if requested file does not exist. + FileReadError: Any error resulting from trying to open/read file. + """ + if os.path.exists(filename): + try: + data = open(filename, 'r').read() + return data + except IOError: + raise FileReadError('Unable to open or read file %s' % filename) + else: + raise FileNotFoundError('Unable to open policy file %s' % filename) + + +def _Preprocess(data, max_depth=5, base_dir=''): + """Search input for include statements and import specified include file. + + Search input for include statements and if found, import specified file + and recursively search included data for includes as well up to max_depth. + + Args: + data: A string of Policy file data. + max_depth: Maximum depth of included files + base_dir: Base path string where to look for policy or include files + + Returns: + A string containing result of the processed input data + + Raises: + RecursionTooDeepError: nested include files exceed maximum + """ + if not max_depth: + raise RecursionTooDeepError('%s' % ( + 'Included files exceed maximum recursion depth of %s.' % max_depth)) + rval = [] + lines = [x.rstrip() for x in data.splitlines()] + for index, line in enumerate(lines): + words = line.split() + if len(words) > 1 and words[0] == '#include': + # remove any quotes around included filename + include_file = words[1].strip('\'"') + data = _ReadFile(os.path.join(base_dir, include_file)) + # recursively handle includes in included data + inc_data = _Preprocess(data, max_depth - 1, base_dir=base_dir) + rval.extend(inc_data) + else: + rval.append(line) + return rval + + +def ParseFile(filename, definitions=None, optimize=True, base_dir='', + shade_check=False): + """Parse the policy contained in file, optionally provide a naming object. + + Read specified policy file and parse into a policy object. + + Args: + filename: Name of policy file to parse. + definitions: optional naming library definitions object. + optimize: bool - whether to summarize networks and services. + base_dir: base path string to look for acls or include files. + shade_check: bool - whether to raise an exception when a term is shaded. + + Returns: + policy object. + """ + data = _ReadFile(filename) + p = ParsePolicy(data, definitions, optimize, base_dir=base_dir, + shade_check=shade_check) + return p + + +def ParsePolicy(data, definitions=None, optimize=True, base_dir='', + shade_check=False): + """Parse the policy in 'data', optionally provide a naming object. + + Parse a blob of policy text into a policy object. + + Args: + data: a string blob of policy data to parse. + definitions: optional naming library definitions object. + optimize: bool - whether to summarize networks and services. + base_dir: base path string to look for acls or include files. + shade_check: bool - whether to raise an exception when a term is shaded. + + Returns: + policy object. + """ + try: + if definitions: + globals()['DEFINITIONS'] = definitions + else: + globals()['DEFINITIONS'] = naming.Naming(DEFAULT_DEFINITIONS) + if not optimize: + globals()['_OPTIMIZE'] = False + if shade_check: + globals()['_SHADE_CHECK'] = True + + lexer = lex.lex() + + preprocessed_data = '\n'.join(_Preprocess(data, base_dir=base_dir)) + p = yacc.yacc(write_tables=False, debug=0, errorlog=yacc.NullLogger()) + + return p.parse(preprocessed_data, lexer=lexer) + + except IndexError: + return False + + +# if you call this from the command line, you can specify a jcl file for it to +# read. +if __name__ == '__main__': + ret = 0 + if len(sys.argv) > 1: + try: + ret = ParsePolicy(open(sys.argv[1], 'r').read()) + except IOError: + print('ERROR: \'%s\' either does not exist or is not readable' % + (sys.argv[1])) + ret = 1 + else: + # default to reading stdin + ret = ParsePolicy(sys.stdin.read()) + sys.exit(ret) |