From a42d38385cb2a5abf4f97b3ebdafec7214ba2585 Mon Sep 17 00:00:00 2001
From: Anders Blomdell <anders.blomdell@control.lth.se>
Date: Thu, 31 Oct 2013 14:25:33 +0100
Subject: [PATCH] Version 2013-10-31 14:25

M  src/hostinfo.py
M  src/hostinfo/ifconfig.py
M  src/hostinfo/named.py
M  src/hostinfo/util.py
---
 src/hostinfo.py          |   2 +-
 src/hostinfo/ifconfig.py |   7 +-
 src/hostinfo/named.py    | 443 ++++++++++++++++++++-------------------
 src/hostinfo/util.py     |   7 +-
 4 files changed, 239 insertions(+), 220 deletions(-)

diff --git a/src/hostinfo.py b/src/hostinfo.py
index 4d3b341..772f031 100755
--- a/src/hostinfo.py
+++ b/src/hostinfo.py
@@ -314,7 +314,7 @@ if __name__ == '__main__':
         print mio.encode("iso8859-1")
         
     if options.named:
-        for (f, c) in hostinfo.named.generate(tree, host):
+        for (f, c) in hostinfo.named.generate(tree, options):
             file["%s/%s" % (options.named, f)] = c
 
     if options.netgroup:
diff --git a/src/hostinfo/ifconfig.py b/src/hostinfo/ifconfig.py
index 4936e11..a541e05 100755
--- a/src/hostinfo/ifconfig.py
+++ b/src/hostinfo/ifconfig.py
@@ -141,7 +141,12 @@ def generate_ifcfg(tree, interface):
                          nameservers, [])
     config.extend(indexed_assign('DNS', nameservers, 1))
     if search:
-        config.append('SEARCH="%s"' % (' '.join(search)))
+        def reverse_order(a, b):
+            a_v = list(reversed(a.split('.')))
+            b_v = list(reversed(b.split('.')))
+            return cmp(a_v, b_v)
+        tmp = sorted(search, cmp=reverse_order)
+        config.append('SEARCH="%s"' % (' '.join(tmp)))
         pass
     if interface.defroute[0]:
         config.append('DEFROUTE=%s' % interface.defroute[0])
diff --git a/src/hostinfo/named.py b/src/hostinfo/named.py
index 54ed6d1..a17a33f 100755
--- a/src/hostinfo/named.py
+++ b/src/hostinfo/named.py
@@ -1,110 +1,90 @@
 import copy
 import hostinfo.parser
-from hostinfo.util import fqn, by_ip, network, address
 import ipaddr
-import itertools
+import hostinfo.util as util
+import re
 
-def reverse_addr(addr):
-    if addr == None:
-        return None
-    if addr.version == 4:
-        def join(l):
-            return '.'.join(reversed(l))+'.in-addr.arpa'
-        if isinstance(addr, ipaddr._BaseNet):
-            assert addr.prefixlen % 8 == 0
-            n = addr.prefixlen / 8
-            return join(addr.exploded.split('.')[0:n])
-        else:
-            return join(addr.exploded.split('.'))
-        pass
-    elif addr.version == 6:
-        def join(l):
-            return '.'.join(reversed(l))+'.ip6.arpa'
-        if isinstance(addr, ipaddr._BaseNet):
-            assert addr.prefixlen % 4 == 0
-            n = addr.prefixlen / 4
-            return join(map(None, addr.exploded.replace(':', ''))[0:n])
-        else:
-            return join(map(None, addr.exploded.replace(':', '')))
-    else:
-	raise Exception('Unknown address version %s' % addr)
-
-def generate(tree, host):
+def generate(tree, options):
     #
     # A. Check if host is a nameserver
     #
     emit = False
     for ns in tree._host_._interface_._nameserver_:
-        if ns.name[0:] == host:
+        if ns.name[0:] == options.host:
             # Given host is a nameserver
             emit = True
             pass
         pass
     
     if not emit:
-        raise Exception("%s is not a nameserver" % host)
+        raise Exception("%s is not a nameserver" % options.host)
     
     #
-    # B. Read named.conf header
+    # B. Get named.conf header
     #
+    conf = util.StringArray()
+    
     for h in tree._nameserver_:
