Skip to content

Commit a4a83f0

Browse files
authored
[0.6] Scrub reports during aggregation job creation, rather than the initial aggregation step. (#2559)
Other than being the earliest point that we can scrub reports, this avoids (briefly) duplicating report data in both the `client_reports` and `report_aggregations` tables. I think placing this logic in the aggregation job creator also makes more sense for when we'll eventually expand it to handle VDAFs requiring multiple aggregations per report.
1 parent f3bab98 commit a4a83f0

File tree

5 files changed

+102
-61
lines changed

5 files changed

+102
-61
lines changed

aggregator/src/aggregator/aggregation_job_creator.rs

+47-6
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use fixed::{
44
types::extra::{U15, U31},
55
FixedI16, FixedI32,
66
};
7+
use futures::future::try_join_all;
78
use janus_aggregator_core::{
89
datastore::models::{AggregationJob, AggregationJobState},
910
datastore::{self, Datastore},
@@ -33,8 +34,15 @@ use prio::{
3334
},
3435
};
3536
use rand::{random, thread_rng, Rng};
36-
use std::{collections::HashMap, sync::Arc, time::Duration};
37-
use tokio::time::{self, sleep_until, Instant, MissedTickBehavior};
37+
use std::{
38+
collections::{HashMap, HashSet},
39+
sync::Arc,
40+
time::Duration,
41+
};
42+
use tokio::{
43+
time::{self, sleep_until, Instant, MissedTickBehavior},
44+
try_join,
45+
};
3846
use tracing::{debug, error, info};
3947
use trillium_tokio::{CloneCounterObserver, Stopper};
4048

@@ -528,6 +536,7 @@ impl<C: Clock + 'static> AggregationJobCreator<C> {
528536

529537
// Generate aggregation jobs & report aggregations based on the reports we read.
530538
let mut aggregation_job_writer = AggregationJobWriter::new(Arc::clone(&task));
539+
let mut report_ids_to_scrub = HashSet::new();
531540
for agg_job_reports in reports.chunks(this.max_aggregation_job_size) {
532541
if agg_job_reports.len() < this.min_aggregation_job_size {
533542
if !agg_job_reports.is_empty() {
@@ -585,12 +594,22 @@ impl<C: Clock + 'static> AggregationJobCreator<C> {
585594
))
586595
})
587596
.collect::<Result<_, datastore::Error>>()?;
597+
report_ids_to_scrub
598+
.extend(agg_job_reports.iter().map(|report| *report.metadata().id()));
588599

589600
aggregation_job_writer.put(aggregation_job, report_aggregations)?;
590601
}
591602

