Skip to content

Commit

Permalink
feat: Add an error code on the metering middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
jubianchi committed Dec 17, 2020
1 parent 44fb48e commit 15a34bf
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 18 deletions.
35 changes: 23 additions & 12 deletions examples/metering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@ use wasmer::CompilerConfig;
use wasmer::{imports, wat2wasm, Instance, Module, Store};
use wasmer_compiler_cranelift::Cranelift;
use wasmer_engine_jit::JIT;
use wasmer_middlewares::metering::{get_remaining_points, set_remaining_points, Metering};
use wasmer_middlewares::{
metering::{get_remaining_points, set_remaining_points, get_error},
Metering,
MeteringError,
};

fn main() -> anyhow::Result<()> {
// Let's declare the Wasm module.
Expand Down Expand Up @@ -132,19 +136,26 @@ fn main() -> anyhow::Result<()> {
);
}
Err(_) => {
println!("Calling `add_one` failed: not enough gas points remaining.");
}
}
println!("Calling `add_one` failed.");

// Becasue the previous call failed, it did not consume any gas point.
// We still have 2 remaining points.
let remaining_points_after_third_call = get_remaining_points(&instance);
assert_eq!(remaining_points_after_third_call, 2);
let error = get_error(&instance);

println!(
"Remaining points after third call: {:?}",
remaining_points_after_third_call
);
match error {
MeteringError::NoError => bail!("No error."),
MeteringError::OutOfGas => println!("Not enough gas points remaining."),
}

// Becasue the previous consumed all the remaining gas points, we got an error.
// There is now 0 remaining points (we can't go below 0).
let remaining_points_after_third_call = get_remaining_points(&instance);
assert_eq!(remaining_points_after_third_call, 0);

println!(
"Remaining points after third call: {:?}",
remaining_points_after_third_call
);
}
}

// Now let's see how we can set a new limit...
println!("Set new remaining points points to 10");
Expand Down
1 change: 1 addition & 0 deletions lib/middlewares/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ pub mod metering;
// The most commonly used symbol are exported at top level of the module. Others are available
// via modules, e.g. `wasmer_middlewares::metering::get_remaining_points`
pub use metering::Metering;
pub use metering::{get_remaining_points, set_remaining_points, get_error, MeteringError};
79 changes: 73 additions & 6 deletions lib/middlewares/src/metering.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//! `metering` is a middleware for tracking how many operators are executed in total
//! and putting a limit on the total number of operators executed.
use core::convert::TryFrom;
use std::convert::TryInto;
use std::fmt;
use std::sync::Mutex;
Expand All @@ -11,7 +12,7 @@ use wasmer::{
ExportIndex, FunctionMiddleware, GlobalInit, GlobalType, Instance, LocalFunctionIndex,
MiddlewareReaderState, ModuleMiddleware, Mutability, Type,
};
use wasmer_types::GlobalIndex;
use wasmer_types::{GlobalIndex, Value};
use wasmer_vm::ModuleInfo;

/// The module-level metering middleware.
Expand All @@ -30,6 +31,8 @@ pub struct Metering<F: Fn(&Operator) -> u64 + Copy + Clone + Send + Sync> {

/// The global index in the current module for remaining points.
remaining_points_index: Mutex<Option<GlobalIndex>>,

error_code_index: Mutex<Option<GlobalIndex>>,
}

/// The function-level metering middleware.
Expand All @@ -40,17 +43,49 @@ pub struct FunctionMetering<F: Fn(&Operator) -> u64 + Copy + Clone + Send + Sync
/// The global index in the current module for remaining points.
remaining_points_index: GlobalIndex,

error_code_index: GlobalIndex,

/// Accumulated cost of the current basic block.
accumulated_cost: u64,
}

pub enum MeteringError {
NoError,
OutOfGas,
}

impl <T>TryFrom<Value<T>> for MeteringError {
type Error = ();

fn try_from(v: Value<T>) -> Result<Self, Self::Error> {
match v {
Value::I32(value) => {
match value {
value if value == MeteringError::NoError as _ => Ok(MeteringError::NoError),
value if value == MeteringError::OutOfGas as _ => Ok(MeteringError::OutOfGas),
_ => Err(()),
}
},
Value::I64(value ) => {
match value {
value if value == MeteringError::NoError as i64 => Ok(MeteringError::NoError),
value if value == MeteringError::OutOfGas as i64 => Ok(MeteringError::OutOfGas),
_ => Err(()),
}
}
_ => Err(()),
}
}
}

