# Licensed Materials - Property of IBM
# 5737-M66, 5900-AAA, 5900-AMG
# (C) Copyright IBM Corp. 2019, 2025 All Rights Reserved.
# US Government Users Restricted Rights - Use, duplication, or disclosure
# restricted by GSA ADP Schedule Contract with IBM Corp.

import datetime as dt
import logging

import numpy as np
import pandas as pd
from sqlalchemy import (Boolean, Column, DateTime, Float, Index, MetaData,
                        String, Table, func)
from sqlalchemy.schema import UniqueConstraint

from . import api
from .util import get_as_schema, get_logger, log_df_info

DATA_ITEM_COLUMN_TYPE_KEY = 'columnType'
DATA_ITEM_SOURCETABLE_KEY = 'sourceTableName'
DATA_ITEM_DATATYPE_BOOLEAN = 'BOOLEAN'
DATA_ITEM_DATATYPE_NUMBER = 'NUMBER'
DATA_ITEM_DATATYPE_LITERAL = 'LITERAL'
DATA_ITEM_DATATYPE_TIMESTAMP = 'TIMESTAMP'
KPI_VALUE_COLUMN_NAMES = {
    DATA_ITEM_DATATYPE_BOOLEAN: 'value_b',
    DATA_ITEM_DATATYPE_NUMBER: 'value_n',
    DATA_ITEM_DATATYPE_LITERAL: 'value_s',
    DATA_ITEM_DATATYPE_TIMESTAMP: 'value_t',
}
KPI_ENTITY_ID_COLUMN = 'entity_id'
KPI_TIMESTAMP_COLUMN = 'timestamp'
KPI_KEY_COLUMN = 'key'
KPI_VALUE_S_COLUMN = 'value_s'
KPI_LAST_UPDATE_COLUMN = 'last_update'