-        conf = h.conf[0].strip() + '\n'
-        conf = conf.replace('{ ', '{\n')
-        conf = conf.replace('; ', ';\n')
+        for l in re.sub('([{;]\s)', '\g<1>\n', h.conf[0]).split('\n'):
+            if l[0] == ' ':
+                l = l[1:]
+                pass
+            conf.append(l)
+            pass
         pass
-
+    
     #
-    # C. Append to named.conf and create named/* files
+    # C. Create reverse mapping for localhost
     #
     result = []
 
-    done = {}
-    rzone = "0.0.127.in-addr.arpa"
-    done[rzone] = 1
-    conf += 'zone "0.0.127.in-addr.arpa" {\n'
-    conf += '  type master;\n'
-    conf += '  file "0.0.127.in-addr.arpa";\n'
-    conf += '};\n'
-    result.append(("named/%s" % rzone, reverse_local(tree, host)))
-
-    forward = {}
-    reverse = {}
-    for s in tree._subnet_:
-        if not s.domain[0] in forward:
-            forward[s.domain[0]] = generate_forward(tree, s)
-            pass
-        r = reverse_addr(network(s))
-        if r and not r in reverse:
-            reverse[r] = generate_reverse(tree, s)
-            pass
-
-    for s in tree._subnet_ipv6_:
-        if not s.domain[0] in forward:
-            forward[s.domain[0]] = generate_forward(tree, s)
-            pass
-        r = reverse_addr(network(s))
-        if r and not r in reverse:
-            reverse[r] = generate_reverse(tree, s)
-            pass
-        pass
+    forward = generate_forward(tree, get_hosts(tree))
+    reverse = generate_reverse(tree, get_hosts(tree, with_alias=False))
 
     for f in filter(lambda f: forward[f], sorted(forward)):
-        conf += "zone \"%s\" { \n" % f
-        conf += "  type master; file \"hosts-%s\"; \n" % f
-        conf += "};\n"
-        result.append(("named/hosts-%s" % f, forward[f]))
+        conf.append_lines("""
+          |zone \"%(name)s\" {
+          |  type master; file \"hosts-%(name)s\";
+          |};""" % dict(name=f))
+        result.append(("named/hosts-%s" % f, str(forward[f])))
         pass
     for r in filter(lambda r: reverse[r], sorted(reverse)):
-        conf += "zone \"%s\" { \n" % r
-        conf += "  type master; file \"%s\"; \n" % r
-        conf += "};\n"
-        result.append(("named/%s" % r, reverse[r])) 
+        conf.append_lines("""
+          |zone \"%(name)s\" {
+          |  type master; file \"%(name)s\";
+          |};""" % dict(name=r))
+        result.append(("named/%s" % r, str(reverse[r])))
         pass
 
-    result.append(("named.conf", conf))
+    result.append(("named.conf", str(conf)))
     
     return result
 
+def reverse_addr(addr):
+    if addr == None:
+        return None
+    if addr.version == 4:
+        def join(l):
+            return '.'.join(reversed(l))+'.in-addr.arpa'
+        if isinstance(addr, ipaddr._BaseNet):
+            assert addr.prefixlen % 8 == 0
+            n = addr.prefixlen / 8
+            return join(addr.exploded.split('.')[0:n])
+        else:
+            return join(addr.exploded.split('.'))
+        pass
+    elif addr.version == 6:
+        def join(l):
+            return '.'.join(reversed(l))+'.ip6.arpa'
+        if isinstance(addr, ipaddr._BaseNet):
+            assert addr.prefixlen % 4 == 0
+            n = addr.prefixlen / 4
+            return join(map(None, addr.exploded.replace(':', ''))[0:n])
+        else:
+            return join(map(None, addr.exploded.replace(':', '')))
+    else:
+	raise Exception('Unknown address version %s' % addr)
+
 def header(tree, domain, origin=None):
     soa = None
     for s in tree._soa_:
@@ -122,185 +102,214 @@ def header(tree, domain, origin=None):
     if not filter(lambda ns: ns.domain[0] == domain and ns.primary[0] == 'yes',
                   tree._host_._interface_._nameserver_):
         return None
-    result = ""
+    result = util.StringArray()
     if origin:
-        result = "$ORIGIN %s.\n" % origin
+        result += "$ORIGIN %s." % origin
         pass
     if soa.ttl[0]:
-        result += "$TTL %s\n" % soa.ttl[0]
+        result += "$TTL %s" % soa.ttl[0]
         pass
