# Licensed Materials - Property of IBM
# 5737-M66, 5900-AAA, 5900-AMG
# (C) Copyright IBM Corp. 2019, 2025 All Rights Reserved.
# US Government Users Restricted Rights - Use, duplication, or disclosure
# restricted by GSA ADP Schedule Contract with IBM Corp.

import math
import calendar
import logging
from collections import defaultdict
from datetime import datetime

import dill as pickle
import numpy as np
import pandas as pd
import json
# # Work around to set a global timeout value (on requests) for WMLExecutor
import requests
#from watson_machine_learning_client import WatsonMachineLearningAPIClient
from ibm_watsonx_ai import APIClient  # CP4D 3.5
from iotfunctions.base import BaseEstimatorFunction
from srom.classification.smart_classification import SmartClassification
from srom.pipeline.anomaly_pipeline import AnomalyPipeline
from srom.pipeline.srom_pipeline import SROMPipeline
from srom.anomaly_detection import GaussianGraphicalModel
from srom.wml.wrappers.onprem.scoring import WMLScorer
from .temporal_feature_eng import create_temporal_features, create_simple_temporal_features, create_higher_order_temporal_features, create_advanced_temporal_features

# change to .api
from .api import get_entity_type_id_by_entity_type_name
from .util import _mkdirp, get_logger, log_df_info,compute_binary_class_probabilities

# class _WmlTimeout(requests.adapters.TimeoutSauce):
#     def __init__(self, *args, **kwargs):
#         connect = kwargs.get('connect', 30)
#         read = kwargs.get('read', connect)
#         super(_WmlTimeout, self).__init__(connect=connect, read=read)
# requests.adapters.TimeoutSauce = _WmlTimeout


