# Licensed Materials - Property of IBM
# 5737-M66, 5900-AAA, 5900-AMG
# (C) Copyright IBM Corp. 2019, 2025 All Rights Reserved.
# US Government Users Restricted Rights - Use, duplication, or disclosure
# restricted by GSA ADP Schedule Contract with IBM Corp.

import numpy as np
import pandas as pd
import pytest

from .. import api
from ..anomaly_detection import AnomalyDetectionAssetGroupPipeline, AnomalyDetectionEstimator
from ..transformer import IdentifyPreFailureWindow, TransformNotNaToEvent
from ..pipeline import SimpleCustomAssetGroupPipelineLoader
from ..util import current_directory, log_df_info


class MyAnomalyDetectionEstimator(AnomalyDetectionEstimator):
    def get_stages(self, df):
        from sklearn.ensemble import IsolationForest
        from sklearn.preprocessing import StandardScaler, RobustScaler, MinMaxScaler
        from srom.anomaly_detection.generalized_anomaly_model import GeneralizedAnomalyModel
        from srom.utils.no_op import NoOp

        return [
            [
                ('skipscaling', NoOp()), 
                ('standardscaler', StandardScaler()),
                ('robustscaler', RobustScaler()), 
                ('minmaxscaling', MinMaxScaler())
            ],
            [
                # Rule/Density based Anomaly Models
                ('isolationforest', GeneralizedAnomalyModel(base_learner=IsolationForest(), predict_function='decision_function', score_sign=-1)), 

            ]
        ]


class MyAnomalyDetectionAssetGroupPipeline(AnomalyDetectionAssetGroupPipeline):
    def prepare_execute(self, pipeline, model_config):
        estimator = MyAnomalyDetectionEstimator(**model_config)
        if len(model_config['features_for_training']) > 0:
            estimator.add_training_preprocessor(TransformNotNaToEvent(model_config['features_for_training'][0]))
            estimator.add_training_preprocessor(IdentifyPreFailureWindow(model_config['features_for_training'][0], pre_failure_window_size=model_config.get('pre_failure_window_size', 20), pre_failure_failure_size=model_config.get('pre_failure_failure_size', 10)))
        pipeline.add_stage(estimator)


def test_simple_custom_model_loading(asset_group_id='abcd', iot_type=None):
    if iot_type is None:
        import math
        import random
        iot_type = 'abcdsensor_%05d' % math.floor(random.random() * 10**5)

    df_data_asset = pd.read_csv('%s/trainbrake_asset_faildates.csv' % current_directory(file=__file__), parse_dates=['faildate'])
    df_data_asset['asset'] = np.where(df_data_asset['asset'] == 'TRAINBRAKE1', 'abcd-1', 'abcd-2')
    df_data_asset['assetid'] = df_data_asset['asset'] + '-____-' + df_data_asset['site']
    df_data_asset['datetime'] = df_data_asset['faildate']
    df_data_sensor = pd.read_csv('%s/trainbrake_device_data.csv' % current_directory(file=__file__), parse_dates=['RCV_TIMESTAMP_UTC'])
    df_data_sensor['DEVICETYPE'] = iot_type
    df_data_sensor['DEVICEID'] = np.where(df_data_sensor['DEVICEID'] == 'TrainBrake_1', 'abcd-1', 'abcd-2')

    group = MyAnomalyDetectionAssetGroupPipeline(
                asset_group_id=asset_group_id, 
                model_pipeline={
                    'features': ['%s:axlevibration' % iot_type, '%s:axlemomentum' % iot_type],
                    'features_for_training': [':faildate'],
                    'predictions': ['anomaly_score_custom', 'anomaly_threshold_custom'],
                    'srom_training_options': {
                        'exectype': 'single_node_random_search'
                    },
                },
                asset_device_mappings={
                    'abcd-1-____-BEDFORD': ['%s:abcd-1' % iot_type], 
                    'abcd-2-____-BEDFORD': [],
                },
                data_substitution={
                    '': [
                        {
                            'df': df_data_asset,
                            'keys': ['assetid'],
                            'columns': ['faildate'],
                            'timestamp': 'datetime'
                        },
                    ],
                    iot_type: [
                        {
                            'df': df_data_sensor,
                            'keys': ['DEVICEID'],
                            'columns': [
                                'TRAINBRAKESIMULATION_AXLEVIBRATION',
                                'TRAINBRAKESIMULATION_AXLEMOMENTUM',
                            ],
                            'timestamp': 'RCV_TIMESTAMP_UTC',
                            'rename_columns': {
                                'TRAINBRAKESIMULATION_AXLEVIBRATION': 'axlevibration',
                                'TRAINBRAKESIMULATION_AXLEMOMENTUM': 'axlemomentum',
                            }
                        },
                    ],
                })
    df = group.execute()

    print(log_df_info(df, head=0))

    assert group.new_training
    assert False == group.training
    assert 'MyAnomalyDetectionEstimator' in group.model_timestamp
    assert (13571, 2) == df.shape
    assert 'anomaly_score_custom' in df.columns
    assert 'anomaly_threshold_custom' in df.columns

    group.register()

    # after normal training for a custom model is done, use simple custom model loader to load 
    # it back for scoring
    group = SimpleCustomAssetGroupPipelineLoader(
                target_pipeline_name=MyAnomalyDetectionAssetGroupPipeline.__name__,
                asset_group_id=asset_group_id, 
                model_pipeline={
                    'features': ['%s:axlevibration' % iot_type, '%s:axlemomentum' % iot_type],
                    'features_for_training': [':faildate'],
                    'predictions': ['anomaly_score_custom', 'anomaly_threshold_custom'],
                    'srom_training_options': {
                        'exectype': 'single_node_random_search'
                    },
                },
                asset_device_mappings={
                    'abcd-1-____-BEDFORD': ['%s:abcd-1' % iot_type], 
                    'abcd-2-____-BEDFORD': [],
                },
                data_substitution={
                    iot_type: [
                        {
                            'df': df_data_sensor,
                            'keys': ['DEVICEID'],
                            'columns': [
                                'TRAINBRAKESIMULATION_AXLEVIBRATION',
                                'TRAINBRAKESIMULATION_AXLEMOMENTUM',
                            ],
                            'timestamp': 'RCV_TIMESTAMP_UTC',
                            'rename_columns': {
                                'TRAINBRAKESIMULATION_AXLEVIBRATION': 'axlevibration',
                                'TRAINBRAKESIMULATION_AXLEMOMENTUM': 'axlemomentum',
                            }
                        },
                    ],
                },
                model_timestamp=group.model_timestamp)
    df = group.execute()

    print(log_df_info(df, head=0))

    assert (13566, 2) == df.shape
    assert 'anomaly_score_custom' in df.columns
    assert 'anomaly_threshold_custom' in df.columns

