Skip to content

Commit

Permalink
Move set_as_missing_if_expired from Store to BlockExpirationTracker
Browse files Browse the repository at this point in the history
  • Loading branch information
inetic committed Sep 11, 2023
1 parent a36180c commit ade00bc
Show file tree
Hide file tree
Showing 6 changed files with 156 additions and 111 deletions.
2 changes: 1 addition & 1 deletion lib/src/network/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ impl<'a> Responder<'a> {
.store()
.acquire_read()
.await?
.read_block_on_peer_request(&id, &mut content, &self.vault.block_tracker)
.read_block(&id, &mut content)
.await;

match result {
Expand Down
5 changes: 1 addition & 4 deletions lib/src/repository/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,7 @@ impl Repository {
{
let mut conn = vault.store().db().acquire().await?;
if let Some(block_expiration) = metadata::block_expiration::get(&mut conn).await? {
vault
.store()
.set_block_expiration(Some(block_expiration))
.await?;
vault.set_block_expiration(Some(block_expiration)).await?;
}
}

Expand Down
5 changes: 4 additions & 1 deletion lib/src/repository/vault.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,10 @@ impl Vault {
}

pub async fn set_block_expiration(&self, duration: Option<Duration>) -> Result<()> {
Ok(self.store.set_block_expiration(duration).await?)
Ok(self
.store
.set_block_expiration(duration, self.block_tracker.clone())
.await?)
}

pub async fn block_expiration(&self) -> Option<Duration> {
Expand Down
157 changes: 129 additions & 28 deletions lib/src/store/block_expiration_tracker.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
use super::error::Error;
use super::{
cache::{Cache, CacheTransaction},
error::Error,
index::{self, UpdateSummaryReason},
leaf_node, root_node,
};
use crate::{
block_tracker::BlockTracker as BlockDownloadTracker,
collections::{hash_map, HashMap, HashSet},
crypto::sign::PublicKey,
db,
deadlock::BlockingMutex,
future::try_collect_into,
protocol::{BlockId, SingleBlockPresence},
sync::uninitialized_watch,
sync::{broadcast_hash_set, uninitialized_watch},
};
use futures_util::{StreamExt, TryStreamExt};
use scoped_task::{self, ScopedJoinHandle};
Expand All @@ -17,25 +25,24 @@ use std::{
use tokio::{select, sync::watch, time::sleep};

pub(crate) struct BlockExpirationTracker {
pool: db::Pool,
shared: Arc<BlockingMutex<Shared>>,
watch_tx: uninitialized_watch::Sender<()>,
expiration_time_tx: watch::Sender<Duration>,
_task: ScopedJoinHandle<()>,
}

impl BlockExpirationTracker {
pub fn pool(&self) -> &db::Pool {
&self.pool
}

pub async fn enable_expiration(
pub(super) async fn enable_expiration(
pool: db::Pool,
expiration_time: Duration,
block_download_tracker: BlockDownloadTracker,
client_reload_index_tx: broadcast_hash_set::Sender<PublicKey>,
cache: Arc<Cache>,
) -> Result<Self, Error> {
let mut shared = Shared {
blocks_by_id: Default::default(),
blocks_by_expiration: Default::default(),
to_missing_if_expired: Default::default(),
};

let mut tx = pool.begin_read().await?;
Expand All @@ -59,28 +66,39 @@ impl BlockExpirationTracker {

let _task = scoped_task::spawn({
let shared = shared.clone();
let pool = pool.clone();

async move {
if let Err(err) = run_task(shared, pool, watch_rx, expiration_time_rx).await {
if let Err(err) = run_task(
shared,
pool,
watch_rx,
expiration_time_rx,
block_download_tracker,
client_reload_index_tx,
cache,
)
.await
{
tracing::error!("BlockExpirationTracker task has ended with {err:?}");
}
}
});

Ok(Self {
pool,
shared,
watch_tx,
expiration_time_tx,
_task,
})
}

pub fn handle_block_update(&self, block_id: &BlockId) {
pub fn handle_block_update(&self, block_id: &BlockId, is_missing: bool) {
// Not inlining these lines to call `SystemTime::now()` only once the `lock` is acquired.
let mut lock = self.shared.lock().unwrap();
lock.handle_block_update(block_id, SystemTime::now());
if is_missing {
lock.to_missing_if_expired.insert(*block_id);
}
drop(lock);
self.watch_tx.send(()).unwrap_or(());
}
Expand Down Expand Up @@ -109,6 +127,8 @@ struct Shared {
//
blocks_by_id: HashMap<BlockId, TimeUpdated>,
blocks_by_expiration: BTreeMap<TimeUpdated, HashSet<BlockId>>,

to_missing_if_expired: HashSet<BlockId>,
}

impl Shared {
Expand Down Expand Up @@ -204,27 +224,51 @@ async fn run_task(
pool: db::Pool,
mut watch_rx: uninitialized_watch::Receiver<()>,
mut expiration_time_rx: watch::Receiver<Duration>,
block_download_tracker: BlockDownloadTracker,
client_reload_index_tx: broadcast_hash_set::Sender<PublicKey>,
cache: Arc<Cache>,
) -> Result<(), Error> {
loop {
let expiration_time = *expiration_time_rx.borrow();

let (ts, block) = {
let oldest_entry = shared
.lock()
.unwrap()
.blocks_by_expiration
.first_entry()
// Unwrap OK due to the invariant #2.
.map(|e| (*e.key(), *e.get().iter().next().unwrap()));

match oldest_entry {
Some((ts, block)) => (ts, block),
None => {
enum Enum {
OldestEntry(Option<(TimeUpdated, BlockId)>),
ToMissing(HashSet<BlockId>),
}

match {
let mut lock = shared.lock().unwrap();

if !lock.to_missing_if_expired.is_empty() {
Enum::ToMissing(std::mem::take(&mut lock.to_missing_if_expired))
} else {
Enum::OldestEntry(
lock.blocks_by_expiration
.first_entry()
// Unwrap OK due to the invariant #2.
.map(|e| (*e.key(), *e.get().iter().next().unwrap())),
)
}
} {
Enum::OldestEntry(Some((time_updated, block_id))) => (time_updated, block_id),
Enum::OldestEntry(None) => {
if watch_rx.changed().await.is_err() {
return Ok(());
}
continue;
}
Enum::ToMissing(to_missing_if_expired) => {
set_as_missing_if_expired(
&pool,
to_missing_if_expired,
&block_download_tracker,
&client_reload_index_tx,
cache.begin(),
)
.await?;
continue;
}
}
};

Expand All @@ -238,6 +282,9 @@ async fn run_task(
_ = expiration_time_rx.changed() => {
continue;
}
_ = watch_rx.changed() => {
continue;
}
}
}

Expand Down Expand Up @@ -296,6 +343,54 @@ async fn run_task(
}
}

async fn set_as_missing_if_expired(
pool: &db::Pool,
block_ids: HashSet<BlockId>,
block_download_tracker: &BlockDownloadTracker,
client_reload_index_tx: &broadcast_hash_set::Sender<PublicKey>,
mut cache: CacheTransaction,
) -> Result<(), Error> {
let mut tx = pool.begin_write().await?;

// Branches where we have newly missing blocks. We need to tell the client to reload indices
// for these branches from peers.
let mut branches: HashSet<PublicKey> = HashSet::default();

for block_id in &block_ids {
let changed = leaf_node::set_missing_if_expired(&mut tx, block_id).await?;

if !changed {
continue;
}

block_download_tracker.require(*block_id);

let nodes: Vec<_> = leaf_node::load_parent_hashes(&mut tx, block_id)
.try_collect()
.await?;

for (hash, _state) in index::update_summaries(
&mut tx,
&mut cache,
nodes,
UpdateSummaryReason::BlockRemoved,
)
.await?
{
try_collect_into(root_node::load_writer_ids(&mut tx, &hash), &mut branches).await?;
}
}

tx.commit().await?;

for branch_id in branches {
// TODO: Throttle these messages.
client_reload_index_tx.insert(&branch_id);
}

Ok(())
}

#[cfg(test)]
mod test {
use super::super::*;
Expand All @@ -308,6 +403,7 @@ mod test {
let mut shared = Shared {
blocks_by_id: Default::default(),
blocks_by_expiration: Default::default(),
to_missing_if_expired: Default::default(),
};

// add once
Expand Down Expand Up @@ -363,15 +459,20 @@ mod test {

assert_eq!(count_blocks(store.db()).await, 1);

let tracker =
BlockExpirationTracker::enable_expiration(store.db().clone(), Duration::from_secs(1))
.await
.unwrap();
let tracker = BlockExpirationTracker::enable_expiration(
store.db().clone(),
Duration::from_secs(1),
BlockDownloadTracker::new(),
broadcast_hash_set::channel().0,
Arc::new(Cache::new()),
)
.await
.unwrap();

sleep(Duration::from_millis(700)).await;

let block_id = add_block(&write_keys, &branch_id, &store).await;
tracker.handle_block_update(&block_id);
tracker.handle_block_update(&block_id, false);

assert_eq!(count_blocks(store.db()).await, 2);

Expand Down
6 changes: 3 additions & 3 deletions lib/src/store/changeset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ impl Changeset {
patch.save(tx, &self.bump, write_keys).await?;

for block in self.blocks {
block::write(tx.db(), &block).await?;

if let Some(tracker) = &tx.block_expiration_tracker {
tracker.handle_block_update(&block.id);
tracker.handle_block_update(&block.id, false);
}

block::write(tx.db(), &block).await?;
}

Ok(())
Expand Down
Loading

0 comments on commit ade00bc

Please sign in to comment.