import pydoc
from plotly import express as px
from plotly import graph_objects as go
import plotly.io as pio
from plotly.subplots import make_subplots
from typing import List, Union, Tuple
import numpy as np
import shap

pio.renderers.default = "iframe+notebook+jupyterlab"


def override_help(func):
    def decorator(f):
        f.__doc__ = (
            "This method wraps the method {}. Below is help on {}:\n\n{}".format(
                func.__name__, func.__name__, pydoc.text.document(func)
            )
        )
        return f

    return decorator


def plot_global_pdp(
    explanation,
    plot_width: int = 1000,
    plot_height: int = None,
    plot_bgcolor: str = "white",
    plot_linecolor: str = "firebrick",
    plot_linewidth: int = 4,
    plot_lineopacity: float = 0.3,
    plot_quantile_width: float = 0.5,
    plot_quantile_color: str = "skyblue",
    plot_min_linecolor: str = "orange",
    plot_max_linecolor: str = "darkred",
    plot_grid: bool = False,
    plot_grid_color: str = "gray",
    plot_hist_color: str = "gray",
    renderer="notebook",
):

    exp_data = explanation["data"]
    all_features = exp_data.keys()

    fig = make_subplots(
        rows=len(all_features),
        cols=1,
        specs=[[{"secondary_y": True}] for _ in all_features],
        subplot_titles=tuple(all_features),
    )

    for i, feat in enumerate(all_features):
        fig.add_trace(
            go.Scatter(
                x=exp_data[feat]["feature_value"],
                y=exp_data[feat]["pdp_upper"],
                name="Upper Quantile",
                showlegend=i == 0,
                mode="lines",
                opacity=plot_lineopacity,
                line=dict(color=plot_quantile_color, width=plot_quantile_width),
                hovertemplate="<b>%{y:.2f}</b><br>",
            ),
            secondary_y=False,
            row=i + 1,
            col=1,
        )

        fig.add_trace(
            go.Scatter(
                x=exp_data[feat]["feature_value"],
                y=exp_data[feat]["pdp_lower"],
                name="Lower Quantile",
                showlegend=i == 0,
                fill="tonexty",
                mode="lines",
                opacity=plot_lineopacity,
                line=dict(color=plot_quantile_color, width=plot_quantile_width),
                hovertemplate="<b>%{y:.2f}</b><br>",
            ),
            secondary_y=False,
            row=i + 1,
            col=1,
        )

        fig.add_trace(
            go.Scatter(
                x=exp_data[feat]["feature_value"],
                y=exp_data[feat]["pdp_mean"],
                name="PDP",
                showlegend=i == 0,
                line=dict(color=plot_linecolor, width=plot_linewidth),
                hovertemplate="<b>%{y:.2f}</b><br>",
            ),
            secondary_y=False,
            row=i + 1,
            col=1,
        )

        fig.add_trace(
            go.Scatter(
                x=exp_data[feat]["feature_value"],
                y=exp_data[feat]["pdp_max"],
                name="Maximum",
                mode="lines",
                showlegend=i == 0,
                line=dict(color=plot_max_linecolor, width=plot_quantile_width),
                hovertemplate="<b>%{y:.2f}</b><br>",
            ),
            secondary_y=False,
            row=i + 1,
            col=1,
        )

        fig.add_trace(
            go.Scatter(
                x=exp_data[feat]["feature_value"],
                y=exp_data[feat]["pdp_min"],
                name="Minimum",
                mode="lines",
                showlegend=i == 0,
                line=dict(color=plot_min_linecolor, width=plot_quantile_width),
                hovertemplate="<b>%{y:.2f}</b><br>",
            ),
            secondary_y=False,
            row=i + 1,
            col=1,
        )

        if "feature_dist" in exp_data[feat]:
            fig.add_trace(
                go.Histogram(
                    x=exp_data[feat]["feature_dist"],
                    marker_color=plot_hist_color,
                    histnorm="percent",
                    name="feature distribution",
                    showlegend=(i == 0),
                ),
                secondary_y=True,
                row=i + 1,
                col=1,
            )

    req_plot_height = len(all_features) * 250
    if plot_height is None:
        plot_height = req_plot_height

    fig.update_layout(
        width=plot_width,
        height=plot_height,
        title_text=f"Partial Dependency Plots",
        xaxis_title=f"PDP Plot",
        plot_bgcolor=plot_bgcolor,
        hovermode="x unified",
    )
    fig.update_xaxes(showgrid=plot_grid, gridwidth=0.1, gridcolor=plot_grid_color)
    fig.update_yaxes(
        title_text="<b>Partial Dependency Plot</b>",
        showgrid=plot_grid,
        gridwidth=0.1,
        gridcolor=plot_grid_color,
        secondary_y=False,
    )
    fig.update_yaxes(
        title_text="<b>Feature distribution</b>",
        range=[0, 500],
        tickvals=[0, 50, 100],
        secondary_y=True,
    )
    return fig


