# Licensed Materials - Property of IBM
# 5737-M66, 5900-AAA
# (C) Copyright IBM Corp. 2019, 2025 All Rights Reserved.
# US Government Users Restricted Rights - Use, duplication, or disclosure
# restricted by GSA ADP Schedule Contract with IBM Corp.

import pytest
import unittest
import os
import nbformat
from testbook import testbook
from nbclient import NotebookClient
from nbclient.exceptions import CellExecutionError


# https://testbook.readthedocs.io/en/latest/usage/index.htm

class Semi(unittest.TestCase):
    notebook_content = None
    tagged_cells = {} # Collection to hold a cells with tags. Keyed based on tags.

    @classmethod
    def setup_class(cls):
        Semi.NOTEBOOK_PATH_PREFIX = os.getenv("NOTEBOOK_PATH_PREFIX")
        Semi.NOTEBOOK_NAME = 'WS - Anomaly Detection - SemiSupervised.ipynb'
        Semi.NOTEBOOK_PATH = Semi.NOTEBOOK_PATH_PREFIX + Semi.NOTEBOOK_NAME
        Semi.OUTPUT_NOTEBOOK_NAME = Semi.NOTEBOOK_NAME.replace(".ipynb", "_output.ipynb")
        Semi.OUTPUT_NOTEBOOK_PATH = Semi.NOTEBOOK_PATH_PREFIX + Semi.OUTPUT_NOTEBOOK_NAME
        Semi.extract_cells()

    # Setup method which will execute before every method
    def setup(self, method):
        pass

    @staticmethod
    def extract_cells():
        nb = None
        with open(Semi.NOTEBOOK_PATH, 'r', encoding='utf-8') as notebook_file:
            nb = nbformat.read(notebook_file, as_version=4)

        cells_count = len(nb.get('cells'))

        if cells_count == 0:
            raise Exception("Notebook is empty!!")

        # Extracting cells between startFVT to endFVT
        start_tag = "startFVT"
        end_tag = "endFVT"

        # Finding the index of cell with startFVT tag
        start_index = -1
        end_index = -1

        for index, cell in enumerate(nb.get('cells')):
            cell_metadata = cell.get('metadata')

            if "tags" in cell_metadata:
                if start_tag in cell_metadata.get('tags'):
                    start_index = index
                elif end_tag in cell_metadata.get('tags'):
                    end_index = index

            if start_index != -1 and end_index != -1:
                break
        else:
            if start_index == -1:
                print("startFVT tag not found.")
                start_index = 0

            if end_index == -1:
                print("endFVT tag not found.")
                end_index = (cells_count - 1)

        notebook_subset = None
        if start_index == 0 and end_index == cells_count - 1:
            notebook_subset = nb
        else:
            notebook_subset = nbformat.v4.new_notebook(cells=nb.get('cells')[start_index:end_index + 1])

        client = NotebookClient(notebook_subset, timeout=21600, kernel_name='python3', allow_errors=False, resources={})

        try:
            client.execute()
            print("Notebook executed successfully")

        except Exception as ex:
            print("Error occurred while executing notebook. \n Exception trace:", str(ex))
            raise ex

        finally:
            # Saving entire notebook in memory.
            Semi.notebook_content = notebook_subset.copy()

            # Extracting cells with tags and collecting them in tagged_cells.
            for cell in Semi.notebook_content.get('cells'):
                metadata = cell.get('metadata')
                if (metadata is not None) and ("tags" in metadata):
                    Semi.tagged_cells = {tag: cell for tag in metadata.get('tags')}
            
            # Saving executed notebook for debugging
            nbformat.write(Semi.notebook_content, Semi.OUTPUT_NOTEBOOK_PATH)

    @pytest.mark.predict_fvt
    def test_EndToEndExecution(self):
        tag = "endFVT"
        if (len(Semi.tagged_cells) > 0) and (tag in Semi.tagged_cells.keys()):
            cell = Semi.tagged_cells.get(tag)
            assert len(cell.get('outputs')) > 0 and cell.get('outputs')[0]['text'] != ""
        else:
            raise Exception(
                f"{tag} tag not found in notebook. Please refer to saved notebook at {Semi.OUTPUT_NOTEBOOK_PATH}")

        # with testbook(notebookPath,execute=slice('startFVT', 'endFVT'),timeout=3000) as tb:
        #   assert tb.cell_output_text('endFVT') != ''

    @pytest.mark.predict_fvt
    def test_ModelDeploys(self):
        tag = "deployWML"
        if (len(Semi.tagged_cells) > 0) and (tag in Semi.tagged_cells.keys()):
            cell = Semi.tagged_cells.get(tag)
            assert len(cell.get('outputs')) > 0 and "Successfully finished deployment creation" in cell.get('outputs')[0]['text']
        else:
            raise Exception(
                f"{tag} tag not found in notebook. Please refer to saved notebook at {Semi.OUTPUT_NOTEBOOK_PATH}")

        # notebookPath = self.NOTEBOOK_PATH_PREFIX + 'WS - Anomaly Detection - SemiSupervised.ipynb'
        # with testbook(notebookPath,execute=slice('startFVT', 'endFVT'),timeout=3000) as tb:
        #   assert "Successfully finished deployment creation" in tb.cell_output_text('deployWML')

 
class Unsup(unittest.TestCase):
    def setUp(self):
        self.NOTEBOOK_PATH_PREFIX = os.getenv("NOTEBOOK_PATH_PREFIX")
        self.NOTEBOOK_NAME = 'WS - Anomaly Detection - UnSupervised.ipynb'
        self.NOTEBOOK_PATH = self.NOTEBOOK_PATH_PREFIX + self.NOTEBOOK_NAME

    @pytest.mark.predict_fvt
    @pytest.mark.dependency(name="AD_WS_Unsup", depends=['Data_Loader_FastStart'], scope='session')
    def test_EndToEndExecution(self):      
        notebookPath = self.NOTEBOOK_PATH_PREFIX + 'WS - Anomaly Detection - UnSupervised.ipynb'
        with testbook(notebookPath,execute=slice('startFVT', 'endFVT'),timeout=21600) as tb:
            assert tb.cell_output_text('endFVT') != ''
