# Licensed Materials - Property of IBM
# 5737-M66, 5900-AAA, 5900-AMG
# (C) Copyright IBM Corp. 2019, 2025 All Rights Reserved.
# US Government Users Restricted Rights - Use, duplication, or disclosure
# restricted by GSA ADP Schedule Contract with IBM Corp.

import math
import random

import numpy as np
import pandas as pd
import pytest
from sqlalchemy import func, select

from pandas.tseries.frequencies import to_offset
from pandas.tseries.offsets import DateOffset

from .. import api
from ..failure_prediction import FailurePredictionAssetGroupPipeline
from ..util import current_directory, log_df_info


@pytest.fixture(scope='module')
def asset_group_id():
    return 'abcd'


@pytest.fixture
def iot_type():
    return 'abcdsensor_%05d' % math.floor(random.random() * 10**5)


def test_failure_probability_summary_config(asset_group_id, iot_type):
    # non-multiclass

    # test default summary
    group = FailurePredictionAssetGroupPipeline(
                asset_group_id=asset_group_id,
                model_pipeline={
                    'features': ['%s:axlevibration' % iot_type, '%s:axlemomentum' % iot_type],
                    'features_for_training': [':problemcode'],
                    'predictions': ['failure_probability', 'rca_path'],
                    'aggregation_methods': ['mean', 'max', 'min', 'std'],
                    'prediction_window_size': '5d',
                    'failure_modes': ['STOPPED','BROKEN'],
                },
            )
    assert group.post_processing == [
        {
            "functionName": "Maximum",
            "enabled": True,
            "granularity": "Daily",
            "output": {
                "name": "daily_failure_probability_5d"
            },
            "input": {
                "source": "failure_probability_5d"
            },
        },
        {
            "functionName": "Mean",
            "enabled": True,
            "granularity": "Daily",
            "output": {
                "name": "group_daily_failure_probability_5d"
            },
            "input": {
                "source": "failure_probability_5d"
            },
        },
    ]

    # test default summary for hourly prediction window
    group = FailurePredictionAssetGroupPipeline(
                asset_group_id=asset_group_id,
                model_pipeline={
                    'features': ['%s:axlevibration' % iot_type, '%s:axlemomentum' % iot_type],
                    'features_for_training': [':problemcode'],
                    'predictions': ['failure_probability', 'rca_path'],
                    'aggregation_methods': ['mean', 'max', 'min', 'std'],
                    'prediction_window_size': '5h',
                    'failure_modes': ['STOPPED','BROKEN'],
                },
            )
    assert group.post_processing == [
        {
            "functionName": "Maximum",
            "enabled": True,
            "granularity": "Hourly",
            "output": {
                "name": "hourly_failure_probability_5h"
            },
            "input": {
                "source": "failure_probability_5h"
            },
        },
        {
            "functionName": "Mean",
            "enabled": True,
            "granularity": "GroupHourly",
            "output": {
                "name": "group_hourly_failure_probability_5h"
            },
            "input": {
                "source": "failure_probability_5h"
            },
        },
    ]

    # multiclass

    # test default summary
    group = FailurePredictionAssetGroupPipeline(
                asset_group_id=asset_group_id,
                model_pipeline={
                    'features': ['%s:axlevibration' % iot_type, '%s:axlemomentum' % iot_type],
                    'features_for_training': [':problemcode'],
                    'predictions': ['failure_probability', 'rca_path'],
                    'aggregation_methods': ['mean', 'max', 'min', 'std'],
                    'prediction_window_size': '5d',
                    'multiclass': True,
                    'failure_modes': ['STOPPED','BROKEN'],
                },
            )
    assert group.post_processing == [
        {
            "functionName": "Maximum",
            "enabled": True,
            "granularity": "Daily",
            "output": {
                "name": "daily_failure_probability_stopped_5d"
            },
            "input": {
                "source": "failure_probability_stopped_5d"
            },
        },
        {
            "functionName": "Mean",
            "enabled": True,
            "granularity": "Daily",
            "output": {
                "name": "group_daily_failure_probability_stopped_5d"
            },
            "input": {
                "source": "failure_probability_stopped_5d"
            },
        },
        {
            "functionName": "Maximum",
            "enabled": True,
            "granularity": "Daily",
            "output": {
                "name": "daily_failure_probability_broken_5d"
            },
            "input": {
                "source": "failure_probability_broken_5d"
            },
        },
        {
            "functionName": "Mean",
            "enabled": True,
            "granularity": "Daily",
            "output": {
                "name": "group_daily_failure_probability_broken_5d"
            },
            "input": {
                "source": "failure_probability_broken_5d"
            },
        },
    ]

    # test default summary for hourly prediction window
    group = FailurePredictionAssetGroupPipeline(
                asset_group_id=asset_group_id,
                model_pipeline={
                    'features': ['%s:axlevibration' % iot_type, '%s:axlemomentum' % iot_type],
                    'features_for_training': [':problemcode'],
                    'predictions': ['failure_probability', 'rca_path'],
                    'aggregation_methods': ['mean', 'max', 'min', 'std'],
                    'prediction_window_size': '5h',
                    'multiclass': True,
                    'failure_modes': ['STOPPED','BROKEN'],
                },
            )
    assert group.post_processing == [
        {
            "functionName": "Maximum",
            "enabled": True,
            "granularity": "Hourly",
            "output": {
                "name": "hourly_failure_probability_stopped_5h"
            },
            "input": {
                "source": "failure_probability_stopped_5h"
            },
        },
        {
            "functionName": "Mean",
            "enabled": True,
            "granularity": "GroupHourly",
            "output": {
                "name": "group_hourly_failure_probability_stopped_5h"
            },
            "input": {
                "source": "failure_probability_stopped_5h"
            },
        },
        {
            "functionName": "Maximum",
            "enabled": True,
            "granularity": "Hourly",
            "output": {
                "name": "hourly_failure_probability_broken_5h"
            },
            "input": {
                "source": "failure_probability_broken_5h"
            },
        },
        {
            "functionName": "Mean",
            "enabled": True,
            "granularity": "GroupHourly",
            "output": {
                "name": "group_hourly_failure_probability_broken_5h"
            },
            "input": {
                "source": "failure_probability_broken_5h"
            },
        },
    ]


