diff --git a/CHANGELOG.md b/CHANGELOG.md index ba89f0d5f48..d0eea9506d2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Add `PyByteArray::data`, `PyByteArray::as_bytes`, and `PyByteArray::as_bytes_mut`. [#967](https://github.com/PyO3/pyo3/pull/967) - Add `GILOnceCell` to use in situations where `lazy_static` or `once_cell` can deadlock. [#975](https://github.com/PyO3/pyo3/pull/975) - Add `Py::borrow`, `Py::borrow_mut`, `Py::try_borrow`, and `Py::try_borrow_mut` for accessing `#[pyclass]` values. [#976](https://github.com/PyO3/pyo3/pull/976) +- Add `IterNextOutput` and `IterANextOutput` for returning from `__next__` / `__anext__`. [#997](https://github.com/PyO3/pyo3/pull/997) ### Changed - Simplify internals of `#[pyo3(get)]` attribute. (Remove the hidden API `GetPropertyValue`.) [#934](https://github.com/PyO3/pyo3/pull/934) diff --git a/examples/rustapi_module/setup.py b/examples/rustapi_module/setup.py index c90755f5c71..f1fe9002891 100644 --- a/examples/rustapi_module/setup.py +++ b/examples/rustapi_module/setup.py @@ -99,6 +99,7 @@ def make_rust_extension(module_name): make_rust_extension("rustapi_module.othermod"), make_rust_extension("rustapi_module.subclassing"), make_rust_extension("rustapi_module.test_dict"), + make_rust_extension("rustapi_module.pyclass_iter"), ], install_requires=install_requires, tests_require=tests_require, diff --git a/examples/rustapi_module/src/lib.rs b/examples/rustapi_module/src/lib.rs index 588ffa7239c..ce63565a494 100644 --- a/examples/rustapi_module/src/lib.rs +++ b/examples/rustapi_module/src/lib.rs @@ -3,4 +3,5 @@ pub mod datetime; pub mod dict_iter; pub mod objstore; pub mod othermod; +pub mod pyclass_iter; pub mod subclassing; diff --git a/examples/rustapi_module/src/pyclass_iter.rs b/examples/rustapi_module/src/pyclass_iter.rs new file mode 100644 index 00000000000..779028de5d1 --- /dev/null +++ b/examples/rustapi_module/src/pyclass_iter.rs @@ -0,0 +1,34 @@ +use pyo3::class::iter::{IterNextOutput, PyIterProtocol}; +use pyo3::prelude::*; + +/// This is for demonstrating how to return a value from __next__ +#[pyclass] +struct PyClassIter { + count: usize, +} + +#[pymethods] +impl PyClassIter { + #[new] + pub fn new() -> Self { + PyClassIter { count: 0 } + } +} + +#[pyproto] +impl PyIterProtocol for PyClassIter { + fn __next__(mut slf: PyRefMut) -> IterNextOutput { + if slf.count < 5 { + slf.count += 1; + IterNextOutput::Yield(slf.count.into_py(slf.py())) + } else { + IterNextOutput::Return(Some("Ended".into_py(slf.py()))) + } + } +} + +#[pymodule] +pub fn pyclass_iter(_py: Python, m: &PyModule) -> PyResult<()> { + m.add_class::()?; + Ok(()) +} diff --git a/examples/rustapi_module/tests/test_pyclass_iter.py b/examples/rustapi_module/tests/test_pyclass_iter.py new file mode 100644 index 00000000000..653f738209c --- /dev/null +++ b/examples/rustapi_module/tests/test_pyclass_iter.py @@ -0,0 +1,14 @@ +import pytest +from rustapi_module import pyclass_iter + +def test_iter(): + i = pyclass_iter.PyClassIter() + assert next(i) == 1 + assert next(i) == 2 + assert next(i) == 3 + assert next(i) == 4 + assert next(i) == 5 + + with pytest.raises(StopIteration) as excinfo: + next(i) + assert excinfo.value.value == "Ended" diff --git a/src/class/iter.rs b/src/class/iter.rs index bb3e3a9d10d..cbdefef6511 100644 --- a/src/class/iter.rs +++ b/src/class/iter.rs @@ -64,22 +64,37 @@ impl PyIterMethods { } } -pub struct IterNextOutput(Option); +pub enum IterNextOutput { + Yield(PyObject), + Return(Option), +} impl IntoPyCallbackOutput<*mut ffi::PyObject> for IterNextOutput { fn convert(self, _py: Python) -> PyResult<*mut ffi::PyObject> { - match self.0 { - Some(o) => Ok(o.into_ptr()), - None => Err(crate::exceptions::StopIteration::py_err(())), + match self { + IterNextOutput::Yield(o) => Ok(o.into_ptr()), + IterNextOutput::Return(opt) => match opt { + Some(o) => Err(crate::exceptions::StopIteration::py_err((o,))), + None => Err(crate::exceptions::StopIteration::py_err(())), + }, } } } +impl IntoPyCallbackOutput for IterNextOutput { + fn convert(self, _py: Python) -> PyResult { + Ok(self) + } +} + impl IntoPyCallbackOutput for Option where T: IntoPy, { fn convert(self, py: Python) -> PyResult { - Ok(IterNextOutput(self.map(|o| o.into_py(py)))) + match self { + Some(o) => Ok(IterNextOutput::Yield(o.into_py(py))), + None => Ok(IterNextOutput::Return(None)), + } } } diff --git a/src/class/pyasync.rs b/src/class/pyasync.rs index 83df141053d..5b77df6a976 100644 --- a/src/class/pyasync.rs +++ b/src/class/pyasync.rs @@ -107,23 +107,38 @@ impl ffi::PyAsyncMethods { } } -pub struct IterANextOutput(Option); +pub enum IterANextOutput { + Yield(PyObject), + Return(Option), +} impl IntoPyCallbackOutput<*mut ffi::PyObject> for IterANextOutput { fn convert(self, _py: Python) -> PyResult<*mut ffi::PyObject> { - match self.0 { - Some(o) => Ok(o.into_ptr()), - None => Err(crate::exceptions::StopAsyncIteration::py_err(())), + match self { + IterANextOutput::Yield(o) => Ok(o.into_ptr()), + IterANextOutput::Return(opt) => match opt { + Some(o) => Err(crate::exceptions::StopAsyncIteration::py_err((o,))), + None => Err(crate::exceptions::StopAsyncIteration::py_err(())), + }, } } } +impl IntoPyCallbackOutput for IterANextOutput { + fn convert(self, _py: Python) -> PyResult { + Ok(self) + } +} + impl IntoPyCallbackOutput for Option where T: IntoPy, { fn convert(self, py: Python) -> PyResult { - Ok(IterANextOutput(self.map(|o| o.into_py(py)))) + match self { + Some(o) => Ok(IterANextOutput::Yield(o.into_py(py))), + None => Ok(IterANextOutput::Return(None)), + } } }