class BaseEstimator(BaseEstimatorFunction):
    '''Base class for estimators, supporting training/prediction/scoring.

    Note that though the AS base class supports multiple targets per estimator, we
    are stick to single target per estimator for now. So this class assume the
    given targets and predictions are always of an 1-element array. Also, when
    prediction is not needed, the passed in preditions should be an array of one
    element 'None'.
    '''
    def __init__(self, features, targets, predictions, features_for_training=None, training_options=None, **kwargs):
        super().__init__(features, targets, predictions)

        self.features_for_training = features_for_training

        if training_options is None:
            training_options = {}
        self.training_options = training_options

        self.models = dict()
        self.model_extras = defaultdict(list)

        self.df_traces = dict()
        self.logger = get_logger(self)

        self.local_model = True
        self.model_timestamp = None
        self.training_timestamp = None

        for key, value in kwargs.items():
            if not hasattr(self, key):
                setattr(self, key, value)
        
        self.event_timestamp_column_name = kwargs.get(
            'event_timestamp_column_name', 'evt_timestamp')

        self.latest_prediction_timestamp = kwargs.get('latest_prediction_timestamp', None)
        
                

    def get_model_name(self, target_name, suffix=None):
        if suffix is None:
            suffix = self.model_timestamp
        return self.generate_model_name(target_name=target_name, prefix=None, suffix=suffix)

    def generate_model_name(self, target_name, prefix=None, suffix=None):
        name = ['apm', 'pmi', 'model']

        if prefix is not None:
            if isinstance(prefix, str):
                prefix = [prefix]
            if len(prefix) > 0:
                name += prefix
        name.extend([self._entity_type.logical_name, self.name, target_name])
        name = '/'.join(name)

        if suffix is not None:
            if isinstance(suffix, datetime):
                name += '_' + str(calendar.timegm(suffix.timetuple()))
            else:
                name += '_' + str(suffix)

        return name

    def _get_target_name(self):
        return '_' if (self.predictions is None or len(self.predictions) == 0) else self.predictions[0]

    def _load_model(self, bucket):
        model_name = self.get_model_name(target_name=self._get_target_name()) # load with default suffix timestamp
        return (model_name, self.load_model(model_name, bucket, self.local_model))

    def load_model(self, model_path, bucket, local):
        if local:
            # local FS
            model = None
            try:
                with open(model_path, mode='rb') as file:
                    model = file.read()
            except FileNotFoundError as e:
                pass
            return pickle.loads(model) if model is not None else None
        else:
            self.logger.debug('In load_model')

            import os
            cos_kpi = os.environ.get('COS_BUCKET_KPI')
            self.logger.debug('In load_model cos_kpi='+ str(cos_kpi))

            if cos_kpi is not None:
                self.logger.debug('In load_model , load from COS')
                return self._entity_type.db.cos_load(filename=model_path, bucket=bucket, binary=True)
            else:
                if self._entity_type.db.model_store is not None:
                    self.logger.debug('In load_model , load from KPI_MODEL_STORE')
                    return self._entity_type.db.model_store.retrieve_model(model_path, deserialize=True)

    def _save_model(self, bucket, new_model, suffix=None, local=True):
        filename = self.get_model_name(target_name=self._get_target_name(), suffix=suffix) # save with explicity suffix timestamp set

        objects = [(filename, new_model, True, True)] # model itself always pickle dumped and binary
        extras = self.get_model_extra(new_model, filename)
        objects.extend(extras)
        for fname, obj, picket_dump, binary in objects:
            self.save_model(obj, fname, bucket, picket_dump, binary, local)

        # add model to internal list for prediction usage
        self.models[filename] = new_model
        if len(extras) > 0:
            self.model_extras[filename].extend(extras)

    def _get_model_trace_base_path(self, target_name, suffix=None, training=False):
        base_path = self.get_model_name(target_name=self._get_target_name(), suffix=suffix)
        if training:
            # for training, put together with model itself
            return base_path
        else:
            # for prediction, put in separate log location
            now = datetime.utcnow()
            return base_path.replace('apm/pmi/model/', 'apm/pmi/log/', 1) + '/%s/%s' % (now.strftime('%Y%m%d'), now.strftime('%H%M%S'))

    def _save_model_df_trace(self, bucket, df_trace_dict=None, suffix=None, local=True, training=False):
        filename = self._get_model_trace_base_path(target_name=self._get_target_name(), suffix=suffix, training=training)

        # make sure path do exist
        try:
            _mkdirp(filename)
        except:
            pass

        if df_trace_dict is None:
            df_trace_dict = self.df_traces
            if df_trace_dict is None:
                self.logger.warning('model_df_traces not found, skipped saving traces')
                return

        for key, df in df_trace_dict.items():
            if not isinstance(df, pd.DataFrame):
                if isinstance(df, list):
                    df = pd.DataFrame(np.array(df))
                elif isinstance(df, np.ndarray):
                    df = pd.DataFrame(df)
                elif isinstance(df, pd.Series):
                    df = pd.DataFrame({df.name: df})
                # elif isinstance(df,dict):
                #     df = df # it is begin_training_date
                #     self.logger.debug('_save_model_df_trace df=%s', df)
                else:
                    self.logger.warning('unknown df_trace type: %s', type(df))
                    continue

            # always write to local first
            #path = '%s_%s.gz' % (filename, key)
            #df.to_csv(path)

            if key == 'training_date_range':
                # always write to local first

                path = '%s_%s' % (filename, key)
                df.to_csv(path,header=False)

                if not local:
                    self.logger.debug('training_date_range df=%s', log_df_info(df,head=3, logger=self.logger, log_level=logging.DEBUG))
                    with open(path, 'r') as file:
                        self.save_model(new_model=file.read(), model_path=path, bucket=bucket, pickle_dump=False, binary=False, local=False)

            else:
                # always write to local first
                path = '%s_%s.gz' % (filename, key)
                df.to_csv(path)
                if not local:
                    with open(path, 'rb') as file:
                        self.save_model(new_model=file.read(), model_path=path, bucket=bucket, pickle_dump=False, binary=True, local=False)

    def save_model(self, new_model, model_path, bucket, pickle_dump, binary, local):
        self.logger.debug('Saving model...')
        if local:
            mode = 'w'
            if pickle_dump:
                new_model = pickle.dumps(new_model)
            if binary:
                mode += 'b'

            try:
                _mkdirp(model_path)
            except:
                pass
            with open(model_path, mode=mode) as file:
                file.write(new_model)
        else:
            import os
            cos_kpi = os.environ.get('COS_BUCKET_KPI')

            try:
                if pickle_dump:
                    if cos_kpi is not None:
                        self.logger.debug('Saving to COS, pickle_dump=true')
                        self._entity_type.db.cos_save(persisted_object=new_model, filename=model_path, bucket=bucket, binary=binary)
                    else:
                        if self._entity_type.db.model_store is not None:
                            self.logger.debug('Saving to KPI_MODEL_STORE, pickle_dump=true')
                            #entity_type_id is null because of db cached during training. When Monitor tries to score, entity_type_id is not null
                            if self._entity_type.db.model_store.entity_type_id is None:
                                # get_entity_type_id_by_entity_type_name(db,entity_type_name)
                                self._entity_type.db.model_store.entity_type_id = get_entity_type_id_by_entity_type_name(entity_type_name=self._entity_type.logical_name)
                            self.logger.debug('Retrieved entity type ID before saving model: db.model_store.entity_type_id=%s', self._entity_type.db.model_store.entity_type_id)

                            self._entity_type.db.model_store.store_model(model_name=model_path, model=new_model, user_name=None, serialize=True)
                else:
                    if self._entity_type.db.model_store is not None and cos_kpi is None:
                        self.logger.debug('Saving to KPI_MODEL_STORE, pickle_dump=false')
                        #entity_type_id is null because of db cached during training. When Monitor tries to score, entity_type_id is not null
                        if self._entity_type.db.model_store.entity_type_id is None:
                            self._entity_type.db.model_store.entity_type_id = get_entity_type_id_by_entity_type_name(entity_type_name=self._entity_type.logical_name)

                            self.logger.debug('Retrieved entity type ID before saving model: db.model_store.entity_type_id=%s', self._entity_type.db.model_store.entity_type_id)
                        self._entity_type.db.model_store.store_model(model_name=model_path, model=new_model, user_name=None, serialize=False)
                    else:
                        # work-around to be able to not pickle save to cos
                        self.logger.debug('Saving to COS , pickle_dump=false')
                        ret = self._entity_type.db.cos_client._cos_api_request('PUT', bucket=bucket, key=model_path, payload=new_model, binary=binary)
                        if ret is None:
                            self.logger.warning('Not able to PUT %s to COS bucket %s', model_path, bucket)
            except requests.exceptions.ReadTimeout as err:
                self.logger.warning('Timeout saving %s to cos: %s', model_path, err)

        self.logger.debug('Saved model to path: %s', model_path)

    def get_model_extra(self, new_model, model_path):
        '''Return extra objects to be saved along with the model as a list of (path, object, pickle_dump, binary) tuplies.

        Normal estimator only has one model object to be saved to COS. Some estimtors might want to save
        other objects, possibly caching/deriving from the model object, for other usage. You can override
        this method to return a list of such additional objects, in the form of (cos_path, object, pickle_dump, binary) tuple.

        It is recommended to construct your extra object cos_path based on the given model_path, with
        different suffix appended.
        '''
        return []

    def get_models_for_training(self, db, df, bucket=None):
        model_name, model = self._load_model(bucket=bucket)

        if model is not None:
            self.models[model_name] = model
            return []
        else:
            return [model]

    def get_models_for_predict(self, db, bucket=None):
        if len(self.predictions) == 0 or self.predictions[0] is None:
            return []
        else:
            return list(self.models.values())

    def conform_index(self,df,entity_id_col = None, timestamp_col = None):
        # workaround for avoiding base class adding columns
        return df

    def add_training_preprocessor(self, stage):
        if hasattr(self, '_entity_type') and self._entity_type is not None:
            stage.set_entity_type(self._entity_type)
        self.add_preprocessor(stage)

    def execute_training_preprocessing(self, df):
        if len(self._preprocessors) == 0:
            return df
        else:
            return super().execute_preprocessing(df)

    def get_df_for_training(self, df):
        features = [] + self.features
        if self.features_for_training is not None:
            features.extend(self.features_for_training)

        #df = df[self.features]
        df = df[features]
        df = df.reset_index(drop=True)

        return df

    def execute_train_test_split(self,df):
        # TODO disable splitting for now
        return (df, None)

    def get_df_for_prediction(self, df):
        df_for_prediction = df[self.features]
        self.logger.debug('Returning DataFrame for prediction: df_for_prediction=%s', log_df_info(df_for_prediction, head=5, logger=self.logger, log_level=logging.DEBUG))
        return df_for_prediction

    def predict(self, model, df):
        self.logger.debug('Running prediction on DF=%s', log_df_info(df, head=5, logger=self.logger, log_level=logging.DEBUG))
        # Following conditional block commented out on 09/07/2023 because SROM provides this in the base implementation. This is no longer necessary
        #if isinstance(model, SmartClassification):

            # https://github.ibm.com/maximo/Asset-Health-Insight/issues/13920
            #scores = model.predict_proba(df.values)

            #if len(np.shape(scores)) == 1: # distance measure and not probability
                #class_probability = compute_binary_class_probabilities(scores)
            #else:
                #class_probability = scores

            #return list(zip(model.predict(df.values), class_probability)) if model is not None else None
            
        #else:
        return list(zip(model.predict(df), model.predict_proba(df))) if model is not None else None

    def get_prediction_result_value_index(self):
        raise NotImplementedError()

    def process_prediction_result(self, df, prediction_result, model):
        if prediction_result is None:
            df[self.predictions[0]] = None

            self.logger.debug('No suitable model found. Created null predictions')
        else:
            for idx in self.get_prediction_result_value_index():
                if not all([isinstance(p, (tuple, list, np.ndarray)) for p in prediction_result]):
                    break

                try:
                    best_estimator = str(model.get_best_estimator())
                    self.logger.debug('process_prediction_result model.get_best_estimator()=%s', best_estimator)
                    if 'GaussianGraphicalModel' in best_estimator:
                        self.logger.debug('Processing as GaussianGraphicalModel output')
                        prediction_result = [np.max(p[idx]) if len(p) > idx else 0.0 for p in prediction_result]
                    else:
                        self.logger.debug('Processing as non-GaussianGraphicalModel output')
                        prediction_result = [p[idx] if len(p) > idx else 0.0 for p in prediction_result]
                except:
                    # custom model deployed in the WML
                    prediction_result = [p[idx] if len(p) > idx else 0.0 for p in prediction_result]

            df[self.predictions[0]] = prediction_result

        return df

    def execute(self, df=None, start_ts=None, end_ts=None, entities=None):
        self.df_traces = {}

        self.logger.debug('Running estimator on input DataFrame: df=%s', log_df_info(df, head=5, logger=self.logger, log_level=logging.DEBUG))
        self.df_traces['input'] = df.copy()

        db = self._entity_type.db
        bucket = self.get_bucket_name()

        self.training_timestamp = None

        # transform incoming data using any preprocessors
        # include whatever preprocessing stages are required by implementing a set_preprocessors method
        required_models = self.get_models_for_training(db=db, df=df, bucket=bucket)
        is_training = len(required_models) > 0
        if is_training:
            begin_date_row= df.head(1)
            end_date_row= df.tail(1)
            begin_date_row.reset_index(inplace = True)
            end_date_row.reset_index(inplace = True)
            begin_date_row.astype({self.event_timestamp_column_name:'str'})
            end_date_row.astype({self.event_timestamp_column_name:'str'})

            begin_date = str(begin_date_row.at[0,self.event_timestamp_column_name])
            end_date =   str(end_date_row.at[0,self.event_timestamp_column_name])
            time_range_json = {'begin_date': begin_date, 'end_date':  end_date}

            #self.model_df_traces['training_date_range'] = time_range_json

            training_date_range_df = pd.DataFrame()
            training_date_range_df = pd.concat([training_date_range_df, begin_date_row], ignore_index=True)
            training_date_range_df = pd.concat([training_date_range_df, end_date_row], ignore_index=True)
            training_date_range_df_result=training_date_range_df[self.event_timestamp_column_name]


            self.df_traces['training_date_range'] = training_date_range_df_result
            #self.model_df_traces['end_training_date'] = time_range_json

            self.logger.debug('Training date range: training_date_range_df_result=%s', log_df_info(training_date_range_df_result,head=-1, logger=self.logger, log_level=logging.DEBUG))

            # only do preprocessing and splitting once
            df_train = self.execute_training_preprocessing(df)
            self.df_traces['input_after_train_preprocess'] = df_train
            df_train = self.get_df_for_training(df_train)
            self.df_traces['to_train'] = df_train
            self.logger.debug('Dataframe to train: df=%s', log_df_info(df_train, head=5, logger=self.logger, log_level=logging.DEBUG))
            # df_train, df_test = self.execute_train_test_split(df_train)

            self.logger.info('Iterating over models to train...')
            for model in required_models:
                self.logger.info('Beginning model training for model: %s', model)

                new_model = self.train_model(df_train)

                self.training_timestamp = str(calendar.timegm(datetime.utcnow().timetuple()))

                self._save_model_df_trace(bucket=bucket, df_trace_dict=self.df_traces, suffix=self.training_timestamp, local=self.local_model, training=True)

                self._save_model(bucket=bucket, new_model=new_model, suffix=self.training_timestamp, local=self.local_model)

                # switch to the new one just trained
                self.model_timestamp = self.training_timestamp
        elif self.model_timestamp is not None:
            #self.training_timestamp = self.model_timestamp
            self.training_timestamp = self.model_timestamp
            filename = self.get_model_name(target_name=self._get_target_name(), suffix=self.training_timestamp) # save with explicity suffix timestamp set
            extras = self.get_model_extra(self.models[filename], filename)
            if len(extras) > 0:
                self.model_extras[filename].extend(extras)

        # Predictions

        prediction_model_df_traces = {} if is_training else self.df_traces.copy()
        df_for_prediction = None
        for idx, model in enumerate(self.get_models_for_predict(db=db, bucket=bucket)):
            # TODO deal with multiple predictions
            if df_for_prediction is None:
                df_for_prediction = self.get_df_for_prediction(df)
            self.df_traces['to_predict'] = prediction_model_df_traces['to_predict'] = df_for_prediction

            df_prediction = self.predict(model, df_for_prediction)
            df = self.process_prediction_result(df, df_prediction, model)
            
            # trim off repeat predictions- only want to write new predictions to table
            self.logger.debug('Trimming off repeated predictions using latest_prediction_timestamp=%s', self.latest_prediction_timestamp)
            if self.latest_prediction_timestamp is not None:
                df.reset_index(inplace=True)
                
                self.logger.debug(
                    'Prediction DF max, min dates before trimming old predictions: max=%s, min=%s', 
                    max(df[self.event_timestamp_column_name], default=None), 
                    min(df[self.event_timestamp_column_name], default=None))
                
                # convert self.event_timestamp_column_name to datetime64 because latest_prediction_timestamp is pd.Timestamp
                df[self.event_timestamp_column_name]=  pd.to_datetime(df[self.event_timestamp_column_name])
                
                df = df.loc[df[self.event_timestamp_column_name] > self.latest_prediction_timestamp]
                self.logger.debug(
                    'Prediction DF max, min dates after trimming old predictions: max=%s, min=%s', 
                    max(df[self.event_timestamp_column_name], default=None), 
                    min(df[self.event_timestamp_column_name], default=None))

                df.set_index(['id', self.event_timestamp_column_name], inplace=True)

        if df_for_prediction is None:
            # no prediction needed, return empty df
            df = df[[]]

        self.df_traces['output'] = prediction_model_df_traces['output'] = df.copy()
        self._save_model_df_trace(bucket=bucket, df_trace_dict=prediction_model_df_traces, suffix=self.training_timestamp, local=self.local_model, training=is_training)

        self.logger.debug('Final DF with predictions: %s', log_df_info(df, head=5, logger=self.logger, log_level=logging.DEBUG))

        return df


