Skip to content

Commit

Permalink
Merge branch 'expiration-race'
Browse files Browse the repository at this point in the history
  • Loading branch information
inetic committed Sep 15, 2023
2 parents 39ee967 + 77aa21b commit 4fd53de
Show file tree
Hide file tree
Showing 5 changed files with 544 additions and 74 deletions.
84 changes: 76 additions & 8 deletions lib/src/db/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,20 @@ use std::{
ops::{Deref, DerefMut},
panic::Location,
path::Path,
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
time::Duration,
};
#[cfg(test)]
use tempfile::TempDir;
use thiserror::Error;
use tokio::{fs, sync::OwnedMutexGuard as AsyncOwnedMutexGuard, task};

#[cfg(test)]
use crate::sync::break_point::BreakPoint;

const WARN_AFTER_TRANSACTION_LIFETIME: Duration = Duration::from_secs(3);

pub(crate) use self::connection::Connection;
Expand All @@ -41,6 +48,7 @@ pub(crate) struct Pool {
reads: SqlitePool,
// Pool with a single writable connection.
write: SqlitePool,
next_transaction_id: Arc<AtomicU64>,
}

impl Pool {
Expand All @@ -66,7 +74,11 @@ impl Pool {
.connect_with(read_options)
.await?;

Ok(Self { reads, write })
Ok(Self {
reads,
write,
next_transaction_id: Arc::new(AtomicU64::new(0)),
})
}

/// Acquire a read-only database connection.
Expand Down Expand Up @@ -117,7 +129,9 @@ impl Pool {
#[track_caller]
pub fn begin_write(&self) -> impl Future<Output = Result<WriteTransaction, sqlx::Error>> + '_ {
let location = Location::caller();
WriteTransaction::begin(&self.write, location)
let transaction_id =
TransactionId(self.next_transaction_id.fetch_add(1, Ordering::Relaxed));
WriteTransaction::begin(&self.write, location, transaction_id)
}

pub(crate) async fn close(&self) -> Result<(), sqlx::Error> {
Expand Down Expand Up @@ -183,15 +197,18 @@ impl fmt::Debug for ReadTransaction {
impl_executor_by_deref!(ReadTransaction);

/// Transaction that allows both reading and writing.
#[derive(Debug)]
pub(crate) struct WriteTransaction {
inner: ReadTransaction,
id: TransactionId,
#[cfg(test)]
break_on_commit: Option<BreakPoint>,
}

impl WriteTransaction {
async fn begin(
pool: &SqlitePool,
location: &'static Location<'static>,
id: TransactionId,
) -> Result<Self, sqlx::Error> {
let tx = pool.begin().await?;
let track_lifetime = ExpectShortLifetime::new_in(WARN_AFTER_TRANSACTION_LIFETIME, location);
Expand All @@ -201,6 +218,9 @@ impl WriteTransaction {
inner: tx,
track_lifetime: Some(track_lifetime),
},
id: id,

Check warning on line 221 in lib/src/db/mod.rs

View workflow job for this annotation

GitHub Actions / check (windows)

redundant field names in struct initialization

Check warning on line 221 in lib/src/db/mod.rs

View workflow job for this annotation

GitHub Actions / check (android)

redundant field names in struct initialization
#[cfg(test)]
break_on_commit: None,
})
}

Expand All @@ -211,8 +231,17 @@ impl WriteTransaction {
/// If the future returned by this function is cancelled before completion, the transaction
/// is guaranteed to be either committed or rolled back but there is no way to tell in advance
/// which of the two operations happens.
pub async fn commit(self) -> Result<(), sqlx::Error> {
self.inner.inner.commit().await
pub async fn commit(self) -> Result<CommitId, sqlx::Error> {
let result = self.inner.inner.commit().await;

#[cfg(test)]
if let Some(mut break_point) = self.break_on_commit {
// Unwrap is OK because this is code is only executed in tests and we want to make sure
// the BreakPointController is used appropriately.
break_point.hit().await.unwrap();
}

result.map(|()| CommitId(self.id.0))
}

/// Commits the transaction and if (and only if) the commit completes successfully, runs the
Expand Down Expand Up @@ -250,16 +279,25 @@ impl WriteTransaction {
/// case and disabling the guard but there is nothing to do about number 4.
pub async fn commit_and_then<F, R>(self, f: F) -> Result<R, sqlx::Error>
where
F: FnOnce() -> R + Send + 'static,
F: FnOnce(CommitId) -> R + Send + 'static,
R: Send + 'static,
{
let span = Span::current();
let f = move || span.in_scope(f);
let f = move |commit_id| span.in_scope(|| f(commit_id));

task::spawn(async move { self.commit().await.map(|()| f()) })
task::spawn(async move { self.commit().await.map(f) })
.await
.unwrap()
}

pub fn id(&self) -> TransactionId {
self.id
}

#[cfg(test)]
pub fn break_on_commit(&mut self, break_point: BreakPoint) {
self.break_on_commit = Some(break_point);
}
}

impl Deref for WriteTransaction {
Expand All @@ -276,6 +314,36 @@ impl DerefMut for WriteTransaction {
}
}

impl std::fmt::Debug for WriteTransaction {
fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> {
write!(f, "WriteTransaction{{ inner:{:?} }}", self.inner)
}
}

#[derive(Eq, PartialEq, Ord, PartialOrd, Debug, Copy, Clone)]
pub(crate) struct TransactionId(u64);

impl TransactionId {
#[cfg(test)]
pub fn new(n: u64) -> Self {
Self(n)
}
}

#[derive(Debug, Copy, Clone)]
pub(crate) struct CommitId(u64);

impl CommitId {
pub fn as_transaction_id(&self) -> TransactionId {
TransactionId(self.0)
}

#[cfg(test)]
pub fn new(n: u64) -> Self {
Self(n)
}
}

impl_executor_by_deref!(WriteTransaction);

/// Shared write transaction
Expand Down
Loading

0 comments on commit 4fd53de

Please sign in to comment.