def test_failure_probability_model_config_prediction_window(asset_group_id, iot_type):
    group = FailurePredictionAssetGroupPipeline(
                asset_group_id=asset_group_id,
                model_pipeline={
                    'features': ['%s:axlevibration' % iot_type, '%s:axlemomentum' % iot_type],
                    'features_for_training': [':problemcode'],
                    'predictions': ['failure_probability', 'rca_path'],
                    'aggregation_methods': ['mean', 'max', 'min', 'std'],
                    'prediction_window_size': '5d',
                    'multiclass': True,
                    'failure_modes': ['STOPPED','BROKEN'],
                },
            )
    assert ['mean', 'max', 'min', 'std'] == group.pipeline_config['aggregation_methods']
    assert '5d' == group.pipeline_config['prediction_window_size']
    assert '5d' == group.pipeline_config['rolling_window_size']
    assert ['failure_probability_stopped_5d', 'failure_probability_broken_5d','rca_path_5d'] == group.pipeline_config['predictions']

    # special single 'd'
    group = FailurePredictionAssetGroupPipeline(
                asset_group_id=asset_group_id,
                model_pipeline={
                    'features': ['%s:axlevibration' % iot_type, '%s:axlemomentum' % iot_type],
                    'features_for_training': [':problemcode'],
                    'predictions': ['failure_probability', 'rca_path'],
                    'aggregation_methods': ['mean', 'max', 'min', 'std'],
                    'prediction_window_size': 'd',
                    'multiclass': True,
                    'failure_modes': ['STOPPED','BROKEN'],
                },
            )
    assert ['mean', 'max', 'min', 'std'] == group.pipeline_config['aggregation_methods']
    assert 'd' == group.pipeline_config['prediction_window_size']
    assert 'd' == group.pipeline_config['rolling_window_size']
    assert ['failure_probability_stopped_1d', 'failure_probability_broken_1d','rca_path_1d'] == group.pipeline_config['predictions']

    # special single 'D'
    group = FailurePredictionAssetGroupPipeline(
                asset_group_id=asset_group_id,
                model_pipeline={
                    'features': ['%s:axlevibration' % iot_type, '%s:axlemomentum' % iot_type],
                    'features_for_training': [':problemcode'],
                    'predictions': ['failure_probability_1d', 'rca_path_1d'],
                    'aggregation_methods': ['mean', 'max', 'min', 'std'],
                    'prediction_window_size': 'D',
                    'multiclass': True,
                    'failure_modes': ['STOPPED','BROKEN'],
                },
            )
    assert ['mean', 'max', 'min', 'std'] == group.pipeline_config['aggregation_methods']
    assert 'D' == group.pipeline_config['prediction_window_size']
    assert 'D' == group.pipeline_config['rolling_window_size']
    assert ['failure_probability_stopped_1d', 'failure_probability_broken_1d', 'rca_path_1d'] == group.pipeline_config['predictions']

    # special '1d'
    group = FailurePredictionAssetGroupPipeline(
                asset_group_id=asset_group_id,
                model_pipeline={
                    'features': ['%s:axlevibration' % iot_type, '%s:axlemomentum' % iot_type],
                    'features_for_training': [':problemcode'],
                    'predictions': ['failure_probability', 'rca_path'],
                    'aggregation_methods': ['mean', 'max', 'min', 'std'],
                    'prediction_window_size': '1d',
                    'multiclass': True,
                    'failure_modes': ['STOPPED','BROKEN'],
                },
            )
    assert ['mean', 'max', 'min', 'std'] == group.pipeline_config['aggregation_methods']
    assert '1d' == group.pipeline_config['prediction_window_size']
    assert '1d' == group.pipeline_config['rolling_window_size']
    assert ['failure_probability_stopped_1d', 'failure_probability_broken_1d', 'rca_path_1d'] == group.pipeline_config['predictions']

    # special '1D'
    group = FailurePredictionAssetGroupPipeline(
                asset_group_id=asset_group_id,
                model_pipeline={
                    'features': ['%s:axlevibration' % iot_type, '%s:axlemomentum' % iot_type],
                    'features_for_training': [':problemcode'],
                    'predictions': ['failure_probability', 'rca_path'],
                    'aggregation_methods': ['mean', 'max', 'min', 'std'],
                    'prediction_window_size': '1D',
                    'multiclass': True,
                    'failure_modes': ['STOPPED','BROKEN'],
                },
            )
    assert ['mean', 'max', 'min', 'std'] == group.pipeline_config['aggregation_methods']
    assert '1D' == group.pipeline_config['prediction_window_size']
    assert '1D' == group.pipeline_config['rolling_window_size']
    assert ['failure_probability_stopped_1d', 'failure_probability_broken_1d', 'rca_path_1d'] == group.pipeline_config['predictions']

    # if given prediction already with postfix, no confict
    group = FailurePredictionAssetGroupPipeline(
                asset_group_id=asset_group_id,
                model_pipeline={
                    'features': ['%s:axlevibration' % iot_type, '%s:axlemomentum' % iot_type],
                    'features_for_training': [':problemcode'],
                    'predictions': ['failure_probability_1d', 'rca_path_1d'],
                    'aggregation_methods': ['mean', 'max', 'min', 'std'],
                    'prediction_window_size': '1d',
                    'multiclass': True,
                    'failure_modes': ['STOPPED','BROKEN'],
                },
            )
    assert ['mean', 'max', 'min', 'std'] == group.pipeline_config['aggregation_methods']
    assert '1d' == group.pipeline_config['prediction_window_size']
    assert '1d' == group.pipeline_config['rolling_window_size']
    assert ['failure_probability_stopped_1d', 'failure_probability_broken_1d', 'rca_path_1d'] == group.pipeline_config['predictions']

    # explicityly given prediction names are preserved (case sensitive) without appending extra postfix
    group = FailurePredictionAssetGroupPipeline(
                asset_group_id=asset_group_id,
                model_pipeline={
                    'features': ['%s:axlevibration' % iot_type, '%s:axlemomentum' % iot_type],
                    'features_for_training': [':problemcode'],
                    'predictions': ['failure_Probability_1D', 'rca_path_1d'],
                    'aggregation_methods': ['mean', 'max', 'min', 'std'],
                    'prediction_window_size': '1d',
                    'multiclass': True,
                    'failure_modes': ['STOPPED','BROKEN'],
                },
            )
    assert ['mean', 'max', 'min', 'std'] == group.pipeline_config['aggregation_methods']
    assert '1d' == group.pipeline_config['prediction_window_size']
    assert '1d' == group.pipeline_config['rolling_window_size']
    assert ['failure_Probability_stopped_1D','failure_Probability_broken_1D' ,'rca_path_1d'] == group.pipeline_config['predictions']

    # uniform lower-case postfix
    group = FailurePredictionAssetGroupPipeline(
                asset_group_id=asset_group_id,
                model_pipeline={
                    'features': ['%s:axlevibration' % iot_type, '%s:axlemomentum' % iot_type],
                    'features_for_training': [':problemcode'],
                    'predictions': ['failure_probability', 'rca_path'],
                    'aggregation_methods': ['mean', 'max', 'min', 'std'],
                    'prediction_window_size': '1D',
                    'multiclass': True,
                    'failure_modes': ['STOPPED','BROKEN'],
                },
            )
    assert ['mean', 'max', 'min', 'std'] == group.pipeline_config['aggregation_methods']
    assert '1D' == group.pipeline_config['prediction_window_size']
    assert '1D' == group.pipeline_config['rolling_window_size']
    assert ['failure_probability_stopped_1d', 'failure_probability_broken_1d', 'rca_path_1d'] == group.pipeline_config['predictions']

    group = FailurePredictionAssetGroupPipeline(
                asset_group_id=asset_group_id,
                model_pipeline={
                    'features': ['%s:axlevibration' % iot_type, '%s:axlemomentum' % iot_type],
                    'features_for_training': [':problemcode'],
                    'predictions': ['failure_probability', 'rca_path'],
                    'prediction_window_size': '10d',
                    'multiclass': True,
                    'failure_modes': ['STOPPED','BROKEN'],
                },
            )
    assert ['mean', 'max', 'min', 'median', 'std', 'sum', 'count'] == group.pipeline_config['aggregation_methods']
    assert '10d' == group.pipeline_config['prediction_window_size']
    assert '10d' == group.pipeline_config['rolling_window_size']
    assert ['failure_probability_stopped_10d','failure_probability_broken_10d', 'rca_path_10d'] == group.pipeline_config['predictions']

    # different rolling window size, and no re-padding prediction name postfix
    group = FailurePredictionAssetGroupPipeline(
                asset_group_id=asset_group_id,
                model_pipeline={
                    'features': ['%s:axlevibration' % iot_type, '%s:axlemomentum' % iot_type],
                    'features_for_training': [':problemcode'],
                    'predictions': ['failure_probability_10d', 'rca_path_10d'],
                    'prediction_window_size': '10d',
                    'rolling_window_size': '5d',
                    'multiclass': True,
                    'failure_modes': ['STOPPED','BROKEN'],
                },
            )
    assert ['mean', 'max', 'min', 'median', 'std', 'sum', 'count'] == group.pipeline_config['aggregation_methods']
    assert '10d' == group.pipeline_config['prediction_window_size']
    assert '5d' == group.pipeline_config['rolling_window_size']
    assert ['failure_probability_stopped_10d', 'failure_probability_broken_10d','rca_path_10d'] == group.pipeline_config['predictions']

    # padding prediction window size to prediction output names
    group = FailurePredictionAssetGroupPipeline(
                asset_group_id=asset_group_id,
                model_pipeline={
                    'features': ['%s:axlevibration' % iot_type, '%s:axlemomentum' % iot_type],
                    'features_for_training': [':problemcode'],
                    'predictions': ['failure_probability_10d', 'rca_path_10d'],
                    'prediction_window_size': '12h',
                    'multiclass': True,
                    'failure_modes': ['STOPPED','BROKEN'],
                },
            )
    assert ['mean', 'max', 'min', 'median', 'std', 'sum', 'count'] == group.pipeline_config['aggregation_methods']
    assert '12h' == group.pipeline_config['prediction_window_size']
    assert '12h' == group.pipeline_config['rolling_window_size']
    assert ['failure_probability_10d_stopped_12h', 'failure_probability_10d_broken_12h','rca_path_10d_12h'] == group.pipeline_config['predictions']

    # invalid prediction window given, no unit
    with pytest.raises(ValueError):
        group = FailurePredictionAssetGroupPipeline(
                    asset_group_id=asset_group_id,
                    model_pipeline={
                        'features': ['%s:axlevibration' % iot_type, '%s:axlemomentum' % iot_type],
                        'features_for_training': [':problemcode'],
                        'predictions': ['failure_probability_10d', 'rca_path_10d'],
                        'prediction_window_size': '10',
                        'multiclass': True,
                        'failure_modes': ['STOPPED','BROKEN'],
                    },
                )

    # invalid prediction window
    with pytest.raises(ValueError):
        group = FailurePredictionAssetGroupPipeline(
                    asset_group_id=asset_group_id,
                    model_pipeline={
                        'features': ['%s:axlevibration' % iot_type, '%s:axlemomentum' % iot_type],
                        'features_for_training': [':problemcode'],
                        'predictions': ['failure_probability_10m', 'rca_path_10m'],
                        'prediction_window_size': '10ms',
                        'multiclass': True,
                        'failure_modes': ['STOPPED','BROKEN'],
                    },
                )

    # invalid rolling window given
    with pytest.raises(ValueError):
        group = FailurePredictionAssetGroupPipeline(
                    asset_group_id=asset_group_id,
                    model_pipeline={
                        'features': ['%s:axlevibration' % iot_type, '%s:axlemomentum' % iot_type],
                        'features_for_training': [':problemcode'],
                        'predictions': ['failure_probability_10d', 'rca_path_10d'],
                        'prediction_window_size': '10d',
                        'rolling_window_size': 'abc',
                        'multiclass': True,
                        'failure_modes': ['STOPPED','BROKEN'],
                    },
                )

    # test auto padding prediction window at the end of prediction names, if not already

    group = FailurePredictionAssetGroupPipeline(
                asset_group_id=asset_group_id,
                model_pipeline={
                    'features': ['%s:axlevibration' % iot_type, '%s:axlemomentum' % iot_type],
                    'features_for_training': [':problemcode'],
                    'predictions': ['failure_probability', 'rca_path_5'],
                    'aggregation_methods': ['mean', 'max', 'min', 'std'],
                    'prediction_window_size': '5d',
                    'multiclass': True,
                    'failure_modes': ['STOPPED','BROKEN'],
                },
            )
    assert ['failure_probability_stopped_5d', 'failure_probability_broken_5d','rca_path_5_5d'] == group.pipeline_config['predictions']

    group = FailurePredictionAssetGroupPipeline(
                asset_group_id=asset_group_id,
                model_pipeline={
                    'features': ['%s:axlevibration' % iot_type, '%s:axlemomentum' % iot_type],
                    'features_for_training': [':problemcode'],
                    'predictions': ['failure_probability', 'rca_path_5h'],
                    'aggregation_methods': ['mean', 'max', 'min', 'std'],
                    'prediction_window_size': '5h',
                    'multiclass': True,
                    'failure_modes': ['STOPPED','BROKEN'],
                },
            )
    assert ['failure_probability_stopped_5h', 'failure_probability_broken_5h', 'rca_path_5h'] == group.pipeline_config['predictions']