class WmlDeploymentEstimator(BaseEstimator):
    '''This estimator takes a WML deployed model and caches it on COS for prediction.

    The caching can be disabled in which case the prediction would be using the WML
    deployment.
    '''
    wml_client = None

    def __init__(self, features, targets, predictions, wml_credentials, wml_deployment_uid, \
                 result_value_index=None, cache_model=True, wml_deployment_space_name=None, \
                 wml_deployment_space_id = None, **kwargs):

        super().__init__(features, targets, predictions, **kwargs)

        for key in ('instance_id', 'url'):
            if not key in wml_credentials:
                raise ValueError('parameter wml_credentials missing mandatory key "%s"' % key)

        self.wml_credentials = wml_credentials
        self.wml_deployment_uid = wml_deployment_uid
        self.result_value_index = result_value_index
        self.cache_model = cache_model

        #cp4d config
        self.wml_deployment_space_name = wml_deployment_space_name
        self.wml_deployment_space_id = wml_deployment_space_id
        self.wml_model_uid = kwargs.get('wml_model_uid', None)

        # default to 1000, you can configure it
        self.chunk = kwargs.get('wml_chunk_size',1000)

        self.config_options = kwargs

        try:
            self.wml_client_version = kwargs['wml_client_version']
        except:
            self.wml_client_version = None

    def get_models_for_training(self, db, df, bucket=None):
        if self.cache_model:
            return super().get_models_for_training(db=db, df=df, bucket=bucket)
        else:
            return []

    def train_model(self, df):
        self._init_wml_executor()

        deployment_details = self.wml_client.deployments.get_details(self.wml_deployment_uid)
        #try:
        #    return self.wml_client.repository.load(deployment_details['entity']['deployable_asset']['guid'])
        #except:
        if self.wml_model_uid == None:
            self.wml_model_uid = self.wml_client.repository.load(deployment_details['entity']['asset']['id'])
        return self.wml_model_uid

    def get_models_for_predict(self, db, bucket=None):
        if self.cache_model:
            return super().get_models_for_predict(db=db, bucket=bucket)
        else:
            self._init_wml_executor()
            return [self.wml_client.deployments]

    def get_df_for_prediction(self, df):
        df = super().get_df_for_prediction(df=df)
        
        if self.cache_model:
            return df
        else:
            # filter only features
            df = df[self.features]

            # for any column not of numeric or boolean type, cast they to string
            df = df.astype({col:str for col in df.select_dtypes(exclude=[np.number, np.bool]).columns})
            self.logger.debug('df_for_prediction=%s', log_df_info(df, head=5, logger=self.logger, log_level=logging.DEBUG))

            payload_for_prediction = {'values': df.values.tolist()}
            self.logger.debug('payload_for_prediction=%s', payload_for_prediction['values'][:10])

            return payload_for_prediction

    def predict(self, model, df):
        self.logger.debug('Start of WmlDeploymentEstimator')
        #print('in WmlDeploymentEstimator predict='+str(df))
        if self.cache_model:
            return super().predict(model=model, df=df)
        else:
            self.logger.debug('WML Client Version = %s', self.wml_client_version)
            self.logger.debug('WML Deployment UID = %s', self.wml_deployment_uid)
            #if (self.wml_client_version is None) or (self.wml_client_version == "V4") : # remove this line in MAS 8.11 / Predict 8.9 and beyond
            ####
            #
            #
            #######

            payload_for_prediction=df

            
            
    
            self.logger.debug('chunk size= %s',str(self.chunk))
            predictions = {"predictions": [{ "fields": [], "values": []}]}

            total_records=len(payload_for_prediction['values'])
            self.logger.debug('total_records=%s',str(total_records))
            
            for x in range(math.ceil(total_records/self.chunk)) :
                st = self.chunk * x
                end = self.chunk * (x+1)
                if end > total_records :
                    end = total_records
                self.logger.debug("chunk #" + str(x))
        
                partitioned_df = payload_for_prediction['values'][st:end]
    
    
                wml_payload_for_prediction = {"fields": payload_for_prediction['fields'],'values': partitioned_df}
                #print(wml_payload_for_prediction)

            
                scoring_payload = {
                    self.wml_client.deployments.ScoringMetaNames.INPUT_DATA: [wml_payload_for_prediction]
                }
                #print(str(scoring_payload))
                
                predictions_chunk = self.wml_client.deployments.score(self.wml_deployment_uid, scoring_payload)

                predictions['predictions'][0]['values'].extend(predictions_chunk['predictions'][0]['values'])
            
            # for custom model
            #predictions['predictions'][1]['values'].extend(predictions_chunk['predictions'][0]['values'])
            predictions['predictions'][0]['fields'].extend(predictions_chunk['predictions'][0]['fields'])     
            #print('final_output',predictions)  

            return predictions  


            #scoring_payload = {
            #            self.wml_client.deployments.ScoringMetaNames.INPUT_DATA: [df]
            #       }
            #predictions = self.wml_client.deployments.score(self.wml_deployment_uid, scoring_payload)
            #return predictions




    def get_prediction_result_value_index(self):
        return self.result_value_index

    def process_prediction_result(self, df, prediction_result, model):
        if self.cache_model:
            return super().process_prediction_result(df=df, prediction_result=prediction_result, model=model)
        else:
            try:
                wml_v3_prediction_result=prediction_result['values']
                return super().process_prediction_result(df=df, prediction_result=prediction_result['values'], model=model)
            except:
                #print('wml_client_v4_prediction_result')
                return super().process_prediction_result(df=df, prediction_result=prediction_result['predictions'][0]['values'], model=model)

    def guid_from_space_name(self,client, space_name):
        #instance_details = client.service_instance.get_details()
        #space = client.spaces.get_details()
        #return(next(item for item in space['resources'] if item['entity']["name"] == space_name)['metadata']['guid'])
        space = client.spaces.get_details()
        try:
            meta = next(item for item in space['resources'] if item['entity']["name"] == space_name)['metadata']
        except StopIteration:
            print("ERROR: Deployment space " + space_name + ' not found.')
            return None
        if 'guid' in meta:
            return (meta['guid'])
        elif 'id' in meta:
            return (meta['id'])
        else :
            print("ERROR: Can't find deployment space id for " + space_name)
            return None

    def _init_wml_executor(self):
        # cache and reuse WML runtime
        #print('self.wml_client_version='+self.wml_client_version)
        import os
        if self.wml_client is None:
            self.wml_client = APIClient(self.wml_credentials)
            if self.wml_deployment_space_id is None:
                self.wml_deployment_space_id = self.guid_from_space_name(self.wml_client,self.wml_deployment_space_name)
        
            print('in self.wml_client is None, space_guid='+str(self.wml_deployment_space_id ))    
            self.wml_client.set.default_space(self.wml_deployment_space_id)
            #else:
                #if self.wml_client_version =="V4":

                    #self.wml_client = APIClient(self.wml_credentials)

                    #space_guid = self.guid_from_space_name(self.wml_client,self.wml_deployment_space_name)
                    #print('space_guid='+str(space_guid))
                    #self.wml_client.set.default_space(space_guid)

