# ------------------------------------------------------------------------------------------------

from dotenv import load_dotenv
import os

load_dotenv(dotenv_path="env.txt")

BASE_URL = os.getenv("BASE_URL")
TEMPERATURE = os.getenv("TEMPERATURE")
MODEL = os.getenv("MODEL")
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")

PG_VECTOR_CONNECTION = os.getenv("PG_VECTOR_CONNECTION")
PG_VECTOR_COLLECTION_NAME = os.getenv("PG_VECTOR_COLLECTION_NAME")

MAX_LENGTH = 768

# print("embeddings model: ", EMBEDDING_MODEL)

# ------------------------------------------------------------------------------------------------
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.documents import Document
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_text_splitters.base import TextSplitter

import logging

logger = logging.getLogger(__name__)


def get_splitter(language: str) -> TextSplitter:
    if language == "zh":
        from chinese_recursive_text_splitter import ChineseRecursiveTextSplitter

        text_splitter = ChineseRecursiveTextSplitter(
            keep_separator=True,
            is_separator_regex=True,
            chunk_size=200,
            chunk_overlap=0,
        )
        return text_splitter
    else:
        from langchain.text_splitter import RecursiveCharacterTextSplitter

        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=200,
            chunk_overlap=20,
        )
        return text_splitter


def get_documents_from_web(url: str, splitter: TextSplitter):
    from langchain_community.document_loaders import WebBaseLoader

    loader = WebBaseLoader(url)
    docs = loader.load()
    splited_docs = splitter.split_documents(docs)
    return splited_docs


from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore


def get_vectordb(
    embeddings: Embeddings, connection: str, collection_name: str
) -> VectorStore:
    from langchain_postgres.vectorstores import PGVector

    vectordb = PGVector(
        embeddings=embeddings,
        collection_name=collection_name,
        connection=connection,
        use_jsonb=True,
    )
    return vectordb


from langchain_core.language_models.llms import BaseLLM


def get_llm(base_url: str, model: str, temperature: float) -> BaseLLM:
    from langchain_community.llms import Ollama

    llm = Ollama(
        model=model,
        base_url=base_url,
        temperature=temperature,
    )
    return llm


def get_embeddings(base_url: str, model: str) -> Embeddings:
    from langchain_community.embeddings import OllamaEmbeddings

    embeddings = OllamaEmbeddings(model=model, base_url=base_url)
    return embeddings


def get_keywords_from_text(
    llm: BaseLLM,
    text: str,
):
    llm.temperature = 0
    from langchain_core.output_parsers import StrOutputParser
    from langchain_core.prompts import ChatPromptTemplate

    text = text[:MAX_LENGTH]
    prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                "You are a helpful assistant that return the keyword of user input, you must return the keyword only, without any explanations.",
            ),
            (
                "human",
                f"Please return the keyword of text >>> {text} <<< (avoid explaining the original text)",
            ),
        ]
    )

    chain = prompt | llm | StrOutputParser()
    keyword = chain.invoke({"text": text})
    return keyword


# ------------------------------------------------------------------------------------------------
llm = get_llm(
    base_url=BASE_URL,
    model=MODEL,
    temperature=float(TEMPERATURE),
)
embeddings = get_embeddings(
    base_url=BASE_URL,
    model=EMBEDDING_MODEL,
)


vectordb = get_vectordb(
    embeddings=embeddings,
    connection=PG_VECTOR_CONNECTION,
    collection_name=PG_VECTOR_COLLECTION_NAME,
)

# ------------------------------------------------------------------------------------------------
query = "什么是垄断性竞争? 它有什么特点"
keyword = get_keywords_from_text(llm, query)
filter = {
    "title": {"$like": f"%{keyword}%"},
}
where_document = {"$contains": keyword}

docs = vectordb.similarity_search(query, k=5, filter=filter)
if len(docs) < 5:
    docs = vectordb.similarity_search(query, k=5, where_document=where_document)

# for doc in docs:
#     print(doc.metadata["source"])

prompt = ChatPromptTemplate.from_template(
    """
Answer the user's question with given context, ignore the context if it's not relevant.:
Context: {context}
Question: {input}
"""
)

chain = create_stuff_documents_chain(
    llm=llm,
    prompt=prompt,
)

response = chain.invoke(
    {
        "input": query,
        "context": docs,
    }
)

print(query)
print("\n")
print(response)