def test_data_substitution_validation(asset_group_id, iot_type):
    df_data_asset = pd.read_csv('%s/trainbrake_asset_faildates.csv' % current_directory(file=__file__), parse_dates=['faildate'])
    df_data_asset['asset'] = np.where(df_data_asset['asset'] == 'TRAINBRAKE1', 'abcd-1', 'abcd-2')
    df_data_asset['assetid'] = df_data_asset['asset'] + '-____-' + df_data_asset['site']
    df_data_asset['datetime'] = df_data_asset['faildate']
    df_data_sensor = pd.read_csv('%s/trainbrake_device_data.csv' % current_directory(file=__file__), parse_dates=['RCV_TIMESTAMP_UTC'])[:10000]
    df_data_sensor['DEVICETYPE'] = iot_type
    df_data_sensor['DEVICEID'] = np.where(df_data_sensor['DEVICEID'] == 'TrainBrake_1', 'abcd-1', 'abcd-2')

    model_pipeline = {
        'features': ['%s:axlevibration' % iot_type, '%s:axlemomentum' % iot_type],
        'features_for_training': [':problemcode'],
        'predictions': ['failure_probability_1d', 'rca_path_1d'],
        'aggregation_methods': ['mean', 'max', 'min', 'median', 'std', 'sum', 'count'],
        'prediction_window_size': '1d',
        'multiclass': True,
        'failure_modes': ['STOPPED','BROKEN'],
    }
    asset_device_mappings = {
        'abcd-1-____-BEDFORD': ['%s:abcd-1' % iot_type], 
        'abcd-2-____-BEDFORD': [],
    }
    substitution_asset = [
        {
            'df': df_data_asset,
            'keys': ['assetid'],
            'columns': ['faildate'],
            'timestamp': 'datetime',
        },
    ]
    substitution_iot = [
        {
            'df': df_data_sensor,
            'keys': ['DEVICEID'],
            'columns': [
                'TRAINBRAKESIMULATION_AXLEVIBRATION',
                'TRAINBRAKESIMULATION_AXLEMOMENTUM',
            ],
            'timestamp': 'RCV_TIMESTAMP_UTC',
            'rename_columns': {
                'TRAINBRAKESIMULATION_AXLEVIBRATION': 'axlevibration',
                'TRAINBRAKESIMULATION_AXLEMOMENTUM': 'axlemomentum',
            },
        },
    ]

    # training

    FailurePredictionAssetGroupPipeline(
            asset_group_id=asset_group_id,
            model_pipeline=model_pipeline,
            asset_device_mappings=asset_device_mappings,
            data_substitution={
                '': substitution_asset,
                iot_type: substitution_iot,
            },
        )

    # when training, runtime error raised if given invalid entity type
    with pytest.raises(ValueError):
        FailurePredictionAssetGroupPipeline(
                asset_group_id=asset_group_id,
                model_pipeline=model_pipeline,
                asset_device_mappings=asset_device_mappings,
                data_substitution={
                    '': substitution_asset,
                    iot_type + '_abc': substitution_iot,
                },
            )

    # scoring

    FailurePredictionAssetGroupPipeline(
            asset_group_id=asset_group_id,
            model_pipeline=model_pipeline,
            asset_device_mappings=asset_device_mappings,
            model_timestamp={'FailurePredictionRcaEstimator': '1570169739', 'FailurePredictionEstimator': '1570169523'},
            data_substitution={
                iot_type: substitution_iot,
            },
        )

    # when scoring, runtime error raised if given invalid entity type
    with pytest.raises(ValueError):
        FailurePredictionAssetGroupPipeline(
                asset_group_id=asset_group_id,
                model_pipeline=model_pipeline,
                asset_device_mappings=asset_device_mappings,
                model_timestamp={'FailurePredictionRcaEstimator': '1570169739', 'FailurePredictionEstimator': '1570169523'},
                data_substitution={
                    iot_type + '_abc': substitution_iot,
                },
            )

    # when scoring, runtime error raised if given unused asset data (since scoring does not need it)
    with pytest.raises(ValueError):
        FailurePredictionAssetGroupPipeline(
                asset_group_id=asset_group_id,
                model_pipeline=model_pipeline,
                asset_device_mappings=asset_device_mappings,
                model_timestamp={'FailurePredictionRcaEstimator': '1570169739', 'FailurePredictionEstimator': '1570169523'},
                data_substitution={
                    '': substitution_asset,
                    iot_type: substitution_iot,
                },
            )


