# Licensed Materials - Property of IBM
# 5737-M66, 5900-AAA, 5900-AMG
# (C) Copyright IBM Corp. 2019, 2025 All Rights Reserved.
# US Government Users Restricted Rights - Use, duplication, or disclosure
# restricted by GSA ADP Schedule Contract with IBM Corp.

import numpy as np
import pandas as pd
import pytest
from sqlalchemy import func, select

from pandas.tseries.frequencies import to_offset
from pandas.tseries.offsets import DateOffset

from .. import api
from ..anomaly_detection import AnomalyDetectionAssetGroupPipeline
from ..util import current_directory, log_df_info


def test_anomaly_detection_prediction_backtrack(asset_group_id='abcd', iot_type=None):
    if iot_type is None:
        import math
        import random
        iot_type = 'abcdsensor_%05d' % math.floor(random.random() * 10**5)

    # test default summary
    group = AnomalyDetectionAssetGroupPipeline(
                asset_group_id=asset_group_id,
                model_pipeline={
                    'features': ['%s:axlevibration' % iot_type, '%s:axlemomentum' % iot_type],
                    'features_for_training': [':faildate'],
                    'predictions': ['anomaly_score', 'anomaly_threshold'],
                },
            )
    assert [[DateOffset(**{"hour": 0, "minute": 0, "second": 0, "microsecond": 0}), DateOffset(**{"days": 1})], []] == group.get_prediction_backtrack(group.pipeline_config)

    # test default summary without incremental update
    group = AnomalyDetectionAssetGroupPipeline(
                asset_group_id=asset_group_id,
                model_pipeline={
                    'features': ['%s:axlevibration' % iot_type, '%s:axlemomentum' % iot_type],
                    'features_for_training': [':faildate'],
                    'predictions': ['anomaly_score', 'anomaly_threshold'],
                },
                incremental_summary=False,
            )
    assert [[DateOffset(**{"hour": 0, "minute": 0, "second": 0, "microsecond": 0}), DateOffset(**{"days": 1})], [DateOffset(**{"hour": 0, "minute": 0, "second": 0, "microsecond": 0})]] == group.get_prediction_backtrack(group.pipeline_config, incremental=group.incremental_summary)

    # test custom hourly summary
    group = AnomalyDetectionAssetGroupPipeline(
                asset_group_id=asset_group_id,
                model_pipeline={
                    'features': ['%s:axlevibration' % iot_type, '%s:axlemomentum' % iot_type],
                    'features_for_training': [':faildate'],
                    'predictions': ['anomaly_score', 'anomaly_threshold'],
                },
                summary={
                    '${predictions[0]}': {
                        'hourly': {
                            'max': None,
                        },
                    },
                },
            )
    assert [[DateOffset(**{"minute": 0, "second": 0, "microsecond": 0}), DateOffset(**{"hours": 1})], []] == group.get_prediction_backtrack(group.pipeline_config)

    # test custom weekly summary
    group = AnomalyDetectionAssetGroupPipeline(
                asset_group_id=asset_group_id,
                model_pipeline={
                    'features': ['%s:axlevibration' % iot_type, '%s:axlemomentum' % iot_type],
                    'features_for_training': [':faildate'],
                    'predictions': ['anomaly_score', 'anomaly_threshold'],
                },
                summary={
                    '${predictions[0]}': {
                        'weekly': {
                            'max': None,
                        },
                    },
                },
            )
    assert [[DateOffset(**{"weekday": 6, "hour": 0, "minute": 0, "second": 0, "microsecond": 0}), DateOffset(n=1 if pd.Timestamp('today').weekday() == 6 else 2, **{"weeks": 1})], []] == group.get_prediction_backtrack(group.pipeline_config)

    # test custom monthly summary
    group = AnomalyDetectionAssetGroupPipeline(
                asset_group_id=asset_group_id,
                model_pipeline={
                    'features': ['%s:axlevibration' % iot_type, '%s:axlemomentum' % iot_type],
                    'features_for_training': [':faildate'],
                    'predictions': ['anomaly_score', 'anomaly_threshold'],
                },
                summary={
                    '${predictions[0]}': {
                        'monthly': {
                            'max': None,
                        },
                    },
                },
            )
    assert [[DateOffset(**{"day": 1, "hour": 0, "minute": 0, "second": 0, "microsecond": 0}), DateOffset(**{"months": 1})], []] == group.get_prediction_backtrack(group.pipeline_config)


