# Licensed Materials - Property of IBM
# 5737-M66, 5900-AAA
# (C) Copyright IBM Corp. 2019, 2025 All Rights Reserved.
# US Government Users Restricted Rights - Use, duplication, or disclosure
# restricted by GSA ADP Schedule Contract with IBM Corp.

import numpy as np
import pandas as pd
from sqlalchemy import func, select

from .. import api
from ..loader import AssetLoader
from ..pipeline import AssetGroupPipeline
from ..util import current_directory, log_df_info, setup_logging


ASSET_GROUP_ID = 'abcd'


setup_logging()
api.init_environ()


def test_entity_type_with_sparse_data(asset_group_id=ASSET_GROUP_ID, iot_type=None):
    """Note that this test needs time, max 5 min, to clean up the created entity type on 
    AS (deletion is an asynchronous operation on AS side).
    """
    if iot_type is None:
        import math
        import random
        iot_type = 'abcdsensor_%05d' % math.floor(random.random() * 10**5)

    db = api._get_db()
    db_schema = None

    pwd = current_directory(__file__)
    print(pwd)

    df_device_1 = pd.read_csv('%s/test_data_preparation.device-1.csv' % pwd)
    # we can't update iot type yet, so still have to prepare all columns up front
    df_device_1['ra_humidity_episode'] = None
    df_device_1['sa_humidity_episode'] = None
    df_device_1['ra_temp_episode'] = None
    df_device_1['sa_temp_episode'] = None
    df_device_1['return_co2_episode'] = None
    df_device_1['lthw_valve_feedback_episode'] = None
    df_device_1['chw_valve_feedback_episode'] = None
    df_device_1['off_coil_temp_episode'] = None
    print('df_device_1=%s' % log_df_info(df_device_1, head=5))

    try:
        t1 = api.setup_iot_type(
            name=iot_type, 
            df=df_device_1, 
            columns=[
                'sa_humidity',
                'sa_temp',
                'ra_humidity',
                'ra_temp',
                'return_co2',
                'chw_valve_feedback',
                'lthw_valve_feedback',
                'off_coil_temp',
                'sa_humidity_episode',
                'sa_temp_episode',
                'ra_humidity_episode',
                'ra_temp_episode',
                'return_co2_episode',
                'chw_valve_feedback_episode',
                'lthw_valve_feedback_episode',
                'off_coil_temp_episode',
            ], 
            deviceid_column='deviceid',
            timestamp_column='rcv_timestamp_utc', 
            timestamp_in_payload=False,
            rename_columns={
            },
            write='deletefirst',
            import_only=False,
            use_wiotp=False,
            db=db,
            db_schema=db_schema)

        assert 100 == db.connection.execute(select([func.count()]).select_from(db.get_table(t1.name, db_schema))).first()[0]

        df_device_2 = pd.read_csv('%s/test_data_preparation.device-2.csv' % pwd)
        df_device_2['deviceid'] = df_device_2['id'].str.replace('-____-BEDFORD', '') + '_1'
        print('df_device_2=%s' % log_df_info(df_device_2, head=5))

        t2 = api.setup_iot_type(
            name=iot_type, 
            df=df_device_2, 
            columns=[
                'sa_humidity_episode',
                'sa_temp_episode',
                'ra_humidity_episode',
                'ra_temp_episode',
                'return_co2_episode',
                'chw_valve_feedback_episode',
                'lthw_valve_feedback_episode',
                'off_coil_temp_episode',
            ], 
            deviceid_column='deviceid',
            timestamp_column='event_timestamp', 
            timestamp_in_payload=False,
            rename_columns={
                'event_timestamp': 'rcv_timestamp_utc',
            },
            write='append',
            import_only=True,
            use_wiotp=False,
            db=db,
            db_schema=db_schema)

        assert 200 == db.connection.execute(select([func.count()]).select_from(db.get_table(t2.name, db_schema))).first()[0]
    except:
        raise
    finally:
        api.delete_iot_type(name=iot_type, use_wiotp=False, db=db, db_schema=db_schema)


