# Licensed Materials - Property of IBM
# 5737-M66, 5900-AAA, 5900-AMG
# (C) Copyright IBM Corp. 2019, 2025 All Rights Reserved.
# US Government Users Restricted Rights - Use, duplication, or disclosure
# restricted by GSA ADP Schedule Contract with IBM Corp.

import numpy as np
import pandas as pd
import pytest

from sqlalchemy import func, select

from .. import api
from ..failure_prediction import FailurePredictionAssetGroupPipeline
from ..util import current_directory, log_df_info


def test_resample_config(asset_group_id='abcd', iot_type=None):
    if iot_type is None:
        import math
        import random
        iot_type = 'abcdsensor_%05d' % math.floor(random.random() * 10**5)

    df_data_asset = pd.read_csv('%s/trainbrake_asset_faildates.csv' % current_directory(file=__file__), parse_dates=['faildate'])
    df_data_asset['asset'] = np.where(df_data_asset['asset'] == 'TRAINBRAKE1', 'abcd-1', 'abcd-2')
    df_asset_group = df_data_asset.groupby(['site', 'asset']).size().reset_index()[['site', 'asset']]

    df_data_sensor = pd.read_csv('%s/trainbrake_device_data.csv' % current_directory(file=__file__), parse_dates=['RCV_TIMESTAMP_UTC'])[:10000]
    df_data_sensor['DEVICEID'] = np.where(df_data_sensor['DEVICEID'] == 'TrainBrake_1', 'abcd-1', 'abcd-2')

    df_mappings = pd.DataFrame(
        columns=['site', 'asset', 'devicetype', 'deviceid'],
        data=[
            ['BEDFORD', 'abcd-1', iot_type, 'abcd-1'],
        ],
    )

    try:
        db = api._get_db()
        db_schema = None

        api.set_asset_group_members(asset_group_id=asset_group_id, df=df_asset_group, db=db, db_schema=db_schema)

        api.set_asset_device_mappings(df=df_mappings, db=db, db_schema=db_schema)

        api.set_asset_cache(df=df_data_asset, siteid_column='site', assetid_column='asset', faildate_column='faildate', db=db, db_schema=db_schema)

        api.setup_iot_type(iot_type, df_data_sensor, columns=['TRAINBRAKESIMULATION_AXLEVIBRATION', 'TRAINBRAKESIMULATION_AXLEMOMENTUM'], deviceid_column='DEVICEID', timestamp_column='RCV_TIMESTAMP_UTC', timestamp_in_payload=False, parse_dates=None, rename_columns={'TRAINBRAKESIMULATION_AXLEVIBRATION': 'axlevibration', 'TRAINBRAKESIMULATION_AXLEMOMENTUM': 'axlemomentum'}, write=None, use_wiotp=False, import_only=False, db=db, db_schema=db_schema)

        model_pipeline_base = {
            'features': ['%s:axlevibration' % iot_type, '%s:axlemomentum' % iot_type],
            'features_for_training': [':faildate'],
            'predictions': ['failure_probability', 'rca_path'],
            'aggregation_methods': ['mean', 'max', 'min'],
            'prediction_window_size': '5d',
        }

        # unknown/invalid feature type used in resample config
        model_pipeline = model_pipeline_base.copy()
        model_pipeline['features_resampled'] = {
            iot_type + '_2': {
            },
        }
        with pytest.raises(ValueError):
            group = FailurePredictionAssetGroupPipeline(asset_group_id=asset_group_id, model_pipeline=model_pipeline)
        # error as long as one type is invalid
        model_pipeline = model_pipeline_base.copy()
        model_pipeline['features_resampled'] = {
            iot_type + '_2': {
            },
            iot_type: {
            }
        }
        with pytest.raises(ValueError):
            group = FailurePredictionAssetGroupPipeline(asset_group_id=asset_group_id, model_pipeline=model_pipeline)

        # invalid resample config, must be a dict
        model_pipeline = model_pipeline_base.copy()
        model_pipeline['features_resampled'] = {
            iot_type: 123,
        }
        with pytest.raises(ValueError):
            group = FailurePredictionAssetGroupPipeline(asset_group_id=asset_group_id, model_pipeline=model_pipeline)

        # invalid resample config, must has key '${freqency}'
        model_pipeline = model_pipeline_base.copy()
        model_pipeline['features_resampled'] = {
            iot_type: {
            },
        }
        with pytest.raises(ValueError):
            group = FailurePredictionAssetGroupPipeline(asset_group_id=asset_group_id, model_pipeline=model_pipeline)

        # invalid resample config freqency, must be a valid offset (or special ones supported by iotfuncitons)
        model_pipeline = model_pipeline_base.copy()
        model_pipeline['features_resampled'] = {
            iot_type: {
                '${freqency}': 'abc',
            },
        }
        with pytest.raises(ValueError):
            group = FailurePredictionAssetGroupPipeline(asset_group_id=asset_group_id, model_pipeline=model_pipeline)

        # with just frequency specified, default methods and outputs of all features are automatically generaed
        model_pipeline = model_pipeline_base.copy()
        model_pipeline['features_resampled'] = {
            iot_type: {
                '${freqency}': '1T',
            },
        }
        group = FailurePredictionAssetGroupPipeline(asset_group_id=asset_group_id, model_pipeline=model_pipeline)
        assert group.pipeline_config['features_resampled'] == {iot_type: ('1T' if db.db_type == 'sqlite' else '1min', {'axlevibration': ['mean'], 'axlemomentum': ['mean']}, {'axlevibration': ['axlevibration'], 'axlemomentum': ['axlemomentum']})}

        # non-numeric feature default to use 'max' method
        model_pipeline = model_pipeline_base.copy()
        model_pipeline['features'] = ['%s:axlevibration' % iot_type, '%s:axlemomentum' % iot_type, '%s:deviceid' % iot_type]
        model_pipeline['features_resampled'] = {
            iot_type: {
                '${freqency}': '1T',
            },
        }
        group = FailurePredictionAssetGroupPipeline(asset_group_id=asset_group_id, model_pipeline=model_pipeline)
        assert group.pipeline_config['features_resampled'] == {iot_type: ('1T' if db.db_type == 'sqlite' else '1min', {'axlevibration': ['mean'], 'axlemomentum': ['mean'], 'deviceid': ['max']}, {'axlevibration': ['axlevibration'], 'axlemomentum': ['axlemomentum'], 'deviceid': ['deviceid']})}

        # error if any data item name is invalid
        model_pipeline = model_pipeline_base.copy()
        model_pipeline['features_resampled'] = {
            iot_type: {
                '${freqency}': '1T',
                'abc': {
                }
            },
        }
        with pytest.raises(ValueError):
            group = FailurePredictionAssetGroupPipeline(asset_group_id=asset_group_id, model_pipeline=model_pipeline)

        # data item given should have the output name composed correctly
        # data item not specified get their default method/output automatically generated
        model_pipeline = model_pipeline_base.copy()
        model_pipeline['features_resampled'] = {
            iot_type: {
                '${freqency}': '1T',
                'axlevibration': {
                    'max': '${data_item}_${method}',
                    'mean': None,
                    'min': None,
                },
            },
        }
        group = FailurePredictionAssetGroupPipeline(asset_group_id=asset_group_id, model_pipeline=model_pipeline)
        assert group.pipeline_config['features_resampled'] == {iot_type: ('1T' if db.db_type == 'sqlite' else '1min', {'axlevibration': ['max', 'mean', 'min'], 'axlemomentum': ['mean']}, {'axlevibration': ['axlevibration_max', 'axlevibration_mean', 'axlevibration_min'], 'axlemomentum': ['axlemomentum']})}

        # data item given should have the output name composed correctly
        # data item not specified get their default method/output automatically generated
        model_pipeline = model_pipeline_base.copy()
        model_pipeline['features_resampled'] = {
            iot_type: {
                '${freqency}': '1T',
                'axlevibration': {
                    'max': '${data_item}_${method}',
                    'mean': '${data_item}',
                    'min': 'vibration_min',
                },
            },
        }
        group = FailurePredictionAssetGroupPipeline(asset_group_id=asset_group_id, model_pipeline=model_pipeline)
        assert group.pipeline_config['features_resampled'] == {iot_type: ('1T' if db.db_type == 'sqlite' else '1min', {'axlevibration': ['max', 'mean', 'min'], 'axlemomentum': ['mean']}, {'axlevibration': ['axlevibration_max', 'axlevibration', 'vibration_min'], 'axlemomentum': ['axlemomentum']})}

        # all details are given
        model_pipeline = model_pipeline_base.copy()
        model_pipeline['features_resampled'] = {
            iot_type: {
                '${freqency}': '1T',
                'axlevibration': {
                    'max': '${data_item}_${method}',
                    'mean': '${data_item}',
                    'min': 'vibration_min',
                },
                'axlemomentum': {
                    'count': None,
                    'sum': None,
                },
            },
        }
        group = FailurePredictionAssetGroupPipeline(asset_group_id=asset_group_id, model_pipeline=model_pipeline)
        assert group.pipeline_config['features_resampled'] == {iot_type: ('1T' if db.db_type == 'sqlite' else '1min', {'axlevibration': ['max', 'mean', 'min'], 'axlemomentum': ['count', 'sum']}, {'axlevibration': ['axlevibration_max', 'axlevibration', 'vibration_min'], 'axlemomentum': ['axlemomentum_count', 'axlemomentum_sum']})}

        # everythin right, but with an invalid extra data item specified, still an error
        model_pipeline = model_pipeline_base.copy()
        model_pipeline['features_resampled'] = {
            iot_type: {
                '${freqency}': '1T',
                'axlevibration': {
                    'max': '${data_item}_${method}',
                    'mean': '${data_item}',
                    'min': 'vibration_min',
                },
                'axlemomentum': {
                    'count': None,
                    'sum': None,
                },
                'abc': {
                }
            },
        }
        with pytest.raises(ValueError):
            group = FailurePredictionAssetGroupPipeline(asset_group_id=asset_group_id, model_pipeline=model_pipeline)
    except:
        raise
    finally:
        api.delete_iot_type(iot_type, use_wiotp=False, db=db, db_schema=db_schema)

        api.delete_asset_cache(df=df_data_asset, db=db, db_schema=db_schema)

        api.delete_asset_device_mappings(df=df_mappings, db=db, db_schema=db_schema)

        api.delete_asset_group_members(asset_group_id=asset_group_id, db=db, db_schema=db_schema)