def test_anomaly_detection_summary_config(asset_group_id='abcd', iot_type=None):
    if iot_type is None:
        import math
        import random
        iot_type = 'abcdsensor_%05d' % math.floor(random.random() * 10**5)

    # test default summary
    group = AnomalyDetectionAssetGroupPipeline(
                asset_group_id=asset_group_id,
                model_pipeline={
                    'features': ['%s:axlevibration' % iot_type, '%s:axlemomentum' % iot_type],
                    'features_for_training': [':faildate'],
                    'predictions': ['anomaly_score', 'anomaly_threshold'],
                },
            )
    assert group.post_processing == [
        {
            "functionName": "Maximum",
            "enabled": True,
            "granularity": "Daily",
            "output": {
                "name": "daily_anomaly_score"
            },
            "input": {
                "source": "anomaly_score"
            },
        }
    ]

    # test default summary with different prediction names
    group = AnomalyDetectionAssetGroupPipeline(
                asset_group_id=asset_group_id,
                model_pipeline={
                    'features': ['%s:axlevibration' % iot_type, '%s:axlemomentum' % iot_type],
                    'features_for_training': [':faildate'],
                    'predictions': ['ascore', 'anomaly_threshold'],
                },
            )
    assert group.post_processing == [
        {
            "functionName": "Maximum",
            "enabled": True,
            "granularity": "Daily",
            "output": {
                "name": "daily_ascore"
            },
            "input": {
                "source": "ascore"
            },
        }
    ]

    # test custom summary which override the default one, note here two methods are used
    group = AnomalyDetectionAssetGroupPipeline(
                asset_group_id=asset_group_id,
                model_pipeline={
                    'features': ['%s:axlevibration' % iot_type, '%s:axlemomentum' % iot_type],
                    'features_for_training': [':faildate'],
                    'predictions': ['anomaly_score', 'anomaly_threshold'],
                },
                summary={
                    '${predictions[1]}': {
                        'daily': {
                            'max': None,
                            'mean': '${granularity}_${data_item}_${method}',
                        },
                    },
                },
            )
    assert group.post_processing == [
        {
            "functionName": "Maximum",
            "enabled": True,
            "granularity": "Daily",
            "output": {
                "name": "daily_anomaly_threshold_max"
            },
            "input": {
                "source": "anomaly_threshold"
            },
        },
        {
            "functionName": "Mean",
            "enabled": True,
            "granularity": "Daily",
            "output": {
                "name": "daily_anomaly_threshold_mean"
            },
            "input": {
                "source": "anomaly_threshold"
            },
        }
    ]

    # test custom summary with multiple source data items and multiple granularity/methods
    group = AnomalyDetectionAssetGroupPipeline(
                asset_group_id=asset_group_id,
                model_pipeline={
                    'features': ['%s:axlevibration' % iot_type, '%s:axlemomentum' % iot_type],
                    'features_for_training': [':faildate'],
                    'predictions': ['anomaly_score', 'anomaly_threshold'],
                },
                summary={
                    '${predictions[0]}': {
                        'hourly': {
                            'max': '${granularity}_${data_item}',
                        },
                        'daily': {
                            'max': '${granularity}_${data_item}',
                            'mean': '${granularity}_${data_item}_${method}',
                        },
                    },
                    '${predictions[1]}': {
                        'daily': {
                            'first': '${granularity}_${data_item}',
                        },
                    },
                }
            )
    assert group.post_processing == [
        {
            "functionName": "Maximum",
            "enabled": True,
            "granularity": "Hourly",
            "output": {
                "name": "hourly_anomaly_score"
            },
            "input": {
                "source": "anomaly_score"
            },
        },
        {
            "functionName": "Maximum",
            "enabled": True,
            "granularity": "Daily",
            "output": {
                "name": "daily_anomaly_score"
            },
            "input": {
                "source": "anomaly_score"
            },
        },
        {
            "functionName": "Mean",
            "enabled": True,
            "granularity": "Daily",
            "output": {
                "name": "daily_anomaly_score_mean"
            },
            "input": {
                "source": "anomaly_score"
            },
        },
        {
            "functionName": "First",
            "enabled": True,
            "granularity": "Daily",
            "output": {
                "name": "daily_anomaly_threshold"
            },
            "input": {
                "source": "anomaly_threshold"
            },
        }
    ]

    # test validation, invalid index of predictions array
    with pytest.raises(ValueError):
        group = AnomalyDetectionAssetGroupPipeline(
                    asset_group_id=asset_group_id,
                    model_pipeline={
                        'features': ['%s:axlevibration' % iot_type, '%s:axlemomentum' % iot_type],
                        'features_for_training': [':faildate'],
                        'predictions': ['anomaly_score', 'anomaly_threshold'],
                    },
                    summary={
                        '${predictions[10]}': {
                            'daily': {
                                'max': None,
                            },
                        },
                    },
                )
    # test validation, duplicate data item
    with pytest.raises(ValueError):
        group = AnomalyDetectionAssetGroupPipeline(
                    asset_group_id=asset_group_id,
                    model_pipeline={
                        'features': ['%s:axlevibration' % iot_type, '%s:axlemomentum' % iot_type],
                        'features_for_training': [':faildate'],
                        'predictions': ['anomaly_score', 'anomaly_threshold'],
                    },
                    summary={
                        '${predictions[0]}': {
                            'daily': {
                                'max': '${granularity}_${data_item}',
                                'min': '${granularity}_${data_item}',
                            },
                        },
                    },
                )
    # test validation, duplicate data item
    with pytest.raises(ValueError):
        group = AnomalyDetectionAssetGroupPipeline(
                    asset_group_id=asset_group_id,
                    model_pipeline={
                        'features': ['%s:axlevibration' % iot_type, '%s:axlemomentum' % iot_type],
                        'features_for_training': [':faildate'],
                        'predictions': ['anomaly_score', 'anomaly_threshold'],
                    },
                    summary={
                        '${predictions[0]}': {
                            'daily': {
                                'max': '${granularity}',
                            }
                        },
                        '${predictions[1]}': {
                            'daily': {
                                'max': '${granularity}',
                            },
                        },
                    },
                )
    # test validation, invalid granularity
    with pytest.raises(ValueError):
        group = AnomalyDetectionAssetGroupPipeline(
                    asset_group_id=asset_group_id,
                    model_pipeline={
                        'features': ['%s:axlevibration' % iot_type, '%s:axlemomentum' % iot_type],
                        'features_for_training': [':faildate'],
                        'predictions': ['anomaly_score', 'anomaly_threshold'],
                    },
                    summary={
                        '${predictions[0]}': {
                            'abc': {
                                'max': None,
                            },
                        },
                    },
                )
    # test validation, invalid agg method
    with pytest.raises(ValueError):
        group = AnomalyDetectionAssetGroupPipeline(
                    asset_group_id=asset_group_id,
                    model_pipeline={
                        'features': ['%s:axlevibration' % iot_type, '%s:axlemomentum' % iot_type],
                        'features_for_training': [':faildate'],
                        'predictions': ['anomaly_score', 'anomaly_threshold'],
                    },
                    summary={
                        '${predictions[0]}': {
                            'daily': {
                                'abc': None,
                            },
                        },
                    },
                )


