Skip to content

Commit

Permalink
Implement caching via tasks instead of shared mutexes
Browse files Browse the repository at this point in the history
This helps prevent a rare suboptimal case where a visitor to the
playground triggered the crates / versions request but then left the
site before the request finished, resulting in the request being
canceled and the work being wasted without being cached. This mostly
showed up when running heavy load tests locally to try and suss out
other bugs.

Other benefits:

- This will also result in reduced contention when multiple requests
  would have triggered a cache refresh. Only one computation should
  occur.

- The cache value is now computed out-of-band and requests should not
  block on it.
  • Loading branch information
shepmaster committed Nov 24, 2024
1 parent db8a083 commit f116c28
Show file tree
Hide file tree
Showing 2 changed files with 358 additions and 159 deletions.
251 changes: 92 additions & 159 deletions ui/src/server_axum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use axum_extra::{
headers::{authorization::Bearer, Authorization, CacheControl, ETag, IfNoneMatch},
TypedHeader,
};
use futures::{future::BoxFuture, FutureExt};
use futures::{future::BoxFuture, FutureExt, TryFutureExt};
use orchestrator::coordinator::{self, CoordinatorFactory, DockerBackend, TRACKED_CONTAINERS};
use snafu::prelude::*;
use std::{
Expand All @@ -34,9 +34,9 @@ use std::{
mem, path,
str::FromStr,
sync::{Arc, LazyLock},
time::{Duration, Instant, SystemTime, UNIX_EPOCH},
time::{Duration, Instant, UNIX_EPOCH},
};
use tokio::{select, sync::Mutex};
use tokio::{select, sync::mpsc};
use tower_http::{
cors::{self, CorsLayer},
request_id::{MakeRequestUuid, PropagateRequestIdLayer, SetRequestIdLayer},
Expand All @@ -48,25 +48,35 @@ use tracing::{error, error_span, field};

use crate::{env::PLAYGROUND_GITHUB_TOKEN, public_http_api as api};

use cache::{
cache_task, CacheTaskItem, CacheTx, CacheTxError, Stamped, SANDBOX_CACHE_TIME_TO_LIVE,
};

const ONE_HOUR: Duration = Duration::from_secs(60 * 60);
const CORS_CACHE_TIME_TO_LIVE: Duration = ONE_HOUR;

const TEN_MINUTES: Duration = Duration::from_secs(10 * 60);
const SANDBOX_CACHE_TIME_TO_LIVE: Duration = TEN_MINUTES;
const CORS_CACHE_TIME_TO_LIVE: Duration = ONE_HOUR;

const MAX_AGE_ONE_DAY: HeaderValue = HeaderValue::from_static("public, max-age=86400");
const MAX_AGE_ONE_YEAR: HeaderValue = HeaderValue::from_static("public, max-age=31536000");

const DOCKER_PROCESS_TIMEOUT_SOFT: Duration = Duration::from_secs(10);

mod cache;
mod websocket;

#[derive(Debug, Clone)]
struct Factory(Arc<CoordinatorFactory>);

#[tokio::main]
pub(crate) async fn serve(config: Config) {
let factory = Factory(Arc::new(config.coordinator_factory()));
let factory = Arc::new(config.coordinator_factory());

let (cache_crates_task, cache_crates_tx) =
CacheTx::spawn(|rx| cache_crates_task(factory.clone(), rx));
let (cache_versions_task, cache_versions_tx) =
CacheTx::spawn(|rx| cache_versions_task(factory.clone(), rx));

let factory = Factory(factory);

let request_db = config.request_database();
let (db_task, db_handle) = request_db.spawn();
Expand Down Expand Up @@ -101,7 +111,8 @@ pub(crate) async fn serve(config: Config) {
)
.layer(Extension(factory))
.layer(Extension(db_handle))
.layer(Extension(Arc::new(SandboxCache::default())))
.layer(Extension(cache_crates_tx))
.layer(Extension(cache_versions_tx))
.layer(Extension(config.github_token()))
.layer(Extension(config.feature_flags))
.layer(Extension(config.websocket_config));
Expand Down Expand Up @@ -167,6 +178,8 @@ pub(crate) async fn serve(config: Config) {
select! {
v = server => v.unwrap(),
v = db_task => v.unwrap(),
v = cache_crates_task => v.unwrap(),
v = cache_versions_task => v.unwrap(),
}
}

Expand Down Expand Up @@ -477,24 +490,22 @@ where
}

async fn meta_crates(
Extension(factory): Extension<Factory>,
Extension(cache): Extension<Arc<SandboxCache>>,
Extension(tx): Extension<CacheCratesTx>,
if_none_match: Option<TypedHeader<IfNoneMatch>>,
) -> Result<impl IntoResponse> {
let value =
track_metric_no_request_async(Endpoint::MetaCrates, || cache.crates(&factory.0)).await?;

let value = track_metric_no_request_async(Endpoint::MetaCrates, || tx.get())
.await
.context(CratesSnafu)?;
apply_timestamped_caching(value, if_none_match)
}

async fn meta_versions(
Extension(factory): Extension<Factory>,
Extension(cache): Extension<Arc<SandboxCache>>,
Extension(tx): Extension<CacheVersionsTx>,
if_none_match: Option<TypedHeader<IfNoneMatch>>,
) -> Result<impl IntoResponse> {
let value =
track_metric_no_request_async(Endpoint::MetaVersions, || cache.versions(&factory.0))
.await?;
let value = track_metric_no_request_async(Endpoint::MetaVersions, || tx.get())
.await
.context(VersionsSnafu)?;
apply_timestamped_caching(value, if_none_match)
}

Expand Down Expand Up @@ -640,6 +651,67 @@ impl MetricsAuthorization {
const FAILURE: MetricsAuthorizationRejection = (StatusCode::UNAUTHORIZED, "Wrong credentials");
}

type CacheCratesTx = CacheTx<api::MetaCratesResponse, CacheCratesError>;
type CacheCratesItem = CacheTaskItem<api::MetaCratesResponse, CacheCratesError>;

#[tracing::instrument(skip_all)]
async fn cache_crates_task(factory: Arc<CoordinatorFactory>, rx: mpsc::Receiver<CacheCratesItem>) {
cache_task(rx, move || {
let coordinator = factory.build::<DockerBackend>();

async move {
let crates = coordinator.crates().map_ok(From::from).await?;

coordinator.shutdown().await?;

Ok::<_, CacheCratesError>(crates)
}
.boxed()
})
.await
}

#[derive(Debug, Snafu)]
enum CacheCratesError {
#[snafu(transparent)]
Crates { source: coordinator::CratesError },

#[snafu(transparent)]
Shutdown { source: coordinator::Error },
}

type CacheVersionsTx = CacheTx<api::MetaVersionsResponse, CacheVersionsError>;
type CacheVersionsItem = CacheTaskItem<api::MetaVersionsResponse, CacheVersionsError>;

#[tracing::instrument(skip_all)]
async fn cache_versions_task(
factory: Arc<CoordinatorFactory>,
rx: mpsc::Receiver<CacheVersionsItem>,
) {
cache_task(rx, move || {
let coordinator = factory.build::<DockerBackend>();

async move {
let versions = coordinator.versions().map_ok(From::from).await?;

coordinator.shutdown().await?;

Ok::<_, CacheVersionsError>(versions)
}
.boxed()
})
.await
}

#[derive(Debug, Snafu)]
enum CacheVersionsError {
#[snafu(transparent)]
Versions { source: coordinator::VersionsError },

#[snafu(transparent)]
Shutdown { source: coordinator::Error },
}

#[async_trait]
impl<S> extract::FromRequestParts<S> for MetricsAuthorization
where
Expand Down Expand Up @@ -667,145 +739,6 @@ where
}
}

