# Licensed Materials - Property of IBM
# 5737-M66, 5900-AAA
# (C) Copyright IBM Corp. 2019, 2025 All Rights Reserved.
# US Government Users Restricted Rights - Use, duplication, or disclosure
# restricted by GSA ADP Schedule Contract with IBM Corp.

import numpy as np
import pandas as pd
import pytest
from sqlalchemy import func, select

from .. import api
from ..degradation_curve import DegradationCurveAssetGroupPipeline
from ..util import log_df_info


@pytest.fixture(scope='module')
def asset_group_id():
    return 'abcd'


@pytest.fixture(scope="module")
def degradation_data():
    original_data = [
        ('asset001','1979',), ('asset002','1979',), ('asset003','1979',), ('asset004','1981',), ('asset005','1981',),
        ('asset006','1981',), ('asset007','1985',), ('asset008','1985',), ('asset009','1985',), ('asset010','1979',),
        ('asset011','1979',), ('asset012','1979',), ('asset013','1969',), ('asset014','1969',), ('asset015','1969',),
        ('asset016','1969',), ('asset017','1969',), ('asset018','1969',), ('asset019','1970','1996'), ('asset020','1996',),
        ('asset021','1970',), ('asset022','1970',), ('asset023','1970','1989'), ('asset024','1989',), ('asset025','1970',),
        ('asset026','1970',), ('asset027','1970',), ('asset028','1970',), ('asset029','1970',), ('asset030','1970',),
        ('asset031','1970',), ('asset032','1970',), ('asset033','1976',), ('asset034','1976',), ('asset035','1976',),
        ('asset036','1976',), ('asset037','1976',), ('asset038','1976',), ('asset039','1976',), ('asset040','1976',),
        ('asset041','1976',), ('asset042','1976',), ('asset043','1976',), ('asset044','1976',), ('asset045','1976',),
        ('asset046','1976',), ('asset047','1976',), ('asset048','1976',), ('asset049','1976',), ('asset050','1976',),
        ('asset051','1976',), ('asset052','1976',), ('asset053','1976',), ('asset054','1970',), ('asset055','1970',),
        ('asset056','1970',), ('asset057','1981',), ('asset058','1981',), ('asset059','1981',), ('asset060','1983',),
        ('asset061','1983',), ('asset062','1983',), ('asset063','1984',), ('asset064','1984',), ('asset065','1984',),
        ('asset066','1983',), ('asset067','1983',), ('asset068','1983',), ('asset069','1984',), ('asset070','1984',),
        ('asset071','1984',), ('asset072','1983',), ('asset073','1983',), ('asset074','1983',), ('asset075','1984',),
        ('asset076','1984',), ('asset077','1984',), ('asset078','1978',), ('asset079','1978',), ('asset080','1978',),
        ('asset081','1969','1996'), ('asset082','1996',), ('asset083','1969',), ('asset084','1969',), ('asset085','1969',),
        ('asset086','1969',), ('asset087','1969',), ('asset088','1969',), ('asset089','1969',), ('asset090','1969',),
        ('asset091','1969',), ('asset092','1969',), ('asset093','1969',), ('asset094','1969',), ('asset095','1969','1997'),
        ('asset096','1997',), ('asset097','1969',), ('asset098','1969',), ('asset099','1969',), ('asset100','1969',)
    ]
    data = []
    year_offset = pd.Timestamp('now').year - 2019
    for d in original_data:
        if len(d) == 2:
            data.append((d[0], str(int(d[1]) + year_offset),))
        else:
            data.append((d[0], str(int(d[1]) + year_offset), str(int(d[2]) + year_offset) if d[2] is not None else d[2]))
    return data


