diff --git a/lib/src/network/server.rs b/lib/src/network/server.rs index 2731a1e1a..ea37d55d6 100644 --- a/lib/src/network/server.rs +++ b/lib/src/network/server.rs @@ -211,7 +211,7 @@ impl<'a> Responder<'a> { .store() .acquire_read() .await? - .read_block_on_peer_request(&id, &mut content, &self.vault.block_tracker) + .read_block(&id, &mut content) .await; match result { diff --git a/lib/src/repository/mod.rs b/lib/src/repository/mod.rs index 420e5b90c..a20cd1165 100644 --- a/lib/src/repository/mod.rs +++ b/lib/src/repository/mod.rs @@ -180,10 +180,7 @@ impl Repository { { let mut conn = vault.store().db().acquire().await?; if let Some(block_expiration) = metadata::block_expiration::get(&mut conn).await? { - vault - .store() - .set_block_expiration(Some(block_expiration)) - .await?; + vault.set_block_expiration(Some(block_expiration)).await?; } } diff --git a/lib/src/repository/vault.rs b/lib/src/repository/vault.rs index 4b5521489..6789b8427 100644 --- a/lib/src/repository/vault.rs +++ b/lib/src/repository/vault.rs @@ -222,7 +222,10 @@ impl Vault { } pub async fn set_block_expiration(&self, duration: Option) -> Result<()> { - Ok(self.store.set_block_expiration(duration).await?) + Ok(self + .store + .set_block_expiration(duration, self.block_tracker.clone()) + .await?) } pub async fn block_expiration(&self) -> Option { diff --git a/lib/src/store/block_expiration_tracker.rs b/lib/src/store/block_expiration_tracker.rs index f83487ca9..175af5a22 100644 --- a/lib/src/store/block_expiration_tracker.rs +++ b/lib/src/store/block_expiration_tracker.rs @@ -1,10 +1,18 @@ -use super::error::Error; +use super::{ + cache::{Cache, CacheTransaction}, + error::Error, + index::{self, UpdateSummaryReason}, + leaf_node, root_node, +}; use crate::{ + block_tracker::BlockTracker as BlockDownloadTracker, collections::{hash_map, HashMap, HashSet}, + crypto::sign::PublicKey, db, deadlock::BlockingMutex, + future::try_collect_into, protocol::{BlockId, SingleBlockPresence}, - sync::uninitialized_watch, + sync::{broadcast_hash_set, uninitialized_watch}, }; use futures_util::{StreamExt, TryStreamExt}; use scoped_task::{self, ScopedJoinHandle}; @@ -17,7 +25,6 @@ use std::{ use tokio::{select, sync::watch, time::sleep}; pub(crate) struct BlockExpirationTracker { - pool: db::Pool, shared: Arc>, watch_tx: uninitialized_watch::Sender<()>, expiration_time_tx: watch::Sender, @@ -25,17 +32,17 @@ pub(crate) struct BlockExpirationTracker { } impl BlockExpirationTracker { - pub fn pool(&self) -> &db::Pool { - &self.pool - } - - pub async fn enable_expiration( + pub(super) async fn enable_expiration( pool: db::Pool, expiration_time: Duration, + block_download_tracker: BlockDownloadTracker, + client_reload_index_tx: broadcast_hash_set::Sender, + cache: Arc, ) -> Result { let mut shared = Shared { blocks_by_id: Default::default(), blocks_by_expiration: Default::default(), + to_missing_if_expired: Default::default(), }; let mut tx = pool.begin_read().await?; @@ -49,7 +56,7 @@ impl BlockExpirationTracker { let now = SystemTime::now(); while let Some(id) = ids.next().await { - shared.handle_block_update(&id?, now); + shared.insert_block(&id?, now); } let (watch_tx, watch_rx) = uninitialized_watch::channel(); @@ -59,17 +66,25 @@ impl BlockExpirationTracker { let _task = scoped_task::spawn({ let shared = shared.clone(); - let pool = pool.clone(); async move { - if let Err(err) = run_task(shared, pool, watch_rx, expiration_time_rx).await { + if let Err(err) = run_task( + shared, + pool, + watch_rx, + expiration_time_rx, + block_download_tracker, + client_reload_index_tx, + cache, + ) + .await + { tracing::error!("BlockExpirationTracker task has ended with {err:?}"); } } }); Ok(Self { - pool, shared, watch_tx, expiration_time_tx, @@ -77,16 +92,19 @@ impl BlockExpirationTracker { }) } - pub fn handle_block_update(&self, block_id: &BlockId) { + pub fn handle_block_update(&self, block_id: &BlockId, is_missing: bool) { // Not inlining these lines to call `SystemTime::now()` only once the `lock` is acquired. let mut lock = self.shared.lock().unwrap(); - lock.handle_block_update(block_id, SystemTime::now()); + lock.insert_block(block_id, SystemTime::now()); + if is_missing { + lock.to_missing_if_expired.insert(*block_id); + } drop(lock); self.watch_tx.send(()).unwrap_or(()); } pub fn handle_block_removed(&self, block: &BlockId) { - self.shared.lock().unwrap().handle_block_removed(block); + self.shared.lock().unwrap().remove_block(block); } pub fn set_expiration_time(&self, expiration_time: Duration) { @@ -109,12 +127,14 @@ struct Shared { // blocks_by_id: HashMap, blocks_by_expiration: BTreeMap>, + + to_missing_if_expired: HashSet, } impl Shared { /// Add the `block` into `Self`. If it's already there, remove it and add it back with the new /// time stamp. - fn handle_block_update(&mut self, block: &BlockId, ts: TimeUpdated) { + fn insert_block(&mut self, block: &BlockId, ts: TimeUpdated) { // Asserts and unwraps are OK due to the `Shared` invariants defined above. match self.blocks_by_id.entry(*block) { hash_map::Entry::Occupied(mut entry) => { @@ -155,8 +175,7 @@ impl Shared { } } - /// Remove `block` from `Self`. - fn handle_block_removed(&mut self, block: &BlockId) { + fn remove_block(&mut self, block: &BlockId) { // Asserts and unwraps are OK due to the `Shared` invariants defined above. let ts = match self.blocks_by_id.entry(*block) { hash_map::Entry::Occupied(entry) => entry.remove(), @@ -204,27 +223,51 @@ async fn run_task( pool: db::Pool, mut watch_rx: uninitialized_watch::Receiver<()>, mut expiration_time_rx: watch::Receiver, + block_download_tracker: BlockDownloadTracker, + client_reload_index_tx: broadcast_hash_set::Sender, + cache: Arc, ) -> Result<(), Error> { loop { let expiration_time = *expiration_time_rx.borrow(); let (ts, block) = { - let oldest_entry = shared - .lock() - .unwrap() - .blocks_by_expiration - .first_entry() - // Unwrap OK due to the invariant #2. - .map(|e| (*e.key(), *e.get().iter().next().unwrap())); - - match oldest_entry { - Some((ts, block)) => (ts, block), - None => { + enum Enum { + OldestEntry(Option<(TimeUpdated, BlockId)>), + ToMissing(HashSet), + } + + match { + let mut lock = shared.lock().unwrap(); + + if !lock.to_missing_if_expired.is_empty() { + Enum::ToMissing(std::mem::take(&mut lock.to_missing_if_expired)) + } else { + Enum::OldestEntry( + lock.blocks_by_expiration + .first_entry() + // Unwrap OK due to the invariant #2. + .map(|e| (*e.key(), *e.get().iter().next().unwrap())), + ) + } + } { + Enum::OldestEntry(Some((time_updated, block_id))) => (time_updated, block_id), + Enum::OldestEntry(None) => { if watch_rx.changed().await.is_err() { return Ok(()); } continue; } + Enum::ToMissing(to_missing_if_expired) => { + set_as_missing_if_expired( + &pool, + to_missing_if_expired, + &block_download_tracker, + &client_reload_index_tx, + cache.begin(), + ) + .await?; + continue; + } } }; @@ -238,6 +281,9 @@ async fn run_task( _ = expiration_time_rx.changed() => { continue; } + _ = watch_rx.changed() => { + continue; + } } } @@ -290,12 +336,60 @@ async fn run_task( // TODO: Should we then restart the tracker or do we rely on the fact that if committing // into the database fails, then we know the app will be closed? The situation would // resolve itself upon restart. - shared.lock().unwrap().handle_block_removed(&block); + shared.lock().unwrap().remove_block(&block); tx.commit().await?; } } +async fn set_as_missing_if_expired( + pool: &db::Pool, + block_ids: HashSet, + block_download_tracker: &BlockDownloadTracker, + client_reload_index_tx: &broadcast_hash_set::Sender, + mut cache: CacheTransaction, +) -> Result<(), Error> { + let mut tx = pool.begin_write().await?; + + // Branches where we have newly missing blocks. We need to tell the client to reload indices + // for these branches from peers. + let mut branches: HashSet = HashSet::default(); + + for block_id in &block_ids { + let changed = leaf_node::set_missing_if_expired(&mut tx, block_id).await?; + + if !changed { + continue; + } + + block_download_tracker.require(*block_id); + + let nodes: Vec<_> = leaf_node::load_parent_hashes(&mut tx, block_id) + .try_collect() + .await?; + + for (hash, _state) in index::update_summaries( + &mut tx, + &mut cache, + nodes, + UpdateSummaryReason::BlockRemoved, + ) + .await? + { + try_collect_into(root_node::load_writer_ids(&mut tx, &hash), &mut branches).await?; + } + } + + tx.commit().await?; + + for branch_id in branches { + // TODO: Throttle these messages. + client_reload_index_tx.insert(&branch_id); + } + + Ok(()) +} + #[cfg(test)] mod test { use super::super::*; @@ -308,6 +402,7 @@ mod test { let mut shared = Shared { blocks_by_id: Default::default(), blocks_by_expiration: Default::default(), + to_missing_if_expired: Default::default(), }; // add once @@ -315,25 +410,25 @@ mod test { let ts = SystemTime::now(); let block: BlockId = rand::random(); - shared.handle_block_update(&block, ts); + shared.insert_block(&block, ts); assert_eq!(*shared.blocks_by_id.get(&block).unwrap(), ts); shared.assert_invariants(); - shared.handle_block_removed(&block); + shared.remove_block(&block); assert!(shared.blocks_by_id.is_empty()); shared.assert_invariants(); // add twice - shared.handle_block_update(&block, ts); - shared.handle_block_update(&block, ts); + shared.insert_block(&block, ts); + shared.insert_block(&block, ts); assert_eq!(*shared.blocks_by_id.get(&block).unwrap(), ts); shared.assert_invariants(); - shared.handle_block_removed(&block); + shared.remove_block(&block); assert!(shared.blocks_by_id.is_empty()); shared.assert_invariants(); @@ -363,15 +458,20 @@ mod test { assert_eq!(count_blocks(store.db()).await, 1); - let tracker = - BlockExpirationTracker::enable_expiration(store.db().clone(), Duration::from_secs(1)) - .await - .unwrap(); + let tracker = BlockExpirationTracker::enable_expiration( + store.db().clone(), + Duration::from_secs(1), + BlockDownloadTracker::new(), + broadcast_hash_set::channel().0, + Arc::new(Cache::new()), + ) + .await + .unwrap(); sleep(Duration::from_millis(700)).await; let block_id = add_block(&write_keys, &branch_id, &store).await; - tracker.handle_block_update(&block_id); + tracker.handle_block_update(&block_id, false); assert_eq!(count_blocks(store.db()).await, 2); diff --git a/lib/src/store/changeset.rs b/lib/src/store/changeset.rs index dc01f087d..34203cc23 100644 --- a/lib/src/store/changeset.rs +++ b/lib/src/store/changeset.rs @@ -46,11 +46,11 @@ impl Changeset { patch.save(tx, &self.bump, write_keys).await?; for block in self.blocks { + block::write(tx.db(), &block).await?; + if let Some(tracker) = &tx.block_expiration_tracker { - tracker.handle_block_update(&block.id); + tracker.handle_block_update(&block.id, false); } - - block::write(tx.db(), &block).await?; } Ok(()) diff --git a/lib/src/store/mod.rs b/lib/src/store/mod.rs index 962758b37..cf7bda76f 100644 --- a/lib/src/store/mod.rs +++ b/lib/src/store/mod.rs @@ -30,11 +30,11 @@ use self::{ index::UpdateSummaryReason, }; use crate::{ + block_tracker::BlockTracker as BlockDownloadTracker, collections::HashSet, crypto::{sign::PublicKey, CacheHash, Hash, Hashable}, db, debug::DebugPrinter, - future::try_collect_into, progress::Progress, protocol::{ get_bucket, Block, BlockContent, BlockId, BlockNonce, InnerNodeMap, LeafNodeSet, @@ -77,6 +77,7 @@ impl Store { pub async fn set_block_expiration( &self, expiration_time: Option, + block_download_tracker: BlockDownloadTracker, ) -> Result<(), Error> { let mut tracker_lock = self.block_expiration_tracker.write().await; @@ -93,8 +94,14 @@ impl Store { None => return Ok(()), }; - let tracker = - BlockExpirationTracker::enable_expiration(self.db.clone(), expiration_time).await?; + let tracker = BlockExpirationTracker::enable_expiration( + self.db.clone(), + expiration_time, + block_download_tracker, + self.client_reload_index_tx.clone(), + self.cache.clone(), + ) + .await?; *tracker_lock = Some(Arc::new(tracker)); @@ -114,7 +121,6 @@ impl Store { Ok(Reader { inner: Handle::Connection(self.db.acquire().await?), cache: self.cache.begin(), - client_reload_index_tx: self.client_reload_index_tx.clone(), block_expiration_tracker: self.block_expiration_tracker.read().await.clone(), }) } @@ -125,7 +131,6 @@ impl Store { inner: Reader { inner: Handle::ReadTransaction(self.db.begin_read().await?), cache: self.cache.begin(), - client_reload_index_tx: self.client_reload_index_tx.clone(), block_expiration_tracker: self.block_expiration_tracker.read().await.clone(), }, }) @@ -138,7 +143,6 @@ impl Store { inner: Reader { inner: Handle::WriteTransaction(self.db.begin_write().await?), cache: self.cache.begin(), - client_reload_index_tx: self.client_reload_index_tx.clone(), block_expiration_tracker: self.block_expiration_tracker.read().await.clone(), }, }, @@ -241,7 +245,6 @@ impl Store { pub(crate) struct Reader { inner: Handle, cache: CacheTransaction, - client_reload_index_tx: broadcast_hash_set::Sender, block_expiration_tracker: Option>, } @@ -256,78 +259,16 @@ impl Reader { id: &BlockId, content: &mut BlockContent, ) -> Result { - if let Some(expiration_tracker) = &self.block_expiration_tracker { - expiration_tracker.handle_block_update(id); - } - let result = block::read(self.db(), id, content).await; - if matches!(result, Err(Error::BlockNotFound)) { - self.set_as_missing_if_expired(id).await?; - } - - result - } - - pub async fn read_block_on_peer_request( - &mut self, - id: &BlockId, - content: &mut BlockContent, - block_tracker: &crate::block_tracker::BlockTracker, - ) -> Result { if let Some(expiration_tracker) = &self.block_expiration_tracker { - expiration_tracker.handle_block_update(id); - } - - let result = block::read(self.db(), id, content).await; - - if matches!(result, Err(Error::BlockNotFound)) && self.set_as_missing_if_expired(id).await? - { - block_tracker.require(*id); + let is_missing = matches!(result, Err(Error::BlockNotFound)); + expiration_tracker.handle_block_update(id, is_missing); } result } - async fn set_as_missing_if_expired(&mut self, block_id: &BlockId) -> Result { - let expiration_tracker = match &self.block_expiration_tracker { - Some(expiration_tracker) => expiration_tracker, - None => return Ok(false), - }; - - // TODO: This actually does a DB write operation - let mut tx = expiration_tracker.pool().begin_write().await?; - let cache = &mut self.cache; - - let changed = leaf_node::set_missing_if_expired(&mut tx, block_id).await?; - - if !changed { - return Ok(false); - } - - let nodes: Vec<_> = leaf_node::load_parent_hashes(&mut tx, block_id) - .try_collect() - .await?; - - let mut branches: HashSet = HashSet::default(); - - for (hash, _state) in - index::update_summaries(&mut tx, cache, nodes, UpdateSummaryReason::BlockRemoved) - .await? - { - try_collect_into(root_node::load_writer_ids(&mut tx, &hash), &mut branches).await?; - } - - // TODO: Use the 'on commit' machinery to ensure the below code is executed. - tx.commit().await?; - - for branch_id in branches { - self.client_reload_index_tx.insert(&branch_id); - } - - Ok(true) - } - /// Checks whether the block exists in the store. pub async fn block_exists(&mut self, id: &BlockId) -> Result { block::exists(self.db(), id).await @@ -633,11 +574,14 @@ impl WriteTransaction { &mut self, block: &Block, ) -> Result { + let (db, cache) = self.db_and_cache(); + let result = block::receive(db, cache, block).await; + if let Some(tracker) = &self.block_expiration_tracker { - tracker.handle_block_update(&block.id); + tracker.handle_block_update(&block.id, false); } - let (db, cache) = self.db_and_cache(); - block::receive(db, cache, block).await + + result } pub async fn commit(self) -> Result<(), Error> {