#!/usr/bin/env python3
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.

import atexit
import base64
import contextlib
import functools
import glob
import inspect
import json
import ntpath
import optparse
import os
import posixpath
import re
import socket
import subprocess as sp
import sys
import time
import uuid
import traceback
from configparser import ConfigParser

from pbkdf2 import pbkdf2_hex

COMMON_SALT = uuid.uuid4().hex

try:
    from urllib.request import urlopen
except ImportError:
    from urllib.request import urlopen

try:
    import http.client as httpclient
except ImportError:
    import http.client as httpclient


def toposixpath(path):
    if os.sep == ntpath.sep:
        return path.replace(ntpath.sep, posixpath.sep)
    else:
        return path


def log(msg):
    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            def print_(chars):
                if log.verbose:
                    sys.stdout.write(chars)
                    sys.stdout.flush()

            argnames = list(inspect.signature(func).parameters.keys())
            callargs = dict(list(zip(argnames, args)))
            callargs.update(kwargs)
            print_("[ * ] " + msg.format(**callargs) + " ... ")
            try:
                res = func(*args, **kwargs)
            except KeyboardInterrupt:
                print_("ok\n")
            except Exception as err:
                print_("failed: %s\n" % err)
                raise
            else:
                print_("ok\n")
                return res

        return wrapper

    return decorator


log.verbose = True


def main():
    ctx = setup()
    startup(ctx)
    if ctx["cmd"]:
        run_command(ctx, ctx["cmd"])
    else:
        join(ctx, cluster_port(ctx, 1), *ctx["admin"])


def setup():
    opts, args = setup_argparse()
    ctx = setup_context(opts, args)
    setup_logging(ctx)
    setup_dirs(ctx)
    check_beams(ctx)
    check_boot_script(ctx)
    setup_configs(ctx)
    return ctx


def setup_logging(ctx):
    log.verbose = ctx["verbose"]


def setup_argparse():
    parser = get_args_parser()
    return parser.parse_args()


def get_args_parser():
    parser = optparse.OptionParser(description="Runs CouchDB 2.0 dev cluster")
    parser.add_option(
        "-a",
        "--admin",
        metavar="USER:PASS",
        default=None,
        help="Add an admin account to the development cluster",
    )
    parser.add_option(
        "-n",
        "--nodes",
        metavar="nodes",
        default=3,
        type=int,
        help="Number of development nodes to be spun up",
    )
    parser.add_option(
        "-q",
        "--quiet",
        action="store_false",
        dest="verbose",
        default=True,
        help="Don't print anything to STDOUT",
    )
    parser.add_option(
        "--with-admin-party-please",
        dest="with_admin_party",
        default=False,
        action="store_true",
        help="Runs a dev cluster with admin party mode on",
    )
    parser.add_option(
        "--enable-erlang-views",
        action="store_true",
        help="Enables the Erlang view server",
    )
    parser.add_option(
        "--no-join",
        dest="no_join",
        default=False,
        action="store_true",
        help="Do not join nodes on boot",
    )
    parser.add_option(
        "--with-haproxy",
        dest="with_haproxy",
        default=False,
        action="store_true",
        help="Use HAProxy",
    )
    parser.add_option(
        "--haproxy", dest="haproxy", default="haproxy", help="HAProxy executable path"
    )
    parser.add_option(
        "--haproxy-port", dest="haproxy_port", default="5984", help="HAProxy port"
    )
    parser.add_option(
        "--node-number",
        dest="node_number",
        type=int,
        default=1,
        help="The node number to seed them when creating the node(s)",
    )
    parser.add_option(
        "-c",
        "--config-overrides",
        action="append",
        default=[],
        help="Optional key=val config overrides. Can be repeated",
    )
    parser.add_option(
        "--erlang-config",
        dest="erlang_config",
        default="rel/files/sys.config",
        help="Specify an alternative Erlang application configuration",
    )
    parser.add_option(
        "--degrade-cluster",
        dest="degrade_cluster",
        type=int,
        default=0,
        help="The number of nodes that should be stopped after cluster config",
    )
    parser.add_option(
        "--no-eval",
        action="store_true",
        default=False,
        help="Do not eval subcommand output",
    )
    parser.add_option(
        "--auto-ports",
        dest="auto_ports",
        default=False,
        action="store_true",
        help="Select available ports for nodes automatically",
    )
    parser.add_option(
        "--extra_args",
        dest="extra_args",
        default=None,
        help="Extra arguments to pass to beam process",
    )
    parser.add_option(
        "-l",
        "--locald-config",
        dest="locald_configs",
        action="append",
        default=[],
        help="Path to config to place in 'local.d'. Can be repeated",
    )
    return parser