def test_degradation_curve_model_config(asset_group_id):
    # no statistics_distribution_args given
    with pytest.raises(ValueError):
        group = DegradationCurveAssetGroupPipeline(
                    asset_group_id=asset_group_id,
                    model_pipeline={
                        'features_for_training': [':installdate', ':statusdate'],
                    },
                )

    # empty statistics_distribution_args given
    with pytest.raises(ValueError):
        group = DegradationCurveAssetGroupPipeline(
                    asset_group_id=asset_group_id,
                    model_pipeline={
                        'features_for_training': [':installdate', ':statusdate'],
                        'statistics_distribution_args': {
                        },
                    },
                )

    # invalid statistics_distribution_args.distribution_type given
    with pytest.raises(ValueError):
        group = DegradationCurveAssetGroupPipeline(
                    asset_group_id=asset_group_id,
                    model_pipeline={
                        'features_for_training': [':installdate', ':statusdate'],
                        'statistics_distribution_args': {
                            'distribution_type': None,
                        },
                    },
                )

    # invalid statistics_distribution_args.distribution_type given
    with pytest.raises(ValueError):
        group = DegradationCurveAssetGroupPipeline(
                    asset_group_id=asset_group_id,
                    model_pipeline={
                        'features_for_training': [':installdate', ':statusdate'],
                        'statistics_distribution_args': {
                            'distribution_type': 'abc',
                        },
                    },
                )

    group = DegradationCurveAssetGroupPipeline(
                asset_group_id=asset_group_id,
                model_pipeline={
                    'statistics_distribution_args': {
                        'distribution_type': 'normal',
                    },
                },
            )
    assert group.pipeline_config['features_for_training'] == ['installdate', 'statusdate', 'status']

    group = DegradationCurveAssetGroupPipeline(
                asset_group_id=asset_group_id,
                model_pipeline={
                    'statistics_distribution_args': {
                    'features_for_training': [':installdate', ':estendoflife'],
                        'distribution_type': 'normal',
                    },
                },
            )
    assert group.pipeline_config['features_for_training'] == ['installdate', 'statusdate', 'status']

    group = DegradationCurveAssetGroupPipeline(
                asset_group_id=asset_group_id,
                model_pipeline={
                    'statistics_distribution_args': {
                    'features_for_training': [':installdate', ':estendoflife', ':status'],
                        'distribution_type': 'normal',
                    },
                },
            )
    assert group.pipeline_config['features_for_training'] == ['installdate', 'statusdate', 'status']

    group = DegradationCurveAssetGroupPipeline(
                asset_group_id=asset_group_id,
                model_pipeline={
                    'features_for_training': [':installdate', ':statusdate'],
                    'statistics_distribution_args': {
                        'distribution_type': 'normal',
                    },
                },
            )
    assert group.pipeline_config['features_for_training'] == ['installdate', 'statusdate', 'status']

    group = DegradationCurveAssetGroupPipeline(
                asset_group_id=asset_group_id,
                model_pipeline={
                    'features_for_training': [':installdate', ':statusdate'],
                    'statistics_distribution_args': {
                        'distribution_type': 'weibull',
                    },
                },
            )
    assert group.pipeline_config['features_for_training'] == ['installdate', 'statusdate', 'status']


