# ZFS snapshot server

import freebsd, SocketServer, threading, os, shelve, commands, time
from popen2 import popen3

ZSERVER = ('gohan18.freebsd.org', 8888)
ZFSLOCAL = '/tmp/.zserver'

# Only serve snapshots from a subtree for security
SNAPPREFIX = 'a/snap/'

storepath = '/usr/local/etc/zfs/snaps'
store = None

def issafepath(s):
    for i in s:
        if not i.isalnum() and not i == '_' and not i == '-' and not i == "/" :
            return False
    return True

def isprivileged(sock):
    return freebsd.getpeerid(sock)[0] == 0

def zfs_getallfs():
    (err, out) = commands.getstatusoutput("zfs list -Ht filesystem")
    if err:
        raise (OSError, err)
    return out

def zfs_getfs(fs):

    if not issafepath(fs):
        return None

    (err, out) = commands.getstatusoutput("zfs list -Ht filesystem '%s'" % fs)
    if err:
        if "dataset does not exist" in out:
            return None
        print "err = %s, out = %s" % (err, out)
        raise (OSError, err)
    return out

def zfs_getsnaps(fs):

    if not issafepath(fs):
        return None

    (err, out) = commands.getstatusoutput("zfs list -Ht snapshot | grep '^%s@'" % fs)
    if err:
        print "err = %s, out = %s" % (err, out)
        raise (OSError, err)
    if len(out) == 0:
        return None

    return tuple(tuple(i.split('\t')) for i in out.split('\n'))

def zfs_validate(fs):
    """ Make sure that the filesystem exists; return list of snapshots otherwise delete from store.

    Returns:
    None if fs does not exist
    [] if fs exsts but no snapshots
    List of snapshots
"""
    global store

    if not issafepath(fs):
        return None

    if not fs.startswith(SNAPPREFIX):
        try:
            del store[fs]
            return None
        except:
            pass

    if zfs_getfs(fs):
        snaps = zfs_getsnaps(fs)
        return snaps
    else:
        try:
            del store[fs]
        except:
            pass
        return None
    return

def do_list(sock, wfile, arg):
    wfile.write("200 Filesystem list\n")

    for (fs, dummy) in store.iteritems():
        snaps = zfs_validate(fs)
        if snaps:
            for (snap, used, dummy, refer, dummy) in snaps:
                wfile.write("%s %s %s %s\n" % (fs[len(SNAPPREFIX):], snap.split('@')[1], used, refer))
    return

def do_diff(sock, wfile, arg):
    path = SNAPPREFIX + arg[0]
    snap1 = path + '@' + arg[1]
    snap2 = path + '@' + arg[2]

    if not issafepath(arg[0]) or not issafepath(arg[1]) or not issafepath(arg[2]):
        wfile.write("400 argument error\n")
        return

    if not path in store:
        return "400 filesystem %s not found\n" % arg[0]

    snaps = [i[0] for i in zfs_validate(path)]

    if not snaps or not snap1 in snaps:
        return "400 snapshot %s does not exist\n" % arg[1]

    if not snap2 in snaps:
        return "400 snapshot %s does not exist\n" % arg[2]

    wfile.write("200 Here it comes!\n")

    try:
        (stdout, stdin, stderr) = popen3("zfs send -i %s %s | gzip -c9" % (snap1, snap2))
        while True:
            buf = stdout.read(32*1024)
            if not buf:
                break
            wfile.write(buf)
    except:
        pass
    finally:
        stdout.close()
        stdin.close()
        stderr.close()
    return

def do_get(sock, wfile, arg):

    if not issafepath(arg[0]) or not issafepath(arg[1]):
        wfile.write("400 argument error\n")
        return

    path=SNAPPREFIX + arg[0]
    snap = path + "@" + arg[1]

    if not path in store:
        return "400 filesystem %s not found\n" % arg[0]
    
    snaps = [i[0] for i in zfs_validate(path)]
    if not snap in snaps:
        return "400 snapshot %s does not exist\n" % arg[1]

    wfile.write("200 Here it comes\n")
    try:
        (stdout, stdin, stderr) = popen3("zfs send %s | gzip -c9 " % snap)
        while True:
            buf = stdout.read(32*1024)
            if not buf:
                break
            wfile.write(buf)
    except:
        pass
    finally:
        stdout.close()
        stdin.close()
        stderr.close()

    return

def do_reg(sock, wfile, arg):
    global store

    if not isprivileged(sock):
        return "400 access denied\n"

    path = arg[0]

    if not issafepath(path):
        wfile.write("400 argument error\n")
        return

    if not path.startswith(SNAPPREFIX) or path == SNAPPREFIX:
        return "400 not in public snapshot tree\n"

    if path in store:
        return "400 filesystem already registered\n"

    out = zfs_getfs(path)
    if out:
        wfile.write("200 Exporting %s\n" % path)
        store[path] = []
        store.sync()
    else:
        wfile.write("400 no such filesystem\n")
    return

def do_unreg(sock, wfile, arg):
    if not isprivileged(sock):
        return "400 access denied\n"
    path=arg[0]

    if not issafepath(path):
        wfile.write("400 argument error\n")
        return

    try:
        del store[path]
        wfile.write("200 Unexporting %s\n" % path)
    except KeyError:
        wfile.write("400 no such filesystem\n")
    return

class Handler(SocketServer.StreamRequestHandler):

    def handle(self):
        input = self.rfile.readline().split()

        try:
            cmd = input[0]
            arg = input[1:]
            res = self.cmddict[cmd](self.request, self.wfile, arg)
            if res:
                self.wfile.write(res)
        except (KeyError, IndexError):
            self.wfile.write("300 No such command\n")

class TCPhandler(Handler):

    cmddict={'LIST':do_list, #"List availble sets",
             'GET':do_get, #"Receive a snapshot",
             'DIFF':do_diff, #"Receive a snapshot",
             }

class UNIXhandler(Handler):

    cmddict={'LIST':do_list, #"List availble sets",
             'REGISTER':do_reg, #"Register a new snapshot (privileged)",
             'UNREGISTER':do_unreg, #"Unregister a snapshot (privileged)",
             'GET':do_get, #"Receive a snapshot",
             'DIFF':do_diff, #"Receive a diff between two snapshots",
             }

class unixworker(threading.Thread):
    sock = None

    def __init__(self, sock):
        super(unixworker, self).__init__()
        self.sock = sock

        try:
            stats = os.stat(self.sock)
            if stats.st_mode & 0140000:
                os.unlink(self.sock)
        except OSError, foo:
            if foo.errno != 2: # ENOENT
                raise

    def run(self):
        print "Starting UNIX socket server on %s" % self.sock
        server = SocketServer.ThreadingUnixStreamServer(self.sock, UNIXhandler)
        os.chmod(self.sock, 0666)
        server.serve_forever()

def main():
    global store

    store = shelve.open(storepath, flag='c')

    try:
        uworker = unixworker(ZFSLOCAL)
        uworker.start()
    except:
        raise

    print "Starting server on %s port %d"% ZSERVER
    while True:
        try:
            server = SocketServer.ThreadingTCPServer(ZSERVER, TCPhandler)
            server.serve_forever()
        except KeyboardInterrupt:
            exit()
        except SystemExit:
            exit()
        except:
            time.sleep(1)
            print "Retrying..."

if __name__ == "__main__":
    main()

