# Licensed Materials - Property of IBM
# 5737-M66, 5900-AAA, 5900-AMG
# (C) Copyright IBM Corp. 2019, 2025 All Rights Reserved.
# US Government Users Restricted Rights - Use, duplication, or disclosure
# restricted by GSA ADP Schedule Contract with IBM Corp.

import numpy as np
import pandas as pd
import pytest
from sqlalchemy import column, func, select, Column, DateTime, String, Table

from .. import api
from ..util import current_directory, log_df_info


def test_asset_cache_import(mocker):
    df_data_asset = pd.read_csv('%s/trainbrake_asset_faildates.csv' % current_directory(file=__file__), parse_dates=['faildate'])
    print(log_df_info(df_data_asset, head=10))

    try:
        db = api._get_db()
        db_schema = None

        api.set_asset_cache(df=df_data_asset, siteid_column='site', assetid_column='asset', faildate_column='faildate', db=db, db_schema=db_schema)

        assert 8 == db.connection.execute(select([func.count()]).select_from(db.get_table(api.asset_cache_table_name, db_schema))).first()[0]

        api.set_asset_cache(df=df_data_asset, siteid_column='site', assetid_column='asset', faildate_column='faildate', failurecode_column='failurecode', db=db, db_schema=db_schema)

        assert 8 == db.connection.execute(select([func.count()]).select_from(db.get_table(api.asset_cache_table_name, db_schema))).first()[0]
        assert 'PUMPS' == db.connection.execute(select([column(api.default_failurecode_column)]).select_from(db.get_table(api.asset_cache_table_name, db_schema))).first()[0]

        api.set_asset_cache(df=df_data_asset, siteid_column='site', assetid_column='asset', faildate_column='faildate', problemcode_column='problemcode', db=db, db_schema=db_schema)

        assert 8 == db.connection.execute(select([func.count()]).select_from(db.get_table(api.asset_cache_table_name, db_schema))).first()[0]
        assert 'STOPPED' == db.connection.execute(select([column(api.default_problemcode_column)]).select_from(db.get_table(api.asset_cache_table_name, db_schema))).first()[0]

        api.set_asset_cache(df=df_data_asset, siteid_column='site', assetid_column='asset', faildate_column='faildate', failurecode_column='failurecode', problemcode_column='problemcode', db=db, db_schema=db_schema)

        assert 8 == db.connection.execute(select([func.count()]).select_from(db.get_table(api.asset_cache_table_name, db_schema))).first()[0]
        assert 'PUMPS' == db.connection.execute(select([column(api.default_failurecode_column)]).select_from(db.get_table(api.asset_cache_table_name, db_schema))).first()[0]
        assert 'STOPPED' == db.connection.execute(select([column(api.default_problemcode_column)]).select_from(db.get_table(api.asset_cache_table_name, db_schema))).first()[0]

        # test parameter validation

        with pytest.raises(ValueError):
            api.set_asset_cache(df=None, siteid_column='site', assetid_column='asset', faildate_column='faildate', db=db, db_schema=db_schema)
        with pytest.raises(ValueError):
            api.set_asset_cache(df=df_data_asset, siteid_column=None, assetid_column='asset', faildate_column='faildate', db=db, db_schema=db_schema)
        with pytest.raises(ValueError):
            api.set_asset_cache(df=df_data_asset, siteid_column='site', assetid_column=None, faildate_column='faildate', db=db, db_schema=db_schema)
        with pytest.raises(ValueError):
            api.set_asset_cache(df=df_data_asset, siteid_column='site', assetid_column='asset', faildate_column=None, db=db, db_schema=db_schema)
        with pytest.raises(ValueError):
            api.set_asset_cache(df=df_data_asset, siteid_column='site_', assetid_column='asset', faildate_column='faildate', db=db, db_schema=db_schema)
        with pytest.raises(ValueError):
            api.set_asset_cache(df=df_data_asset, siteid_column='site', assetid_column='asset_', faildate_column='faildate_', db=db, db_schema=db_schema)
        with pytest.raises(ValueError):
            api.set_asset_cache(df=df_data_asset, siteid_column='site', assetid_column='asset_', faildate_column='faildate', failurecode_column='failurecode_', db=db, db_schema=db_schema)
        with pytest.raises(ValueError):
            api.set_asset_cache(df=df_data_asset, siteid_column='site', assetid_column='asset_', faildate_column='faildate', problemcode_column='problemcode_', db=db, db_schema=db_schema)
    except:
        raise
    finally:
        api.delete_asset_cache(df=df_data_asset, db=db, db_schema=db_schema)


def test_asset_cache_table_upgrade():
    df_data_asset = pd.read_csv('%s/trainbrake_asset_faildates.csv' % current_directory(file=__file__), parse_dates=['faildate'])
    print(log_df_info(df_data_asset, head=10))

    try:
        db = api._get_db()
        db_schema = None

        # drop if exists already first
        db.drop_table(api.asset_cache_table_name, schema=db_schema)

        # create a new table with old schema
        table = Table(api.asset_cache_table_name,
                      db.metadata, 
                      Column(api.default_deviceid_column, String(256), nullable=False),
                      Column(api.asset_cache_entity_type_timestamp, DateTime(), nullable=False),
                      Column(api.default_site_column, String(64), nullable=False),
                      Column(api.default_asset_column, String(64), nullable=False),
                      Column(api.default_faildate_column, DateTime(), nullable=True),
                      schema=db_schema)
        table.create()

        # test if table upgrade can sucessfully replace the old table with one of new schema

        api.set_asset_cache(df=df_data_asset, siteid_column='site', assetid_column='asset', faildate_column='faildate', failurecode_column='failurecode', problemcode_column='problemcode', db=db, db_schema=db_schema)

        assert 8 == db.connection.execute(select([func.count()]).select_from(db.get_table(api.asset_cache_table_name, db_schema))).first()[0]
        assert 'PUMPS' == db.connection.execute(select([column(api.default_failurecode_column)]).select_from(db.get_table(api.asset_cache_table_name, db_schema))).first()[0]
        assert 'STOPPED' == db.connection.execute(select([column(api.default_problemcode_column)]).select_from(db.get_table(api.asset_cache_table_name, db_schema))).first()[0]
    except:
        raise
    finally:
        api.delete_asset_cache(df=df_data_asset, db=db, db_schema=db_schema)