def test_anomaly_detection_model_config(asset_group_id='abcd', iot_type=None):
    if iot_type is None:
        import math
        import random
        iot_type = 'abcdsensor_%05d' % math.floor(random.random() * 10**5)

    group = AnomalyDetectionAssetGroupPipeline(
                asset_group_id=asset_group_id,
                model_pipeline={
                    'features': ['%s:axlevibration' % iot_type, '%s:axlemomentum' % iot_type],
                    'features_for_training': [':faildate'],
                    'predictions': ['anomaly_score', 'anomaly_threshold'],
                },
            )
    assert 20 == group.pipeline_config['pre_failure_window_size']
    assert 10 == group.pipeline_config['pre_failure_failure_size']

    group = AnomalyDetectionAssetGroupPipeline(
                asset_group_id=asset_group_id,
                model_pipeline={
                    'features': ['%s:axlevibration' % iot_type, '%s:axlemomentum' % iot_type],
                    'features_for_training': [':faildate'],
                    'predictions': ['anomaly_score', 'anomaly_threshold'],
                    'pre_failure_window_size': 40,
                },
            )
    assert 40 == group.pipeline_config['pre_failure_window_size']
    assert 10 == group.pipeline_config['pre_failure_failure_size']

    group = AnomalyDetectionAssetGroupPipeline(
                asset_group_id=asset_group_id,
                model_pipeline={
                    'features': ['%s:axlevibration' % iot_type, '%s:axlemomentum' % iot_type],
                    'features_for_training': [':faildate'],
                    'predictions': ['anomaly_score', 'anomaly_threshold'],
                    'pre_failure_failure_size': 8,
                },
            )
    assert 20 == group.pipeline_config['pre_failure_window_size']
    assert 8 == group.pipeline_config['pre_failure_failure_size']

    # pre_failure_failure_size must not be larger than pre_failure_window_size
    with pytest.raises(ValueError):
        group = AnomalyDetectionAssetGroupPipeline(
                    asset_group_id=asset_group_id,
                    model_pipeline={
                        'features': ['%s:axlevibration' % iot_type, '%s:axlemomentum' % iot_type],
                        'features_for_training': [':faildate'],
                        'predictions': ['anomaly_score', 'anomaly_threshold'],
                        'pre_failure_window_size': 20,
                        'pre_failure_failure_size': 21,
                    },
                )


