diff --git a/newsfragments/3185.fixed.md b/newsfragments/3185.fixed.md new file mode 100644 index 00000000000..6b3a594c91a --- /dev/null +++ b/newsfragments/3185.fixed.md @@ -0,0 +1 @@ +Fix conversion of classes implementing `__complex__` to `Complex` when using `abi3` or PyPy. diff --git a/src/conversions/num_complex.rs b/src/conversions/num_complex.rs index 217d862a542..df6b54b45bc 100644 --- a/src/conversions/num_complex.rs +++ b/src/conversions/num_complex.rs @@ -152,6 +152,18 @@ macro_rules! complex_conversion { #[cfg(any(Py_LIMITED_API, PyPy))] unsafe { + let obj = if obj.is_instance_of::() { + obj + } else if let Some(method) = + obj.lookup_special(crate::intern!(obj.py(), "__complex__"))? + { + method.call0()? + } else { + // `obj` might still implement `__float__` or `__index__`, which will be + // handled by `PyComplex_{Real,Imag}AsDouble`, including propagating any + // errors if those methods don't exist / raise exceptions. + obj + }; let ptr = obj.as_ptr(); let real = ffi::PyComplex_RealAsDouble(ptr); if real == -1.0 { @@ -172,6 +184,7 @@ complex_conversion!(f64); #[cfg(test)] mod tests { use super::*; + use crate::types::PyModule; #[test] fn from_complex() { @@ -197,4 +210,131 @@ mod tests { assert!(obj.extract::>(py).is_err()); }); } + #[test] + fn from_python_magic() { + Python::with_gil(|py| { + let module = PyModule::from_code( + py, + r#" +class A: + def __complex__(self): return 3.0+1.2j +class B: + def __float__(self): return 3.0 +class C: + def __index__(self): return 3 + "#, + "test.py", + "test", + ) + .unwrap(); + let from_complex = module.getattr("A").unwrap().call0().unwrap(); + assert_eq!( + from_complex.extract::>().unwrap(), + Complex::new(3.0, 1.2) + ); + let from_float = module.getattr("B").unwrap().call0().unwrap(); + assert_eq!( + from_float.extract::>().unwrap(), + Complex::new(3.0, 0.0) + ); + // Before Python 3.8, `__index__` wasn't tried by `float`/`complex`. + #[cfg(Py_3_8)] + { + let from_index = module.getattr("C").unwrap().call0().unwrap(); + assert_eq!( + from_index.extract::>().unwrap(), + Complex::new(3.0, 0.0) + ); + } + }) + } + #[test] + fn from_python_inherited_magic() { + Python::with_gil(|py| { + let module = PyModule::from_code( + py, + r#" +class First: pass +class ComplexMixin: + def __complex__(self): return 3.0+1.2j +class FloatMixin: + def __float__(self): return 3.0 +class IndexMixin: + def __index__(self): return 3 +class A(First, ComplexMixin): pass +class B(First, FloatMixin): pass +class C(First, IndexMixin): pass + "#, + "test.py", + "test", + ) + .unwrap(); + let from_complex = module.getattr("A").unwrap().call0().unwrap(); + assert_eq!( + from_complex.extract::>().unwrap(), + Complex::new(3.0, 1.2) + ); + let from_float = module.getattr("B").unwrap().call0().unwrap(); + assert_eq!( + from_float.extract::>().unwrap(), + Complex::new(3.0, 0.0) + ); + #[cfg(Py_3_8)] + { + let from_index = module.getattr("C").unwrap().call0().unwrap(); + assert_eq!( + from_index.extract::>().unwrap(), + Complex::new(3.0, 0.0) + ); + } + }) + } + #[test] + fn from_python_noncallable_descriptor_magic() { + // Functions and lambdas implement the descriptor protocol in a way that makes + // `type(inst).attr(inst)` equivalent to `inst.attr()` for methods, but this isn't the only + // way the descriptor protocol might be implemented. + Python::with_gil(|py| { + let module = PyModule::from_code( + py, + r#" +class A: + @property + def __complex__(self): + return lambda: 3.0+1.2j + "#, + "test.py", + "test", + ) + .unwrap(); + let obj = module.getattr("A").unwrap().call0().unwrap(); + assert_eq!( + obj.extract::>().unwrap(), + Complex::new(3.0, 1.2) + ); + }) + } + #[test] + fn from_python_nondescriptor_magic() { + // Magic methods don't need to implement the descriptor protocol, if they're callable. + Python::with_gil(|py| { + let module = PyModule::from_code( + py, + r#" +class MyComplex: + def __call__(self): return 3.0+1.2j +class A: + __complex__ = MyComplex() + "#, + "test.py", + "test", + ) + .unwrap(); + let obj = module.getattr("A").unwrap().call0().unwrap(); + assert_eq!( + obj.extract::>().unwrap(), + Complex::new(3.0, 1.2) + ); + }) + } } diff --git a/src/types/any.rs b/src/types/any.rs index afdeb6ab573..d6ad2cbc4d7 100644 --- a/src/types/any.rs +++ b/src/types/any.rs @@ -124,6 +124,55 @@ impl PyAny { } } + /// Retrieve an attribute value, skipping the instance dictionary during the lookup but still + /// binding the object to the instance. + /// + /// This is useful when trying to resolve Python's "magic" methods like `__getitem__`, which + /// are looked up starting from the type object. This returns an `Option` as it is not + /// typically a direct error for the special lookup to fail, as magic methods are optional in + /// many situations in which they might be called. + /// + /// To avoid repeated temporary allocations of Python strings, the [`intern!`] macro can be used + /// to intern `attr_name`. + #[allow(dead_code)] // Currently only used with num-complex+abi3, so dead without that. + pub(crate) fn lookup_special(&self, attr_name: N) -> PyResult> + where + N: IntoPy>, + { + let py = self.py(); + let self_type = self.get_type(); + let attr = if let Ok(attr) = self_type.getattr(attr_name) { + attr + } else { + return Ok(None); + }; + + // Manually resolve descriptor protocol. + unsafe { + if cfg!(Py_3_10) + || ffi::PyType_HasFeature(attr.get_type_ptr(), ffi::Py_TPFLAGS_HEAPTYPE) != 0 + { + // This is the preferred faster path, but does not work on static types (generally, + // types defined in extension modules) before Python 3.10. + let descr_get_ptr = ffi::PyType_GetSlot(attr.get_type_ptr(), ffi::Py_tp_descr_get); + if descr_get_ptr.is_null() { + return Ok(Some(attr)); + } + let descr_get: ffi::descrgetfunc = std::mem::transmute(descr_get_ptr); + let ret = descr_get(attr.as_ptr(), self.as_ptr(), self_type.as_ptr()); + if ret.is_null() { + Err(PyErr::fetch(py)) + } else { + Ok(Some(py.from_owned_ptr(ret))) + } + } else if let Ok(descr_get) = attr.get_type().getattr(crate::intern!(py, "__get__")) { + descr_get.call1((attr, self, self_type)).map(Some) + } else { + Ok(Some(attr)) + } + } + } + /// Sets an attribute value. /// /// This is equivalent to the Python expression `self.attr_name = value`. @@ -974,9 +1023,82 @@ impl PyAny { #[cfg(test)] mod tests { use crate::{ - types::{IntoPyDict, PyBool, PyList, PyLong, PyModule}, + types::{IntoPyDict, PyAny, PyBool, PyList, PyLong, PyModule}, Python, ToPyObject, }; + + #[test] + fn test_lookup_special() { + Python::with_gil(|py| { + let module = PyModule::from_code( + py, + r#" +class CustomCallable: + def __call__(self): + return 1 + +class SimpleInt: + def __int__(self): + return 1 + +class InheritedInt(SimpleInt): pass + +class NoInt: pass + +class NoDescriptorInt: + __int__ = CustomCallable() + +class InstanceOverrideInt: + def __int__(self): + return 1 +instance_override = InstanceOverrideInt() +instance_override.__int__ = lambda self: 2 + +class ErrorInDescriptorInt: + @property + def __int__(self): + raise ValueError("uh-oh!") + +class NonHeapNonDescriptorInt: + # A static-typed callable that doesn't implement `__get__`. These are pretty hard to come by. + __int__ = int + "#, + "test.py", + "test", + ) + .unwrap(); + + let int = crate::intern!(py, "__int__"); + let eval_int = + |obj: &PyAny| obj.lookup_special(int)?.unwrap().call0()?.extract::(); + + let simple = module.getattr("SimpleInt").unwrap().call0().unwrap(); + assert_eq!(eval_int(simple).unwrap(), 1); + let inherited = module.getattr("InheritedInt").unwrap().call0().unwrap(); + assert_eq!(eval_int(inherited).unwrap(), 1); + let no_descriptor = module.getattr("NoDescriptorInt").unwrap().call0().unwrap(); + assert_eq!(eval_int(no_descriptor).unwrap(), 1); + let missing = module.getattr("NoInt").unwrap().call0().unwrap(); + assert!(missing.lookup_special(int).unwrap().is_none()); + // Note the instance override should _not_ call the instance method that returns 2, + // because that's not how special lookups are meant to work. + let instance_override = module.getattr("instance_override").unwrap(); + assert_eq!(eval_int(instance_override).unwrap(), 1); + let descriptor_error = module + .getattr("ErrorInDescriptorInt") + .unwrap() + .call0() + .unwrap(); + assert!(descriptor_error.lookup_special(int).is_err()); + let nonheap_nondescriptor = module + .getattr("NonHeapNonDescriptorInt") + .unwrap() + .call0() + .unwrap(); + assert_eq!(eval_int(nonheap_nondescriptor).unwrap(), 0); + }) + } + #[test] fn test_call_for_non_existing_method() { Python::with_gil(|py| {