"""Llama Dataset Class."""

import asyncio
import time
from typing import Any, Dict, List, Optional, Sequence

from llama_index.core.base.base_query_engine import BaseQueryEngine
from llama_index.core.bridge.pydantic import Field
from llama_index.core.llama_dataset.base import (
    BaseLlamaDataExample,
    BaseLlamaDataset,
    BaseLlamaExamplePrediction,
    BaseLlamaPredictionDataset,
    CreatedBy,
)


class RagExamplePrediction(BaseLlamaExamplePrediction):
    """
    RAG example prediction class.

    Args:
        response (str): The response generated by the LLM.
        contexts (Optional[List[str]]): The retrieved context (text) for generating
                                        response.

    """

    response: str = Field(
        default="",
        description="The generated (predicted) response that can be compared to a reference (ground-truth) answer.",
    )
    contexts: Optional[List[str]] = Field(
        default=None,
        description="The contexts in raw text form used to generate the response.",
    )

    @property
    def class_name(self) -> str:
        """Data example class name."""
        return "RagExamplePrediction"


class LabelledRagDataExample(BaseLlamaDataExample):
    """
    RAG example class. Analogous to traditional ML datasets, this dataset contains
    the "features" (i.e., query + context) to make a prediction and the "label" (i.e., response)
    to evaluate the prediction.

    Args:
        query (str): The user query
        query_by (CreatedBy): Query generated by human or ai (model-name)
        reference_contexts (Optional[List[str]]): The contexts used for response
        reference_answer ([str]): Reference answer to the query. An answer
                                    that would receive full marks upon evaluation.
        reference_answer_by: The reference answer generated by human or ai (model-name).

    """

    query: str = Field(
        default_factory=str, description="The user query for the example."
    )
    query_by: Optional[CreatedBy] = Field(
        default=None, description="What generated the query."
    )
    reference_contexts: Optional[List[str]] = Field(
        default=None,
        description="The contexts used to generate the reference answer.",
    )
    reference_answer: str = Field(
        default_factory=str,
        description="The reference (ground-truth) answer to the example.",
    )
    reference_answer_by: Optional[CreatedBy] = Field(
        default=None, description="What generated the reference answer."
    )

    @property
    def class_name(self) -> str:
        """Data example class name."""
        return "LabelledRagDataExample"


class RagPredictionDataset(BaseLlamaPredictionDataset):
    """RagDataset class."""

    _prediction_type = RagExamplePrediction

    def to_pandas(self) -> Any:
        """Create pandas dataframe."""
        try:
            import pandas as pd
        except ImportError:
            raise ImportError(
                "pandas is required for this function. Please install it with `pip install pandas`."
            )

        data: Dict[str, List] = {
            "response": [],
            "contexts": [],
        }
        for pred in self.predictions:
            if not isinstance(pred, RagExamplePrediction):
                raise ValueError(
                    "All predictions in the dataset must be of type RagExamplePrediction."
                )
            data["response"].append(pred.response)
            data["contexts"].append(pred.contexts)

        return pd.DataFrame(data)

    @property
    def class_name(self) -> str:
        """Class name."""
        return "RagPredictionDataset"


class LabelledRagDataset(BaseLlamaDataset[BaseQueryEngine]):
    """RagDataset class."""

    _example_type = LabelledRagDataExample

    def to_pandas(self) -> Any:
        """Create pandas dataframe."""
        try:
            import pandas as pd
        except ImportError:
            raise ImportError(
                "pandas is required for this function. Please install it with `pip install pandas`."
            )

        data: Dict[str, List] = {
            "query": [],
            "reference_contexts": [],
            "reference_answer": [],
            "reference_answer_by": [],
            "query_by": [],
        }
        for example in self.examples:
            if not isinstance(example, LabelledRagDataExample):
                raise ValueError(
                    "All examples in the dataset must be of type LabelledRagDataExample."
                )
            data["query"].append(example.query)
            data["reference_contexts"].append(example.reference_contexts)
            data["reference_answer"].append(example.reference_answer)
            data["reference_answer_by"].append(str(example.reference_answer_by))
            data["query_by"].append(str(example.query_by))

        return pd.DataFrame(data)

    async def _apredict_example(  # type: ignore
        self,
        predictor: BaseQueryEngine,
        example: LabelledRagDataExample,
        sleep_time_in_seconds: int,
    ) -> RagExamplePrediction:
        """Async predict RAG example with a query engine."""
        await asyncio.sleep(sleep_time_in_seconds)
        response = await predictor.aquery(example.query)
        return RagExamplePrediction(
            response=str(response), contexts=[s.text for s in response.source_nodes]
        )

    def _predict_example(  # type: ignore
        self,
        predictor: BaseQueryEngine,
        example: LabelledRagDataExample,
        sleep_time_in_seconds: int = 0,
    ) -> RagExamplePrediction:
        """Predict RAG example with a query engine."""
        time.sleep(sleep_time_in_seconds)
        response = predictor.query(example.query)
        return RagExamplePrediction(
            response=str(response), contexts=[s.text for s in response.source_nodes]
        )

    def _construct_prediction_dataset(  # type: ignore
        self, predictions: Sequence[RagExamplePrediction]
    ) -> RagPredictionDataset:
        """Construct prediction dataset."""
        return RagPredictionDataset(predictions=predictions)

    @property
    def class_name(self) -> str:
        """Class name."""
        return "LabelledRagDataset"


# British English + American English
LabeledRagDataExample = LabelledRagDataExample
LabeledRagDataset = LabelledRagDataset
