Skip to content

Commit

Permalink
Merge branch 'no-write-in-reader'
Browse files Browse the repository at this point in the history
  • Loading branch information
inetic committed Sep 11, 2023
2 parents f578436 + ed50c9d commit 7e8fe2f
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 123 deletions.
2 changes: 1 addition & 1 deletion lib/src/network/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ impl<'a> Responder<'a> {
.store()
.acquire_read()
.await?
.read_block_on_peer_request(&id, &mut content, &self.vault.block_tracker)
.read_block(&id, &mut content)
.await;

match result {
Expand Down
5 changes: 1 addition & 4 deletions lib/src/repository/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,7 @@ impl Repository {
{
let mut conn = vault.store().db().acquire().await?;
if let Some(block_expiration) = metadata::block_expiration::get(&mut conn).await? {
vault
.store()
.set_block_expiration(Some(block_expiration))
.await?;
vault.set_block_expiration(Some(block_expiration)).await?;
}
}

Expand Down
5 changes: 4 additions & 1 deletion lib/src/repository/vault.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,10 @@ impl Vault {
}

pub async fn set_block_expiration(&self, duration: Option<Duration>) -> Result<()> {
Ok(self.store.set_block_expiration(duration).await?)
Ok(self
.store
.set_block_expiration(duration, self.block_tracker.clone())
.await?)
}

pub async fn block_expiration(&self) -> Option<Duration> {
Expand Down
180 changes: 140 additions & 40 deletions lib/src/store/block_expiration_tracker.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
use super::error::Error;
use super::{
cache::{Cache, CacheTransaction},
error::Error,
index::{self, UpdateSummaryReason},
leaf_node, root_node,
};
use crate::{
block_tracker::BlockTracker as BlockDownloadTracker,
collections::{hash_map, HashMap, HashSet},
crypto::sign::PublicKey,
db,
deadlock::BlockingMutex,
future::try_collect_into,
protocol::{BlockId, SingleBlockPresence},
sync::uninitialized_watch,
sync::{broadcast_hash_set, uninitialized_watch},
};
use futures_util::{StreamExt, TryStreamExt};
use scoped_task::{self, ScopedJoinHandle};
Expand All @@ -17,25 +25,24 @@ use std::{
use tokio::{select, sync::watch, time::sleep};

pub(crate) struct BlockExpirationTracker {
pool: db::Pool,
shared: Arc<BlockingMutex<Shared>>,
watch_tx: uninitialized_watch::Sender<()>,
expiration_time_tx: watch::Sender<Duration>,
_task: ScopedJoinHandle<()>,
}

impl BlockExpirationTracker {
pub fn pool(&self) -> &db::Pool {
&self.pool
}

pub async fn enable_expiration(
pub(super) async fn enable_expiration(
pool: db::Pool,
expiration_time: Duration,
block_download_tracker: BlockDownloadTracker,
client_reload_index_tx: broadcast_hash_set::Sender<PublicKey>,
cache: Arc<Cache>,
) -> Result<Self, Error> {
let mut shared = Shared {
blocks_by_id: Default::default(),
blocks_by_expiration: Default::default(),
to_missing_if_expired: Default::default(),
};

let mut tx = pool.begin_read().await?;
Expand All @@ -49,7 +56,7 @@ impl BlockExpirationTracker {
let now = SystemTime::now();

while let Some(id) = ids.next().await {
shared.handle_block_update(&id?, now);
shared.insert_block(&id?, now);
}

let (watch_tx, watch_rx) = uninitialized_watch::channel();
Expand All @@ -59,34 +66,45 @@ impl BlockExpirationTracker {

let _task = scoped_task::spawn({
let shared = shared.clone();
let pool = pool.clone();

async move {
if let Err(err) = run_task(shared, pool, watch_rx, expiration_time_rx).await {
if let Err(err) = run_task(
shared,
pool,
watch_rx,
expiration_time_rx,
block_download_tracker,
client_reload_index_tx,
cache,
)
.await
{
tracing::error!("BlockExpirationTracker task has ended with {err:?}");
}
}
});

Ok(Self {
pool,
shared,
watch_tx,
expiration_time_tx,
_task,
})
}

pub fn handle_block_update(&self, block_id: &BlockId) {
pub fn handle_block_update(&self, block_id: &BlockId, is_missing: bool) {
// Not inlining these lines to call `SystemTime::now()` only once the `lock` is acquired.
let mut lock = self.shared.lock().unwrap();
lock.handle_block_update(block_id, SystemTime::now());
lock.insert_block(block_id, SystemTime::now());
if is_missing {
lock.to_missing_if_expired.insert(*block_id);
}
drop(lock);
self.watch_tx.send(()).unwrap_or(());
}

pub fn handle_block_removed(&self, block: &BlockId) {
self.shared.lock().unwrap().handle_block_removed(block);
self.shared.lock().unwrap().remove_block(block);
}

pub fn set_expiration_time(&self, expiration_time: Duration) {
Expand All @@ -109,12 +127,14 @@ struct Shared {
//
blocks_by_id: HashMap<BlockId, TimeUpdated>,
blocks_by_expiration: BTreeMap<TimeUpdated, HashSet<BlockId>>,

to_missing_if_expired: HashSet<BlockId>,
}

impl Shared {
/// Add the `block` into `Self`. If it's already there, remove it and add it back with the new
/// time stamp.
fn handle_block_update(&mut self, block: &BlockId, ts: TimeUpdated) {
fn insert_block(&mut self, block: &BlockId, ts: TimeUpdated) {
// Asserts and unwraps are OK due to the `Shared` invariants defined above.
match self.blocks_by_id.entry(*block) {
hash_map::Entry::Occupied(mut entry) => {
Expand Down Expand Up @@ -155,8 +175,7 @@ impl Shared {
}
}

/// Remove `block` from `Self`.
fn handle_block_removed(&mut self, block: &BlockId) {
fn remove_block(&mut self, block: &BlockId) {
// Asserts and unwraps are OK due to the `Shared` invariants defined above.
let ts = match self.blocks_by_id.entry(*block) {
hash_map::Entry::Occupied(entry) => entry.remove(),
Expand Down Expand Up @@ -204,27 +223,51 @@ async fn run_task(
pool: db::Pool,
mut watch_rx: uninitialized_watch::Receiver<()>,
mut expiration_time_rx: watch::Receiver<Duration>,
block_download_tracker: BlockDownloadTracker,
client_reload_index_tx: broadcast_hash_set::Sender<PublicKey>,
cache: Arc<Cache>,
) -> Result<(), Error> {
loop {
let expiration_time = *expiration_time_rx.borrow();

let (ts, block) = {
let oldest_entry = shared
.lock()
.unwrap()
.blocks_by_expiration
.first_entry()
// Unwrap OK due to the invariant #2.
.map(|e| (*e.key(), *e.get().iter().next().unwrap()));

match oldest_entry {
Some((ts, block)) => (ts, block),
None => {
enum Enum {
OldestEntry(Option<(TimeUpdated, BlockId)>),
ToMissing(HashSet<BlockId>),
}

match {
let mut lock = shared.lock().unwrap();

if !lock.to_missing_if_expired.is_empty() {
Enum::ToMissing(std::mem::take(&mut lock.to_missing_if_expired))
} else {
Enum::OldestEntry(
lock.blocks_by_expiration
.first_entry()
// Unwrap OK due to the invariant #2.
.map(|e| (*e.key(), *e.get().iter().next().unwrap())),
)
}
} {
Enum::OldestEntry(Some((time_updated, block_id))) => (time_updated, block_id),
Enum::OldestEntry(None) => {
if watch_rx.changed().await.is_err() {
return Ok(());
}
continue;
}
Enum::ToMissing(to_missing_if_expired) => {
set_as_missing_if_expired(
&pool,
to_missing_if_expired,
&block_download_tracker,
&client_reload_index_tx,
cache.begin(),
)
.await?;
continue;
}
}
};

Expand All @@ -238,6 +281,9 @@ async fn run_task(
_ = expiration_time_rx.changed() => {
continue;
}
_ = watch_rx.changed() => {
continue;
}
}
}

Expand Down Expand Up @@ -290,12 +336,60 @@ async fn run_task(
// TODO: Should we then restart the tracker or do we rely on the fact that if committing
// into the database fails, then we know the app will be closed? The situation would
// resolve itself upon restart.
shared.lock().unwrap().handle_block_removed(&block);
shared.lock().unwrap().remove_block(&block);

tx.commit().await?;
}
}

async fn set_as_missing_if_expired(
pool: &db::Pool,
block_ids: HashSet<BlockId>,
block_download_tracker: &BlockDownloadTracker,
client_reload_index_tx: &broadcast_hash_set::Sender<PublicKey>,
mut cache: CacheTransaction,
) -> Result<(), Error> {
let mut tx = pool.begin_write().await?;

// Branches where we have newly missing blocks. We need to tell the client to reload indices
// for these branches from peers.
let mut branches: HashSet<PublicKey> = HashSet::default();

for block_id in &block_ids {
let changed = leaf_node::set_missing_if_expired(&mut tx, block_id).await?;

if !changed {
continue;
}

block_download_tracker.require(*block_id);

let nodes: Vec<_> = leaf_node::load_parent_hashes(&mut tx, block_id)
.try_collect()
.await?;

for (hash, _state) in index::update_summaries(
&mut tx,
&mut cache,
nodes,
UpdateSummaryReason::BlockRemoved,
)
.await?
{
try_collect_into(root_node::load_writer_ids(&mut tx, &hash), &mut branches).await?;
}
}

tx.commit().await?;

for branch_id in branches {
// TODO: Throttle these messages.
client_reload_index_tx.insert(&branch_id);
}

Ok(())
}

#[cfg(test)]
mod test {
use super::super::*;
Expand All @@ -308,32 +402,33 @@ mod test {
let mut shared = Shared {
blocks_by_id: Default::default(),
blocks_by_expiration: Default::default(),
to_missing_if_expired: Default::default(),
};

// add once

let ts = SystemTime::now();
let block: BlockId = rand::random();

shared.handle_block_update(&block, ts);
shared.insert_block(&block, ts);

assert_eq!(*shared.blocks_by_id.get(&block).unwrap(), ts);
shared.assert_invariants();

shared.handle_block_removed(&block);
shared.remove_block(&block);

assert!(shared.blocks_by_id.is_empty());
shared.assert_invariants();

// add twice

shared.handle_block_update(&block, ts);
shared.handle_block_update(&block, ts);
shared.insert_block(&block, ts);
shared.insert_block(&block, ts);

assert_eq!(*shared.blocks_by_id.get(&block).unwrap(), ts);
shared.assert_invariants();

shared.handle_block_removed(&block);
shared.remove_block(&block);

assert!(shared.blocks_by_id.is_empty());
shared.assert_invariants();
Expand Down Expand Up @@ -363,15 +458,20 @@ mod test {

assert_eq!(count_blocks(store.db()).await, 1);

let tracker =
BlockExpirationTracker::enable_expiration(store.db().clone(), Duration::from_secs(1))
.await
.unwrap();
let tracker = BlockExpirationTracker::enable_expiration(
store.db().clone(),
Duration::from_secs(1),
BlockDownloadTracker::new(),
broadcast_hash_set::channel().0,
Arc::new(Cache::new()),
)
.await
.unwrap();

sleep(Duration::from_millis(700)).await;

let block_id = add_block(&write_keys, &branch_id, &store).await;
tracker.handle_block_update(&block_id);
tracker.handle_block_update(&block_id, false);

assert_eq!(count_blocks(store.db()).await, 2);

Expand Down
6 changes: 3 additions & 3 deletions lib/src/store/changeset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ impl Changeset {
patch.save(tx, &self.bump, write_keys).await?;

for block in self.blocks {
block::write(tx.db(), &block).await?;

if let Some(tracker) = &tx.block_expiration_tracker {
tracker.handle_block_update(&block.id);
tracker.handle_block_update(&block.id, false);
}

block::write(tx.db(), &block).await?;
}

Ok(())
Expand Down
Loading

0 comments on commit 7e8fe2f

Please sign in to comment.