Skip to content

Commit 427e1e4

Browse files
committed
Add custom weights to rag indexes
1 parent 515518c commit 427e1e4

File tree

1 file changed

+41
-2
lines changed

1 file changed

+41
-2
lines changed

ols/src/rag_index/index_loader.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,15 @@
99

1010
logger = logging.getLogger(__name__)
1111

12+
13+
SCORE_DILUTION_WEIGHT = 0.05
14+
SCORE_DILUTION_DEPTH = 2
15+
16+
1217
# delay import of llama_index dependencies
1318
BaseIndex = Any
1419
BaseRetriever = Any
20+
QueryFusionRetriever = Any
1521

1622

1723
# NOTE: Loading/importing something from llama_index bumps memory
@@ -43,6 +49,35 @@ def load_llama_index_deps() -> None:
4349
from llama_index.vector_stores.faiss import FaissVectorStore
4450

4551

52+
class QueryFusionRetrieverCustom(QueryFusionRetriever):
53+
def __init__(self, **kwargs):
54+
"""Initialize custom query fusion class."""
55+
super().__init__(**kwargs)
56+
57+
retrievers = kwargs.get("retrievers", None)
58+
retriever_weights = kwargs.get("retriever_weights", None)
59+
if not retriever_weights:
60+
retriever_weights = [1.0] * len(retrievers)
61+
self._custom_retriever_weights = retriever_weights
62+
63+
def _simple_fusion(self, results):
64+
"""Override internal method and apply weighted score."""
65+
# Overriding one of the method is okay, we just need to add our custom logic.
66+
all_nodes = {}
67+
for i, nodes_with_scores in enumerate(results.values()):
68+
for j, node_with_score in enumerate(nodes_with_scores):
69+
node_index_id = f"{i}_{j}"
70+
all_nodes[node_index_id] = node_with_score
71+
# weighted_score = node_with_score.score * self._custom_retriever_weights[i]
72+
# Uncomment above and delete below, if we decide weights to be set from config.
73+
weighted_score = node_with_score.score * (
74+
1 - min(i, SCORE_DILUTION_DEPTH - 1) * SCORE_DILUTION_WEIGHT
75+
)
76+
all_nodes[node_index_id].score = weighted_score
77+
78+
return sorted(all_nodes.values(), key=lambda x: x.score or 0.0, reverse=True)
79+
80+
4681
class IndexLoader:
4782
"""Load index from local file storage."""
4883

@@ -146,12 +181,16 @@ def get_retriever(
146181
and self._retriever.similarity_top_k == similarity_top_k
147182
):
148183
return self._retriever
149-
retriever = QueryFusionRetriever(
150-
[
184+
185+
# Note: we are using a custom retriever, based on our need
186+
retriever = QueryFusionRetrieverCustom(
187+
retrievers=[
151188
index.as_retriever(similarity_top_k=similarity_top_k)
152189
for index in self._indexes
153190
],
154191
similarity_top_k=similarity_top_k,
192+
retriever_weights=None, # Setting as None, until this gets added to config
193+
mode="simple", # Don't modify this as we are adding our own logic
155194
num_queries=1, # set this to 1 to disable query generation
156195
use_async=False,
157196
verbose=False,

0 commit comments

Comments
 (0)