From acbb3eed6082c22da839874cff710b709a73e40f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20P=C3=BCtz?= Date: Wed, 2 Sep 2020 17:26:43 +0200 Subject: [PATCH 01/10] Replace PyFunction_New with extern C function. PyFunction_New was previously implemented as a Rust function wrapper around a call to the extern C function PyFunction_NewExt with a hard-coded third argument. This commit removes the Rust wrapper and directly exposes the function from the CPython API. --- src/ffi/methodobject.rs | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/src/ffi/methodobject.rs b/src/ffi/methodobject.rs index 7c3872c5dc8..921cca847c1 100644 --- a/src/ffi/methodobject.rs +++ b/src/ffi/methodobject.rs @@ -1,6 +1,6 @@ use crate::ffi::object::{PyObject, PyTypeObject, Py_TYPE}; +use std::mem; use std::os::raw::{c_char, c_int}; -use std::{mem, ptr}; #[cfg_attr(windows, link(name = "pythonXY"))] extern "C" { @@ -96,19 +96,16 @@ impl Default for PyMethodDef { } } -#[inline] -pub unsafe fn PyCFunction_New(ml: *mut PyMethodDef, slf: *mut PyObject) -> *mut PyObject { - #[cfg_attr(PyPy, link_name = "PyPyCFunction_NewEx")] - PyCFunction_NewEx(ml, slf, ptr::null_mut()) -} - extern "C" { #[cfg_attr(PyPy, link_name = "PyPyCFunction_NewEx")] pub fn PyCFunction_NewEx( - arg1: *mut PyMethodDef, - arg2: *mut PyObject, - arg3: *mut PyObject, + ml: *mut PyMethodDef, + slf: *mut PyObject, + module: *mut PyObject, ) -> *mut PyObject; + + #[cfg_attr(PyPy, link_name = "PyPyCFunction_NewEx")] + pub fn PyCFunction_New(ml: *mut PyMethodDef, slf: *mut PyObject) -> *mut PyObject; } /* Flag passed to newmethodobject */ From 5bbca1a0528d10c003679f92f06241f8ae36e571 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20P=C3=BCtz?= Date: Thu, 3 Sep 2020 13:39:07 +0200 Subject: [PATCH 02/10] Set the module of `#[pyfunction]`s. Previously neither the module nor the name of the module of pyfunctions were registered. This commit passes the module and its name when creating a new pyfunction. PyModule::add_function and PyModule::add_module have been added and are set to replace `add_wrapped` in a future release. `add_wrapped` is kept for compatibility reasons during the transition. Depending on whether a `PyModule` or `Python` is the argument for the Python function-wrapper, the module will be registered with the function. --- examples/rustapi_module/src/datetime.rs | 28 ++++++++--------- examples/rustapi_module/src/othermod.rs | 2 +- examples/word-count/src/lib.rs | 6 ++-- guide/src/function.md | 4 +-- guide/src/module.md | 4 +-- pyo3-derive-backend/src/module.rs | 32 ++++++++++++++++--- src/derive_utils.rs | 19 ++++++++++++ src/lib.rs | 2 +- src/python.rs | 2 +- src/types/module.rs | 41 ++++++++++++++++++++++++- tests/test_module.rs | 18 +++++------ 11 files changed, 120 insertions(+), 38 deletions(-) diff --git a/examples/rustapi_module/src/datetime.rs b/examples/rustapi_module/src/datetime.rs index 3181ae79d97..3ccb7c697f6 100644 --- a/examples/rustapi_module/src/datetime.rs +++ b/examples/rustapi_module/src/datetime.rs @@ -215,29 +215,29 @@ impl TzClass { #[pymodule] fn datetime(_py: Python<'_>, m: &PyModule) -> PyResult<()> { - m.add_wrapped(wrap_pyfunction!(make_date))?; - m.add_wrapped(wrap_pyfunction!(get_date_tuple))?; - m.add_wrapped(wrap_pyfunction!(date_from_timestamp))?; - m.add_wrapped(wrap_pyfunction!(make_time))?; - m.add_wrapped(wrap_pyfunction!(get_time_tuple))?; - m.add_wrapped(wrap_pyfunction!(make_delta))?; - m.add_wrapped(wrap_pyfunction!(get_delta_tuple))?; - m.add_wrapped(wrap_pyfunction!(make_datetime))?; - m.add_wrapped(wrap_pyfunction!(get_datetime_tuple))?; - m.add_wrapped(wrap_pyfunction!(datetime_from_timestamp))?; + m.add_function(wrap_pyfunction!(make_date))?; + m.add_function(wrap_pyfunction!(get_date_tuple))?; + m.add_function(wrap_pyfunction!(date_from_timestamp))?; + m.add_function(wrap_pyfunction!(make_time))?; + m.add_function(wrap_pyfunction!(get_time_tuple))?; + m.add_function(wrap_pyfunction!(make_delta))?; + m.add_function(wrap_pyfunction!(get_delta_tuple))?; + m.add_function(wrap_pyfunction!(make_datetime))?; + m.add_function(wrap_pyfunction!(get_datetime_tuple))?; + m.add_function(wrap_pyfunction!(datetime_from_timestamp))?; // Python 3.6+ functions #[cfg(Py_3_6)] { - m.add_wrapped(wrap_pyfunction!(time_with_fold))?; + m.add_function(wrap_pyfunction!(time_with_fold))?; #[cfg(not(PyPy))] { - m.add_wrapped(wrap_pyfunction!(get_time_tuple_fold))?; - m.add_wrapped(wrap_pyfunction!(get_datetime_tuple_fold))?; + m.add_function(wrap_pyfunction!(get_time_tuple_fold))?; + m.add_function(wrap_pyfunction!(get_datetime_tuple_fold))?; } } - m.add_wrapped(wrap_pyfunction!(issue_219))?; + m.add_function(wrap_pyfunction!(issue_219))?; m.add_class::()?; Ok(()) diff --git a/examples/rustapi_module/src/othermod.rs b/examples/rustapi_module/src/othermod.rs index 20745b29fb6..b9955806186 100644 --- a/examples/rustapi_module/src/othermod.rs +++ b/examples/rustapi_module/src/othermod.rs @@ -31,7 +31,7 @@ fn double(x: i32) -> i32 { #[pymodule] fn othermod(_py: Python<'_>, m: &PyModule) -> PyResult<()> { - m.add_wrapped(wrap_pyfunction!(double))?; + m.add_function(wrap_pyfunction!(double))?; m.add_class::()?; diff --git a/examples/word-count/src/lib.rs b/examples/word-count/src/lib.rs index 06d696e895f..8d65199c8bf 100644 --- a/examples/word-count/src/lib.rs +++ b/examples/word-count/src/lib.rs @@ -55,9 +55,9 @@ fn count_line(line: &str, needle: &str) -> usize { #[pymodule] fn word_count(_py: Python<'_>, m: &PyModule) -> PyResult<()> { - m.add_wrapped(wrap_pyfunction!(search))?; - m.add_wrapped(wrap_pyfunction!(search_sequential))?; - m.add_wrapped(wrap_pyfunction!(search_sequential_allow_threads))?; + m.add_function(wrap_pyfunction!(search))?; + m.add_function(wrap_pyfunction!(search_sequential))?; + m.add_function(wrap_pyfunction!(search_sequential_allow_threads))?; Ok(()) } diff --git a/guide/src/function.md b/guide/src/function.md index 1a12d8ec6f1..b8167c0b7cb 100644 --- a/guide/src/function.md +++ b/guide/src/function.md @@ -36,7 +36,7 @@ fn double(x: usize) -> usize { #[pymodule] fn module_with_functions(py: Python, m: &PyModule) -> PyResult<()> { - m.add_wrapped(wrap_pyfunction!(double)).unwrap(); + m.add_function(wrap_pyfunction!(double)).unwrap(); Ok(()) } @@ -65,7 +65,7 @@ fn num_kwds(kwds: Option<&PyDict>) -> usize { #[pymodule] fn module_with_functions(py: Python, m: &PyModule) -> PyResult<()> { - m.add_wrapped(wrap_pyfunction!(num_kwds)).unwrap(); + m.add_function(wrap_pyfunction!(num_kwds)).unwrap(); Ok(()) } diff --git a/guide/src/module.md b/guide/src/module.md index 4dea21b1b9b..042b11f0178 100644 --- a/guide/src/module.md +++ b/guide/src/module.md @@ -67,13 +67,13 @@ fn subfunction() -> String { #[pymodule] fn submodule(_py: Python, module: &PyModule) -> PyResult<()> { - module.add_wrapped(wrap_pyfunction!(subfunction))?; + module.add_function(wrap_pyfunction!(subfunction))?; Ok(()) } #[pymodule] fn supermodule(_py: Python, module: &PyModule) -> PyResult<()> { - module.add_wrapped(wrap_pymodule!(submodule))?; + module.add_module(wrap_pymodule!(submodule))?; Ok(()) } diff --git a/pyo3-derive-backend/src/module.rs b/pyo3-derive-backend/src/module.rs index bd6e4182793..ff5d9f964ab 100644 --- a/pyo3-derive-backend/src/module.rs +++ b/pyo3-derive-backend/src/module.rs @@ -45,7 +45,7 @@ pub fn process_functions_in_module(func: &mut syn::ItemFn) -> syn::Result<()> { let item: syn::ItemFn = syn::parse_quote! { fn block_wrapper() { #function_to_python - #module_name.add_wrapped(&#function_wrapper_ident)?; + #module_name.add_function(&#function_wrapper_ident)?; } }; stmts.extend(item.block.stmts.into_iter()); @@ -190,7 +190,17 @@ pub fn add_fn_to_module( let wrapper = function_c_wrapper(&func.sig.ident, &spec); Ok(quote! { - fn #function_wrapper_ident(py: pyo3::Python) -> pyo3::PyObject { + fn #function_wrapper_ident<'a>( + args: impl Into> + ) -> pyo3::PyObject { + let arg = args.into(); + let (py, maybe_module) = match arg { + pyo3::derive_utils::WrapPyFunctionArguments::Python(py) => (py, None), + pyo3::derive_utils::WrapPyFunctionArguments::PyModule(module) => { + let py = ::py(module); + (py, Some(module)) + } + }; #wrapper let _def = pyo3::class::PyMethodDef { @@ -200,12 +210,26 @@ pub fn add_fn_to_module( ml_doc: #doc, }; + let (mod_ptr, name) = if let Some(m) = maybe_module { + let mod_ptr = ::as_ptr(m); + let name = match m.name() { + Ok(name) => <&str as pyo3::conversion::IntoPy>::into_py(name, py), + Err(err) => { + return >::into_py(err, py); + } + }; + (mod_ptr, ::as_ptr(&name)) + } else { + (std::ptr::null_mut(), std::ptr::null_mut()) + }; + let function = unsafe { pyo3::PyObject::from_owned_ptr( py, - pyo3::ffi::PyCFunction_New( + pyo3::ffi::PyCFunction_NewEx( Box::into_raw(Box::new(_def.as_method_def())), - ::std::ptr::null_mut() + mod_ptr, + name ) ) }; diff --git a/src/derive_utils.rs b/src/derive_utils.rs index 2a736d7ebcd..f68f3ab1a83 100644 --- a/src/derive_utils.rs +++ b/src/derive_utils.rs @@ -207,3 +207,22 @@ where >>::try_from(cell) } } + +/// Enum to abstract over the arguments of Python function wrappers. +#[doc(hidden)] +pub enum WrapPyFunctionArguments<'a> { + Python(Python<'a>), + PyModule(&'a PyModule), +} + +impl<'a> From> for WrapPyFunctionArguments<'a> { + fn from(py: Python<'a>) -> WrapPyFunctionArguments<'a> { + WrapPyFunctionArguments::Python(py) + } +} + +impl<'a> From<&'a PyModule> for WrapPyFunctionArguments<'a> { + fn from(module: &'a PyModule) -> WrapPyFunctionArguments<'a> { + WrapPyFunctionArguments::PyModule(module) + } +} diff --git a/src/lib.rs b/src/lib.rs index 10f3e768f8e..4c2313e3a47 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -71,7 +71,7 @@ //! #[pymodule] //! /// A Python module implemented in Rust. //! fn string_sum(py: Python, m: &PyModule) -> PyResult<()> { -//! m.add_wrapped(wrap_pyfunction!(sum_as_string))?; +//! m.add_function(wrap_pyfunction!(sum_as_string))?; //! //! Ok(()) //! } diff --git a/src/python.rs b/src/python.rs index 901a426bcf0..db4abfe6b1a 100644 --- a/src/python.rs +++ b/src/python.rs @@ -134,7 +134,7 @@ impl<'p> Python<'p> { /// let gil = Python::acquire_gil(); /// let py = gil.python(); /// let m = PyModule::new(py, "pcount").unwrap(); - /// m.add_wrapped(wrap_pyfunction!(parallel_count)).unwrap(); + /// m.add_function(wrap_pyfunction!(parallel_count)).unwrap(); /// let locals = [("pcount", m)].into_py_dict(py); /// py.run(r#" /// s = ["Flow", "my", "tears", "the", "Policeman", "Said"] diff --git a/src/types/module.rs b/src/types/module.rs index b345fcab563..c6edb370632 100644 --- a/src/types/module.rs +++ b/src/types/module.rs @@ -194,11 +194,50 @@ impl PyModule { /// ```rust,ignore /// m.add("also_double", wrap_pyfunction!(double)(py)); /// ``` - pub fn add_wrapped(&self, wrapper: &impl Fn(Python) -> PyObject) -> PyResult<()> { + /// + /// **This function will be deprecated in the next release. Please use the specific + /// [add_function] and [add_module] functions instead.** + pub fn add_wrapped<'a>(&'a self, wrapper: &impl Fn(Python<'a>) -> PyObject) -> PyResult<()> { let function = wrapper(self.py()); let name = function .getattr(self.py(), "__name__") .expect("A function or module must have a __name__"); self.add(name.extract(self.py()).unwrap(), function) } + + /// Adds a (sub)module to a module. + /// + /// Use this together with `#[pymodule]` and [wrap_pymodule!]. + /// + /// ```rust,ignore + /// m.add_module(wrap_pymodule!(utils)); + /// ``` + pub fn add_module<'a>(&'a self, wrapper: &impl Fn(Python<'a>) -> PyObject) -> PyResult<()> { + let function = wrapper(self.py()); + let name = function + .getattr(self.py(), "__name__") + .expect("A module must have a __name__"); + self.add(name.extract(self.py()).unwrap(), function) + } + + /// Adds a function to a module, using the functions __name__ as name. + /// + /// Use this together with the`#[pyfunction]` and [wrap_pyfunction!]. + /// + /// ```rust,ignore + /// m.add_function(wrap_pyfunction!(double)); + /// ``` + /// + /// You can also add a function with a custom name using [add](PyModule::add): + /// + /// ```rust,ignore + /// m.add("also_double", wrap_pyfunction!(double)(py, m)); + /// ``` + pub fn add_function<'a>(&'a self, wrapper: &impl Fn(&'a Self) -> PyObject) -> PyResult<()> { + let function = wrapper(self); + let name = function + .getattr(self.py(), "__name__") + .expect("A function or module must have a __name__"); + self.add(name.extract(self.py()).unwrap(), function) + } } diff --git a/tests/test_module.rs b/tests/test_module.rs index 0746fb8f868..f3c44667f06 100644 --- a/tests/test_module.rs +++ b/tests/test_module.rs @@ -35,7 +35,7 @@ fn double(x: usize) -> usize { /// This module is implemented in Rust. #[pymodule] -fn module_with_functions(py: Python, m: &PyModule) -> PyResult<()> { +fn module_with_functions(_py: Python, m: &PyModule) -> PyResult<()> { use pyo3::wrap_pyfunction; #[pyfn(m, "sum_as_string")] @@ -60,8 +60,8 @@ fn module_with_functions(py: Python, m: &PyModule) -> PyResult<()> { m.add("foo", "bar").unwrap(); - m.add_wrapped(wrap_pyfunction!(double)).unwrap(); - m.add("also_double", wrap_pyfunction!(double)(py)).unwrap(); + m.add_function(wrap_pyfunction!(double)).unwrap(); + m.add("also_double", wrap_pyfunction!(double)(m)).unwrap(); Ok(()) } @@ -157,7 +157,7 @@ fn r#move() -> usize { fn raw_ident_module(_py: Python, module: &PyModule) -> PyResult<()> { use pyo3::wrap_pyfunction; - module.add_wrapped(wrap_pyfunction!(r#move)) + module.add_function(wrap_pyfunction!(r#move)) } #[test] @@ -182,7 +182,7 @@ fn custom_named_fn() -> usize { fn foobar_module(_py: Python, m: &PyModule) -> PyResult<()> { use pyo3::wrap_pyfunction; - m.add_wrapped(wrap_pyfunction!(custom_named_fn))?; + m.add_function(wrap_pyfunction!(custom_named_fn))?; m.dict().set_item("yay", "me")?; Ok(()) } @@ -216,7 +216,7 @@ fn subfunction() -> String { fn submodule(_py: Python, module: &PyModule) -> PyResult<()> { use pyo3::wrap_pyfunction; - module.add_wrapped(wrap_pyfunction!(subfunction))?; + module.add_function(wrap_pyfunction!(subfunction))?; Ok(()) } @@ -229,8 +229,8 @@ fn superfunction() -> String { fn supermodule(_py: Python, module: &PyModule) -> PyResult<()> { use pyo3::{wrap_pyfunction, wrap_pymodule}; - module.add_wrapped(wrap_pyfunction!(superfunction))?; - module.add_wrapped(wrap_pymodule!(submodule))?; + module.add_function(wrap_pyfunction!(superfunction))?; + module.add_module(wrap_pymodule!(submodule))?; Ok(()) } @@ -268,7 +268,7 @@ fn vararg_module(_py: Python, m: &PyModule) -> PyResult<()> { ext_vararg_fn(py, a, vararg) } - m.add_wrapped(pyo3::wrap_pyfunction!(ext_vararg_fn)) + m.add_function(pyo3::wrap_pyfunction!(ext_vararg_fn)) .unwrap(); Ok(()) } From 1f017b66fbd21eaf8a2408415e23fb2e42aeab92 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20P=C3=BCtz?= Date: Thu, 3 Sep 2020 15:21:33 +0200 Subject: [PATCH 03/10] Move py fn wrapper argument expansion to associated function. Suggestion by @kngwyu. Additionally replace some `expect` calls with error handling. --- pyo3-derive-backend/src/module.rs | 8 +------- src/derive_utils.rs | 12 ++++++++++++ src/types/module.rs | 18 ++++++------------ 3 files changed, 19 insertions(+), 19 deletions(-) diff --git a/pyo3-derive-backend/src/module.rs b/pyo3-derive-backend/src/module.rs index ff5d9f964ab..843f3578560 100644 --- a/pyo3-derive-backend/src/module.rs +++ b/pyo3-derive-backend/src/module.rs @@ -194,13 +194,7 @@ pub fn add_fn_to_module( args: impl Into> ) -> pyo3::PyObject { let arg = args.into(); - let (py, maybe_module) = match arg { - pyo3::derive_utils::WrapPyFunctionArguments::Python(py) => (py, None), - pyo3::derive_utils::WrapPyFunctionArguments::PyModule(module) => { - let py = ::py(module); - (py, Some(module)) - } - }; + let (py, maybe_module) = arg.into_py_and_maybe_module(); #wrapper let _def = pyo3::class::PyMethodDef { diff --git a/src/derive_utils.rs b/src/derive_utils.rs index f68f3ab1a83..cef90fd1248 100644 --- a/src/derive_utils.rs +++ b/src/derive_utils.rs @@ -215,6 +215,18 @@ pub enum WrapPyFunctionArguments<'a> { PyModule(&'a PyModule), } +impl<'a> WrapPyFunctionArguments<'a> { + pub fn into_py_and_maybe_module(self) -> (Python<'a>, Option<&'a PyModule>) { + match self { + WrapPyFunctionArguments::Python(py) => (py, None), + WrapPyFunctionArguments::PyModule(module) => { + let py = module.py(); + (py, Some(module)) + } + } + } +} + impl<'a> From> for WrapPyFunctionArguments<'a> { fn from(py: Python<'a>) -> WrapPyFunctionArguments<'a> { WrapPyFunctionArguments::Python(py) diff --git a/src/types/module.rs b/src/types/module.rs index c6edb370632..d4bfa7222ed 100644 --- a/src/types/module.rs +++ b/src/types/module.rs @@ -199,10 +199,8 @@ impl PyModule { /// [add_function] and [add_module] functions instead.** pub fn add_wrapped<'a>(&'a self, wrapper: &impl Fn(Python<'a>) -> PyObject) -> PyResult<()> { let function = wrapper(self.py()); - let name = function - .getattr(self.py(), "__name__") - .expect("A function or module must have a __name__"); - self.add(name.extract(self.py()).unwrap(), function) + let name = function.getattr(self.py(), "__name__")?; + self.add(name.extract(self.py())?, function) } /// Adds a (sub)module to a module. @@ -214,10 +212,8 @@ impl PyModule { /// ``` pub fn add_module<'a>(&'a self, wrapper: &impl Fn(Python<'a>) -> PyObject) -> PyResult<()> { let function = wrapper(self.py()); - let name = function - .getattr(self.py(), "__name__") - .expect("A module must have a __name__"); - self.add(name.extract(self.py()).unwrap(), function) + let name = function.getattr(self.py(), "__name__")?; + self.add(name.extract(self.py())?, function) } /// Adds a function to a module, using the functions __name__ as name. @@ -235,9 +231,7 @@ impl PyModule { /// ``` pub fn add_function<'a>(&'a self, wrapper: &impl Fn(&'a Self) -> PyObject) -> PyResult<()> { let function = wrapper(self); - let name = function - .getattr(self.py(), "__name__") - .expect("A function or module must have a __name__"); - self.add(name.extract(self.py()).unwrap(), function) + let name = function.getattr(self.py(), "__name__")?; + self.add(name.extract(self.py())?, function) } } From 32142490100391d557e645c8a2c66a6d52955a39 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20P=C3=BCtz?= Date: Thu, 3 Sep 2020 15:48:32 +0200 Subject: [PATCH 04/10] Make python function wrapper creation fallible. Wrapping a function can fail if we can't get the module name. Based on suggestion by @kngwyu --- README.md | 2 +- examples/word-count/src/lib.rs | 2 +- guide/src/logging.md | 2 +- guide/src/trait_bounds.md | 2 +- pyo3-derive-backend/src/module.rs | 12 ++++-------- src/types/module.rs | 21 ++++++++++++++------- tests/test_bytes.rs | 6 +++--- tests/test_exceptions.rs | 4 ++-- tests/test_module.rs | 2 +- tests/test_pyfunction.rs | 4 ++-- tests/test_string.rs | 2 +- tests/test_text_signature.rs | 2 +- tests/test_various.rs | 4 ++-- 13 files changed, 34 insertions(+), 31 deletions(-) diff --git a/README.md b/README.md index bf06df9e06e..ff041bf60b7 100644 --- a/README.md +++ b/README.md @@ -67,7 +67,7 @@ fn sum_as_string(a: usize, b: usize) -> PyResult { /// A Python module implemented in Rust. #[pymodule] fn string_sum(py: Python, m: &PyModule) -> PyResult<()> { - m.add_wrapped(wrap_pyfunction!(sum_as_string))?; + m.add_function(wrap_pyfunction!(sum_as_string))?; Ok(()) } diff --git a/examples/word-count/src/lib.rs b/examples/word-count/src/lib.rs index 8d65199c8bf..50a0078026b 100644 --- a/examples/word-count/src/lib.rs +++ b/examples/word-count/src/lib.rs @@ -55,7 +55,7 @@ fn count_line(line: &str, needle: &str) -> usize { #[pymodule] fn word_count(_py: Python<'_>, m: &PyModule) -> PyResult<()> { - m.add_function(wrap_pyfunction!(search))?; + m.add_wrapped(wrap_pyfunction!(search))?; m.add_function(wrap_pyfunction!(search_sequential))?; m.add_function(wrap_pyfunction!(search_sequential_allow_threads))?; diff --git a/guide/src/logging.md b/guide/src/logging.md index a9078b5d01a..500a8804709 100644 --- a/guide/src/logging.md +++ b/guide/src/logging.md @@ -35,7 +35,7 @@ fn my_module(_py: Python<'_>, m: &PyModule) -> PyResult<()> { // A good place to install the Rust -> Python logger. pyo3_log::init(); - m.add_wrapped(wrap_pyfunction!(log_something))?; + m.add_function(wrap_pyfunction!(log_something))?; Ok(()) } ``` diff --git a/guide/src/trait_bounds.md b/guide/src/trait_bounds.md index c11d585ced9..65e173cd40d 100644 --- a/guide/src/trait_bounds.md +++ b/guide/src/trait_bounds.md @@ -488,7 +488,7 @@ pub struct UserModel { #[pymodule] fn trait_exposure(_py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; - m.add_wrapped(wrap_pyfunction!(solve_wrapper))?; + m.add_function(wrap_pyfunction!(solve_wrapper))?; Ok(()) } diff --git a/pyo3-derive-backend/src/module.rs b/pyo3-derive-backend/src/module.rs index 843f3578560..0554e71c51a 100644 --- a/pyo3-derive-backend/src/module.rs +++ b/pyo3-derive-backend/src/module.rs @@ -192,7 +192,7 @@ pub fn add_fn_to_module( Ok(quote! { fn #function_wrapper_ident<'a>( args: impl Into> - ) -> pyo3::PyObject { + ) -> pyo3::PyResult { let arg = args.into(); let (py, maybe_module) = arg.into_py_and_maybe_module(); #wrapper @@ -206,12 +206,8 @@ pub fn add_fn_to_module( let (mod_ptr, name) = if let Some(m) = maybe_module { let mod_ptr = ::as_ptr(m); - let name = match m.name() { - Ok(name) => <&str as pyo3::conversion::IntoPy>::into_py(name, py), - Err(err) => { - return >::into_py(err, py); - } - }; + let name = m.name()?; + let name = <&str as pyo3::conversion::IntoPy>::into_py(name, py); (mod_ptr, ::as_ptr(&name)) } else { (std::ptr::null_mut(), std::ptr::null_mut()) @@ -228,7 +224,7 @@ pub fn add_fn_to_module( ) }; - function + Ok(function) } }) } diff --git a/src/types/module.rs b/src/types/module.rs index d4bfa7222ed..0358257aa3f 100644 --- a/src/types/module.rs +++ b/src/types/module.rs @@ -2,6 +2,7 @@ // // based on Daniel Grunwald's https://github.com/dgrunwald/rust-cpython +use crate::callback::IntoPyCallbackOutput; use crate::err::{PyErr, PyResult}; use crate::exceptions; use crate::ffi; @@ -197,8 +198,11 @@ impl PyModule { /// /// **This function will be deprecated in the next release. Please use the specific /// [add_function] and [add_module] functions instead.** - pub fn add_wrapped<'a>(&'a self, wrapper: &impl Fn(Python<'a>) -> PyObject) -> PyResult<()> { - let function = wrapper(self.py()); + pub fn add_wrapped<'a, T>(&'a self, wrapper: &impl Fn(Python<'a>) -> T) -> PyResult<()> + where + T: IntoPyCallbackOutput, + { + let function = wrapper(self.py()).convert(self.py())?; let name = function.getattr(self.py(), "__name__")?; self.add(name.extract(self.py())?, function) } @@ -211,9 +215,9 @@ impl PyModule { /// m.add_module(wrap_pymodule!(utils)); /// ``` pub fn add_module<'a>(&'a self, wrapper: &impl Fn(Python<'a>) -> PyObject) -> PyResult<()> { - let function = wrapper(self.py()); - let name = function.getattr(self.py(), "__name__")?; - self.add(name.extract(self.py())?, function) + let module = wrapper(self.py()); + let name = module.getattr(self.py(), "__name__")?; + self.add(name.extract(self.py())?, module) } /// Adds a function to a module, using the functions __name__ as name. @@ -229,8 +233,11 @@ impl PyModule { /// ```rust,ignore /// m.add("also_double", wrap_pyfunction!(double)(py, m)); /// ``` - pub fn add_function<'a>(&'a self, wrapper: &impl Fn(&'a Self) -> PyObject) -> PyResult<()> { - let function = wrapper(self); + pub fn add_function<'a>( + &'a self, + wrapper: &impl Fn(&'a Self) -> PyResult, + ) -> PyResult<()> { + let function = wrapper(self)?; let name = function.getattr(self.py(), "__name__")?; self.add(name.extract(self.py())?, function) } diff --git a/tests/test_bytes.rs b/tests/test_bytes.rs index a48c05fc014..e458e35b553 100644 --- a/tests/test_bytes.rs +++ b/tests/test_bytes.rs @@ -14,7 +14,7 @@ fn test_pybytes_bytes_conversion() { let gil = Python::acquire_gil(); let py = gil.python(); - let f = wrap_pyfunction!(bytes_pybytes_conversion)(py); + let f = wrap_pyfunction!(bytes_pybytes_conversion)(py).unwrap(); py_assert!(py, f, "f(b'Hello World') == b'Hello World'"); } @@ -28,7 +28,7 @@ fn test_pybytes_vec_conversion() { let gil = Python::acquire_gil(); let py = gil.python(); - let f = wrap_pyfunction!(bytes_vec_conversion)(py); + let f = wrap_pyfunction!(bytes_vec_conversion)(py).unwrap(); py_assert!(py, f, "f(b'Hello World') == b'Hello World'"); } @@ -37,6 +37,6 @@ fn test_bytearray_vec_conversion() { let gil = Python::acquire_gil(); let py = gil.python(); - let f = wrap_pyfunction!(bytes_vec_conversion)(py); + let f = wrap_pyfunction!(bytes_vec_conversion)(py).unwrap(); py_assert!(py, f, "f(bytearray(b'Hello World')) == b'Hello World'"); } diff --git a/tests/test_exceptions.rs b/tests/test_exceptions.rs index 3726dfb7e39..d232f29d1c6 100644 --- a/tests/test_exceptions.rs +++ b/tests/test_exceptions.rs @@ -19,7 +19,7 @@ fn fail_to_open_file() -> PyResult<()> { fn test_filenotfounderror() { let gil = Python::acquire_gil(); let py = gil.python(); - let fail_to_open_file = wrap_pyfunction!(fail_to_open_file)(py); + let fail_to_open_file = wrap_pyfunction!(fail_to_open_file)(py).unwrap(); py_run!( py, @@ -64,7 +64,7 @@ fn call_fail_with_custom_error() -> PyResult<()> { fn test_custom_error() { let gil = Python::acquire_gil(); let py = gil.python(); - let call_fail_with_custom_error = wrap_pyfunction!(call_fail_with_custom_error)(py); + let call_fail_with_custom_error = wrap_pyfunction!(call_fail_with_custom_error)(py).unwrap(); py_run!( py, diff --git a/tests/test_module.rs b/tests/test_module.rs index f3c44667f06..da40809380f 100644 --- a/tests/test_module.rs +++ b/tests/test_module.rs @@ -61,7 +61,7 @@ fn module_with_functions(_py: Python, m: &PyModule) -> PyResult<()> { m.add("foo", "bar").unwrap(); m.add_function(wrap_pyfunction!(double)).unwrap(); - m.add("also_double", wrap_pyfunction!(double)(m)).unwrap(); + m.add("also_double", wrap_pyfunction!(double)(m)?).unwrap(); Ok(()) } diff --git a/tests/test_pyfunction.rs b/tests/test_pyfunction.rs index e8e95bdf332..0d8500a36fe 100644 --- a/tests/test_pyfunction.rs +++ b/tests/test_pyfunction.rs @@ -14,7 +14,7 @@ fn test_optional_bool() { // Regression test for issue #932 let gil = Python::acquire_gil(); let py = gil.python(); - let f = wrap_pyfunction!(optional_bool)(py); + let f = wrap_pyfunction!(optional_bool)(py).unwrap(); py_assert!(py, f, "f() == 'Some(true)'"); py_assert!(py, f, "f(True) == 'Some(true)'"); @@ -36,7 +36,7 @@ fn buffer_inplace_add(py: Python, x: PyBuffer, y: PyBuffer) { fn test_buffer_add() { let gil = Python::acquire_gil(); let py = gil.python(); - let f = wrap_pyfunction!(buffer_inplace_add)(py); + let f = wrap_pyfunction!(buffer_inplace_add)(py).unwrap(); py_expect_exception!( py, diff --git a/tests/test_string.rs b/tests/test_string.rs index 6236484a578..38d375b5418 100644 --- a/tests/test_string.rs +++ b/tests/test_string.rs @@ -14,7 +14,7 @@ fn test_unicode_encode_error() { let gil = Python::acquire_gil(); let py = gil.python(); - let take_str = wrap_pyfunction!(take_str)(py); + let take_str = wrap_pyfunction!(take_str)(py).unwrap(); py_run!( py, take_str, diff --git a/tests/test_text_signature.rs b/tests/test_text_signature.rs index 85211a34f10..e81260811e5 100644 --- a/tests/test_text_signature.rs +++ b/tests/test_text_signature.rs @@ -104,7 +104,7 @@ fn test_function() { let gil = Python::acquire_gil(); let py = gil.python(); - let f = wrap_pyfunction!(my_function)(py); + let f = wrap_pyfunction!(my_function)(py).unwrap(); py_assert!(py, f, "f.__text_signature__ == '(a, b=None, *, c=42)'"); } diff --git a/tests/test_various.rs b/tests/test_various.rs index 87270c39186..b2de718b22e 100644 --- a/tests/test_various.rs +++ b/tests/test_various.rs @@ -59,7 +59,7 @@ fn return_custom_class() { assert_eq!(get_zero().unwrap().value, 0); // Using from python - let get_zero = wrap_pyfunction!(get_zero)(py); + let get_zero = wrap_pyfunction!(get_zero)(py).unwrap(); py_assert!(py, get_zero, "get_zero().value == 0"); } @@ -206,5 +206,5 @@ fn result_conversion_function() -> Result<(), MyError> { fn test_result_conversion() { let gil = Python::acquire_gil(); let py = gil.python(); - wrap_pyfunction!(result_conversion_function)(py); + wrap_pyfunction!(result_conversion_function)(py).unwrap(); } From 795c05451169d40c9e2515c26e287bf82ee437d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20P=C3=BCtz?= Date: Thu, 3 Sep 2020 17:27:24 +0200 Subject: [PATCH 05/10] Possible to pass PyModule as first arg. This commit makes it possible to access the module of a function by passing the `need_module` argument to the pyfn and pyfunction macros. --- pyo3-derive-backend/src/module.rs | 59 ++++++++---- pyo3-derive-backend/src/pyfunction.rs | 6 +- tests/test_compile_error.rs | 1 + tests/test_module.rs | 89 ++++++++++++++++++- tests/ui/invalid_need_module_arg_position.rs | 12 +++ .../invalid_need_module_arg_position.stderr | 5 ++ 6 files changed, 154 insertions(+), 18 deletions(-) create mode 100644 tests/ui/invalid_need_module_arg_position.rs create mode 100644 tests/ui/invalid_need_module_arg_position.stderr diff --git a/pyo3-derive-backend/src/module.rs b/pyo3-derive-backend/src/module.rs index 0554e71c51a..2493caa66ae 100644 --- a/pyo3-derive-backend/src/module.rs +++ b/pyo3-derive-backend/src/module.rs @@ -2,7 +2,6 @@ //! Code generation for the function that initializes a python module and adds classes and function. use crate::method; -use crate::pyfunction; use crate::pyfunction::PyFunctionAttr; use crate::pymethod; use crate::pymethod::get_arg_names; @@ -78,11 +77,11 @@ fn wrap_fn_argument<'a>(cap: &'a syn::PatType) -> syn::Result> /// Extracts the data from the #[pyfn(...)] attribute of a function fn extract_pyfn_attrs( attrs: &mut Vec, -) -> syn::Result)>> { +) -> syn::Result> { let mut new_attrs = Vec::new(); let mut fnname = None; let mut modname = None; - let mut fn_attrs = Vec::new(); + let mut fn_attrs = PyFunctionAttr::default(); for attr in attrs.iter() { match attr.parse_meta() { @@ -115,9 +114,7 @@ fn extract_pyfn_attrs( } // Read additional arguments if list.nested.len() >= 3 { - fn_attrs = PyFunctionAttr::from_meta(&meta[2..meta.len()]) - .unwrap() - .arguments; + fn_attrs = PyFunctionAttr::from_meta(&meta[2..meta.len()])?; } } else { return Err(syn::Error::new_spanned( @@ -148,11 +145,11 @@ fn function_wrapper_ident(name: &Ident) -> Ident { pub fn add_fn_to_module( func: &mut syn::ItemFn, python_name: Ident, - pyfn_attrs: Vec, + pyfn_attrs: PyFunctionAttr, ) -> syn::Result { let mut arguments = Vec::new(); - for input in func.sig.inputs.iter() { + for (i, input) in func.sig.inputs.iter().enumerate() { match input { syn::FnArg::Receiver(_) => { return Err(syn::Error::new_spanned( @@ -161,7 +158,27 @@ pub fn add_fn_to_module( )) } syn::FnArg::Typed(ref cap) => { - arguments.push(wrap_fn_argument(cap)?); + if pyfn_attrs.need_module && i == 0 { + if let syn::Type::Reference(tyref) = cap.ty.as_ref() { + if let syn::Type::Path(typath) = tyref.elem.as_ref() { + if typath + .path + .segments + .last() + .map(|seg| seg.ident == "PyModule") + .unwrap_or(false) + { + continue; + } + } + } + return Err(syn::Error::new_spanned( + cap, + "Expected &PyModule as first argument with `need_module`.", + )); + } else { + arguments.push(wrap_fn_argument(cap)?); + } } } } @@ -177,7 +194,7 @@ pub fn add_fn_to_module( tp: method::FnType::FnStatic, name: &function_wrapper_ident, python_name, - attrs: pyfn_attrs, + attrs: pyfn_attrs.arguments, args: arguments, output: ty, doc, @@ -187,7 +204,7 @@ pub fn add_fn_to_module( let python_name = &spec.python_name; - let wrapper = function_c_wrapper(&func.sig.ident, &spec); + let wrapper = function_c_wrapper(&func.sig.ident, &spec, pyfn_attrs.need_module); Ok(quote! { fn #function_wrapper_ident<'a>( @@ -230,12 +247,23 @@ pub fn add_fn_to_module( } /// Generate static function wrapper (PyCFunction, PyCFunctionWithKeywords) -fn function_c_wrapper(name: &Ident, spec: &method::FnSpec<'_>) -> TokenStream { +fn function_c_wrapper(name: &Ident, spec: &method::FnSpec<'_>, need_module: bool) -> TokenStream { let names: Vec = get_arg_names(&spec); - let cb = quote! { - #name(#(#names),*) + let cb; + let slf_module; + if need_module { + cb = quote! { + #name(_slf, #(#names),*) + }; + slf_module = Some(quote! { + let _slf = _py.from_borrowed_ptr::(_slf); + }); + } else { + cb = quote! { + #name(#(#names),*) + }; + slf_module = None; }; - let body = pymethod::impl_arg_params(spec, None, cb); quote! { @@ -246,6 +274,7 @@ fn function_c_wrapper(name: &Ident, spec: &method::FnSpec<'_>) -> TokenStream { { const _LOCATION: &'static str = concat!(stringify!(#name), "()"); pyo3::callback_body!(_py, { + #slf_module let _args = _py.from_borrowed_ptr::(_args); let _kwargs: Option<&pyo3::types::PyDict> = _py.from_borrowed_ptr_or_opt(_kwargs); diff --git a/pyo3-derive-backend/src/pyfunction.rs b/pyo3-derive-backend/src/pyfunction.rs index 96ef584a9ca..cde619b79ae 100644 --- a/pyo3-derive-backend/src/pyfunction.rs +++ b/pyo3-derive-backend/src/pyfunction.rs @@ -24,6 +24,7 @@ pub struct PyFunctionAttr { has_kw: bool, has_varargs: bool, has_kwargs: bool, + pub need_module: bool, } impl syn::parse::Parse for PyFunctionAttr { @@ -45,6 +46,9 @@ impl PyFunctionAttr { pub fn add_item(&mut self, item: &NestedMeta) -> syn::Result<()> { match item { + NestedMeta::Meta(syn::Meta::Path(ref ident)) if ident.is_ident("need_module") => { + self.need_module = true; + } NestedMeta::Meta(syn::Meta::Path(ref ident)) => self.add_work(item, ident)?, NestedMeta::Meta(syn::Meta::NameValue(ref nv)) => { self.add_name_value(item, nv)?; @@ -204,7 +208,7 @@ pub fn parse_name_attribute(attrs: &mut Vec) -> syn::Result syn::Result { let python_name = parse_name_attribute(&mut ast.attrs)?.unwrap_or_else(|| ast.sig.ident.unraw()); - add_fn_to_module(ast, python_name, args.arguments) + add_fn_to_module(ast, python_name, args) } #[cfg(test)] diff --git a/tests/test_compile_error.rs b/tests/test_compile_error.rs index 653a80d56b5..5d02b8e8991 100644 --- a/tests/test_compile_error.rs +++ b/tests/test_compile_error.rs @@ -4,6 +4,7 @@ fn test_compile_errors() { let t = trybuild::TestCases::new(); t.compile_fail("tests/ui/invalid_frompy_derive.rs"); t.compile_fail("tests/ui/invalid_macro_args.rs"); + t.compile_fail("tests/ui/invalid_need_module_arg_position.rs"); t.compile_fail("tests/ui/invalid_property_args.rs"); t.compile_fail("tests/ui/invalid_pyclass_args.rs"); t.compile_fail("tests/ui/invalid_pymethod_names.rs"); diff --git a/tests/test_module.rs b/tests/test_module.rs index da40809380f..0b071af9558 100644 --- a/tests/test_module.rs +++ b/tests/test_module.rs @@ -1,6 +1,6 @@ use pyo3::prelude::*; -use pyo3::types::{IntoPyDict, PyTuple}; +use pyo3::types::{IntoPyDict, PyDict, PyTuple}; mod common; @@ -49,6 +49,11 @@ fn module_with_functions(_py: Python, m: &PyModule) -> PyResult<()> { Ok(42) } + #[pyfn(m, "with_module", need_module)] + fn with_module(module: &PyModule) -> PyResult<&str> { + module.name() + } + #[pyfn(m, "double_value")] fn double_value(v: &ValueClass) -> usize { v.value * 2 @@ -97,6 +102,7 @@ fn test_module_with_functions() { run("assert module_with_functions.also_double(3) == 6"); run("assert module_with_functions.also_double.__doc__ == 'Doubles the given value'"); run("assert module_with_functions.double_value(module_with_functions.ValueClass(1)) == 2"); + run("assert module_with_functions.with_module() == 'module_with_functions'"); } #[pymodule(other_name)] @@ -230,7 +236,7 @@ fn supermodule(_py: Python, module: &PyModule) -> PyResult<()> { use pyo3::{wrap_pyfunction, wrap_pymodule}; module.add_function(wrap_pyfunction!(superfunction))?; - module.add_module(wrap_pymodule!(submodule))?; + module.add_submodule(wrap_pymodule!(submodule))?; Ok(()) } @@ -305,3 +311,82 @@ fn test_module_with_constant() { py_assert!(py, m, "isinstance(m.ANON, m.AnonClass)"); }); } + +#[pyfunction(need_module)] +fn pyfunction_with_module(module: &PyModule) -> PyResult<&str> { + module.name() +} + +#[pyfunction(need_module)] +fn pyfunction_with_module_and_py<'a>( + module: &'a PyModule, + _python: Python<'a>, +) -> PyResult<&'a str> { + module.name() +} + +#[pyfunction(need_module)] +fn pyfunction_with_module_and_arg(module: &PyModule, string: String) -> PyResult<(&str, String)> { + module.name().map(|s| (s, string)) +} + +#[pyfunction(need_module, string = "\"foo\"")] +fn pyfunction_with_module_and_default_arg<'a>( + module: &'a PyModule, + string: &str, +) -> PyResult<(&'a str, String)> { + module.name().map(|s| (s, string.into())) +} + +#[pyfunction(need_module, args = "*", kwargs = "**")] +fn pyfunction_with_module_and_args_kwargs<'a>( + module: &'a PyModule, + args: &PyTuple, + kwargs: Option<&PyDict>, +) -> PyResult<(&'a str, usize, Option)> { + module + .name() + .map(|s| (s, args.len(), kwargs.map(|d| d.len()))) +} + +#[pymodule] +fn module_with_functions_with_module(_py: Python, m: &PyModule) -> PyResult<()> { + m.add_function(pyo3::wrap_pyfunction!(pyfunction_with_module))?; + m.add_function(pyo3::wrap_pyfunction!(pyfunction_with_module_and_py))?; + m.add_function(pyo3::wrap_pyfunction!(pyfunction_with_module_and_arg))?; + m.add_function(pyo3::wrap_pyfunction!( + pyfunction_with_module_and_default_arg + ))?; + m.add_function(pyo3::wrap_pyfunction!( + pyfunction_with_module_and_args_kwargs + )) +} + +#[test] +fn test_module_functions_with_module() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let m = pyo3::wrap_pymodule!(module_with_functions_with_module)(py); + py_assert!( + py, + m, + "m.pyfunction_with_module() == 'module_with_functions_with_module'" + ); + py_assert!( + py, + m, + "m.pyfunction_with_module_and_py() == 'module_with_functions_with_module'" + ); + py_assert!( + py, + m, + "m.pyfunction_with_module_and_default_arg() \ + == ('module_with_functions_with_module', 'foo')" + ); + py_assert!( + py, + m, + "m.pyfunction_with_module_and_args_kwargs(1, x=1, y=2) \ + == ('module_with_functions_with_module', 1, 2)" + ); +} diff --git a/tests/ui/invalid_need_module_arg_position.rs b/tests/ui/invalid_need_module_arg_position.rs new file mode 100644 index 00000000000..8f69716110c --- /dev/null +++ b/tests/ui/invalid_need_module_arg_position.rs @@ -0,0 +1,12 @@ +use pyo3::prelude::*; + +#[pymodule] +fn module(_py: Python, m: &PyModule) -> PyResult<()> { + #[pyfn(m, "with_module", need_module)] + fn fail(string: &str, module: &PyModule) -> PyResult<&str> { + module.name() + } + Ok(()) +} + +fn main(){} \ No newline at end of file diff --git a/tests/ui/invalid_need_module_arg_position.stderr b/tests/ui/invalid_need_module_arg_position.stderr new file mode 100644 index 00000000000..ae0f261f774 --- /dev/null +++ b/tests/ui/invalid_need_module_arg_position.stderr @@ -0,0 +1,5 @@ +error: Expected &PyModule as first argument with `need_module`. + --> $DIR/invalid_need_module_arg_position.rs:6:13 + | +6 | fn fail(string: &str, module: &PyModule) -> PyResult<&str> { + | ^^^^^^^^^^^^ From 4aae523e54ddd4060050191dcf1477f9842962c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20P=C3=BCtz?= Date: Fri, 4 Sep 2020 09:02:49 +0200 Subject: [PATCH 06/10] Rename add_module to add_submodule, documentation fixes. --- guide/src/module.md | 2 +- src/types/module.rs | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/guide/src/module.md b/guide/src/module.md index 042b11f0178..a458eadd7f3 100644 --- a/guide/src/module.md +++ b/guide/src/module.md @@ -73,7 +73,7 @@ fn submodule(_py: Python, module: &PyModule) -> PyResult<()> { #[pymodule] fn supermodule(_py: Python, module: &PyModule) -> PyResult<()> { - module.add_module(wrap_pymodule!(submodule))?; + module.add_submodule(wrap_pymodule!(submodule))?; Ok(()) } diff --git a/src/types/module.rs b/src/types/module.rs index 0358257aa3f..8a0dd078bd8 100644 --- a/src/types/module.rs +++ b/src/types/module.rs @@ -197,7 +197,7 @@ impl PyModule { /// ``` /// /// **This function will be deprecated in the next release. Please use the specific - /// [add_function] and [add_module] functions instead.** + /// [add_function] and [add_submodule] functions instead.** pub fn add_wrapped<'a, T>(&'a self, wrapper: &impl Fn(Python<'a>) -> T) -> PyResult<()> where T: IntoPyCallbackOutput, @@ -207,20 +207,20 @@ impl PyModule { self.add(name.extract(self.py())?, function) } - /// Adds a (sub)module to a module. + /// Add a submodule to a module. /// /// Use this together with `#[pymodule]` and [wrap_pymodule!]. /// /// ```rust,ignore - /// m.add_module(wrap_pymodule!(utils)); + /// m.add_submodule(wrap_pymodule!(utils)); /// ``` - pub fn add_module<'a>(&'a self, wrapper: &impl Fn(Python<'a>) -> PyObject) -> PyResult<()> { + pub fn add_submodule<'a>(&'a self, wrapper: &impl Fn(Python<'a>) -> PyObject) -> PyResult<()> { let module = wrapper(self.py()); let name = module.getattr(self.py(), "__name__")?; self.add(name.extract(self.py())?, module) } - /// Adds a function to a module, using the functions __name__ as name. + /// Add a function to a module. /// /// Use this together with the`#[pyfunction]` and [wrap_pyfunction!]. /// From 9137855e810b186e58ad83d7f7b22ffe1fa9402b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20P=C3=BCtz?= Date: Fri, 4 Sep 2020 10:02:40 +0200 Subject: [PATCH 07/10] Add documentation for accessing PyModule in #[pyfunction]s. --- guide/src/function.md | 50 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/guide/src/function.md b/guide/src/function.md index b8167c0b7cb..99487a1b44d 100644 --- a/guide/src/function.md +++ b/guide/src/function.md @@ -189,3 +189,53 @@ If you have a static function, you can expose it with `#[pyfunction]` and use [` [`PyAny::call1`]: https://docs.rs/pyo3/latest/pyo3/struct.PyAny.html#tymethod.call1 [`PyObject`]: https://docs.rs/pyo3/latest/pyo3/type.PyObject.html [`wrap_pyfunction!`]: https://docs.rs/pyo3/latest/pyo3/macro.wrap_pyfunction.html + +### Accessing the module of a function + +Functions are usually associated with modules, in the C-API, the self parameter in a function call corresponds +to the module of the function. It is possible to access the module of a `#[pyfunction]` and `#[pyfn]` in the +function body by passing the `need_module` argument to the attribute: + +```rust +use pyo3::wrap_pyfunction; +use pyo3::prelude::*; + +#[pyfunction(need_module)] +fn pyfunction_with_module( + module: &PyModule +) -> PyResult<&str> { + module.name() +} + +#[pymodule] +fn module_with_fn(py: Python, m: &PyModule) -> PyResult<()> { + m.add_function(wrap_pyfunction!(pyfunction_with_module)) +} + +# fn main() {} +``` + +If `need_module` is set, the first argument **must** be the `&PyModule`. It is then possible to interact with +the module. + +The same works for `#[pyfn]`: + +```rust +use pyo3::wrap_pyfunction; +use pyo3::prelude::*; + +#[pymodule] +fn module_with_fn(py: Python, m: &PyModule) -> PyResult<()> { + + #[pyfn(m, "module_name", need_module)] + fn module_name(module: &PyModule) -> PyResult<&str> { + module.name() + } + Ok(()) +} + +# fn main() {} +``` + +Within Python, the name of the module that a function belongs to can be accessed through the `__module__` +attribute. \ No newline at end of file From e65b849ab66b5e7f651536c9e74463e5df51b005 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20P=C3=BCtz?= Date: Sat, 5 Sep 2020 10:06:24 +0200 Subject: [PATCH 08/10] Doc fixes, changelog and rename. --- CHANGELOG.md | 4 ++++ guide/src/function.md | 16 ++++++---------- pyo3-derive-backend/src/module.rs | 10 +++++----- pyo3-derive-backend/src/pyfunction.rs | 6 +++--- tests/test_module.rs | 12 ++++++------ tests/ui/invalid_need_module_arg_position.rs | 2 +- tests/ui/invalid_need_module_arg_position.stderr | 2 +- 7 files changed, 26 insertions(+), 26 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e3886d0e1da..bd649b91d13 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,9 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Add optional implementations of `ToPyObject`, `IntoPy`, and `FromPyObject` for [hashbrown](https://crates.io/crates/hashbrown)'s `HashMap` and `HashSet` types. The `hashbrown` feature must be enabled for these implementations to be built. [#1114](https://github.com/PyO3/pyo3/pull/1114/) - Allow other `Result` types when using `#[pyfunction]`. [#1106](https://github.com/PyO3/pyo3/issues/1106). - Add `#[derive(FromPyObject)]` macro for enums and structs. [#1065](https://github.com/PyO3/pyo3/pull/1065) +- Add macro attribute to `#[pyfn]` and `#[pyfunction]` to pass the module of a Python function to the function + body. [#1143](https://github.com/PyO3/pyo3/pull/1143) +- Add `add_function()` and `add_submodule()` functions to `PyModule` [#1143](https://github.com/PyO3/pyo3/pull/1143) ### Changed - Exception types have been renamed from e.g. `RuntimeError` to `PyRuntimeError`, and are now only accessible by `&T` or `Py` similar to other Python-native types. The old names continue to exist but are deprecated. [#1024](https://github.com/PyO3/pyo3/pull/1024) @@ -50,6 +53,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Link against libpython on android with `extension-module` set. [#1095](https://github.com/PyO3/pyo3/pull/1095) - Fix support for both `__add__` and `__radd__` in the `+` operator when both are defined in `PyNumberProtocol` (and similar for all other reversible operators). [#1107](https://github.com/PyO3/pyo3/pull/1107) +- Associate Python functions with their module by passing the Module and Module name [#1143](https://github.com/PyO3/pyo3/pull/1143) ## [0.11.1] - 2020-06-30 ### Added diff --git a/guide/src/function.md b/guide/src/function.md index 99487a1b44d..82d3669a1b1 100644 --- a/guide/src/function.md +++ b/guide/src/function.md @@ -192,15 +192,14 @@ If you have a static function, you can expose it with `#[pyfunction]` and use [` ### Accessing the module of a function -Functions are usually associated with modules, in the C-API, the self parameter in a function call corresponds -to the module of the function. It is possible to access the module of a `#[pyfunction]` and `#[pyfn]` in the -function body by passing the `need_module` argument to the attribute: +It is possible to access the module of a `#[pyfunction]` and `#[pyfn]` in the +function body by passing the `pass_module` argument to the attribute: ```rust use pyo3::wrap_pyfunction; use pyo3::prelude::*; -#[pyfunction(need_module)] +#[pyfunction(pass_module)] fn pyfunction_with_module( module: &PyModule ) -> PyResult<&str> { @@ -215,8 +214,8 @@ fn module_with_fn(py: Python, m: &PyModule) -> PyResult<()> { # fn main() {} ``` -If `need_module` is set, the first argument **must** be the `&PyModule`. It is then possible to interact with -the module. +If `pass_module` is set, the first argument **must** be the `&PyModule`. It is then possible to use the module +in the function body. The same works for `#[pyfn]`: @@ -227,7 +226,7 @@ use pyo3::prelude::*; #[pymodule] fn module_with_fn(py: Python, m: &PyModule) -> PyResult<()> { - #[pyfn(m, "module_name", need_module)] + #[pyfn(m, "module_name", pass_module)] fn module_name(module: &PyModule) -> PyResult<&str> { module.name() } @@ -236,6 +235,3 @@ fn module_with_fn(py: Python, m: &PyModule) -> PyResult<()> { # fn main() {} ``` - -Within Python, the name of the module that a function belongs to can be accessed through the `__module__` -attribute. \ No newline at end of file diff --git a/pyo3-derive-backend/src/module.rs b/pyo3-derive-backend/src/module.rs index 2493caa66ae..a706100ecd2 100644 --- a/pyo3-derive-backend/src/module.rs +++ b/pyo3-derive-backend/src/module.rs @@ -158,7 +158,7 @@ pub fn add_fn_to_module( )) } syn::FnArg::Typed(ref cap) => { - if pyfn_attrs.need_module && i == 0 { + if pyfn_attrs.pass_module && i == 0 { if let syn::Type::Reference(tyref) = cap.ty.as_ref() { if let syn::Type::Path(typath) = tyref.elem.as_ref() { if typath @@ -174,7 +174,7 @@ pub fn add_fn_to_module( } return Err(syn::Error::new_spanned( cap, - "Expected &PyModule as first argument with `need_module`.", + "Expected &PyModule as first argument with `pass_module`.", )); } else { arguments.push(wrap_fn_argument(cap)?); @@ -204,7 +204,7 @@ pub fn add_fn_to_module( let python_name = &spec.python_name; - let wrapper = function_c_wrapper(&func.sig.ident, &spec, pyfn_attrs.need_module); + let wrapper = function_c_wrapper(&func.sig.ident, &spec, pyfn_attrs.pass_module); Ok(quote! { fn #function_wrapper_ident<'a>( @@ -247,11 +247,11 @@ pub fn add_fn_to_module( } /// Generate static function wrapper (PyCFunction, PyCFunctionWithKeywords) -fn function_c_wrapper(name: &Ident, spec: &method::FnSpec<'_>, need_module: bool) -> TokenStream { +fn function_c_wrapper(name: &Ident, spec: &method::FnSpec<'_>, pass_module: bool) -> TokenStream { let names: Vec = get_arg_names(&spec); let cb; let slf_module; - if need_module { + if pass_module { cb = quote! { #name(_slf, #(#names),*) }; diff --git a/pyo3-derive-backend/src/pyfunction.rs b/pyo3-derive-backend/src/pyfunction.rs index cde619b79ae..80ac1cf35f0 100644 --- a/pyo3-derive-backend/src/pyfunction.rs +++ b/pyo3-derive-backend/src/pyfunction.rs @@ -24,7 +24,7 @@ pub struct PyFunctionAttr { has_kw: bool, has_varargs: bool, has_kwargs: bool, - pub need_module: bool, + pub pass_module: bool, } impl syn::parse::Parse for PyFunctionAttr { @@ -46,8 +46,8 @@ impl PyFunctionAttr { pub fn add_item(&mut self, item: &NestedMeta) -> syn::Result<()> { match item { - NestedMeta::Meta(syn::Meta::Path(ref ident)) if ident.is_ident("need_module") => { - self.need_module = true; + NestedMeta::Meta(syn::Meta::Path(ref ident)) if ident.is_ident("pass_module") => { + self.pass_module = true; } NestedMeta::Meta(syn::Meta::Path(ref ident)) => self.add_work(item, ident)?, NestedMeta::Meta(syn::Meta::NameValue(ref nv)) => { diff --git a/tests/test_module.rs b/tests/test_module.rs index 0b071af9558..037c0217524 100644 --- a/tests/test_module.rs +++ b/tests/test_module.rs @@ -49,7 +49,7 @@ fn module_with_functions(_py: Python, m: &PyModule) -> PyResult<()> { Ok(42) } - #[pyfn(m, "with_module", need_module)] + #[pyfn(m, "with_module", pass_module)] fn with_module(module: &PyModule) -> PyResult<&str> { module.name() } @@ -312,12 +312,12 @@ fn test_module_with_constant() { }); } -#[pyfunction(need_module)] +#[pyfunction(pass_module)] fn pyfunction_with_module(module: &PyModule) -> PyResult<&str> { module.name() } -#[pyfunction(need_module)] +#[pyfunction(pass_module)] fn pyfunction_with_module_and_py<'a>( module: &'a PyModule, _python: Python<'a>, @@ -325,12 +325,12 @@ fn pyfunction_with_module_and_py<'a>( module.name() } -#[pyfunction(need_module)] +#[pyfunction(pass_module)] fn pyfunction_with_module_and_arg(module: &PyModule, string: String) -> PyResult<(&str, String)> { module.name().map(|s| (s, string)) } -#[pyfunction(need_module, string = "\"foo\"")] +#[pyfunction(pass_module, string = "\"foo\"")] fn pyfunction_with_module_and_default_arg<'a>( module: &'a PyModule, string: &str, @@ -338,7 +338,7 @@ fn pyfunction_with_module_and_default_arg<'a>( module.name().map(|s| (s, string.into())) } -#[pyfunction(need_module, args = "*", kwargs = "**")] +#[pyfunction(pass_module, args = "*", kwargs = "**")] fn pyfunction_with_module_and_args_kwargs<'a>( module: &'a PyModule, args: &PyTuple, diff --git a/tests/ui/invalid_need_module_arg_position.rs b/tests/ui/invalid_need_module_arg_position.rs index 8f69716110c..607b21273f6 100644 --- a/tests/ui/invalid_need_module_arg_position.rs +++ b/tests/ui/invalid_need_module_arg_position.rs @@ -2,7 +2,7 @@ use pyo3::prelude::*; #[pymodule] fn module(_py: Python, m: &PyModule) -> PyResult<()> { - #[pyfn(m, "with_module", need_module)] + #[pyfn(m, "with_module", pass_module)] fn fail(string: &str, module: &PyModule) -> PyResult<&str> { module.name() } diff --git a/tests/ui/invalid_need_module_arg_position.stderr b/tests/ui/invalid_need_module_arg_position.stderr index ae0f261f774..0fd00964c53 100644 --- a/tests/ui/invalid_need_module_arg_position.stderr +++ b/tests/ui/invalid_need_module_arg_position.stderr @@ -1,4 +1,4 @@ -error: Expected &PyModule as first argument with `need_module`. +error: Expected &PyModule as first argument with `pass_module`. --> $DIR/invalid_need_module_arg_position.rs:6:13 | 6 | fn fail(string: &str, module: &PyModule) -> PyResult<&str> { From 06cd7c7d5aa8a4db74165085c03dd2a72a96d54a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20P=C3=BCtz?= Date: Sat, 5 Sep 2020 11:22:13 +0200 Subject: [PATCH 09/10] Fix some more docs. --- guide/src/function.md | 4 +-- src/types/module.rs | 79 +++++++++++++++++++++++++++++++++---------- 2 files changed, 62 insertions(+), 21 deletions(-) diff --git a/guide/src/function.md b/guide/src/function.md index 82d3669a1b1..b33221c9d41 100644 --- a/guide/src/function.md +++ b/guide/src/function.md @@ -200,9 +200,7 @@ use pyo3::wrap_pyfunction; use pyo3::prelude::*; #[pyfunction(pass_module)] -fn pyfunction_with_module( - module: &PyModule -) -> PyResult<&str> { +fn pyfunction_with_module(module: &PyModule) -> PyResult<&str> { module.name() } diff --git a/src/types/module.rs b/src/types/module.rs index 8a0dd078bd8..437085f5143 100644 --- a/src/types/module.rs +++ b/src/types/module.rs @@ -185,15 +185,28 @@ impl PyModule { /// Use this together with the`#[pyfunction]` and [wrap_pyfunction!] or `#[pymodule]` and /// [wrap_pymodule!]. /// - /// ```rust,ignore - /// m.add_wrapped(wrap_pyfunction!(double)); - /// m.add_wrapped(wrap_pymodule!(utils)); + /// ```rust + /// use pyo3::prelude::*; + /// #[pymodule] + /// fn utils(_py: Python, _module: &PyModule) -> PyResult<()> { + /// Ok(()) + /// } + /// + /// #[pyfunction] + /// fn double(x: usize) -> usize { + /// x * 2 + /// } + /// #[pymodule] + /// fn top_level(_py: Python, module: &PyModule) -> PyResult<()> { + /// module.add_wrapped(pyo3::wrap_pymodule!(utils))?; + /// module.add_wrapped(pyo3::wrap_pyfunction!(double)) + /// } /// ``` /// /// You can also add a function with a custom name using [add](PyModule::add): /// /// ```rust,ignore - /// m.add("also_double", wrap_pyfunction!(double)(py)); + /// m.add("also_double", wrap_pyfunction!(double)(m)?)?; /// ``` /// /// **This function will be deprecated in the next release. Please use the specific @@ -202,43 +215,73 @@ impl PyModule { where T: IntoPyCallbackOutput, { - let function = wrapper(self.py()).convert(self.py())?; - let name = function.getattr(self.py(), "__name__")?; - self.add(name.extract(self.py())?, function) + let py = self.py(); + let function = wrapper(py).convert(py)?; + let name = function.getattr(py, "__name__")?; + let name = name.extract(py)?; + self.add(name, function) } /// Add a submodule to a module. /// /// Use this together with `#[pymodule]` and [wrap_pymodule!]. /// - /// ```rust,ignore - /// m.add_submodule(wrap_pymodule!(utils)); + /// ```rust + /// use pyo3::prelude::*; + /// #[pymodule] + /// fn utils(_py: Python, _module: &PyModule) -> PyResult<()> { + /// Ok(()) + /// } + /// #[pymodule] + /// fn top_level(_py: Python, module: &PyModule) -> PyResult<()> { + /// module.add_submodule(pyo3::wrap_pymodule!(utils)) + /// } /// ``` pub fn add_submodule<'a>(&'a self, wrapper: &impl Fn(Python<'a>) -> PyObject) -> PyResult<()> { - let module = wrapper(self.py()); - let name = module.getattr(self.py(), "__name__")?; - self.add(name.extract(self.py())?, module) + let py = self.py(); + let module = wrapper(py); + let name = module.getattr(py, "__name__")?; + let name = name.extract(py)?; + self.add(name, module) } /// Add a function to a module. /// /// Use this together with the`#[pyfunction]` and [wrap_pyfunction!]. /// - /// ```rust,ignore - /// m.add_function(wrap_pyfunction!(double)); + /// ```rust + /// use pyo3::prelude::*; + /// #[pyfunction] + /// fn double(x: usize) -> usize { + /// x * 2 + /// } + /// #[pymodule] + /// fn double_mod(_py: Python, module: &PyModule) -> PyResult<()> { + /// module.add_function(pyo3::wrap_pyfunction!(double)) + /// } /// ``` /// /// You can also add a function with a custom name using [add](PyModule::add): /// - /// ```rust,ignore - /// m.add("also_double", wrap_pyfunction!(double)(py, m)); + /// ```rust + /// use pyo3::prelude::*; + /// #[pyfunction] + /// fn double(x: usize) -> usize { + /// x * 2 + /// } + /// #[pymodule] + /// fn double_mod(_py: Python, module: &PyModule) -> PyResult<()> { + /// module.add("also_double", pyo3::wrap_pyfunction!(double)(module)?) + /// } /// ``` pub fn add_function<'a>( &'a self, wrapper: &impl Fn(&'a Self) -> PyResult, ) -> PyResult<()> { + let py = self.py(); let function = wrapper(self)?; - let name = function.getattr(self.py(), "__name__")?; - self.add(name.extract(self.py())?, function) + let name = function.getattr(py, "__name__")?; + let name = name.extract(py)?; + self.add(name, function) } } From 64b06ea9ec1b0f1c1c58f2cd59d923cb9de4a4f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20P=C3=BCtz?= Date: Sat, 5 Sep 2020 15:54:03 +0200 Subject: [PATCH 10/10] Change `add_submodule()` to take `&PyModule`. The C-exported wrapper generated through `#[pymodule]` is only required for the top-level module. --- CHANGELOG.md | 1 + guide/src/module.md | 28 +++++++++++++++++++--------- src/types/module.rs | 19 +++++++++---------- tests/test_module.rs | 25 +++++++++++++++++++++---- 4 files changed, 50 insertions(+), 23 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bd649b91d13..962ba01e0de 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Implement `Send + Sync` for `PyErr`. `PyErr::new`, `PyErr::from_type`, `PyException::py_err` and `PyException::into` have had these bounds added to their arguments. [#1067](https://github.com/PyO3/pyo3/pull/1067) - Change `#[pyproto]` to return NotImplemented for operators for which Python can try a reversed operation. #[1072](https://github.com/PyO3/pyo3/pull/1072) - `PyModule::add` now uses `IntoPy` instead of `ToPyObject`. #[1124](https://github.com/PyO3/pyo3/pull/1124) +- Add nested modules as `&PyModule` instead of using the wrapper generated by `#[pymodule]`. [#1143](https://github.com/PyO3/pyo3/pull/1143) ### Removed - Remove `PyString::as_bytes`. [#1023](https://github.com/PyO3/pyo3/pull/1023) diff --git a/guide/src/module.md b/guide/src/module.md index a458eadd7f3..6b1d4581eec 100644 --- a/guide/src/module.md +++ b/guide/src/module.md @@ -32,16 +32,22 @@ fn sum_as_string(a: i64, b: i64) -> String { # fn main() {} ``` -The `#[pymodule]` procedural macro attribute takes care of exporting the initialization function of your module to Python. It can take as an argument the name of your module, which must be the name of the `.so` or `.pyd` file; the default is the Rust function's name. +The `#[pymodule]` procedural macro attribute takes care of exporting the initialization function of your +module to Python. It can take as an argument the name of your module, which must be the name of the `.so` +or `.pyd` file; the default is the Rust function's name. -If the name of the module (the default being the function name) does not match the name of the `.so` or `.pyd` file, you will get an import error in Python with the following message: +If the name of the module (the default being the function name) does not match the name of the `.so` or +`.pyd` file, you will get an import error in Python with the following message: `ImportError: dynamic module does not define module export function (PyInit_name_of_your_module)` -To import the module, either copy the shared library as described in [the README](https://github.com/PyO3/pyo3) or use a tool, e.g. `maturin develop` with [maturin](https://github.com/PyO3/maturin) or `python setup.py develop` with [setuptools-rust](https://github.com/PyO3/setuptools-rust). +To import the module, either copy the shared library as described in [the README](https://github.com/PyO3/pyo3) +or use a tool, e.g. `maturin develop` with [maturin](https://github.com/PyO3/maturin) or +`python setup.py develop` with [setuptools-rust](https://github.com/PyO3/setuptools-rust). ## Documentation -The [Rust doc comments](https://doc.rust-lang.org/stable/book/first-edition/comments.html) of the module initialization function will be applied automatically as the Python docstring of your module. +The [Rust doc comments](https://doc.rust-lang.org/stable/book/first-edition/comments.html) of the module +initialization function will be applied automatically as the Python docstring of your module. ```python import rust2py @@ -53,7 +59,8 @@ Which means that the above Python code will print `This module is implemented in ## Modules as objects -In Python, modules are first class objects. This means that you can store them as values or add them to dicts or other modules: +In Python, modules are first class objects. This means that you can store them as values or add them to +dicts or other modules: ```rust use pyo3::prelude::*; @@ -65,15 +72,16 @@ fn subfunction() -> String { "Subfunction".to_string() } -#[pymodule] -fn submodule(_py: Python, module: &PyModule) -> PyResult<()> { +fn init_submodule(module: &PyModule) -> PyResult<()> { module.add_function(wrap_pyfunction!(subfunction))?; Ok(()) } #[pymodule] -fn supermodule(_py: Python, module: &PyModule) -> PyResult<()> { - module.add_submodule(wrap_pymodule!(submodule))?; +fn supermodule(py: Python, module: &PyModule) -> PyResult<()> { + let submod = PyModule::new(py, "submodule")?; + init_submodule(submod)?; + module.add_submodule(submod)?; Ok(()) } @@ -86,3 +94,5 @@ fn supermodule(_py: Python, module: &PyModule) -> PyResult<()> { ``` This way, you can create a module hierarchy within a single extension module. + +It is not necessary to add `#[pymodule]` on nested modules, this is only required on the top-level module. \ No newline at end of file diff --git a/src/types/module.rs b/src/types/module.rs index 437085f5143..0e480010822 100644 --- a/src/types/module.rs +++ b/src/types/module.rs @@ -228,20 +228,19 @@ impl PyModule { /// /// ```rust /// use pyo3::prelude::*; - /// #[pymodule] - /// fn utils(_py: Python, _module: &PyModule) -> PyResult<()> { - /// Ok(()) + /// + /// fn init_utils(module: &PyModule) -> PyResult<()> { + /// module.add("super_useful_constant", "important") /// } /// #[pymodule] - /// fn top_level(_py: Python, module: &PyModule) -> PyResult<()> { - /// module.add_submodule(pyo3::wrap_pymodule!(utils)) + /// fn top_level(py: Python, module: &PyModule) -> PyResult<()> { + /// let utils = PyModule::new(py, "utils")?; + /// init_utils(utils)?; + /// module.add_submodule(utils) /// } /// ``` - pub fn add_submodule<'a>(&'a self, wrapper: &impl Fn(Python<'a>) -> PyObject) -> PyResult<()> { - let py = self.py(); - let module = wrapper(py); - let name = module.getattr(py, "__name__")?; - let name = name.extract(py)?; + pub fn add_submodule(&self, module: &PyModule) -> PyResult<()> { + let name = module.name()?; self.add(name, module) } diff --git a/tests/test_module.rs b/tests/test_module.rs index 037c0217524..7c278bdcd5e 100644 --- a/tests/test_module.rs +++ b/tests/test_module.rs @@ -218,8 +218,15 @@ fn subfunction() -> String { "Subfunction".to_string() } +fn submodule(module: &PyModule) -> PyResult<()> { + use pyo3::wrap_pyfunction; + + module.add_function(wrap_pyfunction!(subfunction))?; + Ok(()) +} + #[pymodule] -fn submodule(_py: Python, module: &PyModule) -> PyResult<()> { +fn submodule_with_init_fn(_py: Python, module: &PyModule) -> PyResult<()> { use pyo3::wrap_pyfunction; module.add_function(wrap_pyfunction!(subfunction))?; @@ -232,11 +239,16 @@ fn superfunction() -> String { } #[pymodule] -fn supermodule(_py: Python, module: &PyModule) -> PyResult<()> { - use pyo3::{wrap_pyfunction, wrap_pymodule}; +fn supermodule(py: Python, module: &PyModule) -> PyResult<()> { + use pyo3::wrap_pyfunction; module.add_function(wrap_pyfunction!(superfunction))?; - module.add_submodule(wrap_pymodule!(submodule))?; + let module_to_add = PyModule::new(py, "submodule")?; + submodule(module_to_add)?; + module.add_submodule(module_to_add)?; + let module_to_add = PyModule::new(py, "submodule_with_init_fn")?; + submodule_with_init_fn(py, module_to_add)?; + module.add_submodule(module_to_add)?; Ok(()) } @@ -258,6 +270,11 @@ fn test_module_nesting() { supermodule, "supermodule.submodule.subfunction() == 'Subfunction'" ); + py_assert!( + py, + supermodule, + "supermodule.submodule_with_init_fn.subfunction() == 'Subfunction'" + ); } // Test that argument parsing specification works for pyfunctions