Skip to content

Commit a0e7d0a

Browse files
committed
Apply weights on index retrievers
1 parent fb735b5 commit a0e7d0a

File tree

2 files changed

+61
-11
lines changed

2 files changed

+61
-11
lines changed

ols/src/rag_index/index_loader.py

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,28 @@ def load_llama_index_deps() -> None:
4343
from llama_index.vector_stores.faiss import FaissVectorStore
4444

4545

46+
def calculate_retrievers_weights(number_of_retrievers: int) -> list[float]:
47+
"""Calculate weights for retrievers.
48+
49+
Weights always prioritizes the first retriever.
50+
"""
51+
if number_of_retrievers == 0:
52+
logger.warning("Number of retrievers is 0. Returning empty weights list.")
53+
return []
54+
if number_of_retrievers == 1:
55+
return [1.0]
56+
57+
# assign a higher weight to the first retriever
58+
first_weight = 0.6 # adjust this value as needed
59+
remaining_weight = 1.0 - first_weight
60+
other_weight = remaining_weight / (number_of_retrievers - 1)
61+
62+
# generate the weights list
63+
retriever_weights = [first_weight] + [other_weight] * (number_of_retrievers - 1)
64+
65+
return retriever_weights
66+
67+
4668
class IndexLoader:
4769
"""Load index from local file storage."""
4870

@@ -146,15 +168,30 @@ def get_retriever(
146168
and self._retriever.similarity_top_k == similarity_top_k
147169
):
148170
return self._retriever
149-
retriever = QueryFusionRetriever(
150-
[
151-
index.as_retriever(similarity_top_k=similarity_top_k)
152-
for index in self._indexes
153-
],
154-
similarity_top_k=similarity_top_k,
155-
num_queries=1, # set this to 1 to disable query generation
156-
use_async=False,
157-
verbose=False,
158-
)
171+
172+
retrievers = [
173+
index.as_retriever(similarity_top_k=similarity_top_k)
174+
for index in self._indexes
175+
]
176+
177+
if len(retrievers) == 1:
178+
retriever = QueryFusionRetriever(
179+
retrievers=retrievers,
180+
similarity_top_k=similarity_top_k,
181+
num_queries=1, # set this to 1 to disable query generation
182+
use_async=False,
183+
verbose=False,
184+
)
185+
else:
186+
# set weights and mode in case of multiple retrievers/indices
187+
retriever = QueryFusionRetriever(
188+
retrievers=retrievers,
189+
retriever_weights=calculate_retrievers_weights(len(retrievers)),
190+
mode="relative_score", # relative score fusion mode to apply weights
191+
similarity_top_k=similarity_top_k,
192+
num_queries=1, # set this to 1 to disable query generation
193+
use_async=False,
194+
verbose=False,
195+
)
159196
self._retriever = retriever
160197
return self._retriever

tests/unit/rag_index/test_index_loader.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from ols import config
77
from ols.app.models.config import ReferenceContent, ReferenceContentIndex
8-
from ols.src.rag_index.index_loader import IndexLoader
8+
from ols.src.rag_index.index_loader import IndexLoader, calculate_retrievers_weights
99
from tests.mock_classes.mock_llama_index import MockLlamaIndex
1010

1111

@@ -69,3 +69,16 @@ def test_index_loader():
6969

7070
assert len(indexes) == 1
7171
assert isinstance(indexes[0], MockLlamaIndex)
72+
73+
74+
def test_calculate_retrievers_weights():
75+
"""Test calculate retrievers weights."""
76+
assert calculate_retrievers_weights(0) == []
77+
78+
assert calculate_retrievers_weights(1) == [1.0]
79+
80+
assert calculate_retrievers_weights(2) == [0.6, 0.4]
81+
82+
assert calculate_retrievers_weights(3) == [0.6, 0.2, 0.2]
83+
84+
# and so on...

0 commit comments

Comments
 (0)