#!/usr/bin/env python3

import json
import time
from subprocess import check_output, check_call, STDOUT


def getEtcdStatus():
    status = check_output('juju status --format json etcd', shell=True).decode()
    return json.loads(status)


def getEtcdUnits():
    status = getEtcdStatus()
    return status['applications']['etcd']['units']


def getMajorVersion():
    status = getEtcdStatus()
    return int(status['applications']['etcd']['version'].split('.')[0])


def waitActiveIdle():
    print("· Waiting for etcd units to be active and idle...", end='', flush=True)
    units = getEtcdUnits().values()
    status = [u['workload-status']['current'] == 'active' and u['juju-status']['current'] == 'idle' for u in units]
    while not all(status):
        print(".", end='', flush=True)
        units = getEtcdUnits().values()
        status = [u['workload-status']['current'] == 'active' and u['juju-status']['current'] == 'idle' for u in units]
    print()


def waitEtcdConvergence():
    print("· Waiting for etcd cluster convergence...", end='', flush=True)
    score = None
    while score is None or score > 1:
        cmd = "juju run --format json --application etcd 'ETCDCTL_API=3 /snap/bin/etcd.etcdctl --endpoints http://127.0.0.1:4001 -w table endpoint status'"
        es = check_output(cmd, shell=True)
        es = es.decode()
        es = json.loads(es)
        es = [int(s['Stdout'].split('\n')[3].split('|')[6].strip()) for s in es]
        score = max([max([abs(a-b) for b in es]) for a in es])
        if score > 1:
            print('.', end='', flush=True)
    print()


def main():

    waitActiveIdle()

    print("· Acquiring configured version of etcd...")
    config = check_output('juju config --format json etcd', shell=True).decode()
    config = json.loads(config)
    configuredVersion = config['settings']['channel']['value']

    # Kick off bin upgrade if we need to.
    if int(configuredVersion.split('.')[0]) < 3:
        print("· Upgrading etcd to version 3.0/stable from %s." % configuredVersion)
        check_call('juju config etcd channel=3.0/stable', shell=True)

    # Wait on the bin upgrade if it seems to still be in progress.
    if getMajorVersion() < 3:
        print("· Waiting for etcd upgrade to 3.0/stable...", end='', flush=True)
        while getMajorVersion() < 3:
            print(".", end='', flush=True)
        configuredVersion = '3.0/stable'
        print()

    # Wait for things to settle down.
    waitActiveIdle()
    waitEtcdConvergence()

    # Shut down the k8s api server.
    print("· Stopping all Kubernetes api servers...")
    check_output("juju run --application kubernetes-master 'service snap.kube-apiserver.daemon stop'", shell=True, stderr=STDOUT)

    # Wait for the etcd cluster to converge.
    waitEtcdConvergence()

    # Stop the etcd cluster.
    print("· Stopping etcd...")
    check_output("juju run --application etcd 'service snap.etcd.etcd stop'", shell=True, stderr=STDOUT)

    # Migrate each etcd unit.
    units = getEtcdUnits()
    for unit in units:
        print("· Migrating %s..." % unit)
        unitid = unit.split('/')[1]
        cmd = "juju run --unit etcd/%s 'ETCDCTL_API=3 /snap/bin/etcdctl migrate --data-dir=/var/snap/etcd/current/etcd%s.etcd'"
        cmd %= (unitid, unitid)
        check_output(cmd, shell=True, stderr=STDOUT)

    # Restart the etcd cluster.
    print("· Starting etcd...")
    check_output("juju run --application etcd 'service snap.etcd.etcd start'", shell=True, stderr=STDOUT)

    # Configure storage backend to etcd3.
    print("· Configuring storage-backend to etcd3...")
    check_call("juju config kubernetes-master storage-backend=etcd3", shell=True)

    # Wait for k8s api servers to come back online.
    print("· Waiting for all Kubernetes api servers to restart...", end='', flush=True)
    cmd = 'juju run --format json --application kubernetes-master "ps aux | grep apiserver | grep -v grep"'
    while True:
        ps = check_output(cmd, shell=True, stderr=STDOUT)
        ps = ps.decode()
        ps = json.loads(ps)
        if all(['kube-apiserver' in s['Stdout'] for s in ps]):
            break
        print(".", end='', flush=True)
    print()

    # All done!
    print("· Done.")


if __name__ == '__main__':
    main()