Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow async methods to accept &self/&mut self #3609

Merged
merged 1 commit into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions guide/src/async-await.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ Resulting future of an `async fn` decorated by `#[pyfunction]` must be `Send + '

As a consequence, `async fn` parameters and return types must also be `Send + 'static`, so it is not possible to have a signature like `async fn does_not_compile(arg: &PyAny, py: Python<'_>) -> &PyAny`.

It also means that methods cannot use `&self`/`&mut self`, *but this restriction should be dropped in the future.*

However, there is an exception for method receiver, so async methods can accept `&self`/`&mut self`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worth pointing out here that there's a good case for being extra careful with &mut self as a method receiver in a future? I think it's very likely users will accidentally hit borrow errors due to overlapping futures.

We could potentially restrict this to &self on frozen classes? That would help prevent users creating aliasing borrow errors from async code. What do you think about that restriction?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could potentially restrict this to &self on frozen classes?

Honestly, I prefer to just add a warning in the documentation. I don't want to prevent careful users to use a mpsc::Receiver in an efficient way.
Maybe you could ping another maintainer to discuss this decision?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok let's keep this. I think what I'd prefer to do is make frozen the default and then document what the risks of &mut are if users choose to enable this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But that is a separate PR.


## Implicit GIL holding

Expand Down
1 change: 1 addition & 0 deletions newsfragments/3609.changed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Allow async methods to accept `&self`/`&mut self`
48 changes: 31 additions & 17 deletions pyo3-macros-backend/src/method.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
use std::fmt::Display;

use crate::attributes::{TextSignatureAttribute, TextSignatureAttributeValue};
use crate::deprecations::{Deprecation, Deprecations};
use crate::params::impl_arg_params;
use crate::pyfunction::{FunctionSignature, PyFunctionArgPyO3Attributes};
use crate::pyfunction::{PyFunctionOptions, SignatureAttribute};
use crate::quotes;
use crate::utils::{self, PythonDoc};
use proc_macro2::{Span, TokenStream};
use quote::ToTokens;
use quote::{quote, quote_spanned};
use syn::ext::IdentExt;
use syn::spanned::Spanned;
use syn::{Ident, Result};
use quote::{quote, quote_spanned, ToTokens};
use syn::{ext::IdentExt, spanned::Spanned, Ident, Result};

use crate::{
attributes::{TextSignatureAttribute, TextSignatureAttributeValue},
deprecations::{Deprecation, Deprecations},
params::impl_arg_params,
pyfunction::{
FunctionSignature, PyFunctionArgPyO3Attributes, PyFunctionOptions, SignatureAttribute,
},
quotes,
utils::{self, PythonDoc},
};

#[derive(Clone, Debug)]
pub struct FnArg<'a> {
Expand Down Expand Up @@ -473,8 +474,7 @@ impl<'a> FnSpec<'a> {
}

let rust_call = |args: Vec<TokenStream>| {
let mut call = quote! { function(#self_arg #(#args),*) };
if self.asyncness.is_some() {
let call = if self.asyncness.is_some() {
let throw_callback = if cancel_handle.is_some() {
quote! { Some(__throw_callback) }
} else {
Expand All @@ -485,8 +485,19 @@ impl<'a> FnSpec<'a> {
Some(cls) => quote!(Some(<#cls as _pyo3::PyTypeInfo>::NAME)),
None => quote!(None),
};
call = quote! {{
let future = #call;
let future = match self.tp {
FnType::Fn(SelfType::Receiver { mutable: false, .. }) => quote! {{
let __guard = _pyo3::impl_::coroutine::RefGuard::<#cls>::new(py.from_borrowed_ptr::<_pyo3::types::PyAny>(_slf))?;
async move { function(&__guard, #(#args),*).await }
}},
FnType::Fn(SelfType::Receiver { mutable: true, .. }) => quote! {{
let mut __guard = _pyo3::impl_::coroutine::RefMutGuard::<#cls>::new(py.from_borrowed_ptr::<_pyo3::types::PyAny>(_slf))?;
async move { function(&mut __guard, #(#args),*).await }
}},
_ => quote! { function(#self_arg #(#args),*) },
};
let mut call = quote! {{
let future = #future;
_pyo3::impl_::coroutine::new_coroutine(
_pyo3::intern!(py, stringify!(#python_name)),
#qualname_prefix,
Expand All @@ -501,7 +512,10 @@ impl<'a> FnSpec<'a> {
#call
}};
}
}
call
} else {
quote! { function(#self_arg #(#args),*) }
};
quotes::map_result_into_ptr(quotes::ok_wrap(call))
};

Expand Down
74 changes: 71 additions & 3 deletions src/impl_/coroutine.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
use std::future::Future;
use std::{
future::Future,
mem,
ops::{Deref, DerefMut},
};

use crate::coroutine::cancel::ThrowCallback;
use crate::{coroutine::Coroutine, types::PyString, IntoPy, PyErr, PyObject};
use crate::{
coroutine::{cancel::ThrowCallback, Coroutine},
pyclass::boolean_struct::False,
types::PyString,
IntoPy, Py, PyAny, PyCell, PyClass, PyErr, PyObject, PyResult, Python,
};

pub fn new_coroutine<F, T, E>(
name: &PyString,
Expand All @@ -16,3 +24,63 @@ where
{
Coroutine::new(Some(name.into()), qualname_prefix, throw_callback, future)
}

fn get_ptr<T: PyClass>(obj: &Py<T>) -> *mut T {
// SAFETY: Py<T> can be casted as *const PyCell<T>
unsafe { &*(obj.as_ptr() as *const PyCell<T>) }.get_ptr()
}

pub struct RefGuard<T: PyClass>(Py<T>);

davidhewitt marked this conversation as resolved.
Show resolved Hide resolved
impl<T: PyClass> RefGuard<T> {
pub fn new(obj: &PyAny) -> PyResult<Self> {
let owned: Py<T> = obj.extract()?;
mem::forget(owned.try_borrow(obj.py())?);
Ok(RefGuard(owned))
}
}

impl<T: PyClass> Deref for RefGuard<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
// SAFETY: `RefGuard` has been built from `PyRef` and provides the same guarantees
unsafe { &*get_ptr(&self.0) }
}
}

impl<T: PyClass> Drop for RefGuard<T> {
fn drop(&mut self) {
Python::with_gil(|gil| self.0.as_ref(gil).release_ref())
}
}

pub struct RefMutGuard<T: PyClass<Frozen = False>>(Py<T>);

impl<T: PyClass<Frozen = False>> RefMutGuard<T> {
pub fn new(obj: &PyAny) -> PyResult<Self> {
let owned: Py<T> = obj.extract()?;
mem::forget(owned.try_borrow_mut(obj.py())?);
Ok(RefMutGuard(owned))
}
}

impl<T: PyClass<Frozen = False>> Deref for RefMutGuard<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
// SAFETY: `RefMutGuard` has been built from `PyRefMut` and provides the same guarantees
unsafe { &*get_ptr(&self.0) }
}
}

impl<T: PyClass<Frozen = False>> DerefMut for RefMutGuard<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
// SAFETY: `RefMutGuard` has been built from `PyRefMut` and provides the same guarantees
unsafe { &mut *get_ptr(&self.0) }
}
}

impl<T: PyClass<Frozen = False>> Drop for RefMutGuard<T> {
fn drop(&mut self) {
Python::with_gil(|gil| self.0.as_ref(gil).release_mut())
}
}
10 changes: 10 additions & 0 deletions src/pycell.rs
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,16 @@ impl<T: PyClass> PyCell<T> {
#[allow(clippy::useless_conversion)]
offset.try_into().expect("offset should fit in Py_ssize_t")
}

#[cfg(feature = "macros")]
pub(crate) fn release_ref(&self) {
self.borrow_checker().release_borrow();
}

#[cfg(feature = "macros")]
pub(crate) fn release_mut(&self) {
self.borrow_checker().release_borrow_mut();
}
}

impl<T: PyClassImpl> PyCell<T> {
Expand Down
53 changes: 53 additions & 0 deletions tests/test_coroutine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,56 @@ fn coroutine_panic() {
py_run!(gil, panic, &handle_windows(test));
})
}

#[test]
fn test_async_method_receiver() {
#[pyclass]
struct Counter(usize);
#[pymethods]
impl Counter {
#[new]
fn new() -> Self {
Self(0)
}
async fn get(&self) -> usize {
self.0
}
async fn incr(&mut self) -> usize {
self.0 += 1;
self.0
}
}
Python::with_gil(|gil| {
let test = r#"
import asyncio

obj = Counter()
coro1 = obj.get()
coro2 = obj.get()
try:
obj.incr() # borrow checking should fail
except RuntimeError as err:
pass
else:
assert False
assert asyncio.run(coro1) == 0
coro2.close()
coro3 = obj.incr()
try:
obj.incr() # borrow checking should fail
except RuntimeError as err:
pass
else:
assert False
try:
obj.get() # borrow checking should fail
except RuntimeError as err:
pass
else:
assert False
assert asyncio.run(coro3) == 1
"#;
let locals = [("Counter", gil.get_type::<Counter>())].into_py_dict(gil);
py_run!(gil, *locals, test);
})
}
Loading