Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix some wit-bindgen-related issues with generated bindings #5692

Merged
merged 2 commits into from
Feb 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions crates/component-macro/tests/codegen/function-new.wit
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
default world foo {
export new: func()
}
19 changes: 19 additions & 0 deletions crates/component-macro/tests/codegen/share-types.wit
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
interface http-types{
record request {
method: string
}
record response {
body: string
}
}

default world http-interface {
export http-handler: interface {
use self.http-types.{request,response}
handle-request: func(request: request) -> response
}
import http-fetch: interface {
use self.http-types.{request,response}
fetch-request: func(request: request) -> response
}
}
18 changes: 4 additions & 14 deletions crates/wit-bindgen/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ impl Wasmtime {

fn import(&mut self, resolve: &Resolve, name: &str, item: &WorldItem) {
let snake = name.to_snake_case();
let mut gen = InterfaceGenerator::new(self, resolve, TypeMode::Owned);
let mut gen = InterfaceGenerator::new(self, resolve);
let import = match item {
WorldItem::Function(func) => {
gen.generate_function_trait_sig(TypeOwner::None, &func);
Expand Down Expand Up @@ -139,7 +139,7 @@ impl Wasmtime {

fn export(&mut self, resolve: &Resolve, name: &str, item: &WorldItem) {
let snake = name.to_snake_case();
let mut gen = InterfaceGenerator::new(self, resolve, TypeMode::AllBorrowed("'a"));
let mut gen = InterfaceGenerator::new(self, resolve);
let (ty, getter) = match item {
WorldItem::Function(func) => {
gen.define_rust_guest_export(None, func);
Expand Down Expand Up @@ -450,21 +450,15 @@ struct InterfaceGenerator<'a> {
src: Source,
gen: &'a mut Wasmtime,
resolve: &'a Resolve,
default_param_mode: TypeMode,
current_interface: Option<InterfaceId>,
}

impl<'a> InterfaceGenerator<'a> {
fn new(
gen: &'a mut Wasmtime,
resolve: &'a Resolve,
default_param_mode: TypeMode,
) -> InterfaceGenerator<'a> {
fn new(gen: &'a mut Wasmtime, resolve: &'a Resolve) -> InterfaceGenerator<'a> {
InterfaceGenerator {
src: Source::default(),
gen,
resolve,
default_param_mode,
current_interface: None,
}
}
Expand Down Expand Up @@ -1159,7 +1153,7 @@ impl<'a> InterfaceGenerator<'a> {
self.rustdoc(&func.docs);
uwrite!(
self.src,
"pub {async_} fn {}<S: wasmtime::AsContextMut>(&self, mut store: S, ",
"pub {async_} fn call_{}<S: wasmtime::AsContextMut>(&self, mut store: S, ",
func.name.to_snake_case(),
);
for (i, param) in func.params.iter().enumerate() {
Expand Down Expand Up @@ -1351,10 +1345,6 @@ impl<'a> RustGenerator<'a> for InterfaceGenerator<'a> {
self.current_interface
}

fn default_param_mode(&self) -> TypeMode {
self.default_param_mode
}

fn push_str(&mut self, s: &str) {
self.src.push_str(s);
}
Expand Down
15 changes: 4 additions & 11 deletions crates/wit-bindgen/src/rust.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ pub trait RustGenerator<'a> {

fn push_str(&mut self, s: &str);
fn info(&self, ty: TypeId) -> TypeInfo;
fn default_param_mode(&self) -> TypeMode;
fn current_interface(&self) -> Option<InterfaceId>;

fn print_ty(&mut self, ty: &Type, mode: TypeMode) {
Expand Down Expand Up @@ -209,10 +208,10 @@ pub trait RustGenerator<'a> {
fn modes_of(&self, ty: TypeId) -> Vec<(String, TypeMode)> {
let info = self.info(ty);
let mut result = Vec::new();
if info.param {
result.push((self.param_name(ty), self.default_param_mode()));
if info.borrowed {
result.push((self.param_name(ty), TypeMode::AllBorrowed("'a")));
}
if info.result && (!info.param || self.uses_two_names(&info)) {
if info.owned && (!info.borrowed || self.uses_two_names(&info)) {
result.push((self.result_name(ty), TypeMode::Owned));
}
return result;
Expand Down Expand Up @@ -358,13 +357,7 @@ pub trait RustGenerator<'a> {
}

fn uses_two_names(&self, info: &TypeInfo) -> bool {
info.has_list
&& info.param
&& info.result
&& match self.default_param_mode() {
TypeMode::AllBorrowed(_) => true,
TypeMode::Owned => false,
}
info.has_list && info.borrowed && info.owned
}

fn lifetime_for(&self, info: &TypeInfo, mode: TypeMode) -> Option<&'static str> {
Expand Down
39 changes: 25 additions & 14 deletions crates/wit-bindgen/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@ pub struct Types {

#[derive(Default, Clone, Copy, Debug, PartialEq)]
pub struct TypeInfo {
/// Whether or not this type is ever used (transitively) within the
/// parameter of a function.
pub param: bool,
/// Whether or not this type is ever used (transitively) within a borrowed
/// context, or a parameter to an export function.
pub borrowed: bool,

/// Whether or not this type is ever used (transitively) within the
/// result of a function.
pub result: bool,
/// Whether or not this type is ever used (transitively) within an owned
/// context, such as the result of an exported function or in the params or
/// results of an imported function.
pub owned: bool,

/// Whether or not this type is ever used (transitively) within the
/// error case in the result of a function.
Expand All @@ -26,8 +27,8 @@ pub struct TypeInfo {

impl std::ops::BitOrAssign for TypeInfo {
fn bitor_assign(&mut self, rhs: Self) {
self.param |= rhs.param;
self.result |= rhs.result;
self.borrowed |= rhs.borrowed;
self.owned |= rhs.owned;
self.error |= rhs.error;
self.has_list |= rhs.has_list;
}
Expand All @@ -36,39 +37,49 @@ impl std::ops::BitOrAssign for TypeInfo {
impl Types {
pub fn analyze(&mut self, resolve: &Resolve, world: WorldId) {
let world = &resolve.worlds[world];
for (_, item) in world.imports.iter().chain(world.exports.iter()) {
for (import, (_, item)) in world
.imports
.iter()
.map(|i| (true, i))
.chain(world.exports.iter().map(|i| (false, i)))
{
match item {
WorldItem::Function(f) => self.type_info_func(resolve, f),
WorldItem::Function(f) => self.type_info_func(resolve, f, import),
WorldItem::Interface(id) => {
let iface = &resolve.interfaces[*id];

for (_, t) in iface.types.iter() {
self.type_id_info(resolve, *t);
}
for (_, f) in iface.functions.iter() {
self.type_info_func(resolve, f);
self.type_info_func(resolve, f, import);
}
}
}
}
}

fn type_info_func(&mut self, resolve: &Resolve, func: &Function) {
fn type_info_func(&mut self, resolve: &Resolve, func: &Function, import: bool) {
let mut live = LiveTypes::default();
for (_, ty) in func.params.iter() {
self.type_info(resolve, ty);
live.add_type(resolve, ty);
}
for id in live.iter() {
self.type_info.get_mut(&id).unwrap().param = true;
let info = self.type_info.get_mut(&id).unwrap();
if import {
info.owned = true;
} else {
info.borrowed = true;
}
}
let mut live = LiveTypes::default();
for ty in func.results.iter_types() {
self.type_info(resolve, ty);
live.add_type(resolve, ty);
}
for id in live.iter() {
self.type_info.get_mut(&id).unwrap().result = true;
self.type_info.get_mut(&id).unwrap().owned = true;
}

for ty in func.results.iter_types() {
Expand Down
6 changes: 3 additions & 3 deletions tests/all/component_model/bindgen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ mod no_imports {
let linker = Linker::new(&engine);
let mut store = Store::new(&engine, ());
let (no_imports, _) = NoImports::instantiate(&mut store, &component, &linker)?;
no_imports.bar(&mut store)?;
no_imports.foo().foo(&mut store)?;
no_imports.call_bar(&mut store)?;
no_imports.foo().call_foo(&mut store)?;
Ok(())
}
}
Expand Down Expand Up @@ -108,7 +108,7 @@ mod one_import {
foo::add_to_linker(&mut linker, |f: &mut MyImports| f)?;
let mut store = Store::new(&engine, MyImports::default());
let (one_import, _) = OneImport::instantiate(&mut store, &component, &linker)?;
one_import.bar(&mut store)?;
one_import.call_bar(&mut store)?;
assert!(store.data().hit);
Ok(())
}
Expand Down
36 changes: 21 additions & 15 deletions tests/all/component_model/bindgen/results.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,19 +80,22 @@ mod empty_error {

assert_eq!(
results
.empty_error(&mut store, 0.0)
.call_empty_error(&mut store, 0.0)
.expect("no trap")
.expect("no error returned"),
0.0
);

results
.empty_error(&mut store, 1.0)
.call_empty_error(&mut store, 1.0)
.expect("no trap")
.err()
.expect("() error returned");

let e = results.empty_error(&mut store, 2.0).err().expect("trap");
let e = results
.call_empty_error(&mut store, 2.0)
.err()
.expect("trap");
assert_eq!(
format!("{}", e.source().expect("trap message is stored in source")),
"empty_error: trap"
Expand Down Expand Up @@ -188,20 +191,23 @@ mod string_error {

assert_eq!(
results
.string_error(&mut store, 0.0)
.call_string_error(&mut store, 0.0)
.expect("no trap")
.expect("no error returned"),
0.0
);

let e = results
.string_error(&mut store, 1.0)
.call_string_error(&mut store, 1.0)
.expect("no trap")
.err()
.expect("error returned");
assert_eq!(e, "string_error: error");

let e = results.string_error(&mut store, 2.0).err().expect("trap");
let e = results
.call_string_error(&mut store, 2.0)
.err()
.expect("trap");
assert_eq!(
format!("{}", e.source().expect("trap message is stored in source")),
"string_error: trap"
Expand Down Expand Up @@ -328,23 +334,23 @@ mod enum_error {
assert_eq!(
results
.foo()
.enum_error(&mut store, 0.0)
.call_enum_error(&mut store, 0.0)
.expect("no trap")
.expect("no error returned"),
0.0
);

let e = results
.foo()
.enum_error(&mut store, 1.0)
.call_enum_error(&mut store, 1.0)
.expect("no trap")
.err()
.expect("error returned");
assert_eq!(e, enum_error::foo::E1::A);

let e = results
.foo()
.enum_error(&mut store, 2.0)
.call_enum_error(&mut store, 2.0)
.err()
.expect("trap");
assert_eq!(
Expand Down Expand Up @@ -458,15 +464,15 @@ mod record_error {
assert_eq!(
results
.foo()
.record_error(&mut store, 0.0)
.call_record_error(&mut store, 0.0)
.expect("no trap")
.expect("no error returned"),
0.0
);

let e = results
.foo()
.record_error(&mut store, 1.0)
.call_record_error(&mut store, 1.0)
.expect("no trap")
.err()
.expect("error returned");
Expand All @@ -480,7 +486,7 @@ mod record_error {

let e = results
.foo()
.record_error(&mut store, 2.0)
.call_record_error(&mut store, 2.0)
.err()
.expect("trap");
assert_eq!(
Expand Down Expand Up @@ -594,15 +600,15 @@ mod variant_error {
assert_eq!(
results
.foo()
.variant_error(&mut store, 0.0)
.call_variant_error(&mut store, 0.0)
.expect("no trap")
.expect("no error returned"),
0.0
);

let e = results
.foo()
.variant_error(&mut store, 1.0)
.call_variant_error(&mut store, 1.0)
.expect("no trap")
.err()
.expect("error returned");
Expand All @@ -616,7 +622,7 @@ mod variant_error {

let e = results
.foo()
.variant_error(&mut store, 2.0)
.call_variant_error(&mut store, 2.0)
.err()
.expect("trap");
assert_eq!(
Expand Down