@@ -43,6 +43,28 @@ def load_llama_index_deps() -> None:
43
43
from llama_index .vector_stores .faiss import FaissVectorStore
44
44
45
45
46
+ def calculate_retrievers_weights (number_of_retrievers : int ) -> list [float ]:
47
+ """Calculate weights for retrievers.
48
+
49
+ Weights always prioritizes the first retriever.
50
+ """
51
+ if number_of_retrievers == 0 :
52
+ logger .warning ("Number of retrievers is 0. Returning empty weights list." )
53
+ return []
54
+ if number_of_retrievers == 1 :
55
+ return [1.0 ]
56
+
57
+ # assign a higher weight to the first retriever
58
+ first_weight = 0.6 # adjust this value as needed
59
+ remaining_weight = 1.0 - first_weight
60
+ other_weight = remaining_weight / (number_of_retrievers - 1 )
61
+
62
+ # generate the weights list
63
+ retriever_weights = [first_weight ] + [other_weight ] * (number_of_retrievers - 1 )
64
+
65
+ return retriever_weights
66
+
67
+
46
68
class IndexLoader :
47
69
"""Load index from local file storage."""
48
70
@@ -146,15 +168,30 @@ def get_retriever(
146
168
and self ._retriever .similarity_top_k == similarity_top_k
147
169
):
148
170
return self ._retriever
149
- retriever = QueryFusionRetriever (
150
- [
151
- index .as_retriever (similarity_top_k = similarity_top_k )
152
- for index in self ._indexes
153
- ],
154
- similarity_top_k = similarity_top_k ,
155
- num_queries = 1 , # set this to 1 to disable query generation
156
- use_async = False ,
157
- verbose = False ,
158
- )
171
+
172
+ retrievers = [
173
+ index .as_retriever (similarity_top_k = similarity_top_k )
174
+ for index in self ._indexes
175
+ ]
176
+
177
+ if len (retrievers ) == 1 :
178
+ retriever = QueryFusionRetriever (
179
+ retrievers = retrievers ,
180
+ similarity_top_k = similarity_top_k ,
181
+ num_queries = 1 , # set this to 1 to disable query generation
182
+ use_async = False ,
183
+ verbose = False ,
184
+ )
185
+ else :
186
+ # set weights and mode in case of multiple retrievers/indices
187
+ retriever = QueryFusionRetriever (
188
+ retrievers = retrievers ,
189
+ retriever_weights = calculate_retrievers_weights (len (retrievers )),
190
+ mode = "relative_score" , # relative score fusion mode to apply weights
191
+ similarity_top_k = similarity_top_k ,
192
+ num_queries = 1 , # set this to 1 to disable query generation
193
+ use_async = False ,
194
+ verbose = False ,
195
+ )
159
196
self ._retriever = retriever
160
197
return self ._retriever
0 commit comments