def test_model_config(asset_group_id, iot_type):

    # multi-class

    # problemcode must be the 1st eleemnt of features_for_training when falure_modes is used
    with pytest.raises(ValueError):
        group = FailurePredictionAssetGroupPipeline(
                    asset_group_id=asset_group_id,
                    model_pipeline={
                        'features': ['%s:Axlevibration' % iot_type, '%s:Axlemomentum' % iot_type],
                        'features_for_training': [':faildate', ':problemcode'],
                        'predictions': ['failure_probability', 'rca_path'],
                        'aggregation_methods': ['mean', 'max'],
                        'prediction_window_size': '5d',
                        'multiclass': True,
                        'failure_modes': ['STOPPED','BROKEN'],
                    },
                )

    # multi-class must have failure_modes specified
    with pytest.raises(ValueError):
        group = FailurePredictionAssetGroupPipeline(
                    asset_group_id=asset_group_id,
                    model_pipeline={
                        'features': ['%s:Axlevibration' % iot_type, '%s:Axlemomentum' % iot_type],
                        'features_for_training': [':faildate', ':problemcode'],
                        'predictions': ['failure_probability', 'rca_path'],
                        'aggregation_methods': ['mean', 'max'],
                        'prediction_window_size': '5d',
                        'multiclass': True,
                    },
                )
    with pytest.raises(ValueError):
        group = FailurePredictionAssetGroupPipeline(
                    asset_group_id=asset_group_id,
                    model_pipeline={
                        'features': ['%s:Axlevibration' % iot_type, '%s:Axlemomentum' % iot_type],
                        'features_for_training': [':faildate', ':problemcode'],
                        'predictions': ['failure_probability', 'rca_path'],
                        'aggregation_methods': ['mean', 'max'],
                        'prediction_window_size': '5d',
                        'multiclass': True,
                        'failure_modes': [],
                    },
                )


