From ade00bc71f3f63f1574ea48aa302379f0b166912 Mon Sep 17 00:00:00 2001
From: Peter Jankuliak
Date: Mon, 11 Sep 2023 15:32:12 +0200
Subject: [PATCH] Move set_as_missing_if_expired from Store to
BlockExpirationTracker
---
lib/src/network/server.rs | 2 +-
lib/src/repository/mod.rs | 5 +-
lib/src/repository/vault.rs | 5 +-
lib/src/store/block_expiration_tracker.rs | 157 ++++++++++++++++++----
lib/src/store/changeset.rs | 6 +-
lib/src/store/mod.rs | 92 +++----------
6 files changed, 156 insertions(+), 111 deletions(-)
diff --git a/lib/src/network/server.rs b/lib/src/network/server.rs
index 2731a1e1a..ea37d55d6 100644
--- a/lib/src/network/server.rs
+++ b/lib/src/network/server.rs
@@ -211,7 +211,7 @@ impl<'a> Responder<'a> {
.store()
.acquire_read()
.await?
- .read_block_on_peer_request(&id, &mut content, &self.vault.block_tracker)
+ .read_block(&id, &mut content)
.await;
match result {
diff --git a/lib/src/repository/mod.rs b/lib/src/repository/mod.rs
index 420e5b90c..a20cd1165 100644
--- a/lib/src/repository/mod.rs
+++ b/lib/src/repository/mod.rs
@@ -180,10 +180,7 @@ impl Repository {
{
let mut conn = vault.store().db().acquire().await?;
if let Some(block_expiration) = metadata::block_expiration::get(&mut conn).await? {
- vault
- .store()
- .set_block_expiration(Some(block_expiration))
- .await?;
+ vault.set_block_expiration(Some(block_expiration)).await?;
}
}
diff --git a/lib/src/repository/vault.rs b/lib/src/repository/vault.rs
index 4b5521489..6789b8427 100644
--- a/lib/src/repository/vault.rs
+++ b/lib/src/repository/vault.rs
@@ -222,7 +222,10 @@ impl Vault {
}
pub async fn set_block_expiration(&self, duration: Option) -> Result<()> {
- Ok(self.store.set_block_expiration(duration).await?)
+ Ok(self
+ .store
+ .set_block_expiration(duration, self.block_tracker.clone())
+ .await?)
}
pub async fn block_expiration(&self) -> Option {
diff --git a/lib/src/store/block_expiration_tracker.rs b/lib/src/store/block_expiration_tracker.rs
index f83487ca9..27d495945 100644
--- a/lib/src/store/block_expiration_tracker.rs
+++ b/lib/src/store/block_expiration_tracker.rs
@@ -1,10 +1,18 @@
-use super::error::Error;
+use super::{
+ cache::{Cache, CacheTransaction},
+ error::Error,
+ index::{self, UpdateSummaryReason},
+ leaf_node, root_node,
+};
use crate::{
+ block_tracker::BlockTracker as BlockDownloadTracker,
collections::{hash_map, HashMap, HashSet},
+ crypto::sign::PublicKey,
db,
deadlock::BlockingMutex,
+ future::try_collect_into,
protocol::{BlockId, SingleBlockPresence},
- sync::uninitialized_watch,
+ sync::{broadcast_hash_set, uninitialized_watch},
};
use futures_util::{StreamExt, TryStreamExt};
use scoped_task::{self, ScopedJoinHandle};
@@ -17,7 +25,6 @@ use std::{
use tokio::{select, sync::watch, time::sleep};
pub(crate) struct BlockExpirationTracker {
- pool: db::Pool,
shared: Arc>,
watch_tx: uninitialized_watch::Sender<()>,
expiration_time_tx: watch::Sender,
@@ -25,17 +32,17 @@ pub(crate) struct BlockExpirationTracker {
}
impl BlockExpirationTracker {
- pub fn pool(&self) -> &db::Pool {
- &self.pool
- }
-
- pub async fn enable_expiration(
+ pub(super) async fn enable_expiration(
pool: db::Pool,
expiration_time: Duration,
+ block_download_tracker: BlockDownloadTracker,
+ client_reload_index_tx: broadcast_hash_set::Sender,
+ cache: Arc,
) -> Result {
let mut shared = Shared {
blocks_by_id: Default::default(),
blocks_by_expiration: Default::default(),
+ to_missing_if_expired: Default::default(),
};
let mut tx = pool.begin_read().await?;
@@ -59,17 +66,25 @@ impl BlockExpirationTracker {
let _task = scoped_task::spawn({
let shared = shared.clone();
- let pool = pool.clone();
async move {
- if let Err(err) = run_task(shared, pool, watch_rx, expiration_time_rx).await {
+ if let Err(err) = run_task(
+ shared,
+ pool,
+ watch_rx,
+ expiration_time_rx,
+ block_download_tracker,
+ client_reload_index_tx,
+ cache,
+ )
+ .await
+ {
tracing::error!("BlockExpirationTracker task has ended with {err:?}");
}
}
});
Ok(Self {
- pool,
shared,
watch_tx,
expiration_time_tx,
@@ -77,10 +92,13 @@ impl BlockExpirationTracker {
})
}
- pub fn handle_block_update(&self, block_id: &BlockId) {
+ pub fn handle_block_update(&self, block_id: &BlockId, is_missing: bool) {
// Not inlining these lines to call `SystemTime::now()` only once the `lock` is acquired.
let mut lock = self.shared.lock().unwrap();
lock.handle_block_update(block_id, SystemTime::now());
+ if is_missing {
+ lock.to_missing_if_expired.insert(*block_id);
+ }
drop(lock);
self.watch_tx.send(()).unwrap_or(());
}
@@ -109,6 +127,8 @@ struct Shared {
//
blocks_by_id: HashMap,
blocks_by_expiration: BTreeMap>,
+
+ to_missing_if_expired: HashSet,
}
impl Shared {
@@ -204,27 +224,51 @@ async fn run_task(
pool: db::Pool,
mut watch_rx: uninitialized_watch::Receiver<()>,
mut expiration_time_rx: watch::Receiver,
+ block_download_tracker: BlockDownloadTracker,
+ client_reload_index_tx: broadcast_hash_set::Sender,
+ cache: Arc,
) -> Result<(), Error> {
loop {
let expiration_time = *expiration_time_rx.borrow();
let (ts, block) = {
- let oldest_entry = shared
- .lock()
- .unwrap()
- .blocks_by_expiration
- .first_entry()
- // Unwrap OK due to the invariant #2.
- .map(|e| (*e.key(), *e.get().iter().next().unwrap()));
-
- match oldest_entry {
- Some((ts, block)) => (ts, block),
- None => {
+ enum Enum {
+ OldestEntry(Option<(TimeUpdated, BlockId)>),
+ ToMissing(HashSet),
+ }
+
+ match {
+ let mut lock = shared.lock().unwrap();
+
+ if !lock.to_missing_if_expired.is_empty() {
+ Enum::ToMissing(std::mem::take(&mut lock.to_missing_if_expired))
+ } else {
+ Enum::OldestEntry(
+ lock.blocks_by_expiration
+ .first_entry()
+ // Unwrap OK due to the invariant #2.
+ .map(|e| (*e.key(), *e.get().iter().next().unwrap())),
+ )
+ }
+ } {
+ Enum::OldestEntry(Some((time_updated, block_id))) => (time_updated, block_id),
+ Enum::OldestEntry(None) => {
if watch_rx.changed().await.is_err() {
return Ok(());
}
continue;
}
+ Enum::ToMissing(to_missing_if_expired) => {
+ set_as_missing_if_expired(
+ &pool,
+ to_missing_if_expired,
+ &block_download_tracker,
+ &client_reload_index_tx,
+ cache.begin(),
+ )
+ .await?;
+ continue;
+ }
}
};
@@ -238,6 +282,9 @@ async fn run_task(
_ = expiration_time_rx.changed() => {
continue;
}
+ _ = watch_rx.changed() => {
+ continue;
+ }
}
}
@@ -296,6 +343,54 @@ async fn run_task(
}
}
+async fn set_as_missing_if_expired(
+ pool: &db::Pool,
+ block_ids: HashSet,
+ block_download_tracker: &BlockDownloadTracker,
+ client_reload_index_tx: &broadcast_hash_set::Sender,
+ mut cache: CacheTransaction,
+) -> Result<(), Error> {
+ let mut tx = pool.begin_write().await?;
+
+ // Branches where we have newly missing blocks. We need to tell the client to reload indices
+ // for these branches from peers.
+ let mut branches: HashSet = HashSet::default();
+
+ for block_id in &block_ids {
+ let changed = leaf_node::set_missing_if_expired(&mut tx, block_id).await?;
+
+ if !changed {
+ continue;
+ }
+
+ block_download_tracker.require(*block_id);
+
+ let nodes: Vec<_> = leaf_node::load_parent_hashes(&mut tx, block_id)
+ .try_collect()
+ .await?;
+
+ for (hash, _state) in index::update_summaries(
+ &mut tx,
+ &mut cache,
+ nodes,
+ UpdateSummaryReason::BlockRemoved,
+ )
+ .await?
+ {
+ try_collect_into(root_node::load_writer_ids(&mut tx, &hash), &mut branches).await?;
+ }
+ }
+
+ tx.commit().await?;
+
+ for branch_id in branches {
+ // TODO: Throttle these messages.
+ client_reload_index_tx.insert(&branch_id);
+ }
+
+ Ok(())
+}
+
#[cfg(test)]
mod test {
use super::super::*;
@@ -308,6 +403,7 @@ mod test {
let mut shared = Shared {
blocks_by_id: Default::default(),
blocks_by_expiration: Default::default(),
+ to_missing_if_expired: Default::default(),
};
// add once
@@ -363,15 +459,20 @@ mod test {
assert_eq!(count_blocks(store.db()).await, 1);
- let tracker =
- BlockExpirationTracker::enable_expiration(store.db().clone(), Duration::from_secs(1))
- .await
- .unwrap();
+ let tracker = BlockExpirationTracker::enable_expiration(
+ store.db().clone(),
+ Duration::from_secs(1),
+ BlockDownloadTracker::new(),
+ broadcast_hash_set::channel().0,
+ Arc::new(Cache::new()),
+ )
+ .await
+ .unwrap();
sleep(Duration::from_millis(700)).await;
let block_id = add_block(&write_keys, &branch_id, &store).await;
- tracker.handle_block_update(&block_id);
+ tracker.handle_block_update(&block_id, false);
assert_eq!(count_blocks(store.db()).await, 2);
diff --git a/lib/src/store/changeset.rs b/lib/src/store/changeset.rs
index dc01f087d..34203cc23 100644
--- a/lib/src/store/changeset.rs
+++ b/lib/src/store/changeset.rs
@@ -46,11 +46,11 @@ impl Changeset {
patch.save(tx, &self.bump, write_keys).await?;
for block in self.blocks {
+ block::write(tx.db(), &block).await?;
+
if let Some(tracker) = &tx.block_expiration_tracker {
- tracker.handle_block_update(&block.id);
+ tracker.handle_block_update(&block.id, false);
}
-
- block::write(tx.db(), &block).await?;
}
Ok(())
diff --git a/lib/src/store/mod.rs b/lib/src/store/mod.rs
index 962758b37..cf7bda76f 100644
--- a/lib/src/store/mod.rs
+++ b/lib/src/store/mod.rs
@@ -30,11 +30,11 @@ use self::{
index::UpdateSummaryReason,
};
use crate::{
+ block_tracker::BlockTracker as BlockDownloadTracker,
collections::HashSet,
crypto::{sign::PublicKey, CacheHash, Hash, Hashable},
db,
debug::DebugPrinter,
- future::try_collect_into,
progress::Progress,
protocol::{
get_bucket, Block, BlockContent, BlockId, BlockNonce, InnerNodeMap, LeafNodeSet,
@@ -77,6 +77,7 @@ impl Store {
pub async fn set_block_expiration(
&self,
expiration_time: Option,
+ block_download_tracker: BlockDownloadTracker,
) -> Result<(), Error> {
let mut tracker_lock = self.block_expiration_tracker.write().await;
@@ -93,8 +94,14 @@ impl Store {
None => return Ok(()),
};
- let tracker =
- BlockExpirationTracker::enable_expiration(self.db.clone(), expiration_time).await?;
+ let tracker = BlockExpirationTracker::enable_expiration(
+ self.db.clone(),
+ expiration_time,
+ block_download_tracker,
+ self.client_reload_index_tx.clone(),
+ self.cache.clone(),
+ )
+ .await?;
*tracker_lock = Some(Arc::new(tracker));
@@ -114,7 +121,6 @@ impl Store {
Ok(Reader {
inner: Handle::Connection(self.db.acquire().await?),
cache: self.cache.begin(),
- client_reload_index_tx: self.client_reload_index_tx.clone(),
block_expiration_tracker: self.block_expiration_tracker.read().await.clone(),
})
}
@@ -125,7 +131,6 @@ impl Store {
inner: Reader {
inner: Handle::ReadTransaction(self.db.begin_read().await?),
cache: self.cache.begin(),
- client_reload_index_tx: self.client_reload_index_tx.clone(),
block_expiration_tracker: self.block_expiration_tracker.read().await.clone(),
},
})
@@ -138,7 +143,6 @@ impl Store {
inner: Reader {
inner: Handle::WriteTransaction(self.db.begin_write().await?),
cache: self.cache.begin(),
- client_reload_index_tx: self.client_reload_index_tx.clone(),
block_expiration_tracker: self.block_expiration_tracker.read().await.clone(),
},
},
@@ -241,7 +245,6 @@ impl Store {
pub(crate) struct Reader {
inner: Handle,
cache: CacheTransaction,
- client_reload_index_tx: broadcast_hash_set::Sender,
block_expiration_tracker: Option>,
}
@@ -256,78 +259,16 @@ impl Reader {
id: &BlockId,
content: &mut BlockContent,
) -> Result {
- if let Some(expiration_tracker) = &self.block_expiration_tracker {
- expiration_tracker.handle_block_update(id);
- }
-
let result = block::read(self.db(), id, content).await;
- if matches!(result, Err(Error::BlockNotFound)) {
- self.set_as_missing_if_expired(id).await?;
- }
-
- result
- }
-
- pub async fn read_block_on_peer_request(
- &mut self,
- id: &BlockId,
- content: &mut BlockContent,
- block_tracker: &crate::block_tracker::BlockTracker,
- ) -> Result {
if let Some(expiration_tracker) = &self.block_expiration_tracker {
- expiration_tracker.handle_block_update(id);
- }
-
- let result = block::read(self.db(), id, content).await;
-
- if matches!(result, Err(Error::BlockNotFound)) && self.set_as_missing_if_expired(id).await?
- {
- block_tracker.require(*id);
+ let is_missing = matches!(result, Err(Error::BlockNotFound));
+ expiration_tracker.handle_block_update(id, is_missing);
}
result
}
- async fn set_as_missing_if_expired(&mut self, block_id: &BlockId) -> Result {
- let expiration_tracker = match &self.block_expiration_tracker {
- Some(expiration_tracker) => expiration_tracker,
- None => return Ok(false),
- };
-
- // TODO: This actually does a DB write operation
- let mut tx = expiration_tracker.pool().begin_write().await?;
- let cache = &mut self.cache;
-
- let changed = leaf_node::set_missing_if_expired(&mut tx, block_id).await?;
-
- if !changed {
- return Ok(false);
- }
-
- let nodes: Vec<_> = leaf_node::load_parent_hashes(&mut tx, block_id)
- .try_collect()
- .await?;
-
- let mut branches: HashSet = HashSet::default();
-
- for (hash, _state) in
- index::update_summaries(&mut tx, cache, nodes, UpdateSummaryReason::BlockRemoved)
- .await?
- {
- try_collect_into(root_node::load_writer_ids(&mut tx, &hash), &mut branches).await?;
- }
-
- // TODO: Use the 'on commit' machinery to ensure the below code is executed.
- tx.commit().await?;
-
- for branch_id in branches {
- self.client_reload_index_tx.insert(&branch_id);
- }
-
- Ok(true)
- }
-
/// Checks whether the block exists in the store.
pub async fn block_exists(&mut self, id: &BlockId) -> Result {
block::exists(self.db(), id).await
@@ -633,11 +574,14 @@ impl WriteTransaction {
&mut self,
block: &Block,
) -> Result {
+ let (db, cache) = self.db_and_cache();
+ let result = block::receive(db, cache, block).await;
+
if let Some(tracker) = &self.block_expiration_tracker {
- tracker.handle_block_update(&block.id);
+ tracker.handle_block_update(&block.id, false);
}
- let (db, cache) = self.db_and_cache();
- block::receive(db, cache, block).await
+
+ result
}
pub async fn commit(self) -> Result<(), Error> {