import json
import logging
from typing import Any, Dict, Iterator, List, Optional

import requests
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk

logger = logging.getLogger(__name__)


class CloudflareWorkersAI(LLM):
    """Langchain LLM class to help to access Cloudflare Workers AI service.

    To use, you must provide an API token and
    account ID to access Cloudflare Workers AI, and
    pass it as a named parameter to the constructor.

    Example:
        .. code-block:: python

            from langchain_community.llms.cloudflare_workersai import CloudflareWorkersAI

            my_account_id = "my_account_id"
            my_api_token = "my_secret_api_token"
            llm_model =  "@cf/meta/llama-2-7b-chat-int8"

            cf_ai = CloudflareWorkersAI(
                account_id=my_account_id,
                api_token=my_api_token,
                model=llm_model
            )
    """  # noqa: E501

    account_id: str
    api_token: str
    model: str = "@cf/meta/llama-2-7b-chat-int8"
    base_url: str = "https://api.cloudflare.com/client/v4/accounts"
    streaming: bool = False
    endpoint_url: str = ""

    def __init__(self, **kwargs: Any) -> None:
        """Initialize the Cloudflare Workers AI class."""
        super().__init__(**kwargs)

        self.endpoint_url = f"{self.base_url}/{self.account_id}/ai/run/{self.model}"

    @property
    def _llm_type(self) -> str:
        """Return type of LLM."""
        return "cloudflare"

    @property
    def _default_params(self) -> Dict[str, Any]:
        """Default parameters"""
        return {}

    @property
    def _identifying_params(self) -> Dict[str, Any]:
        """Identifying parameters"""
        return {
            "account_id": self.account_id,
            "api_token": self.api_token,
            "model": self.model,
            "base_url": self.base_url,
        }

    def _call_api(self, prompt: str, params: Dict[str, Any]) -> requests.Response:
        """Call Cloudflare Workers API"""
        headers = {"Authorization": f"Bearer {self.api_token}"}
        data = {"prompt": prompt, "stream": self.streaming, **params}
        response = requests.post(self.endpoint_url, headers=headers, json=data)
        return response

    def _process_response(self, response: requests.Response) -> str:
        """Process API response"""
        if response.ok:
            data = response.json()
            return data["result"]["response"]
        else:
            raise ValueError(f"Request failed with status {response.status_code}")

    def _stream(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> Iterator[GenerationChunk]:
        """Streaming prediction"""
        original_steaming: bool = self.streaming
        self.streaming = True
        _response_prefix_count = len("data: ")
        _response_stream_end = b"data: [DONE]"
        for chunk in self._call_api(prompt, kwargs).iter_lines():
            if chunk == _response_stream_end:
                break
            if len(chunk) > _response_prefix_count:
                try:
                    data = json.loads(chunk[_response_prefix_count:])
                except Exception as e:
                    logger.debug(chunk)
                    raise e
                if data is not None and "response" in data:
                    yield GenerationChunk(text=data["response"])
                if run_manager:
                    run_manager.on_llm_new_token(data["response"])
        logger.debug("stream end")
        self.streaming = original_steaming

    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> str:
        """Regular prediction"""
        if self.streaming:
            return "".join(
                [c.text for c in self._stream(prompt, stop, run_manager, **kwargs)]
            )
        else:
            response = self._call_api(prompt, kwargs)
            return self._process_response(response)