def setup_context(opts, args):
    fpath = os.path.abspath(__file__)
    return {
        "N": opts.nodes,
        "no_join": opts.no_join,
        "with_admin_party": opts.with_admin_party,
        "enable_erlang_views": opts.enable_erlang_views,
        "admin": opts.admin.split(":", 1) if opts.admin else None,
        "nodes": ["node%d" % (i + opts.node_number) for i in range(opts.nodes)],
        "node_number": opts.node_number,
        "degrade_cluster": opts.degrade_cluster,
        "devdir": os.path.dirname(fpath),
        "rootdir": os.path.dirname(os.path.dirname(fpath)),
        "cmd": " ".join(args),
        "verbose": opts.verbose,
        "with_haproxy": opts.with_haproxy,
        "haproxy": opts.haproxy,
        "haproxy_port": opts.haproxy_port,
        "config_overrides": opts.config_overrides,
        "erlang_config": opts.erlang_config,
        "no_eval": opts.no_eval,
        "extra_args": opts.extra_args,
        "reset_logs": True,
        "procs": [],
        "auto_ports": opts.auto_ports,
        "locald_configs": opts.locald_configs,
    }


@log("Setup environment")
def setup_dirs(ctx):
    ensure_dir_exists(ctx["devdir"], "logs")


def ensure_dir_exists(root, *segments):
    path = os.path.join(root, *segments)
    if not os.path.exists(path):
        os.makedirs(path)
    return path


@log("Ensure CouchDB is built")
def check_beams(ctx):
    for fname in glob.glob(os.path.join(ctx["devdir"], "*.erl")):
        sp.check_call(["erlc", "-o", ctx["devdir"] + os.sep, fname])


@log("Ensure Erlang boot script exists")
def check_boot_script(ctx):
    if not os.path.exists(os.path.join(ctx["devdir"], "devnode.boot")):
        env = os.environ.copy()
        env["ERL_LIBS"] = os.path.join(ctx["rootdir"], "src")
        sp.check_call(["escript", "make_boot_script"], env=env, cwd=ctx["devdir"])


@log("Prepare configuration files")
def setup_configs(ctx):
    for idx, node in enumerate(ctx["nodes"]):
        cluster_port, backend_port, prometheus_port = get_ports(
            ctx, idx + ctx["node_number"]
        )
        env = {
            "prefix": toposixpath(ctx["rootdir"]),
            "package_author_name": "The Apache Software Foundation",
            "data_dir": toposixpath(
                ensure_dir_exists(ctx["devdir"], "lib", node, "data")
            ),
            "view_index_dir": toposixpath(
                ensure_dir_exists(ctx["devdir"], "lib", node, "data")
            ),
            "state_dir": toposixpath(
                ensure_dir_exists(ctx["devdir"], "lib", node, "data")
            ),
            "node_name": "-name %s@127.0.0.1" % node,
            "cluster_port": cluster_port,
            "backend_port": backend_port,
            "prometheus_port": prometheus_port,
            "uuid": "fake_uuid_for_dev",
            "_default": "",
        }
        write_config(ctx, node, env)
        write_locald_configs(ctx, node, env)
    generate_haproxy_config(ctx)


def write_locald_configs(ctx, node, env):
    for locald_config in ctx["locald_configs"]:
        config_src = os.path.join(ctx["rootdir"], locald_config)
        if os.path.exists(config_src):
            config_filename = os.path.basename(config_src)
            config_tgt = os.path.join(
                ctx["devdir"], "lib", node, "etc", "local.d", config_filename
            )
            with open(config_src) as handle:
                content = handle.read()
            with open(config_tgt, "w") as handle:
                handle.write(content)


