Skip to content

Commit

Permalink
Add ability to return from __next__ / __anext__
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt committed Jun 23, 2020
1 parent 0c59b05 commit a9827a4
Show file tree
Hide file tree
Showing 8 changed files with 159 additions and 20 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- Add `PyByteArray::data`, `PyByteArray::as_bytes`, and `PyByteArray::as_bytes_mut`. [#967](https://github.com/PyO3/pyo3/pull/967)
- Add `GILOnceCell` to use in situations where `lazy_static` or `once_cell` can deadlock. [#975](https://github.com/PyO3/pyo3/pull/975)
- Add `Py::borrow`, `Py::borrow_mut`, `Py::try_borrow`, and `Py::try_borrow_mut` for accessing `#[pyclass]` values. [#976](https://github.com/PyO3/pyo3/pull/976)
- Add `IterNextOutput` and `IterANextOutput` for returning from `__next__` / `__anext__`. [#997](https://github.com/PyO3/pyo3/pull/997)

### Changed
- Simplify internals of `#[pyo3(get)]` attribute. (Remove the hidden API `GetPropertyValue`.) [#934](https://github.com/PyO3/pyo3/pull/934)
Expand Down
1 change: 1 addition & 0 deletions examples/rustapi_module/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def make_rust_extension(module_name):
make_rust_extension("rustapi_module.othermod"),
make_rust_extension("rustapi_module.subclassing"),
make_rust_extension("rustapi_module.test_dict"),
make_rust_extension("rustapi_module.pyclass_iter"),
],
install_requires=install_requires,
tests_require=tests_require,
Expand Down
1 change: 1 addition & 0 deletions examples/rustapi_module/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ pub mod datetime;
pub mod dict_iter;
pub mod objstore;
pub mod othermod;
pub mod pyclass_iter;
pub mod subclassing;
34 changes: 34 additions & 0 deletions examples/rustapi_module/src/pyclass_iter.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
use pyo3::class::iter::{IterNextOutput, PyIterProtocol};
use pyo3::prelude::*;

/// This is for demonstrating how to return a value from __next__
#[pyclass]
struct PyClassIter {
count: usize,
}

#[pymethods]
impl PyClassIter {
#[new]
pub fn new() -> Self {
PyClassIter { count: 0 }
}
}

#[pyproto]
impl PyIterProtocol for PyClassIter {
fn __next__(mut slf: PyRefMut<Self>) -> IterNextOutput<usize, &'static str> {
if slf.count < 5 {
slf.count += 1;
IterNextOutput::Yield(slf.count)
} else {
IterNextOutput::Return("Ended")
}
}
}

#[pymodule]
pub fn pyclass_iter(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<PyClassIter>()?;
Ok(())
}
15 changes: 15 additions & 0 deletions examples/rustapi_module/tests/test_pyclass_iter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import pytest
from rustapi_module import pyclass_iter


def test_iter():
i = pyclass_iter.PyClassIter()
assert next(i) == 1
assert next(i) == 2
assert next(i) == 3
assert next(i) == 4
assert next(i) == 5

with pytest.raises(StopIteration) as excinfo:
next(i)
assert excinfo.value.value == "Ended"
49 changes: 47 additions & 2 deletions guide/src/class.md
Original file line number Diff line number Diff line change
Expand Up @@ -808,11 +808,10 @@ It includes two methods `__iter__` and `__next__`:
* `fn __iter__(slf: PyRefMut<Self>) -> PyResult<impl IntoPy<PyObject>>`
* `fn __next__(slf: PyRefMut<Self>) -> PyResult<Option<impl IntoPy<PyObject>>>`

Returning `Ok(None)` from `__next__` indicates that that there are no further items.
Returning `None` from `__next__` indicates that that there are no further items.
These two methods can be take either `PyRef<Self>` or `PyRefMut<Self>` as their
first argument, so that mutable borrow can be avoided if needed.


Example:

```rust
Expand Down Expand Up @@ -891,6 +890,52 @@ impl PyIterProtocol for Container {
For more details on Python's iteration protocols, check out [the "Iterator Types" section of the library
documentation](https://docs.python.org/3/library/stdtypes.html#iterator-types).

#### Returning a value from iteration

In Python it is possible to set a value in the `StopIteration` exception raised at the end of iteration,
for example with the following code:

```python
def iter_():
for i in range(5):
yield i
return "Ended"
```

When the generator `iter_()` returns with `"Ended"`, this is translated to a `StopIteration` exception `e` with `e.value = "Ended"`.

PyO3 supports this same behavior using the `IterNextOutput` enum. The Python code seen above could be translated into a Rust iterator like the following:

```rust
use pyo3::prelude::*;
use pyo3::PyIterProtocol;
use pyo3::class::iter::IterNextOutput;

#[pyclass]
struct Iter {
count: usize
}

#[pyproto]
impl PyIterProtocol for Iter {
fn __next__(mut slf: PyRefMut<Self>) -> IterNextOutput<usize, &'static str> {
if slf.count < 5 {
let out = slf.count;
slf.count += 1;
IterNextOutput::Yield(out)
} else {
IterNextOutput::Return("Ended")
}
}
}

# let gil = Python::acquire_gil();
# let py = gil.python();
# let inst = Py::new(py, Iter { count: 0 }).unwrap();
# pyo3::py_run!(py, inst, "assert next(inst) == 0");
# // test of StopIteration is done in examples/rustapi_module/pyclass_iter.rs
```

## How methods are implemented

Users should be able to define a `#[pyclass]` with or without `#[pymethods]`, while PyO3 needs a
Expand Down
39 changes: 30 additions & 9 deletions src/class/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ pub trait PyIterIterProtocol<'p>: PyIterProtocol<'p> {

pub trait PyIterNextProtocol<'p>: PyIterProtocol<'p> {
type Receiver: TryFromPyCell<'p, Self>;
type Result: IntoPyCallbackOutput<IterNextOutput>;
type Result: IntoPyCallbackOutput<PyIterNextOutput>;
}

#[derive(Default)]
Expand Down Expand Up @@ -64,22 +64,43 @@ impl PyIterMethods {
}
}

pub struct IterNextOutput(Option<PyObject>);
pub enum IterNextOutput<T, U> {
Yield(T),
Return(U),
}

pub type PyIterNextOutput = IterNextOutput<PyObject, PyObject>;

impl IntoPyCallbackOutput<*mut ffi::PyObject> for IterNextOutput {
impl IntoPyCallbackOutput<*mut ffi::PyObject> for PyIterNextOutput {
fn convert(self, _py: Python) -> PyResult<*mut ffi::PyObject> {
match self.0 {
Some(o) => Ok(o.into_ptr()),
None => Err(crate::exceptions::StopIteration::py_err(())),
match self {
IterNextOutput::Yield(o) => Ok(o.into_ptr()),
IterNextOutput::Return(opt) => Err(crate::exceptions::StopIteration::py_err((opt,))),
}
}
}

impl<T, U> IntoPyCallbackOutput<PyIterNextOutput> for IterNextOutput<T, U>
where
T: IntoPy<PyObject>,
U: IntoPy<PyObject>,
{
fn convert(self, py: Python) -> PyResult<PyIterNextOutput> {
match self {
IterNextOutput::Yield(o) => Ok(IterNextOutput::Yield(o.into_py(py))),
IterNextOutput::Return(o) => Ok(IterNextOutput::Return(o.into_py(py))),
}
}
}

impl<T> IntoPyCallbackOutput<IterNextOutput> for Option<T>
impl<T> IntoPyCallbackOutput<PyIterNextOutput> for Option<T>
where
T: IntoPy<PyObject>,
{
fn convert(self, py: Python) -> PyResult<IterNextOutput> {
Ok(IterNextOutput(self.map(|o| o.into_py(py))))
fn convert(self, py: Python) -> PyResult<PyIterNextOutput> {
match self {
Some(o) => Ok(PyIterNextOutput::Yield(o.into_py(py))),
None => Ok(PyIterNextOutput::Return(py.None())),
}
}
}
39 changes: 30 additions & 9 deletions src/class/pyasync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ pub trait PyAsyncAiterProtocol<'p>: PyAsyncProtocol<'p> {

pub trait PyAsyncAnextProtocol<'p>: PyAsyncProtocol<'p> {
type Receiver: TryFromPyCell<'p, Self>;
type Result: IntoPyCallbackOutput<IterANextOutput>;
type Result: IntoPyCallbackOutput<PyIterANextOutput>;
}

pub trait PyAsyncAenterProtocol<'p>: PyAsyncProtocol<'p> {
Expand Down Expand Up @@ -107,23 +107,44 @@ impl ffi::PyAsyncMethods {
}
}

pub struct IterANextOutput(Option<PyObject>);
pub enum IterANextOutput<T, U> {
Yield(T),
Return(U),
}

pub type PyIterANextOutput = IterANextOutput<PyObject, PyObject>;

impl IntoPyCallbackOutput<*mut ffi::PyObject> for IterANextOutput {
impl IntoPyCallbackOutput<*mut ffi::PyObject> for PyIterANextOutput {
fn convert(self, _py: Python) -> PyResult<*mut ffi::PyObject> {
match self.0 {
Some(o) => Ok(o.into_ptr()),
None => Err(crate::exceptions::StopAsyncIteration::py_err(())),
match self {
IterANextOutput::Yield(o) => Ok(o.into_ptr()),
IterANextOutput::Return(opt) => Err(crate::exceptions::StopAsyncIteration::py_err((opt,))),
}
}
}

impl<T, U> IntoPyCallbackOutput<PyIterANextOutput> for IterANextOutput<T, U>
where
T: IntoPy<PyObject>,
U: IntoPy<PyObject>,
{
fn convert(self, py: Python) -> PyResult<PyIterANextOutput> {
match self {
IterANextOutput::Yield(o) => Ok(IterANextOutput::Yield(o.into_py(py))),
IterANextOutput::Return(o) => Ok(IterANextOutput::Return(o.into_py(py))),
}
}
}

impl<T> IntoPyCallbackOutput<IterANextOutput> for Option<T>
impl<T> IntoPyCallbackOutput<PyIterANextOutput> for Option<T>
where
T: IntoPy<PyObject>,
{
fn convert(self, py: Python) -> PyResult<IterANextOutput> {
Ok(IterANextOutput(self.map(|o| o.into_py(py))))
fn convert(self, py: Python) -> PyResult<PyIterANextOutput> {
match self {
Some(o) => Ok(PyIterANextOutput::Yield(o.into_py(py))),
None => Ok(PyIterANextOutput::Return(py.None())),
}
}
}

Expand Down

0 comments on commit a9827a4

Please sign in to comment.