Skip to content

Commit 8e59780

Browse files
committed
Add custom weights to rag indexes
1 parent 515518c commit 8e59780

File tree

1 file changed

+46
-2
lines changed

1 file changed

+46
-2
lines changed

ols/src/rag_index/index_loader.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@
99

1010
logger = logging.getLogger(__name__)
1111

12+
13+
SCORE_DILUTION_WEIGHT = 0.05
14+
SCORE_DILUTION_DEPTH = 2
15+
16+
1217
# delay import of llama_index dependencies
1318
BaseIndex = Any
1419
BaseRetriever = Any
@@ -42,6 +47,41 @@ def load_llama_index_deps() -> None:
4247
from llama_index.core.retrievers import BaseRetriever, QueryFusionRetriever
4348
from llama_index.vector_stores.faiss import FaissVectorStore
4449

50+
# Set custom query fusion class to override existing normalized weighted score.
51+
global QueryFusionRetrieverCustom
52+
53+
class QueryFusionRetrieverCustom(QueryFusionRetriever):
54+
"""Custom query fusion retriever."""
55+
56+
def __init__(self, **kwargs):
57+
"""Initialize custom query fusion class."""
58+
super().__init__(**kwargs)
59+
60+
retrievers = kwargs.get("retrievers", None)
61+
retriever_weights = kwargs.get("retriever_weights", None)
62+
if not retriever_weights:
63+
retriever_weights = [1.0] * len(retrievers)
64+
self._custom_retriever_weights = retriever_weights
65+
66+
def _simple_fusion(self, results):
67+
"""Override internal method and apply weighted score."""
68+
# Overriding one of the method is okay, we just need to add our custom logic.
69+
all_nodes = {}
70+
for i, nodes_with_scores in enumerate(results.values()):
71+
for j, node_with_score in enumerate(nodes_with_scores):
72+
node_index_id = f"{i}_{j}"
73+
all_nodes[node_index_id] = node_with_score
74+
# weighted_score = node_with_score.score * self._custom_retriever_weights[i]
75+
# Uncomment above and delete below, if we decide weights to be set from config.
76+
weighted_score = node_with_score.score * (
77+
1 - min(i, SCORE_DILUTION_DEPTH - 1) * SCORE_DILUTION_WEIGHT
78+
)
79+
all_nodes[node_index_id].score = weighted_score
80+
81+
return sorted(
82+
all_nodes.values(), key=lambda x: x.score or 0.0, reverse=True
83+
)
84+
4585

4686
class IndexLoader:
4787
"""Load index from local file storage."""
@@ -146,12 +186,16 @@ def get_retriever(
146186
and self._retriever.similarity_top_k == similarity_top_k
147187
):
148188
return self._retriever
149-
retriever = QueryFusionRetriever(
150-
[
189+
190+
# Note: we are using a custom retriever, based on our need
191+
retriever = QueryFusionRetrieverCustom(
192+
retrievers=[
151193
index.as_retriever(similarity_top_k=similarity_top_k)
152194
for index in self._indexes
153195
],
154196
similarity_top_k=similarity_top_k,
197+
retriever_weights=None, # Setting as None, until this gets added to config
198+
mode="simple", # Don't modify this as we are adding our own logic
155199
num_queries=1, # set this to 1 to disable query generation
156200
use_async=False,
157201
verbose=False,

0 commit comments

Comments
 (0)