summaryrefslogtreecommitdiff
path: root/lib/policy.py
diff options
context:
space:
mode:
authorJohan Lundberg <lundberg@nordu.net>2015-04-02 10:43:33 +0200
committerJohan Lundberg <lundberg@nordu.net>2015-04-02 10:43:33 +0200
commitbd611ac59f7c4db885a2f8631ef0bcdcd1901ca0 (patch)
treee60f5333a7699cd021b33c7f5292af55b774001b /lib/policy.py
Diffstat (limited to 'lib/policy.py')
-rw-r--r--lib/policy.py1821
1 files changed, 1821 insertions, 0 deletions
diff --git a/lib/policy.py b/lib/policy.py
new file mode 100644
index 0000000..a6b8ad6
--- /dev/null
+++ b/lib/policy.py
@@ -0,0 +1,1821 @@
+#!/usr/bin/python
+#
+# Copyright 2011 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Parses the generic policy files and return a policy object for acl rendering.
+"""
+
+import datetime
+import os
+import sys
+
+import logging
+import nacaddr
+import naming
+
+from third_party.ply import lex
+from third_party.ply import yacc
+
+
+DEFINITIONS = None
+DEFAULT_DEFINITIONS = './def'
+_ACTIONS = set(('accept', 'deny', 'reject', 'next', 'reject-with-tcp-rst'))
+_LOGGING = set(('true', 'True', 'syslog', 'local', 'disable'))
+_OPTIMIZE = True
+_SHADE_CHECK = False
+
+
+class Error(Exception):
+ """Generic error class."""
+
+
+class FileNotFoundError(Error):
+ """Policy file unable to be read."""
+
+
+class FileReadError(Error):
+ """Policy file unable to be read."""
+
+
+class RecursionTooDeepError(Error):
+ """Included files exceed maximum recursion depth."""
+
+
+class ParseError(Error):
+ """ParseError in the input."""
+
+
+class TermAddressExclusionError(Error):
+ """Excluded address block is not contained in the accepted address block."""
+
+
+class TermObjectTypeError(Error):
+ """Error with an object passed to Term."""
+
+
+class TermPortProtocolError(Error):
+ """Error when a requested protocol doesn't have any of the requested ports."""
+
+
+class TermProtocolEtherTypeError(Error):
+ """Error when both ether-type & upper-layer protocol matches are requested."""
+
+
+class TermNoActionError(Error):
+ """Error when a term hasn't defined an action."""
+
+
+class TermInvalidIcmpType(Error):
+ """Error when a term has invalid icmp-types specified."""
+
+
+class InvalidTermActionError(Error):
+ """Error when an action is invalid."""
+
+
+class InvalidTermLoggingError(Error):
+ """Error when a option is set for logging."""
+
+
+class UndefinedAddressError(Error):
+ """Error when an undefined address is referenced."""
+
+
+class NoTermsError(Error):
+ """Error when no terms were found."""
+
+
+class ShadingError(Error):
+ """Error when a term is shaded by a prior term."""
+
+
+def TranslatePorts(ports, protocols, term_name):
+ """Return all ports of all protocols requested.
+
+ Args:
+ ports: list of ports, eg ['SMTP', 'DNS', 'HIGH_PORTS']
+ protocols: list of protocols, eg ['tcp', 'udp']
+ term_name: name of current term, used for warning messages
+
+ Returns:
+ ret_array: list of ports tuples such as [(25,25), (53,53), (1024,65535)]
+
+ Note:
+ Duplication will be taken care of in Term.CollapsePortList
+ """
+ ret_array = []
+ for proto in protocols:
+ for port in ports:
+ service_by_proto = DEFINITIONS.GetServiceByProto(port, proto)
+ if not service_by_proto:
+ logging.warn('%s %s %s %s %s %s%s %s', 'Term', term_name,
+ 'has service', port, 'which is not defined with protocol',
+ proto,
+ ', but will be permitted. Unless intended, you should',
+ 'consider splitting the protocols into separate terms!')
+
+ for p in [x.split('-') for x in service_by_proto]:
+ if len(p) == 1:
+ ret_array.append((int(p[0]), int(p[0])))
+ else:
+ ret_array.append((int(p[0]), int(p[1])))
+ return ret_array
+
+
+# classes for storing the object types in the policy files.
+class Policy(object):
+ """The policy object contains everything found in a given policy file."""
+
+ def __init__(self, header, terms):
+ """Initiator for the Policy object.
+
+ Args:
+ header: __main__.Header object. contains comments which should be passed
+ on to the rendered acls as well as the type of acls this policy file
+ should render to.
+
+ terms: list __main__.Term. an array of Term objects which must be rendered
+ in each of the rendered acls.
+
+ Attributes:
+ filters: list of tuples containing (header, terms).
+ """
+ self.filters = []
+ self.AddFilter(header, terms)
+
+ def AddFilter(self, header, terms):
+ """Add another header & filter."""
+ self.filters.append((header, terms))
+ self._TranslateTerms(terms)
+ if _SHADE_CHECK:
+ self._DetectShading(terms)
+
+ def _TranslateTerms(self, terms):
+ """."""
+ if not terms:
+ raise NoTermsError('no terms found')
+ for term in terms:
+ # TODO(pmoody): this probably belongs in Term.SanityCheck(),
+ # or at the very least, in some method under class Term()
+ if term.translated:
+ continue
+ if term.port:
+ term.port = TranslatePorts(term.port, term.protocol, term.name)
+ if not term.port:
+ raise TermPortProtocolError(
+ 'no ports of the correct protocol for term %s' % (
+ term.name))
+ if term.source_port:
+ term.source_port = TranslatePorts(term.source_port, term.protocol,
+ term.name)
+ if not term.source_port:
+ raise TermPortProtocolError(
+ 'no source ports of the correct protocol for term %s' % (
+ term.name))
+ if term.destination_port:
+ term.destination_port = TranslatePorts(term.destination_port,
+ term.protocol, term.name)
+ if not term.destination_port:
+ raise TermPortProtocolError(
+ 'no destination ports of the correct protocol for term %s' % (
+ term.name))
+
+ # If argument is true, we optimize, otherwise just sort addresses
+ term.AddressCleanup(_OPTIMIZE)
+ # Reset _OPTIMIZE global to default value
+ globals()['_OPTIMIZE'] = True
+ term.SanityCheck()
+ term.translated = True
+
+ @property
+ def headers(self):
+ """Returns the headers from each of the configured filters.
+
+ Returns:
+ headers
+ """
+ return [x[0] for x in self.filters]
+
+ def _DetectShading(self, terms):
+ """Finds terms which are shaded (impossible to reach).
+
+ Iterate through each term, looking at each prior term. If a prior term
+ contains every component of the current term then the current term would
+ never be hit and is thus shaded. This can be a mistake.
+
+ Args:
+ terms: list of Term objects.
+
+ Raises:
+ ShadingError: When a term is impossible to reach.
+ """
+ # Reset _OPTIMIZE global to default value
+ globals()['_SHADE_CHECK'] = False
+ shading_errors = []
+ for index, term in enumerate(terms):
+ for prior_index in xrange(index):
+ # Check each term that came before for shading. Terms with next as an
+ # action do not terminate evaluation, so cannot shade.
+ if (term in terms[prior_index]
+ and 'next' not in terms[prior_index].action):
+ shading_errors.append(
+ ' %s is shaded by %s.' % (
+ term.name, terms[prior_index].name))
+ if shading_errors:
+ raise ShadingError('\n'.join(shading_errors))
+
+
+class Term(object):
+ """The Term object is used to store each of the terms.
+
+ Args:
+ obj: an object of type VarType or a list of objects of type VarType
+
+ members:
+ address/source_address/destination_address/: list of
+ VarType.(S|D)?ADDRESS's
+ address_exclude/source_address_exclude/destination_address_exclude: list of
+ VarType.(S|D)?ADDEXCLUDE's
+ port/source_port/destination_port: list of VarType.(S|D)?PORT's
+ options: list of VarType.OPTION's.
+ protocol: list of VarType.PROTOCOL's.
+ counter: VarType.COUNTER
+ action: list of VarType.ACTION's
+ comments: VarType.COMMENT
+ expiration: VarType.EXPIRATION
+ verbatim: VarType.VERBATIM
+ logging: VarType.LOGGING
+ qos: VarType.QOS
+ policer: VarType.POLICER
+ """
+ ICMP_TYPE = {4: {'echo-reply': 0,
+ 'unreachable': 3,
+ 'source-quench': 4,
+ 'redirect': 5,
+ 'alternate-address': 6,
+ 'echo-request': 8,
+ 'router-advertisement': 9,
+ 'router-solicitation': 10,
+ 'time-exceeded': 11,
+ 'parameter-problem': 12,
+ 'timestamp-request': 13,
+ 'timestamp-reply': 14,
+ 'information-request': 15,
+ 'information-reply': 16,
+ 'mask-request': 17,
+ 'mask-reply': 18,
+ 'conversion-error': 31,
+ 'mobile-redirect': 32,
+ },
+ 6: {'destination-unreachable': 1,
+ 'packet-too-big': 2,
+ 'time-exceeded': 3,
+ 'parameter-problem': 4,
+ 'echo-request': 128,
+ 'echo-reply': 129,
+ 'multicast-listener-query': 130,
+ 'multicast-listener-report': 131,
+ 'multicast-listener-done': 132,
+ 'router-solicit': 133,
+ 'router-advertisement': 134,
+ 'neighbor-solicit': 135,
+ 'neighbor-advertisement': 136,
+ 'redirect-message': 137,
+ 'router-renumbering': 138,
+ 'icmp-node-information-query': 139,
+ 'icmp-node-information-response': 140,
+ 'inverse-neighbor-discovery-solicitation': 141,
+ 'inverse-neighbor-discovery-advertisement': 142,
+ 'version-2-multicast-listener-report': 143,
+ 'home-agent-address-discovery-request': 144,
+ 'home-agent-address-discovery-reply': 145,
+ 'mobile-prefix-solicitation': 146,
+ 'mobile-prefix-advertisement': 147,
+ 'certification-path-solicitation': 148,
+ 'certification-path-advertisement': 149,
+ 'multicast-router-advertisement': 151,
+ 'multicast-router-solicitation': 152,
+ 'multicast-router-termination': 153,
+ },
+ }
+
+ def __init__(self, obj):
+ self.name = None
+
+ self.action = []
+ self.address = []
+ self.address_exclude = []
+ self.comment = []
+ self.counter = None
+ self.expiration = None
+ self.destination_address = []
+ self.destination_address_exclude = []
+ self.destination_port = []
+ self.destination_prefix = []
+ self.logging = []
+ self.loss_priority = None
+ self.option = []
+ self.owner = None
+ self.policer = None
+ self.port = []
+ self.precedence = []
+ self.principals = []
+ self.protocol = []
+ self.protocol_except = []
+ self.qos = None
+ self.routing_instance = None
+ self.source_address = []
+ self.source_address_exclude = []
+ self.source_port = []
+ self.source_prefix = []
+ self.verbatim = []
+ # juniper specific.
+ self.packet_length = None
+ self.fragment_offset = None
+ self.icmp_type = []
+ self.ether_type = []
+ self.traffic_type = []
+ self.translated = False
+ # iptables specific
+ self.source_interface = None
+ self.destination_interface = None
+ self.platform = []
+ self.platform_exclude = []
+ self.timeout = None
+ self.AddObject(obj)
+ self.flattened = False
+ self.flattened_addr = None
+ self.flattened_saddr = None
+ self.flattened_daddr = None
+
+ def __contains__(self, other):
+ """Determine if other term is contained in this term."""
+ if self.verbatim or other.verbatim:
+ # short circuit these
+ if sorted(self.verbatim) != sorted(other.verbatim):
+ return False
+
+ # check prototols
+ # either protocol or protocol-except may be used, not both at the same time.
+ if self.protocol:
+ if other.protocol:
+ if not self.CheckProtocolIsContained(other.protocol, self.protocol):
+ return False
+ # this term has protocol, other has protocol_except.
+ elif other.protocol_except:
+ return False
+ else:
+ # other does not have protocol or protocol_except. since we do other
+ # cannot be contained in self.
+ return False
+ elif self.protocol_except:
+ if other.protocol_except:
+ if self.CheckProtocolIsContained(
+ self.protocol_except, other.protocol_except):
+ return False
+ elif other.protocol:
+ for proto in other.protocol:
+ if proto in self.protocol_except:
+ return False
+ else:
+ return False
+
+ # combine addresses with exclusions for proper contains comparisons.
+ if not self.flattened:
+ self.FlattenAll()
+ if not other.flattened:
+ other.FlattenAll()
+
+ # flat 'address' is compared against other flat (saddr|daddr).
+ # if NONE of these evaluate to True other is not contained.
+ if not (
+ self.CheckAddressIsContained(
+ self.flattened_addr, other.flattened_addr)
+ or self.CheckAddressIsContained(
+ self.flattened_addr, other.flattened_saddr)
+ or self.CheckAddressIsContained(
+ self.flattened_addr, other.flattened_daddr)):
+ return False
+
+ # compare flat address from other to flattened self (saddr|daddr).
+ if not (
+ # other's flat address needs both self saddr & daddr to contain in order
+ # for the term to be contained. We already compared the flattened_addr
+ # attributes of both above, which was not contained.
+ self.CheckAddressIsContained(
+ other.flattened_addr, self.flattened_saddr)
+ and self.CheckAddressIsContained(
+ other.flattened_addr, self.flattened_daddr)):
+ return False
+
+ # basic saddr/daddr check.
+ if not (
+ self.CheckAddressIsContained(
+ self.flattened_saddr, other.flattened_saddr)):
+ return False
+ if not (
+ self.CheckAddressIsContained(
+ self.flattened_daddr, other.flattened_daddr)):
+ return False
+
+ if not (
+ self.CheckPrincipalsContained(
+ self.principals, other.principals)):
+ return False
+
+ # check ports
+ # like the address directive, the port directive is special in that it can
+ # be either source or destination.
+ if self.port:
+ if not (self.CheckPortIsContained(self.port, other.port) or
+ self.CheckPortIsContained(self.port, other.sport) or
+ self.CheckPortIsContained(self.port, other.dport)):
+ return False
+ if not self.CheckPortIsContained(self.source_port, other.source_port):
+ return False
+ if not self.CheckPortIsContained(self.destination_port,
+ other.destination_port):
+ return False
+
+ # prefix lists
+ if self.source_prefix:
+ if sorted(self.source_prefix) != sorted(other.source_prefix):
+ return False
+ if self.destination_prefix:
+ if sorted(self.destination_prefix) != sorted(
+ other.destination_prefix):
+ return False
+
+ # check precedence
+ if self.precedence:
+ if not other.precedence:
+ return False
+ for precedence in other.precedence:
+ if precedence not in self.precedence:
+ return False
+ # check various options
+ if self.option:
+ if not other.option:
+ return False
+ for opt in other.option:
+ if opt not in self.option:
+ return False
+ if self.fragment_offset:
+ # fragment_offset looks like 'integer-integer' or just, 'integer'
+ sfo = [int(x) for x in self.fragment_offset.split('-')]
+ if other.fragment_offset:
+ ofo = [int(x) for x in other.fragment_offset.split('-')]
+ if sfo[0] < ofo[0] or sorted(sfo[1:]) > sorted(ofo[1:]):
+ return False
+ else:
+ return False
+ if self.packet_length:
+ # packet_length looks like 'integer-integer' or just, 'integer'
+ spl = [int(x) for x in self.packet_length.split('-')]
+ if other.packet_length:
+ opl = [int(x) for x in other.packet_length.split('-')]
+ if spl[0] < opl[0] or sorted(spl[1:]) > sorted(opl[1:]):
+ return False
+ else:
+ return False
+ if self.icmp_type:
+ if sorted(self.icmp_type) is not sorted(other.icmp_type):
+ return False
+
+ # check platform
+ if self.platform:
+ if sorted(self.platform) is not sorted(other.platform):
+ return False
+ if self.platform_exclude:
+ if sorted(self.platform_exclude) is not sorted(other.platform_exclude):
+ return False
+
+ # we have containment
+ return True
+
+ def __str__(self):
+ ret_str = []
+ ret_str.append(' name: %s' % self.name)
+ if self.address:
+ ret_str.append(' address: %s' % self.address)
+ if self.address_exclude:
+ ret_str.append(' address_exclude: %s' % self.address_exclude)
+ if self.source_address:
+ ret_str.append(' source_address: %s' % self.source_address)
+ if self.source_address_exclude:
+ ret_str.append(' source_address_exclude: %s' %
+ self.source_address_exclude)
+ if self.destination_address:
+ ret_str.append(' destination_address: %s' % self.destination_address)
+ if self.destination_address_exclude:
+ ret_str.append(' destination_address_exclude: %s' %
+ self.destination_address_exclude)
+ if self.source_prefix:
+ ret_str.append(' source_prefix: %s' % self.source_prefix)
+ if self.destination_prefix:
+ ret_str.append(' destination_prefix: %s' % self.destination_prefix)
+ if self.protocol:
+ ret_str.append(' protocol: %s' % self.protocol)
+ if self.protocol_except:
+ ret_str.append(' protocol-except: %s' % self.protocol_except)
+ if self.owner:
+ ret_str.append(' owner: %s' % self.owner)
+ if self.port:
+ ret_str.append(' port: %s' % self.port)
+ if self.source_port:
+ ret_str.append(' source_port: %s' % self.source_port)
+ if self.destination_port:
+ ret_str.append(' destination_port: %s' % self.destination_port)
+ if self.action:
+ ret_str.append(' action: %s' % self.action)
+ if self.option:
+ ret_str.append(' option: %s' % self.option)
+ if self.qos:
+ ret_str.append(' qos: %s' % self.qos)
+ if self.logging:
+ ret_str.append(' logging: %s' % self.logging)
+ if self.counter:
+ ret_str.append(' counter: %s' % self.counter)
+ if self.source_interface:
+ ret_str.append(' source_interface: %s' % self.source_interface)
+ if self.destination_interface:
+ ret_str.append(' destination_interface: %s' % self.destination_interface)
+ if self.expiration:
+ ret_str.append(' expiration: %s' % self.expiration)
+ if self.platform:
+ ret_str.append(' platform: %s' % self.platform)
+ if self.platform_exclude:
+ ret_str.append(' platform_exclude: %s' % self.platform_exclude)
+ if self.timeout:
+ ret_str.append(' timeout: %s' % self.timeout)
+ return '\n'.join(ret_str)
+
+ def __eq__(self, other):
+ # action
+ if sorted(self.action) != sorted(other.action):
+ return False
+
+ # addresses.
+ if not (sorted(self.address) == sorted(other.address) and
+ sorted(self.source_address) == sorted(other.source_address) and
+ sorted(self.source_address_exclude) ==
+ sorted(other.source_address_exclude) and
+ sorted(self.destination_address) ==
+ sorted(other.destination_address) and
+ sorted(self.destination_address_exclude) ==
+ sorted(other.destination_address_exclude)):
+ return False
+
+ # prefix lists
+ if not (sorted(self.source_prefix) == sorted(other.source_prefix) and
+ sorted(self.destination_prefix) ==
+ sorted(other.destination_prefix)):
+ return False
+
+ # ports
+ if not (sorted(self.port) == sorted(other.port) and
+ sorted(self.source_port) == sorted(other.source_port) and
+ sorted(self.destination_port) == sorted(other.destination_port)):
+ return False
+
+ # protocol
+ if not (sorted(self.protocol) == sorted(other.protocol) and
+ sorted(self.protocol_except) == sorted(other.protocol_except)):
+ return False
+
+ # option
+ if sorted(self.option) != sorted(other.option):
+ return False
+
+ # qos
+ if self.qos != other.qos:
+ return False
+
+ # verbatim
+ if self.verbatim != other.verbatim:
+ return False
+
+ # policer
+ if self.policer != other.policer:
+ return False
+
+ # interface
+ if self.source_interface != other.source_interface:
+ return False
+
+ if self.destination_interface != other.destination_interface:
+ return False
+
+ if sorted(self.logging) != sorted(other.logging):
+ return False
+ if self.qos != other.qos:
+ return False
+ if self.packet_length != other.packet_length:
+ return False
+ if self.fragment_offset != other.fragment_offset:
+ return False
+ if sorted(self.icmp_type) != sorted(other.icmp_type):
+ return False
+ if sorted(self.ether_type) != sorted(other.ether_type):
+ return False
+ if sorted(self.traffic_type) != sorted(other.traffic_type):
+ return False
+
+ # platform
+ if not (sorted(self.platform) == sorted(other.platform) and
+ sorted(self.platform_exclude) == sorted(other.platform_exclude)):
+ return False
+
+ # timeout
+ if self.timeout != other.timeout:
+ return False
+
+ return True
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def FlattenAll(self):
+ """Reduce source, dest, and address fields to their post-exclude state.
+
+ Populates the self.flattened_addr, self.flattened_saddr,
+ self.flattened_daddr by removing excludes from includes.
+ """
+ # No excludes, set flattened attributes and move along.
+ self.flattened = True
+ if not (self.source_address_exclude or self.destination_address_exclude or
+ self.address_exclude):
+ self.flattened_saddr = self.source_address
+ self.flattened_daddr = self.destination_address
+ self.flattened_addr = self.address
+ return
+
+ if self.source_address_exclude:
+ self.flattened_saddr = self._FlattenAddresses(
+ self.source_address, self.source_address_exclude)
+ if self.destination_address_exclude:
+ self.flattened_daddr = self._FlattenAddresses(
+ self.destination_address, self.destination_address_exclude)
+ if self.address_exclude:
+ self.flattened_addr = self._FlattenAddresses(
+ self.address, self.address_exclude)
+
+
+ @staticmethod
+ def _FlattenAddresses(include, exclude):
+ """Reduce an include and exclude list to a single include list.
+
+ Using recursion, whittle away exclude addresses from address include
+ addresses which contain the exclusion.
+
+ Args:
+ include: list of include addresses.
+ exclude: list of exclude addresses.
+ Returns:
+ a single flattened list of nacaddr objects.
+ """
+ if not exclude:
+ return include
+
+ for index, in_addr in enumerate(include):
+ for ex_addr in exclude:
+ if ex_addr in in_addr:
+ reduced_list = in_addr.address_exclude(ex_addr)
+ include.pop(index)
+ include.extend(
+ Term._FlattenAddresses(reduced_list, exclude[1:]))
+ return include
+
+ def GetAddressOfVersion(self, addr_type, af=None):
+ """Returns addresses of the appropriate Address Family.
+
+ Args:
+ addr_type: string, this will be either
+ 'source_address', 'source_address_exclude',
+ 'destination_address' or 'destination_address_exclude'
+ af: int or None, either Term.INET4 or Term.INET6
+
+ Returns:
+ list of addresses of the correct family.
+ """
+ if not af:
+ return eval('self.' + addr_type)
+
+ return filter(lambda x: x.version == af, eval('self.' + addr_type))
+
+ def AddObject(self, obj):
+ """Add an object of unknown type to this term.
+
+ Args:
+ obj: single or list of either
+ [Address, Port, Option, Protocol, Counter, Action, Comment, Expiration]
+
+ Raises:
+ InvalidTermActionError: if the action defined isn't an accepted action.
+ eg, action:: godofoobar
+ TermObjectTypeError: if AddObject is called with an object it doesn't
+ understand.
+ InvalidTermLoggingError: when a option is set for logging not known.
+ """
+ if type(obj) is list:
+ for x in obj:
+ # do we have a list of addresses?
+ # expanded address fields consolidate naked address fields with
+ # saddr/daddr.
+ if x.var_type is VarType.SADDRESS:
+ saddr = DEFINITIONS.GetNetAddr(x.value)
+ self.source_address.extend(saddr)
+ elif x.var_type is VarType.DADDRESS:
+ daddr = DEFINITIONS.GetNetAddr(x.value)
+ self.destination_address.extend(daddr)
+ elif x.var_type is VarType.ADDRESS:
+ addr = DEFINITIONS.GetNetAddr(x.value)
+ self.address.extend(addr)
+ # do we have address excludes?
+ elif x.var_type is VarType.SADDREXCLUDE:
+ saddr_exclude = DEFINITIONS.GetNetAddr(x.value)
+ self.source_address_exclude.extend(saddr_exclude)
+ elif x.var_type is VarType.DADDREXCLUDE:
+ daddr_exclude = DEFINITIONS.GetNetAddr(x.value)
+ self.destination_address_exclude.extend(daddr_exclude)
+ elif x.var_type is VarType.ADDREXCLUDE:
+ addr_exclude = DEFINITIONS.GetNetAddr(x.value)
+ self.address_exclude.extend(addr_exclude)
+ # do we have a list of ports?
+ elif x.var_type is VarType.PORT:
+ self.port.append(x.value)
+ elif x.var_type is VarType.SPORT:
+ self.source_port.append(x.value)
+ elif x.var_type is VarType.DPORT:
+ self.destination_port.append(x.value)
+ # do we have a list of protocols?
+ elif x.var_type is VarType.PROTOCOL:
+ self.protocol.append(x.value)
+ # do we have a list of protocol-exceptions?
+ elif x.var_type is VarType.PROTOCOL_EXCEPT:
+ self.protocol_except.append(x.value)
+ # do we have a list of options?
+ elif x.var_type is VarType.OPTION:
+ self.option.append(x.value)
+ elif x.var_type is VarType.PRINCIPALS:
+ self.principals.append(x.value)
+ elif x.var_type is VarType.SPFX:
+ self.source_prefix.append(x.value)
+ elif x.var_type is VarType.DPFX:
+ self.destination_prefix.append(x.value)
+ elif x.var_type is VarType.ETHER_TYPE:
+ self.ether_type.append(x.value)
+ elif x.var_type is VarType.TRAFFIC_TYPE:
+ self.traffic_type.append(x.value)
+ elif x.var_type is VarType.PRECEDENCE:
+ self.precedence.append(x.value)
+ elif x.var_type is VarType.PLATFORM:
+ self.platform.append(x.value)
+ elif x.var_type is VarType.PLATFORMEXCLUDE:
+ self.platform_exclude.append(x.value)
+ else:
+ raise TermObjectTypeError(
+ '%s isn\'t a type I know how to deal with (contains \'%s\')' % (
+ type(x), x.value))
+ else:
+ # stupid no switch statement in python
+ if obj.var_type is VarType.COMMENT:
+ self.comment.append(str(obj))
+ elif obj.var_type is VarType.OWNER:
+ self.owner = obj.value
+ elif obj.var_type is VarType.EXPIRATION:
+ self.expiration = obj.value
+ elif obj.var_type is VarType.LOSS_PRIORITY:
+ self.loss_priority = obj.value
+ elif obj.var_type is VarType.ROUTING_INSTANCE:
+ self.routing_instance = obj.value
+ elif obj.var_type is VarType.PRECEDENCE:
+ self.precedence = obj.value
+ elif obj.var_type is VarType.VERBATIM:
+ self.verbatim.append(obj)
+ elif obj.var_type is VarType.ACTION:
+ if str(obj) not in _ACTIONS:
+ raise InvalidTermActionError('%s is not a valid action' % obj)
+ self.action.append(obj.value)
+ elif obj.var_type is VarType.COUNTER:
+ self.counter = obj
+ elif obj.var_type is VarType.ICMP_TYPE:
+ self.icmp_type.extend(obj.value)
+ elif obj.var_type is VarType.LOGGING:
+ if str(obj) not in _LOGGING:
+ raise InvalidTermLoggingError('%s is not a valid logging option' %
+ obj)
+ self.logging.append(obj)
+ # police man, tryin'a take you jail
+ elif obj.var_type is VarType.POLICER:
+ self.policer = obj.value
+ # qos?
+ elif obj.var_type is VarType.QOS:
+ self.qos = obj.value
+ elif obj.var_type is VarType.PACKET_LEN:
+ self.packet_length = obj.value
+ elif obj.var_type is VarType.FRAGMENT_OFFSET:
+ self.fragment_offset = obj.value
+ elif obj.var_type is VarType.SINTERFACE:
+ self.source_interface = obj.value
+ elif obj.var_type is VarType.DINTERFACE:
+ self.destination_interface = obj.value
+ elif obj.var_type is VarType.TIMEOUT:
+ self.timeout = obj.value
+ else:
+ raise TermObjectTypeError(
+ '%s isn\'t a type I know how to deal with' % (type(obj)))
+
+ def SanityCheck(self):
+ """Sanity check the definition of the term.
+
+ Raises:
+ ParseError: if term has both verbatim and non-verbatim tokens
+ TermInvalidIcmpType: if term has invalid icmp-types specified
+ TermNoActionError: if the term doesn't have an action defined.
+ TermPortProtocolError: if the term has a service/protocol definition pair
+ which don't match up, eg. SNMP and tcp
+ TermAddressExclusionError: if one of the *-exclude directives is defined,
+ but that address isn't contained in the non *-exclude directive. eg:
+ source-address::CORP_INTERNAL source-exclude:: LOCALHOST
+ TermProtocolEtherTypeError: if the term has both ether-type and
+ upper-layer protocol restrictions
+ InvalidTermActionError: action and routing-instance both defined
+
+ This should be called when the term is fully formed, and
+ all of the options are set.
+
+ """
+ if self.verbatim:
+ if (self.action or self.source_port or self.destination_port or
+ self.port or self.protocol or self.option):
+ raise ParseError(
+ 'term "%s" has both verbatim and non-verbatim tokens.' % self.name)
+ else:
+ if not self.action and not self.routing_instance:
+ raise TermNoActionError('no action specified for term %s' % self.name)
+ # have we specified a port with a protocol that doesn't support ports?
+ if self.source_port or self.destination_port or self.port:
+ if 'tcp' not in self.protocol and 'udp' not in self.protocol:
+ raise TermPortProtocolError(
+ 'ports specified with a protocol that doesn\'t support ports. '
+ 'Term: %s ' % self.name)
+ # TODO(pmoody): do we have mutually exclusive options?
+ # eg. tcp-established + tcp-initial?
+
+ if self.ether_type and (
+ self.protocol or
+ self.address or
+ self.destination_address or
+ self.destination_address_exclude or
+ self.destination_port or
+ self.destination_prefix or
+ self.source_address or
+ self.source_address_exclude or
+ self.source_port or
+ self.source_prefix):
+ raise TermProtocolEtherTypeError(
+ 'ether-type not supported when used with upper-layer protocol '
+ 'restrictions. Term: %s' % self.name)
+ # validate icmp-types if specified, but addr_family will have to be checked
+ # in the generators as policy module doesn't know about that at this point.
+ if self.icmp_type:
+ for icmptype in self.icmp_type:
+ if (icmptype not in self.ICMP_TYPE[4] and icmptype not in
+ self.ICMP_TYPE[6]):
+ raise TermInvalidIcmpType('Term %s contains an invalid icmp-type:'
+ '%s' % (self.name, icmptype))
+
+ def AddressCleanup(self, optimize=True):
+ """Do Address and Port collapsing.
+
+ Notes:
+ Collapses both the address definitions and the port definitions
+ to their smallest possible length.
+
+ Args:
+ optimize: boolean value indicating whether to optimize addresses
+ """
+ if optimize:
+ cleanup = nacaddr.CollapseAddrList
+ else:
+ cleanup = nacaddr.SortAddrList
+
+ # address collapsing.
+ if self.address:
+ self.address = cleanup(self.address)
+ if self.source_address:
+ self.source_address = cleanup(self.source_address)
+ if self.source_address_exclude:
+ self.source_address_exclude = cleanup(self.source_address_exclude)
+ if self.destination_address:
+ self.destination_address = cleanup(self.destination_address)
+ if self.destination_address_exclude:
+ self.destination_address_exclude = cleanup(
+ self.destination_address_exclude)
+
+ # port collapsing.
+ if self.port:
+ self.port = self.CollapsePortList(self.port)
+ if self.source_port:
+ self.source_port = self.CollapsePortList(self.source_port)
+ if self.destination_port:
+ self.destination_port = self.CollapsePortList(self.destination_port)
+
+ def CollapsePortListRecursive(self, ports):
+ """Given a sorted list of ports, collapse to the smallest required list.
+
+ Args:
+ ports: sorted list of port tuples
+
+ Returns:
+ ret_ports: collapsed list of ports
+ """
+ optimized = False
+ ret_ports = []
+ for port in ports:
+ if not ret_ports:
+ ret_ports.append(port)
+ # we should be able to count on ret_ports[-1][0] <= port[0]
+ elif ret_ports[-1][1] >= port[1]:
+ # (10, 20) and (12, 13) -> (10, 20)
+ optimized = True
+ elif port[0] < ret_ports[-1][1] < port[1]:
+ # (10, 20) and (15, 30) -> (10, 30)
+ ret_ports[-1] = (ret_ports[-1][0], port[1])
+ optimized = True
+ elif ret_ports[-1][1] + 1 == port[0]:
+ # (10, 20) and (21, 30) -> (10, 30)
+ ret_ports[-1] = (ret_ports[-1][0], port[1])
+ optimized = True
+ else:
+ # (10, 20) and (22, 30) -> (10, 20), (22, 30)
+ ret_ports.append(port)
+
+ if optimized:
+ return self.CollapsePortListRecursive(ret_ports)
+ return ret_ports
+
+ def CollapsePortList(self, ports):
+ """Given a list of ports, Collapse to the smallest required.
+
+ Args:
+ ports: a list of port strings eg: [(80,80), (53,53) (2000, 2009),
+ (1024,65535)]
+
+ Returns:
+ ret_array: the collapsed sorted list of ports, eg: [(53,53), (80,80),
+ (1024,65535)]
+ """
+ return self.CollapsePortListRecursive(sorted(ports))
+
+ def CheckPrincipalsContained(self, superset, subset):
+ """Check to if the given list of principals is wholly contained.
+
+ Args:
+ superset: list of principals
+ subset: list of principals
+
+ Returns:
+ bool: True if subset is contained in superset. false otherwise.
+ """
+ # Skip set comparison if neither term has principals.
+ if not superset and not subset:
+ return True
+
+ # Convert these lists to sets to use set comparison.
+ sup = set(superset)
+ sub = set(subset)
+ return sub.issubset(sup)
+
+ def CheckProtocolIsContained(self, superset, subset):
+ """Check if the given list of protocols is wholly contained.
+
+ Args:
+ superset: list of protocols
+ subset: list of protocols
+
+ Returns:
+ bool: True if subset is contained in superset. false otherwise.
+ """
+ if not superset:
+ return True
+ if not subset:
+ return False
+
+ # Convert these lists to sets to use set comparison.
+ sup = set(superset)
+ sub = set(subset)
+ return sub.issubset(sup)
+
+ def CheckPortIsContained(self, superset, subset):
+ """Check if the given list of ports is wholly contained.
+
+ Args:
+ superset: list of port tuples
+ subset: list of port tuples
+
+ Returns:
+ bool: True if subset is contained in superset, false otherwise
+ """
+ if not superset:
+ return True
+ if not subset:
+ return False
+
+ for sub_port in subset:
+ not_contains = True
+ for sup_port in superset:
+ if (int(sub_port[0]) >= int(sup_port[0])
+ and int(sub_port[1]) <= int(sup_port[1])):
+ not_contains = False
+ break
+ if not_contains:
+ return False
+ return True
+
+ def CheckAddressIsContained(self, superset, subset):
+ """Check if subset is wholey contained by superset.
+
+ Args:
+ superset: list of the superset addresses
+ subset: list of the subset addresses
+
+ Returns:
+ True or False.
+ """
+ if not superset:
+ return True
+ if not subset:
+ return False
+
+ for sub_addr in subset:
+ sub_contained = False
+ for sup_addr in superset:
+ # ipaddr ensures that version numbers match for inclusion.
+ if sub_addr in sup_addr:
+ sub_contained = True
+ break
+ if not sub_contained:
+ return False
+ return True
+
+
+class VarType(object):
+ """Generic object meant to store lots of basic policy types."""
+
+ COMMENT = 0
+ COUNTER = 1
+ ACTION = 2
+ SADDRESS = 3
+ DADDRESS = 4
+ ADDRESS = 5
+ SPORT = 6
+ DPORT = 7
+ PROTOCOL_EXCEPT = 8
+ OPTION = 9
+ PROTOCOL = 10
+ SADDREXCLUDE = 11
+ DADDREXCLUDE = 12
+ LOGGING = 13
+ QOS = 14
+ POLICER = 15
+ PACKET_LEN = 16
+ FRAGMENT_OFFSET = 17
+ ICMP_TYPE = 18
+ SPFX = 19
+ DPFX = 20
+ ETHER_TYPE = 21
+ TRAFFIC_TYPE = 22
+ VERBATIM = 23
+ LOSS_PRIORITY = 24
+ ROUTING_INSTANCE = 25
+ PRECEDENCE = 26
+ SINTERFACE = 27
+ EXPIRATION = 28
+ DINTERFACE = 29
+ PLATFORM = 30
+ PLATFORMEXCLUDE = 31
+ PORT = 32
+ TIMEOUT = 33
+ OWNER = 34
+ PRINCIPALS = 35
+ ADDREXCLUDE = 36
+
+ def __init__(self, var_type, value):
+ self.var_type = var_type
+ if self.var_type == self.COMMENT:
+ # remove the double quotes
+ comment = value.strip('"')
+ # make all of the lines start w/o leading whitespace.
+ self.value = '\n'.join([x.lstrip() for x in comment.splitlines()])
+ else:
+ self.value = value
+
+ def __str__(self):
+ return self.value
+
+ def __eq__(self, other):
+ return self.var_type == other.var_type and self.value == other.value
+
+
+class Header(object):
+ """The header of the policy file contains the targets and a global comment."""
+
+ def __init__(self):
+ self.target = []
+ self.comment = []
+
+ def AddObject(self, obj):
+ """Add and object to the Header.
+
+ Args:
+ obj: of type VarType.COMMENT or Target
+ """
+ if type(obj) == Target:
+ self.target.append(obj)
+ elif obj.var_type == VarType.COMMENT:
+ self.comment.append(str(obj))
+
+ @property
+ def platforms(self):
+ """The platform targets of this particular header."""
+ return map(lambda x: x.platform, self.target)
+
+ def FilterOptions(self, platform):
+ """Given a platform return the options.
+
+ Args:
+ platform: string
+
+ Returns:
+ list or None
+ """
+ for target in self.target:
+ if target.platform == platform:
+ return target.options
+ return []
+
+ def FilterName(self, platform):
+ """Given a filter_type, return the filter name.
+
+ Args:
+ platform: string
+
+ Returns:
+ filter_name: string or None
+
+ Notes:
+ !! Deprecated in favor of Header.FilterOptions(platform) !!
+ """
+ for target in self.target:
+ if target.platform == platform:
+ if target.options:
+ return target.options[0]
+ return None
+
+
+# This could be a VarType object, but I'm keeping it as it's class
+# b/c we're almost certainly going to have to do something more exotic with
+# it shortly to account for various rendering options like default iptables
+# policies or output file names, etc. etc.
+class Target(object):
+ """The type of acl to be rendered from this policy file."""
+
+ def __init__(self, target):
+ self.platform = target[0]
+ if len(target) > 1:
+ self.options = target[1:]
+ else:
+ self.options = None
+
+ def __str__(self):
+ return self.platform
+
+ def __eq__(self, other):
+ return self.platform == other.platform and self.options == other.options
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+
+# Lexing/Parsing starts here
+tokens = (
+ 'ACTION',
+ 'ADDR',
+ 'ADDREXCLUDE',
+ 'COMMENT',
+ 'COUNTER',
+ 'DADDR',
+ 'DADDREXCLUDE',
+ 'DPFX',
+ 'DPORT',
+ 'DINTERFACE',
+ 'DQUOTEDSTRING',
+ 'ETHER_TYPE',
+ 'EXPIRATION',
+ 'FRAGMENT_OFFSET',
+ 'HEADER',
+ 'ICMP_TYPE',
+ 'INTEGER',
+ 'LOGGING',
+ 'LOSS_PRIORITY',
+ 'OPTION',
+ 'OWNER',
+ 'PACKET_LEN',
+ 'PLATFORM',
+ 'PLATFORMEXCLUDE',
+ 'POLICER',
+ 'PORT',
+ 'PRECEDENCE',
+ 'PRINCIPALS',
+ 'PROTOCOL',
+ 'PROTOCOL_EXCEPT',
+ 'QOS',
+ 'ROUTING_INSTANCE',
+ 'SADDR',
+ 'SADDREXCLUDE',
+ 'SINTERFACE',
+ 'SPFX',
+ 'SPORT',
+ 'STRING',
+ 'TARGET',
+ 'TERM',
+ 'TIMEOUT',
+ 'TRAFFIC_TYPE',
+ 'VERBATIM',
+)
+
+literals = r':{},-'
+t_ignore = ' \t'
+
+reserved = {
+ 'action': 'ACTION',
+ 'address': 'ADDR',
+ 'address-exclude': 'ADDREXCLUDE',
+ 'comment': 'COMMENT',
+ 'counter': 'COUNTER',
+ 'destination-address': 'DADDR',
+ 'destination-exclude': 'DADDREXCLUDE',
+ 'destination-interface': 'DINTERFACE',
+ 'destination-prefix': 'DPFX',
+ 'destination-port': 'DPORT',
+ 'ether-type': 'ETHER_TYPE',
+ 'expiration': 'EXPIRATION',
+ 'fragment-offset': 'FRAGMENT_OFFSET',
+ 'header': 'HEADER',
+ 'icmp-type': 'ICMP_TYPE',
+ 'logging': 'LOGGING',
+ 'loss-priority': 'LOSS_PRIORITY',
+ 'option': 'OPTION',
+ 'owner': 'OWNER',
+ 'packet-length': 'PACKET_LEN',
+ 'platform': 'PLATFORM',
+ 'platform-exclude': 'PLATFORMEXCLUDE',
+ 'policer': 'POLICER',
+ 'port': 'PORT',
+ 'precedence': 'PRECEDENCE',
+ 'principals': 'PRINCIPALS',
+ 'protocol': 'PROTOCOL',
+ 'protocol-except': 'PROTOCOL_EXCEPT',
+ 'qos': 'QOS',
+ 'routing-instance': 'ROUTING_INSTANCE',
+ 'source-address': 'SADDR',
+ 'source-exclude': 'SADDREXCLUDE',
+ 'source-interface': 'SINTERFACE',
+ 'source-prefix': 'SPFX',
+ 'source-port': 'SPORT',
+ 'target': 'TARGET',
+ 'term': 'TERM',
+ 'timeout': 'TIMEOUT',
+ 'traffic-type': 'TRAFFIC_TYPE',
+ 'verbatim': 'VERBATIM',
+}
+
+
+# disable linting warnings for lexx/yacc code
+# pylint: disable-msg=W0613,C6102,C6104,C6105,C6108,C6409
+
+
+def t_IGNORE_COMMENT(t):
+ r'\#.*'
+ pass
+
+
+def t_DQUOTEDSTRING(t):
+ r'"[^"]*?"'
+ t.lexer.lineno += str(t.value).count('\n')
+ return t
+
+
+def t_newline(t):
+ r'\n+'
+ t.lexer.lineno += len(t.value)
+
+
+def t_error(t):
+ print "Illegal character '%s' on line %s" % (t.value[0], t.lineno)
+ t.lexer.skip(1)
+
+
+def t_INTEGER(t):
+ r'\d+'
+ return t
+
+
+def t_STRING(t):
+ r'\w+([-_+.@]\w*)*'
+ # we have an identifier; let's check if it's a keyword or just a string.
+ t.type = reserved.get(t.value, 'STRING')
+ return t
+
+
+###
+## parser starts here
+###
+def p_target(p):
+ """ target : target header terms
+ | """
+ if len(p) > 1:
+ if type(p[1]) is Policy:
+ p[1].AddFilter(p[2], p[3])
+ p[0] = p[1]
+ else:
+ p[0] = Policy(p[2], p[3])
+
+
+def p_header(p):
+ """ header : HEADER '{' header_spec '}' """
+ p[0] = p[3]
+
+
+def p_header_spec(p):
+ """ header_spec : header_spec target_spec
+ | header_spec comment_spec
+ | """
+ if len(p) > 1:
+ if type(p[1]) == Header:
+ p[1].AddObject(p[2])
+ p[0] = p[1]
+ else:
+ p[0] = Header()
+ p[0].AddObject(p[2])
+
+
+# we may want to change this at some point if we want to be clever with things
+# like being able to set a default input/output policy for iptables policies.
+def p_target_spec(p):
+ """ target_spec : TARGET ':' ':' strings_or_ints """
+ p[0] = Target(p[4])
+
+
+def p_terms(p):
+ """ terms : terms TERM STRING '{' term_spec '}'
+ | """
+ if len(p) > 1:
+ p[5].name = p[3]
+ if type(p[1]) == list:
+ p[1].append(p[5])
+ p[0] = p[1]
+ else:
+ p[0] = [p[5]]
+
+
+def p_term_spec(p):
+ """ term_spec : term_spec action_spec
+ | term_spec addr_spec
+ | term_spec comment_spec
+ | term_spec counter_spec
+ | term_spec ether_type_spec
+ | term_spec exclude_spec
+ | term_spec expiration_spec
+ | term_spec fragment_offset_spec
+ | term_spec icmp_type_spec
+ | term_spec interface_spec
+ | term_spec logging_spec
+ | term_spec losspriority_spec
+ | term_spec option_spec
+ | term_spec owner_spec
+ | term_spec packet_length_spec
+ | term_spec platform_spec
+ | term_spec policer_spec
+ | term_spec port_spec
+ | term_spec precedence_spec
+ | term_spec principals_spec
+ | term_spec prefix_list_spec
+ | term_spec protocol_spec
+ | term_spec qos_spec
+ | term_spec routinginstance_spec
+ | term_spec timeout_spec
+ | term_spec traffic_type_spec
+ | term_spec verbatim_spec
+ | """
+ if len(p) > 1:
+ if type(p[1]) == Term:
+ p[1].AddObject(p[2])
+ p[0] = p[1]
+ else:
+ p[0] = Term(p[2])
+
+
+def p_routinginstance_spec(p):
+ """ routinginstance_spec : ROUTING_INSTANCE ':' ':' STRING """
+ p[0] = VarType(VarType.ROUTING_INSTANCE, p[4])
+
+
+def p_losspriority_spec(p):
+ """ losspriority_spec : LOSS_PRIORITY ':' ':' STRING """
+ p[0] = VarType(VarType.LOSS_PRIORITY, p[4])
+
+
+def p_precedence_spec(p):
+ """ precedence_spec : PRECEDENCE ':' ':' one_or_more_ints """
+ p[0] = VarType(VarType.PRECEDENCE, p[4])
+
+
+def p_icmp_type_spec(p):
+ """ icmp_type_spec : ICMP_TYPE ':' ':' one_or_more_strings """
+ p[0] = VarType(VarType.ICMP_TYPE, p[4])
+
+
+def p_packet_length_spec(p):
+ """ packet_length_spec : PACKET_LEN ':' ':' INTEGER
+ | PACKET_LEN ':' ':' INTEGER '-' INTEGER """
+ if len(p) == 4:
+ p[0] = VarType(VarType.PACKET_LEN, str(p[4]))
+ else:
+ p[0] = VarType(VarType.PACKET_LEN, str(p[4]) + '-' + str(p[6]))
+
+
+def p_fragment_offset_spec(p):
+ """ fragment_offset_spec : FRAGMENT_OFFSET ':' ':' INTEGER
+ | FRAGMENT_OFFSET ':' ':' INTEGER '-' INTEGER """
+ if len(p) == 4:
+ p[0] = VarType(VarType.FRAGMENT_OFFSET, str(p[4]))
+ else:
+ p[0] = VarType(VarType.FRAGMENT_OFFSET, str(p[4]) + '-' + str(p[6]))
+
+
+def p_exclude_spec(p):
+ """ exclude_spec : SADDREXCLUDE ':' ':' one_or_more_strings
+ | DADDREXCLUDE ':' ':' one_or_more_strings
+ | ADDREXCLUDE ':' ':' one_or_more_strings
+ | PROTOCOL_EXCEPT ':' ':' one_or_more_strings """
+
+ p[0] = []
+ for ex in p[4]:
+ if p[1].find('source-exclude') >= 0:
+ p[0].append(VarType(VarType.SADDREXCLUDE, ex))
+ elif p[1].find('destination-exclude') >= 0:
+ p[0].append(VarType(VarType.DADDREXCLUDE, ex))
+ elif p[1].find('address-exclude') >= 0:
+ p[0].append(VarType(VarType.ADDREXCLUDE, ex))
+ elif p[1].find('protocol-except') >= 0:
+ p[0].append(VarType(VarType.PROTOCOL_EXCEPT, ex))
+
+
+def p_prefix_list_spec(p):
+ """ prefix_list_spec : DPFX ':' ':' one_or_more_strings
+ | SPFX ':' ':' one_or_more_strings """
+ p[0] = []
+ for pfx in p[4]:
+ if p[1].find('source-prefix') >= 0:
+ p[0].append(VarType(VarType.SPFX, pfx))
+ elif p[1].find('destination-prefix') >= 0:
+ p[0].append(VarType(VarType.DPFX, pfx))
+
+
+def p_addr_spec(p):
+ """ addr_spec : SADDR ':' ':' one_or_more_strings
+ | DADDR ':' ':' one_or_more_strings
+ | ADDR ':' ':' one_or_more_strings """
+ p[0] = []
+ for addr in p[4]:
+ if p[1].find('source-address') >= 0:
+ p[0].append(VarType(VarType.SADDRESS, addr))
+ elif p[1].find('destination-address') >= 0:
+ p[0].append(VarType(VarType.DADDRESS, addr))
+ else:
+ p[0].append(VarType(VarType.ADDRESS, addr))
+
+
+def p_port_spec(p):
+ """ port_spec : SPORT ':' ':' one_or_more_strings
+ | DPORT ':' ':' one_or_more_strings
+ | PORT ':' ':' one_or_more_strings """
+ p[0] = []
+ for port in p[4]:
+ if p[1].find('source-port') >= 0:
+ p[0].append(VarType(VarType.SPORT, port))
+ elif p[1].find('destination-port') >= 0:
+ p[0].append(VarType(VarType.DPORT, port))
+ else:
+ p[0].append(VarType(VarType.PORT, port))
+
+
+def p_protocol_spec(p):
+ """ protocol_spec : PROTOCOL ':' ':' strings_or_ints """
+ p[0] = []
+ for proto in p[4]:
+ p[0].append(VarType(VarType.PROTOCOL, proto))
+
+
+def p_ether_type_spec(p):
+ """ ether_type_spec : ETHER_TYPE ':' ':' one_or_more_strings """
+ p[0] = []
+ for proto in p[4]:
+ p[0].append(VarType(VarType.ETHER_TYPE, proto))
+
+
+def p_traffic_type_spec(p):
+ """ traffic_type_spec : TRAFFIC_TYPE ':' ':' one_or_more_strings """
+ p[0] = []
+ for proto in p[4]:
+ p[0].append(VarType(VarType.TRAFFIC_TYPE, proto))
+
+
+def p_policer_spec(p):
+ """ policer_spec : POLICER ':' ':' STRING """
+ p[0] = VarType(VarType.POLICER, p[4])
+
+
+def p_logging_spec(p):
+ """ logging_spec : LOGGING ':' ':' STRING """
+ p[0] = VarType(VarType.LOGGING, p[4])
+
+
+def p_option_spec(p):
+ """ option_spec : OPTION ':' ':' one_or_more_strings """
+ p[0] = []
+ for opt in p[4]:
+ p[0].append(VarType(VarType.OPTION, opt))
+
+def p_principals_spec(p):
+ """ principals_spec : PRINCIPALS ':' ':' one_or_more_strings """
+ p[0] = []
+ for opt in p[4]:
+ p[0].append(VarType(VarType.PRINCIPALS, opt))
+
+def p_action_spec(p):
+ """ action_spec : ACTION ':' ':' STRING """
+ p[0] = VarType(VarType.ACTION, p[4])
+
+
+def p_counter_spec(p):
+ """ counter_spec : COUNTER ':' ':' STRING """
+ p[0] = VarType(VarType.COUNTER, p[4])
+
+
+def p_expiration_spec(p):
+ """ expiration_spec : EXPIRATION ':' ':' INTEGER '-' INTEGER '-' INTEGER """
+ p[0] = VarType(VarType.EXPIRATION, datetime.date(int(p[4]),
+ int(p[6]),
+ int(p[8])))
+
+
+def p_comment_spec(p):
+ """ comment_spec : COMMENT ':' ':' DQUOTEDSTRING """
+ p[0] = VarType(VarType.COMMENT, p[4])
+
+
+def p_owner_spec(p):
+ """ owner_spec : OWNER ':' ':' STRING """
+ p[0] = VarType(VarType.OWNER, p[4])
+
+
+def p_verbatim_spec(p):
+ """ verbatim_spec : VERBATIM ':' ':' STRING DQUOTEDSTRING """
+ p[0] = VarType(VarType.VERBATIM, [p[4], p[5].strip('"')])
+
+
+def p_qos_spec(p):
+ """ qos_spec : QOS ':' ':' STRING """
+ p[0] = VarType(VarType.QOS, p[4])
+
+
+def p_interface_spec(p):
+ """ interface_spec : SINTERFACE ':' ':' STRING
+ | DINTERFACE ':' ':' STRING """
+ if p[1].find('source-interface') >= 0:
+ p[0] = VarType(VarType.SINTERFACE, p[4])
+ elif p[1].find('destination-interface') >= 0:
+ p[0] = VarType(VarType.DINTERFACE, p[4])
+
+
+def p_platform_spec(p):
+ """ platform_spec : PLATFORM ':' ':' one_or_more_strings
+ | PLATFORMEXCLUDE ':' ':' one_or_more_strings """
+ p[0] = []
+ for platform in p[4]:
+ if p[1].find('platform-exclude') >= 0:
+ p[0].append(VarType(VarType.PLATFORMEXCLUDE, platform))
+ elif p[1].find('platform') >= 0:
+ p[0].append(VarType(VarType.PLATFORM, platform))
+
+
+def p_timeout_spec(p):
+ """ timeout_spec : TIMEOUT ':' ':' INTEGER """
+ p[0] = VarType(VarType.TIMEOUT, p[4])
+
+
+def p_one_or_more_strings(p):
+ """ one_or_more_strings : one_or_more_strings STRING
+ | STRING
+ | """
+ if len(p) > 1:
+ if type(p[1]) == type([]):
+ p[1].append(p[2])
+ p[0] = p[1]
+ else:
+ p[0] = [p[1]]
+
+
+def p_one_or_more_ints(p):
+ """ one_or_more_ints : one_or_more_ints INTEGER
+ | INTEGER
+ | """
+ if len(p) > 1:
+ if type(p[1]) == type([]):
+ p[1].append(p[2])
+ p[0] = p[1]
+ else:
+ p[0] = [p[1]]
+
+
+def p_strings_or_ints(p):
+ """ strings_or_ints : strings_or_ints STRING
+ | strings_or_ints INTEGER
+ | STRING
+ | INTEGER
+ | """
+ if len(p) > 1:
+ if type(p[1]) is list:
+ p[1].append(p[2])
+ p[0] = p[1]
+ else:
+ p[0] = [p[1]]
+
+
+def p_error(p):
+ """."""
+ next_token = yacc.token()
+ if next_token is None:
+ use_token = 'EOF'
+ else:
+ use_token = repr(next_token.value)
+
+ if p:
+ raise ParseError(' ERROR on "%s" (type %s, line %d, Next %s)'
+ % (p.value, p.type, p.lineno, use_token))
+ else:
+ raise ParseError(' ERROR you likely have unablanaced "{"\'s')
+
+# pylint: enable-msg=W0613,C6102,C6104,C6105,C6108,C6409
+
+
+def _ReadFile(filename):
+ """Read data from a file if it exists.
+
+ Args:
+ filename: str - Filename
+
+ Returns:
+ data: str contents of file.
+
+ Raises:
+ FileNotFoundError: if requested file does not exist.
+ FileReadError: Any error resulting from trying to open/read file.
+ """
+ if os.path.exists(filename):
+ try:
+ data = open(filename, 'r').read()
+ return data
+ except IOError:
+ raise FileReadError('Unable to open or read file %s' % filename)
+ else:
+ raise FileNotFoundError('Unable to open policy file %s' % filename)
+
+
+def _Preprocess(data, max_depth=5, base_dir=''):
+ """Search input for include statements and import specified include file.
+
+ Search input for include statements and if found, import specified file
+ and recursively search included data for includes as well up to max_depth.
+
+ Args:
+ data: A string of Policy file data.
+ max_depth: Maximum depth of included files
+ base_dir: Base path string where to look for policy or include files
+
+ Returns:
+ A string containing result of the processed input data
+
+ Raises:
+ RecursionTooDeepError: nested include files exceed maximum
+ """
+ if not max_depth:
+ raise RecursionTooDeepError('%s' % (
+ 'Included files exceed maximum recursion depth of %s.' % max_depth))
+ rval = []
+ lines = [x.rstrip() for x in data.splitlines()]
+ for index, line in enumerate(lines):
+ words = line.split()
+ if len(words) > 1 and words[0] == '#include':
+ # remove any quotes around included filename
+ include_file = words[1].strip('\'"')
+ data = _ReadFile(os.path.join(base_dir, include_file))
+ # recursively handle includes in included data
+ inc_data = _Preprocess(data, max_depth - 1, base_dir=base_dir)
+ rval.extend(inc_data)
+ else:
+ rval.append(line)
+ return rval
+
+
+def ParseFile(filename, definitions=None, optimize=True, base_dir='',
+ shade_check=False):
+ """Parse the policy contained in file, optionally provide a naming object.
+
+ Read specified policy file and parse into a policy object.
+
+ Args:
+ filename: Name of policy file to parse.
+ definitions: optional naming library definitions object.
+ optimize: bool - whether to summarize networks and services.
+ base_dir: base path string to look for acls or include files.
+ shade_check: bool - whether to raise an exception when a term is shaded.
+
+ Returns:
+ policy object.
+ """
+ data = _ReadFile(filename)
+ p = ParsePolicy(data, definitions, optimize, base_dir=base_dir,
+ shade_check=shade_check)
+ return p
+
+
+def ParsePolicy(data, definitions=None, optimize=True, base_dir='',
+ shade_check=False):
+ """Parse the policy in 'data', optionally provide a naming object.
+
+ Parse a blob of policy text into a policy object.
+
+ Args:
+ data: a string blob of policy data to parse.
+ definitions: optional naming library definitions object.
+ optimize: bool - whether to summarize networks and services.
+ base_dir: base path string to look for acls or include files.
+ shade_check: bool - whether to raise an exception when a term is shaded.
+
+ Returns:
+ policy object.
+ """
+ try:
+ if definitions:
+ globals()['DEFINITIONS'] = definitions
+ else:
+ globals()['DEFINITIONS'] = naming.Naming(DEFAULT_DEFINITIONS)
+ if not optimize:
+ globals()['_OPTIMIZE'] = False
+ if shade_check:
+ globals()['_SHADE_CHECK'] = True
+
+ lexer = lex.lex()
+
+ preprocessed_data = '\n'.join(_Preprocess(data, base_dir=base_dir))
+ p = yacc.yacc(write_tables=False, debug=0, errorlog=yacc.NullLogger())
+
+ return p.parse(preprocessed_data, lexer=lexer)
+
+ except IndexError:
+ return False
+
+
+# if you call this from the command line, you can specify a jcl file for it to
+# read.
+if __name__ == '__main__':
+ ret = 0
+ if len(sys.argv) > 1:
+ try:
+ ret = ParsePolicy(open(sys.argv[1], 'r').read())
+ except IOError:
+ print('ERROR: \'%s\' either does not exist or is not readable' %
+ (sys.argv[1]))
+ ret = 1
+ else:
+ # default to reading stdin
+ ret = ParsePolicy(sys.stdin.read())
+ sys.exit(ret)