魔戒检索器(Merger Retriever)

魔戒检索器,也称为MergerRetriever,以一个检索器列表作为输入,并将它们的get_relevant_documents()方法的结果合并为一个列表。合并的结果将是与查询相关且已由不同的检索器进行排名的文档列表。

MergerRetriever类可以在许多方面提高文档检索的准确性。首先,它可以组合多个检索器的结果,有助于减少结果中的偏见风险。其次,它可以对不同检索器的结果进行排名,确保最相关的文档首先返回。

import os
import chromadb
from langchain.retrievers.merger_retriever import MergerRetriever
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.embeddings import OpenAIEmbeddings
from langchain.document_transformers import EmbeddingsRedundantFilter
from langchain.retrievers.document_compressors import DocumentCompressorPipeline
from langchain.retrievers import ContextualCompressionRetriever

# 获取3种不同的嵌入向量
all_mini = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
multi_qa_mini = HuggingFaceEmbeddings(model_name="multi-qa-MiniLM-L6-dot-v1")
filter_embeddings = OpenAIEmbeddings()

ABS_PATH = os.path.dirname(os.path.abspath(__file__))
DB_DIR = os.path.join(ABS_PATH, "db")

# 实例化两个不同的ChromaDB索引,每个索引都有不同的嵌入向量。
client_settings = chromadb.config.Settings(
    chroma_db_impl="duckdb+parquet",
    persist_directory=DB_DIR,
    anonymized_telemetry=False,
)
db_all = Chroma(
    collection_name="project_store_all",
    persist_directory=DB_DIR,
    client_settings=client_settings,
    embedding_function=all_mini,
)
db_multi_qa = Chroma(
    collection_name="project_store_multi",
    persist_directory=DB_DIR,
    client_settings=client_settings,
    embedding_function=multi_qa_mini,
)

# 使用不同的嵌入向量和不同的搜索类型定义两个不同的检索器。
retriever_all = db_all.as_retriever(
    search_type="similarity", search_kwargs={"k": 5, "include_metadata": True}
)
retriever_multi_qa = db_multi_qa.as_retriever(
    search_type="mmr", search_kwargs={"k": 5, "include_metadata": True}
)

# 魔戒检索器将保存两个检索器的输出,并且可以像其他检索器一样用于不同类型的链式操作。
lotr = MergerRetriever(retrievers=[retriever_all, retriever_multi_qa])

从合并的检索器中删除冗余结果。


# 我们可以使用另一个嵌入向量从两个检索器中删除冗余的结果。
# 在不同的步骤使用多个嵌入向量可以帮助减少偏见。
filter = EmbeddingsRedundantFilter(embeddings=filter_embeddings)
pipeline = DocumentCompressorPipeline(transformers=[filter])
compression_retriever = ContextualCompressionRetriever(
    base_compressor=pipeline, base_retriever=lotr
)
Last Updated:
Contributors: 刘强