# Licensed Materials - Property of IBM
# 5737-M66, 5900-AAA, 5900-AMG
# (C) Copyright IBM Corp. 2019, 2025 All Rights Reserved.
# US Government Users Restricted Rights - Use, duplication, or disclosure
# restricted by GSA ADP Schedule Contract with IBM Corp.

"""This module includes various utility functions."""

__all__ = [
    'setup_logging',
    'set_log_level',
    'get_logger',
    'log_df_info',
    'get_as_schema',
    'get_table_name',
]


import copy
import logging
import logging.config
import os
import re
from typing import Any, Union
import json

import numpy as np
import pandas as pd
import requests
from requests_futures.sessions import FuturesSession
from sklearn.utils.extmath import log_logistic, softmax

KEYS_NEEDING_MASK = {'apmapitoken', 'apikey', 'maxauth', 'x-api-key', 'x-api-token', 'Authorization', 'DB_CONNECTION_STRING', 'MH_USER', 'COS_HMAC_ACCESS_KEY_ID',
                     'COS_HMAC_SECRET_ACCESS_KEY', 'API_TOKEN', 'API_KEY', 'MH_PASSWORD', 'as_apikey', 'as_apitoken', 'mahi_apikey', 'wml_cred', 'X-api-key', 'X-api-token',
                     'APM_API_KEY'}

DEFAULT_LOG_LEVEL = logging.DEBUG
logger = logging.getLogger(__name__)

def setup_logging():
    """
    Setup proper logging mechanism including log format and level. This function runs automatically when PMLIB gets
    initialized; you should never need to call this yourself.

    To change the log level, call `set_log_level`.
    """

    logging.config.dictConfig({
        'version': 1,
        'disable_existing_loggers': False,
        'formatters': {
            'simple': {
                'format': '%(asctime)s.%(msecs)03d %(levelname)s::%(name)s.%(funcName)s: %(message)s',
                'datefmt': '%Y-%m-%dT%H:%M:%S'
            }
        },
        'handlers': {
            'console': {
                'class': 'logging.StreamHandler',
                'formatter': 'simple',
                'stream': 'ext://sys.stdout'
            },
            'file': {
                'class': 'logging.FileHandler',
                'filename': 'main.log',
                'mode': 'w',
                'formatter': 'simple'
            }
        },
        'root': {
            'level': DEFAULT_LOG_LEVEL,
            'handlers': ['console', 'file']
        }
    })

    # check if pmlib log level == NOTSET, meaning user has not yet called set_log_level
    if logging.getLogger('pmlib').level == logging.NOTSET:
        logger.info('Log level has not been set yet... setting to default level of %s', DEFAULT_LOG_LEVEL)
        set_log_level(DEFAULT_LOG_LEVEL)
    
    # set log level of noisy loggers
    set_log_level('INFO', ['urllib3.connectionpool'])
    set_log_level('ERROR', ['sklearn', 'py4j', 'pyspark', 'numexpr', 'iotfunctions', 'ibm_watson_machine_learning', 'matplotlib'])
    logging.captureWarnings(True)


def set_log_level(level, loggers: Union[list[str], str] = None):
    """
    Sets the loggers to the specified log level.

    Parameters
    ----------
    level : str
        Log level to set loggers to. Options include `DEBUG`, `INFO`, `WARNING`, `ERROR`, `FATAL`, etc.
    loggers: list[str] | str , optional
        List of loggers to change log level for. If not provided, a list of default loggers for the 
        project is changed, including root, pmlib, srom, and iotfunctions.
    """
    loggers = ['root', 'pmlib', 'srom', 'analytics_service'] if loggers is None else loggers
    loggers = loggers if isinstance(loggers, list) else [loggers]
    for log in loggers:
        logging.getLogger(log).setLevel(level)


def get_logger(obj):
    """Get logger for the given object following common logger naming convention.

    The logger is named by `pmlib.<module_name>.<class_name>` from the given object.

    Parameters
    ----------
    obj:
        The object to get a logger for.

    Returns
    -------
    The logger object.
    """
    return logging.getLogger('%s.%s' % (obj.__module__, obj.__class__.__name__))