def test_failure_probability_with_failure_modes(asset_group_id, iot_type):
    df_data_sensor = pd.DataFrame()
    df_data_sensor['rcv_timestamp_utc'] = pd.date_range(end=pd.Timestamp('2020-01-05T10:31:40'), periods=2084, freq='10T', tz='UTC')
    df_data_sensor['devicetype'] = iot_type
    df_data_sensor['deviceid'] = 'abcd-1'

    from numpy.random import default_rng
    rg = default_rng(12345)

    df_data_sensor['current'] = rg.integers(low=0, high=1000, size=len(df_data_sensor))
    df_data_sensor['voltage'] = rg.integers(low=0, high=1000, size=len(df_data_sensor))

    print(log_df_info(df_data_sensor, head=5))

    df_data_asset = pd.DataFrame(data={
        'datetime': [
            pd.Timestamp('2019-12-25T08:00'),
            pd.Timestamp('2019-12-31T08:00'),
            pd.Timestamp('2019-12-31T08:00'),
            pd.Timestamp('2020-01-03T08:00'),
        ],
        'problemcode': [
            'PUMPS/STOPPED',
            'PUMPS/STOPPED',
            'PUMPS/BROKEN',
            'PUMPS/BROKEN',
        ],
        'classcode': [
            'PUMPS',
            'PUMPS',
            'PUMPS',
            'PUMPS',
        ]
    })
    df_data_asset['asset'] = 'abcd-1'
    df_data_asset['site'] = 'BEDFORD'
    df_data_asset['assetid'] = df_data_asset['asset'] + '-____-' + df_data_asset['site']
    df_data_asset['faildate'] = df_data_asset['datetime']

    print(log_df_info(df_data_asset, head=5))

    # non-multi-class with failure_modes specified

    # comment out below code because of exception from opt/app-root/lib/python3.6/site-packages/sklearn/pipeline.py

    
    #df = group.execute()

    #print(log_df_info(df, head=0))

    #assert group.new_training
    #assert False == group.training
    #assert 'FailurePredictionEstimator' in group.model_timestamp and 'FailurePredictionRcaEstimator' in group.model_timestamp
    #assert (1655, 2) == df.shape
    #assert 'failure_probability_3d' in df.columns
    #assert 'rca_path_3d' in df.columns
    #assert group.pipeline_config['failure_modes'] == {'STOPPED': [1, True], 'BROKEN': [2, True], 'UNKNOWN': [3, False]}

    # multi-class

    group = FailurePredictionAssetGroupPipeline(
                asset_group_id=asset_group_id,
                model_pipeline={
                    'features': ['%s:current' % iot_type, '%s:voltage' % iot_type],
                    #'features_for_training': [':problemcode'],
                    'predictions': ['failure_probability', 'rca_path'],
                    'aggregation_methods': ['mean', 'max', 'min', 'median', 'std', 'sum', 'count'],
                    'prediction_window_size': '3d',
                    'multiclass': True,
                    'failure_modes': ['PUMPS/STOPPED', 'PUMPS/BROKEN', 'PUMPS/UNKNOWN'],
                },
                asset_device_mappings={
                    'abcd-1-____-BEDFORD': ['%s:abcd-1' % iot_type], 
                },
                data_substitution={
                    '': [
                        {
                            'df': df_data_asset,
                            'keys': ['assetid'],
                            
                            'columns': ['problemcode','classcode'],
                            'timestamp': 'datetime',
                        },
                    ],
                    iot_type: [
                        {
                            'df': df_data_sensor,
                            'keys': ['deviceid'],
                            'columns': [
                                'current',
                                'voltage',
                            ],
                            'timestamp': 'rcv_timestamp_utc',
                            'rename_columns': {
                                'current': 'current',
                                'voltage': 'voltage',
                            },
                        },
                    ],
                },
            )
    df = group.execute()

    print(log_df_info(df, head=10))

    assert group.new_training
    assert False == group.training
    assert 'MulticlassFailurePredictionEstimator' in group.model_timestamp and 'MulticlassFailurePredictionRcaEstimator' in group.model_timestamp
    assert (1655, 4) == df.shape
    assert 'failure_probability_pumps_stopped_3d' in df.columns
    assert 'failure_probability_pumps_broken_3d' in df.columns
    assert 'failure_probability_pumps_unknown_3d' in df.columns
    assert 'rca_path_3d' in df.columns
    assert group.pipeline_config['failure_modes'] == {'PUMPS/STOPPED': [1, True], 'PUMPS/BROKEN': [2, True], 'PUMPS/UNKNOWN': [3, False]}