class WmlSPSSDeploymentEstimator(WmlDeploymentEstimator):
    '''This estimator takes a WML based SPSS stream as the deployed model for prediction.
    '''
    
    #def __init__(self, features, targets, predictions, wml_credentials, wml_deployment_uid, wml_deployed_model_id = None,\
                 #result_value_index=None, cache_model=True, wml_deployment_space_name=None, \
                 #wml_deployment_space_id = None, wml_deployment_mode = None, **kwargs):
         #super().__init__(features, targets, predictions, wml_credentials,wml_deployment_uid, \
                 #result_value_index, cache_model, wml_deployment_space_name, wml_deployment_space_id, **kwargs)
    
    def init_config(self, **kwargs):

        self.wml_deployed_model_id = kwargs.get('wml_deployed_model_uid', None)
        self.wml_deployment_mode = kwargs.get('wml_deployment_mode','online')

        self.logger = logging.getLogger()
        self.logger.setLevel(kwargs.get('log_level', 10))
        
        self.source_spss_mapping = kwargs['source_spss_mapping']
        self.asset_id_column_name = kwargs['asset_id_column_name']
        # Timestamp or date column name
        self.timestamp_column_name = kwargs['timestamp_column_name']
        # Timestamp format
        #self.timestamp_format = params.get('timestamp_format', None)
        self.spss_stream_input_schema = None
        self.spss_stream_output_schema = None
        self.anomly_score_output_field_name = kwargs['anomly_score_output_field_name']
        self.anomaly_threshold_output_field_name = kwargs['anomaly_threshold_output_field_name']
        self.anomaly_flag_output_field_name = kwargs['anomaly_flag_output_field_name']
        self.anomaly_flag_output_labels = kwargs.get('anomaly_flag_output_labels', None)
                        
        self.logger.debug('WmlSPSSDeploymentEstimator::init_config() - Finished initializing parameters')
    
    def get_wml_client(self):
        if self.wml_client == None:
            wml_wrapper = WMLScorer()
            wml_wrapper.connect(self.wml_credentials, self.wml_deployment_space_name)
            self.wml_client = wml_wrapper.wml_client()
        return self.wml_client

    def get_spss_stream_input_schema(self):
        wml_client = self.get_wml_client()
        if self.wml_deployed_model_id == None:
            wml_deployment_descriptor = None
            try:
                wml_deployment_descriptor = wml_client.deployments.get_details(self.wml_deployment_uid,\
                     self.wml_deployment_space_name)
                self.logger.debug('WmlSPSSDeploymentEstimator::get_spss_stream_input_schema(): Retrieved the WML deployment descriptor \n  = %s',\
                    json.dumps(wml_deployment_descriptor, indent = 2))
                self.wml_deployed_model_id = wml_deployment_descriptor['entity']['asset']['id']
                self.logger.debug('WmlSPSSDeploymentEstimator::get_spss_stream_input_schema(): Retrieved the deployed model ID = %s',\
                     self.wml_deployed_model_id)
            except ValueError:
                msg = "The WML deployment descriptor is not retrievable from WML. Verify the deployment space, deployment UID, and make sure the model deployment exists in WML"
                self.logger.error('WmlSPSSDeploymentEstimator::get_spss_stream_input_schema(): %s'+msg)
                raise RuntimeError(msg)
        if self.spss_stream_input_schema == None:
            wml_deployed_model_schema = wml_client.repository.get_details(self.wml_deployed_model_id)
            self.spss_stream_input_schema = {item['name']:item['type'] for item in wml_deployed_model_schema['entity']['schemas']['input'][0]['fields']}
            self.logger.debug('WmlSPSSDeploymentEstimator::get_spss_stream_input_schema(): %s %s ',\
                'SPSS Stream Input Schema as retrieved from WML = \n', self.spss_stream_input_schema)
            for key, val in self.spss_stream_input_schema.items():
                if val == 'integer':
                    self.spss_stream_input_schema[key] = 'int'
                elif val == 'double':
                    self.spss_stream_input_schema[key] = 'float'
                elif val == 'string':
                    self.spss_stream_input_schema[key] = 'str'
                else:
                    raise RuntimeError('From WmlSPSSDeploymentEstimator::get_spss_stream_input_schema(): Wrong input schema was provided for the SPSS based model')
            ## Also initialize the output schema in the same invocation 
            self.spss_stream_output_schema = {item['name']:item['type'] for item in wml_deployed_model_schema['entity']['schemas']['output'][0]['fields']}
        return self.spss_stream_input_schema

    def prepare_data_for_spss_stream(self,df):
        self.spss_stream_input_schema = self.get_spss_stream_input_schema()
        # Caution, the input dataframe is intentionally being modified inside
        df.reset_index(inplace=True,drop=False)
        df[self.timestamp_column_name]=df[self.timestamp_column_name].astype(str)
        self.logger.debug('WmlSPSSDeploymentEstimator::prepare_data_for_spss_stream(): %s ',\
                log_df_info(df, head=5, logger=self.logger, log_level=logging.DEBUG,\
                     comment = 'Input data as received = \n', include_missing_value_count = True))      
        spss_monitor_mapping = self.source_spss_mapping.copy()
        self.logger.debug('WmlSPSSDeploymentEstimator::prepare_data_for_spss_stream(): %s %s',\
            "SPSS to Monitor Mapping values as read: \n ", spss_monitor_mapping)
        spss_stream_input_variables = list(spss_monitor_mapping.keys())
        mapped_monitor_variables = list(spss_monitor_mapping.values())
        for spss_in in spss_stream_input_variables:
            if spss_monitor_mapping[spss_in] == '$derive_from_id$site_id':
                df[spss_in] = df[self.asset_id_column_name].map(lambda x: x.split('-____-')[1])
                mapped_monitor_variables.remove('$derive_from_id$site_id')
                spss_monitor_mapping[spss_in] = spss_in
        
        self.logger.debug('WmlSPSSDeploymentEstimator::prepare_data_for_spss_stream(): %s %s',\
            "SPSS to Monitor Mapping values after processing the directives: \n ", spss_monitor_mapping)
        
        df.rename({value:key for key, value in spss_monitor_mapping.items()}, axis = 1, inplace = True)
        
        self.logger.debug('WmlSPSSDeploymentEstimator::prepare_data_for_spss_stream(): %s %s',\
            'Input data frame columns AFTER renaming according to SPSS type mapping = \n', df.columns.values)
        
        self.logger.debug('WmlSPSSDeploymentEstimator::prepare_data_for_spss_stream(): %s %s',\
            'SPSS stream input schema with column names and types = \n', self.spss_stream_input_schema)
        df = df.astype(self.spss_stream_input_schema)
        payload_for_prediction = {'fields': df.columns.values.tolist(),'values': df.values}

        self.logger.debug('WmlSPSSDeploymentEstimator::prepare_data_for_spss_stream(): %s %s %s %s',\
            'Payload formatted for invoking the SPSS stream deployed on WML with payload size = ',\
                 payload_for_prediction['values'].shape,\
                    ' (only the first 10 rows are printed) \n',\
                        payload_for_prediction['values'][0:10])

        return payload_for_prediction
   
    def get_models_for_training(self, db, df, bucket=None):
        return super().get_models_for_training(db=db, df=df, bucket=bucket)

    def train_model(self, df):
        return super().train_model(df)

    def get_models_for_predict(self, db, bucket=None):
        return super().get_models_for_predict(db=db, bucket=bucket)

    def get_df_for_prediction(self, df):
        self.logger.debug('WmlSPSSDeploymentEstimator::get_df_for_prediction(): %s',\
            log_df_info(df, 5, comment = 'Input payload (df) = \n', include_missing_value_count = True))    
        df1 = self.prepare_data_for_spss_stream(df)
        self.logger.debug('WmlSPSSDeploymentEstimator::get_df_for_prediction(): %s %s %s %s %s %s %s',\
            'Data Frame after processing the raw input data frame = \n', list(df1.keys()), '\n',\
                 df1['values'].shape, '\n', ' (only the first 10 rows are printed) \n',\
                        df1['values'][0:10])   
        return df1

    def predict(self, model, df):
        return super().predict(model=model, df=df)

    def get_prediction_result_value_index(self):
        return super().get_prediction_result_value_index()

    def process_prediction_result(self, df, prediction_result, model):
        prediction_output_df = pd.DataFrame(prediction_result['predictions'][0]['values'], columns = prediction_result['predictions'][0]['fields'])
        
        # Monitor & DB2 based data lake needs the timestamp to be formatted like "2023-01-27-14.31.00.621000"
        #prediction_output_df[self.timestamp_column_name] = prediction_output_df[self.timestamp_column_name].map(lambda x: x.strftime('%Y-%m-%d %H.%M.%S'))
        prediction_output_df[self.timestamp_column_name] = pd.to_datetime(prediction_output_df[self.timestamp_column_name])
        
        self.logger.debug('WmlSPSSDeploymentEstimator::process_prediction_result(): %s ',\
                log_df_info(prediction_output_df, head=5, logger=self.logger, log_level=logging.DEBUG,\
                     comment = 'Prediction output data frame with results as received = \n',\
                         include_missing_value_count = True))
        
        prediction_output_df.rename({key:val for key, val in self.source_spss_mapping.items() if val != '$derive_from_id$site_id'}, axis = 1, inplace = True)
        
        self.logger.debug('WmlSPSSDeploymentEstimator::process_prediction_result(): %s ',\
                log_df_info(prediction_output_df, head=5, logger=self.logger, log_level=logging.DEBUG,\
                     comment = 'Prediction output data frame with results after column mapping = \n',\
                         include_missing_value_count = True))
        
        if self.anomaly_flag_output_labels != None:
            prediction_output_df[self.anomaly_flag_output_field_name] = \
            prediction_output_df[self.anomaly_flag_output_field_name].map(lambda x: False if x == self.anomaly_flag_output_labels['normal'] else True)
        
        self.logger.debug('WmlSPSSDeploymentEstimator::process_prediction_result(): %s ',\
                log_df_info(prediction_output_df, head=5, logger=self.logger, log_level=logging.DEBUG,\
                     comment = 'Prediction output data frame with output transformation (if any) = \n',\
                         include_missing_value_count = True))
        
        prediction_output_df.set_index([self.asset_id_column_name, self.timestamp_column_name], inplace = True)
        
        self.logger.debug('WmlSPSSDeploymentEstimator::process_prediction_result(): %s ',\
                log_df_info(prediction_output_df, head=5, logger=self.logger, log_level=logging.DEBUG,\
                     comment = 'Prediction output data frame with results and column mapping with multi-index = \n',\
                         include_missing_value_count = True))
        
        self.logger.debug("WmlSPSSDeploymentEstimator::process_prediction_result(): %s %s %s %s %s ",\
            "Comparison of source and result data frame shapes = \n", df.shape, ' == ',\
                prediction_output_df.shape, ' - Are they equal?')

        return prediction_output_df

    def guid_from_space_name(self,client, space_name):
        return super().guid_from_space_name(client, space_name)

    def _init_wml_executor(self):
        super()._init_wml_executor()