def generate_haproxy_config(ctx):
    haproxy_config = os.path.join(ctx["devdir"], "lib", "haproxy.cfg")
    template = os.path.join(ctx["rootdir"], "rel", "haproxy.cfg")

    with open(template) as handle:
        config = handle.readlines()

    out = []
    for line in config:
        match = re.match("(.*?)<<(.*?)>>(.*?)", line, re.S)
        if match:
            prefix, template, suffix = match.groups()
            for node in ctx["nodes"]:
                node_idx = int(node.replace("node", ""))
                text = template.format(
                    **{"node_idx": node_idx, "port": cluster_port(ctx, node_idx)}
                )
                out.append(prefix + text + suffix)
        else:
            out.append(line)

    with open(haproxy_config, "w") as handle:
        handle.write("\n".join(out))
        handle.write("\n")


def apply_config_overrides(ctx, content):
    for kv_str in ctx["config_overrides"]:
        key, val = kv_str.split("=")
        key, val = key.strip(), val.strip()
        match = "[;=]{0,2}%s.*" % key
        repl = "%s = %s" % (key, val)
        content = re.sub(match, repl, content)
    return content


def get_ports(ctx, idnode):
    assert idnode
    if idnode <= 5 and not ctx["auto_ports"]:
        return (
            (10000 * idnode) + 5984,
            (10000 * idnode) + 5986,
            (10000 * idnode) + 7986,
        )
    else:
        return tuple(get_available_ports(2))


def get_available_ports(num):
    ports = []
    while len(ports) < num + 1:
        with contextlib.closing(
            socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        ) as soc:
            soc.bind(("localhost", 0))
            _, port = soc.getsockname()
            if port not in ports:
                ports.append(port)
    return ports


def get_node_config(ctx, node_idx):
    node = "node{}".format(node_idx)
    config_dir = os.path.join(ctx["devdir"], "lib", node, "etc")
    config = ConfigParser()
    config.read(
        [os.path.join(config_dir, "default.ini"), os.path.join(config_dir, "local.ini")]
    )
    return config


def backend_port(ctx, n):
    return int(get_node_config(ctx, n).get("httpd", "port"))


def cluster_port(ctx, n):
    return int(get_node_config(ctx, n).get("chttpd", "port"))


def write_config(ctx, node, env):
    etc_src = os.path.join(ctx["rootdir"], "rel", "overlay", "etc")
    etc_tgt = ensure_dir_exists(ctx["devdir"], "lib", node, "etc")

    for fname in glob.glob(os.path.join(etc_src, "*")):
        base = os.path.basename(fname)
        tgt = os.path.join(etc_tgt, base)

        if os.path.isdir(fname):
            continue

        with open(fname) as handle:
            content = handle.read()

        for key in env:
            content = re.sub("{{%s}}" % key, str(env[key]), content)

        if base == "default.ini":
            content = hack_default_ini(ctx, node, content)
            content = apply_config_overrides(ctx, content)
        elif base == "local.ini":
            content = hack_local_ini(ctx, content)

        with open(tgt, "w") as handle:
            handle.write(content)

    ensure_dir_exists(etc_tgt, "local.d")


def boot_haproxy(ctx):
    if not ctx["with_haproxy"]:
        return
    config = os.path.join(ctx["devdir"], "lib", "haproxy.cfg")
    cmd = [ctx["haproxy"], "-f", config]
    logfname = os.path.join(ctx["devdir"], "logs", "haproxy.log")
    log = open(logfname, "w")
    env = os.environ.copy()
    if "HAPROXY_PORT" not in env:
        env["HAPROXY_PORT"] = ctx["haproxy_port"]
    return sp.Popen(
        " ".join(cmd), shell=True, stdin=sp.PIPE, stdout=log, stderr=sp.STDOUT, env=env
    )