-    result += "@ IN SOA %s %s ( \n" % (soa.nameserver[0], soa.email[0])
-    result += "                %-15s ; Serial\n" % tree._mtime
-    result += "                %-15s ; Refresh\n" % soa.refresh[0]
-    result += "                %-15s ; Retry\n" % soa.retry[0]
-    result += "                %-15s ; Expire\n" % soa.expire[0]
-    result += "                %-15s ; Minimum\n" % soa.minimum[0]
-    result += "                )\n"
-    result += ";\n"
-
+    result.append_lines("""
+      |@ IN SOA %(nameserver)s %(email)s (
+      |                %(mtime)-15s ; Serial
+      |                %(refresh)-15s ; Refresh
+      |                %(retry)-15s ; Retry
+      |                %(expire)-15s ; Expire
+      |                %(minimum)-15s ; Minimum
+      |                )
+      |;""" % dict(nameserver=soa.nameserver[0],
+                   email=soa.email[0],
+                   mtime=tree._mtime,
+                   refresh=soa.refresh[0],
+                   retry=soa.retry[0],
+                   expire=soa.expire[0],
+                   minimum=soa.minimum[0]))
     for ns in tree._host_._interface_._nameserver_:
         if ns.domain[0] == domain and ns.primary[0] == 'yes':
-            result += "                IN      NS      %s\n" % (
-                fqn(tree, ns._parent))
+            result += "                IN      NS      %s" % (
+                util.fqn(tree, ns._parent))
             pass
         pass
-    result += ";\n"
+    result += ";"
     return result
 
-def generate_forward(tree, domain):
-    result = header(tree, domain.domain[0], domain.domain[0])
-    if not result:
-        return None
-    net = []
-    for s in tree._subnet_:
-        if network(s) and s.domain[0] == domain.domain[0]:
-            net.append(network(s))
-            pass
-        pass
+class DomainDict:
+    
+    class Domain:
 
-    for m in tree._host_._interface_._mailhost_:
-        if m.domain[0] == domain.domain[0]:
-            pri = int(m.priority[0] or 0)
-            result += "%-16sIN      MX      %d %s\n" % ("", pri,
-                                                        fqn(tree, m._parent))
+        def __init__(self, header):
+            self.header = header
+            self.host = {}
             pass
-        pass
-    result += ";\n"
 
-    # Add domain TXT entries
-    newline = False
-    for t in tree._subnet_._txt_:
-        if t.domain[1] == domain.domain[0]:
-            result += '                IN      TXT     "%s"\n' % (
-                t.value[0])
-            newline = True
+        def add_host(self, name, kind, value):
+            if not name in self.host:
+                self.host[name] = set()
+                pass
+            self.host[name].add((kind, value))
             pass
+
+        def value(self, cmp=None):
+            result = util.StringArray()
+            result += self.header
+            for name in sorted(self.host, cmp):
+                for kind, value in sorted(self.host[name]):
+                    result += ('%(name)-18s IN      %(kind)-7s %(value)s' % 
+                               dict(name=name, kind=kind, value=value))
+                    pass
+                pass
+            return result
+        
         pass
-    if newline:
-        result += ";\n"
+
+    def __init__(self, callback):
+        self.callback = callback
+        self.domain = {}
         pass
+
+    def value(self, cmp=None):
+        result = {}
+        for d in self.domain:
+            result[d] = self.domain[d].value(cmp)
+            pass
+        return result
+
+    def __getitem__(self, key):
+        if not key in self.domain:
+            self.domain[key] = self.Domain(self.callback(key))
+            pass
+        return self.domain[key]
         
-    # Add a localhost entry
-    result += "localhost       IN      A       127.0.0.1\n"
-    result += ";\n"
-                
-    host = {}
-    def add_entry(name, kind, value):
-        if not name in host:
-            host[name] = []
+def generate_forward(tree, hosts):
+    def callback(domain):
+        result = header(tree, domain, domain)
+        for mx in [ m for m in tree._host_._interface_._mailhost_ 
+                    if m.domain[0] == domain]:
+            pri = int(mx.priority[0] or 0)
+            result += ('                IN      MX      %d %s' %
+                       (pri, util.fqn(tree, mx._parent)))
             pass