def test_resample(asset_group_id='abcd', iot_type=None):
    if iot_type is None:
        import math
        import random
        iot_type = 'abcdsensor_%05d' % math.floor(random.random() * 10**5)

    df_data_asset = pd.read_csv('%s/trainbrake_asset_faildates.csv' % current_directory(file=__file__), parse_dates=['faildate'])
    df_data_asset['asset'] = np.where(df_data_asset['asset'] == 'TRAINBRAKE1', 'abcd-1', 'abcd-2')
    df_asset_group = df_data_asset.groupby(['site', 'asset']).size().reset_index()[['site', 'asset']]

    df_data_sensor = pd.read_csv('%s/trainbrake_device_data.csv' % current_directory(file=__file__), parse_dates=['RCV_TIMESTAMP_UTC'])[:10000]
    df_data_sensor['DEVICEID'] = np.where(df_data_sensor['DEVICEID'] == 'TrainBrake_1', 'abcd-1', 'abcd-2')

    df_mappings = pd.DataFrame(
        columns=['site', 'asset', 'devicetype', 'deviceid'],
        data=[
            ['BEDFORD', 'abcd-1', iot_type, 'abcd-1'],
        ],
    )

    try:
        db = api._get_db()
        db_schema = None

        api.set_asset_group_members(asset_group_id=asset_group_id, df=df_asset_group, db=db, db_schema=db_schema)

        api.set_asset_device_mappings(df=df_mappings, db=db, db_schema=db_schema)

        api.set_asset_cache(df=df_data_asset, siteid_column='site', assetid_column='asset', faildate_column='faildate', db=db, db_schema=db_schema)

        api.setup_iot_type(iot_type, df_data_sensor, columns=['TRAINBRAKESIMULATION_AXLEVIBRATION', 'TRAINBRAKESIMULATION_AXLEMOMENTUM'], deviceid_column='DEVICEID', timestamp_column='RCV_TIMESTAMP_UTC', timestamp_in_payload=False, parse_dates=None, rename_columns={'TRAINBRAKESIMULATION_AXLEVIBRATION': 'axlevibration', 'TRAINBRAKESIMULATION_AXLEMOMENTUM': 'axlemomentum'}, write='deletefirst', use_wiotp=False, import_only=False, db=db, db_schema=db_schema)

        group = FailurePredictionAssetGroupPipeline(
                    asset_group_id=asset_group_id,
                    model_pipeline={
                        'features': ['%s:axlevibration' % iot_type, '%s:axlemomentum' % iot_type],
                        'features_for_training': [':faildate'],
                        'predictions': ['failure_probability', 'rca_path'],
                        'features_resampled': {
                            iot_type: {
                                '${freqency}': '1T',
                                'axlevibration': {
                                    'max': None,
                                    'mean': None,
                                    'min': None,
                                },
                            },
                        },
                        'aggregation_methods': ['mean', 'max', 'min', 'median', 'std', 'sum', 'count'],
                        'prediction_window_size': '5d',
                    },
                )
        df = group.execute()

        print(log_df_info(df, head=0))

        assert group.new_training
        assert False == group.training
        assert 'FailurePredictionEstimator' in group.model_timestamp and 'FailurePredictionRcaEstimator' in group.model_timestamp
        assert (103, 2) == df.shape
        assert 'failure_probability_5d' in df.columns
        assert 'rca_path_5d' in df.columns

        target_tables = ['dm_%s' % asset_group_id, 'dm_%s_daily' % asset_group_id, 'dm_%s_Daily' % asset_group_id]
        for table_name in target_tables:
            try:
                table = db.get_table(table_name, db_schema)
            except:
                pass
            else:
                db.connection.execute(table.delete())

        # test writing directly
        group._write(df)

        table = db.get_table('dm_%s' % asset_group_id, db_schema)
        assert 206 == db.connection.execute(select([func.count()]).select_from(table)).first()[0]
        table = db.get_table('dm_%s_daily' % asset_group_id, db_schema)
        assert 3 == db.connection.execute(select([func.count()]).select_from(table)).first()[0]
        table = db.get_table('dm_%s_Daily' % asset_group_id, db_schema)
        assert 3 == db.connection.execute(select([func.count()]).select_from(table)).first()[0]

        df_scored = group.predict(start_ts='2019-07-11 03:25:00', end_ts='2019-07-17 00:00:00')

        print(log_df_info(df_scored, head=0))

        assert (16, 2) == df_scored.shape
        assert 'failure_probability_5d' in df_scored.columns
        assert 'rca_path_5d' in df_scored.columns

        # test empty input data for scoring, which should return empty df with prediction columns added
        df_scored = group.predict(start_ts='1976-01-01 03:25:00', end_ts='1976-01-02 00:00:00')

        print(log_df_info(df_scored, head=0))

        assert (0, 2) == df_scored.shape
        assert 'failure_probability_5d' in df_scored.columns
        assert 'rca_path_5d' in df_scored.columns

        # test hourly summary
        group = FailurePredictionAssetGroupPipeline(
                    asset_group_id=asset_group_id,
                    model_pipeline={
                        'features': ['%s:axlevibration' % iot_type, '%s:axlemomentum' % iot_type],
                        'features_for_training': [':faildate'],
                        'predictions': ['failure_probability', 'rca_path'],
                        'aggregation_methods': ['mean', 'max', 'min', 'std'],
                        'prediction_window_size': '8h',
                    },
                )
        df = group.execute()

        print(log_df_info(df, head=0))

        assert group.new_training
        assert False == group.training
        assert 'FailurePredictionEstimator' in group.model_timestamp and 'FailurePredictionRcaEstimator' in group.model_timestamp
        assert (4010, 2) == df.shape
        assert 'failure_probability_8h' in df.columns
        assert 'rca_path_8h' in df.columns
    except:
        raise
    finally:
        api.delete_iot_type(iot_type, use_wiotp=False, db=db, db_schema=db_schema)

        api.delete_asset_cache(df=df_data_asset, db=db, db_schema=db_schema)

        api.delete_asset_device_mappings(df=df_mappings, db=db, db_schema=db_schema)

        api.delete_asset_group_members(asset_group_id=asset_group_id, db=db, db_schema=db_schema)


if __name__ == '__main__':
    test_resample()
