# Licensed Materials - Property of IBM
# 5737-M66, 5900-AAA, 5900-AMG
# (C) Copyright IBM Corp. 2019, 2025 All Rights Reserved.
# US Government Users Restricted Rights - Use, duplication, or disclosure
# restricted by GSA ADP Schedule Contract with IBM Corp.

import numpy as np
import pandas as pd
import pytest
from sqlalchemy import func, select

from iotfunctions.pipeline import CalcPipeline

from .. import api
from ..failure_prediction import FailurePredictionAssetGroupPipeline
from ..util import current_directory, log_df_info


def test_multi_failure_probability(asset_group_id='abcd', iot_type=None):
    if iot_type is None:
        import math
        import random
        iot_type = 'abcdsensor_%05d' % math.floor(random.random() * 10**5)

    df_data_asset = pd.read_csv('%s/trainbrake_asset_faildates.csv' % current_directory(file=__file__), parse_dates=['faildate'])
    df_data_asset['asset'] = np.where(df_data_asset['asset'] == 'TRAINBRAKE1', 'abcd-1', 'abcd-2')
    df_asset_group = df_data_asset.groupby(['site', 'asset']).size().reset_index()[['site', 'asset']]

    df_data_sensor = pd.read_csv('%s/trainbrake_device_data.csv' % current_directory(file=__file__), parse_dates=['RCV_TIMESTAMP_UTC'])[:10000]
    df_data_sensor['DEVICEID'] = np.where(df_data_sensor['DEVICEID'] == 'TrainBrake_1', 'abcd-1', 'abcd-2')

    df_mappings = pd.DataFrame(
        columns=['site', 'asset', 'devicetype', 'deviceid'],
        data=[
            ['BEDFORD', 'abcd-1', iot_type, 'abcd-1'],
        ],
    )

    try:
        db = api._get_db()
        db_schema = None

        api.set_asset_group_members(asset_group_id=asset_group_id, df=df_asset_group, db=db, db_schema=db_schema)

        api.set_asset_device_mappings(df=df_mappings, db=db, db_schema=db_schema)

        api.set_asset_cache(df=df_data_asset, siteid_column='site', assetid_column='asset', faildate_column='faildate', db=db, db_schema=db_schema)

        api.setup_iot_type(iot_type, df_data_sensor, columns=['TRAINBRAKESIMULATION_AXLEVIBRATION', 'TRAINBRAKESIMULATION_AXLEMOMENTUM'], deviceid_column='DEVICEID', timestamp_column='RCV_TIMESTAMP_UTC', timestamp_in_payload=False, parse_dates=None, rename_columns={'TRAINBRAKESIMULATION_AXLEVIBRATION': 'Axlevibration', 'TRAINBRAKESIMULATION_AXLEMOMENTUM': 'Axlemomentum'}, write='deletefirst', use_wiotp=False, import_only=False, db=db, db_schema=db_schema)

        # 5-day prediction
        fp1 = FailurePredictionAssetGroupPipeline(
            asset_group_id=asset_group_id,
            model_pipeline={
                'features': ['%s:Axlevibration' % iot_type, '%s:Axlemomentum' % iot_type],
                'features_for_training': [':faildate'],
                'predictions': ['failure_probability', 'rca_path'],
                'aggregation_methods': ['mean', 'max', 'min', 'median', 'std', 'sum', 'count'],
                'prediction_window_size': '5d',
            },
        )
        # 3-day prediction
        fp2 = FailurePredictionAssetGroupPipeline(
            asset_group_id=asset_group_id,
            model_pipeline={
                'features': ['%s:Axlevibration' % iot_type, '%s:Axlemomentum' % iot_type],
                'features_for_training': [':faildate'],
                'predictions': ['failure_probability', 'rca_path'],
                'aggregation_methods': ['mean', 'max', 'min', 'std'],
                'prediction_window_size': '3d',
            },
        )

        # add both to a single pipeline
        pipeline = CalcPipeline(stages=None, entity_type=fp1._entity_type)
        pipeline.add_stage(fp1)
        #TO DO 
        #pipeline.add_stage(fp2)

        df = pipeline.execute()

        print(log_df_info(df, head=0))

        assert fp1.new_training
        assert False == fp1.training
        assert 'FailurePredictionEstimator' in fp1.model_timestamp and 'FailurePredictionRcaEstimator' in fp1.model_timestamp

        #assert fp2.new_training
        #assert False == fp2.training
        #assert 'FailurePredictionEstimator' in fp2.model_timestamp and 'FailurePredictionRcaEstimator' in fp2.model_timestamp

        #assert (3028, 4) == df.shape
        #assert 'failure_probability_5d' in df.columns
        #assert 'rca_path_5d' in df.columns
        #assert 'failure_probability_3d' in df.columns
        #assert 'rca_path_3d' in df.columns
    except:
        raise
    finally:
        api.delete_iot_type(iot_type, use_wiotp=False, db=db, db_schema=db_schema)

        api.delete_asset_cache(df=df_data_asset, db=db, db_schema=db_schema)

        api.delete_asset_device_mappings(df=df_mappings, db=db, db_schema=db_schema)

        api.delete_asset_group_members(asset_group_id=asset_group_id, db=db, db_schema=db_schema)


if __name__ == '__main__':
    pass

