"""Utility functions for RNNs"""

# pylint: disable=unused-variable, invalid-name, too-many-locals
import math
import os
import sys

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.metrics import mean_squared_error
from sklearn.preprocessing import MinMaxScaler

font = {"size": 12}
matplotlib.rc("font", **font)


def make_train_test(data, train_fraction=0.67, rescale=True):
    """Create train and test data set from a data vector.

    Args:
      data (np.array): data array
      train_fraction (float): fraction of data devoted to training

    Returns:
      (train, test, data): 3-tuple of train, test, and original data
    """
    split = int(len(data) * train_fraction)
    if rescale:
        scaler = MinMaxScaler(feature_range=(0, 1))
    else:
        scaler = MinMaxScaler(feature_range=(min(data), max(data)))
    data = scaler.fit_transform(data).flatten()
    train = data[range(split)]
    test = data[split:]
    return train, test, scaler


def make_xy(data, window_size=1, step_size=1):
    """Create X, Y data pairs from a dataset vector.

    Args:
      data (float, int): dataset vector
      window_size (int): window size; number of dataset points looking back
      step_size (int): step size between windows

    Returns:
      (X, Y): X, Y data pair
    """
    X_indices = []
    X = []
    Y_indices = np.arange(window_size, len(data), step_size)
    Y = data[Y_indices]
    j = 0
    for i, _ in enumerate(Y_indices):
        ind = list(range(j, j + window_size))
        j = j + step_size
        X_indices.append(ind)
        X.append(data[ind])
    X = np.reshape(np.array(X), (len(Y), window_size, 1))
    return X, Y, np.array(X_indices), Y_indices


def plot_pred(data, scaler=None, rmse=True, plotmarkers=False, show=True, **kw):
    """Plot prediction and original data"""
    ticks = kw.pop("ticks", None)
    labels = kw.pop("labels", None)
    fig, ax = plt.subplots(**kw)
    legend = []
    markers = ["*", "x", "o"]
    colors = ["black", "steelblue", "darkred", "green"]
    x = []
    y = []
    shift = 0
    for k, v in data.items():
        if v is None:
            continue
        Ypred, Y, Y_indices = v
        X = np.arange(len(Y)) + shift
        shift = shift + len(X)
        if scaler is not None:
            Y = scaler.inverse_transform(Y.reshape(-1, 1)).flatten()
            Ypred = scaler.inverse_transform(Ypred.reshape(-1, 1)).flatten()
        e = math.sqrt(mean_squared_error(Y[Y_indices], Ypred))
        if rmse:
            k = f"{k} (RMSE: {e:.4f})"
        legend.append(k)
        col = colors.pop()
        ax.plot(X[Y_indices], Ypred, color=col)
        if plotmarkers:
            ax.plot(X[Y_indices], Ypred, markers.pop(), color=col)
        x.extend(X)
        y.extend(Y)
    legend.append("Data")
    ax.plot(x, y, "-", color=colors.pop())
    ax.set_title("Model prediction")
    if ticks is not None and labels is not None:
        ax.set_xticks(ticks, labels=labels)
    ax.legend(legend)
    if show:
        plt.show()


def plot_loss_acc(history):
    """Plot loss and accuracy of history"""
    try:
        plt.plot(history.history["accuracy"])
        plt.plot(history.history["val_accuracy"])
    except KeyError:
        plt.plot(history.history["acc"])
        plt.plot(history.history["val_acc"])

    plt.plot(history.history["loss"])
    plt.plot(history.history["val_loss"])
    plt.title("model accuracy")
    plt.ylabel("accuracy")
    plt.xlabel("epoch")
    plt.legend(["train acc", "val acc", "train loss", "val loss"], loc="upper left")
    plt.show()


def plot_history(history, show=True, xlim=None, ylim=None, **kw):
    """Plot history - plot training and/or test accuracy or loss values"""
    datalabels = ["Training", "Validation"]
    metrics_labels = {
        "loss": "loss",
        "acc": "accuracy",
        "accuracy": "accuracy",
        "mse": "mse",
        "recall": "recall",
    }
    if not isinstance(history, dict):
        history = history.history
    hkeys = history.keys()
    h = np.array([history[k] for k in hkeys])
    labels = [
        f"{x} {y}"
        for x, y in zip(
            [datalabels[u.startswith("val_")] for u in hkeys],
            [metrics_labels[v.replace("val_", "")] for v in hkeys],
        )
    ]
    fig, ax = plt.subplots(**kw)
    ax.plot(np.array(range(0, h.shape[1])), h.T)
    ax.set_title("Model metrics")
    ax.set_ylabel("Metric")
    ax.set_xlabel("Epoch")
    if xlim is not None:
        ax.set_xlim(xlim)
    if ylim is not None:
        ax.set_ylim(ylim)
    ax.legend(list(labels), loc="upper left")
    if show:
        plt.show()


def airlines():
    """Load and reformat airlines data set"""
    fn = "airline-passengers.csv"
    url = "https://raw.githubusercontent.com/jbrownlee/Datasets/master/{fn}"
    if not os.path.exists(fn):
        print(f"Download airline passenger data: 'wget {url} --no-check-certificate'")
        sys.exit(1)
    df = pd.read_csv(fn)
    df = df.rename(columns={"Month": "time", "Passengers": "passengers"})
    df["time"] = pd.to_datetime(df["time"], format="%Y-%m")
    df["year"] = pd.DatetimeIndex(df["time"]).year  # pylint: disable=E1101
    df["month"] = pd.DatetimeIndex(df["time"]).month  # pylint: disable=E1101
    return df