from typing import Any, Iterator, List, Mapping, Optional

import requests
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk
from requests.exceptions import ConnectionError

from langchain_community.llms.utils import enforce_stop_tokens


class TitanTakeoffPro(LLM):
    """Titan Takeoff Pro is a language model that can be used to generate text."""

    base_url: Optional[str] = "http://localhost:3000"
    """Specifies the baseURL to use for the Titan Takeoff Pro API.
    Default = http://localhost:3000.
    """

    max_new_tokens: Optional[int] = None
    """Maximum tokens generated."""

    min_new_tokens: Optional[int] = None
    """Minimum tokens generated."""

    sampling_topk: Optional[int] = None
    """Sample predictions from the top K most probable candidates."""

    sampling_topp: Optional[float] = None
    """Sample from predictions whose cumulative probability exceeds this value.
    """

    sampling_temperature: Optional[float] = None
    """Sample with randomness. Bigger temperatures are associated with 
    more randomness and 'creativity'.
    """

    repetition_penalty: Optional[float] = None
    """Penalise the generation of tokens that have been generated before. 
    Set to > 1 to penalize.
    """

    regex_string: Optional[str] = None
    """A regex string for constrained generation."""

    no_repeat_ngram_size: Optional[int] = None
    """Prevent repetitions of ngrams of this size. Default = 0 (turned off)."""

    streaming: bool = False
    """Whether to stream the output. Default = False."""

    @property
    def _default_params(self) -> Mapping[str, Any]:
        """Get the default parameters for calling Titan Takeoff Server (Pro)."""
        return {
            **(
                {"regex_string": self.regex_string}
                if self.regex_string is not None
                else {}
            ),
            **(
                {"sampling_temperature": self.sampling_temperature}
                if self.sampling_temperature is not None
                else {}
            ),
            **(
                {"sampling_topp": self.sampling_topp}
                if self.sampling_topp is not None
                else {}
            ),
            **(
                {"repetition_penalty": self.repetition_penalty}
                if self.repetition_penalty is not None
                else {}
            ),
            **(
                {"max_new_tokens": self.max_new_tokens}
                if self.max_new_tokens is not None
                else {}
            ),
            **(
                {"min_new_tokens": self.min_new_tokens}
                if self.min_new_tokens is not None
                else {}
            ),
            **(
                {"sampling_topk": self.sampling_topk}
                if self.sampling_topk is not None
                else {}
            ),
            **(
                {"no_repeat_ngram_size": self.no_repeat_ngram_size}
                if self.no_repeat_ngram_size is not None
                else {}
            ),
        }

    @property
    def _llm_type(self) -> str:
        """Return type of llm."""
        return "titan_takeoff_pro"

    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> str:
        """Call out to Titan Takeoff (Pro) generate endpoint.

        Args:
            prompt: The prompt to pass into the model.
            stop: Optional list of stop words to use when generating.

        Returns:
            The string generated by the model.

        Example:
            .. code-block:: python

                prompt = "What is the capital of the United Kingdom?"
                response = model(prompt)

        """
        try:
            if self.streaming:
                text_output = ""
                for chunk in self._stream(
                    prompt=prompt,
                    stop=stop,
                    run_manager=run_manager,
                ):
                    text_output += chunk.text
                return text_output
            url = f"{self.base_url}/generate"
            params = {"text": prompt, **self._default_params}

            response = requests.post(url, json=params)
            response.raise_for_status()
            response.encoding = "utf-8"

            text = ""
            if "text" in response.json():
                text = response.json()["text"]
                text = text.replace("</s>", "")
            else:
                raise ValueError("Something went wrong.")
            if stop is not None:
                text = enforce_stop_tokens(text, stop)
            return text
        except ConnectionError:
            raise ConnectionError(
                "Could not connect to Titan Takeoff (Pro) server. \
                Please make sure that the server is running."
            )

    def _stream(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> Iterator[GenerationChunk]:
        """Call out to Titan Takeoff (Pro) stream endpoint.

        Args:
            prompt: The prompt to pass into the model.
            stop: Optional list of stop words to use when generating.

        Returns:
            The string generated by the model.

        Yields:
            A dictionary like object containing a string token.

        Example:
            .. code-block:: python

                prompt = "What is the capital of the United Kingdom?"
                response = model(prompt)

        """
        url = f"{self.base_url}/generate_stream"
        params = {"text": prompt, **self._default_params}

        response = requests.post(url, json=params, stream=True)
        response.encoding = "utf-8"
        buffer = ""
        for text in response.iter_content(chunk_size=1, decode_unicode=True):
            buffer += text
            if "data:" in buffer:
                # Remove the first instance of "data:" from the buffer.
                if buffer.startswith("data:"):
                    buffer = ""
                if len(buffer.split("data:", 1)) == 2:
                    content, _ = buffer.split("data:", 1)
                    buffer = content.rstrip("\n")
                # Trim the buffer to only have content after the "data:" part.
                if buffer:  # Ensure that there's content to process.
                    chunk = GenerationChunk(text=buffer)
                    buffer = ""  # Reset buffer for the next set of data.
                    yield chunk
                    if run_manager:
                        run_manager.on_llm_new_token(token=chunk.text)

        # Yield any remaining content in the buffer.
        if buffer:
            chunk = GenerationChunk(text=buffer.replace("</s>", ""))
            yield chunk
            if run_manager:
                run_manager.on_llm_new_token(token=chunk.text)

    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        """Get the identifying parameters."""
        return {"base_url": self.base_url, **{}, **self._default_params}