592603
// Write the aggregation jobs & report aggregations we created.
593-
aggregation_job_writer.write(tx, vdaf).await?;
604+
try_join!(
605+
aggregation_job_writer.write(tx, vdaf),
606+
try_join_all(
607+
report_ids_to_scrub
608+
.iter()
609+
.map(|report_id| tx.scrub_client_report(task.id(), report_id))
610+
)
611+
)?;
612+
594613
Ok(!aggregation_job_writer.is_empty())
595614
})
596615
})
@@ -2585,7 +2604,7 @@ mod tests {
25852604
for<'a> A::PrepareState: ParameterizedDecode<(&'a A, usize)>,
25862605
A::PublicShare: PartialEq,
25872606
{
2588-
try_join!(
2607+
let (agg_jobs_and_report_ids, batches) = try_join!(
25892608
try_join_all(
25902609
tx.get_aggregation_jobs_for_task(task_id)
25912610
.await
@@ -2602,7 +2621,7 @@ mod tests {
26022621
.await
26032622
.map(|report_aggs| {
26042623
// Verify that each report aggregation has the expected state.
2605-
let report_ids = report_aggs
2624+
let report_ids: Vec<_> = report_aggs
26062625
.into_iter()
26072626
.map(|ra| {
26082627
let want_ra_state =
@@ -2624,6 +2643,28 @@ mod tests {
26242643
),
26252644
tx.get_batches_for_task(task_id),
26262645
)
2627-
.unwrap()
2646+
.unwrap();
2647+
2648+
// Verify that all reports we saw a report aggregation for are scrubbed.
2649+
let all_seen_report_ids: HashSet<_> = agg_jobs_and_report_ids
2650+
.iter()
2651+
.flat_map(|(_, report_ids)| report_ids)
2652+
.collect();
2653+
for report_id in &all_seen_report_ids {
2654+
tx.verify_client_report_scrubbed(task_id, report_id).await;
2655+
}
2656+
2657+
// Verify that all reports we did not see a report aggregation for are not scrubbed. (We do
2658+
// so by reading the report, since reading a report will fail if the report is scrubbed.)
2659+
for report_id in want_ra_states.keys() {
2660+
if all_seen_report_ids.contains(report_id) {
2661+
continue;
2662+
}
2663+
tx.get_client_report(vdaf, task_id, report_id)
2664+
.await
2665+
.unwrap();
2666+
}
2667+
2668+
(agg_jobs_and_report_ids, batches)
26282669
}
26292670
}

aggregator/src/aggregator/aggregation_job_driver.rs

+3-24
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ use janus_messages::{
2323
query_type::{FixedSize, TimeInterval},
2424
AggregationJobContinueReq, AggregationJobInitializeReq, AggregationJobResp,
2525
PartialBatchSelector, PrepareContinue, PrepareError, PrepareInit, PrepareResp,
26-
PrepareStepResult, ReportId, ReportShare, Role,
26+
PrepareStepResult, ReportShare, Role,
2727
};
2828
use opentelemetry::{
2929
metrics::{Counter, Histogram, Meter, Unit},
@@ -310,16 +310,6 @@ impl AggregationJobDriver {
310310
A::PrepareMessage: PartialEq + Eq + Send + Sync,
311311
A::PublicShare: PartialEq + Send + Sync,
312312
{
313-
// We currently scrub all reports included in an aggregation job as part of completing the
314-
// first step. Once we support VDAFs which accept use an aggregation parameter &
315-
// permit/require multiple aggregations per report, we will need a more complicated strategy
316-
// where we only scrub a report as part of the _final_ aggregation over the report.
317-
let report_ids_to_scrub = report_aggregations
318-
.iter()
319-
.map(|ra| ra.report_id())
320-
.copied()
321-
.collect();
322-
323313
// Only process non-failed report aggregations.
324314
let report_aggregations: Vec<_> = report_aggregations
325315
.into_iter()
@@ -479,7 +469,6 @@ impl AggregationJobDriver {
479469
aggregation_job,
480470
&stepped_aggregations,
481471
report_aggregations_to_write,
482-
report_ids_to_scrub,
483472
resp.prepare_resps(),
484473
)
485474
.await
@@ -577,7 +566,6 @@ impl AggregationJobDriver {
577566
aggregation_job,
578567
&stepped_aggregations,
579568
report_aggregations_to_write,
580-
Vec::new(), /* reports are only scrubbed on the initial step */
581569
resp.prepare_resps(),
582570
)
583571
.await
@@ -598,7 +586,6 @@ impl AggregationJobDriver {
598586
aggregation_job: AggregationJob<SEED_SIZE, Q, A>,
599587
stepped_aggregations: &[SteppedAggregation<SEED_SIZE, A>],
600588
mut report_aggregations_to_write: Vec<ReportAggregation<SEED_SIZE, A>>,
601-
report_ids_to_scrub: Vec<ReportId>,
602589
helper_prep_resps: &[PrepareResp],
603590
) -> Result<(), Error>
604591
where
@@ -759,27 +746,19 @@ impl AggregationJobDriver {
759746
report_aggregations_to_write,
760747
)?;
761748
let aggregation_job_writer = Arc::new(aggregation_job_writer);
762-
763-
let report_ids_to_scrub = Arc::new(report_ids_to_scrub);
764749
let accumulator = Arc::new(accumulator);
750+
765751
datastore
766752
.run_tx("step_aggregation_job_2", |tx| {
767-
let task_id = *task.id();
768753
let vdaf = Arc::clone(&vdaf);
769754
let aggregation_job_writer = Arc::clone(&aggregation_job_writer);
770755
let accumulator = Arc::clone(&accumulator);
771-
let report_ids_to_scrub = Arc::clone(&report_ids_to_scrub);
772756
let lease = Arc::clone(&lease);
773757

774758
Box::pin(async move {
775-
let (unwritable_ra_report_ids, unwritable_ba_report_ids, _, _) = try_join!(
759+
let (unwritable_ra_report_ids, unwritable_ba_report_ids, _) = try_join!(
776760
aggregation_job_writer.write(tx, Arc::clone(&vdaf)),
777761
accumulator.flush_to_datastore(tx, &vdaf),
778-
try_join_all(
779-
report_ids_to_scrub
780-
.iter()
781-
.map(|report_id| tx.scrub_client_report(&task_id, report_id))
782-
),
783762
tx.release_aggregation_job(&lease),
784763
)?;
785764

aggregator/src/aggregator/batch_creator.rs

+23-8
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@ use janus_aggregator_core::datastore::{
77
};
88
use janus_core::time::{Clock, DurationExt, TimeExt};
99
use janus_messages::{
10-
query_type::FixedSize, AggregationJobStep, BatchId, Duration, Interval, TaskId, Time,
10+
query_type::FixedSize, AggregationJobStep, BatchId, Duration, Interval, ReportId, TaskId, Time,
1111
};
1212
use prio::{codec::Encode, vdaf::Aggregator};
1313
use rand::random;
1414
use std::{
1515
cmp::{max, min, Ordering},
16-
collections::{binary_heap::PeekMut, hash_map, BinaryHeap, HashMap, VecDeque},
16+
collections::{binary_heap::PeekMut, hash_map, BinaryHeap, HashMap, HashSet, VecDeque},
1717
ops::RangeInclusive,
1818
sync::Arc,
1919
};
@@ -32,8 +32,9 @@ where
3232
{
3333
properties: Properties,
3434
aggregation_job_writer: &'a mut AggregationJobWriter<SEED_SIZE, FixedSize, A>,
35-
map: HashMap<Option<Time>, Bucket<SEED_SIZE, A>>,
35+
buckets: HashMap<Option<Time>, Bucket<SEED_SIZE, A>>,
3636
new_batches: Vec<(BatchId, Option<Time>)>,
37+
report_ids_to_scrub: HashSet<ReportId>,
3738
}
3839

3940
/// Common properties used by [`BatchCreator`]. This is broken out into a separate structure to make
@@ -72,8 +73,9 @@ where
7273
task_batch_time_window_size,
7374
},
7475
aggregation_job_writer,
75-
map: HashMap::new(),
76+
buckets: HashMap::new(),
7677
new_batches: Vec::new(),
78+
report_ids_to_scrub: HashSet::new(),
7779
}
7880
}
7981

@@ -95,15 +97,15 @@ where
9597
.to_batch_interval_start(&batch_time_window_size)
9698
})
9799
.transpose()?;
98-
let mut map_entry = self.map.entry(time_bucket_start_opt);
100+
let mut map_entry = self.buckets.entry(time_bucket_start_opt);
99101
let bucket = match &mut map_entry {
100102
hash_map::Entry::Occupied(occupied) => occupied.get_mut(),
101103
hash_map::Entry::Vacant(_) => {
102104
// Lazily find existing unfilled batches.
103105
let outstanding_batches = tx
104106
.get_outstanding_batches(&self.properties.task_id, &time_bucket_start_opt)
105107
.await?;
106-
self.map
108+
self.buckets
107109
.entry(time_bucket_start_opt)
108110
.or_insert_with(|| Bucket::new(outstanding_batches))
109111
}
@@ -115,6 +117,7 @@ where
115117
Self::process_batches(
116118
&self.properties,
117119
self.aggregation_job_writer,
120+
&mut self.report_ids_to_scrub,
118121
&mut self.new_batches,
119122
&time_bucket_start_opt,
120123
bucket,
@@ -136,6 +139,7 @@ where
136139
fn process_batches(
137140
properties: &Properties,
138141
aggregation_job_writer: &mut AggregationJobWriter<SEED_SIZE, FixedSize, A>,
142+
report_ids_to_scrub: &mut HashSet<ReportId>,
139143
new_batches: &mut Vec<(BatchId, Option<Time>)>,
140144
time_bucket_start: &Option<Time>,
141145
bucket: &mut Bucket<SEED_SIZE, A>,
@@ -182,6 +186,7 @@ where
182186
desired_aggregation_job_size,
183187
&mut bucket.unaggregated_reports,
184188
aggregation_job_writer,
189+
report_ids_to_scrub,
185190
)?;
186191
largest_outstanding_batch.add_reports(desired_aggregation_job_size);
187192
} else {
@@ -209,6 +214,7 @@ where
209214
desired_aggregation_job_size,
210215
&mut bucket.unaggregated_reports,
211216
aggregation_job_writer,
217+
report_ids_to_scrub,
212218
)?;
213219
largest_outstanding_batch.add_reports(desired_aggregation_job_size);
214220
} else {
@@ -249,6 +255,7 @@ where
249255
desired_aggregation_job_size,
250256
&mut bucket.unaggregated_reports,
251257
aggregation_job_writer,
258+
report_ids_to_scrub,
252259
)?;
253260

254261
// Loop to the top of this method to create more aggregation jobs in this newly
@@ -268,6 +275,7 @@ where
268275
aggregation_job_size: usize,
269276
unaggregated_reports: &mut VecDeque<LeaderStoredReport<SEED_SIZE, A>>,
270277
aggregation_job_writer: &mut AggregationJobWriter<SEED_SIZE, FixedSize, A>,
278+
report_ids_to_scrub: &mut HashSet<ReportId>,
271279
) -> Result<(), Error> {
272280
let aggregation_job_id = random();
273281
debug!(
@@ -280,7 +288,7 @@ where
280288
let mut min_client_timestamp = None;
281289
let mut max_client_timestamp = None;
282290

283-
let report_aggregations = (0u64..)
291+
let report_aggregations: Vec<_> = (0u64..)
284292
.zip(unaggregated_reports.drain(..aggregation_job_size))
285293
.map(|(ord, report)| {
286294
let client_timestamp = *report.metadata().time();
@@ -294,6 +302,7 @@ where
294302
report.as_start_leader_report_aggregation(aggregation_job_id, ord)
295303
})
296304
.collect();
305+
report_ids_to_scrub.extend(report_aggregations.iter().map(|ra| *ra.report_id()));
297306

298307
let min_client_timestamp = min_client_timestamp.unwrap(); // unwrap safety: aggregation_job_size > 0
299308
let max_client_timestamp = max_client_timestamp.unwrap(); // unwrap safety: aggregation_job_size > 0
@@ -329,10 +338,11 @@ where
329338
// be smaller than max_aggregation_job_size. We will only create jobs smaller than
330339
// min_aggregation_job_size if the remaining headroom in a batch requires it, otherwise
331340
// remaining reports will be added to unaggregated_report_ids, to be marked as unaggregated.
332-
for (time_bucket_start, mut bucket) in self.map.into_iter() {
341+
for (time_bucket_start, mut bucket) in self.buckets.into_iter() {
333342
Self::process_batches(
334343
&self.properties,
335344
self.aggregation_job_writer,
345+
&mut self.report_ids_to_scrub,
336346
&mut self.new_batches,
337347
&time_bucket_start,
338348
&mut bucket,
@@ -348,6 +358,11 @@ where
348358

349359
try_join!(
350360
self.aggregation_job_writer.write(tx, vdaf),
361+
try_join_all(
362+
self.report_ids_to_scrub
363+
.iter()
364+
.map(|report_id| tx.scrub_client_report(&self.properties.task_id, report_id))
365+
),
351366
try_join_all(
352367
self.new_batches
353368
.iter()

aggregator_core/src/datastore.rs

+28
Original file line numberDiff line numberDiff line change
@@ -1555,6 +1555,34 @@ impl<C: Clock> Transaction<'_, C> {
15551555
)
15561556
}
15571557

1558+
#[cfg(feature = "test-util")]
1559+
#[cfg_attr(docsrs, doc(cfg(feature = "test-util")))]
1560+
pub async fn verify_client_report_scrubbed(&self, task_id: &TaskId, report_id: &ReportId) {
1561+
let row = self
1562+
.query_one(
1563+
"SELECT
1564+
client_reports.extensions,
1565+
client_reports.public_share,
1566+
client_reports.leader_input_share,
1567+
client_reports.helper_encrypted_input_share
1568+
FROM client_reports
1569+
JOIN tasks ON tasks.id = client_reports.task_id
1570+
WHERE tasks.task_id = $1
1571+
AND client_reports.report_id = $2",
1572+
&[task_id.as_ref(), report_id.as_ref()],
1573+
)
1574+
.await
1575+
.unwrap();
1576+
1577+
assert_eq!(row.get::<_, Option<Vec<u8>>>("extensions"), None);
1578+
assert_eq!(row.get::<_, Option<Vec<u8>>>("public_share"), None);
1579+
assert_eq!(row.get::<_, Option<Vec<u8>>>("leader_input_share"), None);
1580+
assert_eq!(
1581+
row.get::<_, Option<Vec<u8>>>("helper_encrypted_input_share"),
1582+
None
1583+
);
1584+
}
1585+
15581586
/// put_report_share stores a report share, given its associated task ID.
15591587
///
15601588
/// This method is intended for use by aggregators acting in the Helper role; notably, it does

aggregator_core/src/datastore/tests.rs

+1-23
Original file line numberDiff line numberDiff line change
@@ -660,29 +660,7 @@ async fn roundtrip_report(ephemeral_datastore: EphemeralDatastore) {
660660
Box::pin(async move {
661661
tx.scrub_client_report(&task_id, &report_id).await.unwrap();
662662

663-
let row = tx
664-
.query_one(
665-
"SELECT
666-
client_reports.extensions,
667-
client_reports.public_share,
668-
client_reports.leader_input_share,
669-
client_reports.helper_encrypted_input_share
670-
FROM client_reports
671-
JOIN tasks ON tasks.id = client_reports.task_id
672-
WHERE tasks.task_id = $1
673-
AND client_reports.report_id = $2",
674-
&[&task_id.as_ref(), &report_id.as_ref()],
675-
)
676-
.await
677-
.unwrap();
678-
679-
assert_eq!(row.get::<_, Option<Vec<u8>>>("extensions"), None);
680-
assert_eq!(row.get::<_, Option<Vec<u8>>>("public_share"), None);
681-
assert_eq!(row.get::<_, Option<Vec<u8>>>("leader_input_share"), None);
682-
assert_eq!(
683-
row.get::<_, Option<Vec<u8>>>("helper_encrypted_input_share"),
684-
None
685-
);
663+
tx.verify_client_report_scrubbed(&task_id, &report_id).await;
686664

687665
assert_matches!(
688666
tx.get_client_report::<0, dummy_vdaf::Vdaf>(

0 commit comments

Comments
 (0)