def plot_global_shap(
    explanation,
    top_k: int = 8,
    plot_type: str = "scatter",
    plot_width: int = 1000,
    plot_height: int = 150,
    plot_bgcolor: str = "white",
    plot_color: str = "tomato",
    renderer="notebook",
    **kwargs,
):

    exp_data = explanation["data"]
    feat_names = list(exp_data["feature_names"])
    feat_values = np.asarray(exp_data["feature_values"])
    feat_contrib = np.asarray(exp_data["shap_values"])

    n_features = len(feat_names)

    n_feature = min(n_features, top_k)
    features_avg_contrib = feat_contrib.mean(axis=0)
    ordered_features = np.argsort(np.abs(features_avg_contrib))[::-1][:n_features]

    n_columns = min(n_feature, 3)
    n_rows = int(np.ceil(n_features / n_columns))

    ordered_feature_names = [
        feat_names[f_id].replace("_", " ").upper() for f_id in ordered_features
    ]

    if plot_type == "scatter":
        spec = [
            [
                {"type": "xy"} if i * n_columns + j < n_feature else None
                for j in range(n_columns)
            ]
            for i in range(n_rows)
        ]

        fig = make_subplots(
            rows=n_rows,
            cols=n_columns,
            specs=spec,
            subplot_titles=ordered_feature_names,
        )

        counter = 0
        for i in range(n_rows):
            if counter == n_feature:
                break
            for j in range(n_columns):
                feat_name = feat_names[ordered_features[counter]]
                feat_idx = feat_names.index(feat_name)
                x = feat_values[:, feat_idx]
                y = feat_contrib[:, feat_idx]
                fig.add_trace(
                    go.Scatter(x=x, y=y, mode="markers", marker_color=plot_color),
                    row=i + 1,
                    col=j + 1,
                )
                counter += 1
                if counter == n_feature:
                    break
        fig.update_layout(showlegend=False)
        fig.update_layout(
            width=plot_width, height=n_rows * plot_height, plot_bgcolor=plot_bgcolor
        )
        return fig
    elif plot_type == "forceplot":
        fig = shap.force_plot(
            exp_data["model_average"],
            np.asarray(exp_data["shap_values"]),
            np.asarray(exp_data["feature_values"]),
            exp_data["feature_names"],
        )
        return fig


def plot_local_groupedce(
    explanation,
    plot_width: int = 250,
    plot_height: int = 250,
    plot_bgcolor: str = "white",
    plot_line_width: int = 2,
    plot_instance_size: int = 15,
    plot_instance_color: str = "firebrick",
    plot_instance_width: int = 4,
    plot_contour_coloring: str = "heatmap",
    plot_contour_color: Union[str, List[Tuple[float, str]]] = "Portland",
    renderer="notebook",
    **kwargs,
):

    exp_data = explanation["data"]
    feat_dict = {k: len(exp_data[k].keys()) for k in exp_data.keys()}
    features = sorted(feat_dict, key=lambda l: feat_dict[l])
    n_feat = len(features)

    specs = [
        [{} if i <= j else None for j in range(n_feat - 1)] for i in range(n_feat - 1)
    ]

    fig = make_subplots(
        rows=n_feat - 1,
        cols=n_feat - 1,
        specs=specs,
        shared_xaxes="columns",
        shared_yaxes="rows",
        column_titles=features[1:],
        row_titles=features[:-1],
    )

    for x_i in range(n_feat):
        for y_i in range(n_feat):
            if y_i < x_i:
                x_feat = features[x_i]
                y_feat = features[y_i]
                z = exp_data[x_feat][y_feat]["values"]
                x_g = exp_data[x_feat][y_feat]["x_grid"]
                y_g = exp_data[x_feat][y_feat]["y_grid"]
                fig.add_trace(
                    go.Contour(
                        z=z,
                        x=x_g,
                        y=y_g,
                        connectgaps=True,
                        line_smoothing=0.5,
                        contours_coloring=plot_contour_coloring,
                        contours_showlabels=True,
                        line_width=plot_line_width,
                        coloraxis="coloraxis1",
                        hovertemplate="<b>"
                        + str(x_feat)
                        + "</b>: %{x:.2f}<br>"
                        + "<b>"
                        + str(y_feat)
                        + "</b>: %{y:.2f}<br>"
                        + "<b>prediction</b>: %{z:.2f}<br><extra></extra>",
                    ),
                    row=y_i + 1,
                    col=x_i,
                )
                if "current_values" in exp_data[x_feat][y_feat]:
                    x = exp_data[x_feat][y_feat]["current_values"][x_feat]
                    y = exp_data[x_feat][y_feat]["current_values"][y_feat]
                    fig.add_trace(
                        go.Scatter(
                            mode="markers",
                            marker_symbol="x",
                            x=[x],
                            y=[y],
                            marker_color=plot_instance_color,
                            marker_line_color=plot_instance_color,
                            marker_size=plot_instance_size,
                            marker_line_width=plot_instance_width,
                            showlegend=False,
                            hovertemplate="{}: {:.2f}<br> {}: {:.2f}<extra></extra>".format(
                                x_feat, x, y_feat, y
                            ),
                        ),
                        row=y_i + 1,
                        col=x_i,
                    )

    fig.update_layout(
        height=(n_feat - 1) * plot_height,
        width=(n_feat - 1) * plot_width,
        plot_bgcolor=plot_bgcolor,
        coloraxis_autocolorscale=False,
        coloraxis_colorscale=plot_contour_color,
        title_text=f"Grouped Conditional Expectation Plots",
        # xaxis_title=f"",
    )
    return fig