def log_df_info(df, head=0, maxlen=0, comment=None, include_missing_value_count=False, include_memory_usage=False, logger: logging.Logger = None, log_level: int = logging.DEBUG):
    """Helper function for extracting information for debugging the given dataframe.

    This function returns a string containing the information about the gvein dataframe, including 
    the shape and length of it, the indexes/columns and their data types, as well as the requested 
    number of rows extracted.

    Parameters
    ----------
    df : `pandas.DataFrame`, `pandas.Series`, `list`, or `numpy.nparray`
        The dataframe to extract information.
    head: `int`, optional
        Number of rows to extract. If <0, print all rows. If ==0, print no rows. Default is 0.
    maxlen: `int`, optional
        Maximum length of characters allowed. Default is 0, which means no limit.
    comment: `str`, optional
        Include a comment at the beginning of the DataFrame output.
    include_missing_value_count:
        True to include a count of the missing values.
    include_memory_usage: `bool`, optional
        True to print the amount of memory the DataFrame is using.
    logger: `Logger`, optional
        Pass in a logger to exit function will processing DataFrame if param `log_level` is not 
        enabled for this logger. Optional, but can be used to save computation time for 
        large DataFrames. This is only necessary if outputting directly to logger.
    log_level: `int`, optional
        Log level you are logging this DataFrame to. If param `logger` has a higher log level 
        than this, the function will exit without processing the data frame. The desired log 
        level `int` can be passed in using the logging library (i.e., `logging.DEBUG`). Default is `logging.DEBUG`.


    Returns
    -------
    str
        A string containing the debugging information of the given df.
    """
    if logger is not None and not logger.isEnabledFor(log_level):
        return

    if isinstance(df, list):
        try:
            df = pd.DataFrame(np.array(df))
        except ValueError:
            pass
    elif isinstance(df, np.ndarray):
        df = pd.DataFrame(df)
    elif isinstance(df, pd.Series):
        df = pd.DataFrame({df.name: df})
    if not isinstance(df, pd.DataFrame):
        return str(df)

    info = "\n=========== START DATAFRAME LOG ==========="
    info += ('\n' + comment) if comment is not None else ''
    info += '\nshape=%s, ' % str(df.shape)
    info += '\nindex=%s, ' % re.sub(r'dtype\(([^)]+)\)', r'\1', str(df.index.to_frame().dtypes.to_dict()))
    info += '\ncolumns=%s, ' % re.sub(r'dtype\(([^)]+)\)', r'\1', str(df.dtypes.to_dict()))

    if include_missing_value_count:
        info += '\nmissing_value_count=%s' % str(df.isna().sum().to_dict())
    if include_memory_usage:
        info += '\nmemory_usage=%s' % str(df.memory_usage(deep=True).to_dict())

    if df.empty:
        info += '\ndf=Empty DataFrame'
    elif head == 1:
        info += '\n1st_row=%s' % str(df.head(1).reset_index().to_dict())
    elif head > 1:
        info += '\nhead(%d)=\n%s' % (head, str(df.head(head)))
    elif head < 0:
        info += '\ndf=\n%s' % str(df)

    if maxlen > 0:
        info = info[0:maxlen]

    info += "\n=========== END DATAFRAME LOG ==========="

    return info


def find_list_duplicate(list_obj):
    '''This method removes duplicate items in the given list object, while keeping original order.
    '''
    seen = set()
    seen_add = seen.add
    return [x for x in list_obj if (x in seen or seen_add(x))]


def remove_list_duplicate(list_obj):
    '''This method removes duplicate items in the given list object, while keeping original order.
    '''
    seen = set()
    seen_add = seen.add
    return [x for x in list_obj if not (x in seen or seen_add(x))]

def mask_credential(key, value, to_mask=KEYS_NEEDING_MASK):
    masked_value = value
    if key in to_mask and masked_value is not None:
        masked_value = '********'
    return masked_value


def mask_credential_in_dict(dict_to_mask, keys_to_mask=KEYS_NEEDING_MASK):
    if dict_to_mask is None:
        return None

    masked_dict = dict_to_mask.copy()
    for key in keys_to_mask:
        if key in masked_dict and masked_dict[key] is not None:
            masked_dict[key] = '********'
    return masked_dict

def mask_credentials_in_object(object_to_mask, keys_to_mask=KEYS_NEEDING_MASK, parent_key=None):
    """Recursively masks credentials in the parameter `object_to_mask`. Similar to `mask_credentials_in_dict`, but can handle various data types and recursively masks the credentials.

    Args:
        object_to_mask (Any): Object to mask credentials in.
        keys_to_mask (list, optional): List of keys needing masked. Defaults to KEYS_NEEDING_MASK.
        parent_key (_type_, optional): Used when masking a dictionary. In a key-value pair, pass in the key of a string to see if the string should be masked. Defaults to None.

    Returns:
        Any: Returns a copy of the object passed in with masked credentials.
    """
    object_masked = copy.deepcopy(object_to_mask)
    if isinstance(object_masked, list):
        for i in range(0, len(object_to_mask)):
            object_masked[i] = mask_credentials_in_object(
                object_to_mask[i], keys_to_mask)
    elif isinstance(object_masked, dict):
        for key in object_to_mask:
            object_masked[key] = mask_credentials_in_object(
                object_to_mask[key], keys_to_mask, parent_key=key)
    elif isinstance(object_to_mask, str) and parent_key in keys_to_mask:
        object_masked = '********'
    return object_masked