type Stamped<T> = (T, SystemTime);

#[derive(Debug, Default)]
struct SandboxCache {
crates: CacheOne<api::MetaCratesResponse>,
versions: CacheOne<api::MetaVersionsResponse>,
}

impl SandboxCache {
async fn crates(
&self,
factory: &CoordinatorFactory,
) -> Result<Stamped<api::MetaCratesResponse>> {
let coordinator = factory.build::<DockerBackend>();

let c = self
.crates
.fetch(|| async { Ok(coordinator.crates().await.context(CratesSnafu)?.into()) })
.await;

coordinator
.shutdown()
.await
.context(ShutdownCoordinatorSnafu)?;

c
}

async fn versions(
&self,
factory: &CoordinatorFactory,
) -> Result<Stamped<api::MetaVersionsResponse>> {
let coordinator = factory.build::<DockerBackend>();

let v = self
.versions
.fetch(|| async { Ok(coordinator.versions().await.context(VersionsSnafu)?.into()) })
.await;

coordinator
.shutdown()
.await
.context(ShutdownCoordinatorSnafu)?;

v
}
}

#[derive(Debug)]
struct CacheOne<T>(Mutex<Option<CacheInfo<T>>>);

impl<T> Default for CacheOne<T> {
fn default() -> Self {
Self(Default::default())
}
}

impl<T> CacheOne<T>
where
T: Clone + PartialEq,
{
async fn fetch<F, FFut>(&self, generator: F) -> Result<Stamped<T>>
where
F: FnOnce() -> FFut,
FFut: Future<Output = Result<T>>,
{
let data = &mut *self.0.lock().await;
match data {
Some(info) => {
if info.validation_time.elapsed() <= SANDBOX_CACHE_TIME_TO_LIVE {
Ok(info.stamped_value())
} else {
Self::set_value(data, generator).await
}
}
None => Self::set_value(data, generator).await,
}
}

async fn set_value<F, FFut>(data: &mut Option<CacheInfo<T>>, generator: F) -> Result<Stamped<T>>
where
F: FnOnce() -> FFut,
FFut: Future<Output = Result<T>>,
{
let value = generator().await?;

let old_info = data.take();
let new_info = CacheInfo::build(value);

let info = match old_info {
Some(mut old_value) => {
if old_value.value == new_info.value {
// The value hasn't changed; record that we have
// checked recently, but keep the creation time to
// preserve caching.
old_value.validation_time = new_info.validation_time;
old_value
} else {
new_info
}
}
None => new_info,
};

let value = info.stamped_value();

*data = Some(info);

Ok(value)
}
}

#[derive(Debug)]
struct CacheInfo<T> {
value: T,
creation_time: SystemTime,
validation_time: Instant,
}

impl<T> CacheInfo<T> {
fn build(value: T) -> Self {
let creation_time = SystemTime::now();
let validation_time = Instant::now();

Self {
value,
creation_time,
validation_time,
}
}

fn stamped_value(&self) -> Stamped<T>
where
T: Clone,
{
(self.value.clone(), self.creation_time)
}
}

impl IntoResponse for Error {
fn into_response(self) -> axum::response::Response {
let error = snafu::CleanedErrorText::new(&self)
Expand Down Expand Up @@ -901,12 +834,12 @@ enum Error {

#[snafu(display("Unable to find the available crates"))]
Crates {
source: orchestrator::coordinator::CratesError,
source: CacheTxError<CacheCratesError>,
},

#[snafu(display("Unable to find the available versions"))]
Versions {
source: orchestrator::coordinator::VersionsError,
source: CacheTxError<CacheVersionsError>,
},

#[snafu(display("Unable to shutdown the coordinator"))]
Expand Down
Loading

0 comments on commit f116c28

Please sign in to comment.