Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyModule in #[pyfunction] #1143

Merged
merged 10 commits into from
Sep 6, 2020
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ fn sum_as_string(a: usize, b: usize) -> PyResult<String> {
/// A Python module implemented in Rust.
#[pymodule]
fn string_sum(py: Python, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(sum_as_string))?;
m.add_function(wrap_pyfunction!(sum_as_string))?;

Ok(())
}
Expand Down
28 changes: 14 additions & 14 deletions examples/rustapi_module/src/datetime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,29 +215,29 @@ impl TzClass {

#[pymodule]
fn datetime(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(make_date))?;
m.add_wrapped(wrap_pyfunction!(get_date_tuple))?;
m.add_wrapped(wrap_pyfunction!(date_from_timestamp))?;
m.add_wrapped(wrap_pyfunction!(make_time))?;
m.add_wrapped(wrap_pyfunction!(get_time_tuple))?;
m.add_wrapped(wrap_pyfunction!(make_delta))?;
m.add_wrapped(wrap_pyfunction!(get_delta_tuple))?;
m.add_wrapped(wrap_pyfunction!(make_datetime))?;
m.add_wrapped(wrap_pyfunction!(get_datetime_tuple))?;
m.add_wrapped(wrap_pyfunction!(datetime_from_timestamp))?;
m.add_function(wrap_pyfunction!(make_date))?;
m.add_function(wrap_pyfunction!(get_date_tuple))?;
m.add_function(wrap_pyfunction!(date_from_timestamp))?;
m.add_function(wrap_pyfunction!(make_time))?;
m.add_function(wrap_pyfunction!(get_time_tuple))?;
m.add_function(wrap_pyfunction!(make_delta))?;
m.add_function(wrap_pyfunction!(get_delta_tuple))?;
m.add_function(wrap_pyfunction!(make_datetime))?;
m.add_function(wrap_pyfunction!(get_datetime_tuple))?;
m.add_function(wrap_pyfunction!(datetime_from_timestamp))?;

// Python 3.6+ functions
#[cfg(Py_3_6)]
{
m.add_wrapped(wrap_pyfunction!(time_with_fold))?;
m.add_function(wrap_pyfunction!(time_with_fold))?;
#[cfg(not(PyPy))]
{
m.add_wrapped(wrap_pyfunction!(get_time_tuple_fold))?;
m.add_wrapped(wrap_pyfunction!(get_datetime_tuple_fold))?;
m.add_function(wrap_pyfunction!(get_time_tuple_fold))?;
m.add_function(wrap_pyfunction!(get_datetime_tuple_fold))?;
}
}

m.add_wrapped(wrap_pyfunction!(issue_219))?;
m.add_function(wrap_pyfunction!(issue_219))?;
m.add_class::<TzClass>()?;

Ok(())
Expand Down
2 changes: 1 addition & 1 deletion examples/rustapi_module/src/othermod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ fn double(x: i32) -> i32 {

#[pymodule]
fn othermod(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(double))?;
m.add_function(wrap_pyfunction!(double))?;

m.add_class::<ModClass>()?;

Expand Down
4 changes: 2 additions & 2 deletions examples/word-count/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ fn count_line(line: &str, needle: &str) -> usize {
#[pymodule]
fn word_count(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(search))?;
m.add_wrapped(wrap_pyfunction!(search_sequential))?;
m.add_wrapped(wrap_pyfunction!(search_sequential_allow_threads))?;
m.add_function(wrap_pyfunction!(search_sequential))?;
m.add_function(wrap_pyfunction!(search_sequential_allow_threads))?;

Ok(())
}
54 changes: 52 additions & 2 deletions guide/src/function.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ fn double(x: usize) -> usize {

#[pymodule]
fn module_with_functions(py: Python, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(double)).unwrap();
m.add_function(wrap_pyfunction!(double)).unwrap();

Ok(())
}
Expand Down Expand Up @@ -65,7 +65,7 @@ fn num_kwds(kwds: Option<&PyDict>) -> usize {

#[pymodule]
fn module_with_functions(py: Python, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(num_kwds)).unwrap();
m.add_function(wrap_pyfunction!(num_kwds)).unwrap();
Ok(())
}

Expand Down Expand Up @@ -189,3 +189,53 @@ If you have a static function, you can expose it with `#[pyfunction]` and use [`
[`PyAny::call1`]: https://docs.rs/pyo3/latest/pyo3/struct.PyAny.html#tymethod.call1
[`PyObject`]: https://docs.rs/pyo3/latest/pyo3/type.PyObject.html
[`wrap_pyfunction!`]: https://docs.rs/pyo3/latest/pyo3/macro.wrap_pyfunction.html

### Accessing the module of a function

Functions are usually associated with modules, in the C-API, the self parameter in a function call corresponds
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this first sentence is probably an implementation detail that the user doesn't need to know? Might want to take it out for simplicity.

to the module of the function. It is possible to access the module of a `#[pyfunction]` and `#[pyfn]` in the
function body by passing the `need_module` argument to the attribute:

```rust
use pyo3::wrap_pyfunction;
use pyo3::prelude::*;

#[pyfunction(need_module)]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bikeshedding: I'd like to propose to call the attribute pass_module.

Motivation is that this was already the verb used in #828

Also I've seen this phrasing before e.g. in Python's click library: https://click.palletsprojects.com/en/7.x/api/#click.pass_context

fn pyfunction_with_module(
sebpuetz marked this conversation as resolved.
Show resolved Hide resolved
module: &PyModule
) -> PyResult<&str> {
module.name()
}

#[pymodule]
fn module_with_fn(py: Python, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(pyfunction_with_module))
}

# fn main() {}
```

If `need_module` is set, the first argument **must** be the `&PyModule`. It is then possible to interact with
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe "it is then possible to use the module in the function body"?

the module.

The same works for `#[pyfn]`:

```rust
use pyo3::wrap_pyfunction;
use pyo3::prelude::*;

#[pymodule]
fn module_with_fn(py: Python, m: &PyModule) -> PyResult<()> {

#[pyfn(m, "module_name", need_module)]
fn module_name(module: &PyModule) -> PyResult<&str> {
module.name()
}
Ok(())
}

# fn main() {}
```

Within Python, the name of the module that a function belongs to can be accessed through the `__module__`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is probably true of all Python functions (although you've just fixed this for pyO3) so not sure this sentence needs to be in this section?

attribute.
2 changes: 1 addition & 1 deletion guide/src/logging.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ fn my_module(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
// A good place to install the Rust -> Python logger.
pyo3_log::init();

m.add_wrapped(wrap_pyfunction!(log_something))?;
m.add_function(wrap_pyfunction!(log_something))?;
Ok(())
}
```
Expand Down
4 changes: 2 additions & 2 deletions guide/src/module.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,13 @@ fn subfunction() -> String {

#[pymodule]
fn submodule(_py: Python, module: &PyModule) -> PyResult<()> {
module.add_wrapped(wrap_pyfunction!(subfunction))?;
module.add_function(wrap_pyfunction!(subfunction))?;
Ok(())
}

#[pymodule]
fn supermodule(_py: Python, module: &PyModule) -> PyResult<()> {
module.add_wrapped(wrap_pymodule!(submodule))?;
module.add_submodule(wrap_pymodule!(submodule))?;
Ok(())
}

Expand Down
2 changes: 1 addition & 1 deletion guide/src/trait_bounds.md
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@ pub struct UserModel {
#[pymodule]
fn trait_exposure(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<UserModel>()?;
m.add_wrapped(wrap_pyfunction!(solve_wrapper))?;
m.add_function(wrap_pyfunction!(solve_wrapper))?;
Ok(())
}

Expand Down
83 changes: 63 additions & 20 deletions pyo3-derive-backend/src/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
//! Code generation for the function that initializes a python module and adds classes and function.

use crate::method;
use crate::pyfunction;
use crate::pyfunction::PyFunctionAttr;
use crate::pymethod;
use crate::pymethod::get_arg_names;
Expand Down Expand Up @@ -45,7 +44,7 @@ pub fn process_functions_in_module(func: &mut syn::ItemFn) -> syn::Result<()> {
let item: syn::ItemFn = syn::parse_quote! {
fn block_wrapper() {
#function_to_python
#module_name.add_wrapped(&#function_wrapper_ident)?;
#module_name.add_function(&#function_wrapper_ident)?;
}
};
stmts.extend(item.block.stmts.into_iter());
Expand Down Expand Up @@ -78,11 +77,11 @@ fn wrap_fn_argument<'a>(cap: &'a syn::PatType) -> syn::Result<method::FnArg<'a>>
/// Extracts the data from the #[pyfn(...)] attribute of a function
fn extract_pyfn_attrs(
attrs: &mut Vec<syn::Attribute>,
) -> syn::Result<Option<(syn::Path, Ident, Vec<pyfunction::Argument>)>> {
) -> syn::Result<Option<(syn::Path, Ident, PyFunctionAttr)>> {
let mut new_attrs = Vec::new();
let mut fnname = None;
let mut modname = None;
let mut fn_attrs = Vec::new();
let mut fn_attrs = PyFunctionAttr::default();

for attr in attrs.iter() {
match attr.parse_meta() {
Expand Down Expand Up @@ -115,9 +114,7 @@ fn extract_pyfn_attrs(
}
// Read additional arguments
if list.nested.len() >= 3 {
fn_attrs = PyFunctionAttr::from_meta(&meta[2..meta.len()])
.unwrap()
.arguments;
fn_attrs = PyFunctionAttr::from_meta(&meta[2..meta.len()])?;
}
} else {
return Err(syn::Error::new_spanned(
Expand Down Expand Up @@ -148,11 +145,11 @@ fn function_wrapper_ident(name: &Ident) -> Ident {
pub fn add_fn_to_module(
func: &mut syn::ItemFn,
python_name: Ident,
pyfn_attrs: Vec<pyfunction::Argument>,
pyfn_attrs: PyFunctionAttr,
) -> syn::Result<TokenStream> {
let mut arguments = Vec::new();

for input in func.sig.inputs.iter() {
for (i, input) in func.sig.inputs.iter().enumerate() {
match input {
syn::FnArg::Receiver(_) => {
return Err(syn::Error::new_spanned(
Expand All @@ -161,7 +158,27 @@ pub fn add_fn_to_module(
))
}
syn::FnArg::Typed(ref cap) => {
arguments.push(wrap_fn_argument(cap)?);
if pyfn_attrs.need_module && i == 0 {
if let syn::Type::Reference(tyref) = cap.ty.as_ref() {
if let syn::Type::Path(typath) = tyref.elem.as_ref() {
if typath
.path
.segments
.last()
.map(|seg| seg.ident == "PyModule")
.unwrap_or(false)
{
continue;
}
}
}
return Err(syn::Error::new_spanned(
cap,
"Expected &PyModule as first argument with `need_module`.",
));
} else {
arguments.push(wrap_fn_argument(cap)?);
}
}
}
}
Expand All @@ -177,7 +194,7 @@ pub fn add_fn_to_module(
tp: method::FnType::FnStatic,
name: &function_wrapper_ident,
python_name,
attrs: pyfn_attrs,
attrs: pyfn_attrs.arguments,
args: arguments,
output: ty,
doc,
Expand All @@ -187,10 +204,14 @@ pub fn add_fn_to_module(

let python_name = &spec.python_name;

let wrapper = function_c_wrapper(&func.sig.ident, &spec);
let wrapper = function_c_wrapper(&func.sig.ident, &spec, pyfn_attrs.need_module);

Ok(quote! {
fn #function_wrapper_ident(py: pyo3::Python) -> pyo3::PyObject {
fn #function_wrapper_ident<'a>(
args: impl Into<pyo3::derive_utils::WrapPyFunctionArguments<'a>>
) -> pyo3::PyResult<pyo3::PyObject> {
let arg = args.into();
let (py, maybe_module) = arg.into_py_and_maybe_module();
#wrapper

let _def = pyo3::class::PyMethodDef {
Expand All @@ -200,28 +221,49 @@ pub fn add_fn_to_module(
ml_doc: #doc,
};

let (mod_ptr, name) = if let Some(m) = maybe_module {
let mod_ptr = <pyo3::types::PyModule as ::pyo3::conversion::AsPyPointer>::as_ptr(m);
let name = m.name()?;
let name = <&str as pyo3::conversion::IntoPy<PyObject>>::into_py(name, py);
(mod_ptr, <PyObject as pyo3::AsPyPointer>::as_ptr(&name))
} else {
(std::ptr::null_mut(), std::ptr::null_mut())
};

let function = unsafe {
pyo3::PyObject::from_owned_ptr(
py,
pyo3::ffi::PyCFunction_New(
pyo3::ffi::PyCFunction_NewEx(
Box::into_raw(Box::new(_def.as_method_def())),
::std::ptr::null_mut()
mod_ptr,
name
)
)
};

function
Ok(function)
}
})
}

/// Generate static function wrapper (PyCFunction, PyCFunctionWithKeywords)
fn function_c_wrapper(name: &Ident, spec: &method::FnSpec<'_>) -> TokenStream {
fn function_c_wrapper(name: &Ident, spec: &method::FnSpec<'_>, need_module: bool) -> TokenStream {
let names: Vec<Ident> = get_arg_names(&spec);
let cb = quote! {
#name(#(#names),*)
let cb;
let slf_module;
if need_module {
cb = quote! {
#name(_slf, #(#names),*)
};
slf_module = Some(quote! {
let _slf = _py.from_borrowed_ptr::<pyo3::types::PyModule>(_slf);
});
} else {
cb = quote! {
#name(#(#names),*)
};
slf_module = None;
};

let body = pymethod::impl_arg_params(spec, None, cb);

quote! {
Expand All @@ -232,6 +274,7 @@ fn function_c_wrapper(name: &Ident, spec: &method::FnSpec<'_>) -> TokenStream {
{
const _LOCATION: &'static str = concat!(stringify!(#name), "()");
pyo3::callback_body!(_py, {
#slf_module
let _args = _py.from_borrowed_ptr::<pyo3::types::PyTuple>(_args);
let _kwargs: Option<&pyo3::types::PyDict> = _py.from_borrowed_ptr_or_opt(_kwargs);

Expand Down
6 changes: 5 additions & 1 deletion pyo3-derive-backend/src/pyfunction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ pub struct PyFunctionAttr {
has_kw: bool,
has_varargs: bool,
has_kwargs: bool,
pub need_module: bool,
}

impl syn::parse::Parse for PyFunctionAttr {
Expand All @@ -45,6 +46,9 @@ impl PyFunctionAttr {

pub fn add_item(&mut self, item: &NestedMeta) -> syn::Result<()> {
match item {
NestedMeta::Meta(syn::Meta::Path(ref ident)) if ident.is_ident("need_module") => {
self.need_module = true;
}
NestedMeta::Meta(syn::Meta::Path(ref ident)) => self.add_work(item, ident)?,
NestedMeta::Meta(syn::Meta::NameValue(ref nv)) => {
self.add_name_value(item, nv)?;
Expand Down Expand Up @@ -204,7 +208,7 @@ pub fn parse_name_attribute(attrs: &mut Vec<syn::Attribute>) -> syn::Result<Opti
pub fn build_py_function(ast: &mut syn::ItemFn, args: PyFunctionAttr) -> syn::Result<TokenStream> {
let python_name =
parse_name_attribute(&mut ast.attrs)?.unwrap_or_else(|| ast.sig.ident.unraw());
add_fn_to_module(ast, python_name, args.arguments)
add_fn_to_module(ast, python_name, args)
}

#[cfg(test)]
Expand Down
Loading