From 7ddf96cd997e77a54516a5cda0c2ffd83c2b7343 Mon Sep 17 00:00:00 2001 From: Christoph Herzog Date: Mon, 29 Jul 2024 21:50:54 +0200 Subject: [PATCH] feat: Implement image validation in BuiltinPackageLoader --- .../runtime/package_loader/builtin_loader.rs | 224 +++++++++++++++++- 1 file changed, 223 insertions(+), 1 deletion(-) diff --git a/lib/wasix/src/runtime/package_loader/builtin_loader.rs b/lib/wasix/src/runtime/package_loader/builtin_loader.rs index 521b10a89a3..18b0f8c49b1 100644 --- a/lib/wasix/src/runtime/package_loader/builtin_loader.rs +++ b/lib/wasix/src/runtime/package_loader/builtin_loader.rs @@ -34,6 +34,20 @@ pub struct BuiltinPackageLoader { cache: Option, /// A mapping from hostnames to tokens tokens: HashMap, + + hash_validation: HashIntegrityValidationMode, +} + +/// Defines how to validate package hash integrity. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum HashIntegrityValidationMode { + /// Do not validate anything. + /// Best for performance. + NoValidate, + /// Compute the image hash and produce a trace warning on hash mismatches. + WarnOnHashMismatch, + /// Compute the image hash and fail on a mismatch. + FailOnHashMismatch, } impl BuiltinPackageLoader { @@ -42,10 +56,19 @@ impl BuiltinPackageLoader { in_memory: InMemoryCache::default(), client: Arc::new(crate::http::default_http_client().unwrap()), cache: None, + hash_validation: HashIntegrityValidationMode::NoValidate, tokens: HashMap::new(), } } + /// Set the validation mode to apply after downloading an image. + /// + /// See [`HashIntegrityValidationMode`] for details. + pub fn with_hash_validation_mode(mut self, mode: HashIntegrityValidationMode) -> Self { + self.hash_validation = mode; + self + } + pub fn with_cache_dir(self, cache_dir: impl Into) -> Self { BuiltinPackageLoader { cache: Some(FileSystemCache { @@ -55,6 +78,44 @@ impl BuiltinPackageLoader { } } + pub fn validate_cache( + &self, + mode: CacheValidationMode, + ) -> Result, anyhow::Error> { + let cache = self + .cache + .as_ref() + .context("can not validate cache - no cache configured")?; + + let items = cache.validate_hashes()?; + let mut errors = Vec::new(); + for (path, error) in items { + match mode { + CacheValidationMode::WarnOnMismatch => { + tracing::warn!(?error, "hash mismatch in cached image file"); + } + CacheValidationMode::PruneOnMismatch => { + tracing::warn!(?error, "deleting cached image file due to hash mismatch"); + match std::fs::remove_file(&path) { + Ok(()) => {} + Err(error) if error.kind() == std::io::ErrorKind::NotFound => {} + Err(fs_err) => { + tracing::error!( + path=%error.source, + ?fs_err, + "could not delete cached image file with hash mismatch" + ); + } + } + } + } + + errors.push(error); + } + + Ok(errors) + } + pub fn with_http_client(self, client: impl HttpClient + Send + Sync + 'static) -> Self { self.with_shared_http_client(Arc::new(client)) } @@ -110,6 +171,40 @@ impl BuiltinPackageLoader { Ok(None) } + /// Validate image contents with the specified validation mode. + fn validate_hash( + image: &[u8], + mode: HashIntegrityValidationMode, + info: &DistributionInfo, + ) -> Result<(), anyhow::Error> { + match mode { + HashIntegrityValidationMode::NoValidate => { + // Nothing to do. + Ok(()) + } + HashIntegrityValidationMode::WarnOnHashMismatch => { + let actual_hash = WebcHash::sha256(image); + if actual_hash != info.webc_sha256 { + tracing::warn!(%info.webc_sha256, %actual_hash, "image hash mismatch - actual image hash does not match the expected hash!"); + } + Ok(()) + } + HashIntegrityValidationMode::FailOnHashMismatch => { + let actual_hash = WebcHash::sha256(image); + if actual_hash != info.webc_sha256 { + Err(ImageHashMismatchError { + source: info.webc.to_string(), + actual_hash, + expected_hash: info.webc_sha256, + } + .into()) + } else { + Ok(()) + } + } + } + } + #[tracing::instrument(level = "debug", skip_all, fields(%dist.webc, %dist.webc_sha256))] async fn download(&self, dist: &DistributionInfo) -> Result { if dist.webc.scheme() == "file" { @@ -121,6 +216,9 @@ impl BuiltinPackageLoader { }) .await? .with_context(|| format!("Unable to read \"{}\"", path.display()))?; + + Self::validate_hash(&bytes, self.hash_validation, dist)?; + return Ok(bytes.into()); } Err(e) => { @@ -167,6 +265,8 @@ impl BuiltinPackageLoader { let body = response.body.context("package download failed")?; tracing::debug!(%url, "package_download_succeeded"); + Self::validate_hash(&body, self.hash_validation, dist)?; + Ok(body.into()) } @@ -267,6 +367,35 @@ impl PackageLoader for BuiltinPackageLoader { } } +#[derive(Clone, Debug)] +pub struct ImageHashMismatchError { + source: String, + expected_hash: WebcHash, + actual_hash: WebcHash, +} + +impl std::fmt::Display for ImageHashMismatchError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "image hash mismatch! expected hash '{}', but the computed hash is '{}' (source '{}')", + self.expected_hash, self.actual_hash, self.source, + ) + } +} + +impl std::error::Error for ImageHashMismatchError {} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum CacheValidationMode { + /// Just emit a warning for all images where the filename doesn't match + /// the expected hash. + WarnOnMismatch, + /// Remove images from the cache if the filename doesn't match the actual + /// hash. + PruneOnMismatch, +} + // FIXME: This implementation will block the async runtime and should use // some sort of spawn_blocking() call to run it in the background. #[derive(Debug)] @@ -275,6 +404,66 @@ struct FileSystemCache { } impl FileSystemCache { + const FILE_SUFFIX: &'static str = ".bin"; + + /// Validate that the cached image file names correspond to their actual + /// file content hashes. + fn validate_hashes(&self) -> Result, anyhow::Error> { + let mut items = Vec::<(PathBuf, ImageHashMismatchError)>::new(); + + let iter = match std::fs::read_dir(&self.cache_dir) { + Ok(v) => v, + Err(err) if err.kind() == std::io::ErrorKind::NotFound => { + // Cache dir does not exist, so nothing to validate. + return Ok(Vec::new()); + } + Err(err) => { + return Err(err).with_context(|| { + format!( + "Could not read image cache dir: '{}'", + self.cache_dir.display() + ) + }); + } + }; + + for res in iter { + let entry = res?; + if !entry.file_type()?.is_file() { + continue; + } + + // Extract the hash from the filename. + + let hash_opt = entry + .file_name() + .to_str() + .and_then(|x| { + let (raw_hash, _) = x.split_once(Self::FILE_SUFFIX)?; + Some(raw_hash) + }) + .and_then(|x| WebcHash::parse_hex(x).ok()); + let Some(expected_hash) = hash_opt else { + continue; + }; + + // Compute the actual hash. + let path = entry.path(); + let actual_hash = WebcHash::for_file(&path)?; + + if actual_hash != expected_hash { + let err = ImageHashMismatchError { + source: path.to_string_lossy().to_string(), + actual_hash, + expected_hash, + }; + items.push((path, err)); + } + } + + Ok(items) + } + async fn lookup(&self, hash: &WebcHash) -> Result, Error> { let path = self.path(hash); @@ -357,7 +546,7 @@ impl FileSystemCache { for b in hash { write!(filename, "{b:02x}").unwrap(); } - filename.push_str(".bin"); + filename.push_str(Self::FILE_SUFFIX); self.cache_dir.join(filename) } @@ -484,3 +673,36 @@ mod tests { cache_misses_will_trigger_a_download_internal().await } } + +#[cfg(test)] +mod test { + use super::*; + + // NOTE: must be a tokio test because the BuiltinPackageLoader::new() + // constructor requires a runtime... + #[tokio::test] + async fn test_builtin_package_downloader_cache_validation() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path(); + + let contents = "fail"; + let correct_hash = WebcHash::sha256(&contents); + let used_hash = + WebcHash::parse_hex("0000a28ea38a000f3a3328cb7fabe330638d3258affe1a869e3f92986222d997") + .unwrap(); + let filename = format!("{}{}", used_hash, FileSystemCache::FILE_SUFFIX); + let file_path = path.join(filename); + std::fs::write(&file_path, contents).unwrap(); + + let dl = BuiltinPackageLoader::new().with_cache_dir(path); + + let errors = dl + .validate_cache(CacheValidationMode::PruneOnMismatch) + .unwrap(); + assert_eq!(errors.len(), 1); + assert_eq!(errors[0].actual_hash, correct_hash); + assert_eq!(errors[0].expected_hash, used_hash); + + assert_eq!(file_path.exists(), false); + } +}