# ------------------------------------------------------------------------------------------------
from dotenv import load_dotenv
import os
load_dotenv(dotenv_path="env.txt")
BASE_URL = os.getenv("BASE_URL")
TEMPERATURE = os.getenv("TEMPERATURE")
MODEL = os.getenv("MODEL")
EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")
PG_VECTOR_CONNECTION = os.getenv("PG_VECTOR_CONNECTION")
PG_VECTOR_COLLECTION_NAME = os.getenv("PG_VECTOR_COLLECTION_NAME")
MAX_LENGTH = 768
# print("embeddings model: ", EMBEDDING_MODEL)
# ------------------------------------------------------------------------------------------------
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.documents import Document
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_text_splitters.base import TextSplitter
import logging
logger = logging.getLogger(__name__)
def get_splitter(language: str) -> TextSplitter:
if language == "zh":
from chinese_recursive_text_splitter import ChineseRecursiveTextSplitter
text_splitter = ChineseRecursiveTextSplitter(
keep_separator=True,
is_separator_regex=True,
chunk_size=200,
chunk_overlap=0,
)
return text_splitter
else:
from langchain.text_splitter import RecursiveCharacterTextSplitter
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=200,
chunk_overlap=20,
)
return text_splitter
def get_documents_from_web(url: str, splitter: TextSplitter):
from langchain_community.document_loaders import WebBaseLoader
loader = WebBaseLoader(url)
docs = loader.load()
splited_docs = splitter.split_documents(docs)
return splited_docs
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore
def get_vectordb(
embeddings: Embeddings, connection: str, collection_name: str
) -> VectorStore:
from langchain_postgres.vectorstores import PGVector
vectordb = PGVector(
embeddings=embeddings,
collection_name=collection_name,
connection=connection,
use_jsonb=True,
)
return vectordb
from langchain_core.language_models.llms import BaseLLM
def get_llm(base_url: str, model: str, temperature: float) -> BaseLLM:
from langchain_community.llms import Ollama
llm = Ollama(
model=model,
base_url=base_url,
temperature=temperature,
)
return llm
def get_embeddings(base_url: str, model: str) -> Embeddings:
from langchain_community.embeddings import OllamaEmbeddings
embeddings = OllamaEmbeddings(model=model, base_url=base_url)
return embeddings
def get_keywords_from_text(
llm: BaseLLM,
text: str,
):
llm.temperature = 0
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
text = text[:MAX_LENGTH]
prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"You are a helpful assistant that return the keyword of user input, you must return the keyword only, without any explanations.",
),
(
"human",
f"Please return the keyword of text >>> {text} <<< (avoid explaining the original text)",
),
]
)
chain = prompt | llm | StrOutputParser()
keyword = chain.invoke({"text": text})
return keyword
# ------------------------------------------------------------------------------------------------
llm = get_llm(
base_url=BASE_URL,
model=MODEL,
temperature=float(TEMPERATURE),
)
embeddings = get_embeddings(
base_url=BASE_URL,
model=EMBEDDING_MODEL,
)
vectordb = get_vectordb(
embeddings=embeddings,
connection=PG_VECTOR_CONNECTION,
collection_name=PG_VECTOR_COLLECTION_NAME,
)
# ------------------------------------------------------------------------------------------------
query = "什么是垄断性竞争? 它有什么特点"
keyword = get_keywords_from_text(llm, query)
filter = {
"title": {"$like": f"%{keyword}%"},
}
where_document = {"$contains": keyword}
docs = vectordb.similarity_search(query, k=5, filter=filter)
if len(docs) < 5:
docs = vectordb.similarity_search(query, k=5, where_document=where_document)
# for doc in docs:
# print(doc.metadata["source"])
prompt = ChatPromptTemplate.from_template(
"""
Answer the user's question with given context, ignore the context if it's not relevant.:
Context: {context}
Question: {input}
"""
)
chain = create_stuff_documents_chain(
llm=llm,
prompt=prompt,
)
response = chain.invoke(
{
"input": query,
"context": docs,
}
)
print(query)
print("\n")
print(response)
langchain-with-ollama
2024-05-13