from typing import Any, Dict, List, Type, Union

from langchain_community.graphs import NetworkxEntityGraph
from langchain_community.graphs.networkx_graph import (
    KnowledgeTriple,
    get_entities,
    parse_triples,
)
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import BaseMessage, SystemMessage, get_buffer_string
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import Field

from langchain.chains.llm import LLMChain
from langchain.memory.chat_memory import BaseChatMemory
from langchain.memory.prompt import (
    ENTITY_EXTRACTION_PROMPT,
    KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT,
)
from langchain.memory.utils import get_prompt_input_key


class ConversationKGMemory(BaseChatMemory):
    """Knowledge graph conversation memory.

    Integrates with external knowledge graph to store and retrieve
    information about knowledge triples in the conversation.
    """

    k: int = 2
    human_prefix: str = "Human"
    ai_prefix: str = "AI"
    kg: NetworkxEntityGraph = Field(default_factory=NetworkxEntityGraph)
    knowledge_extraction_prompt: BasePromptTemplate = KNOWLEDGE_TRIPLE_EXTRACTION_PROMPT
    entity_extraction_prompt: BasePromptTemplate = ENTITY_EXTRACTION_PROMPT
    llm: BaseLanguageModel
    summary_message_cls: Type[BaseMessage] = SystemMessage
    """Number of previous utterances to include in the context."""
    memory_key: str = "history"  #: :meta private:

    def load_memory_variables(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
        """Return history buffer."""
        entities = self._get_current_entities(inputs)

        summary_strings = []
        for entity in entities:
            knowledge = self.kg.get_entity_knowledge(entity)
            if knowledge:
                summary = f"On {entity}: {'. '.join(knowledge)}."
                summary_strings.append(summary)
        context: Union[str, List]
        if not summary_strings:
            context = [] if self.return_messages else ""
        elif self.return_messages:
            context = [
                self.summary_message_cls(content=text) for text in summary_strings
            ]
        else:
            context = "\n".join(summary_strings)

        return {self.memory_key: context}

    @property
    def memory_variables(self) -> List[str]:
        """Will always return list of memory variables.

        :meta private:
        """
        return [self.memory_key]

    def _get_prompt_input_key(self, inputs: Dict[str, Any]) -> str:
        """Get the input key for the prompt."""
        if self.input_key is None:
            return get_prompt_input_key(inputs, self.memory_variables)
        return self.input_key

    def _get_prompt_output_key(self, outputs: Dict[str, Any]) -> str:
        """Get the output key for the prompt."""
        if self.output_key is None:
            if len(outputs) != 1:
                raise ValueError(f"One output key expected, got {outputs.keys()}")
            return list(outputs.keys())[0]
        return self.output_key

    def get_current_entities(self, input_string: str) -> List[str]:
        chain = LLMChain(llm=self.llm, prompt=self.entity_extraction_prompt)
        buffer_string = get_buffer_string(
            self.chat_memory.messages[-self.k * 2 :],
            human_prefix=self.human_prefix,
            ai_prefix=self.ai_prefix,
        )
        output = chain.predict(
            history=buffer_string,
            input=input_string,
        )
        return get_entities(output)

    def _get_current_entities(self, inputs: Dict[str, Any]) -> List[str]:
        """Get the current entities in the conversation."""
        prompt_input_key = self._get_prompt_input_key(inputs)
        return self.get_current_entities(inputs[prompt_input_key])

    def get_knowledge_triplets(self, input_string: str) -> List[KnowledgeTriple]:
        chain = LLMChain(llm=self.llm, prompt=self.knowledge_extraction_prompt)
        buffer_string = get_buffer_string(
            self.chat_memory.messages[-self.k * 2 :],
            human_prefix=self.human_prefix,
            ai_prefix=self.ai_prefix,
        )
        output = chain.predict(
            history=buffer_string,
            input=input_string,
            verbose=True,
        )
        knowledge = parse_triples(output)
        return knowledge

    def _get_and_update_kg(self, inputs: Dict[str, Any]) -> None:
        """Get and update knowledge graph from the conversation history."""
        prompt_input_key = self._get_prompt_input_key(inputs)
        knowledge = self.get_knowledge_triplets(inputs[prompt_input_key])
        for triple in knowledge:
            self.kg.add_triple(triple)

    def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str]) -> None:
        """Save context from this conversation to buffer."""
        super().save_context(inputs, outputs)
        self._get_and_update_kg(inputs)

    def clear(self) -> None:
        """Clear memory contents."""
        super().clear()
        self.kg.clear()
