""" A class for creating plots with matplotlib. Please note that the formats are partially incoherent
and not the best example.
Comes with capacities for various x-y-plots, surface plots, heatmaps, and saving figures-

Usage example:
    plot_frame = Plotter()
    plot_frame.create_figure(dpi=600)
    x_data = np.arange(10)
    y_data = np.random.rand(1,10)
    plot_frame.make_x_y_plot(x_data, y_data)
    plot_frame.save_figure(save_fig_dir="C:/temp/plots/example.png"):

"""

try:
    import os
except Exception as problem:
    print("ExceptionERROR: Missing fundamental packages (required: os).")
    print(problem)

try:
    import numpy as np
    import matplotlib
    import matplotlib.pyplot as plt
    import matplotlib.font_manager as font_manager
    from mpl_toolkits.mplot3d import Axes3D  # imports 3D projection
    from matplotlib.ticker import LinearLocator, FormatStrFormatter
except Exception as problem:
    print("ExceptionERROR: Could not import matplotlib and/or numpy.")
    print(problem)


class Plotter:
    def __init__(self, *args):
        """
        :param args: dummy
        """
        
        self.set_default_parameters()
        self.dir = os.path.abspath(os.path.dirname(__file__)) + os.sep

    def create_figure(self):
        """
        :return: matplotlib.pyplot.figure
        """
        try:
            self.update_fonts()
            return plt.figure(
                figsize=(self.width, self.height),
                dpi=self.resolution,
                facecolor=self.face_color,
                edgecolor=self.edge_color
            )
        except Exception as problem:
            print("ERROR: Could not create figure.")
            print(problem)
            return -1

    def get_color_map(self, plot_data_size):
        """ Create a colormap of size plot_data_size

        :param plot_data_size: INT (for x-y plot - the number of y-graphs, for surface plots: size(Z) elements)
        :return: matplotlib.pyplot.cm.get_cmap)
        """

        try:
            return plt.cm.get_cmap(self.color_map_type, plot_data_size + 1)  # +1 because lowest color is nearby white
        except Exception as problem:
            print(
                "Could not create colormap.\n Valid color_map_types are: jet, Oranges,inferno, plasma, Greens, Blues, PuRd, RdPu, Reds, Greys")
            print("         More options are available: http://matplotlib.org/users/colormaps.html")
            print(problem)

    def annotate_heatmap(
            self,
            im,
            data=None,
            valfmt="{x:.2f}",
            textcolors=["black", "white"],
            threshold=None,
            **textkw
    ):
        """ Annotate a heatmap function from matplotlib.org

        Arguments:
            im         : The AxesImage to be labeled.
        Optional arguments:
            data       : Data used to annotate. If None, the image"s data is used.
            valfmt     : The format of the annotations inside the heatmap.
                         This should either use the string format method, e.g.
                         "$ {x:.2f}", or be a :class:`matplotlib.ticker.Formatter`.
            textcolors : A list or array of two color specifications. The first is
                         used for values below a threshold, the second for those
                         above.
            threshold  : Value in data units according to which the colors from
                         textcolors are applied. If None (the default) uses the
                         middle of the colormap as separation.
        """

        if not isinstance(data, (list, np.ndarray)):
            data = im.get_array()

        # Normalize the threshold to the images color range.
        if threshold is not None:
            threshold = im.norm(threshold)
        else:
            threshold = im.norm(data.max()) / 2.

        # Set default alignment to center, but allow it to be overwritten by textkw.
        kw = dict(horizontalalignment="center", verticalalignment="center")
        kw.update(textkw)

        # Get the formatter in case a string is supplied
        if isinstance(valfmt, str):
            valfmt = matplotlib.ticker.StrMethodFormatter(valfmt)

        # Loop over the data and create a `Text` for each "pixel". Change the text"s color depending on the data.
        texts = []
        for i in range(data.shape[0]):
            for j in range(data.shape[1]):
                try:
                    kw.update(color=textcolors[im.norm(data[i, j]) > threshold])
                    text = im.axes.text(j, i, valfmt(data[i, j], None), **kw)
                    texts.append(text)
                except Exception as problem:
                    print("WARNING: Heatmap annotation failed.")
                    print(problem)
        return texts

    def heatmap(self, data, row_labels, col_labels, ax=None, cbar_kw={}, cbarlabel="", **kwargs):
        """ Create a heatmap from a numpy array and two lists of labels (source: matplolib.org).

        Arguments:
            data       : A 2D numpy array of shape (N,M)
            row_labels : A list or array of length N with the labels
                         for the rows
            col_labels : A list or array of length M with the labels
                         for the columns
        Optional arguments:
            axe        : A matplotlib.axes.Axes instance to which the heatmap
                         is plotted. If not provided, use current axes or
                         create a new one.
            cbar_kw    : A dictionary with arguments to
                         :meth: matplotlib.Figure.colorbar.
            cbarlabel  : The label for the colorbar
        """
        try:
            if not ax:
                ax = plt.gca()
        except Exception as problem:
            print("WARNING: Heatmap axe identification failed.")
            print(problem)

        # Plot the heatmap
        try:
            im = ax.imshow(data, **kwargs)  # uses vmin and vmax
        except Exception as problem:
            print("ERROR: Heatmap creation failed (ax.imshow(data)).")
            print(problem)
            return -1

        # Create colorbar
        try:
            cbar = ax.figure.colorbar(
                im,
                ax=ax,
                shrink=self.colorbar_shrink,
                aspect=self.colorbar_aspect,
                format=self.colorbar_format,
                **cbar_kw
            )
        except Exception as problem:
            print("WARNING: Heatmap colorbar creation failed.")
            print(problem)
        try:
            cbar.ax.set_ylabel(self.colorbar_label, rotation=-90, va="bottom", **self.hfont)
        except Exception as problem:
            print("WARNING: Heatmap colorbar arrangement failed.")
            print(problem)

        # We want to show all ticks...
        try:
            ax.set_xticks(np.arange(data.shape[1]))
            ax.set_yticks(np.arange(data.shape[0]))
        except Exception as problem:
            print("WARNING: Heatmap tick setting failed.")
            print(problem)

        # Let the horizontal axes labeling appear on top
        try:
            # ... and label them with the respective list entries
            ax.set_xticklabels(col_labels, **self.hfont)
            ax.set_yticklabels(row_labels, **self.hfont)
            ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False)
        except Exception as problem:
            print("WARNING: Heatmap tick labeling and arrangement failed.")
            print(problem)

        # Rotate the tick labels and set their alignment
        try:
            plt.setp(ax.get_xticklabels(), rotation=-30, ha="right", rotation_mode="anchor", **self.hfont)
        except Exception as problem:
            print("WARNING: Heatmap tick label rotation failed.")
            print(problem)

        # Turn spines off and create white grid
        try:
            for edge, spine in ax.spines.items():
                spine.set_visible(False)
        except Exception as problem:
            print("WARNING: Heatmap spine and grid modifications failed.")
            print(problem)

        try:
            ax.set_xticks(np.arange(data.shape[1] + 1) - .5, minor=True)
            ax.set_yticks(np.arange(data.shape[0] + 1) - .5, minor=True)
            ax.grid(which="minor", color="w", linestyle=self.data_line_style[0], linewidth=self.data_line_width)
            ax.tick_params(which="minor", bottom=False, left=False)
        except Exception as problem:
            print("WARNING: Heatmap line-up failed.")
            print(problem)

        if self.plot_title_mode:
            try:
                ax.set_title(self.plot_title)
            except Exception as problem:
                print("WARNING: Failed to make plot title.")
                print(problem)

        return im, cbar

    def make_heatmap(self, Z, x_labels, y_labels):
        """ Make a heatmap plot
        Read more at https://matplotlib.org/gallery/images_contours_and_fields/image_annotated_heatmap.html

        :param Z: NESTED LIST with size = (y_size, x_size) -- [y_size*[x_size elements]]
        :param x_labels: list of strings
        :param y_labels: list of strings
        :return: -1 if failes
        """
        self.update_fonts()
        try:
            fig, axe = plt.subplots(
                figsize=(self.width, self.height),
                dpi=self.resolution,
                facecolor=self.face_color,
                edgecolor=self.edge_color
            )
        except Exception as problem:
            print("ERROR: Could not initiate figure.")
            print(problem)
            return -1
        Z = np.array(Z)

        try:
            im, cbar = self.heatmap(Z, y_labels, x_labels, ax=axe,
                                    vmin=self.colorbar_min_val,
                                    vmax=self.colorbar_max_val,
                                    cmap=self.color_map_type,
                                    cbarlabel=self.colorbar_label)
        except Exception as problem:
            print("ERROR: Could not create heatmap.")
            print(problem)
            return -1

    def make_surface_plot(self, x_data, y_data, Z, *args, **kwargs):
        """Make a surface plot of Z values on an x-y array

        :param x_data: list or 1d array of x coordinates
        :param y_data: list or 1d array of y coordindates
        :param Z: NESTED LIST with size = (y_size, x_size) -- [y_size*[x_size elements]]
        :param args:
        :param kwargs:
              plot_type = STR: "surface", "scatter3D", "trisurf", "contour", "contourf", "pcolormesh", "streamplot"
              projection_type = STR: "2D" or "3D"
        :return:
        """

        for opt_var in kwargs.items():
            if "plot_type" in opt_var[0]:
                # type of 3D/2D plot - default= 2D
                plot_type = opt_var[1]

        if not ("plot_type" in locals()):
            plot_type = "surface"

        three_d_plot_types = ["surface", "trisurf", "scatter3D"]
        if plot_type in three_d_plot_types:
            projection_type = "3D"
        else:
            projection_type = "2D"

        fig = self.create_figure()

        try:
            if projection_type == "2D":
                axe = fig.add_subplot(self.subplot_rows, self.subplot_cols, self.subplot_index)
            else:
                axe = fig.gca(projection="3d")
        except Exception as problem:
            print("ERROR: Could not create axe (fig.gca(projection=3d) failed).")
            print(problem)
            return -1

        color_map = self.get_color_map(Z.__len__())

        try:
            X, Y = np.meshgrid(x_data, y_data)
        except Exception as problem:
            print("ERROR: Could not create np.meshgrid with x_data and y_data.")
            print(problem)

        try:
            if plot_type == "surface":  # verified
                surf = axe.plot_surface(X, Y, Z, cmap=color_map, linewidth=self.data_line_width, antialiased=False)
            if plot_type == "contour":  # verified
                surf = axe.contour(X, Y, Z, cmap=color_map, linewidths=self.data_line_width,
                                   linestyles=self.data_line_style[0], alpha=self.alpha_value, antialiased=False)
            if plot_type == "contourf":  # verified
                surf = axe.contourf(X, Y, Z, self.contour_interval_no, cmap=color_map, alpha=self.alpha_value,
                                    antialiased=False)
            if plot_type == "pcolormesh":  # not yet verified
                surf = axe.pcolormesh(X, Y, Z, cmap=color_map, linewidth=self.data_line_width, antialiased=False,
                                      alpha=self.alpha_value)
            if plot_type == "trisurf":  # not yet verified
                surf = axe.plot_trisurf(X, Y, Z, cmap=color_map, linewidth=self.data_line_width)
            if plot_type == "scatter3D":  # not yet verified
                surf = axe.scatter3D(X, Y, Z)
            if plot_type == "streamplot":  # plot vectors of velocities -- not yet verified
                # Z[1] = 2D array of x-velocites (u)
                # Z[2] = 2D array of y-velocites (v)
                surf = axe.streamplot(X, Y, Z[1], Z[2], cmap=color_map, linewidth=self.data_line_width,
                                      arrowsize=self.stream_arrow_size, arrowstyle=self.stream_arrows_style)
        except Exception as problem:
            print("WARNING: Plotting failed.")
            print(problem)

        if projection_type == "3D":
            axe.zaxis.set_major_formatter(FormatStrFormatter(self.number_format))

        # Labels
        try:
            axe.set_xlabel(self.x_label, **self.hfont)
            axe.set_ylabel(self.y_label, **self.hfont)
            if projection_type == "3D":
                axe.set_zlabel(self.z_label, **self.hfont)
        except Exception as problem:
            print("WARNING: Undefined x, y and/or z axis labels.")
            print(problem)

        if self.legend_active:
            fig.colorbar(surf, shrink=self.colorbar_shrink, aspect=self.colorbar_aspect)

        self.setup_figure(fig, axe)

    def make_x_y_plot(self, x_data, y_data, *args, **kwargs):
        """ plot y data against and x series
        :param (list) x_data: [x_series]
        :param (nested list) y_data: [[y_series1], [y_series2], ... ]
        :param args:
        :param kwargs:
            plot_type = STR: "plot", "bar", "barh", "hist", "scatter"
        :return:
        """

        # parse optional arguments
        try:
            for opt_var in kwargs.items():
                if "plot_type" in opt_var[0]:
                    # type of 3D/2D plot
                    plot_type = opt_var[1]
        except:
            pass
        if not ("plot_type" in locals()):
            plot_type = "plot"

        fig = self.create_figure()

        try:
            axe = fig.add_subplot(self.subplot_rows, self.subplot_cols,
                                  self.subplot_index)  # , sharex=self.subplot_share_x, sharey=self.subplot_share_y)
        except Exception as problem:
            print("ERROR: Could not create axe (add_subplot failed).")
            print(problem)
            return -1

        color_map = self.get_color_map(y_data.__len__())

        graph_no = 0
        for y in y_data:
            try:
                if plot_type == "plot":  # verified
                    axe.plot(x_data, y, linestyle=self.data_line_style[graph_no], color=color_map(graph_no + 1),
                             label=self.data_labels[graph_no])
                if plot_type == "bar":  # not yet verified
                    axe.bar(x_data, y, color=self.bar_color, yerr=self.y_err_type, label=self.data_labels[graph_no])
                if plot_type == "barh":  # horizontal bar plot -- not yet verified
                    axe.barh(x_data, y, color=self.bar_color, xerr=self.x_err_type, label=self.data_labels[graph_no])
                if plot_type == "hist":  # histogram -- not yet verified
                    axe.hist(x_data, self.hist_class_numbers)
                if plot_type == "scatter":  # histogram -- not yet verified
                    axe.scatter(x_data, y, color=color_map(graph_no + 1), label=self.data_labels[graph_no])
            except:
                try:
                    axe.plot(x_data, y, linestyle="-", color=color_map(graph_no + 1), label="series" + str(graph_no))
                    print(
                        "WARNING: Used default graph labels. Consider setting as many data_line_styles and data_labels as there are graphs.")
                except Exception as problem:
                    print("ERROR: Plotting of graph no. " + str(graph_no) + " failed.")
                    print(problem)
            graph_no += 1

        # Labels
        try:
            axe.set_xlabel(self.x_label, **self.hfont)
            axe.set_ylabel(self.y_label, **self.hfont)
        except Exception as problem:
            print("WARNING: Undefined x and/or y axis labels.")
            print(problem)

        if self.legend_active:
            axe.legend(loc=self.legend_loc, prop=self.font, facecolor=self.legend_face_color,
                       edgecolor=self.legend_edge_color, framealpha=self.legend_frame_alpha, fancybox=0)

        self.setup_figure(fig, axe)

    def save_figure(self, save_fig_dir):
        """ Save figure to disk

        :param matplotlib.pyplot.figure fig: instance of matplotlib.pyplot
        :param str save_fig_dir: directory where to save the figure
        :return:
        """
        try:
            plt.savefig(save_fig_dir, bbox_inches=self.fig_boxes)
            print(" * Saved figure as: " + save_fig_dir)
        except Exception as problem:
            print("WARNING: Could not save figure as path:\n  " + save_fig_dir)
            print("         Hint: .JPG is not supported (use .pdf or .png).")
            print(problem)

    def set_default_parameters(self):
        """ Instantiate font and style definitions
        """
        self.font_family = "sans-serif"
        self.font_name = "Arial"
        self.font_size = 10.0
        self.font_style = "normal"
        self.font_weight = "medium"
        self.update_fonts()
        self.number_format = "%02f"

        # AXES DEFINITIONS
        self.grid_line_color = "gray"
        self.grid_line_style = "-"
        self.grid_line_width = 0.5

        self.x_label = "undefined x-label"
        self.x_lim = (0, 1)
        self.x_lim_mode = False
        self.x_tick_mode = False
        self.x_ticks = []
        self.x_err_type = []  # list of numbers corresponding to no. of x-points

        self.y_label = "undefined y-label"
        self.y_lim = (0, 1)
        self.y_lim_mode = False
        self.y_ticks = []
        self.y_tick_mode = False
        self.y_err_type = []  # list of numbers corresponding to no. of y-points

        self.z_label = "undefined z-label"
        self.z_lim = (0, 1)
        self.z_lim_mode = False
        self.z_ticks = []
        self.z_tick_mode = False

        # DATA LABELLING AND LAYOUT
        self.data_labels = [""]  # list of strings
        self.data_line_style = ["-"]
        self.data_line_width = 0.5
        self.plot_title = "Title"
        self.plot_title_mode = False

        # COLOR DEFINITIONS
        self.alpha_value = 1.0  # FLOAT between 0 and 1
        self.face_color = "w"
        self.edge_color = "k"
        # define colormap -- more options: http://matplotlib.org/users/colormaps.html
        # alternative cmaps: jet, Oranges,inferno, plasma, Greens, Blues, PuRd, RdPu, Reds, Greys
        self.color_map_type = "Greys"
        self.bar_color = "green"

        # LEGEND SETTINGS
        # more legend options: https://matplotlib.org/api/_as_gen/matplotlib.pyplot.legend.html#matplotlib.pyplot.legend
        self.legend_active = False
        self.legend_edge_color = "gray"
        self.legend_face_color = "w"
        self.legend_frame_alpha = 1
        self.legend_loc = "upper_right"

        # FIGURE CONFIGURATION
        self.width = 6.0  # FLOAT defining inches
        self.height = 4.0  # FLOAT defining inches
        self.resolution = 300  # INT defining dpi
        self.fig_boxes = "tight"  # or INT in inches
        self.show_fig = False

        self.subplot_cols = 1
        self.subplot_rows = 1
        self.subplot_index = 1
        self.subplot_share_x = False
        self.subplot_share_y = False

        # COLORBAR SETTINGS
        # https://matplotlib.org/api/_as_gen/matplotlib.pyplot.colorbar.html#matplotlib.pyplot.colorbar
        self.colorbar_aspect = 10
        self.colorbar_draw_edges = False
        self.colorbar_format = "%.2f"
        self.colorbar_label = ""
        self.colorbar_shrink = 1.0
        self.colorbar_ticks = None  # LIST of ticks
        self.colorbar_min_val = 0
        self.colorbar_max_val = 1

        # PLOT SPECIFIC
        self.contour_interval_no = 10
        self.heatmap_annotation = ""  # STR of format: "{x:.1f} t"
        self.hist_class_numbers = 10  # INT required for histogram plots
        self.stream_arrow_size = 2
        self.stream_arrow_style = "-|>"  # other options: ->, -, -[, <-, <->, <|-|>, ]-, ]-[, |-| (Bar)

    def setup_figure(self, fig, axe):
        """ Figure setup and rendering

        :param fig: instance of matplotlib.plt
        :param axe: instance of matplotlib.plt
        """
        # Ticks
        if self.x_tick_mode:
            try:
                plt.xticks(self.x_ticks, **self.hfont)
            except Exception as problem:
                print("WARNING: x_tick_mode active but no valid x_ticks and/or hfont set.")
                print(problem)
        if self.y_tick_mode:
            try:
                plt.yticks(self.y_ticks, **self.hfont)
            except Exception as problem:
                print("WARNING: y_tick_mode active but no valid y_ticks and/or hfont set.")
                print(problem)
        if self.z_tick_mode:
            try:
                plt.yticks(self.z_ticks, **self.hfont)
            except Exception as problem:
                print("WARNING: z_tick_mode active but no valid z_ticks and/or hfont set.")
                print(problem)

                # control grid
        axe.grid(color=self.grid_line_color, linestyle=self.grid_line_style, linewidth=self.grid_line_width)
        if self.x_lim_mode:
            try:
                axe.set_xlim(self.x_lim)
            except Exception as problem:
                print("WARNING: x_lim_mode active but no valid x_lim (tuple= (min, max)) defined.")
                print(problem)
        if self.y_lim_mode:
            try:
                axe.set_ylim(self.y_lim)
            except Exception as problem:
                print("WARNING: y_lim_mode active but no valid y_lim (tuple= (min, max)) defined.")
                print(problem)
        if self.z_lim_mode:
            try:
                axe.set_zlim(self.z_lim)
            except Exception as problem:
                print("WARNING: z_lim_mode active but no valid z_lim (tuple= (min, max)) defined.")
                print(problem)

        if self.plot_title_mode:
            axe.set_title(self.plot_title)

        if self.show_fig:
            fig.show()
            plt.show()

    def update_fonts(self):
        """ Update font definitions. Make sure the requested font is installed on your system.
        More about font settings at https://matplotlib.org/users/customizing.html
        """
        self.hfont = {
            "family": self.font_family,
            "weight": self.font_weight,
            "size": self.font_size,
            "style": self.font_style,
             "fontname": self.font_name
        }
        self.font = font_manager.FontProperties(
            family=self.hfont["fontname"],
            weight=self.hfont["weight"],
            style=self.hfont["style"],
            size=self.hfont["size"]
        )
        matplotlib.rcParams.update({"font.family": self.font_family})
        matplotlib.rcParams.update({"font.weight": self.font_weight})
        matplotlib.rcParams.update({"font.size": self.font_size})
        matplotlib.rcParams.update({"font.style": self.font_style})
        matplotlib.rcParams.update({"font.sans-serif": self.font_name})
        matplotlib.rcParams.update({"font.serif": "Times"})

    def __call__(self):
        print("Class Info: <type> = Plotter (uses matplotlib library)")