def hack_default_ini(ctx, node, contents):
    contents = re.sub(
        "^\[httpd\]$",
        "[httpd]\nenable = true",
        contents,
        flags=re.MULTILINE,
    )

    if ctx["enable_erlang_views"]:
        contents = re.sub(
            "^\[native_query_servers\]$",
            "[native_query_servers]\nerlang = {couch_native_process, start_link, []}",
            contents,
            flags=re.MULTILINE,
        )
    contents = re.sub("n=3", "n=%s" % ctx["N"], contents)
    return contents


def hack_local_ini(ctx, contents):
    if ctx["with_admin_party"]:
        os.environ["COUCHDB_TEST_ADMIN_PARTY_OVERRIDE"] = "1"
        ctx["admin"] = ("Admin Party!", "You do not need any password.")
        return contents + "\n\n[chttpd_auth]\nsecret = %s\n" % COMMON_SALT

    # handle admin credentials passed from cli or generate own one
    if ctx["admin"] is None:
        ctx["admin"] = user, pswd = "root", gen_password()
    else:
        user, pswd = ctx["admin"]

    # this relies on [admin] being the last section at the end of the file
    contents = contents + "\n%s = %s" % (user, hashify(pswd))

    return contents + "\n\n[chttpd_auth]\nsecret = %s\n" % COMMON_SALT


def gen_password():
    # TODO: figure how to generate something more friendly here
    return base64.b64encode(os.urandom(6)).decode()


def hashify(pwd, salt=COMMON_SALT, iterations=10, keylen=20):
    """
    Implements password hashing according to:
      - https://issues.apache.org/jira/browse/COUCHDB-1060
      - https://issues.apache.org/jira/secure/attachment/12492631/0001-Integrate-PBKDF2.patch

    This test uses 'candeira:candeira'

    >>> hashify(candeira)
    -pbkdf2-99eb34d97cdaa581e6ba7b5386e112c265c5c670,d1d2d4d8909c82c81b6c8184429a0739,10
    """
    derived_key = pbkdf2_hex(pwd, salt, iterations, keylen)
    return "-pbkdf2-%s,%s,%s" % (derived_key, salt, iterations)


def startup(ctx):
    atexit.register(kill_processes, ctx)
    boot_nodes(ctx)
    ensure_all_nodes_alive(ctx)
    if ctx["no_join"]:
        return
    if ctx["with_admin_party"]:
        cluster_setup_with_admin_party(ctx)
    else:
        cluster_setup(ctx)
    if ctx["degrade_cluster"] > 0:
        degrade_cluster(ctx)


def kill_processes(ctx):
    for proc in ctx["procs"]:
        if proc and proc.returncode is None:
            proc.kill()


def degrade_cluster(ctx):
    if ctx["with_haproxy"]:
        haproxy_proc = ctx["procs"].pop()
    for i in range(0, ctx["degrade_cluster"]):
        proc = ctx["procs"].pop()
        if proc is not None:
            kill_process(proc)
    if ctx["with_haproxy"]:
        ctx["procs"].append(haproxy_proc)


@log("Stopping proc {proc.pid}")
def kill_process(proc):
    if proc and proc.returncode is None:
        proc.kill()


def boot_nodes(ctx):
    for node in ctx["nodes"]:
        ctx["procs"].append(boot_node(ctx, node))
    haproxy_proc = boot_haproxy(ctx)
    if haproxy_proc is not None:
        ctx["procs"].append(haproxy_proc)


def ensure_all_nodes_alive(ctx):
    status = dict((num, False) for num in list(range(ctx["N"])))
    for _ in range(10):
        for num in range(ctx["N"]):
            if status[num]:
                continue
            local_port = cluster_port(ctx, num + 1)
            url = "http://127.0.0.1:{0}/".format(local_port)
            try:
                check_node_alive(url)
            except:
                pass
            else:
                status[num] = True
        if all(status.values()):
            return
        time.sleep(1)
    if not all(status.values()):
        print("Failed to start all the nodes." " Check the dev/logs/*.log for errors.")
        sys.exit(1)


@log("Check node at {url}")
def check_node_alive(url):
    error = None
    for _ in range(10):
        try:
            with contextlib.closing(urlopen(url)):
                pass
        except Exception as exc:
            error = exc
            time.sleep(1)
        else:
            error = None
            break
    if error is not None:
        raise error


