Skip to content

Commit

Permalink
feat: Add an error code on the metering middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
jubianchi committed Dec 17, 2020
1 parent 44fb48e commit 84d2647
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 36 deletions.
48 changes: 29 additions & 19 deletions examples/metering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ use wasmer::CompilerConfig;
use wasmer::{imports, wat2wasm, Instance, Module, Store};
use wasmer_compiler_cranelift::Cranelift;
use wasmer_engine_jit::JIT;
use wasmer_middlewares::metering::{get_remaining_points, set_remaining_points, Metering};
use wasmer_middlewares::{
metering::{get_error, get_remaining_points, set_remaining_points},
Metering, MeteringError,
};

fn main() -> anyhow::Result<()> {
// Let's declare the Wasm module.
Expand Down Expand Up @@ -55,11 +58,10 @@ fn main() -> anyhow::Result<()> {

// Now let's create our metering middleware.
//
// `Metering` needs to be configured with a limit (the gas limit) and
// a cost function.
// `Metering` needs to be configured with a limit and a cost function.
//
// For each `Operator`, the metering middleware will call the cost
// function and subtract the cost from the gas.
// function and subtract the cost from the remaining points.
let metering = Arc::new(Metering::new(10, cost_function));
let mut compiler_config = Cranelift::default();
compiler_config.push_middleware(metering);
Expand Down Expand Up @@ -93,7 +95,7 @@ fn main() -> anyhow::Result<()> {
println!("Calling `add_one` function once...");
add_one.call(1)?;

// As you can see here, after the first call we have 6 remaining gas points.
// As you can see here, after the first call we have 6 remaining points.
//
// This is correct, here are the details of how it has been computed:
// * `local.get $value` is a `Operator::LocalGet` which costs 1 point;
Expand All @@ -110,7 +112,7 @@ fn main() -> anyhow::Result<()> {
println!("Calling `add_one` function twice...");
add_one.call(1)?;

// We spent 4 more gas points with the second call.
// We spent 4 more points with the second call.
// We have 2 remaining points.
let remaining_points_after_second_call = get_remaining_points(&instance);
assert_eq!(remaining_points_after_second_call, 2);
Expand All @@ -120,8 +122,8 @@ fn main() -> anyhow::Result<()> {
remaining_points_after_second_call
);

// Because calling our `add_one` function consumes 4 gas points,
// calling it a third time will fail: we already consume 8 gas
// Because calling our `add_one` function consumes 4 points,
// calling it a third time will fail: we already consume 8
// points, there are only two remaining.
println!("Calling `add_one` function a third time...");
match add_one.call(1) {
Expand All @@ -132,19 +134,27 @@ fn main() -> anyhow::Result<()> {
);
}
Err(_) => {
println!("Calling `add_one` failed: not enough gas points remaining.");
}
}
println!("Calling `add_one` failed.");

// Becasue the previous call failed, it did not consume any gas point.
// We still have 2 remaining points.
let remaining_points_after_third_call = get_remaining_points(&instance);
assert_eq!(remaining_points_after_third_call, 2);
// Because the last needed more than the remaining points, we got an error.
let error = get_error(&instance);

println!(
"Remaining points after third call: {:?}",
remaining_points_after_third_call
);
match error {
MeteringError::None => bail!("No metering error."),
MeteringError::OutOfPoints => println!("Not enough points remaining."),
_ => bail!("Unknown metering error."),
}

// There is now 0 remaining points (we can't go below 0).
let remaining_points_after_third_call = get_remaining_points(&instance);
assert_eq!(remaining_points_after_third_call, 0);

println!(
"Remaining points after third call: {:?}",
remaining_points_after_third_call
);
}
}

// Now let's see how we can set a new limit...
println!("Set new remaining points points to 10");
Expand Down
1 change: 1 addition & 0 deletions lib/middlewares/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ pub mod metering;
// The most commonly used symbol are exported at top level of the module. Others are available
// via modules, e.g. `wasmer_middlewares::metering::get_remaining_points`
pub use metering::Metering;
pub use metering::{get_error, get_remaining_points, set_remaining_points, MeteringError};
124 changes: 107 additions & 17 deletions lib/middlewares/src/metering.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! `metering` is a middleware for tracking how many operators are executed in total
//! and putting a limit on the total number of operators executed.
use std::convert::TryInto;
use std::convert::{TryFrom, TryInto};
use std::fmt;
use std::sync::Mutex;
use wasmer::wasmparser::{
Expand All @@ -11,7 +11,7 @@ use wasmer::{
ExportIndex, FunctionMiddleware, GlobalInit, GlobalType, Instance, LocalFunctionIndex,
MiddlewareReaderState, ModuleMiddleware, Mutability, Type,
};
use wasmer_types::GlobalIndex;
use wasmer_types::{GlobalIndex, Value};
use wasmer_vm::ModuleInfo;

/// The module-level metering middleware.
Expand All @@ -28,6 +28,9 @@ pub struct Metering<F: Fn(&Operator) -> u64 + Copy + Clone + Send + Sync> {
/// Function that maps each operator to a cost in "points".
cost_function: F,

/// The global index in the current module for the error code.
error_code_index: Mutex<Option<GlobalIndex>>,

/// The global index in the current module for remaining points.
remaining_points_index: Mutex<Option<GlobalIndex>>,
}
Expand All @@ -37,19 +40,53 @@ pub struct FunctionMetering<F: Fn(&Operator) -> u64 + Copy + Clone + Send + Sync
/// Function that maps each operator to a cost in "points".
cost_function: F,

/// The global index in the current module for the error code.
error_code_index: GlobalIndex,

/// The global index in the current module for remaining points.
remaining_points_index: GlobalIndex,

/// Accumulated cost of the current basic block.
accumulated_cost: u64,
}

#[derive(Debug, PartialEq)]
#[non_exhaustive]
pub enum MeteringError {
None,
OutOfPoints,
}

impl TryFrom<i32> for MeteringError {
type Error = ();

fn try_from(value: i32) -> Result<Self, Self::Error> {
match value {
value if value == MeteringError::None as _ => Ok(MeteringError::None),
value if value == MeteringError::OutOfPoints as _ => Ok(MeteringError::OutOfPoints),
_ => Err(()),
}
}
}

impl<T> TryFrom<Value<T>> for MeteringError {
type Error = ();

fn try_from(v: Value<T>) -> Result<Self, Self::Error> {
match v {
Value::I32(value) => value.try_into(),
_ => Err(()),
}
}
}

impl<F: Fn(&Operator) -> u64 + Copy + Clone + Send + Sync> Metering<F> {
/// Creates a `Metering` middleware.
pub fn new(initial_limit: u64, cost_function: F) -> Self {
Self {
initial_limit,
cost_function,
error_code_index: Mutex::new(None),
remaining_points_index: Mutex::new(None),
}
}
Expand All @@ -60,6 +97,7 @@ impl<F: Fn(&Operator) -> u64 + Copy + Clone + Send + Sync> fmt::Debug for Meteri
f.debug_struct("Metering")
.field("initial_limit", &self.initial_limit)
.field("cost_function", &"<function>")
.field("error_code_index", &self.error_code_index)
.field("remaining_points_index", &self.remaining_points_index)
.finish()
}
Expand All @@ -72,6 +110,11 @@ impl<F: Fn(&Operator) -> u64 + Copy + Clone + Send + Sync + 'static> ModuleMiddl
fn generate_function_middleware(&self, _: LocalFunctionIndex) -> Box<dyn FunctionMiddleware> {
Box::new(FunctionMetering {
cost_function: self.cost_function,
error_code_index: self
.error_code_index
.lock()
.unwrap()
.expect("Metering::generate_function_middleware: Error code index not set up."),
remaining_points_index: self.remaining_points_index.lock().unwrap().expect(
"Metering::generate_function_middleware: Remaining points index not set up.",
),
Expand All @@ -81,23 +124,37 @@ impl<F: Fn(&Operator) -> u64 + Copy + Clone + Send + Sync + 'static> ModuleMiddl

/// Transforms a `ModuleInfo` struct in-place. This is called before application on functions begins.
fn transform_module_info(&self, module_info: &mut ModuleInfo) {
let mut error_code_index = self.error_code_index.lock().unwrap();
let mut remaining_points_index = self.remaining_points_index.lock().unwrap();
if remaining_points_index.is_some() {

if error_code_index.is_some() || remaining_points_index.is_some() {
panic!("Metering::transform_module_info: Attempting to use a `Metering` middleware from multiple modules.");
}

// Append a global for remaining points and initialize it.
let global_index = module_info
let remaining_points_global_index = module_info
.globals
.push(GlobalType::new(Type::I64, Mutability::Var));
*remaining_points_index = Some(global_index.clone());
*remaining_points_index = Some(remaining_points_global_index.clone());

// Append a global for the error code and initialize it.
let error_code_global_index = module_info
.globals
.push(GlobalType::new(Type::I32, Mutability::Var));
*error_code_index = Some(error_code_global_index.clone());

module_info
.global_initializers
.push(GlobalInit::I64Const(self.initial_limit as i64));

module_info.exports.insert(
"remaining_points".to_string(),
ExportIndex::Global(global_index),
"wasmer_metering_remaining_points".to_string(),
ExportIndex::Global(remaining_points_global_index),
);

module_info.exports.insert(
"wasmer_metering_error_code".to_string(),
ExportIndex::Global(error_code_global_index),
);
}
}
Expand All @@ -106,6 +163,7 @@ impl<F: Fn(&Operator) -> u64 + Copy + Clone + Send + Sync> fmt::Debug for Functi
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FunctionMetering")
.field("cost_function", &"<function>")
.field("error_code_index", &self.error_code_index)
.field("remaining_points_index", &self.remaining_points_index)
.finish()
}
Expand Down Expand Up @@ -143,7 +201,11 @@ impl<F: Fn(&Operator) -> u64 + Copy + Clone + Send + Sync> FunctionMiddleware
Operator::I64Const { value: self.accumulated_cost as i64 },
Operator::I64LtU,
Operator::If { ty: WpTypeOrFuncType::Type(WpType::EmptyBlockType) },
Operator::Unreachable, // FIXME: Signal the error properly.
Operator::I32Const { value: MeteringError::OutOfPoints as i32 },
Operator::GlobalSet { global_index: self.error_code_index.as_u32() },
Operator::I64Const { value: 0 },
Operator::GlobalSet { global_index: self.remaining_points_index.as_u32() },
Operator::Unreachable,
Operator::End,

// globals[remaining_points_index] -= self.accumulated_cost;
Expand All @@ -164,6 +226,28 @@ impl<F: Fn(&Operator) -> u64 + Copy + Clone + Send + Sync> FunctionMiddleware
}
}

/// Get the error code in an `Instance`.
///
/// When instance execution traps, this error code will help know if it was caused by
/// remaining points being exhausted.
///
/// This can be used in a headless engine after an ahead-of-time compilation
/// as all required state lives in the instance.
///
/// # Panic
///
/// The instance Module must have been processed with the [`Metering`] middleware
/// at compile time, otherwise this will panic.
pub fn get_error(instance: &Instance) -> MeteringError {
instance
.exports
.get_global("wasmer_metering_error_code")
.expect("Can't get `wasmer_metering_error_code` from Instance")
.get()
.try_into()
.expect("`wasmer_metering_error_code` from Instance has wrong type")
}

/// Get the remaining points in an `Instance`.
///
/// This can be used in a headless engine after an ahead-of-time compilation
Expand All @@ -176,11 +260,11 @@ impl<F: Fn(&Operator) -> u64 + Copy + Clone + Send + Sync> FunctionMiddleware
pub fn get_remaining_points(instance: &Instance) -> u64 {
instance
.exports
.get_global("remaining_points")
.expect("Can't get `remaining_points` from Instance")
.get_global("wasmer_metering_remaining_points")
.expect("Can't get `wasmer_metering_remaining_points` from Instance")
.get()
.try_into()
.expect("`remaining_points` from Instance has wrong type")
.expect("`wasmer_metering_remaining_points` from Instance has wrong type")
}

/// Set the provided remaining points in an `Instance`.
Expand All @@ -195,10 +279,10 @@ pub fn get_remaining_points(instance: &Instance) -> u64 {
pub fn set_remaining_points(instance: &Instance, points: u64) {
instance
.exports
.get_global("remaining_points")
.expect("Can't get `remaining_points` from Instance")
.get_global("wasmer_metering_remaining_points")
.expect("Can't get `wasmer_metering_remaining_points` from Instance")
.set(points.into())
.expect("Can't set `remaining_points` in Instance");
.expect("Can't set `wasmer_metering_remaining_points` in Instance");
}

#[cfg(test)]
Expand Down Expand Up @@ -258,16 +342,17 @@ mod tests {
.unwrap();
add_one.call(1).unwrap();
assert_eq!(get_remaining_points(&instance), 6);
assert_eq!(get_error(&instance), MeteringError::None);

// Second call
add_one.call(1).unwrap();
assert_eq!(get_remaining_points(&instance), 2);
assert_eq!(get_error(&instance), MeteringError::None);

// Third call fails due to limit
assert!(add_one.call(1).is_err());
// TODO: what do we expect now? 0 or 2? See https://github.com/wasmerio/wasmer/issues/1931
// assert_eq!(metering.get_remaining_points(&instance), 2);
// assert_eq!(metering.get_remaining_points(&instance), 0);
assert_eq!(get_remaining_points(&instance), 0);
assert_eq!(get_error(&instance), MeteringError::OutOfPoints);
}

#[test]
Expand All @@ -294,9 +379,14 @@ mod tests {
// Ensure we can use the new points now
add_one.call(1).unwrap();
assert_eq!(get_remaining_points(&instance), 8);
assert_eq!(get_error(&instance), MeteringError::None);

add_one.call(1).unwrap();
assert_eq!(get_remaining_points(&instance), 4);
assert_eq!(get_error(&instance), MeteringError::None);

add_one.call(1).unwrap();
assert_eq!(get_remaining_points(&instance), 0);
assert_eq!(get_error(&instance), MeteringError::None);
}
}

0 comments on commit 84d2647

Please sign in to comment.