-        if not (kind, value) in host[name]:
-            host[name].append((kind, value))
+        for txt in [ t for t in tree._subnet_._txt_ if t.domain[1] == domain]:
+            result += ('                IN      TXT     "%s"' % (txt.value[0]))
             pass
+        result.append_lines("""
+          |;
+          |localhost       IN      A       127.0.0.1
+          |;""")
+        return result
+    result = DomainDict(callback)
+    
+    # Add cname hosts
+    for c in tree._subnet_._cname_:
+        result[c.domain[1]].add_host(c.alias[0], 'CNAME', c.name[0])
         pass
-    for i in filter(address, tree._host_._interface_._ip_):
-        # Find all hosts that belong to this network
-        for n in net:
-            if address(i) in n:
-                add_entry(i.name[1:], 'A', '%s' % i.address[0])
-                for a in i._alias_:
-                    add_entry(a.name[0], 'A', '%s' % i.address[0])
+    # Add numbered hosts
+    for domain,net in [ (s.domain[0],util.network(s)) 
+                        for s in tree._subnet_ 
+                        if s.domain[0] and util.network(s)]:
+        for name,address in hosts:
+            if address in net:
+                if address.version == 4:
+                    result[domain].add_host(name, 'A', str(address.exploded))
                     pass
-                for s in i._srv_:
-                    port = int(s.port[0] or 0)
-                    priority = int(s.priority[0] or 0)
-                    weight = int(s.weight[0] or 0)
-                    add_entry(s.name[0], 'SRV', '%d %d %d %s' % (
-                        priority, weight, port, s.name[1:]))
+                elif address.version == 6:
+                    result[domain].add_host(name, 'AAAA', str(address.exploded))
                     pass
                 pass
             pass
         pass
+    
+    return result.value()
 
-    for i in filter(address, tree._host_._interface_._ipv6_):
-        for n in net:
-            if address(i) in n:
-                add_entry(i.name[1:], 'AAAA', '%s' % i.address[0])
-                for a in i._alias_:
-                    add_entry(a.name[0], 'AAAA', '%s' % i.address[0])
+def generate_reverse(tree, hosts):
+    def callback(origin):
+        result = header(tree, origin_to_domain[origin], origin)
+        return result
+    result = DomainDict(callback)
+    net_to_origin = {}
+    origin_to_domain = {}
+    for s in filter(util.network, tree._subnet_):
+        net = util.network(s)
+        origin = reverse_addr(net)
+        net_to_origin[net] = origin
+        origin_to_domain[origin] = s.domain[0]
+        pass
+    for net in net_to_origin:
+        origin = net_to_origin[net]
+        domain = origin_to_domain[origin]
+        for name,address in hosts:
+            if address in net:
+                reverse = reverse_addr(address).replace('.%s' % origin, '')
+                fqn = name
+                if fqn[-1] != '.':
+                    fqn += '.' + domain + '.'
                     pass
+                result[origin].add_host(reverse, 'PTR', fqn)
                 pass
             pass
         pass
+    def by_reverse(a, b):
+        a_v = map(lambda i: int(i,16), reversed(a.split('.')))
+        b_v = map(lambda i: int(i,16), reversed(b.split('.')))
+        return cmp(a_v, b_v)
+    return result.value(cmp=by_reverse)
 
-    for c in domain._cname_:
-        # Emit cnames defined in subnet
-        add_entry(c.alias[0], 'CNAME', '%s' % c.name[0])
+def get_hosts(tree, with_alias=True):
+    result = []
+    seen = {}
+    def add (name, address, check=None):
+        if check and address in seen:
+            old_name = seen[address][0]
+            old_check = seen[address][1]
+            if check.duplicate[0] != 'ok' or old_check.duplicate[0] != 'ok':
+                raise util.HostinfoException('Duplicate address %s %s %s' 
+                                             % (address, old_name, name),
+                                             where=[old_check, check])
+            pass
+        seen[address] = (name, check)
+        result.append((name,address))
         pass
-        
-    # Add hosts
-    for h1 in sorted(host):
-        for k,v  in sorted(host[h1]):
-            result += "%-18s IN      %-7s %s\n" % (h1, k, v)
+    # IPv4 static addresses
+    for i in filter(util.address, tree._host_._interface_._ip_):
+        add(i.name[0:], util.address(i), check=i)
+        if with_alias:
+            for a in i._alias_:
+                add(a.name[0:], util.address(i))
+                pass
             pass
         pass
