Skip to content

Commit

Permalink
Merge pull request #13 from arnaudgolfouse/avoid-return-copy
Browse files Browse the repository at this point in the history
Avoid copy when the plugin returns
  • Loading branch information
astrale-sharp authored Aug 6, 2023
2 parents 60664a7 + dfaf3c5 commit aa695aa
Show file tree
Hide file tree
Showing 6 changed files with 179 additions and 90 deletions.
39 changes: 27 additions & 12 deletions examples/hello_c/hello.c
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,37 @@
#include <cstddef>
#include <cstdint>
#include <cstdlib>
#include <cstring>
#define PROTOCOL_FUNCTION __attribute__((import_module("typst_env"))) extern "C"
#else
#include <stddef.h>
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#define PROTOCOL_FUNCTION __attribute__((import_module("typst_env"))) extern
#endif

// ===
// Functions for the protocol

PROTOCOL_FUNCTION void
wasm_minimal_protocol_send_result_to_host(const uint8_t *ptr, size_t len);
PROTOCOL_FUNCTION void wasm_minimal_protocol_write_args_to_buffer(uint8_t *ptr);

EMSCRIPTEN_KEEPALIVE void wasm_minimal_protocol_free_byte_buffer(uint8_t *ptr,
size_t len) {
free(ptr);
}

// ===

EMSCRIPTEN_KEEPALIVE
int32_t hello(void) {
const char message[] = "Hello world !";
wasm_minimal_protocol_send_result_to_host((uint8_t *)message,
sizeof(message) - 1);
const char static_message[] = "Hello world !";
const size_t length = sizeof(static_message);
char *message = malloc(length);
memcpy((void *)message, (void *)static_message, length);
wasm_minimal_protocol_send_result_to_host((uint8_t *)message, length - 1);
return 0;
}

Expand All @@ -36,7 +50,6 @@ int32_t double_it(size_t arg_len) {
alloc_result[arg_len + i] = alloc_result[i];
}
wasm_minimal_protocol_send_result_to_host(alloc_result, result_len);
free(alloc_result);
return 0;
}

Expand Down Expand Up @@ -66,7 +79,6 @@ int32_t concatenate(size_t arg1_len, size_t arg2_len) {

wasm_minimal_protocol_send_result_to_host(result, total_len + 1);

free(result);
free(args);
return 0;
}
Expand Down Expand Up @@ -102,24 +114,27 @@ int32_t shuffle(size_t arg1_len, size_t arg2_len, size_t arg3_len) {

wasm_minimal_protocol_send_result_to_host(result, result_len);

free(result);
free(args);
return 0;
}

EMSCRIPTEN_KEEPALIVE
int32_t returns_ok() {
const char message[] = "This is an `Ok`";
wasm_minimal_protocol_send_result_to_host((uint8_t *)message,
sizeof(message) - 1);
const char static_message[] = "This is an `Ok`";
const size_t length = sizeof(static_message);
char *message = malloc(length);
memcpy((void *)message, (void *)static_message, length);
wasm_minimal_protocol_send_result_to_host((uint8_t *)message, length - 1);
return 0;
}

EMSCRIPTEN_KEEPALIVE
int32_t returns_err() {
const char message[] = "This is an `Err`";
wasm_minimal_protocol_send_result_to_host((uint8_t *)message,
sizeof(message) - 1);
const char static_message[] = "This is an `Err`";
const size_t length = sizeof(static_message);
char *message = malloc(length);
memcpy((void *)message, (void *)static_message, length);
wasm_minimal_protocol_send_result_to_host((uint8_t *)message, length - 1);
return 1;
}

Expand Down
49 changes: 32 additions & 17 deletions examples/hello_zig/hello.zig
Original file line number Diff line number Diff line change
@@ -1,23 +1,36 @@
const std = @import("std");
const allocator = std.heap.page_allocator;

// ===
// Functions for the protocol

extern "typst_env" fn wasm_minimal_protocol_send_result_to_host(ptr: [*]const u8, len: usize) void;
extern "typst_env" fn wasm_minimal_protocol_write_args_to_buffer(ptr: [*]u8) void;

export fn wasm_minimal_protocol_free_byte_buffer(ptr: [*]u8, len: usize) void {
var slice: []u8 = undefined;
slice.ptr = ptr;
slice.len = len;
allocator.free(slice);
}

// ===

export fn hello() i32 {
const message = "Hello world !";
wasm_minimal_protocol_send_result_to_host(message.ptr, message.len);
var result = allocator.alloc(u8, message.len) catch return 1;
@memcpy(result, message);
wasm_minimal_protocol_send_result_to_host(result.ptr, result.len);
return 0;
}

