# -*- coding: utf-8 -*-
# Licensed Materials - Property of IBM
# 5737-M66, 5900-AAA, 5900-AMG
# (C) Copyright IBM Corp. 2019, 2025 All Rights Reserved.
# US Government Users Restricted Rights - Use, duplication, or disclosure
# restricted by GSA ADP Schedule Contract with IBM Corp.

import pandas as pd
import numpy as np
import datetime
from datetime import datetime
from datetime import timedelta
import logging

from pmlib import util

# debug in notebook begin
#import IPython
#from IPython.display import display
# debug in notebook end

class TSAggOps:
    def __init__(self, timeunits = 'D', loglevel = logging.INFO):
        """


        Parameters
        ----------
        timeunits : TYPE, optional
            DESCRIPTION. Values permitted by Pandas Timedelta - [weeks, days, hours, minutes, seconds, milliseconds, microseconds, nanoseconds]. The default is 'days'.

        Returns
        -------
        None.

        """
        self.time_units = timeunits
        
        self.__logger = util.get_logger(self)
        self.__logger.setLevel(loglevel)
        self.__logger.debug('TSAggOps __init__: self.time_units=%s',self.time_units)

    def calculate_mtbf(self, df_sensor_data, df_fault_data, sensor_timestamp_column_name = 'event_timestamp', failure_timestamp_column = 'faildate', total_failure_downtime = 0, time_unit = 'D'):
        """
        Description
        -----------
        There are many different ways and definitions to calculate mean time between failures and the downtime
        related numbers. This is one sample calculation.

        Parameters
        ----------
        df_sensor_data : Pandas DataFrame
            All the sensor data for one device / asset
        df_fault_data : Pandas dataFrame
            All the failure for one device / asset.
        timestamp_column_name : string (in one of the pandas compatible timestamp formats)
            The name of the column / variable that contains the timestamp of either the sensor measurement or fault or failure.
        total_failure_downtime : integer, optional
            DESCRIPTION. The default is 0.
        time_unit : string - should be compatible with pandas types, optional
            DESCRIPTION. The default is 'D'.

        Returns
        -------
        integer
            DESCRIPTION.

        """
        df_sensor_data.sort_values(timestamp_column_name, inplace = True)
        initial_sensor_record_ts = df_sensor_data.head(1)[sensor_timestamp_column_name].values[0]
        self.__logger.debug(initial_sensor_record_ts)
        final_sensor_record_ts = df_sensor_data.tail(1)[sensor_timestamp_column_name].values[0]
        self.__logger.debug(final_sensor_record_ts)
        number_of_failures = df_fault_data.shape[0]
        if number_of_failures == 0:
            return 0
        if number_of_failures > 1:
            return (((final_sensor_record_ts-initial_sensor_record_ts)/np.timedelta64(1,self.time_units))-total_failure_downtime) / number_of_failures
        else:
            fault_ts = df_fault_data[failure_timestamp_column_name].values[0]
            if fault_ts > initial_sensor_record_ts and fault_ts < final_sensor_record_ts:
                ts1 = fault_ts - initial_sensor_record_ts
                ts2 = final_sensor_record_ts - fault_ts
                return ((ts1+ts2)/2)
            else:
                return 0

    def create_time_lagged_records_multiple_devices(self, df_sensor_data_all_devices, df_fault_data_all_devices, sensor_asset_id_column_name = 'id', 
                                                    failure_asset_id_column_name = 'asset_id', sensor_timestamp_column_name = 'evt_timestamp', 
                                                    failure_timestamp_column_name = 'faildate', time_to_event_column_name = 'time_to_event'):
        self.__logger.debug('Beginning of create_time_lagged_records_multiple_devices:')
        df_time_lagged_data = pd.DataFrame()

        df_fault_data_all_devices.sort_values(failure_timestamp_column_name, inplace=True)
        df_sensor_data_all_devices.sort_values(sensor_timestamp_column_name, inplace=True)

        df_sensor_data_grouped = df_sensor_data_all_devices.groupby(sensor_asset_id_column_name)
        for group_name, df_group_by_device in df_sensor_data_grouped:
            self.__logger.debug('%s has the shape %s', group_name, df_group_by_device.shape)
            df_fault_data_by_device = df_fault_data_all_devices[df_fault_data_all_devices[failure_asset_id_column_name]==group_name]
            if df_fault_data_by_device.empty:
                continue
            df_time_lagged_for_curr_group = self.create_time_lagged_records(df_group_by_device , df_fault_data_by_device, sensor_timestamp_column_name = sensor_timestamp_column_name, 
                                                    failure_timestamp_column_name = failure_timestamp_column_name, time_to_event_column_name = time_to_event_column_name)
            self.__logger.debug('df_time_lagged_for_curr_group.shape: %s', df_time_lagged_for_curr_group.shape)
            #printDataFrame(df_time_lagged_for_curr_group)
            if (type(df_time_lagged_for_curr_group) != type(None)) and (not df_time_lagged_for_curr_group.empty):
                if df_time_lagged_data.empty:
                    df_time_lagged_data = df_time_lagged_for_curr_group
                else:
                    df_time_lagged_data = pd.concat([df_time_lagged_data, df_time_lagged_for_curr_group], ignore_index=True)

        return df_time_lagged_data

    def create_time_lagged_records(self, df_sensor_data,df_faults_data,sensor_timestamp_column_name = 'evt_timestamp', 
                                                    failure_timestamp_column_name = 'faildate', time_to_event_column_name = 'time_to_event', time_units = 'D'):
        fault_timestamp_list = pd.to_datetime(df_faults_data[failure_timestamp_column_name]).values
        df_sensor_data_time_lagged = pd.DataFrame()
        prev_ts = None
        if self.time_units == 'D':
            prev_ts = df_sensor_data.iloc[0][sensor_timestamp_column_name]-timedelta(days=1)
        elif self.time_units == 'h':
            prev_ts = df_sensor_data.iloc[0][sensor_timestamp_column_name]-timedelta(hours=1)
        elif self.time_units == 'm':
            prev_ts = df_sensor_data.iloc[0][sensor_timestamp_column_name]-timedelta(minutes=1)

        for fault_ts in fault_timestamp_list:
            self.__logger.debug('Processing the records between %s and %s', prev_ts, fault_ts)
            df1 = df_sensor_data[(df_sensor_data[sensor_timestamp_column_name] <= fault_ts) & (df_sensor_data[sensor_timestamp_column_name] > prev_ts)]
            #printDataFrame(df1)
            prev_ts = fault_ts
            if (type(df1) == type(None)) or (df1.shape[0] == 0):
                self.__logger.debug('Either a None type or empty dataframe is created - skipping this segment')
                continue
            elif df1.shape[0] > 0:
                self.__logger.debug('Computing time to event values for the records in the current segment')
                df1[time_to_event_column_name] = (fault_ts-df1[sensor_timestamp_column_name])/np.timedelta64(1,self.time_units)
            #printDataFrame(df1)
            if df_sensor_data_time_lagged.empty:
                self.__logger.debug('Creating the first batch of records in the time lagged records dataframe')
                df_sensor_data_time_lagged = df1
            else:
                self.__logger.debug('Adding this batch to the time lagged records dataframe')
                df_sensor_data_time_lagged = pd.concat([df_sensor_data_time_lagged, df1], ignore_index=True)


        self.__logger.debug('Accumumated records in the final dataframe with columns %s', df_sensor_data_time_lagged.columns)
        if time_to_event_column_name not in df_sensor_data_time_lagged.columns.values.tolist():
            return df_sensor_data_time_lagged
        df_sensor_data_time_lagged = df_sensor_data_time_lagged[df_sensor_data_time_lagged[time_to_event_column_name] > 0]
        df_sensor_data_time_lagged = df_sensor_data_time_lagged.sort_values(sensor_timestamp_column_name)
        df_sensor_data_time_lagged.reset_index(inplace=True,drop=True)
        return df_sensor_data_time_lagged

    def aggregate_lagged_records_multiple_devices(self,df_time_lagged_data_all_devices, df_fault_data_all_devices, agg_times_dict, lagged_time_col_name = 'time_to_event', 
                                                  sensor_asset_id_column_name = 'id', failure_asset_id_column_name = 'asset_id', sensor_timestamp_column_name = 'evt_timestamp', 
                                                    failure_timestamp_column_name = 'faildate',columnwise_aggfns = None, agg_fns = None, cols_for_aggregation = None):
        df_time_lagged_agg_data = pd.DataFrame()
        df_sensor_data_grouped = df_time_lagged_data_all_devices.groupby(sensor_asset_id_column_name)
        for group_name, df_group_by_device in df_sensor_data_grouped:
            self.__logger.debug(group_name + ' has the shape ' + str(df_group_by_device.shape)+'\n')
            df_fault_data_by_device = df_fault_data_all_devices[df_fault_data_all_devices[failure_asset_id_column_name]==group_name]
            if df_fault_data_by_device.empty:
                continue
            df_time_lagged_agg_data_curr_group = self.aggregate_lagged_records(df_group_by_device , sensor_timestamp_column_name, agg_times_dict[group_name],\
                                                                               lagged_time_col_name, df_fault_data_by_device[failure_timestamp_column_name].values,\
                                                                               agg_fns = agg_fns, cols_for_aggregation = cols_for_aggregation)
            df_time_lagged_agg_data_curr_group[sensor_asset_id_column_name] = group_name
            self.__logger.debug(df_time_lagged_agg_data_curr_group.shape)
            #printDataFrame(df_time_lagged_agg_data_curr_group)
            if df_time_lagged_agg_data.empty:
                df_time_lagged_agg_data = df_time_lagged_agg_data_curr_group
            else:
                df_time_lagged_agg_data = pd.concat([df_time_lagged_agg_data, df_time_lagged_agg_data_curr_group], ignore_index=True)

        return df_time_lagged_agg_data

    def aggregate_lagged_records(self, df_time_lagged_data, timestamp_column_name, agg_time, lagged_time_column_name,
                                       failure_timestamps, columnwise_aggfns = None, agg_fns = None, cols_for_aggregation = None):
        def last_value(data_series):
            return data_series[len(data_series)-1]

        df_time_lagged_agg_data = pd.DataFrame()
        if type(None) == type(columnwise_aggfns) or len(columnwise_aggfns) == 0:
            columnwise_aggfns = {}
            #columnwise_aggfns[timestamp_column_name] = [last_value]
            columnwise_aggfns[lagged_time_column_name] = [last_value]
            for col in cols_for_aggregation:
                columnwise_aggfns[col] = agg_fns
        else:
            cols_for_aggregation = list(columnwise_aggfns.keys())
        self.__logger.debug('The columns and the aggregation functions are '+str(columnwise_aggfns))
        #prev_ts = df_sensor_data.iloc[0][timestamp_column]-timedelta(days=1)

        prev_ts = None
        if self.time_units == 'D':
            prev_ts = df_time_lagged_data.iloc[0][timestamp_column_name]- timedelta(days=1)
        elif self.time_units == 'h':
            prev_ts = df_time_lagged_data.iloc[0][timestamp_column_name]- timedelta(hours=1)
        elif self.time_units == 'm':
            prev_ts = df_time_lagged_data.iloc[0][timestamp_column_name]- timedelta(minutes=1)

        #prev_ts = df_time_lagged_data.iloc[0][timestamp_column_name]- timedelta(days=1)
        #agg_timedelta = pd.Timedelta(agg_time)
        for failure_ts in failure_timestamps:
            self.__logger.debug('The failure timestamp is = ' + str(failure_ts))
            df1 = df_time_lagged_data[(df_time_lagged_data[timestamp_column_name] < failure_ts) & (df_time_lagged_data[timestamp_column_name] > prev_ts)]
            df1[timestamp_column_name] = pd.to_datetime(df1[timestamp_column_name])
            df_agg = None
            if df1.empty:
                self.__logger.debug('There are no records between the timestamps ' + str(prev_ts) + ' and ' + str(failure_ts))
                prev_ts = failure_ts
                continue
            prev_ts = failure_ts
            df1.set_index(timestamp_column_name, inplace = True, drop = False)
            #display(df1.head())
            df_agg = df1.rolling(agg_time).agg(columnwise_aggfns)
            if df_agg.empty:
                self.__logger.debug('The aggregated data frame is empty. So skipping this one ... ')
                continue
            if df_time_lagged_agg_data.empty:
                self.__logger.debug('First time aggregation loop - so assigning. The shape = '+ str(df_time_lagged_agg_data.shape))
                df_time_lagged_agg_data = df_agg
            else:
                #Kewei fixed start if we set ignore_index=True, the event_timestamp is dropped from the data frame
                df_time_lagged_agg_data = pd.concat([df_time_lagged_agg_data, df_agg], ignore_index=False)

                self.__logger.debug('Appended the latest aggregated set of records to the growing data frame. Shape = ' + str(df_time_lagged_agg_data.shape))

        self.__logger.debug('The final data frame to return is printed below')
        #printDataFrame(df_time_lagged_agg_data)
        result_df =  self.flatten_dataframe_multi_index(df_time_lagged_agg_data, use_top_level_for_columns = (lagged_time_column_name))
        return result_df.dropna()

    def aggregate_scoring_data(self, scoring_df, device_id_column_name, timestamp_column_name, agg_time,
                             agg_fns = None, columnwise_aggfns = None, cols_for_aggregation = None,
                             use_zero_stddev_for_one_record = True):
        if type(None) == type(cols_for_aggregation) or len(cols_for_aggregation) == 0:
            self.__logger.debug('From TSAggOps::aggregate_scoring_data(): Columns for aggregation not provided explcitly. Using all columns of the incoming data frame except the index')
            cols_for_aggregation = list(set(scoring_df.columns) - set([timestamp_column_name, device_id_column_name]))
        if type(None) == type(columnwise_aggfns) or len(columnwise_aggfns) == 0:
            self.__logger.debug('From TSAggOps::aggregate_scoring_data: Preparing column wise aggregation definitions')
            columnwise_aggfns = {}
            #columnwise_aggfns[timestamp_column_name] = [last_value]
            #columnwise_aggfns[lagged_time_column_name] = [last_value]
            for col in cols_for_aggregation:
                columnwise_aggfns[col] = agg_fns
        else:
            cols_for_aggregation = list(columnwise_aggfns.keys())

        self.__logger.debug(('From TSAggOps::aggregate_scoring_data: Columns for aggregation = ' + str(cols_for_aggregation)))

        self.__logger.debug(('From TSAggOps::aggregate_scoring_data: Aggregation function for columns = ' + str(columnwise_aggfns)))

        feature_engineered_scoring_df = pd.DataFrame()

        scoring_df.reset_index(inplace = True, drop = False)
        self.__logger.debug(('From TSAggOps::aggregate_scoring_data: Reset the index of the incoming data frame. The columns are '+ str(scoring_df.columns)))
        scoring_df.set_index(timestamp_column_name, inplace = True)
        self.__logger.debug(('From TSAggOps::aggregate_scoring_data: Set the index of the incoming data frame to timestamp values. The columns are '+ str(scoring_df.columns)))
        #display(scoring_df.head())
        scoring_data_grouped_by_device_df = scoring_df.groupby(device_id_column_name)
        for device_id, grouped_df in scoring_data_grouped_by_device_df:
            grouped_df.sort_values(timestamp_column_name, inplace = True)
            agg_df = grouped_df[cols_for_aggregation].rolling(agg_time).agg(columnwise_aggfns)
            single_level_index_agg_df =  self.flatten_dataframe_multi_index(agg_df)
            if grouped_df.shape[0] == 1 and use_zero_stddev_for_one_record:
                self.__logger.debug(('From TSAggOps::aggregate_scoring_data: Data frame with single row encountered for the device '+ device_id+ '. Replacing the standard deviation with 0 for scoring '))
                single_level_index_agg_df= TSAggOps.replace_stddev(grouped_df, single_level_index_agg_df)
            single_level_index_agg_df[device_id_column_name] = device_id
            self.__logger.debug(('From TSAggOps::aggregate_scoring_data: After aggregation for the device '+ device_id+ ' the columns are '+ str(single_level_index_agg_df.columns)))
            #display(single_level_index_agg_df.head())
            single_level_index_agg_df.reset_index(inplace = True, drop = True)
            self.__logger.debug(('From TSAggOps::aggregate_scoring_data: After aggregation for the device '+ device_id+ ' the reset columns are '+ str(single_level_index_agg_df.columns)))
            feature_engineered_scoring_df = pd.concat([feature_engineered_scoring_df, single_level_index_agg_df], ignore_index=True)


        self.__logger.debug("after fo loop feature_engineered_scoring_df")
        #display(feature_engineered_scoring_df)
        feature_engineered_scoring_df.set_index([device_id_column_name, timestamp_column_name], inplace = True, drop = True)
        self.__logger.debug(str(feature_engineered_scoring_df.isna().sum()))
        print(feature_engineered_scoring_df.describe().T)
        return feature_engineered_scoring_df.dropna()

    def flatten_dataframe_multi_index(self, df, use_top_level_for_columns = ()):
        flattened_df = pd.DataFrame()
        multi_index_levels = df.columns.levels
        for top_level in multi_index_levels[0]:
            df_loop = df[top_level]
            for col in df_loop.columns:
                if top_level in use_top_level_for_columns:
                    flattened_df[top_level] = df_loop[col]
                else:
                    flattened_df[top_level+'__'+col] = df_loop[col]
        flattened_df.reset_index(inplace = True)
        #flattened_df[df.index.name] = df.index
        return flattened_df

    def get_timestamp_limits(self, sensor_df, fault_df, sensor_timestamp_column_name = 'timestamp', failure_timestamp_column_name = 'timestamp'):
        sensor_readings_begin_timestamp  =  sensor_df.head(1)[sensor_timestamp_column_name].values[0]
        print('The earliest sensor reading was recorded at ', sensor_readings_begin_timestamp)
        sensor_readings_end_timestamp = sensor_df.tail(1)[sensor_timestamp_column_name].values[0]
        print('The most recent sensor reading was recorded at ', sensor_readings_end_timestamp)
        failure_readings_begin_timestamp  =  fault_df.head(1)[failure_timestamp_column_name].values[0]
        print('The earliest failure was recorded at ', failure_readings_begin_timestamp)
        failure_readings_end_timestamp = fault_df.tail(1)[failure_timestamp_column_name].values[0]
        print('The most recent failure was recorded at ', failure_readings_end_timestamp)

        start_ts = sensor_readings_begin_timestamp
        end_ts = sensor_readings_end_timestamp
        if sensor_readings_begin_timestamp > failure_readings_begin_timestamp:
            start_ts = failure_readings_begin_timestamp
        if sensor_readings_end_timestamp > failure_readings_end_timestamp:
            end_ts = sensor_readings_end_timestamp

        #start_ts_formatted = np.datetime_as_string(start_ts, unit='s').replace('-','').replace('T','_').replace(':','')
        #end_ts_formatted = np.datetime_as_string(end_ts, unit='s').replace('-','').replace('T','_').replace(':','')

        return start_ts, end_ts

    def train_test_split_by_asset(self, sensor_df, fault_df, sensor_df_asset_id_column_name, failure_df_asset_id_column_name, sensor_df_timestamp_column_name = 'timestamp',
                              failure_df_timestamp_column_name = 'timestamp', failure_df_timestamp_format = '%Y-%m-%d', failure_timestamp_resolution = 'D', training_fraction = 0.7):

        sensor_final_train_df = pd.DataFrame()
        sensor_final_test_df = pd.DataFrame()
        failure_final_train_df = pd.DataFrame()
        failure_final_test_df = pd.DataFrame()

        start_ts, end_ts = self.get_timestamp_limits(sensor_df, fault_df, sensor_timestamp_column_name = sensor_df_timestamp_column_name, \
                                                     failure_timestamp_column_name = failure_df_timestamp_column_name)

        for asset_id, grouped_df in sensor_df.groupby(sensor_df_asset_id_column_name):
            self.__logger.debug('Asset Id = %s', asset_id)
            failure_asset_df = fault_df[fault_df[failure_df_asset_id_column_name] == asset_id]
            sorted_failure_timestamps = sorted(failure_asset_df[failure_df_timestamp_column_name].values)
            #print(sorted_failure_timestamps)
            
            failures_for_training_segment = []
            failures_for_testing_segment = []
            if len(sorted_failure_timestamps) == 0:
                self.__logger.debug('No failure data for asset %s skipping and going to the next asset', asset_id)
                continue
            elif len(sorted_failure_timestamps) == 1:
                failures_for_training_segment = [sorted_failure_timestamps[0]]
            elif len(sorted_failure_timestamps) == 2:
                failures_for_training_segment = [sorted_failure_timestamps[0]]
                failures_for_testing_segment = [sorted_failure_timestamps[1]]
            else:
                failures_for_training_segment = sorted_failure_timestamps[0:int(len(sorted_failure_timestamps)*training_fraction)]
                failures_for_testing_segment = set(sorted_failure_timestamps).difference(failures_for_training_segment)
            
            #print('Failures for training segment = ', failures_for_training_segment)
            #failures_for_testing_segment = set(sorted_failure_timestamps).difference(failures_for_training_segment)
            #print('Failures for testing segment = ', failures_for_testing_segment)

            train_start = datetime.strptime(np.datetime_as_string(start_ts, unit=failure_timestamp_resolution), failure_df_timestamp_format)
            train_end = datetime.strptime(np.datetime_as_string(failures_for_training_segment[len(failures_for_training_segment)-1], unit = failure_timestamp_resolution), failure_df_timestamp_format)
            #print('Training data range for asset = ', asset_id, train_start, train_end)
            
            # sensor_table[sensor_date] = pd.to_datetime(sensor_table[sensor_date], format=sensor_date_format)
            sensor_train_df = grouped_df.loc[(grouped_df[sensor_df_timestamp_column_name] <= train_end) & (grouped_df[sensor_df_timestamp_column_name] >= train_start)]
            sensor_train_df[sensor_df_asset_id_column_name] = asset_id
            
            failure_train_df = failure_asset_df.loc[(failure_asset_df[failure_df_timestamp_column_name] <= train_end) & (failure_asset_df[failure_df_timestamp_column_name] >= train_start)]
            failure_train_df[failure_df_asset_id_column_name] = asset_id
            
            sensor_final_train_df = pd.concat([sensor_final_train_df, sensor_train_df], ignore_index=True)
            failure_final_train_df = pd.concat([failure_final_train_df, failure_train_df], ignore_index=True)

            test_start = datetime.strptime(np.datetime_as_string(failures_for_training_segment[len(failures_for_training_segment)-1], unit =failure_timestamp_resolution),failure_df_timestamp_format)
            test_end = datetime.strptime(np.datetime_as_string(end_ts, unit = failure_timestamp_resolution), failure_df_timestamp_format)
            self.__logger.debug('Testing data range for asset %s %s %s', asset_id, test_start, test_end)
            sensor_test_df = grouped_df.loc[(grouped_df[sensor_df_timestamp_column_name] <= test_end)   &  (grouped_df[sensor_df_timestamp_column_name] >= test_start)]
            sensor_test_df[sensor_df_asset_id_column_name] = asset_id
            sensor_final_test_df = pd.concat([sensor_final_test_df, sensor_test_df], ignore_index=True)

            
            if len(failures_for_testing_segment) > 0:          
                failure_test_df = failure_asset_df.loc[(failure_asset_df[failure_df_timestamp_column_name] <= test_end)   &  (failure_asset_df[failure_df_timestamp_column_name] >= test_start)]
                failure_test_df[failure_df_asset_id_column_name] = asset_id
                failure_final_test_df = pd.concat([failure_final_test_df, failure_test_df], ignore_index=True)


        return sensor_final_train_df, failure_final_train_df, sensor_final_test_df, failure_final_test_df

    @classmethod
    def replace_stddev(cls, source_df, aggregated_df):
        for col in aggregated_df.columns.values:
            if 'std' in col:
                aggregated_df[col] = np.std(source_df[col.rsplit('__std')[0]])
        return aggregated_df