def set_boot_env(ctx):
    # fudge fauxton path
    if os.path.exists("src/fauxton/dist/release"):
        fauxton_root = "src/fauxton/dist/release"
    else:
        fauxton_root = "share/www"

    os.environ["COUCHDB_FAUXTON_DOCROOT"] = fauxton_root

    # fudge default query server paths
    couchjs = os.path.join(ctx["rootdir"], "src", "couch", "priv", "couchjs")
    mainjs = os.path.join(ctx["rootdir"], "share", "server", "main.js")
    coffeejs = os.path.join(ctx["rootdir"], "share", "server", "main-coffee.js")

    qs_javascript = toposixpath("%s %s" % (couchjs, mainjs))
    qs_coffescript = toposixpath("%s %s" % (couchjs, coffeejs))

    os.environ["COUCHDB_QUERY_SERVER_JAVASCRIPT"] = qs_javascript
    os.environ["COUCHDB_QUERY_SERVER_COFFEESCRIPT"] = qs_coffescript


@log("Start node {node}")
def boot_node(ctx, node):
    set_boot_env(ctx)
    env = os.environ.copy()
    env["ERL_LIBS"] = os.path.join(ctx["rootdir"], "src")

    node_etcdir = os.path.join(ctx["devdir"], "lib", node, "etc")
    reldir = os.path.join(ctx["rootdir"], "rel")

    cmd = [
        "erl",
        "-args_file",
        os.path.join(node_etcdir, "vm.args"),
        "-config",
        os.path.join(ctx["rootdir"], ctx["erlang_config"]),
        "-couch_ini",
        os.path.join(node_etcdir, "default.ini"),
        os.path.join(node_etcdir, "local.ini"),
        os.path.join(node_etcdir, "local.d"),
        "-reltool_config",
        os.path.join(reldir, "reltool.config"),
        "-parent_pid",
        str(os.getpid()),
        "-boot",
        os.path.join(ctx["devdir"], "devnode"),
        "-pa",
        ctx["devdir"],
        "-s monitor_parent",
    ]
    if ctx["reset_logs"]:
        mode = "wb"
    else:
        mode = "r+b"
    logfname = os.path.join(ctx["devdir"], "logs", "%s.log" % node)
    log = open(logfname, mode)
    if "extra_args" in ctx and ctx["extra_args"]:
        cmd += ctx["extra_args"].split(" ")
    cmd = [toposixpath(x) for x in cmd]
    return sp.Popen(cmd, stdin=sp.PIPE, stdout=log, stderr=sp.STDOUT, env=env)


@log("Running cluster setup")
def cluster_setup(ctx):
    lead_port = cluster_port(ctx, 1)
    if enable_cluster(ctx["N"], lead_port, *ctx["admin"]):
        for num in range(1, ctx["N"]):
            node_port = cluster_port(ctx, num + 1)
            node_name = ctx["nodes"][num]
            enable_cluster(ctx["N"], node_port, *ctx["admin"])
            add_node(lead_port, node_name, node_port, *ctx["admin"])
        finish_cluster(lead_port, *ctx["admin"])
    return lead_port


def enable_cluster(node_count, port, user, pswd):
    conn = httpclient.HTTPConnection("127.0.0.1", port)
    conn.request(
        "POST",
        "/_cluster_setup",
        json.dumps(
            {
                "action": "enable_cluster",
                "bind_address": "0.0.0.0",
                "username": user,
                "password": pswd,
                "node_count": node_count,
            }
        ),
        {
            "Authorization": basic_auth_header(user, pswd),
            "Content-Type": "application/json",
        },
    )
    resp = conn.getresponse()
    if resp.status == 400:
        resp.close()
        return False
    assert resp.status == 201, resp.read()
    resp.close()
    return True