# def test_failure_probability(asset_group_id, iot_type):
#     df_data_asset = pd.read_csv('%s/trainbrake_asset_faildates.csv' % current_directory(file=__file__), parse_dates=['faildate'])
#     df_data_asset['asset'] = np.where(df_data_asset['asset'] == 'TRAINBRAKE1', 'abcd-1', 'abcd-2')
#     df_asset_group = df_data_asset.groupby(['site', 'asset']).size().reset_index()[['site', 'asset']]

#     df_data_sensor = pd.read_csv('%s/trainbrake_device_data.csv' % current_directory(file=__file__), parse_dates=['RCV_TIMESTAMP_UTC'])[:10000]
#     df_data_sensor['DEVICEID'] = np.where(df_data_sensor['DEVICEID'] == 'TrainBrake_1', 'abcd-1', 'abcd-2')

#     df_mappings = pd.DataFrame(
#         columns=['site', 'asset', 'devicetype', 'deviceid'],
#         data=[
#             ['BEDFORD', 'abcd-1', iot_type, 'abcd-1'],
#         ],
#     )

#     try:
#         db = api._get_db()
#         db_schema = None

#         api.set_asset_group_members(asset_group_id=asset_group_id, df=df_asset_group, db=db, db_schema=db_schema)

#         api.set_asset_device_mappings(df=df_mappings, db=db, db_schema=db_schema)

