diff --git a/pytests/src/pyclasses.rs b/pytests/src/pyclasses.rs index 48fa628b0c7..46c8523c2dd 100644 --- a/pytests/src/pyclasses.rs +++ b/pytests/src/pyclasses.rs @@ -58,10 +58,14 @@ impl AssertingBaseClass { } } +#[pyclass] +struct ClassWithoutConstructor; + #[pymodule] pub fn pyclasses(_py: Python<'_>, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; Ok(()) } diff --git a/pytests/tests/test_pyclasses.py b/pytests/tests/test_pyclasses.py index 4a45b413669..aa6de694c99 100644 --- a/pytests/tests/test_pyclasses.py +++ b/pytests/tests/test_pyclasses.py @@ -1,3 +1,4 @@ +from typing import Type import pytest from pyo3_pytests import pyclasses @@ -32,7 +33,29 @@ class AssertingSubClass(pyclasses.AssertingBaseClass): def test_new_classmethod(): - # The `AssertingBaseClass` constructor errors if it is not passed the relevant subclass. + # The `AssertingBaseClass` constructor errors if it is not passed the + # relevant subclass. _ = AssertingSubClass(expected_type=AssertingSubClass) with pytest.raises(ValueError): _ = AssertingSubClass(expected_type=str) + + +class ClassWithoutConstructorPy: + def __new__(cls): + raise TypeError("No constructor defined") + + +@pytest.mark.parametrize( + "cls", [pyclasses.ClassWithoutConstructor, ClassWithoutConstructorPy] +) +def test_no_constructor_defined_propagates_cause(cls: Type): + original_error = ValueError("Original message") + with pytest.raises(Exception) as exc_info: + try: + raise original_error + except Exception: + cls() # should raise TypeError("No constructor defined") + + assert exc_info.type is TypeError + assert exc_info.value.args == ("No constructor defined",) + assert exc_info.value.__context__ is original_error diff --git a/src/err/err_state.rs b/src/err/err_state.rs index 2a32387ba7e..ccf243f81b8 100644 --- a/src/err/err_state.rs +++ b/src/err/err_state.rs @@ -15,6 +15,21 @@ pub(crate) struct PyErrStateNormalized { } impl PyErrStateNormalized { + fn from_value(pvalue: &PyBaseException) -> Self { + Self { + #[cfg(not(Py_3_12))] + ptype: pvalue.get_type().into(), + pvalue: pvalue.into(), + #[cfg(not(Py_3_12))] + ptraceback: unsafe { + Py::from_owned_ptr_or_opt( + pvalue.py(), + ffi::PyException_GetTraceback(pvalue.as_ptr()), + ) + }, + } + } + #[cfg(not(Py_3_12))] pub(crate) fn ptype<'py>(&'py self, py: Python<'py>) -> &'py PyType { self.ptype.as_ref(py) @@ -82,70 +97,72 @@ impl PyErrState { } pub(crate) fn normalized(pvalue: &PyBaseException) -> Self { - Self::Normalized(PyErrStateNormalized { - #[cfg(not(Py_3_12))] - ptype: pvalue.get_type().into(), - pvalue: pvalue.into(), - #[cfg(not(Py_3_12))] - ptraceback: unsafe { - Py::from_owned_ptr_or_opt( - pvalue.py(), - ffi::PyException_GetTraceback(pvalue.as_ptr()), - ) - }, - }) + Self::Normalized(PyErrStateNormalized::from_value(pvalue)) } - #[cfg(not(Py_3_12))] - pub(crate) fn into_ffi_tuple( - self, - py: Python<'_>, - ) -> (*mut ffi::PyObject, *mut ffi::PyObject, *mut ffi::PyObject) { + pub(crate) fn normalize(self, py: Python<'_>) -> PyErrStateNormalized { + use crate::{types::PyTuple, PyResult}; + match self { PyErrState::Lazy(lazy) => { + fn exceptions_must_derive_from_base_exception( + py: Python<'_>, + ) -> PyResult<&PyBaseException> { + PyTypeError::type_object(py) + .call1(("exceptions must derive from BaseException",)) + .map(|any| unsafe { any.downcast_unchecked::() }) + } + let PyErrStateLazyFnOutput { ptype, pvalue } = lazy(py); - if unsafe { ffi::PyExceptionClass_Check(ptype.as_ptr()) } == 0 { - PyErrState::lazy( - PyTypeError::type_object(py), - "exceptions must derive from BaseException", - ) - .into_ffi_tuple(py) + let result = if unsafe { ffi::PyExceptionClass_Check(ptype.as_ptr()) } == 0 { + exceptions_must_derive_from_base_exception(py) } else { - (ptype.into_ptr(), pvalue.into_ptr(), std::ptr::null_mut()) + // already an exception instance + let result = if let Ok(base_exc) = pvalue.downcast::(py) { + return PyErrStateNormalized::from_value(base_exc); + } else if pvalue.is_none(py) { + ptype.as_ref(py).call0() + } else if let Ok(tup) = pvalue.as_ref(py).downcast::() { + ptype.as_ref(py).call1(tup) + } else { + ptype.as_ref(py).call1((pvalue,)) + }; + result.and_then(|any| match any.downcast::() { + Ok(base_exc) => Ok(base_exc), + Err(_) => exceptions_must_derive_from_base_exception(py), + }) + }; + + match result { + Ok(base_exc) => PyErrStateNormalized::from_value(base_exc), + Err(e) => e + .state + .into_inner() + .expect("exception is not being normalized") + .normalize(py), } } + #[cfg(not(Py_3_12))] PyErrState::FfiTuple { ptype, pvalue, ptraceback, - } => ( - ptype.into_ptr(), - pvalue.map_or(std::ptr::null_mut(), Py::into_ptr), - ptraceback.map_or(std::ptr::null_mut(), Py::into_ptr), - ), - PyErrState::Normalized(PyErrStateNormalized { - ptype, - pvalue, - ptraceback, - }) => ( - ptype.into_ptr(), - pvalue.into_ptr(), - ptraceback.map_or(std::ptr::null_mut(), Py::into_ptr), - ), - } - } - - #[cfg(not(Py_3_12))] - pub(crate) fn normalize(self, py: Python<'_>) -> PyErrStateNormalized { - let (mut ptype, mut pvalue, mut ptraceback) = self.into_ffi_tuple(py); - - unsafe { - ffi::PyErr_NormalizeException(&mut ptype, &mut pvalue, &mut ptraceback); - PyErrStateNormalized { - ptype: Py::from_owned_ptr_or_opt(py, ptype).expect("Exception type missing"), - pvalue: Py::from_owned_ptr_or_opt(py, pvalue).expect("Exception value missing"), - ptraceback: Py::from_owned_ptr_or_opt(py, ptraceback), + } => { + let mut ptype = ptype.into_ptr(); + let mut pvalue = pvalue.map_or(std::ptr::null_mut(), Py::into_ptr); + let mut ptraceback = ptraceback.map_or(std::ptr::null_mut(), Py::into_ptr); + unsafe { + ffi::PyErr_NormalizeException(&mut ptype, &mut pvalue, &mut ptraceback); + PyErrStateNormalized { + ptype: Py::from_owned_ptr_or_opt(py, ptype) + .expect("Exception type missing"), + pvalue: Py::from_owned_ptr_or_opt(py, pvalue) + .expect("Exception value missing"), + ptraceback: Py::from_owned_ptr_or_opt(py, ptraceback), + } + } } + PyErrState::Normalized(normalized) => normalized, } } @@ -159,13 +176,6 @@ impl PyErrState { PyErrStateNormalized { pvalue } } - #[cfg(not(Py_3_12))] - pub(crate) fn restore(self, py: Python<'_>) { - let (ptype, pvalue, ptraceback) = self.into_ffi_tuple(py); - unsafe { ffi::PyErr_Restore(ptype, pvalue, ptraceback) } - } - - #[cfg(Py_3_12)] pub(crate) fn restore(self, py: Python<'_>) { match self { PyErrState::Lazy(lazy) => { @@ -183,8 +193,42 @@ impl PyErrState { } } } - PyErrState::Normalized(PyErrStateNormalized { pvalue }) => unsafe { - ffi::PyErr_SetRaisedException(pvalue.into_ptr()) + #[cfg(not(Py_3_12))] + PyErrState::FfiTuple { + ptype, + pvalue, + ptraceback, + } => unsafe { + ffi::PyErr_Restore( + ptype.into_ptr(), + pvalue.map_or(std::ptr::null_mut(), Py::into_ptr), + ptraceback.map_or(std::ptr::null_mut(), Py::into_ptr), + ) + }, + PyErrState::Normalized(PyErrStateNormalized { + #[cfg(not(Py_3_12))] + ptype, + pvalue, + #[cfg(not(Py_3_12))] + ptraceback, + }) => unsafe { + #[cfg(not(Py_3_12))] + { + ffi::PyErr_Restore( + ptype.into_ptr(), + pvalue.into_ptr(), + ptraceback.map_or(std::ptr::null_mut(), Py::into_ptr), + ) + } + + // FIXME if the exception has no traceback, we should probably add one + // FIXME if sys.exc_info is set (i.e. an exception is being handled), + // we should chain it. + + #[cfg(Py_3_12)] + { + ffi::PyErr_SetRaisedException(pvalue.into_ptr()) + } }, } }