Skip to content

Commit

Permalink
Fix the expiration_race test
Browse files Browse the repository at this point in the history
  • Loading branch information
inetic committed Sep 15, 2023
1 parent 3c11db2 commit be43243
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 63 deletions.
56 changes: 49 additions & 7 deletions lib/src/db/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ use std::{
ops::{Deref, DerefMut},
panic::Location,
path::Path,
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
time::Duration,
};
#[cfg(test)]
Expand All @@ -44,6 +48,7 @@ pub(crate) struct Pool {
reads: SqlitePool,
// Pool with a single writable connection.
write: SqlitePool,
next_transaction_id: Arc<AtomicU64>,
}

impl Pool {
Expand All @@ -69,7 +74,11 @@ impl Pool {
.connect_with(read_options)
.await?;

Ok(Self { reads, write })
Ok(Self {
reads,
write,
next_transaction_id: Arc::new(AtomicU64::new(0)),
})
}

/// Acquire a read-only database connection.
Expand Down Expand Up @@ -120,7 +129,9 @@ impl Pool {
#[track_caller]
pub fn begin_write(&self) -> impl Future<Output = Result<WriteTransaction, sqlx::Error>> + '_ {
let location = Location::caller();
WriteTransaction::begin(&self.write, location)
let transaction_id =
TransactionId(self.next_transaction_id.fetch_add(1, Ordering::Relaxed));
WriteTransaction::begin(&self.write, location, transaction_id)
}

pub(crate) async fn close(&self) -> Result<(), sqlx::Error> {
Expand Down Expand Up @@ -188,6 +199,7 @@ impl_executor_by_deref!(ReadTransaction);
/// Transaction that allows both reading and writing.
pub(crate) struct WriteTransaction {
inner: ReadTransaction,
id: TransactionId,
#[cfg(test)]
break_on_commit: Option<BreakPoint>,
}
Expand All @@ -196,6 +208,7 @@ impl WriteTransaction {
async fn begin(
pool: &SqlitePool,
location: &'static Location<'static>,
id: TransactionId,
) -> Result<Self, sqlx::Error> {
let tx = pool.begin().await?;
let track_lifetime = ExpectShortLifetime::new_in(WARN_AFTER_TRANSACTION_LIFETIME, location);
Expand All @@ -205,6 +218,7 @@ impl WriteTransaction {
inner: tx,
track_lifetime: Some(track_lifetime),
},
id: id,
#[cfg(test)]
break_on_commit: None,
})
Expand All @@ -217,7 +231,7 @@ impl WriteTransaction {
/// If the future returned by this function is cancelled before completion, the transaction
/// is guaranteed to be either committed or rolled back but there is no way to tell in advance
/// which of the two operations happens.
pub async fn commit(self) -> Result<(), sqlx::Error> {
pub async fn commit(self) -> Result<CommitId, sqlx::Error> {
let result = self.inner.inner.commit().await;

#[cfg(test)]
Expand All @@ -227,7 +241,7 @@ impl WriteTransaction {
break_point.hit().await.unwrap();
}

result
result.map(|()| CommitId(self.id.0))
}

/// Commits the transaction and if (and only if) the commit completes successfully, runs the
Expand Down Expand Up @@ -265,17 +279,21 @@ impl WriteTransaction {
/// case and disabling the guard but there is nothing to do about number 4.
pub async fn commit_and_then<F, R>(self, f: F) -> Result<R, sqlx::Error>
where
F: FnOnce() -> R + Send + 'static,
F: FnOnce(CommitId) -> R + Send + 'static,
R: Send + 'static,
{
let span = Span::current();
let f = move || span.in_scope(f);
let f = move |commit_id| span.in_scope(|| f(commit_id));

task::spawn(async move { self.commit().await.map(|()| f()) })
task::spawn(async move { self.commit().await.map(f) })
.await
.unwrap()
}

pub fn id(&self) -> TransactionId {
self.id
}

#[cfg(test)]
pub fn break_on_commit(&mut self, break_point: BreakPoint) {
self.break_on_commit = Some(break_point);
Expand All @@ -302,6 +320,30 @@ impl std::fmt::Debug for WriteTransaction {
}
}

#[derive(Eq, PartialEq, Ord, PartialOrd, Debug, Copy, Clone)]
pub(crate) struct TransactionId(u64);

impl TransactionId {
#[cfg(test)]
pub fn new(n: u64) -> Self {
Self(n)
}
}

#[derive(Debug, Copy, Clone)]
pub(crate) struct CommitId(u64);

impl CommitId {
pub fn as_transaction_id(&self) -> TransactionId {
TransactionId(self.0)
}

#[cfg(test)]
pub fn new(n: u64) -> Self {
Self(n)
}
}

impl_executor_by_deref!(WriteTransaction);

/// Shared write transaction
Expand Down
123 changes: 82 additions & 41 deletions lib/src/store/block_expiration_tracker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ impl BlockExpirationTracker {
let now = SystemTime::now();

while let Some(id) = ids.next().await {
shared.insert_block(&id?, now);
shared.insert_block(&id?, now, None);
}

let (watch_tx, watch_rx) = uninitialized_watch::channel();
Expand Down Expand Up @@ -92,10 +92,15 @@ impl BlockExpirationTracker {
})
}

pub fn handle_block_update(&self, block_id: &BlockId, is_missing: bool) {
pub fn handle_block_update(
&self,
block_id: &BlockId,
is_missing: bool,
transaction_id: Option<db::TransactionId>,
) {
// Not inlining these lines to call `SystemTime::now()` only once the `lock` is acquired.
let mut lock = self.shared.lock().unwrap();
lock.insert_block(block_id, SystemTime::now());
lock.insert_block(block_id, SystemTime::now(), transaction_id);
if is_missing {
lock.to_missing_if_expired.insert(*block_id);
}
Expand Down Expand Up @@ -141,29 +146,37 @@ impl UntrackTransaction {
self.block_ids.insert(block_id);
}

pub fn commit(self) {
// We require CommitId here instead of the TransactionId to ensure this function is called only
// after the removal of the blocks has been successfully committed into the database.
pub fn commit(self, commit_id: db::CommitId) {
if self.block_ids.is_empty() {
return;
}

let mut shared = self.shared.lock().unwrap();

for block_id in &self.block_ids {
shared.remove_block(block_id);
shared.remove_block(block_id, commit_id);
}
}
}

// For semantics
type TimeUpdated = SystemTime;

#[derive(Eq, PartialEq, Debug)]
struct Entry {
time_updated: TimeUpdated,
transaction_id: Option<db::TransactionId>,
}

struct Shared {
// Invariant #1: There exists (`block`, `ts`) in `blocks_by_id` iff there exists `block` in
// `blocks_by_expiration[ts]`.
// Invariant #1: There exists `(block, Entry { time_updated: ts, ..})` in `blocks_by_id` *iff*
// there exists `block` in `blocks_by_expiration[ts]`.
//
// Invariant #2: `blocks_by_expiration[x]` is never empty for any `x`.
//
blocks_by_id: HashMap<BlockId, TimeUpdated>,
blocks_by_id: HashMap<BlockId, Entry>,
blocks_by_expiration: BTreeMap<TimeUpdated, HashSet<BlockId>>,

to_missing_if_expired: HashSet<BlockId>,
Expand All @@ -172,17 +185,28 @@ struct Shared {
impl Shared {
/// Add the `block` into `Self`. If it's already there, remove it and add it back with the new
/// time stamp.
fn insert_block(&mut self, block: &BlockId, ts: TimeUpdated) {
fn insert_block(
&mut self,
block: &BlockId,
ts: TimeUpdated,
transaction_id: Option<db::TransactionId>,
) {
// Asserts and unwraps are OK due to the `Shared` invariants defined above.
match self.blocks_by_id.entry(*block) {
hash_map::Entry::Occupied(mut entry) => {
let old_ts = *entry.get();
let Entry {
time_updated: old_ts,
transaction_id: old_id,
} = *entry.get();

if old_ts == ts {
if (old_ts, old_id) == (ts, transaction_id) {
return;
}

entry.insert(ts);
entry.insert(Entry {
time_updated: ts,
transaction_id,
});

let mut entry = match self.blocks_by_expiration.entry(old_ts) {
btree_map::Entry::Occupied(entry) => entry,
Expand All @@ -208,19 +232,29 @@ impl Shared {
.or_insert_with(Default::default)
.insert(*block));

entry.insert(ts);
entry.insert(Entry {
time_updated: ts,
transaction_id,
});
}
}
}

fn remove_block(&mut self, block: &BlockId) {
fn remove_block(&mut self, block: &BlockId, commit_id: db::CommitId) {
// Asserts and unwraps are OK due to the `Shared` invariants defined above.
let ts = match self.blocks_by_id.entry(*block) {
hash_map::Entry::Occupied(entry) => entry.remove(),
let time_updated = match self.blocks_by_id.entry(*block) {
hash_map::Entry::Occupied(entry) => {
if entry.get().transaction_id >= Some(commit_id.as_transaction_id()) {
// A race condition happened and operations switched order, we need to ignore
// this one. See the `expiration_race` test for more information.
return;
}
entry.remove().time_updated
}
hash_map::Entry::Vacant(_) => return,
};

let mut entry = match self.blocks_by_expiration.entry(ts) {
let mut entry = match self.blocks_by_expiration.entry(time_updated) {
btree_map::Entry::Occupied(entry) => entry,
btree_map::Entry::Vacant(_) => unreachable!(),
};
Expand All @@ -235,13 +269,19 @@ impl Shared {
#[cfg(test)]
fn assert_invariants(&self) {
// #1 =>
for (block, ts) in self.blocks_by_id.iter() {
for (
block,
Entry {
time_updated: ts, ..
},
) in self.blocks_by_id.iter()
{
assert!(self.blocks_by_expiration.get(ts).unwrap().contains(block));
}
// #1 <=
for (ts, blocks) in self.blocks_by_expiration.iter() {
for block in blocks.iter() {
assert_eq!(self.blocks_by_id.get(block).unwrap(), ts);
assert_eq!(&self.blocks_by_id.get(block).unwrap().time_updated, ts);
}
}
// Degenerate case
Expand Down Expand Up @@ -363,20 +403,9 @@ async fn run_task(

tracing::warn!("Block {block:?} has expired");

// We need to remove the block from `shared` here while the database is locked for writing.
// If we did it after the commit, it could happen that after the commit but between the
// removal someone re-adds the block into the database. That would result in there being a
// block in the database, but not in the BlockExpirationTracker. That is, the block will
// have been be forgotten by the tracker.
//
// Such situation can also happen in this current scenario where we first remove the block
// from `shared` and then the commit fails. But that case will be detected.
// TODO: Should we then restart the tracker or do we rely on the fact that if committing
// into the database fails, then we know the app will be closed? The situation would
// resolve itself upon restart.
shared.lock().unwrap().remove_block(&block);

tx.commit().await?;
let commit_id = tx.commit().await?;

shared.lock().unwrap().remove_block(&block, commit_id);
}
}

Expand Down Expand Up @@ -448,25 +477,37 @@ mod test {
let ts = SystemTime::now();
let block: BlockId = rand::random();

shared.insert_block(&block, ts);
shared.insert_block(&block, ts, Some(db::TransactionId::new(0)));

assert_eq!(*shared.blocks_by_id.get(&block).unwrap(), ts);
assert_eq!(
*shared.blocks_by_id.get(&block).unwrap(),
Entry {
time_updated: ts,
transaction_id: Some(db::TransactionId::new(0))
}
);
shared.assert_invariants();

shared.remove_block(&block);
shared.remove_block(&block, db::CommitId::new(1));

assert!(shared.blocks_by_id.is_empty());
shared.assert_invariants();

// add twice

shared.insert_block(&block, ts);
shared.insert_block(&block, ts);
shared.insert_block(&block, ts, Some(db::TransactionId::new(2)));
shared.insert_block(&block, ts, Some(db::TransactionId::new(3)));

assert_eq!(*shared.blocks_by_id.get(&block).unwrap(), ts);
assert_eq!(
*shared.blocks_by_id.get(&block).unwrap(),
Entry {
time_updated: ts,
transaction_id: Some(db::TransactionId::new(3))
}
);
shared.assert_invariants();

shared.remove_block(&block);
shared.remove_block(&block, db::CommitId::new(4));

assert!(shared.blocks_by_id.is_empty());
shared.assert_invariants();
Expand Down Expand Up @@ -509,7 +550,7 @@ mod test {
sleep(Duration::from_millis(700)).await;

let block_id = add_block(rand::random(), &write_keys, &branch_id, &store).await;
tracker.handle_block_update(&block_id, false);
tracker.handle_block_update(&block_id, false, Some(db::TransactionId::new(1)));

assert_eq!(count_blocks(store.db()).await, 2);

Expand Down
4 changes: 3 additions & 1 deletion lib/src/store/changeset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,13 @@ impl Changeset {

patch.save(tx, &self.bump, write_keys).await?;

let tx_id = tx.db().id();

for block in self.blocks {
block::write(tx.db(), &block).await?;

if let Some(tracker) = &tx.block_expiration_tracker {
tracker.handle_block_update(&block.id, false);
tracker.handle_block_update(&block.id, false, Some(tx_id));
}
}

Expand Down
Loading

0 comments on commit be43243

Please sign in to comment.