# Licensed Materials - Property of IBM
# 5737-M66, 5900-AAA
# (C) Copyright IBM Corp. 2019, 2025 All Rights Reserved.
# US Government Users Restricted Rights - Use, duplication, or disclosure
# restricted by GSA ADP Schedule Contract with IBM Corp.

"""This module includes various utility functions."""

__all__ = [
    'plot_variables'
]


import copy
import logging
import logging.config
import os
import re
from typing import Any, Union

import numpy as np
import pandas as pd
import datetime as dttime

try:
    import matplotlib
    from matplotlib import pyplot as pyplt
except Exception as ex:
    print("Exception raised while attempting to import matplotlib", '\n', \
    ex, '\n', \
    "Environment does not support matplotlib. This function will be disabled in environments other\
         than IPython / Jupyter notebooks. This is a harmless exception and can be safely ignored in \
            non-IPython environment")


DEFAULT_LOG_LEVEL = logging.ERROR
logger = logging.getLogger(__name__)

def plot_variables(df:pd.DataFrame, variables_to_plot:list, group_by_column_name:str = 'id',\
     timeseries_plot:bool = False, timestamp_column_name:str = 'evt_timestamp',figsize:tuple = (15,15)):
    """
    Plot the variables as a simple line chart.
    """
    grouping_level = None
    group_by = None
    if isinstance(df.index, pd.MultiIndex):
        grouping_level = group_by_column_name
    else:
        group_by = group_by_column_name

    for group_id, grouped_df in df.groupby(by = group_by, level = grouping_level):
        if timeseries_plot:
            if (not timestamp_column_name in list(df.columns.values)) and (grouping_level != None):
                grouped_df.reset_index(inplace = True, drop = False)
            grouped_df.set_index(timestamp_column_name, inplace = True)
        axes = grouped_df[variables_to_plot].plot(subplots=True, figsize=(15, 15), title='Asset ID = '+group_id)
    for i, [ax, var_name] in enumerate(zip(axes,variables_to_plot),start=1):
        ax.legend()
        ax.set_xlabel("Distribution over time" )
        ax.set_ylabel(var_name)
    
    pyplt.tight_layout()
    pyplt.show()

def plot_timeseries(df, variables_to_plot, group_by_column_name = 'id', timestamp_column_name = 'evt_timestamp', figsize = (15,15)):
    """
    Plot the variables as a simple line chart.
    """
    grouping_level = None
    group_by = None
    if isinstance(df.index, pd.MultiIndex):
        grouping_level = group_by_column_name
    else:
        group_by = group_by_column_name

    for asset_id, grouped_df in df.groupby(by = group_by, level = grouping_level):
        grouped_df
        axes = grouped_df[variables_to_plot].plot(subplots=True, figsize=(15, 15), title='Asset ID = '+asset_id)
    for i, [ax, var_name] in enumerate(zip(axes,variables_to_plot),start=1):
        ax.legend()
        ax.set_xlabel("Distribution over time" )
        ax.set_ylabel(var_name)
    
    pyplt.tight_layout()
    pyplt.show()

def plot_variables_with_vertical_bars(df:pd.DataFrame, x_col:list, y_col:list, bar_length_col:str,\
    xlim_col:str='index', is_xaxis_timestamp:bool = True, timestamp_column_name:str = 'evt_timestamp',\
    event_data:dict = {},event_criteria:list = [], event_col:str='index',\
    group_by_column_name:str = 'id', fig_size:tuple = (27,5), color_list:list=['b','g','c','m','y'],\
    **kwargs):

    linestyle = kwargs.get('linestyle','dashed')
    linewidth = kwargs.get('linewidth',2)
    vline_color = kwargs.get('vline_color','red')

    grouping_level = None
    group_by = None
    if isinstance(df.index, pd.MultiIndex):
        grouping_level = group_by_column_name
    else:
        group_by = group_by_column_name
    for group_id, grouped_df in df.groupby(by = group_by, level = grouping_level):
        if event_data != None:
            event_points = event_data[group_id]
        else:
            query_expression = ''
            for query_segment in event_criteria:
                query_expression += '('+ query_segment +') &'
            query_expression = query_expression.rstrip('&')
            logger.debug('pmlib.visualization_util::plot_variables_with_vertical_bars() - %s %s',\
                'Query to extract the data points for vertical bars: ', query_expression)

            if grouping_level != None:
                if event_col == 'index':
                    event_data = grouped_df.query(query_expression).loc[(group_id),:].index
                else:
                    event_data = grouped_df.query(query_expression).loc[(group_id),:][event_col].values
                #event_data = df.loc[(group_id),:].index
            else:
                if event_col == 'index':
                    event_data = grouped_df.query(query_expression).index
                else:
                    event_data = grouped_df.query(query_expression)[event_col].values
        length_event_data = len(event_data)
        logger.debug('pmlib.visualization_util::plot_variables_with_vertical_bars() - %s %s %s',\
                'event data points for vertical bars: length = ', length_event_data, event_data[0:5 if length_event_data >5 else length_event_data])
        bar_length_range = grouped_df[bar_length_col]

        fig = pyplt.figure(figsize=fig_size)

        if grouping_level != None:
            if xlim_col == 'index':
                axis_range = grouped_df.loc[(group_id),:].index
            else:
                axis_range = grouped_df.loc[(group_id),:][xlim_col].values
        else:
            if xlim_col == 'index':
                axis_range = grouped_df.index
            else:
                axis_range = grouped_df[xlim_col].values
        if is_xaxis_timestamp:
            pyplt.xlim([axis_range.min().date(), axis_range.max().date()])
        else:
            pyplt.xlim([axis_range.min(), axis_range.max()])
        pyplt.vlines(x = event_data, ymin = bar_length_range.min(), ymax = bar_length_range.max()+1,\
            colors = vline_color, label = group_id, linestyles = linestyle, linewidths = linewidth)
        pyplt.plot(grouped_df[x_col], grouped_df[y_col].values)
        pyplt.show()
pyplt.close()