def add_node(lead_port, node_name, node_port, user, pswd):
    conn = httpclient.HTTPConnection("127.0.0.1", lead_port)
    conn.request(
        "POST",
        "/_cluster_setup",
        json.dumps(
            {
                "action": "add_node",
                "host": "127.0.0.1",
                "port": node_port,
                "name": node_name,
                "username": user,
                "password": pswd,
            }
        ),
        {
            "Authorization": basic_auth_header(user, pswd),
            "Content-Type": "application/json",
        },
    )
    resp = conn.getresponse()
    assert resp.status in (201, 409), resp.read()
    resp.close()


def set_cookie(port, user, pswd):
    conn = httpclient.HTTPConnection("127.0.0.1", port)
    conn.request(
        "POST",
        "/_cluster_setup",
        json.dumps({"action": "receive_cookie", "cookie": generate_cookie()}),
        {
            "Authorization": basic_auth_header(user, pswd),
            "Content-Type": "application/json",
        },
    )
    resp = conn.getresponse()
    assert resp.status == 201, resp.read()
    resp.close()


def finish_cluster(port, user, pswd):
    conn = httpclient.HTTPConnection("127.0.0.1", port)
    conn.request(
        "POST",
        "/_cluster_setup",
        json.dumps({"action": "finish_cluster"}),
        {
            "Authorization": basic_auth_header(user, pswd),
            "Content-Type": "application/json",
        },
    )
    resp = conn.getresponse()
    # 400 for already set up'ed cluster
    assert resp.status in (201, 400), resp.read()
    resp.close()


def basic_auth_header(user, pswd):
    return "Basic " + base64.b64encode((user + ":" + pswd).encode()).decode()


def generate_cookie():
    return base64.b64encode(os.urandom(12)).decode()


def cluster_setup_with_admin_party(ctx):
    connect_nodes(ctx)
    host, port = "127.0.0.1", cluster_port(ctx, 1)
    create_system_databases(host, port)


def connect_nodes(ctx):
    host, port = "127.0.0.1", backend_port(ctx, 1)
    for node in ctx["nodes"]:
        path = "/_nodes/%s@127.0.0.1" % node
        try_request(
            host,
            port,
            "PUT",
            path,
            (200, 201, 202, 409),
            body="{}",
            error="Failed to join %s into cluster:\n" % node,
        )


def try_request(
    host, port, meth, path, success_codes, body=None, retries=10, retry_dt=1, error=""
):
    while True:
        conn = httpclient.HTTPConnection(host, port)
        conn.request(meth, path, body=body)
        resp = conn.getresponse()
        if resp.status in success_codes:
            return resp.status, resp.read()
        elif retries <= 0:
            assert resp.status in success_codes, "%s%s" % (error, resp.read())
        retries -= 1
        time.sleep(retry_dt)


def create_system_databases(host, port):
    for dbname in ["_users", "_replicator", "_global_changes"]:
        conn = httpclient.HTTPConnection(host, port)
        conn.request("HEAD", "/" + dbname)
        resp = conn.getresponse()
        if resp.status == 404:
            try_request(
                host,
                port,
                "PUT",
                "/" + dbname,
                (201, 202, 412),
                error="Failed to create '%s' database:\n" % dbname,
            )


@log(
    "Developers cluster is set up at http://127.0.0.1:{lead_port}.\n"
    "Admin username: {user}\n"
    "Password: {password}\n"
    "Time to hack!"
)
def join(ctx, lead_port, user, password):
    while True:
        for proc in ctx["procs"]:
            if proc is not None and proc.returncode is not None:
                exit(1)
        time.sleep(2)


@log("Exec command {cmd}")
def run_command(ctx, cmd):
    if ctx["no_eval"]:
        p = sp.Popen(cmd, shell=True)
        p.wait()
        exit(p.returncode)
    else:
        p = sp.Popen(cmd, shell=True, stdout=sp.PIPE, stderr=sys.stderr)
        while True:
            line = p.stdout.readline()
            if not line:
                break
            eval(line)
        p.wait()
        exit(p.returncode)


@log("Restart all nodes")
def reboot_nodes(ctx):
    ctx["reset_logs"] = False
    kill_processes(ctx)
    boot_nodes(ctx)
    ensure_all_nodes_alive(ctx)


if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        pass