export fn double_it(arg1_len: usize) i32 {
var alloc_result = allocator.alloc(u8, arg1_len * 2) catch return 1;
defer allocator.free(alloc_result);
wasm_minimal_protocol_write_args_to_buffer(alloc_result.ptr);
var result = allocator.alloc(u8, arg1_len * 2) catch return 1;
wasm_minimal_protocol_write_args_to_buffer(result.ptr);
for (0..arg1_len) |i| {
alloc_result[i + arg1_len] = alloc_result[i];
result[i + arg1_len] = result[i];
}
wasm_minimal_protocol_send_result_to_host(alloc_result.ptr, alloc_result.len);
wasm_minimal_protocol_send_result_to_host(result.ptr, result.len);
return 0;
}

Expand All @@ -27,7 +40,6 @@ export fn concatenate(arg1_len: usize, arg2_len: usize) i32 {
wasm_minimal_protocol_write_args_to_buffer(args.ptr);

var result = allocator.alloc(u8, arg1_len + arg2_len + 1) catch return 1;
defer allocator.free(result);
for (0..arg1_len) |i| {
result[i] = args[i];
}
Expand All @@ -49,27 +61,30 @@ export fn shuffle(arg1_len: usize, arg2_len: usize, arg3_len: usize) i32 {
var arg2 = args[arg1_len .. arg1_len + arg2_len];
var arg3 = args[arg1_len + arg2_len .. args.len];

var result: std.ArrayList(u8) = std.ArrayList(u8).initCapacity(allocator, args_len + 2) catch return 1;
defer result.deinit();
result.appendSlice(arg3) catch return 1;
result.append('-') catch return 1;
result.appendSlice(arg1) catch return 1;
result.append('-') catch return 1;
result.appendSlice(arg2) catch return 1;
var result = allocator.alloc(u8, arg1_len + arg2_len + arg3_len + 2) catch return 1;
@memcpy(result[0..arg3.len], arg3);
result[arg3.len] = '-';
@memcpy(result[arg3.len + 1 ..][0..arg1.len], arg1);
result[arg3.len + arg1.len + 1] = '-';
@memcpy(result[arg3.len + arg1.len + 2 ..][0..arg2.len], arg2);

wasm_minimal_protocol_send_result_to_host(result.items.ptr, result.items.len);
wasm_minimal_protocol_send_result_to_host(result.ptr, result.len);
return 0;
}

export fn returns_ok() i32 {
const message = "This is an `Ok`";
wasm_minimal_protocol_send_result_to_host(message.ptr, message.len);
var result = allocator.alloc(u8, message.len) catch return 1;
@memcpy(result, message);
wasm_minimal_protocol_send_result_to_host(result.ptr, result.len);
return 0;
}

export fn returns_err() i32 {
const message = "This is an `Err`";
wasm_minimal_protocol_send_result_to_host(message.ptr, message.len);
var result = allocator.alloc(u8, message.len) catch return 1;
@memcpy(result, message);
wasm_minimal_protocol_send_result_to_host(result.ptr, result.len);
return 1;
}

Expand Down
121 changes: 79 additions & 42 deletions examples/host-wasmi/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,67 @@
use wasmi::{AsContext, Caller, Engine, Func as Function, Linker, Module, Value};
use wasmi::{AsContext, Caller, Engine, Func as Function, Linker, Memory, Module, Value};

type Store = wasmi::Store<PersistentData>;

/// Reference to a slice of memory returned after
/// [calling a wasm function](PluginInstance::call).
///
/// # Drop
/// On [`Drop`], this will free the slice of memory inside the plugin.
///
/// As such, this structure mutably borrows the [`PluginInstance`], which prevents
/// another function from being called.
pub struct ReturnedData<'a> {
memory: Memory,
ptr: u32,
len: u32,
free_function: &'a Function,
context_mut: &'a mut Store,
}

impl<'a> ReturnedData<'a> {
/// Get a reference to the returned slice of data.
///
/// # Panic
/// This may panic if the function returned an invalid `(ptr, len)` pair.
pub fn get(&self) -> &[u8] {
&self.memory.data(&*self.context_mut)[self.ptr as usize..(self.ptr + self.len) as usize]
}
}

impl Drop for ReturnedData<'_> {
fn drop(&mut self) {
self.free_function
.call(
&mut *self.context_mut,
&[Value::I32(self.ptr as _), Value::I32(self.len as _)],
&mut [],
)
.unwrap();
}
}

#[derive(Debug, Clone)]
struct PersistentData {
result_data: Vec<u8>,
result_ptr: u32,
result_len: u32,
arg_buffer: Vec<u8>,
}

#[derive(Debug)]
pub struct PluginInstance {
store: Store,
memory: Memory,
free_function: Function,
functions: Vec<(String, Function)>,
}