#         api.set_asset_cache(df=df_data_asset, siteid_column='site', assetid_column='asset', faildate_column='faildate', db=db, db_schema=db_schema)

#         api.setup_iot_type(iot_type, df_data_sensor, columns=['TRAINBRAKESIMULATION_AXLEVIBRATION', 'TRAINBRAKESIMULATION_AXLEMOMENTUM'], deviceid_column='DEVICEID', timestamp_column='RCV_TIMESTAMP_UTC', timestamp_in_payload=False, parse_dates=None, rename_columns={'TRAINBRAKESIMULATION_AXLEVIBRATION': 'Axlevibration', 'TRAINBRAKESIMULATION_AXLEMOMENTUM': 'Axlemomentum'}, write='deletefirst', use_wiotp=False, import_only=False, db=db, db_schema=db_schema)

#         group = FailurePredictionAssetGroupPipeline(
#                     asset_group_id=asset_group_id,
#                     model_pipeline={
#                         'features': ['%s:Axlevibration' % iot_type, '%s:Axlemomentum' % iot_type],
#                         'features_for_training': [':problemcode'],
#                         'predictions': ['failure_probability', 'rca_path'],
#                         'aggregation_methods': ['mean', 'max'],
#                         'prediction_window_size': '5d',
#                         'multiclass': True,
#                         'failure_modes': ['STOPPED','BROKEN'],
#                     },
#                 )
#         df = group.execute()

