Skip to content

Commit 1d305c5

Browse files
committed
add memory limit for aggregations
introduce AggregationLimits to set memory consumption limit and bucket limits memory limit is checked during aggregation, bucket limit is checked before returning the aggregation request.
1 parent 8459efa commit 1d305c5

15 files changed

+390
-149
lines changed

examples/aggregation.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ fn main() -> tantivy::Result<()> {
192192
//
193193

194194
let agg_req: Aggregations = serde_json::from_str(agg_req_str)?;
195-
let collector = AggregationCollector::from_aggs(agg_req, None);
195+
let collector = AggregationCollector::from_aggs(agg_req, Default::default());
196196

197197
let agg_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap();
198198
let res2: Value = serde_json::to_value(agg_res)?;
@@ -239,7 +239,7 @@ fn main() -> tantivy::Result<()> {
239239
.into_iter()
240240
.collect();
241241

242-
let collector = AggregationCollector::from_aggs(agg_req, None);
242+
let collector = AggregationCollector::from_aggs(agg_req, Default::default());
243243
// We use the `AllQuery` which will pass all documents to the AggregationCollector.
244244
let agg_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap();
245245

@@ -287,7 +287,7 @@ fn main() -> tantivy::Result<()> {
287287

288288
let agg_req: Aggregations = serde_json::from_str(agg_req_str)?;
289289

290-
let collector = AggregationCollector::from_aggs(agg_req, None);
290+
let collector = AggregationCollector::from_aggs(agg_req, Default::default());
291291

292292
let agg_res: AggregationResults = searcher.search(&AllQuery, &collector).unwrap();
293293
let res: Value = serde_json::to_value(agg_res)?;

src/aggregation/agg_limits.rs

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
use std::collections::HashMap;
2+
use std::sync::atomic::AtomicU64;
3+
use std::sync::Arc;
4+
5+
use super::collector::DEFAULT_MEMORY_LIMIT;
6+
use super::{AggregationError, DEFAULT_BUCKET_LIMIT};
7+
use crate::TantivyError;
8+
9+
/// An estimate for memory consumption
10+
pub trait MemoryConsumption {
11+
fn memory_consumption(&self) -> usize;
12+
}
13+
14+
impl<K, V, S> MemoryConsumption for HashMap<K, V, S> {
15+
fn memory_consumption(&self) -> usize {
16+
let num_items = self.len();
17+
(std::mem::size_of::<K>() + std::mem::size_of::<V>()) * num_items
18+
}
19+
}
20+
21+
/// Aggregation memory limit after which the request fails. Defaults to DEFAULT_MEMORY_LIMIT
22+
/// (500MB). The limit is shared by all SegmentCollectors
23+
pub struct AggregationLimits {
24+
/// The counter which is shared between the aggregations for one request.
25+
memory_consumption: Arc<AtomicU64>,
26+
/// The memory_limit in bytes
27+
memory_limit: u64,
28+
/// The maximum number of buckets _returned_
29+
/// This is not counting intermediate buckets.
30+
bucket_limit: u32,
31+
}
32+
impl Clone for AggregationLimits {
33+
fn clone(&self) -> Self {
34+
Self {
35+
memory_consumption: Arc::clone(&self.memory_consumption),
36+
memory_limit: self.memory_limit,
37+
bucket_limit: self.bucket_limit,
38+
}
39+
}
40+
}
41+
42+
impl Default for AggregationLimits {
43+
fn default() -> Self {
44+
Self {
45+
memory_consumption: Default::default(),
46+
memory_limit: DEFAULT_MEMORY_LIMIT,
47+
bucket_limit: DEFAULT_BUCKET_LIMIT,
48+
}
49+
}
50+
}
51+
52+
impl AggregationLimits {
53+
/// *memory_limit*
54+
/// memory_limit is defined in bytes.
55+
/// Aggregation fails when the estimated memory consumption of the aggregation is higher than
56+
/// memory_limit.
57+
/// memory_limit will default to `DEFAULT_MEMORY_LIMIT` (500MB)
58+
///
59+
/// *bucket_limit*
60+
/// Limits the maximum number of buckets returned from an aggregation request.
61+
/// bucket_limit will default to `DEFAULT_BUCKET_LIMIT` (65000)
62+
pub fn new(memory_limit: Option<u64>, bucket_limit: Option<u32>) -> Self {
63+
Self {
64+
memory_consumption: Default::default(),
65+
memory_limit: memory_limit.unwrap_or(DEFAULT_MEMORY_LIMIT),
66+
bucket_limit: bucket_limit.unwrap_or(DEFAULT_BUCKET_LIMIT),
67+
}
68+
}
69+
pub(crate) fn validate_memory_consumption(&self) -> crate::Result<()> {
70+
if self.get_memory_consumed() > self.memory_limit {
71+
return Err(TantivyError::AggregationError(
72+
AggregationError::MemoryExceeded {
73+
limit: self.memory_limit,
74+
current: self.get_memory_consumed(),
75+
},
76+
));
77+
}
78+
Ok(())
79+
}
80+
pub(crate) fn add_memory_consumed(&self, num_bytes: u64) {
81+
self.memory_consumption
82+
.fetch_add(num_bytes, std::sync::atomic::Ordering::Relaxed);
83+
}
84+
pub fn get_memory_consumed(&self) -> u64 {
85+
self.memory_consumption
86+
.load(std::sync::atomic::Ordering::Relaxed)
87+
}
88+
pub fn get_bucket_limit(&self) -> u32 {
89+
self.bucket_limit
90+
}
91+
}

src/aggregation/agg_req_with_accessor.rs

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
//! This will enhance the request tree with access to the fastfield and metadata.
22
3-
use std::rc::Rc;
4-
use std::sync::atomic::AtomicU32;
5-
63
use columnar::{Column, ColumnType, StrColumn};
74

85
use super::agg_req::{Aggregation, Aggregations, BucketAggregationType, MetricAggregation};
@@ -13,7 +10,7 @@ use super::metric::{
1310
AverageAggregation, CountAggregation, MaxAggregation, MinAggregation, StatsAggregation,
1411
SumAggregation,
1512
};
16-
use super::segment_agg_result::BucketCount;
13+
use super::segment_agg_result::AggregationLimits;
1714
use super::VecWithNames;
1815
use crate::{SegmentReader, TantivyError};
1916

@@ -45,16 +42,15 @@ pub struct BucketAggregationWithAccessor {
4542
pub(crate) field_type: ColumnType,
4643
pub(crate) bucket_agg: BucketAggregationType,
4744
pub(crate) sub_aggregation: AggregationsWithAccessor,
48-
pub(crate) bucket_count: BucketCount,
45+
pub(crate) limits: AggregationLimits,
4946
}
5047

5148
impl BucketAggregationWithAccessor {
5249
fn try_from_bucket(
5350
bucket: &BucketAggregationType,
5451
sub_aggregation: &Aggregations,
5552
reader: &SegmentReader,
56-
bucket_count: Rc<AtomicU32>,
57-
max_bucket_count: u32,
53+
limits: AggregationLimits,
5854
) -> crate::Result<BucketAggregationWithAccessor> {
5955
let mut str_dict_column = None;
6056
let (accessor, field_type) = match &bucket {
@@ -82,15 +78,11 @@ impl BucketAggregationWithAccessor {
8278
sub_aggregation: get_aggs_with_accessor_and_validate(
8379
&sub_aggregation,
8480
reader,
85-
bucket_count.clone(),
86-
max_bucket_count,
81+
&limits.clone(),
8782
)?,
8883
bucket_agg: bucket.clone(),
8984
str_dict_column,
90-
bucket_count: BucketCount {
91-
bucket_count,
92-
max_bucket_count,
93-
},
85+
limits,
9486
})
9587
}
9688
}
@@ -130,8 +122,7 @@ impl MetricAggregationWithAccessor {
130122
pub(crate) fn get_aggs_with_accessor_and_validate(
131123
aggs: &Aggregations,
132124
reader: &SegmentReader,
133-
bucket_count: Rc<AtomicU32>,
134-
max_bucket_count: u32,
125+
limits: &AggregationLimits,
135126
) -> crate::Result<AggregationsWithAccessor> {
136127
let mut metrics = vec![];
137128
let mut buckets = vec![];
@@ -143,8 +134,7 @@ pub(crate) fn get_aggs_with_accessor_and_validate(
143134
&bucket.bucket_agg,
144135
&bucket.sub_aggregation,
145136
reader,
146-
Rc::clone(&bucket_count),
147-
max_bucket_count,
137+
limits.clone(),
148138
)?,
149139
)),
150140
Aggregation::Metric(metric) => metrics.push((

src/aggregation/agg_result.rs

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use super::agg_req::BucketAggregationInternal;
1111
use super::bucket::GetDocCount;
1212
use super::intermediate_agg_result::{IntermediateBucketResult, IntermediateMetricResult};
1313
use super::metric::{SingleMetricResult, Stats};
14+
use super::segment_agg_result::AggregationLimits;
1415
use super::Key;
1516
use crate::TantivyError;
1617

@@ -19,6 +20,13 @@ use crate::TantivyError;
1920
pub struct AggregationResults(pub FxHashMap<String, AggregationResult>);
2021

2122
impl AggregationResults {
23+
pub(crate) fn get_bucket_count(&self) -> u64 {
24+
self.0
25+
.values()
26+
.map(|agg| agg.get_bucket_count())
27+
.sum::<u64>()
28+
}
29+
2230
pub(crate) fn get_value_from_aggregation(
2331
&self,
2432
name: &str,
@@ -47,6 +55,13 @@ pub enum AggregationResult {
4755
}
4856

4957
impl AggregationResult {
58+
pub(crate) fn get_bucket_count(&self) -> u64 {
59+
match self {
60+
AggregationResult::BucketResult(bucket) => bucket.get_bucket_count(),
61+
AggregationResult::MetricResult(_) => 0,
62+
}
63+
}
64+
5065
pub(crate) fn get_value_from_aggregation(
5166
&self,
5267
_name: &str,
@@ -153,9 +168,28 @@ pub enum BucketResult {
153168
}
154169

155170
impl BucketResult {
156-
pub(crate) fn empty_from_req(req: &BucketAggregationInternal) -> crate::Result<Self> {
171+
pub(crate) fn get_bucket_count(&self) -> u64 {
172+
match self {
173+
BucketResult::Range { buckets } => {
174+
buckets.iter().map(|bucket| bucket.get_bucket_count()).sum()
175+
}
176+
BucketResult::Histogram { buckets } => {
177+
buckets.iter().map(|bucket| bucket.get_bucket_count()).sum()
178+
}
179+
BucketResult::Terms {
180+
buckets,
181+
sum_other_doc_count: _,
182+
doc_count_error_upper_bound: _,
183+
} => buckets.iter().map(|bucket| bucket.get_bucket_count()).sum(),
184+
}
185+
}
186+
187+
pub(crate) fn empty_from_req(
188+
req: &BucketAggregationInternal,
189+
limits: &AggregationLimits,
190+
) -> crate::Result<Self> {
157191
let empty_bucket = IntermediateBucketResult::empty_from_req(&req.bucket_agg);
158-
empty_bucket.into_final_bucket_result(req)
192+
empty_bucket.into_final_bucket_result(req, limits)
159193
}
160194
}
161195

@@ -170,6 +204,15 @@ pub enum BucketEntries<T> {
170204
HashMap(FxHashMap<String, T>),
171205
}
172206

207+
impl<T> BucketEntries<T> {
208+
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = &T> + 'a> {
209+
match self {
210+
BucketEntries::Vec(vec) => Box::new(vec.iter()),
211+
BucketEntries::HashMap(map) => Box::new(map.values()),
212+
}
213+
}
214+
}
215+
173216
/// This is the default entry for a bucket, which contains a key, count, and optionally
174217
/// sub-aggregations.
175218
///
@@ -209,6 +252,11 @@ pub struct BucketEntry {
209252
/// Sub-aggregations in this bucket.
210253
pub sub_aggregation: AggregationResults,
211254
}
255+
impl BucketEntry {
256+
pub(crate) fn get_bucket_count(&self) -> u64 {
257+
1 + self.sub_aggregation.get_bucket_count()
258+
}
259+
}
212260
impl GetDocCount for &BucketEntry {
213261
fn doc_count(&self) -> u64 {
214262
self.doc_count
@@ -272,3 +320,8 @@ pub struct RangeBucketEntry {
272320
#[serde(skip_serializing_if = "Option::is_none")]
273321
pub to_as_string: Option<String>,
274322
}
323+
impl RangeBucketEntry {
324+
pub(crate) fn get_bucket_count(&self) -> u64 {
325+
1 + self.sub_aggregation.get_bucket_count()
326+
}
327+
}

0 commit comments

Comments
 (0)