def test_anomaly_detection_data_substitution(asset_group_id='abcd', iot_type=None):
    """
    The asset data has 2 assets while only one of them has device mapping. This tests whether the asset 
    without device mapping is excluded from the training/scoring since sensort data is required for 
    this pipeline.
    """

    if iot_type is None:
        import math
        import random
        iot_type = 'abcdsensor_%05d' % math.floor(random.random() * 10**5)

    df_data_asset = pd.read_csv('%s/trainbrake_asset_faildates.csv' % current_directory(file=__file__), parse_dates=['faildate'])
    df_data_asset['asset'] = np.where(df_data_asset['asset'] == 'TRAINBRAKE1', 'abcd-1', 'abcd-2')
    df_data_asset['assetid'] = df_data_asset['asset'] + '-____-' + df_data_asset['site']
    df_data_asset['datetime'] = df_data_asset['faildate']
    df_data_sensor = pd.read_csv('%s/trainbrake_device_data.csv' % current_directory(file=__file__), parse_dates=['RCV_TIMESTAMP_UTC'])
    df_data_sensor['DEVICETYPE'] = iot_type
    df_data_sensor['DEVICEID'] = np.where(df_data_sensor['DEVICEID'] == 'TrainBrake_1', 'abcd-1', 'abcd-2')

    from sklearn.ensemble import IsolationForest
    from sklearn.preprocessing import StandardScaler, RobustScaler, MinMaxScaler
    from srom.anomaly_detection.generalized_anomaly_model import GeneralizedAnomalyModel
    from srom.utils.no_op import NoOp

    group = AnomalyDetectionAssetGroupPipeline(
                asset_group_id=asset_group_id,
                model_pipeline={
                    'features': ['%s:axlevibration' % iot_type, '%s:axlemomentum' % iot_type],
                    'features_for_training': [':faildate'],
                    'predictions': ['anomaly_score', 'anomaly_threshold'],
                    'srom_training_options': {
                        'exectype': 'single_node_random_search'
                    },
                    'override_training_stages': [
                        [
                            ('skipscaling', NoOp()), 
                            ('standardscaler', StandardScaler()),
                            ('robustscaler', RobustScaler()), 
                            ('minmaxscaling', MinMaxScaler()),
                        ],
                        [
                            ('isolationforest', GeneralizedAnomalyModel(base_learner=IsolationForest(), predict_function='decision_function', score_sign=-1)), 
                        ],
                    ],
                },
                asset_device_mappings={
                    'abcd-1-____-BEDFORD': ['%s:abcd-1' % iot_type], 
                    'abcd-2-____-BEDFORD': [],
                },
                data_substitution={
                    '': [
                        {
                            'df': df_data_asset,
                            'keys': ['assetid'],
                            'columns': ['faildate'],
                            'timestamp': 'datetime',
                        },
                    ],
                    iot_type: [
                        {
                            'df': df_data_sensor,
                            'keys': ['DEVICEID'],
                            'columns': [
                                'TRAINBRAKESIMULATION_AXLEVIBRATION',
                                'TRAINBRAKESIMULATION_AXLEMOMENTUM',
                            ],
                            'timestamp': 'RCV_TIMESTAMP_UTC',
                            'rename_columns': {
                                'TRAINBRAKESIMULATION_AXLEVIBRATION': 'axlevibration',
                                'TRAINBRAKESIMULATION_AXLEMOMENTUM': 'axlemomentum',
                            },
                        },
                    ],
                },
            )
    df = group.execute()

    print(log_df_info(df, head=0))

    assert group.new_training
    assert False == group.training
    assert 'AnomalyDetectionEstimator' in group.model_timestamp
    assert (13571, 2) == df.shape
    assert 'anomaly_score' in df.columns
    assert 'anomaly_threshold' in df.columns


if __name__ == '__main__':
    test_anomaly_detection_prediction_backtrack()
    test_anomaly_detection_summary_config()
    test_anomaly_detection_model_config()
    test_anomaly_detection_data_substitution()
