import os
import json
import shap
import copy
import requests
from typing import List, Literal
from enum import Enum
import numpy as np
import pandas as pd
import aiohttp
from aiohttp import ClientTimeout
import asyncio
import nest_asyncio
from functools import partial
from dateutil.parser import parse
from aiohttp.client_exceptions import ServerDisconnectedError
from plotly import express as px
from plotly import graph_objects as go
import plotly.io as pio
from aiexpts.plot_utils import (
    override_help,
    plot_global_pdp,
    plot_local_groupedce,
    plot_global_shap,
)

pio.renderers.default = "iframe+notebook+jupyterlab"


class Explanation_Scope(Enum):
    LOCAL = 0
    GLOBAL = 1
    LOCAL_AND_GLOBAL = 2


class Explanation_Type(Enum):
    SINGLE_INSTANCE = 0
    BATCH = 1
    FULL = 2


class AIXTSClient(object):
    def __init__(self, url: str, port: int = 5001, auth={}):
        self._service_url = url
        prefix = "/ibm/aix/service/v1"
        if port is not None:
            self._service_url += ":" + str(port) + prefix
        else:
            self._service_url += prefix

        if not auth.get("APM_ID", "") or not auth.get("APM_API_KEY", ""):
            print(
                "Please include APM_ID and APM_API_KEY in auth object in init parameters."
            )

        self.headers = {
            "APM_ID": auth.get("APM_ID", ""),
            "APM_API_KEY": auth.get("APM_API_KEY", ""),
        }
        self.connect()
        nest_asyncio.apply()

    def connect(self):
        """
        Starts a new session
        """
        conn = aiohttp.TCPConnector(verify_ssl=False)
        self.session = aiohttp.ClientSession(connector=conn, headers=self.headers)

    async def _get_service_health(self):
        """
        service statistics
        """
        try:
            resp = await self.session.get(
                self._service_url + "/health", timeout=ClientTimeout(total=None)
            )

            if resp.status in [200]:
                res_json = await resp.json()
                res = {
                    "service_version": res_json["service_version"],
                    "num_configured_models": res_json["num_configured_models"],
                }
                return res
            else:
                if resp.content_type == "application/json":
                    return await resp.json()
                elif resp.content_type == "text/html":
                    return await resp.text()
                else:
                    return resp
        except Exception as ex:
            print("error {}".format(ex))
            return {"error": ex}

    def get_service_health(self):
        """
        service statistics
        """
        loop = asyncio.get_event_loop()
        return loop.run_until_complete(self._get_service_health())

    def get_available_explainers(self):
        """
        Returns list of explainer names, type (local vs global), parameter keys.
        """
        # to do: A generic way to fetch all classes and names from explainers module.
        resp = requests.get(
            self._service_url + "/available_explainers",
            verify=False,
            headers=self.headers,
        )
        if resp.ok:
            res_json = resp.json()
            return res_json["available_explainers"]
        else:
            if resp.headers["content-type"] == "application/json":
                return resp.json()
            elif resp.headers["content-type"] == "text/html":
                return resp.text
            else:
                return resp

    async def _set_explainer(self, request):
        try:
            resp = await self.session.post(
                self._service_url + "/set_explainer",
                json=request,
                timeout=ClientTimeout(total=None),
            )
            if resp.status in [200]:
                return "SUCCESS"
            else:
                if resp.content_type == "application/json":
                    return await resp.json()
                elif resp.content_type == "text/html":
                    return await resp.text()
                else:
                    return resp
        except ServerDisconnectedError as ex:
            msg = "Request is taking time. The job is running in background."
            print(msg)
            return {"message": msg}
        except Exception as ex:
            print("error {}".format(ex))
            return {"error": ex}

    def set_explainer(
        self,
        model_deployment_id: str,
        space_id: str,
        explainer_name: str,
        feature_columns: list = None,
        prediction_columns: list = None,
        explainer_args: dict = None,
        model_language: Literal["python", "spss"] = "python",
        explainer_language: Literal["python", "spss"] = "python",
        predict_proba_column: str = "predict_proba",
        time_column="local_ts",
        time_format="%Y-%m-%d %H:%M:%S",
        group_columns=None,
        training_data=None,
        training_data_columns=None,
        explanation_scope=Explanation_Scope.LOCAL.value,
        hardware_spec_name="S",
        async_mode=False,
    ):
        """
        Initializes explainer instance with the provided args and configures the
        explainer to explain provided model deployment.
        """
        request = {}
        request["model_deployment_id"] = model_deployment_id
        request["space_id"] = space_id
        request["model_language"] = model_language
        request["explainer_language"] = explainer_language
        request["feature_columns"] = feature_columns
        request["prediction_columns"] = prediction_columns
        request["predict_proba_column"] = predict_proba_column
        request["explainer_name"] = explainer_name
        request["explainer_args"] = explainer_args
        request["time_column"] = time_column
        request["time_format"] = time_format
        request["group_columns"] = group_columns
        request["training_data"] = training_data
        request["training_data_columns"] = training_data_columns
        request["explanation_scope"] = explanation_scope
        request["hardware_spec_name"] = hardware_spec_name
        request["async_mode"] = async_mode

        loop = asyncio.get_event_loop()
        task = self._set_explainer(request)

        if async_mode:
            loop.create_task(task)
            return (
                "running in background.. please check back approximately after 3-5 minutes.",
                202,
            )
        else:
            return loop.run_until_complete(task)

    def get_configured_explainers(self, model_deployment_id: str, space_id: str):
        """
        Returns list of explainers to configured for provided model deployment.
        """
        request = {}
        request["model_deployment_id"] = model_deployment_id
        request["space_id"] = space_id
        response = requests.get(
            self._service_url + "/configured_explainers",
            params=request,
            verify=False,
            headers=self.headers,
        )
        if response.ok:
            res_json = response.json()
            return res_json["configured_explainers"]
        else:
            if response.headers["content-type"] == "application/json":
                return response.json()
            elif response.headers["content-type"] == "text/html":
                return response.text
            else:
                return response

    def get_configured_models(self):
        """
        Returns list of models configured for explainability
        """
        resp = requests.get(
            self._service_url + "/configured_models", verify=False, headers=self.headers
        )
        if resp.ok:
            res_json = resp.json()
            return res_json["configured_models"]
        else:
            if resp.headers["content-type"] == "application/json":
                return resp.json()
            elif resp.headers["content-type"] == "text/html":
                return resp.text
            else:
                return resp

    def delete_explainers(self, model_deployment_id: str, space_id: str):
        """
        Deletes all configured explainers for the given model deployment.
        This action removes the model from the configured list / subscriptions.
        """
        raise NotImplementedError

    def delete_explainer(
        self, model_deployment_id: str, space_id: str, explainer_name: str
    ):
        """
        Disables specified explainer for the given model deployment.
        """
        request = {}
        request["model_deployment_id"] = model_deployment_id
        request["space_id"] = space_id
        request["explainer_name"] = explainer_name
        resp = requests.post(
            self._service_url + "/delete_explainer",
            json=request,
            verify=False,
            headers=self.headers,
        )
        if resp.ok:
            return "SUCCESS"
        else:
            if resp.headers["content-type"] == "application/json":
                return resp.json()
            elif resp.headers["content-type"] == "text/html":
                return resp.text
            else:
                return resp

    def delete_model(self, model_deployment_id: str, space_id: str):
        """
        Disables specified model deployment.
        """
        request = {}
        request["model_deployment_id"] = model_deployment_id
        request["space_id"] = space_id
        resp = requests.post(
            self._service_url + "/delete_model",
            json=request,
            verify=False,
            headers=self.headers,
        )
        if resp.ok:
            return "SUCCESS"
        else:
            if resp.headers["content-type"] == "application/json":
                return resp.json()
            elif resp.headers["content-type"] == "text/html":
                return resp.text
            else:
                return resp

    def delete_explanations(
        self,
        model_deployment_id,
        space_id,
        explainer_name,
        from_ts=None,
        to_ts=None,
        explanation_scope=Explanation_Scope.LOCAL.value,
    ):
        """
        Deletes explanations for selected model and time duration.
        """
        request = {}
        request["model_deployment_id"] = model_deployment_id
        request["space_id"] = space_id
        request["explainer_name"] = explainer_name
        from_ts = "" if from_ts is None else from_ts
        request["from_ts"] = from_ts
        request["to_ts"] = from_ts if to_ts is None else to_ts
        request["explanation_scope"] = explanation_scope
        resp = requests.post(
            self._service_url + "/delete_explanations",
            json=request,
            verify=False,
            headers=self.headers,
        )
        if resp.ok:
            return "SUCCESS"
        else:
            if resp.headers["content-type"] == "application/json":
                return resp.json()
            elif resp.headers["content-type"] == "text/html":
                return resp.text
            else:
                return resp

    async def _payload_logging(self, request):
        try:
            resp = await self.session.post(
                self._service_url + "/payload_logging",
                json=request,
                timeout=ClientTimeout(total=None),
                headers=self.headers,
            )
            if resp.status in [200]:
                return "SUCCESS"
            else:
                if resp.content_type == "application/json":
                    return await resp.json()
                elif resp.content_type == "text/html":
                    return await resp.text()
                else:
                    return resp
        except ServerDisconnectedError as ex:
            msg = "Request is taking time. The job is running in background."
            print(msg)
            return {"message": msg}
        except Exception as ex:
            print("error {}".format(ex))
            return {"error": ex}

    def payload_logging(
        self,
        model_deployment_id: str,
        space_id: str,
        data: list,
        fields: list = [],
        predictions: list = [],
        pred_fields: list = [],
        async_mode=False,
    ):
        """
        Log scoring request and predictions.
        """
        if (
            pred_fields is not None
            and predictions is not None
            and len(pred_fields) > 0
            and len(predictions) > 0
            and len(predictions[0]) < len(pred_fields)
        ):
            print("Shapes of predictions and prediction fields should match.")
            return

        request = {}
        request["model_deployment_id"] = model_deployment_id
        request["space_id"] = space_id
        request["data"] = data
        request["fields"] = fields
        request["predictions"] = predictions
        request["pred_fields"] = pred_fields
        request["async_mode"] = async_mode

        loop = asyncio.get_event_loop()
        task = self._payload_logging(request)

        if async_mode:
            loop.create_task(task)
            return (
                "running in background.. please check back approximately after 5-10 minutes. Will take longer for large payload.",
                202,
            )
        else:
            return loop.run_until_complete(task)

    async def _log_explanations(self, request):
        try:
            resp = await self.session.post(
                self._service_url + "/log_explanations",
                json=request,
                timeout=ClientTimeout(total=None),
                headers=self.headers,
            )
            if resp.status in [200]:
                return "SUCCESS"
            else:
                if resp.content_type == "application/json":
                    return await resp.json()
                elif resp.content_type == "text/html":
                    return await resp.text()
                else:
                    return resp
        except ServerDisconnectedError as ex:
            msg = "Request is taking time. The job is running in background."
            print(msg)
            return {"message": msg}
        except Exception as ex:
            print("error {}".format(ex))
            return {"error": ex}

    def log_explanations(
        self,
        model_deployment_id: str,
        space_id: str,
        explainer_name: str,
        explanations_list: List[dict],
        from_timestamps_list: List[str],
        groups_list: List[dict] = None,
        to_timestamps_list: List[str] = None,
        async_mode=False,
    ):
        """
        Log explanations.
        """
        groups_list = [] if groups_list is None else groups_list
        to_timestamps_list = [] if to_timestamps_list is None else to_timestamps_list

        request = {}
        request["model_deployment_id"] = model_deployment_id
        request["space_id"] = space_id
        request["explainer_name"] = explainer_name
        request["explanations_list"] = explanations_list
        request["from_timestamps_list"] = from_timestamps_list
        request["groups_list"] = groups_list
        request["to_timestamps_list"] = to_timestamps_list
        request["async_mode"] = async_mode

        loop = asyncio.get_event_loop()
        task = self._log_explanations(request)

        if async_mode:
            loop.create_task(task)
            return (
                "running in background.. please check back approximately after 5-10 minutes. Will take longer for large set of explanations.",
                202,
            )
        else:
            return loop.run_until_complete(task)

    async def _run_explainers(self, request):
        try:
            resp = await self.session.post(
                self._service_url + "/run_explainers",
                json=request,
                timeout=ClientTimeout(total=None),
                headers=self.headers,
            )
            if resp.status in [200]:
                return "SUCCESS"
            else:
                if resp.content_type == "application/json":
                    return await resp.json()
                elif resp.content_type == "text/html":
                    return await resp.text()
                else:
                    return resp
        except ServerDisconnectedError as ex:
            msg = "Request is taking time. The job is running in background."
            print(msg)
            return {"message": msg}
        except Exception as ex:
            print("error {}".format(ex))
            return {"error": ex}

    def run_explainers(
        self,
        model_deployment_id,
        space_id,
        from_ts=None,
        to_ts=None,
        groups={},
        explanation_type=Explanation_Type.SINGLE_INSTANCE.value,
        async_mode=False,
    ):
        """
        single instance, batch, full
        To explain single data point: from_ts and to_ts should be of same timestamp. One of them can be None. If differ, throws error.
        To explain batch of data: from_ts and to_ts can be of different timestamps in chronological order. If differ, throws error.
        To explain entire dataset: from_ts and to_ts can be None.
        explanation_type has to be set to appropriate enum value as per above scenarios.
        If additional arguments such as base data set are required, can be provided as extra_args.
        Only respective explainers will run as per provided function parameters.
        All the computed explanations will be persisted.
        """

        if explanation_type is Explanation_Type.SINGLE_INSTANCE.value:
            if (from_ts is None and to_ts is None) or (
                from_ts is not None and to_ts is not None and from_ts != to_ts
            ):
                print(
                    "from and to timestamps should be same or atleast one should have a value for single instance explanation type."
                )
                return
            from_ts = from_ts or to_ts
            to_ts = to_ts or from_ts
        elif explanation_type is Explanation_Type.BATCH.value:
            if from_ts is None or to_ts is None:
                print("from and to timestamps should be provided")
                return
            if parse(from_ts) > parse(to_ts):
                print("from_ts should not be greater than to_ts")
                return

        request = {}
        request["model_deployment_id"] = model_deployment_id
        request["space_id"] = space_id
        request["from_ts"] = from_ts
        request["to_ts"] = to_ts
        request["groups"] = groups
        request["explanation_type"] = explanation_type
        request["async_mode"] = async_mode

        loop = asyncio.get_event_loop()
        task = self._run_explainers(request)

        if async_mode:
            loop.create_task(task)
            return (
                "running in background.. please check back approximately after 5-10 minutes. Will take longer for larger batch size.",
                202,
            )
        else:
            return loop.run_until_complete(task)

    def get_explanations(
        self, model_deployment_id, space_id, from_ts=None, to_ts=None, groups=None
    ):
        """
        Returns explanations from all configured explainers for the provided timestamp range.
        """
        request = {}
        request["model_deployment_id"] = model_deployment_id
        request["space_id"] = space_id
        request["from_ts"] = from_ts
        request["to_ts"] = from_ts if to_ts is None else to_ts
        request["groups"] = json.dumps(groups)
        resp = requests.get(
            self._service_url + "/explanations",
            params=request,
            verify=False,
            headers=self.headers,
        )
        if resp.ok:
            res_json = resp.json()
            return res_json["explanations"]
        else:
            if resp.headers["content-type"] == "application/json":
                return resp.json()
            elif resp.headers["content-type"] == "text/html":
                return resp.text
            else:
                return resp

    def get_payload(
        self, model_deployment_id, space_id, from_ts, to_ts=None, groups={}
    ):
        """
        Returns payload data for the provided timestamp range.
        """
        request = {}
        request["model_deployment_id"] = model_deployment_id
        request["space_id"] = space_id
        request["from_ts"] = from_ts
        request["to_ts"] = from_ts if to_ts is None else to_ts
        request["groups"] = json.dumps(groups)
        resp = requests.get(
            self._service_url + "/payload",
            params=request,
            verify=False,
            headers=self.headers,
        )
        if resp.ok:
            res_json = resp.json()
            return res_json["payload"]
        else:
            if resp.headers["content-type"] == "application/json":
                return resp.json()
            elif resp.headers["content-type"] == "text/html":
                return resp.text
            else:
                return resp

    def _return_result(self, future, result):
        response = None
        if result.exception() is None:
            response = result.result()
        # else:
        #     print("error: {}".format(result.exception()))
        return response

    async def _disconnect(self):
        await self.session.close()

    def disconnect(self):
        loop = asyncio.get_event_loop()
        loop.run_until_complete(self._disconnect())

    def plot_ice_local(
        self,
        explanation,
        plot_width: int = 500,
        plot_height: int = 500,
        plot_bgcolor: str = "white",
        plot_linecolor: str = "gray",
        plot_linewidth: int = 15,
        plot_instance_color: str = "firebrick",
        renderer="notebook",
        **kwargs,
    ):
        # fig = go.Figure()
        explanation = copy.deepcopy(explanation)
        if isinstance(
            explanation["data"]["ice_value"][0], list
        ):  # look for predict probas from SROM models
            explanation["data"]["ice_value"] = np.mean(
                explanation["data"]["ice_value"],
                axis=1,
            ).tolist()
        df = pd.DataFrame(
            {
                "feature_value": explanation["data"]["feature_value"],
                "target_value": explanation["data"]["ice_value"],
            }
        )

        fig = px.line(
            df,
            x="feature_value",
            y="target_value",
            title=explanation["data"]["feature_name"],
            color_discrete_sequence=[plot_linecolor],
            width=plot_linewidth,
        )
        if "current_value" in explanation["data"]:
            x_local = explanation["data"]["current_value"]
            reference_line = go.Scatter(
                x=[x_local, x_local],
                y=[np.min(df["target_value"]), np.max(df["target_value"])],
                mode="lines",
                line=go.scatter.Line(color=plot_instance_color, dash="dash"),
                showlegend=False,
            )
            fig.add_trace(reference_line, row=1, col=1)

        fig.update_layout(
            width=plot_width,
            height=plot_height,
            xaxis_title=f"{explanation['data']['feature_name']}",
            title_text=f"ICE (local)",
            yaxis_title=f"prediction",
            plot_bgcolor=plot_bgcolor,
        )
        return fig

    def plot_shap_local(
        self,
        explanation,
        plot_type: str = "forceplot",
        **kwargs,
    ):
        if plot_type == "forceplot":
            shap.initjs()
            base_value = explanation["data"]["model_average"]
            shap_values = np.asarray(explanation["data"]["feature_contrib"])
            feat_values = np.asarray(explanation["data"]["feature_values"])
            feat_names = explanation["data"]["feature_names"]
            fig = shap.force_plot(base_value, shap_values, feat_values, feat_names)
            return fig
        else:
            raise Exception("plot type {} is not supported.".format(plot_type))

    def plot_saliency_local(
        self,
        explanation,
        top_k: int = 10,
        eps: float = 1e-3,
        plot_width: int = 500,
        plot_height: int = 500,
        positive_color: str = "#33e594",
        negative_color: str = "#d5852a",
        plot_bgcolor: str = "white",
        plot_title: str = "Feature Sensitivity",
        plot_yaxis_title: str = "sensitivity",
        renderer="notebook",
        **kwargs,
    ):

        sensitivity = explanation["data"]["sensitivity"]
        feat_names = explanation["data"]["feature_names"]

        sensitivity = np.array(sensitivity)

        index_order = np.argsort(np.abs(sensitivity))[::-1]
        feat_order = [feat_names[i] for i in index_order]

        index_cutoffs = np.where(np.abs(sensitivity)[index_order] < eps)[0]

        n = len(index_cutoffs)
        n = min(top_k, len(feat_names) - n)

        feat_names = feat_order[:n]
        sensitivity = sensitivity[index_order[:n]]

        color = np.repeat(positive_color, n)

        neg_idx = np.where(sensitivity < 0)[0]
        if len(neg_idx) > 0:
            color[neg_idx] = negative_color

        fig = go.Figure()

        for i in reversed(range(n)):
            fig.add_trace(
                go.Bar(
                    x=[sensitivity[i]],
                    y=[feat_names[i]],
                    orientation="h",
                    marker=dict(color=color[i]),
                    showlegend=False,
                )
            )

        fig.update_layout(
            width=plot_width,
            height=plot_height,
            barmode="relative",
            title_text=f"{plot_title}",
            xaxis_title=f"{plot_yaxis_title}",
            plot_bgcolor=plot_bgcolor,
        )

        return fig

    def plot_macem_local(
        self,
        explanation,
        prediction,
        favored_label=0,
        select_imp_features=False,
        delta=0.1,
        plot_width: int = None,
        plot_height: int = 600,
        renderer="notebook",
        **kwargs,
    ):
        interesting_features = explanation["data"]["feature_names"]
        explanation_type = "pp"
        explanation_type_legend = "similar"
        if prediction != favored_label:
            explanation_type = "pn"
            explanation_type_legend = "contrastive"

        df = pd.DataFrame(
            {
                "readings": explanation["data"]["current_value"],
                "reading_type": "actual",
                "features": interesting_features,
            }
        )
        df_2 = pd.DataFrame(
            {
                "readings": explanation["data"][explanation_type],
                "reading_type": explanation_type_legend,
                "features": interesting_features,
            }
        )

        if select_imp_features:
            interesting_features = df[
                (df["readings"] - df_2["readings"]).abs() / df["readings"] > delta
            ]["features"].values

        df = df[df["features"].isin(interesting_features)].append(
            df_2[df_2["features"].isin(interesting_features)]
        )
        fig = px.bar(
            df,
            x="features",
            y="readings",
            color="reading_type",
            barmode="group",
            width=plot_width,
            height=plot_height,
        )

        return fig

    @override_help(plot_global_pdp)
    def plot_pdp_global(self, *args, **kwargs):
        return plot_global_pdp(*args, **kwargs)

    @override_help(plot_local_groupedce)
    def plot_groupedce_local(self, *args, **kwargs):
        return plot_local_groupedce(*args, **kwargs)

    @override_help(plot_global_shap)
    def plot_shap_global(self, *args, **kwargs):
        return plot_global_shap(*args, **kwargs)
