|
9 | 9 |
|
10 | 10 | logger = logging.getLogger(__name__)
|
11 | 11 |
|
| 12 | + |
| 13 | +SCORE_DILUTION_WEIGHT = 0.05 |
| 14 | +SCORE_DILUTION_DEPTH = 2 |
| 15 | + |
| 16 | + |
12 | 17 | # delay import of llama_index dependencies
|
13 | 18 | BaseIndex = Any
|
14 | 19 | BaseRetriever = Any
|
@@ -42,6 +47,41 @@ def load_llama_index_deps() -> None:
|
42 | 47 | from llama_index.core.retrievers import BaseRetriever, QueryFusionRetriever
|
43 | 48 | from llama_index.vector_stores.faiss import FaissVectorStore
|
44 | 49 |
|
| 50 | + # Set custom query fusion class to override existing normalized weighted score. |
| 51 | + global QueryFusionRetrieverCustom |
| 52 | + |
| 53 | + class QueryFusionRetrieverCustom(QueryFusionRetriever): |
| 54 | + """Custom query fusion retriever.""" |
| 55 | + |
| 56 | + def __init__(self, **kwargs): |
| 57 | + """Initialize custom query fusion class.""" |
| 58 | + super().__init__(**kwargs) |
| 59 | + |
| 60 | + retrievers = kwargs.get("retrievers", None) |
| 61 | + retriever_weights = kwargs.get("retriever_weights", None) |
| 62 | + if not retriever_weights: |
| 63 | + retriever_weights = [1.0] * len(retrievers) |
| 64 | + self._custom_retriever_weights = retriever_weights |
| 65 | + |
| 66 | + def _simple_fusion(self, results): |
| 67 | + """Override internal method and apply weighted score.""" |
| 68 | + # Overriding one of the method is okay, we just need to add our custom logic. |
| 69 | + all_nodes = {} |
| 70 | + for i, nodes_with_scores in enumerate(results.values()): |
| 71 | + for j, node_with_score in enumerate(nodes_with_scores): |
| 72 | + node_index_id = f"{i}_{j}" |
| 73 | + all_nodes[node_index_id] = node_with_score |
| 74 | + # weighted_score = node_with_score.score * self._custom_retriever_weights[i] |
| 75 | + # Uncomment above and delete below, if we decide weights to be set from config. |
| 76 | + weighted_score = node_with_score.score * ( |
| 77 | + 1 - min(i, SCORE_DILUTION_DEPTH - 1) * SCORE_DILUTION_WEIGHT |
| 78 | + ) |
| 79 | + all_nodes[node_index_id].score = weighted_score |
| 80 | + |
| 81 | + return sorted( |
| 82 | + all_nodes.values(), key=lambda x: x.score or 0.0, reverse=True |
| 83 | + ) |
| 84 | + |
45 | 85 |
|
46 | 86 | class IndexLoader:
|
47 | 87 | """Load index from local file storage."""
|
@@ -146,12 +186,16 @@ def get_retriever(
|
146 | 186 | and self._retriever.similarity_top_k == similarity_top_k
|
147 | 187 | ):
|
148 | 188 | return self._retriever
|
149 |
| - retriever = QueryFusionRetriever( |
150 |
| - [ |
| 189 | + |
| 190 | + # Note: we are using a custom retriever, based on our need |
| 191 | + retriever = QueryFusionRetrieverCustom( |
| 192 | + retrievers=[ |
151 | 193 | index.as_retriever(similarity_top_k=similarity_top_k)
|
152 | 194 | for index in self._indexes
|
153 | 195 | ],
|
154 | 196 | similarity_top_k=similarity_top_k,
|
| 197 | + retriever_weights=None, # Setting as None, until this gets added to config |
| 198 | + mode="simple", # Don't modify this as we are adding our own logic |
155 | 199 | num_queries=1, # set this to 1 to disable query generation
|
156 | 200 | use_async=False,
|
157 | 201 | verbose=False,
|
|
0 commit comments