impl PluginInstance {
pub fn new_from_bytes(bytes: impl AsRef<[u8]>) -> Result<Self, String> {
let engine = Engine::default();
let data = PersistentData {
result_data: Vec::new(),
arg_buffer: Vec::new(),
result_ptr: 0,
result_len: 0,
};
let mut store = Store::new(&engine, data);

Expand All @@ -32,11 +74,8 @@ impl PluginInstance {
"typst_env",
"wasm_minimal_protocol_send_result_to_host",
move |mut caller: Caller<PersistentData>, ptr: u32, len: u32| {
let memory = caller.get_export("memory").unwrap().into_memory().unwrap();
let mut buffer = std::mem::take(&mut caller.data_mut().result_data);
buffer.resize(len as usize, 0);
memory.read(&caller, ptr as _, &mut buffer).unwrap();
caller.data_mut().result_data = buffer;
caller.data_mut().result_ptr = ptr;
caller.data_mut().result_len = len;
},
)
.unwrap()
Expand All @@ -51,54 +90,44 @@ impl PluginInstance {
},
)
.unwrap()
// hack to accept wasi file
// https://github.com/near/wasi-stub is preferred
/*
.func_wrap(
"wasi_snapshot_preview1",
"fd_write",
|_: i32, _: i32, _: i32, _: i32| 0i32,
)
.unwrap()
.func_wrap(
"wasi_snapshot_preview1",
"environ_get",
|_: i32, _: i32| 0i32,
)
.unwrap()
.func_wrap(
"wasi_snapshot_preview1",
"environ_sizes_get",
|_: i32, _: i32| 0i32,
)
.unwrap()
.func_wrap(
"wasi_snapshot_preview1",
"proc_exit",
|_: i32| {},
)
.unwrap()
*/
.instantiate(&mut store, &module)
.map_err(|e| format!("{e}"))?
.start(&mut store)
.map_err(|e| format!("{e}"))?;

let mut free_function = None;
let functions = instance
.exports(&store)
.filter_map(|e| {
let name = e.name().to_owned();
e.into_func().map(|func| (name, func))

e.into_func().map(|func| {
if name == "wasm_minimal_protocol_free_byte_buffer" {
free_function = Some(func);
}
(name, func)
})
})
.collect::<Vec<_>>();
Ok(Self { store, functions })
let free_function = free_function.unwrap();
let memory = instance
.get_export(&store, "memory")
.unwrap()
.into_memory()
.unwrap();
Ok(Self {
store,
memory,
free_function,
functions,
})
}

fn write(&mut self, args: &[&[u8]]) {
self.store.data_mut().arg_buffer = args.concat();
}

pub fn call(&mut self, function: &str, args: &[&[u8]]) -> Result<Vec<u8>, String> {
pub fn call(&mut self, function: &str, args: &[&[u8]]) -> Result<ReturnedData, String> {
self.write(args);

let (_, function) = self
Expand All @@ -122,11 +151,19 @@ impl PluginInstance {
code.first().cloned().unwrap_or(Value::I32(3)) // if the function returns nothing
};

let s = std::mem::take(&mut self.store.data_mut().result_data);
let (ptr, len) = (self.store.data().result_ptr, self.store.data().result_len);

let result = ReturnedData {
memory: self.memory,
ptr,
len,
free_function: &self.free_function,
context_mut: &mut self.store,
};

match code {
Value::I32(0) => Ok(s),
Value::I32(1) => Err(match String::from_utf8(s) {
Value::I32(0) => Ok(result),
Value::I32(1) => Err(match std::str::from_utf8(result.get()) {
Ok(err) => format!("plugin errored with: '{}'", err,),
Err(_) => String::from("plugin errored and did not return valid UTF-8"),
}),
Expand Down
7 changes: 3 additions & 4 deletions examples/test-runner/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
// you need to build the hello example first

use anyhow::Result;
use std::process::Command;

use host_wasmi::PluginInstance;
use std::process::Command;

#[cfg(not(feature = "wasi"))]
mod consts {
Expand Down Expand Up @@ -118,7 +117,7 @@ fn main() -> Result<()> {
return Ok(());
}
};
match String::from_utf8(result) {
match std::str::from_utf8(result.get()) {
Ok(s) => println!("{s}"),
Err(_) => panic!("Error: function call '{function}' did not return UTF-8"),
}
Expand All @@ -141,7 +140,7 @@ fn main() -> Result<()> {
continue;
}
};
match String::from_utf8(result) {
match std::str::from_utf8(result.get()) {
Ok(s) => println!("{s}"),
Err(_) => panic!("Error: function call '{function}' did not return UTF-8"),
}
Expand Down
Loading

0 comments on commit aa695aa

Please sign in to comment.