#!/usr/bin/python3

import atexit
import command
import hashtoc
import loghandler
import os
import socket
import subprocess
import threading
import time
import shutil
import sys
import tar_stream

def cond_unlink(path, log):
    try:
        os.unlink(path)
        log.DEBUG('removed %s' % path)
    except FileNotFoundError:
        pass

def cond_kill(p):
    try:
        p.kill()
    except:
        pass

class Status:

    def __init__(self, log):
        self.checked = 0
        self.added = 0
        self.deleted = 0
        self.replaced = 0
        self.unchanged = 0
        self.metadata = 0
        self.extract_OK = -1
        def report():
            log.MESSAGE('STATUS %d = +%d -%d =%d ?%d (%d)' % (
                self.checked, self.added, self.deleted,
                self.replaced, self.metadata,
                self.extract_OK))
        atexit.register(report)
        
class Backup:

    def __init__(self, primary_tar, mount, path, status, log):
        self.primary_tar = primary_tar
        self.primary_in = primary_tar.makefile('wb')
        self.primary_out = primary_tar.makefile('rb')
        self.mount = mount
        self.path = path
        self.status = status
        self.log = log
        self.dst_root = os.path.join(mount, path).encode('utf-8')
        self.trash_root = os.path.join(mount, 'TRASH').encode('utf-8')
        self.trash = os.path.join(self.trash_root,
                                  str(int(time.time())).encode('utf-8'))
        self.extractor = threading.Thread(daemon=True, target=self.run)
        self.extractor.start()
        # Make sure that the generated tar archive is not empty
        self.primary_in.write(b'.\0')
        pass

    def run(self):
        cwd = os.path.join(self.mount, self.path)
        reader = tar_stream.TarReader(self.primary_out)
        for e in reader:
            self.make_room(e.size)
            e.tarfile.extract(e, path=cwd)
        pass

    def close(self):
        self.primary_in.flush()
        self.primary_tar.shutdown(socket.SHUT_WR)
        self.extractor.join()
        pass

    def check(self, src, dst):
        if src.name != dst.name:
            raise Exception('Names differ: %s, %s' % (src, dst))
        dst_path = os.path.join(self.dst_root, dst.name)
        if src.kind != dst.kind or src.sum != dst.sum or src.size != dst.size:
            self.log.DEBUG('Replace...', src.name, dst.name,
                           src.sum, dst.sum, src.size, dst.size)
            self.status.replaced += 1
            self.delete(dst)
            self.add(src)
        elif os.path.lexists(dst_path):
            changed = False
            if src.kind in [ b'F', b'D'] and src.mode != dst.mode:
                self.log.DEBUG('MODE', dst.name, src.mode, dst.mode)
                os.chmod(dst_path, int(src.mode, 8))
                changed = True
            if (src.kind in [ b'F', b'D', b'L', b'S'] and
                (src.uid != dst.uid or src.gid != dst.gid)):
                self.log.DEBUG('UID/GID', dst.name, src.uid, src.gid,
                               dst.uid, dst.gid)
                os.lchown(dst_path, int(src.uid), int(src.gid))
                changed = True
            if src.kind in [ b'F', b'L' ] and src.mtime != dst.mtime:
                self.log.DEBUG('MTIME', src.name, src.mtime, dst.mtime)
                atime = os.lstat(dst_path).st_atime
                os.utime(dst_path, (int(atime), int(src.mtime)),
                         follow_symlinks=False)
                changed = True
            if changed:
                self.status.metadata += 1
            else:
                self.status.unchanged += 1
                
    def make_room(self, size):
        while True:
            stat = os.statvfs(self.dst_root)
            free = stat.f_frsize * stat.f_bavail
            need = size + stat.f_frsize
            if free > need:
                break
            self.log.MESSAGE("Need to free:",
                             need - free, (need, free), self.trash_root)
            oldest = sorted(os.listdir(self.trash_root))[0]
            d = os.path.join(self.trash_root, oldest)
            if os.path.isdir(d):
                self.log.MESSAGE('Removing dir', d)
                shutil.rmtree(d)
                pass
            else:
                self.log.MESSAGE('Removing file', d)
                os.unlink(d)

    def add(self, src):
        self.log.DEBUG('Add:', src.name)
        parent = os.path.dirname(src.name)
        while len(parent) != 0:
            # Make sure directories get the correct modes
            self.primary_in.write(parent + b'\0')
            parent = os.path.dirname(parent)
        self.primary_in.write(src.name + b'\0')
        self.primary_in.flush()

    def delete(self, dst):
        self.log.DEBUG('Delete:', dst.name)
        dst_path = os.path.join(self.dst_root, dst.name)
        if os.path.lexists(dst_path):
            trash_path = os.path.join(self.trash, dst.name)
            trash_dir = os.path.dirname(trash_path)
            if not os.path.exists(trash_dir):
                self.make_room(256*1024) # Hack to make sure there is place
                os.makedirs(trash_dir, mode=0o700)
            os.rename(dst_path, trash_path)


def do_backup(hash_name, options, socket_path, mount, path):
    if options.debug:
        log = loghandler.LOG(loghandler.LOG_DEBUG)
    else:
        log = loghandler.LOG(loghandler.LOG_WARNING)
    atexit.register(cond_unlink, socket_path, log)
    status = Status(log)
    
    config_path = '%s/TOTALBACKUP.config' % (mount)
    if not os.path.exists(config_path):
        raise Exception('"%s" does not exists' % (config_path))

    # Connect to server config/hashtoc socket
    config_hash = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
    config_hash.connect(socket_path)
    # Send secondary config to primary
    config_hash.makefile('w').write(open(config_path).read())
    config_hash.shutdown(socket.SHUT_WR)
    # Make ready to read primary TOC (src)
    src = hashtoc.HashTOC(config_hash.makefile('rb'), rename={hash_name:'sum'})

    # Create secondary hashtoc (dst)
    cmd = ( command.Command('/usr/bin/hashtoc')
            .flag('--%s' % hash_name)
            .flag('--zero-terminated')
            .flag('--xattr', options.xattr)
            .option('--max-age', options.max_age)
            .option('--jobs', options.jobs)
            .option('--lookahead', options.lookahead)
            .arg('.') )
    p = subprocess.Popen(cmd,
                         cwd=os.path.join(mount, path),
                         stdout=subprocess.PIPE)
    atexit.register(cond_kill, p)
    dst = hashtoc.HashTOC(p.stdout,  rename={hash_name:'sum'})
            
    # Connect to server tar socket
    primary_tar = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
    primary_tar.connect(socket_path)
    
    backup = Backup(primary_tar=primary_tar,
                    mount=mount, path=path, status=status, log=log)
    while True:
        if src.name == None and dst.name == None:
            # All done
            break
        status.checked += 1
        if src.name == None:
            status.deleted += 1
            backup.delete(dst)
            dst.next()
        elif dst.name == None:
            status.added += 1
            backup.add(src)
            src.next()
        elif src.name == dst.name:
            backup.check(src=src, dst=dst)
            src.next()
            dst.next()
        elif src.name < dst.name:
            status.added += 1
            backup.add(src)
            src.next()
        elif src.name > dst.name:
            status.deleted += 1
            backup.delete(dst)
            dst.next()
        else:
            raise Exception()

    backup.close()
        
    log.DEBUG('hashtoc result', p.wait())
    config_hash.shutdown(socket.SHUT_RD)
    config_hash.close()