Skip to content

Commit

Permalink
feat: Add an error code on the metering middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
jubianchi committed Dec 21, 2020
1 parent 6f06fc2 commit 780be36
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 46 deletions.
56 changes: 34 additions & 22 deletions examples/metering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ use wasmer::CompilerConfig;
use wasmer::{imports, wat2wasm, Instance, Module, Store};
use wasmer_compiler_cranelift::Cranelift;
use wasmer_engine_jit::JIT;
use wasmer_middlewares::metering::{get_remaining_points, set_remaining_points, Metering};
use wasmer_middlewares::{
metering::{get_remaining_points, set_remaining_points},
Metering, MeteringPoints,
};

fn main() -> anyhow::Result<()> {
// Let's declare the Wasm module.
Expand Down Expand Up @@ -55,11 +58,10 @@ fn main() -> anyhow::Result<()> {

// Now let's create our metering middleware.
//
// `Metering` needs to be configured with a limit (the gas limit) and
// a cost function.
// `Metering` needs to be configured with a limit and a cost function.
//
// For each `Operator`, the metering middleware will call the cost
// function and subtract the cost from the gas.
// function and subtract the cost from the remaining points.
let metering = Arc::new(Metering::new(10, cost_function));
let mut compiler_config = Cranelift::default();
compiler_config.push_middleware(metering);
Expand Down Expand Up @@ -93,14 +95,17 @@ fn main() -> anyhow::Result<()> {
println!("Calling `add_one` function once...");
add_one.call(1)?;

// As you can see here, after the first call we have 6 remaining gas points.
// As you can see here, after the first call we have 6 remaining points.
//
// This is correct, here are the details of how it has been computed:
// * `local.get $value` is a `Operator::LocalGet` which costs 1 point;
// * `i32.const` is a `Operator::I32Const` which costs 1 point;
// * `i32.add` is a `Operator::I32Add` which costs 2 points.
let remaining_points_after_first_call = get_remaining_points(&instance);
assert_eq!(remaining_points_after_first_call, 6);
assert_eq!(
remaining_points_after_first_call,
MeteringPoints::Remaining(6)
);

println!(
"Remaining points after the first call: {:?}",
Expand All @@ -110,18 +115,21 @@ fn main() -> anyhow::Result<()> {
println!("Calling `add_one` function twice...");
add_one.call(1)?;

// We spent 4 more gas points with the second call.
// We spent 4 more points with the second call.
// We have 2 remaining points.
let remaining_points_after_second_call = get_remaining_points(&instance);
assert_eq!(remaining_points_after_second_call, 2);
assert_eq!(
remaining_points_after_second_call,
MeteringPoints::Remaining(2)
);

println!(
"Remaining points after the second call: {:?}",
remaining_points_after_second_call
);

// Because calling our `add_one` function consumes 4 gas points,
// calling it a third time will fail: we already consume 8 gas
// Because calling our `add_one` function consumes 4 points,
// calling it a third time will fail: we already consume 8
// points, there are only two remaining.
println!("Calling `add_one` function a third time...");
match add_one.call(1) {
Expand All @@ -132,27 +140,31 @@ fn main() -> anyhow::Result<()> {
);
}
Err(_) => {
println!("Calling `add_one` failed: not enough gas points remaining.");
println!("Calling `add_one` failed.");

// Because the last needed more than the remaining points, we should have an error.
let remaining_points = get_remaining_points(&instance);

match remaining_points {
MeteringPoints::Remaining(..) => {
bail!("No metering error: there are remaining points")
}
MeteringPoints::Depleted(remaining, missing) => println!(
"Not enough points remaining: {:?} points remaining but need {:?}",
remaining,
remaining + missing
),
}
}
}

// Becasue the previous call failed, it did not consume any gas point.
// We still have 2 remaining points.
let remaining_points_after_third_call = get_remaining_points(&instance);
assert_eq!(remaining_points_after_third_call, 2);

println!(
"Remaining points after third call: {:?}",
remaining_points_after_third_call
);

// Now let's see how we can set a new limit...
println!("Set new remaining points points to 10");
let new_limit = 10;
set_remaining_points(&instance, new_limit);

let remaining_points = get_remaining_points(&instance);
assert_eq!(remaining_points, new_limit);
assert_eq!(remaining_points, MeteringPoints::Remaining(new_limit));

println!("Remaining points: {:?}", remaining_points);

Expand Down
1 change: 1 addition & 0 deletions lib/middlewares/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ pub mod metering;
// The most commonly used symbol are exported at top level of the module. Others are available
// via modules, e.g. `wasmer_middlewares::metering::get_remaining_points`
pub use metering::Metering;
pub use metering::{get_remaining_points, set_remaining_points, MeteringPoints};
144 changes: 120 additions & 24 deletions lib/middlewares/src/metering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ pub struct Metering<F: Fn(&Operator) -> u64 + Copy + Clone + Send + Sync> {
/// Function that maps each operator to a cost in "points".
cost_function: F,

/// The global index in the current module for the error code.
missing_points_index: Mutex<Option<GlobalIndex>>,

/// The global index in the current module for remaining points.
remaining_points_index: Mutex<Option<GlobalIndex>>,
}
Expand All @@ -37,19 +40,29 @@ pub struct FunctionMetering<F: Fn(&Operator) -> u64 + Copy + Clone + Send + Sync
/// Function that maps each operator to a cost in "points".
cost_function: F,

/// The global index in the current module for the error code.
missing_points_index: GlobalIndex,

/// The global index in the current module for remaining points.
remaining_points_index: GlobalIndex,

/// Accumulated cost of the current basic block.
accumulated_cost: u64,
}

#[derive(Debug, PartialEq)]
pub enum MeteringPoints {
Remaining(u64),
Depleted(u64, u64),
}

impl<F: Fn(&Operator) -> u64 + Copy + Clone + Send + Sync> Metering<F> {
/// Creates a `Metering` middleware.
pub fn new(initial_limit: u64, cost_function: F) -> Self {
Self {
initial_limit,
cost_function,
missing_points_index: Mutex::new(None),
remaining_points_index: Mutex::new(None),
}
}
Expand All @@ -60,6 +73,7 @@ impl<F: Fn(&Operator) -> u64 + Copy + Clone + Send + Sync> fmt::Debug for Meteri
f.debug_struct("Metering")
.field("initial_limit", &self.initial_limit)
.field("cost_function", &"<function>")
.field("missing_points_index", &self.missing_points_index)
.field("remaining_points_index", &self.remaining_points_index)
.finish()
}
Expand All @@ -72,6 +86,11 @@ impl<F: Fn(&Operator) -> u64 + Copy + Clone + Send + Sync + 'static> ModuleMiddl
fn generate_function_middleware(&self, _: LocalFunctionIndex) -> Box<dyn FunctionMiddleware> {
Box::new(FunctionMetering {
cost_function: self.cost_function,
missing_points_index: self
.missing_points_index
.lock()
.unwrap()
.expect("Metering::generate_function_middleware: Missing points index not set up."),
remaining_points_index: self.remaining_points_index.lock().unwrap().expect(
"Metering::generate_function_middleware: Remaining points index not set up.",
),
Expand All @@ -81,23 +100,37 @@ impl<F: Fn(&Operator) -> u64 + Copy + Clone + Send + Sync + 'static> ModuleMiddl

/// Transforms a `ModuleInfo` struct in-place. This is called before application on functions begins.
fn transform_module_info(&self, module_info: &mut ModuleInfo) {
let mut missing_points_index = self.missing_points_index.lock().unwrap();
let mut remaining_points_index = self.remaining_points_index.lock().unwrap();
if remaining_points_index.is_some() {

if missing_points_index.is_some() || remaining_points_index.is_some() {
panic!("Metering::transform_module_info: Attempting to use a `Metering` middleware from multiple modules.");
}

// Append a global for remaining points and initialize it.
let global_index = module_info
let remaining_points_global_index = module_info
.globals
.push(GlobalType::new(Type::I64, Mutability::Var));
*remaining_points_index = Some(global_index.clone());
*remaining_points_index = Some(remaining_points_global_index.clone());

// Append a global for the error code and initialize it.
let missing_points_global_index = module_info
.globals
.push(GlobalType::new(Type::I64, Mutability::Var));
*missing_points_index = Some(missing_points_global_index.clone());

module_info
.global_initializers
.push(GlobalInit::I64Const(self.initial_limit as i64));

module_info.exports.insert(
"remaining_points".to_string(),
ExportIndex::Global(global_index),
"wasmer_metering_remaining_points".to_string(),
ExportIndex::Global(remaining_points_global_index),
);

module_info.exports.insert(
"wasmer_metering_missing_points".to_string(),
ExportIndex::Global(missing_points_global_index),
);
}
}
Expand All @@ -106,6 +139,7 @@ impl<F: Fn(&Operator) -> u64 + Copy + Clone + Send + Sync> fmt::Debug for Functi
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("FunctionMetering")
.field("cost_function", &"<function>")
.field("missing_points_index", &self.missing_points_index)
.field("remaining_points_index", &self.remaining_points_index)
.finish()
}
Expand Down Expand Up @@ -143,7 +177,11 @@ impl<F: Fn(&Operator) -> u64 + Copy + Clone + Send + Sync> FunctionMiddleware
Operator::I64Const { value: self.accumulated_cost as i64 },
Operator::I64LtU,
Operator::If { ty: WpTypeOrFuncType::Type(WpType::EmptyBlockType) },
Operator::Unreachable, // FIXME: Signal the error properly.
Operator::I64Const { value: self.accumulated_cost as i64 },
Operator::GlobalGet { global_index: self.remaining_points_index.as_u32() },
Operator::I64Sub,
Operator::GlobalSet { global_index: self.missing_points_index.as_u32() },
Operator::Unreachable,
Operator::End,

// globals[remaining_points_index] -= self.accumulated_cost;
Expand Down Expand Up @@ -173,14 +211,28 @@ impl<F: Fn(&Operator) -> u64 + Copy + Clone + Send + Sync> FunctionMiddleware
///
/// The instance Module must have been processed with the [`Metering`] middleware
/// at compile time, otherwise this will panic.
pub fn get_remaining_points(instance: &Instance) -> u64 {
instance
pub fn get_remaining_points(instance: &Instance) -> MeteringPoints {
let points = instance
.exports
.get_global("remaining_points")
.expect("Can't get `remaining_points` from Instance")
.get_global("wasmer_metering_remaining_points")
.expect("Can't get `wasmer_metering_remaining_points` from Instance")
.get()
.try_into()
.expect("`remaining_points` from Instance has wrong type")
.expect("`wasmer_metering_remaining_points` from Instance has wrong type");

let error: u64 = instance
.exports
.get_global("wasmer_metering_missing_points")
.expect("Can't get `wasmer_metering_missing_points` from Instance")
.get()
.try_into()
.expect("`wasmer_metering_missing_points` from Instance has wrong type");

if error > 0 {
return MeteringPoints::Depleted(points, error);
}

MeteringPoints::Remaining(points)
}

/// Set the provided remaining points in an `Instance`.
Expand All @@ -195,10 +247,17 @@ pub fn get_remaining_points(instance: &Instance) -> u64 {
pub fn set_remaining_points(instance: &Instance, points: u64) {
instance
.exports
.get_global("remaining_points")
.expect("Can't get `remaining_points` from Instance")
.get_global("wasmer_metering_remaining_points")
.expect("Can't get `wasmer_metering_remaining_points` from Instance")
.set(points.into())
.expect("Can't set `remaining_points` in Instance");
.expect("Can't set `wasmer_metering_remaining_points` in Instance");

instance
.exports
.get_global("wasmer_metering_missing_points")
.expect("Can't get `wasmer_metering_missing_points` from Instance")
.set((0 as u64).into())
.expect("Can't set `wasmer_metering_missing_points` in Instance");
}

#[cfg(test)]
Expand Down Expand Up @@ -242,7 +301,10 @@ mod tests {

// Instantiate
let instance = Instance::new(&module, &imports! {}).unwrap();
assert_eq!(get_remaining_points(&instance), 10);
assert_eq!(
get_remaining_points(&instance),
MeteringPoints::Remaining(10)
);

// First call
//
Expand All @@ -257,17 +319,24 @@ mod tests {
.native::<i32, i32>()
.unwrap();
add_one.call(1).unwrap();
assert_eq!(get_remaining_points(&instance), 6);
assert_eq!(
get_remaining_points(&instance),
MeteringPoints::Remaining(6)
);

// Second call
add_one.call(1).unwrap();
assert_eq!(get_remaining_points(&instance), 2);
assert_eq!(
get_remaining_points(&instance),
MeteringPoints::Remaining(2)
);

// Third call fails due to limit
assert!(add_one.call(1).is_err());
// TODO: what do we expect now? 0 or 2? See https://github.com/wasmerio/wasmer/issues/1931
// assert_eq!(metering.get_remaining_points(&instance), 2);
// assert_eq!(metering.get_remaining_points(&instance), 0);
assert_eq!(
get_remaining_points(&instance),
MeteringPoints::Depleted(2, 2)
);
}

#[test]
Expand All @@ -280,7 +349,10 @@ mod tests {

// Instantiate
let instance = Instance::new(&module, &imports! {}).unwrap();
assert_eq!(get_remaining_points(&instance), 10);
assert_eq!(
get_remaining_points(&instance),
MeteringPoints::Remaining(10)
);
let add_one = instance
.exports
.get_function("add_one")
Expand All @@ -293,10 +365,34 @@ mod tests {

// Ensure we can use the new points now
add_one.call(1).unwrap();
assert_eq!(get_remaining_points(&instance), 8);
assert_eq!(
get_remaining_points(&instance),
MeteringPoints::Remaining(8)
);

add_one.call(1).unwrap();
assert_eq!(get_remaining_points(&instance), 4);
assert_eq!(
get_remaining_points(&instance),
MeteringPoints::Remaining(4)
);

add_one.call(1).unwrap();
assert_eq!(get_remaining_points(&instance), 0);
assert_eq!(
get_remaining_points(&instance),
MeteringPoints::Remaining(0)
);

assert!(add_one.call(1).is_err());
assert_eq!(
get_remaining_points(&instance),
MeteringPoints::Depleted(0, 4)
);

// Add some points for another call
set_remaining_points(&instance, 4);
assert_eq!(
get_remaining_points(&instance),
MeteringPoints::Remaining(4)
);
}
}

0 comments on commit 780be36

Please sign in to comment.