def test_writing_df_with_extra_columns(asset_group_id=ASSET_GROUP_ID):
    db = api._get_db()
    db_schema = None

    df_asset_group = pd.DataFrame(data={
        'site': [
            'BEDFORD',
            'BEDFORD',
        ],
        'asset': [
            'abcd-1',
            'abcd-2',
        ],
        'devicetype': [
            'TrainBrakeSensor',
            'TrainBrakeSensor',
        ],
        'deviceid': [
            'abcd-1-1',
            'abcd-2-1',
        ]
    })
    print('df_asset_group=%s' % log_df_info(df_asset_group, head=5))

    try:
        api.set_asset_group_members(asset_group_id=asset_group_id, df=df_asset_group, db=db, db_schema=db_schema)

        table_asset_group_members = db.get_table('apm_asset_groups', db_schema)
        assert 2 == db.connection.execute(select([func.count()]).select_from(db.get_table(table_asset_group_members, db_schema)).where(table_asset_group_members.c['assetgroup'] == asset_group_id)).first()[0]
    except:
        raise
    finally:
        api.delete_asset_group_members(asset_group_id=asset_group_id, db=db, db_schema=db_schema)


def test_loading_asset_without_device_mapped(asset_group_id=ASSET_GROUP_ID):
    db = api._get_db()
    db_schema = None

    df_asset_group = pd.DataFrame(data={
        'site': [
            'BEDFORD',
            'BEDFORD',
        ],
        'asset': [
            'abcd-1',
            'abcd-2',
        ],
    })
    print('df_asset_group=%s' % log_df_info(df_asset_group, head=5))

    df_mappings = pd.DataFrame(data={
        'site': [
            'BEDFORD',
        ],
        'asset': [
            'abcd-1',
        ],
        'devicetype': [
            'abcdsensor',
        ],
        'deviceid': [
            'abcd-1-1',
        ],
    })
    print('df_mappings=%s' % log_df_info(df_mappings, head=5))

    try:
        api.set_asset_group_members(asset_group_id=asset_group_id, df=df_asset_group, db=db, db_schema=db_schema)

        api.set_asset_device_mappings(df=df_mappings, db=db, db_schema=db_schema, drop_table_first=True)

        group = AssetGroupPipeline(
            asset_group_id=asset_group_id,
            model_pipeline={
                'features_for_training': [':installdate', ':estendoflife'],
            },
            db=db,
            db_schema=db_schema)
        loader = AssetLoader(asset_group=group.asset_group_id, _entity_type=group._entity_type, data_items=list(group.pipeline_config['inputs']), names=list(group.pipeline_config['renamed_inputs']), resamples=dict(), entity_type_metadata=group.db.entity_type_metadata, asset_device_mappings=None)
        print(loader)

        # load from tables
        loader._load_asset_device_mappings()

        assert loader.asset_device_mappings == {'ABCD-1-____-BEDFORD': ['abcdsensor:abcd-1-1'], 'ABCD-2-____-BEDFORD': []}
    except:
        raise
    finally:
        api.delete_asset_device_mappings(df=df_mappings, db=db, db_schema=db_schema)

        api.delete_asset_group_members(asset_group_id=asset_group_id, db=db, db_schema=db_schema)