def test_degradation_curve_data_substitution(degradation_data, asset_group_id):
    df_data_asset = pd.DataFrame(data=degradation_data, columns=['assetid', 'installdate', 'statusdate'])
    # special year only string, astype/dtype does not work, have to use pd.to_datetime
    df_data_asset['installdate'] = pd.to_datetime(df_data_asset['installdate'])
    df_data_asset['statusdate'] = pd.to_datetime(df_data_asset['statusdate'])
    # add status column, like we call from Maximo
    df_data_asset['status'] = 'DECOMMISSIONED'

    group = DegradationCurveAssetGroupPipeline(
                asset_group_id=asset_group_id,
                model_pipeline={
                    'features_for_training': [':installdate', ':statusdate'],
                    'statistics_distribution_args': {
                        'distribution_type': 'WEIBULL',
                        'mean_or_scale': None,
                        'stddev_or_shape': None,
                    },
                },
                asset_device_mappings={
                    'asset%03d' % i: [] for i in range(1, 101)
                },
                data_substitution={
                    '': [
                        {
                            'df': df_data_asset,
                            'keys': ['assetid'],
                            'columns': ['installdate', 'statusdate', 'status'],
                        },
                    ],
                },
            )
    df = group.execute()

    print(log_df_info(df, head=0))

    assert group.new_training
    assert False == group.training
    assert 'DegradationCurveEstimator' in group.model_timestamp
    assert df.empty
    assert 0 == len(df.columns)

    assert 1 == len(group.trained_models)
    estimator, model, model_paths = group.trained_models[0]
    assert model['final_degradation_curve'] == {0: 0.0, 1: 0.0, 2: 0.0, 3: 0.0, 4: 0.0, 5: 5.551115123125783e-14, 6: 7.438494264988549e-13, 7: 6.7390537594747e-12, 8: 4.5563552930616424e-11, 9: 2.4582558211250216e-10, 10: 1.110200820164664e-09, 11: 4.342326498374405e-09, 12: 1.508237978953275e-08, 13: 4.7414250303745575e-08, 14: 1.36919542370606e-07, 15: 3.6747968001904496e-07, 16: 9.253774524431435e-07, 17: 2.203336912920406e-06, 18: 4.9923102096727234e-06, 19: 1.0822098173157713e-05, 20: 2.2546665601197446e-05, 21: 4.532074354690607e-05, 22: 8.818677535993302e-05, 23: 0.0001665928322958088, 24: 0.00030630208448512164, 25: 0.0005493443154258593, 26: 0.0009629140249822754, 27: 0.0016524551682439714, 28: 0.002780610319996768, 29: 0.0045942767419537844, 30: 0.007462732367424785, 31: 0.011930704267060221, 32: 0.01879138769413924, 33: 0.029185824733568566, 34: 0.044736756655960885, 35: 0.06772710537473747, 36: 0.10133563121027533, 37: 0.14994503352093735, 38: 0.21954071442480405, 39: 0.31822139581526576, 40: 0.45684533603300714, 41: 0.649837256394703, 42: 0.9161799269595594, 43: 1.2806084919250171, 44: 1.7750115968482194, 45: 2.4400160060278786, 46: 3.32668316068061, 47: 4.498166870807674, 48: 6.031058598438188, 49: 8.015967862997941, 50: 10.556642731302013, 51: 13.766638704537126, 52: 17.762241512871448, 53: 22.650159887654098, 54: 28.50866043727187, 55: 35.361691071801125, 56: 43.147606463287936, 57: 51.68773971792341, 58: 60.665034219411965, 59: 69.62764626141784, 60: 78.0331495036052, 61: 85.34045348007432, 62: 91.1354302766718, 63: 95.24778246280941, 64: 97.79982444341154, 65: 99.14743899451898, 66: 99.73364830544686, 67: 99.93583227545949, 68: 99.98869248236701, 69: 99.99863085478329, 70: 99.9998941979418, 71: 99.99999521674788, 72: 99.999999885781, 73: 99.99999999872216, 74: 99.99999999999419, 75: 99.99999999999999, 76: 100.0, 77: 100.0, 78: 100.0, 79: 100.0, 80: 100.0, 81: 100.0, 82: 100.0, 83: 100.0, 84: 100.0, 85: 100.0, 86: 100.0, 87: 100.0, 88: 100.0, 89: 100.0, 90: 100.0, 91: 100.0, 92: 100.0, 93: 100.0, 94: 100.0, 95: 100.0, 96: 100.0, 97: 100.0, 98: 100.0, 99: 100.0, 100: 100.0}