class PersistColumns:

    def __init__(self, target_grain, target_grain_tuple, target_table, db=None, db_schema=None,**kwargs):
        self.logger = get_logger(self)

        self.target_grain = target_grain
        self.target_grain_tuple = target_grain_tuple # (frequency, grain_dimensions, is_entity_first)
        if self.target_grain_tuple is None:
            self.target_grain_tuple = ('', None, True)
        self.target_table = target_table

        self.logger.debug('target_grain=%s, target_grain_tuple=%s, target_table=%s', self.target_grain, self.target_grain_tuple, self.target_table)

        db_schema = get_as_schema(db_schema)
        self.db_schema = db_schema
        self.db = db

        setattr(self,'event_timestamp_column_name', kwargs.get('event_timestamp_column_name','evt_timestamp'))
        
    def execute(self, df, start_ts=None, end_ts=None, entities=None, force_create=False):
        self.logger.debug('df_input=%s', log_df_info(df, head=5, logger=self.logger, log_level=logging.DEBUG))

        col_list=list(df.columns)

        
        for name in col_list:
            if name.find('anomaly_detected') != -1:
                self.logger.debug('Persist execute found anomaly detection name=%s',name)
                df[name] =df[name].map(lambda x: bool(x))
        self.logger.debug('df_input after change=%s', log_df_info(df, head=5, logger=self.logger, log_level=logging.DEBUG))

        #dm_1015, dm_1025_Daily,
        arr = str(self.target_table).split('_')
        if len(arr) >=2:
            asset_group_id=arr[1]
            self.logger.debug('Found asset group ID in target table name: asset_group_id=%s', asset_group_id)

        # Kewei: we need to call _get_db again
        # if the asset_group_id= '1029, the db from the parameter is List of keys retrieved from the entity type metadata: dict_keys([1, 2, 52, 55, 56, 57, 58, 59, 61, 53,
        #  51, 54, 73, 76, 77, 75, 87, 78, 81, 89, 83, 84, 85, 86, 79, 82, 92, 94, 60, 88, 90,
        #  91, 93, '1029', 'ASSET_CACHE', 'IOT_DEVICE_TYPE_94'])

        # There is no entity_type_id=95 for asset_group_id= '1029'

        self.db = api._get_db(asset_group_id=asset_group_id)
        # The new key List of keys retrieved from the entity type metadata: dict_keys([1, 2, 51, 53, 54, 55, 56, 57,
        #  58, 61, 52, 59, 73, 83, 87, 88, 60, 81, 77, 78, 82, 94, 79, 89, 92, 85, 86, 84, 90, 91, 93, 
        # 95, 75, 76])
        # Found matching entity type in the key 95

        #{'entityTypeId': 95,
 
        self.target_table = api.get_new_monitor_table_name_for_prediction_result(self.db, asset_group_id,col_list)
        if self.target_table is None:
            error_msg= 'Can not find  the target_table for col_list'+ str(col_list)
            raise Exception(error_msg)
        self.logger.debug('Retrieved new monitor table name for prediction result: %s', self.target_table)
        
        df = df.copy()

        # derived tables have unique constraint on keys, so have to dedup first (while keeping the last one only)
        df = df.loc[~df.index.duplicated(keep='last')]

        data_item_types = self._get_data_item_meta(df_grain=df, grain_dimensions=self.target_grain_tuple[1])

        # determine data item types and add under proper top level column labels
        columns_to_persist = []
        for name, dtype in df.dtypes.to_dict().items():
            data_item_type = data_item_types.get(name)
            if data_item_type is not None:
                value_column_name = KPI_VALUE_COLUMN_NAMES.get(data_item_type, KPI_VALUE_S_COLUMN)
                columns_to_persist.append((value_column_name, name))

        # "group" same type columns by adding one column level so later making it easier to use stack() directly
        df.columns = pd.MultiIndex.from_tuples(columns_to_persist)

        self.logger.debug('Columns to persist: %s', columns_to_persist)

        # stack it
        df = df.stack()

        # reset indexes and make sure all columns are with correct names
        df.index = df.index.set_names(names=df.index.names[0:-1] + [KPI_KEY_COLUMN])
        df = df.reset_index()
        df = df.rename(columns={'id': KPI_ENTITY_ID_COLUMN, self.event_timestamp_column_name: KPI_TIMESTAMP_COLUMN})

        # extra columns
        df[KPI_LAST_UPDATE_COLUMN] = dt.datetime.utcnow()


        # make sure literal columns are indeed strings (sqlalchemy does not convert type automatically)
        if KPI_VALUE_S_COLUMN in df.columns:
            # we don't wnat NaN become 'nan' so only convert those not NA rows
            df[KPI_VALUE_S_COLUMN] = df[pd.notna(df[KPI_VALUE_S_COLUMN])][KPI_VALUE_S_COLUMN].astype(str)

        self.logger.debug('Target table: %s, DataFrame to write: %s', self.target_table, log_df_info(df, logger=self.logger, log_level=logging.DEBUG))

        # write the df directly
        self.db.start_session()
        try:
            # make sure the target table exist, or create it properly if not
            # becuase we want _write to be usable even before a model instance is not yet enabled, 
            # the base kpi table and the one for Daily grain (AS default one) might not have been 
            # created yet, so for those 2 cases we want to forcefully create it. AS is fine with 
            # kpi tables already exists and gracefully skip the table creation
            table = self._get_table(self.target_table, data_item_types, force_create=(force_create or self.target_grain is None or self.target_grain == 'Daily'))

            # delete first
            for out in data_item_types.keys():
                self.logger.debug('clearing existing data of kpi=%s', out)
                col = [_c for _c in table.columns if _c.name.lower() == KPI_KEY_COLUMN][0]
                with self.db.engine.connect() as conn: 
                    conn.execute(table.delete().where(col == out))

            self.logger.debug('DataFrame to write: %s', log_df_info(df, logger=self.logger, log_level=logging.DEBUG))
            api._write_dataframe(df=df, table_name=self.target_table, db=self.db, db_schema=self.db_schema)
        except:
            self.db.session.rollback()
            raise
        finally:
            self.db.commit()

        # pass the incoming df to downstream after completining the persistence
        return df

    def _get_data_item_meta(self, df_grain, grain_dimensions):
        data_item_meta = {}

        # grain dimensions are in index, so need to collect from both columns and index
        all_items = df_grain.dtypes.to_dict()
        if grain_dimensions is None:
            grain_dimensions = []
        all_items.update({k:v for k, v in df_grain.index.to_frame().dtypes.to_dict().items() if k in grain_dimensions})

        for out, dtype in all_items.items():
            data_type = DATA_ITEM_DATATYPE_NUMBER # default is NUMBER in AS when no dataType is given
            if pd.api.types.is_bool_dtype(dtype):
                data_type = DATA_ITEM_DATATYPE_BOOLEAN
            elif pd.api.types.is_numeric_dtype(dtype):
                data_type = DATA_ITEM_DATATYPE_NUMBER
            elif pd.api.types.is_string_dtype(dtype):
                data_type = DATA_ITEM_DATATYPE_LITERAL
            elif pd.api.types.is_datetime64_any_dtype(dtype):
                data_type = DATA_ITEM_DATATYPE_TIMESTAMP

            data_item_meta[out] = data_type

        return data_item_meta

    def _get_table(self, table_name, data_item_types, force_create=False):
        # this is only needed for local model sqlite, on normal model AS should have 
        # already created target tables
        try:
            return self.db.get_table(table_name, self.db_schema)
        except KeyError:
            if not api.is_local_mode() and not force_create:
                # we don't want to create grain table directly since it will conflict with AS
                raise

        key_column = KPI_KEY_COLUMN
        timestamp_column = KPI_TIMESTAMP_COLUMN
        if self.db.db_type == 'db2':
            key_column = key_column.upper()
            timestamp_column = timestamp_column.upper()

        columns = []
        if self.target_grain_tuple[2]:
            columns.append(Column(KPI_ENTITY_ID_COLUMN, String(255), nullable=False))
        columns.append(Column(key_column, String(255), nullable=False))
        columns.append(Column(KPI_VALUE_COLUMN_NAMES.get(DATA_ITEM_DATATYPE_NUMBER), Float(), nullable=True))
        columns.append(Column(KPI_VALUE_COLUMN_NAMES.get(DATA_ITEM_DATATYPE_BOOLEAN), Boolean(), nullable=True))
        columns.append(Column(KPI_VALUE_COLUMN_NAMES.get(DATA_ITEM_DATATYPE_LITERAL), String(255), nullable=True))
        columns.append(Column(KPI_VALUE_COLUMN_NAMES.get(DATA_ITEM_DATATYPE_TIMESTAMP), DateTime(), nullable=True))
        if self.target_grain_tuple[0] is not None:
            columns.append(Column(timestamp_column, DateTime(), nullable=False))
        if self.target_grain_tuple[1] is not None:
            for col in self.target_grain_tuple[1]:
                data_item_type = data_item_types.get(col)
                if data_item_type == DATA_ITEM_DATATYPE_NUMBER:
                    columns.append(Column(col.lower(), Float(), nullable=False))
                elif data_item_type == DATA_ITEM_DATATYPE_BOOLEAN:
                    columns.append(Column(col.lower(), Boolean(), nullable=False))
                elif data_item_type == DATA_ITEM_DATATYPE_LITERAL:
                    columns.append(Column(col.lower(), String(255), nullable=False))
                elif data_item_type == DATA_ITEM_DATATYPE_TIMESTAMP:
                    columns.append(Column(col.lower(), DateTime(), nullable=False))
        columns.append(Column(KPI_LAST_UPDATE_COLUMN, DateTime(), nullable=False, server_default=func.current_timestamp()))

        uc_name = ('uc_%s' % table_name).lower()
        uc_columns = []
        if self.target_grain_tuple[2]:
            uc_columns.append(KPI_ENTITY_ID_COLUMN)
        uc_columns.append(key_column)
        if self.target_grain_tuple[0] is not None:
            uc_columns.append(timestamp_column)
        if self.target_grain_tuple[1] is not None:
            for col in self.target_grain_tuple[1]:
                uc_columns.append(col)
        columns.append(UniqueConstraint(*uc_columns, name=uc_name))

        table = Table(table_name,
                      self.db.metadata,
                      *columns,
                      schema=self.db_schema)
        engine = self.db.engine
        table.create(bind=engine, checkfirst=True)

        if self.db.db_type == 'db2':
            # strangly, sqlalchemy with db2, merely adding unique constraint does not cut, we have to 
            # actually create the index in order to make it work ... (SQLAlchemy 1.3.10)
            index_name = ('uc_%s' % table_name).lower()
            try:
                columns = []
                if self.target_grain_tuple[2]:
                    columns.append(table.c[KPI_ENTITY_ID_COLUMN])
                columns.append(table.c[key_column])
                if self.target_grain_tuple[0] is not None:
                    columns.append(table.c[timestamp_column])
                if self.target_grain_tuple[1] is not None:
                    for col in self.target_grain_tuple[1]:
                        columns.append(table.c[col])

                Index(index_name, 
                      *columns,
                      unique=True).create(bind=self.db.connection)
            except Exception as e:
                self.logger.warning('failed creating index=%s for table=%s schema=%s: %s', index_name, table_name, self.db_schema, e)

        return self.db.get_table(table_name, self.db_schema)