def test_set_device_mapping(asset_group_id=ASSET_GROUP_ID):
    db = api._get_db()
    db_schema = None

    df_asset_group = pd.DataFrame(data={
        'site': [
            'BEDFORD',
            'BEDFORD',
        ],
        'asset': [
            'abcd-1',
            'abcd-2',
        ],
    })
    print('df_asset_group=%s' % log_df_info(df_asset_group, head=5))

    df_mappings_1 = pd.DataFrame(data={
        'site': [
            'BEDFORD',
        ],
        'asset': [
            'abcd-1',
        ],
        'devicetype': [
            'abcdsensor',
        ],
        'deviceid': [
            'abcd-1-1',
        ],
    })
    print('df_mappings_1=%s' % log_df_info(df_mappings_1, head=5))
    df_mappings_2 = pd.DataFrame(data={
        'site': [
            'BEDFORD',
        ],
        'asset': [
            'abcd-1',
        ],
        'devicetype': [
            'abcdsensor',
        ],
        'deviceid': [
            'abcd-1-2',
        ],
    })
    print('df_mappings_2=%s' % log_df_info(df_mappings_2, head=5))
    df_mappings_others = pd.DataFrame(data={
        'site': [
            'BEDFORD',
        ],
        'asset': [
            'abcd-2',
        ],
        'devicetype': [
            'abcdsensor',
        ],
        'deviceid': [
            'abcd-2-1',
        ],
    })
    print('df_mappings_others=%s' % log_df_info(df_mappings_others, head=5))

    try:
        # add some assets that will not be included in the given df (should be left alone)
        api.set_asset_device_mappings(df=df_mappings_others, db=db, db_schema=db_schema, drop_table_first=True)

        api.set_asset_group_members(asset_group_id=asset_group_id, df=df_asset_group, db=db, db_schema=db_schema)

        api.set_asset_device_mappings(df=df_mappings_1, db=db, db_schema=db_schema, delete_df_asset_first=True)
        api.set_asset_device_mappings(df=df_mappings_2, db=db, db_schema=db_schema, delete_df_asset_first=True)

        group = AssetGroupPipeline(
            asset_group_id=asset_group_id,
            model_pipeline={
                'features_for_training': [':installdate', ':estendoflife'],
            },
            db=db,
            db_schema=db_schema)
        loader = AssetLoader(asset_group=group.asset_group_id, _entity_type=group._entity_type, data_items=list(group.pipeline_config['inputs']), names=list(group.pipeline_config['renamed_inputs']), resamples=dict(), entity_type_metadata=group.db.entity_type_metadata, asset_device_mappings=None)
        print(loader)

        # load from tables
        loader._load_asset_device_mappings()

        # result should be only from the last df
        assert loader.asset_device_mappings == {'ABCD-1-____-BEDFORD': ['abcdsensor:abcd-1-2'], 'ABCD-2-____-BEDFORD': ['abcdsensor:abcd-2-1']}

        # cleanup
        api.delete_asset_device_mappings(df=df_mappings_1, db=db, db_schema=db_schema)
        api.delete_asset_device_mappings(df=df_mappings_2, db=db, db_schema=db_schema)

        # now without deletion of existing data
        api.set_asset_device_mappings(df=df_mappings_1, db=db, db_schema=db_schema, delete_df_asset_first=False)
        api.set_asset_device_mappings(df=df_mappings_2, db=db, db_schema=db_schema, delete_df_asset_first=False)

        group = AssetGroupPipeline(
            asset_group_id=asset_group_id,
            model_pipeline={
                'features_for_training': [':installdate', ':estendoflife'],
            },
            db=db,
            db_schema=db_schema)
        loader = AssetLoader(asset_group=group.asset_group_id, _entity_type=group._entity_type, data_items=list(group.pipeline_config['inputs']), names=list(group.pipeline_config['renamed_inputs']), resamples=dict(), entity_type_metadata=group.db.entity_type_metadata, asset_device_mappings=None)
        print(loader)

        # load from tables
        loader._load_asset_device_mappings()

        # result should be composite of two df
        assert set(loader.asset_device_mappings['ABCD-1-____-BEDFORD']) == {'abcdsensor:abcd-1-1', 'abcdsensor:abcd-1-2'}
        assert set(loader.asset_device_mappings['ABCD-2-____-BEDFORD']) == {'abcdsensor:abcd-2-1'}
    except:
        raise
    finally:
        api.delete_asset_device_mappings(df=df_mappings_2, db=db, db_schema=db_schema)
        api.delete_asset_device_mappings(df=df_mappings_1, db=db, db_schema=db_schema)

        api.delete_asset_group_members(asset_group_id=asset_group_id, db=db, db_schema=db_schema)

        api.delete_asset_device_mappings(df=df_mappings_others, db=db, db_schema=db_schema)

