#!/usr/bin/env python3

# Copyright (c) 2016, Antonio SJ Musumeci <trapexit@spawn.link>
#
# Permission to use, copy, modify, and/or distribute this software for any
# purpose with or without fee is hereby granted, provided that the above
# copyright notice and this permission notice appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

import argparse
import ctypes
import errno
import fnmatch
import io
import os
import shlex
import stat
import subprocess
import sys


_libc = ctypes.CDLL("libc.so.6",use_errno=True)
_lgetxattr = _libc.lgetxattr
_lgetxattr.argtypes = [ctypes.c_char_p,ctypes.c_char_p,ctypes.c_void_p,ctypes.c_size_t]
def lgetxattr(path,name):
    if type(path) == str:
        path = path.encode(errors='backslashreplace')
    if type(name) == str:
        name = name.encode(errors='backslashreplace')
    length = 64
    while True:
        buf = ctypes.create_string_buffer(length)
        res = _lgetxattr(path,name,buf,ctypes.c_size_t(length))
        if res >= 0:
            return buf.raw[0:res]
        else:
            err = ctypes.get_errno()
            if err == errno.ERANGE:
                length *= 2
            elif err == errno.ENODATA:
                return None
            else:
                raise IOError(err,os.strerror(err),path)


def xattr_relpath(fullpath):
    return lgetxattr(fullpath,'user.mergerfs.relpath').decode(errors='backslashreplace')


def xattr_basepath(fullpath):
    return lgetxattr(fullpath,'user.mergerfs.basepath').decode(errors='backslashreplace')


def ismergerfs(path):
    try:
        lgetxattr(path,'user.mergerfs.version')
        return True
    except IOError as e:
        return False


def mergerfs_control_file(basedir):
    if basedir == '/':
        return None
    ctrlfile = os.path.join(basedir,'.mergerfs')
    if os.path.exists(ctrlfile):
        return ctrlfile
    else:
        dirname = os.path.dirname(basedir)
        return mergerfs_control_file(dirname)


def mergerfs_srcmounts(ctrlfile):
    srcmounts = lgetxattr(ctrlfile,'user.mergerfs.srcmounts')
    srcmounts = srcmounts.decode(errors='backslashreplace').split(':')
    return srcmounts


def match(filename,matches):
    for match in matches:
        if fnmatch.fnmatch(filename,match):
            return True
    return False


def execute_cmd(args):
    return subprocess.call(args)


def print_args(args):
    quoted = [shlex.quote(arg) for arg in args]
    print(' '.join(quoted))


def human_to_bytes(s):
    m = s[-1]
    if   m == 'K':
        i = int(s[0:-1]) * 1024
    elif m == 'M':
        i = int(s[0:-1]) * 1024 * 1024
    elif m == 'G':
        i = int(s[0:-1]) * 1024 * 1024 * 1024
    elif m == 'T':
        i = int(s[0:-1]) * 1024 * 1024 * 1024 * 1024
    else:
        i = int(s)

    return i


def get_stats(branches):
    sizes = {}
    for branch in branches:
        vfs = os.statvfs(branch)
        sizes[branch] = vfs.f_bavail * vfs.f_frsize
    return sizes


def build_move_file(src,tgt,rel):
    rel = rel.strip('/')
    srcpath = os.path.join(src,'./',rel)
    tgtpath = tgt.rstrip('/') + '/'
    return ['rsync',
            '-avHAXWE',
            '--numeric-ids',
            '--progress',
            '--relative',
            '--remove-source-files',
            srcpath,
            tgtpath]


def print_help():
    help = \
'''
usage: mergerfs.consolidate [<options>] <dir>

Consolidate files in a single mergerfs directory onto a single drive.

positional arguments:
  dir                    starting directory

optional arguments:
  -m, --max-files=       Skip directories with more than N files.
                         (default: 256)
  -M, --max-size=        Skip directories with files adding up to more
                         than N. (default: 16G)
  -I, --include-path=    fnmatch compatible path include filter.
                         Can be used multiple times.
  -E, --exclude-path=    fnmatch compatible path exclude filter.
                         Can be used multiple times.
  -e, --execute          Execute `rsync` commands as well as print them.
  -h, --help             Print this help.
'''
    print(help)


def buildargparser():
    parser = argparse.ArgumentParser(add_help=False)
    parser.add_argument('dir',
                        type=str,
                        nargs='?',
                        default=None)
    parser.add_argument('-m','--max-files',
                        dest='max_files',
                        type=int,
                        default=256)
    parser.add_argument('-M','--max-size',
                        dest='max_size',
                        type=human_to_bytes,
                        default='16G')
    parser.add_argument('-I','--include-path',
                        dest='includepath',
                        type=str,
                        action='append',
                        default=[])
    parser.add_argument('-E','--exclude-path',
                        dest='excludepath',
                        type=str,
                        action='append',
                        default=[])
    parser.add_argument('-e','--execute',
                        dest='execute',
                        action='store_true')
    parser.add_argument('-h','--help',
                        action='store_true')

    return parser


def main():
    sys.stdout = io.TextIOWrapper(sys.stdout.buffer,
                                  encoding='utf8',
                                  errors='backslashreplace',
                                  line_buffering=True)
    sys.stderr = io.TextIOWrapper(sys.stderr.buffer,
                                  encoding='utf8',
                                  errors='backslashreplace',
                                  line_buffering=True)

    parser = buildargparser()
    args = parser.parse_args()

    if args.help or not args.dir:
        print_help()
        sys.exit(0)

    args.dir = os.path.realpath(args.dir)

    ctrlfile = mergerfs_control_file(args.dir)
    if not ismergerfs(ctrlfile):
        print("%s is not a mergerfs mount" % args.dir)
        sys.exit(1)

    basedir       = args.dir
    execute       = args.execute
    max_files     = args.max_files
    max_size      = args.max_size
    path_includes = ['*'] if not args.includepath else args.includepath
    path_excludes = args.excludepath
    srcmounts     = mergerfs_srcmounts(ctrlfile)

    mount_stats = get_stats(srcmounts)
    try:
        for (root,dirs,files) in os.walk(basedir):
            if len(files) <= 1:
                continue
            if len(files) > max_files:
                continue
            if match(root,path_excludes):
                continue
            if not match(root,path_includes):
                continue

            total_size = 0
            file_stats = {}
            for file in files:
                fullpath = os.path.join(root,file)
                st = os.lstat(fullpath)
                if not stat.S_ISREG(st.st_mode):
                    continue
                total_size += st.st_size
                file_stats[fullpath] = st

            if total_size >= max_size:
                continue

            tgtpath = sorted(mount_stats.items(),key=lambda x: x[1],reverse=True)[0][0]
            for (fullpath,st) in sorted(file_stats.items()):
                srcpath = xattr_basepath(fullpath)
                if srcpath == tgtpath:
                    continue

                relpath = xattr_relpath(fullpath)

                mount_stats[srcpath] += st.st_size
                mount_stats[tgtpath] -= st.st_size

                args = build_move_file(srcpath,tgtpath,relpath)

                print_args(args)
                if execute:
                    execute_cmd(args)
    except (KeyboardInterrupt,BrokenPipeError):
        pass

    sys.exit(0)


if __name__ == "__main__":
   main()