def test_degradation_curve(degradation_data, asset_group_id):
    db = api._get_db()
    db_schema = None

    df_asset_group = pd.DataFrame(
        columns=['site', 'asset'],
        data=np.array([
            ['BEDFORD', 'asset%03d' % i] for i in range(1, 101)
        ]),
    )
    print('df_asset_group=%s' % log_df_info(df_asset_group, head=5))

    df_mappings = pd.DataFrame(
        columns=['site', 'asset', 'devicetype', 'deviceid'],
        data=[],
    )
    print('df_mappings=%s' % log_df_info(df_mappings, head=5))

    df_data_asset = pd.DataFrame(data=degradation_data, columns=['asset', 'installdate', 'statusdate'])
    df_data_asset['site'] = 'BEDFORD'
    # special year only string, astype/dtype does not work, have to use pd.to_datetime
    df_data_asset['installdate'] = pd.to_datetime(df_data_asset['installdate'])
    df_data_asset['statusdate'] = pd.to_datetime(df_data_asset['statusdate'])
    # add status column, like we call from Maximo
    df_data_asset['status'] = 'DECOMMISSIONED'

    try:
        api.set_asset_group_members(asset_group_id=asset_group_id, df=df_asset_group, db=db, db_schema=db_schema)

        api.set_asset_device_mappings(df=df_mappings, db=db, db_schema=db_schema)

        api.set_asset_cache_dimension(df=df_data_asset, db=db, db_schema=db_schema)

        group = DegradationCurveAssetGroupPipeline(
                    asset_group_id=asset_group_id,
                    model_pipeline={
                        'features_for_training': [':installdate', ':statusdate'],
                        'statistics_distribution_args': {
                            'distribution_type': 'WEIBULL',
                            'mean_or_scale': None,
                            'stddev_or_shape': None,
                        },
                    },
                )
        df = group.execute()

        print(log_df_info(df, head=0))

        assert group.new_training
        assert False == group.training
        assert 'DegradationCurveEstimator' in group.model_timestamp
        assert df.empty
        assert 0 == len(df.columns)

        assert 1 == len(group.trained_models)
        estimator, model, model_paths = group.trained_models[0]
        assert model['final_degradation_curve'] == {0: 0.0, 1: 0.0, 2: 0.0, 3: 0.0, 4: 0.0, 5: 5.551115123125783e-14, 6: 7.438494264988549e-13, 7: 6.7390537594747e-12, 8: 4.5563552930616424e-11, 9: 2.4582558211250216e-10, 10: 1.110200820164664e-09, 11: 4.342326498374405e-09, 12: 1.508237978953275e-08, 13: 4.7414250303745575e-08, 14: 1.36919542370606e-07, 15: 3.6747968001904496e-07, 16: 9.253774524431435e-07, 17: 2.203336912920406e-06, 18: 4.9923102096727234e-06, 19: 1.0822098173157713e-05, 20: 2.2546665601197446e-05, 21: 4.532074354690607e-05, 22: 8.818677535993302e-05, 23: 0.0001665928322958088, 24: 0.00030630208448512164, 25: 0.0005493443154258593, 26: 0.0009629140249822754, 27: 0.0016524551682439714, 28: 0.002780610319996768, 29: 0.0045942767419537844, 30: 0.007462732367424785, 31: 0.011930704267060221, 32: 0.01879138769413924, 33: 0.029185824733568566, 34: 0.044736756655960885, 35: 0.06772710537473747, 36: 0.10133563121027533, 37: 0.14994503352093735, 38: 0.21954071442480405, 39: 0.31822139581526576, 40: 0.45684533603300714, 41: 0.649837256394703, 42: 0.9161799269595594, 43: 1.2806084919250171, 44: 1.7750115968482194, 45: 2.4400160060278786, 46: 3.32668316068061, 47: 4.498166870807674, 48: 6.031058598438188, 49: 8.015967862997941, 50: 10.556642731302013, 51: 13.766638704537126, 52: 17.762241512871448, 53: 22.650159887654098, 54: 28.50866043727187, 55: 35.361691071801125, 56: 43.147606463287936, 57: 51.68773971792341, 58: 60.665034219411965, 59: 69.62764626141784, 60: 78.0331495036052, 61: 85.34045348007432, 62: 91.1354302766718, 63: 95.24778246280941, 64: 97.79982444341154, 65: 99.14743899451898, 66: 99.73364830544686, 67: 99.93583227545949, 68: 99.98869248236701, 69: 99.99863085478329, 70: 99.9998941979418, 71: 99.99999521674788, 72: 99.999999885781, 73: 99.99999999872216, 74: 99.99999999999419, 75: 99.99999999999999, 76: 100.0, 77: 100.0, 78: 100.0, 79: 100.0, 80: 100.0, 81: 100.0, 82: 100.0, 83: 100.0, 84: 100.0, 85: 100.0, 86: 100.0, 87: 100.0, 88: 100.0, 89: 100.0, 90: 100.0, 91: 100.0, 92: 100.0, 93: 100.0, 94: 100.0, 95: 100.0, 96: 100.0, 97: 100.0, 98: 100.0, 99: 100.0, 100: 100.0}
    except:
        raise
    finally:
        api.delete_asset_cache_dimension(df=df_data_asset, db=db, db_schema=db_schema)

        api.delete_asset_device_mappings(df=df_mappings, db=db, db_schema=db_schema)

        api.delete_asset_group_members(asset_group_id=asset_group_id, db=db, db_schema=db_schema)

