Skip to content

Commit

Permalink
fix(cache): clean pending request immediately when future cancel (ris…
Browse files Browse the repository at this point in the history
…ingwavelabs#4422)

* fix future drop

Signed-off-by: Little-Wallace <[email protected]>

* clean pending request

Signed-off-by: Little-Wallace <[email protected]>

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
Little-Wallace and mergify[bot] authored Aug 4, 2022
1 parent b016a50 commit fe9cc4d
Showing 1 changed file with 82 additions and 7 deletions.
89 changes: 82 additions & 7 deletions src/common/src/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,21 @@ impl<K: LruKey, T: LruValue> LruCache<K, T> {
}
}

pub struct CleanCacheGuard<'a, K: LruKey + Clone + 'static, T: LruValue + 'static> {
cache: &'a Arc<LruCache<K, T>>,
key: K,
hash: u64,
success: bool,
}

impl<'a, K: LruKey + Clone + 'static, T: LruValue + 'static> Drop for CleanCacheGuard<'a, K, T> {
fn drop(&mut self) {
if !self.success {
self.cache.clear_pending_request(&self.key, self.hash);
}
}
}

/// Only implement `lookup_with_request_dedup` for static values, as they can be sent across tokio
/// spawned futures.
impl<K: LruKey + Clone + 'static, T: LruValue + 'static> LruCache<K, T> {
Expand All @@ -790,20 +805,28 @@ impl<K: LruKey + Clone + 'static, T: LruValue + 'static> LruCache<K, T> {
LookupResult::Miss => {
let this = self.clone();
let fetch_value = fetch_value();
tokio::spawn(async move {
let key2 = key.clone();
let mut guard = CleanCacheGuard {
cache: self,
key,
hash,
success: false,
};
let ret = tokio::spawn(async move {
match fetch_value.await {
Ok((value, charge)) => {
let entry = this.insert(key, hash, charge, value);
let entry = this.insert(key2, hash, charge, value);
Ok(Ok(entry))
}
Err(e) => {
this.clear_pending_request(&key, hash);
Ok(Err(e))
}
Err(e) => Ok(Err(e)),
}
})
.await
.unwrap()
.unwrap();
if let Ok(Ok(_)) = ret.as_ref() {
guard.success = true;
}
ret
}
}
}
Expand Down Expand Up @@ -841,9 +864,13 @@ impl<K: LruKey, T: LruValue> Drop for CachableEntry<K, T> {
mod tests {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::pin::Pin;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering::Relaxed;
use std::sync::Arc;
use std::task::{Context, Poll};

use futures::FutureExt;
use rand::rngs::SmallRng;
use rand::{RngCore, SeedableRng};
use tokio::sync::oneshot::error::TryRecvError;
Expand Down Expand Up @@ -1235,4 +1262,52 @@ mod tests {
drop(cache);
assert!(listener.released.lock().is_empty());
}

pub struct SyncPointFuture<F: Future> {
inner: F,
polled: Arc<AtomicBool>,
}

impl<F: Future + Unpin> Future for SyncPointFuture<F> {
type Output = ();

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.polled.load(Ordering::Acquire) {
return Poll::Ready(());
}
self.inner.poll_unpin(cx).map(|_| ())
}
}

#[tokio::test(flavor = "multi_thread")]
async fn test_future_cancel() {
let cache: Arc<LruCache<u64, u64>> = Arc::new(LruCache::new(0, 5));
// do not need sender because this receiver will be cancelled.
let (_, recv) = channel::<()>();
let polled = Arc::new(AtomicBool::new(false));
let cache2 = cache.clone();
let polled2 = polled.clone();
let f = Box::pin(async move {
cache2
.lookup_with_request_dedup(1, 2, || async move {
polled2.store(true, Ordering::Release);
recv.await.map(|_| (1, 1))
})
.await
.unwrap()
.unwrap();
});
let wrapper = SyncPointFuture {
inner: f,
polled: polled.clone(),
};
{
let handle = tokio::spawn(async move {
wrapper.await;
});
while !polled.load(Ordering::Acquire) {}
handle.await.unwrap();
}
assert!(cache.shards[0].lock().write_request.is_empty());
}
}

0 comments on commit fe9cc4d

Please sign in to comment.