#         print(log_df_info(df, head=0))

#         assert group.new_training
#         assert False == group.training
#         assert 'MulticlassFailurePredictionEstimator' in group.model_timestamp and 'MulticlassFailurePredictionRcaEstimator' in group.model_timestamp
#         assert (545, 2) == df.shape
#         assert 'failure_probability_stopped_5d' in df.columns
#         assert 'failure_probability_broken_5d' in df.columns
#         assert 'rca_path_5d' in df.columns

#         target_tables = ['dm_%s' % asset_group_id, 'dm_%s_daily' % asset_group_id, 'dm_%s_Daily' % asset_group_id]
#         for table_name in target_tables:
#             try:
#                 table = db.get_table(table_name, db_schema)
#             except:
#                 pass
#             else:
#                 db.connection.execute(table.delete())

#         # test writing directly
#         group._write(df)

#         table = db.get_table('dm_%s' % asset_group_id, db_schema)
#         assert 1090 == db.connection.execute(select([func.count()]).select_from(table)).first()[0]
#         table = db.get_table('dm_%s_daily' % asset_group_id, db_schema)
#         assert 3 == db.connection.execute(select([func.count()]).select_from(table)).first()[0]
#         table = db.get_table('dm_%s_Daily' % asset_group_id, db_schema)
#         assert 3 == db.connection.execute(select([func.count()]).select_from(table)).first()[0]

#         df_scored = group.execute(start_ts='2019-07-11 03:25:00', end_ts='2019-07-17 00:00:00')

#         print(log_df_info(df_scored, head=0))

#         assert (132, 2) == df_scored.shape
#         assert 'failure_probability_stopped_5d' in df.columns
#         assert 'failure_probability_broken_5d' in df.columns
#         assert 'rca_path_5d' in df_scored.columns

#         # test empty input data for scoring, which should return empty df with prediction columns added
#         df_scored = group.execute(start_ts='1976-01-01 03:25:00', end_ts='1976-01-02 00:00:00')

#         print(log_df_info(df_scored, head=0))

#         assert (0, 2) == df_scored.shape
#         assert 'failure_probability_stopped_5d' in df.columns
#         assert 'failure_probability_broken_5d' in df.columns
#         assert 'rca_path_5d' in df_scored.columns

#         # test hourly summary
#         group = FailurePredictionAssetGroupPipeline(
#                     asset_group_id=asset_group_id,
#                     model_pipeline={
#                         'features': ['%s:Axlevibration' % iot_type, '%s:Axlemomentum' % iot_type],
#                         'features_for_training': [':problemcode'],
#                         'predictions': ['failure_probability', 'rca_path'],
#                         'aggregation_methods': ['mean', 'max', 'min', 'std'],
#                         'prediction_window_size': '8h',
#                         'multiclass': True,
#                         'failure_modes': ['STOPPED','BROKEN'],
#                     },
#                 )
#         df = group.execute()

#         print(log_df_info(df, head=0))

#         assert group.new_training
#         assert False == group.training
#         assert 'FailurePredictionEstimator' in group.model_timestamp and 'FailurePredictionRcaEstimator' in group.model_timestamp
#         assert (4010, 2) == df.shape
#         assert 'failure_probability_stopped_8h' in df.columns
#         assert 'failure_probability_broken_8h' in df.columns
#         assert 'rca_path_8h' in df.columns
#     except:
#         raise
#     finally:
#         api.delete_iot_type(iot_type, use_wiotp=False, db=db, db_schema=db_schema)

#         api.delete_asset_cache(df=df_data_asset, db=db, db_schema=db_schema)

#         api.delete_asset_device_mappings(df=df_mappings, db=db, db_schema=db_schema)

#         api.delete_asset_group_members(asset_group_id=asset_group_id, db=db, db_schema=db_schema)

