"""Wrapper around YandexGPT embedding models."""
from __future__ import annotations

import logging
import time
from typing import Any, Callable, Dict, List

from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from tenacity import (
    before_sleep_log,
    retry,
    retry_if_exception_type,
    stop_after_attempt,
    wait_exponential,
)

logger = logging.getLogger(__name__)


class YandexGPTEmbeddings(BaseModel, Embeddings):
    """YandexGPT Embeddings models.

    To use, you should have the ``yandexcloud`` python package installed.

    There are two authentication options for the service account
    with the ``ai.languageModels.user`` role:
        - You can specify the token in a constructor parameter `iam_token`
        or in an environment variable `YC_IAM_TOKEN`.
        - You can specify the key in a constructor parameter `api_key`
        or in an environment variable `YC_API_KEY`.

    To use the default model specify the folder ID in a parameter `folder_id`
    or in an environment variable `YC_FOLDER_ID`.
    Or specify the model URI in a constructor parameter `model_uri`

    Example:
        .. code-block:: python

            from langchain_community.embeddings.yandex import YandexGPTEmbeddings
            embeddings = YandexGPTEmbeddings(iam_token="t1.9eu...", model_uri="emb://<folder-id>/text-search-query/latest")
    """

    iam_token: SecretStr = ""  # type: ignore[assignment]
    """Yandex Cloud IAM token for service account
    with the `ai.languageModels.user` role"""
    api_key: SecretStr = ""  # type: ignore[assignment]
    """Yandex Cloud Api Key for service account
    with the `ai.languageModels.user` role"""
    model_uri: str = ""
    """Model uri to use."""
    folder_id: str = ""
    """Yandex Cloud folder ID"""
    model_name: str = "text-search-query"
    """Model name to use."""
    model_version: str = "latest"
    """Model version to use."""
    url: str = "llm.api.cloud.yandex.net:443"
    """The url of the API."""
    max_retries: int = 6
    """Maximum number of retries to make when generating."""
    sleep_interval: float = 0.0
    """Delay between API requests"""

    @root_validator()
    def validate_environment(cls, values: Dict) -> Dict:
        """Validate that iam token exists in environment."""

        iam_token = convert_to_secret_str(
            get_from_dict_or_env(values, "iam_token", "YC_IAM_TOKEN", "")
        )
        values["iam_token"] = iam_token
        api_key = convert_to_secret_str(
            get_from_dict_or_env(values, "api_key", "YC_API_KEY", "")
        )
        values["api_key"] = api_key
        folder_id = get_from_dict_or_env(values, "folder_id", "YC_FOLDER_ID", "")
        values["folder_id"] = folder_id
        if api_key.get_secret_value() == "" and iam_token.get_secret_value() == "":
            raise ValueError("Either 'YC_API_KEY' or 'YC_IAM_TOKEN' must be provided.")
        if values["iam_token"]:
            values["_grpc_metadata"] = [
                ("authorization", f"Bearer {values['iam_token'].get_secret_value()}")
            ]
            if values["folder_id"]:
                values["_grpc_metadata"].append(("x-folder-id", values["folder_id"]))
        else:
            values["_grpc_metadata"] = (
                ("authorization", f"Api-Key {values['api_key'].get_secret_value()}"),
            )
        if values["model_uri"] == "" and values["folder_id"] == "":
            raise ValueError("Either 'model_uri' or 'folder_id' must be provided.")
        if not values["model_uri"]:
            values[
                "model_uri"
            ] = f"emb://{values['folder_id']}/{values['model_name']}/{values['model_version']}"
        return values

    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        """Embed documents using a YandexGPT embeddings models.

        Args:
            texts: The list of texts to embed.

        Returns:
            List of embeddings, one for each text.
        """

        return _embed_with_retry(self, texts=texts)

    def embed_query(self, text: str) -> List[float]:
        """Embed a query using a YandexGPT embeddings models.

        Args:
            text: The text to embed.

        Returns:
            Embeddings for the text.
        """
        return _embed_with_retry(self, texts=[text])[0]


def _create_retry_decorator(llm: YandexGPTEmbeddings) -> Callable[[Any], Any]:
    from grpc import RpcError

    min_seconds = 1
    max_seconds = 60
    return retry(
        reraise=True,
        stop=stop_after_attempt(llm.max_retries),
        wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
        retry=(retry_if_exception_type((RpcError))),
        before_sleep=before_sleep_log(logger, logging.WARNING),
    )


def _embed_with_retry(llm: YandexGPTEmbeddings, **kwargs: Any) -> Any:
    """Use tenacity to retry the embedding call."""
    retry_decorator = _create_retry_decorator(llm)

    @retry_decorator
    def _completion_with_retry(**_kwargs: Any) -> Any:
        return _make_request(llm, **_kwargs)

    return _completion_with_retry(**kwargs)


def _make_request(self: YandexGPTEmbeddings, texts: List[str]):  # type: ignore[no-untyped-def]
    try:
        import grpc
        from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2 import (  # noqa: E501
            TextEmbeddingRequest,
        )
        from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2_grpc import (  # noqa: E501
            EmbeddingsServiceStub,
        )
    except ImportError as e:
        raise ImportError(
            "Please install YandexCloud SDK  with `pip install yandexcloud` \
            or upgrade it to recent version."
        ) from e
    result = []
    channel_credentials = grpc.ssl_channel_credentials()
    channel = grpc.secure_channel(self.url, channel_credentials)

    for text in texts:
        request = TextEmbeddingRequest(model_uri=self.model_uri, text=text)
        stub = EmbeddingsServiceStub(channel)
        res = stub.TextEmbedding(request, metadata=self._grpc_metadata)  # type: ignore[attr-defined]
        result.append(list(res.embedding))
        time.sleep(self.sleep_interval)

    return result