class SromEstimator(BaseEstimator):
    def __init__(self, features, targets, predictions, srom_training_options=None, override_training_stages=None, **kwargs):
        super().__init__(features=features, targets=targets, predictions=predictions, **kwargs)
        self._set_srom_training_options(srom_training_options)
        self.override_training_stages = override_training_stages
        self.temporal_agg_funcs = kwargs.get('temporal_features', {})
        self.rolling_window_size = kwargs.get('rolling_window_size',None)
        self.minimum_periods = kwargs.get('minimum_periods', None)
        self.simple_aggregation_functions=self.temporal_agg_funcs.get('simple_aggregation_functions',[])
        self.higher_order_aggregation_functions=self.temporal_agg_funcs.get('higher_order_aggregation_functions',[])
        self.advanced_aggregation_functions=self.temporal_agg_funcs.get('advanced_aggregation_functions',[])
        self.asset_id_column_name = kwargs.get('asset_id_column_name','id')
        self.timestamp_column_name = kwargs.get('timestamp_column_name','evt_timestamp')
        self.timestamp_format = kwargs.get('timestamp_format',None)

    def _set_srom_training_options(self, srom_training_options):
        self.srom_training_options = srom_training_options
        if self.srom_training_options is None:
            self.srom_training_options = {}
        if 'verbosity' not in self.srom_training_options:
            self.srom_training_options['verbosity'] = 'low'

    def train_model(self, df):
        self.logger.info('Pre-trained model instance = %s', str(self.pre_trained_model))
        if self.pre_trained_model is not None:
            self.logger.debug('Since pre-trained model is available, training is being skipped here')
            return self.pre_trained_model
        
        self.logger.info('A pre-trained model was not provided. Therefore beginning model training...')
        self.logger.debug('Input DF = %s', log_df_info(df, head=5, logger=self.logger, log_level=logging.DEBUG))
        self.logger.debug('srom_training_options=%s', self.srom_training_options)
        srom_pipeline = self.create_pipeline()
        srom_pipeline = self.configure_pipeline(srom_pipeline)
        self.logger.debug('override_training_stages=%s', self.override_training_stages)
        srom_pipeline.set_stages(self.get_stages(df) if self.override_training_stages is None else self.override_training_stages)

        if self.get_param_grid() is not None:
            self.srom_training_options['param_grid'] = self.get_param_grid()

        df_train = df
        if isinstance(srom_pipeline, AnomalyPipeline):
            if  not self.use_labeled_data: # Force the Unsupervised Learning mode
                self.logger.info('No labeled data. Using Unsupervised learning mode')
                df_train = df[self.features]
                validX = None
                validy = None
            elif len(self.features_for_training) > 0:
                self.logger.info('Found labeled data. Using Semisupervised learning mode')
                label = self.features_for_training[0]
                df_train = df[pd.isna(df[label])].drop(labels=label, axis=1, errors='ignore').reset_index(drop=True)
                validX = df[pd.notna(df[label])].drop(labels=label, axis=1, errors='ignore').reset_index(drop=True)
                validy = df[pd.notna(df[label])][label].reset_index(drop=True)
            else:# Examine this branch and change later. If no validation data then do Unsupervised
                # no validation data given, but SROM does expect it, we'll simply split
                # some training data into validation data
                df_train = df.reset_index(drop=True)
                from sklearn.model_selection import train_test_split
                df_train, validX = train_test_split(df_train, test_size=0.2)
                validy = np.zeros(validX.shape[0])


            self.logger.debug('Training data: trainX=%s', log_df_info(df_train, head=5, logger=self.logger, log_level=logging.DEBUG))
            self.logger.debug('Validation data: validX=%s', log_df_info(validX, head=5, logger=self.logger, log_level=logging.DEBUG))
            self.logger.debug('Validation data: validy=%s', log_df_info(validy, head=5, logger=self.logger, log_level=logging.DEBUG))
            if self.use_labeled_data:

                faildate_number=df['faildate'].sum()
                self.logger.debug('Number of failures = %s', faildate_number)
                if faildate_number == 0:
                    self.logger.error('Number of failure record is zero. Please set the failure history of work order in Maximo Manage. The failurecode must have value.')
                    raise RuntimeError("Number of failure record is zero. Please set the failure history of work order in Maximo Manage. The failurecode must have value.")

            # add model df traces
            self.df_traces['trainX'] = df_train
            if self.use_labeled_data:
                self.df_traces['validX'] = validX
                self.df_traces['validy'] = validy

            srom_pipeline.execute(
                trainX=df_train,
                validX=validX,
                validy=validy,
                **self.srom_training_options)

        else:
            srom_pipeline.execute(
                df_train,
                **self.srom_training_options)

        if srom_pipeline.get_best_estimator() is None:
            raise RuntimeError('Training failed to find the best estimator, try retraining with more data or try again directly if using random search.')

        srom_pipeline.fit(df_train)

        return srom_pipeline
    
    def get_df_for_prediction(self, df):
        self.logger.debug('SROMEstimator::get_df_for_prediction() df=%s', log_df_info(df, head=5, logger=self.logger, log_level=logging.DEBUG))
        df = super().get_df_for_prediction(df)
        if self.temporal_agg_funcs is not None:
            num_temporal_agg_functions = sum([len(self.temporal_agg_funcs[k]) for k in list(self.temporal_agg_funcs.keys())])
            if num_temporal_agg_functions > 0:
                df = create_temporal_features(df, self.features, rolling_window_size = self.rolling_window_size, minimum_periods = self.minimum_periods, \
                                    simple_agg_fns = self.simple_aggregation_functions, higher_order_agg_fns = self.higher_order_aggregation_functions, \
                                    adv_agg_fns = self.advanced_aggregation_functions, asset_id_column_name = self.asset_id_column_name, \
                                    timestamp_column_name = self.timestamp_column_name, timestamp_format = self.timestamp_format, \
                                    data_source_type = 'mas_monitor_data_lake', flatten_headers = True)
                df.set_index([self.asset_id_column_name, self.timestamp_column_name], inplace = True)
                self.logger.debug('SROMEstimator::get_df_for_prediction() AFTER computing the temporal features df=%s', log_df_info(df, head=5, logger=self.logger, log_level=logging.DEBUG))
        else:
            self.logger.debug('SROMEstimator::get_df_for_prediction() No feature creation needed. Returning the data frame with raw features df=%s', log_df_info(df, head=5, logger=self.logger, log_level=logging.DEBUG))
        return df
    
    def predict(self, model, df):
        self.logger.debug('SROMEstimator::predict() Intercepting the predict call to enable scoring time features if any df=%s', log_df_info(df, head=5, logger=self.logger, log_level=logging.DEBUG))
        df = self.get_df_for_prediction(df)
        return super().predict(model,df)
    
    def create_pipeline(self):
        return SROMPipeline()

    def configure_pipeline(self, srom_pipeline):
        return srom_pipeline

    def get_param_grid(self):
        return None