def api_request(url, method='get', headers=None, json=None, timeout=300, ssl_verify=True, session=None, **kwargs):
    if '.svc:' in url:
        ssl_verify = False

    logger.debug('Making API Request: method=%s, url=%s, headers=%s, timeout=%s, ssl_verify=%s, json=%s, session=%s, kwargs=%s',
                 method, url, mask_credential_in_dict(headers), timeout, ssl_verify, mask_credentials_in_object(json), session, kwargs)

    r = requests if session is None else session

    if method == 'get':
        resp = r.get(url, headers=headers, verify=ssl_verify, json=json, timeout=timeout, **kwargs)
    elif method == 'post':
        resp = r.post(url, headers=headers, verify=ssl_verify, json=json, timeout=timeout, **kwargs)
    elif method == 'put':
        resp = r.put(url, headers=headers, verify=ssl_verify, json=json, timeout=timeout, **kwargs)
    elif method == 'patch':
        resp = r.patch(url, headers=headers, verify=ssl_verify, json=json, timeout=timeout, **kwargs)
    elif method == 'delete':
        resp = r.delete(url, headers=headers, verify=ssl_verify, json=json, timeout=timeout, **kwargs)
    else:
        raise ValueError('unsupported_method=%s' % method)

    if isinstance(r, FuturesSession):
        return resp

    logger.debug('Received API Response: resp.status_code=%s, method=%s, url=%s', resp.status_code, method, url)

    if resp.status_code not in (requests.codes.ok, requests.codes.created, requests.codes.accepted, requests.codes.no_content) and "healthcheck" not in url:
        logger.warning('Unsuccessful api request: url=%s, method=%s, status_code=%s, response_text=%s', url, method, resp.status_code, resp.text)
        return None

    return resp


def current_directory(file=__file__):
    return os.path.dirname(os.path.realpath(file))


def _mkdirp(path):
    if not os.path.isdir(path):
        drive, path = os.path.splitdrive(path)
        path, file = os.path.split(path)
    os.makedirs(path)


def camel_to_snake(name):
    s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
    return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()


def get_as_schema(db_schema):
    schema = db_schema
    if schema is None:
        schema = os.environ.get('AS_SCHEMA', None)
    return schema


def get_table_name(db_schema, table):
    if db_schema is None:
        return table
    else:
        return db_schema + '.' + table
    
def compute_binary_class_probabilities(scores, algo='sigmoid', handle_nans = True):
    class_probabilities = None
    if algo == 'sigmoid':
        class_probabilities = [( np.exp(-1*n) / (1+np.exp(-1*n)), 1/(1+np.exp(-1*n))) for n in scores]
        if handle_nans:
            if sum([True for cp in class_probabilities if pd.isna(cp[0]) or pd.isna(cp[1])]) > 0:
                print('Handling nans by switching to softmax')
                class_probabilities = compute_binary_class_probabilities(scores, 'softmax')
    elif algo == 'softmax':
        class_probabilities = softmax(np.vstack([-scores, scores]).T / 2, copy = False)
    return class_probabilities
    







class integration_json:

    def __init__(self, name, path='/project_data/data_asset/'):
        self.name = name
        self.path = path
    
    def create(self, current=None):
             
        if current is not None:        
                 
            structure = {
                current: {
                    "precursors": [],
                    "outputs": {}
                    }
                }            

            self.data = structure
            self.current=current
            
        else:
            print("Please put in the name of the current notebook")
            
        with open(self.path+self.name, 'w') as jsonfile:
            jsonfile.write(json.dumps(self.data))
            
        
            
    def json_import(self):
    
        with open(self.path+self.name, 'r') as json_file:            
            self.data=json.load(json_file) 
    
            
    def add_files(self, notebooks=None):
  
        if notebooks is None:    
            print("Please add a notebook (string format) or multiple notebooks (list format)")
        elif isinstance(notebooks, str):    
            addition={
                notebooks: {
                    "precursors": [],
                    "outputs": {}
                    }
                }
            self.data.update(addition)
        elif isinstance(notebooks, list):    
            for file in notebooks:        
                addition = {
                    file: {
                        "precursors": [],
                        "outputs": {}
                        }
                    }            
                self.data.update(addition)
            
        
        
            
    def add_precursors(self, precursor):
        
        if isinstance(precursor, str):
        
            self.data[self.current]['precursors'] = self.data[self.current]['precursors'].append(precursor)
            
        elif isinstance(precursor, list):
        
            self.data[self.current]['precursors'] = self.data[self.current]['precursors']+precursor
    
    
    def add_outputs(self, output):
    
        if isinstance(output, tuple):
            self.data[self.current]['outputs'][output[0]]=output[1]
            
        elif isinstance(output, dict):
        
            self.data[self.current]['outputs']=self.data[self.current]['outputs'].update(output)
    
    
    
    def write_out(self):
        with open(self.path+self.name, 'w') as jsonfile:
            jsonfile.write(json.dumps(self.data))
        

            
    
        
      