-    
-    return result
-
-def generate_reverse(tree, subnet):
-    net = network(subnet)
-    origin = reverse_addr(net)
-    result = header(tree, subnet.domain[0], origin)
-    if not result:
-        return None
-    host = {}
-    for i in itertools.chain(tree._host_._interface_._ip_,
-                             tree._host_._interface_._ipv6_):
-        # Find all hosts that belong to this network
-        a = address(i)
-        if not a:
-            continue
-        if a in net:
-            r = reverse_addr(a).replace('.%s' % origin, '')
-            host[r] = "PTR     %s" % fqn(tree, i)
+    # IPv4 dynamic addresses
+    for d in filter(lambda d: d.first[0] and d.last[0], 
+                    tree._host_._interface_._ip_._dhcpserver_):
+        last = util.address(d.last[0])
+        a = util.address(d.first[0])
+        while a <= last:
+            name = '-'.join([ 'dynamic' ] + a.exploded.split('.'))
+            add(name, a, check=d)
+            a = a + 1
             pass
         pass
-    def order(a, b):
-        def int16(v):
-            return int(v, 16)
-        return cmp(map(int16, a.split('.')),
-                   map(int16, b.split('.')))
-    for h in sorted(host, order):
-        result += "%-15s IN      %s\n" % (h, host[h])
+        
+    # IPv6 static addresses
+    for i in filter(util.address, tree._host_._interface_._ipv6_):
+        add(i.name[0:], util.address(i), check=i)
+        if with_alias:
+            for a in i._alias_:
+                add(a.name[0:], util.address(i))
+                pass
+            pass
+        pass
+    # IPv6 dynamic addresses
+    for d in filter(lambda d: d.first[0] and d.last[0], 
+                    tree._host_._interface_._ipv6_._dhcpserver_):
+        last = util.address(d.last[0])
+        a = util.address(d.first[0])
+        while a <= last:
+            name = '-'.join([ 'dynamic' ] + a.exploded.split(':'))
+            add(name, a, check=d)
+            a = a + 1
+            pass
         pass
     return result
-  
-def reverse_local(tree, nameserver):
-    # Synthesize a minimal tree for local domain
-    t = hostinfo.parser.Node("hostinfo")
-    t._mtime = tree._mtime
 
-    # Add SOA(s)
-    for s in tree._soa_:
-        t._add(copy.copy(s))
-    # Add subnets (for fqn in header)
-    for s in tree._subnet_:
-        t._add(copy.copy(s))
-    # Add local subnet
-    net = hostinfo.parser.Node("subnet", 0, { "network" : "127.0.0.0",
-                                              "netmask" : "255.255.255.0" })
-    t._add(net)
-    # Add localhost
-    h = hostinfo.parser.Node("host", 0, { "name" : "localhost." })
-    i = hostinfo.parser.Node("interface", 0)
-    ip = hostinfo.parser.Node("ip", 0, { "address" : "127.0.0.1" })
-    i._add(ip)
-    h._add(i)
-    t._add(h)
-    # Add nameserver (with extra domain entry)
-    for ns in tree._host_._interface_._nameserver_:
-        if ns.name[0:] == nameserver:
-            h = copy.copy(ns._parent._parent)
-            i = copy.copy(ns._parent)
-            i._add(hostinfo.parser.Node("nameserver", attr={"primary":"yes"}))
-            for ip in ns._parent._ip_:
-                if not ip.alias[0] and not ip.vlan[0]:
-                    i._add(copy.copy(ip))
-                    pass
-                pass
-            h._add(i)
-            t._add(h)
-            break
-        pass
-    return generate_reverse(t, net)
-  
+    
diff --git a/src/hostinfo/util.py b/src/hostinfo/util.py
index 3e4c177..06e6de7 100755
--- a/src/hostinfo/util.py
+++ b/src/hostinfo/util.py
@@ -92,7 +92,12 @@ class StringArray(list, object):
     
     def __iadd__(self, other):
         try:
-            self.extend(other.split('\n'))
+            lines = other.split('\n')
+            if len(lines[-1]) == 0:
+                lines.pop()
+                print '#! ', lines
+                pass
+            self.extend(lines)
             pass
         except AttributeError:
             self.extend(other)
-- 
GitLab