Skip to content

Commit be229e5

Browse files
dqhl76zhang2014
andauthored
fix: use spill depth partition bit in new final agg (#19626)
* fix: use spill depth partition bit in new final agg * feat: add final agg task statistics * fix: need check_spill for every payload * test: add test * polish statistics.rs * cargo fmt Revert "cargo fmt" This reverts commit c9cccc3. --------- Co-authored-by: Winter Zhang <coswde@gmail.com>
1 parent 141ae60 commit be229e5

File tree

9 files changed

+280
-77
lines changed

9 files changed

+280
-77
lines changed

src/query/expression/src/aggregate/aggregate_hashtable.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,11 @@ impl AggregateHashTable {
7474
Self {
7575
direct_append: false,
7676
current_radix_bits: config.initial_radix_bits,
77-
payload: PartitionedPayload::new(
77+
payload: PartitionedPayload::new_with_start_bit(
7878
group_types,
7979
aggrs,
8080
1 << config.initial_radix_bits,
81+
config.partition_start_bit,
8182
vec![arena],
8283
),
8384
hash_index: HashIndex::new(&config, capacity),
@@ -105,10 +106,11 @@ impl AggregateHashTable {
105106
Self {
106107
direct_append: !need_init_entry,
107108
current_radix_bits: config.initial_radix_bits,
108-
payload: PartitionedPayload::new(
109+
payload: PartitionedPayload::new_with_start_bit(
109110
group_types,
110111
aggrs,
111112
1 << config.initial_radix_bits,
113+
config.partition_start_bit,
112114
vec![arena],
113115
),
114116
hash_index,

src/query/expression/src/aggregate/mod.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ pub struct HashTableConfig {
154154
// Max radix bits across all threads, this is a hint to repartition
155155
pub current_max_radix_bits: Arc<AtomicU64>,
156156
pub initial_radix_bits: u64,
157+
pub partition_start_bit: u64,
157158
pub max_radix_bits: u64,
158159
pub repartition_radix_bits_incr: u64,
159160
pub block_fill_factor: f64,
@@ -167,6 +168,7 @@ impl Default for HashTableConfig {
167168
Self {
168169
current_max_radix_bits: Arc::new(AtomicU64::new(3)),
169170
initial_radix_bits: 3,
171+
partition_start_bit: 0,
170172
max_radix_bits: MAX_RADIX_BITS,
171173
repartition_radix_bits_incr: 2,
172174
block_fill_factor: 1.8,
@@ -211,6 +213,11 @@ impl HashTableConfig {
211213
self
212214
}
213215

216+
pub fn with_partition_start_bit(mut self, partition_start_bit: u64) -> Self {
217+
self.partition_start_bit = partition_start_bit;
218+
self
219+
}
220+
214221
pub fn with_experiment_hash_index(mut self, enable: bool) -> Self {
215222
self.enable_experiment_hash_index = enable;
216223
self

src/query/expression/src/aggregate/partitioned_payload.rs

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,15 @@ struct PartitionMask {
3636

3737
impl PartitionMask {
3838
fn new(partition_count: u64) -> Self {
39+
Self::with_start_bit(partition_count, 0)
40+
}
41+
42+
fn with_start_bit(partition_count: u64, start_bit: u64) -> Self {
3943
let radix_bits = partition_count.trailing_zeros() as u64;
4044
debug_assert_eq!(1 << radix_bits, partition_count);
45+
debug_assert!(start_bit + radix_bits <= 48);
4146

42-
let shift = 48 - radix_bits;
47+
let shift = 48 - start_bit - radix_bits;
4348
let mask = ((1 << radix_bits) - 1) << shift;
4449

4550
Self { mask, shift }
@@ -59,6 +64,7 @@ pub struct PartitionedPayload {
5964

6065
pub arenas: Vec<Arc<Bump>>,
6166

67+
partition_start_bit: u64,
6268
partition_mask: PartitionMask,
6369
}
6470

@@ -71,6 +77,16 @@ impl PartitionedPayload {
7177
aggrs: Vec<AggregateFunctionRef>,
7278
partition_count: u64,
7379
arenas: Vec<Arc<Bump>>,
80+
) -> Self {
81+
Self::new_with_start_bit(group_types, aggrs, partition_count, 0, arenas)
82+
}
83+
84+
pub fn new_with_start_bit(
85+
group_types: Vec<DataType>,
86+
aggrs: Vec<AggregateFunctionRef>,
87+
partition_count: u64,
88+
partition_start_bit: u64,
89+
arenas: Vec<Arc<Bump>>,
7490
) -> Self {
7591
let states_layout = if !aggrs.is_empty() {
7692
Some(get_states_layout(&aggrs).unwrap())
@@ -101,7 +117,8 @@ impl PartitionedPayload {
101117
row_layout,
102118

103119
arenas,
104-
partition_mask: PartitionMask::new(partition_count),
120+
partition_start_bit,
121+
partition_mask: PartitionMask::with_start_bit(partition_count, partition_start_bit),
105122
}
106123
}
107124

@@ -169,11 +186,17 @@ impl PartitionedPayload {
169186
group_types,
170187
aggrs,
171188
arenas,
189+
partition_start_bit,
172190
..
173191
} = self;
174192

175-
let mut new_partition_payload =
176-
PartitionedPayload::new(group_types, aggrs, new_partition_count as u64, arenas);
193+
let mut new_partition_payload = PartitionedPayload::new_with_start_bit(
194+
group_types,
195+
aggrs,
196+
new_partition_count as u64,
197+
partition_start_bit,
198+
arenas,
199+
);
177200

178201
state.clear();
179202
for payload in payloads.into_iter() {
@@ -184,7 +207,9 @@ impl PartitionedPayload {
184207
}
185208

186209
pub fn combine(&mut self, other: PartitionedPayload, state: &mut PayloadFlushState) {
187-
if other.partition_count() == self.partition_count() {
210+
if other.partition_count() == self.partition_count()
211+
&& other.partition_start_bit == self.partition_start_bit
212+
{
188213
for (l, r) in self.payloads.iter_mut().zip(other.payloads.into_iter()) {
189214
l.combine(r);
190215
}
@@ -293,3 +318,19 @@ impl PartitionedPayload {
293318
self.payloads.iter().map(|x| x.memory_size()).sum()
294319
}
295320
}
321+
322+
#[cfg(test)]
323+
mod tests {
324+
use super::PartitionMask;
325+
326+
#[test]
327+
fn test_partition_mask_with_start_bit() {
328+
let top_bit_mask = PartitionMask::new(2);
329+
assert_eq!(top_bit_mask.index(1_u64 << 47), 1);
330+
assert_eq!(top_bit_mask.index(1_u64 << 44), 0);
331+
332+
let shifted_mask = PartitionMask::with_start_bit(2, 3);
333+
assert_eq!(shifted_mask.index(1_u64 << 47), 0);
334+
assert_eq!(shifted_mask.index(1_u64 << 44), 1);
335+
}
336+
}

src/query/service/src/pipelines/processors/transforms/aggregator/build_partition_bucket.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// limitations under the License.
1414

1515
use std::sync::Arc;
16+
use std::sync::atomic::AtomicU64;
1617

1718
use databend_common_catalog::table_context::TableContext;
1819
use databend_common_exception::Result;
@@ -49,6 +50,7 @@ fn build_partition_bucket_experimental(
4950
shuffle_mode: AggregateShuffleMode,
5051
) -> Result<()> {
5152
let mut final_parallelism = ctx.get_settings().get_max_threads()? as usize;
53+
let base_consumed_bits = shuffle_mode.determine_radix_bits();
5254
match shuffle_mode {
5355
AggregateShuffleMode::Row => {
5456
let schema = params.spill_schema();
@@ -107,6 +109,7 @@ fn build_partition_bucket_experimental(
107109

108110
let mut builder = TransformPipeBuilder::create();
109111
let (tx, rx) = async_channel::unbounded();
112+
let next_task_id = Arc::new(AtomicU64::new(1));
110113
for id in 0..final_parallelism {
111114
let input_port = InputPort::create();
112115
let output_port = OutputPort::create();
@@ -115,9 +118,11 @@ fn build_partition_bucket_experimental(
115118
output_port.clone(),
116119
params.clone(),
117120
id,
121+
base_consumed_bits,
118122
ctx.clone(),
119123
tx.clone(),
120124
rx.clone(),
125+
next_task_id.clone(),
121126
)?;
122127
builder.add_transform(input_port, output_port, ProcessorPtr::create(processor));
123128
}

0 commit comments

Comments
 (0)