import logging
import uuid
from abc import ABC
from collections import defaultdict
from contextlib import contextmanager
from contextvars import ContextVar
from typing import Any, Dict, Generator, List, Optional, cast, Type

from llama_index.core.callbacks.base_handler import BaseCallbackHandler
from llama_index.core.callbacks.schema import (
    BASE_TRACE_EVENT,
    LEAF_EVENTS,
    CBEventType,
    EventPayload,
)
from llama_index.core.bridge.pydantic import (
    GetCoreSchemaHandler,
    GetJsonSchemaHandler,
)
from llama_index.core.bridge.pydantic_core import CoreSchema, core_schema

logger = logging.getLogger(__name__)
global_stack_trace = ContextVar("trace", default=[BASE_TRACE_EVENT])
empty_trace_ids: List[str] = []
global_stack_trace_ids = ContextVar("trace_ids", default=empty_trace_ids)


class CallbackManager(BaseCallbackHandler, ABC):
    """
    Callback manager that handles callbacks for events within LlamaIndex.

    The callback manager provides a way to call handlers on event starts/ends.

    Additionally, the callback manager traces the current stack of events.
    It does this by using a few key attributes.
    - trace_stack - The current stack of events that have not ended yet.
                    When an event ends, it's removed from the stack.
                    Since this is a contextvar, it is unique to each
                    thread/task.
    - trace_map - A mapping of event ids to their children events.
                  On the start of events, the bottom of the trace stack
                  is used as the current parent event for the trace map.
    - trace_id - A simple name for the current trace, usually denoting the
                 entrypoint (query, index_construction, insert, etc.)

    Args:
        handlers (List[BaseCallbackHandler]): list of handlers to use.

    Usage:
        with callback_manager.event(CBEventType.QUERY) as event:
            event.on_start(payload={key, val})
            ...
            event.on_end(payload={key, val})

    """

    def __init__(self, handlers: Optional[List[BaseCallbackHandler]] = None):
        """Initialize the manager with a list of handlers."""
        from llama_index.core import global_handler

        handlers = handlers or []

        # add eval handlers based on global defaults
        if global_handler is not None:
            new_handler = global_handler
            # go through existing handlers, check if any are same type as new handler
            # if so, error
            for existing_handler in handlers:
                if isinstance(existing_handler, type(new_handler)):
                    raise ValueError(
                        "Cannot add two handlers of the same type "
                        f"{type(new_handler)} to the callback manager."
                    )
            handlers.append(new_handler)

        # if we passed in no handlers, use the global default
        if len(handlers) == 0:
            from llama_index.core.settings import Settings

            # hidden var access to prevent recursion in getter
            cb_manager = Settings._callback_manager
            if cb_manager is not None:
                handlers = cb_manager.handlers

        self.handlers: List[BaseCallbackHandler] = handlers
        self._trace_map: Dict[str, List[str]] = defaultdict(list)

    def on_event_start(
        self,
        event_type: CBEventType,
        payload: Optional[Dict[str, Any]] = None,
        event_id: Optional[str] = None,
        parent_id: Optional[str] = None,
        **kwargs: Any,
    ) -> str:
        """Run handlers when an event starts and return id of event."""
        event_id = event_id or str(uuid.uuid4())

        # if no trace is running, start a default trace
        try:
            parent_id = parent_id or global_stack_trace.get()[-1]
        except IndexError:
            self.start_trace("llama-index")
            parent_id = global_stack_trace.get()[-1]
        parent_id = cast(str, parent_id)
        self._trace_map[parent_id].append(event_id)
        for handler in self.handlers:
            if event_type not in handler.event_starts_to_ignore:
                handler.on_event_start(
                    event_type,
                    payload,
                    event_id=event_id,
                    parent_id=parent_id,
                    **kwargs,
                )

        if event_type not in LEAF_EVENTS:
            # copy the stack trace to prevent conflicts with threads/coroutines
            current_trace_stack = global_stack_trace.get().copy()
            current_trace_stack.append(event_id)
            global_stack_trace.set(current_trace_stack)

        return event_id

    def on_event_end(
        self,
        event_type: CBEventType,
        payload: Optional[Dict[str, Any]] = None,
        event_id: Optional[str] = None,
        **kwargs: Any,
    ) -> None:
        """Run handlers when an event ends."""
        event_id = event_id or str(uuid.uuid4())
        for handler in self.handlers:
            if event_type not in handler.event_ends_to_ignore:
                handler.on_event_end(event_type, payload, event_id=event_id, **kwargs)

        if event_type not in LEAF_EVENTS:
            # copy the stack trace to prevent conflicts with threads/coroutines
            current_trace_stack = global_stack_trace.get().copy()
            current_trace_stack.pop()
            global_stack_trace.set(current_trace_stack)

    def add_handler(self, handler: BaseCallbackHandler) -> None:
        """Add a handler to the callback manager."""
        self.handlers.append(handler)

    def remove_handler(self, handler: BaseCallbackHandler) -> None:
        """Remove a handler from the callback manager."""
        self.handlers.remove(handler)

    def set_handlers(self, handlers: List[BaseCallbackHandler]) -> None:
        """Set handlers as the only handlers on the callback manager."""
        self.handlers = handlers

    @contextmanager
    def event(
        self,
        event_type: CBEventType,
        payload: Optional[Dict[str, Any]] = None,
        event_id: Optional[str] = None,
    ) -> Generator["EventContext", None, None]:
        """
        Context manager for lanching and shutdown of events.

        Handles sending on_evnt_start and on_event_end to handlers for specified event.

        Usage:
            with callback_manager.event(CBEventType.QUERY, payload={key, val}) as event:
                ...
                event.on_end(payload={key, val})  # optional
        """
        # create event context wrapper
        event = EventContext(self, event_type, event_id=event_id)
        event.on_start(payload=payload)

        payload = None
        try:
            yield event
        except Exception as e:
            # data already logged to trace?
            if not hasattr(e, "event_added"):
                payload = {EventPayload.EXCEPTION: e}
                e.event_added = True  # type: ignore
                if not event.finished:
                    event.on_end(payload=payload)
            raise
        finally:
            # ensure event is ended
            if not event.finished:
                event.on_end(payload=payload)

    @contextmanager
    def as_trace(self, trace_id: str) -> Generator[None, None, None]:
        """Context manager tracer for lanching and shutdown of traces."""
        self.start_trace(trace_id=trace_id)

        try:
            yield
        except Exception as e:
            # event already added to trace?
            if not hasattr(e, "event_added"):
                self.on_event_start(
                    CBEventType.EXCEPTION, payload={EventPayload.EXCEPTION: e}
                )
                e.event_added = True  # type: ignore

            raise
        finally:
            # ensure trace is ended
            self.end_trace(trace_id=trace_id)

    def start_trace(self, trace_id: Optional[str] = None) -> None:
        """Run when an overall trace is launched."""
        current_trace_stack_ids = global_stack_trace_ids.get().copy()
        if trace_id is not None:
            if len(current_trace_stack_ids) == 0:
                self._reset_trace_events()

                for handler in self.handlers:
                    handler.start_trace(trace_id=trace_id)

                current_trace_stack_ids = [trace_id]
            else:
                current_trace_stack_ids.append(trace_id)

        global_stack_trace_ids.set(current_trace_stack_ids)

    def end_trace(
        self,
        trace_id: Optional[str] = None,
        trace_map: Optional[Dict[str, List[str]]] = None,
    ) -> None:
        """Run when an overall trace is exited."""
        current_trace_stack_ids = global_stack_trace_ids.get().copy()
        if trace_id is not None and len(current_trace_stack_ids) > 0:
            current_trace_stack_ids.pop()
            if len(current_trace_stack_ids) == 0:
                for handler in self.handlers:
                    handler.end_trace(trace_id=trace_id, trace_map=self._trace_map)
                current_trace_stack_ids = []

        global_stack_trace_ids.set(current_trace_stack_ids)

    def _reset_trace_events(self) -> None:
        """Helper function to reset the current trace."""
        self._trace_map = defaultdict(list)
        global_stack_trace.set([BASE_TRACE_EVENT])

    @property
    def trace_map(self) -> Dict[str, List[str]]:
        return self._trace_map

    @classmethod
    def __get_pydantic_core_schema__(
        cls, source: Type[Any], handler: GetCoreSchemaHandler
    ) -> CoreSchema:
        return core_schema.any_schema()

    @classmethod
    def __get_pydantic_json_schema__(
        cls, core_schema: CoreSchema, handler: GetJsonSchemaHandler
    ) -> Dict[str, Any]:
        json_schema = handler(core_schema)
        return handler.resolve_ref_schema(json_schema)


class EventContext:
    """
    Simple wrapper to call callbacks on event starts and ends
    with an event type and id.
    """

    def __init__(
        self,
        callback_manager: CallbackManager,
        event_type: CBEventType,
        event_id: Optional[str] = None,
    ):
        self._callback_manager = callback_manager
        self._event_type = event_type
        self._event_id = event_id or str(uuid.uuid4())
        self.started = False
        self.finished = False

    def on_start(self, payload: Optional[Dict[str, Any]] = None, **kwargs: Any) -> None:
        if not self.started:
            self.started = True
            self._callback_manager.on_event_start(
                self._event_type, payload=payload, event_id=self._event_id, **kwargs
            )
        else:
            logger.warning(
                f"Event {self._event_type!s}: {self._event_id} already started!"
            )

    def on_end(self, payload: Optional[Dict[str, Any]] = None, **kwargs: Any) -> None:
        if not self.finished:
            self.finished = True
            self._callback_manager.on_event_end(
                self._event_type, payload=payload, event_id=self._event_id, **kwargs
            )