impl<F: Fn(&Operator) -> u64 + Copy + Clone + Send + Sync> Metering<F> {
/// Creates a `Metering` middleware.
pub fn new(initial_limit: u64, cost_function: F) -> Self {
Self {
initial_limit,
cost_function,
remaining_points_index: Mutex::new(None),
error_code_index: Mutex::new(None),
}
}
}
Expand All @@ -61,6 +96,7 @@ impl<F: Fn(&Operator) -> u64 + Copy + Clone + Send + Sync> fmt::Debug for Meteri
.field("initial_limit", &self.initial_limit)
.field("cost_function", &"<function>")
.field("remaining_points_index", &self.remaining_points_index)
.field("error_code_index", &self.error_code_index)
.finish()
}
}
Expand All @@ -75,29 +111,45 @@ impl<F: Fn(&Operator) -> u64 + Copy + Clone + Send + Sync + 'static> ModuleMiddl
remaining_points_index: self.remaining_points_index.lock().unwrap().expect(
"Metering::generate_function_middleware: Remaining points index not set up.",
),
error_code_index: self.error_code_index.lock().unwrap().expect(
"Metering::generate_function_middleware: Error code index not set up.",
),
accumulated_cost: 0,
})
}

/// Transforms a `ModuleInfo` struct in-place. This is called before application on functions begins.
fn transform_module_info(&self, module_info: &mut ModuleInfo) {
let mut remaining_points_index = self.remaining_points_index.lock().unwrap();
if remaining_points_index.is_some() {
let mut error_code_index = self.error_code_index.lock().unwrap();

if remaining_points_index.is_some() || error_code_index.is_some() {
panic!("Metering::transform_module_info: Attempting to use a `Metering` middleware from multiple modules.");
}

// Append a global for remaining points and initialize it.
let global_index = module_info
let remaining_points_global_index = module_info
.globals
.push(GlobalType::new(Type::I64, Mutability::Var));
*remaining_points_index = Some(global_index.clone());
*remaining_points_index = Some(remaining_points_global_index.clone());

let error_code_global_index = module_info
.globals
.push(GlobalType::new(Type::I32, Mutability::Var));
*error_code_index = Some(error_code_global_index.clone());

module_info
.global_initializers
.push(GlobalInit::I64Const(self.initial_limit as i64));

module_info.exports.insert(
"remaining_points".to_string(),
ExportIndex::Global(global_index),
ExportIndex::Global(remaining_points_global_index),
);

module_info.exports.insert(
"error_code".to_string(),
ExportIndex::Global(error_code_global_index),
);
}
}
Expand All @@ -107,6 +159,7 @@ impl<F: Fn(&Operator) -> u64 + Copy + Clone + Send + Sync> fmt::Debug for Functi
f.debug_struct("FunctionMetering")
.field("cost_function", &"<function>")
.field("remaining_points_index", &self.remaining_points_index)
.field("error_code_index", &self.error_code_index)
.finish()
}
}
Expand Down Expand Up @@ -143,7 +196,11 @@ impl<F: Fn(&Operator) -> u64 + Copy + Clone + Send + Sync> FunctionMiddleware
Operator::I64Const { value: self.accumulated_cost as i64 },
Operator::I64LtU,
Operator::If { ty: WpTypeOrFuncType::Type(WpType::EmptyBlockType) },
Operator::Unreachable, // FIXME: Signal the error properly.
Operator::I32Const { value: MeteringError::OutOfGas as i32 },
Operator::GlobalSet { global_index: self.error_code_index.as_u32() },
Operator::I32Const { value: 0 },
Operator::GlobalSet { global_index: self.remaining_points_index.as_u32() },
Operator::Unreachable,
Operator::End,

// globals[remaining_points_index] -= self.accumulated_cost;
Expand Down Expand Up @@ -183,6 +240,16 @@ pub fn get_remaining_points(instance: &Instance) -> u64 {
.expect("`remaining_points` from Instance has wrong type")
}

pub fn get_error(instance: &Instance) -> MeteringError {
instance
.exports
.get_global("error_code")
.expect("Can't get `error_code` from Instance")
.get()
.try_into()
.expect("`error_code` from Instance has wrong type")
}

/// Set the provided remaining points in an `Instance`.
///
/// This can be used in a headless engine after an ahead-of-time compilation
Expand Down

0 comments on commit 15a34bf

Please sign in to comment.