From e63e0cbf5adefec629b9c3a92af48bdede188664 Mon Sep 17 00:00:00 2001 From: kngwyu Date: Wed, 4 Mar 2020 01:23:32 +0900 Subject: [PATCH 1/3] Make it enable to take &PyClass as arguments as pyfunctions/methods --- pyo3-derive-backend/src/method.rs | 64 +++++------------------- pyo3-derive-backend/src/module.rs | 2 +- pyo3-derive-backend/src/pyclass.rs | 9 ++++ pyo3-derive-backend/src/pymethod.rs | 77 ++++++++++++++++++----------- src/derive_utils.rs | 15 ++++++ tests/test_dunder.rs | 0 tests/test_methods.rs | 56 +++++++++++++++++++++ tests/test_module.rs | 26 +++++++++- 8 files changed, 166 insertions(+), 83 deletions(-) mode change 100755 => 100644 tests/test_dunder.rs diff --git a/pyo3-derive-backend/src/method.rs b/pyo3-derive-backend/src/method.rs index 8042ce1169d..6d06c377945 100644 --- a/pyo3-derive-backend/src/method.rs +++ b/pyo3-derive-backend/src/method.rs @@ -109,12 +109,11 @@ impl<'a> FnSpec<'a> { let py = crate::utils::if_type_is_python(ty); - let opt = check_arg_ty_and_optional(name, ty); + let opt = check_ty_optional(ty); arguments.push(FnArg { name: ident, by_ref, mutability, - // mode: mode, ty, optional: opt, py, @@ -305,55 +304,18 @@ pub fn is_ref(name: &syn::Ident, ty: &syn::Type) -> bool { false } -pub fn check_arg_ty_and_optional<'a>( - name: &'a syn::Ident, - ty: &'a syn::Type, -) -> Option<&'a syn::Type> { - match ty { - syn::Type::Path(syn::TypePath { ref path, .. }) => { - //if let Some(ref qs) = qs { - // panic!("explicit Self type in a 'qualified path' is not supported: {:?} - {:?}", - // name, qs); - //} - - if let Some(segment) = path.segments.last() { - match segment.ident.to_string().as_str() { - "Option" => match segment.arguments { - syn::PathArguments::AngleBracketed(ref params) => { - if params.args.len() != 1 { - panic!("argument type is not supported by python method: {:?} ({:?}) {:?}", - name, - ty, - path); - } - - match ¶ms.args[0] { - syn::GenericArgument::Type(ref ty) => Some(ty), - _ => panic!("argument type is not supported by python method: {:?} ({:?}) {:?}", - name, - ty, - path), - } - } - _ => { - panic!( - "argument type is not supported by python method: {:?} ({:?}) {:?}", - name, ty, path - ); - } - }, - _ => None, - } - } else { - None - } - } - _ => { - None - //panic!("argument type is not supported by python method: {:?} ({:?})", - //name, - //ty); - } +pub(crate) fn check_ty_optional<'a>(ty: &'a syn::Type) -> Option<&'a syn::Type> { + let path = match ty { + syn::Type::Path(syn::TypePath { ref path, .. }) => path, + _ => return None, + }; + let seg = path.segments.last().filter(|s| s.ident == "Option")?; + match seg.arguments { + syn::PathArguments::AngleBracketed(ref params) => match params.args.first() { + Some(syn::GenericArgument::Type(ref ty)) => Some(ty), + _ => None, + }, + _ => None, } } diff --git a/pyo3-derive-backend/src/module.rs b/pyo3-derive-backend/src/module.rs index 64c88be2212..bd9d601adee 100644 --- a/pyo3-derive-backend/src/module.rs +++ b/pyo3-derive-backend/src/module.rs @@ -68,7 +68,7 @@ fn wrap_fn_argument<'a>(cap: &'a syn::PatType, name: &'a Ident) -> syn::Result pyo3::derive_utils::ExtractExt<'a> for &'a #cls + { + type Target = pyo3::PyRef<'a, #cls>; + } + impl<'a> pyo3::derive_utils::ExtractExt<'a> for &'a mut #cls + { + type Target = pyo3::PyRefMut<'a, #cls>; + } + #into_pyobject #inventory_impl diff --git a/pyo3-derive-backend/src/pymethod.rs b/pyo3-derive-backend/src/pymethod.rs index 0449c204395..7149d9c7a27 100644 --- a/pyo3-derive-backend/src/pymethod.rs +++ b/pyo3-derive-backend/src/pymethod.rs @@ -387,10 +387,7 @@ pub(crate) fn impl_wrap_setter( }; match _result { Ok(_) => 0, - Err(e) => { - e.restore(_py); - -1 - } + Err(e) => e.restore_and_minus1(_py), } } }) @@ -523,45 +520,65 @@ fn impl_arg_param( } let arg_value = quote!(output[#option_pos]); *option_pos += 1; - if arg.optional.is_some() { - let default = if let Some(d) = spec.default_value(name) { - if d.to_string() == "None" { - quote! { None } - } else { - quote! { Some(#d) } - } + + return if let Some(ty) = arg.optional.as_ref() { + let default = if let Some(d) = spec.default_value(name).filter(|d| d.to_string() != "None") + { + quote! { Some(#d) } } else { quote! { None } }; - quote! { - let #arg_name = match #arg_value.as_ref() { - Some(_obj) => { - if _obj.is_none() { - #default - } else { - Some(_obj.extract()?) - } - }, - None => #default + if let syn::Type::Reference(tref) = ty { + let (tref, mut_) = tref_preprocess(tref); + let as_deref = if mut_.is_some() { + quote! { as_deref_mut } + } else { + quote! { as_deref } }; + // Get Option<&T> from Option> + quote! { + let #mut_ _tmp = match #arg_value.as_ref().filter(|obj| !obj.is_none()) { + Some(_obj) => { + Some(_obj.extract::<<#tref as pyo3::derive_utils::ExtractExt>::Target>()?) + }, + None => #default, + }; + let #arg_name = _tmp.#as_deref(); + } + } else { + quote! { + let #arg_name = match #arg_value.as_ref().filter(|obj| !obj.is_none()) { + Some(_obj) => Some(_obj.extract()?), + None => #default, + }; + } } } else if let Some(default) = spec.default_value(name) { quote! { - let #arg_name = match #arg_value.as_ref() { - Some(_obj) => { - if _obj.is_none() { - #default - } else { - _obj.extract()? - } - }, - None => #default + let #arg_name = match #arg_value.as_ref().filter(|obj| !obj.is_none()) { + Some(_obj) => _obj.extract()?, + None => #default, }; } + } else if let syn::Type::Reference(tref) = arg.ty { + let (tref, mut_) = tref_preprocess(tref); + // Get &T from PyRef + quote! { + let #mut_ _tmp: <#tref as pyo3::derive_utils::ExtractExt>::Target + = #arg_value.unwrap().extract()?; + let #arg_name = &#mut_ *_tmp; + } } else { quote! { let #arg_name = #arg_value.unwrap().extract()?; } + }; + + fn tref_preprocess(tref: &syn::TypeReference) -> (syn::TypeReference, Option) { + let mut tref = tref.to_owned(); + tref.lifetime = None; + let mut_ = tref.mutability; + (tref, mut_) } } diff --git a/src/derive_utils.rs b/src/derive_utils.rs index 653f7c316b5..baa134f6e8a 100644 --- a/src/derive_utils.rs +++ b/src/derive_utils.rs @@ -198,6 +198,7 @@ impl>> IntoPyNewResult for PyRes } } +#[doc(hidden)] pub trait GetPropertyValue { fn get_property_value(&self, py: Python) -> PyObject; } @@ -218,6 +219,7 @@ impl GetPropertyValue for PyObject { } /// Utilities for basetype +#[doc(hidden)] pub trait PyBaseTypeUtils { type Dict; type WeakRef; @@ -231,3 +233,16 @@ impl PyBaseTypeUtils for T { type LayoutAsBase = crate::pycell::PyCellInner; type BaseNativeType = T::BaseNativeType; } + +/// Utility trait to enable &PyClass as a pymethod/function argument +#[doc(hidden)] +pub trait ExtractExt<'a> { + type Target: crate::FromPyObject<'a>; +} + +impl<'a, T> ExtractExt<'a> for T +where + T: crate::FromPyObject<'a>, +{ + type Target = T; +} diff --git a/tests/test_dunder.rs b/tests/test_dunder.rs old mode 100755 new mode 100644 diff --git a/tests/test_methods.rs b/tests/test_methods.rs index 24a5d3bf434..906b887dc09 100644 --- a/tests/test_methods.rs +++ b/tests/test_methods.rs @@ -405,6 +405,62 @@ fn method_with_lifetime() { ); } +#[pyclass] +struct MethodWithPyClassArg { + #[pyo3(get)] + value: i64, +} + +#[pymethods] +impl MethodWithPyClassArg { + fn add(&self, other: &MethodWithPyClassArg) -> MethodWithPyClassArg { + MethodWithPyClassArg { + value: self.value + other.value, + } + } + fn add_pyref(&self, other: PyRef) -> MethodWithPyClassArg { + MethodWithPyClassArg { + value: self.value + other.value, + } + } + fn inplace_add(&self, other: &mut MethodWithPyClassArg) { + other.value += self.value; + } + fn optional_add(&self, other: Option<&MethodWithPyClassArg>) -> MethodWithPyClassArg { + MethodWithPyClassArg { + value: self.value + other.map(|o| o.value).unwrap_or(10), + } + } +} + +#[test] +fn method_with_pyclassarg() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let obj1 = PyCell::new(py, MethodWithPyClassArg { value: 10 }).unwrap(); + let obj2 = PyCell::new(py, MethodWithPyClassArg { value: 10 }).unwrap(); + py_run!( + py, + obj1 obj2, + "obj = obj1.add(obj2); assert obj.value == 20" + ); + py_run!( + py, + obj1 obj2, + "obj = obj1.add_pyref(obj2); assert obj.value == 20" + ); + py_run!( + py, + obj1 obj2, + "obj = obj1.optional_add(); assert obj.value == 20" + ); + py_run!( + py, + obj1 obj2, + "obj1.inplace_add(obj2); assert obj2.value == 20" + ); +} + #[pyclass] #[cfg(unix)] struct CfgStruct {} diff --git a/tests/test_module.rs b/tests/test_module.rs index 198f2f05cdc..535a74c943e 100644 --- a/tests/test_module.rs +++ b/tests/test_module.rs @@ -7,6 +7,19 @@ mod common; #[pyclass] struct AnonClass {} +#[pyclass] +struct ValueClass { + value: usize, +} + +#[pymethods] +impl ValueClass { + #[new] + fn new(value: usize) -> ValueClass { + ValueClass { value } + } +} + #[pyclass(module = "module")] struct LocatedClass {} @@ -36,7 +49,13 @@ fn module_with_functions(py: Python, m: &PyModule) -> PyResult<()> { Ok(42) } + #[pyfn(m, "double_value")] + fn double_value(v: &ValueClass) -> usize { + v.value * 2 + } + m.add_class::().unwrap(); + m.add_class::().unwrap(); m.add_class::().unwrap(); m.add("foo", "bar").unwrap(); @@ -60,7 +79,11 @@ fn test_module_with_functions() { )] .into_py_dict(py); - let run = |code| py.run(code, None, Some(d)).unwrap(); + let run = |code| { + py.run(code, None, Some(d)) + .map_err(|e| e.print(py)) + .unwrap() + }; run("assert module_with_functions.__doc__ == 'This module is implemented in Rust.'"); run("assert module_with_functions.sum_as_string(1, 2) == '3'"); @@ -73,6 +96,7 @@ fn test_module_with_functions() { run("assert module_with_functions.double.__doc__ == 'Doubles the given value'"); run("assert module_with_functions.also_double(3) == 6"); run("assert module_with_functions.also_double.__doc__ == 'Doubles the given value'"); + run("assert module_with_functions.double_value(module_with_functions.ValueClass(1)) == 2"); } #[pymodule(other_name)] From bbe4393b1e8c145b3fe5a5a35075d7c150c08d78 Mon Sep 17 00:00:00 2001 From: kngwyu Date: Wed, 4 Mar 2020 20:21:36 +0900 Subject: [PATCH 2/3] Add more tests in method_with_pyclassarg --- tests/test_methods.rs | 42 ++++++++++++++++++++++-------------------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/tests/test_methods.rs b/tests/test_methods.rs index 906b887dc09..797b0af3b50 100644 --- a/tests/test_methods.rs +++ b/tests/test_methods.rs @@ -426,11 +426,19 @@ impl MethodWithPyClassArg { fn inplace_add(&self, other: &mut MethodWithPyClassArg) { other.value += self.value; } + fn inplace_add_pyref(&self, mut other: PyRefMut) { + other.value += self.value; + } fn optional_add(&self, other: Option<&MethodWithPyClassArg>) -> MethodWithPyClassArg { MethodWithPyClassArg { value: self.value + other.map(|o| o.value).unwrap_or(10), } } + fn optional_inplace_add(&self, other: Option<&mut MethodWithPyClassArg>) { + if let Some(other) = other { + other.value += self.value; + } + } } #[test] @@ -439,26 +447,20 @@ fn method_with_pyclassarg() { let py = gil.python(); let obj1 = PyCell::new(py, MethodWithPyClassArg { value: 10 }).unwrap(); let obj2 = PyCell::new(py, MethodWithPyClassArg { value: 10 }).unwrap(); - py_run!( - py, - obj1 obj2, - "obj = obj1.add(obj2); assert obj.value == 20" - ); - py_run!( - py, - obj1 obj2, - "obj = obj1.add_pyref(obj2); assert obj.value == 20" - ); - py_run!( - py, - obj1 obj2, - "obj = obj1.optional_add(); assert obj.value == 20" - ); - py_run!( - py, - obj1 obj2, - "obj1.inplace_add(obj2); assert obj2.value == 20" - ); + let objs = [("obj1", obj1), ("obj2", obj2)].into_py_dict(py); + let run = |code| { + py.run(code, None, Some(objs)) + .map_err(|e| e.print(py)) + .unwrap() + }; + run("obj = obj1.add(obj2); assert obj.value == 20"); + run("obj = obj1.add_pyref(obj2); assert obj.value == 20"); + run("obj = obj1.optional_add(); assert obj.value == 20"); + run("obj = obj1.optional_add(obj2); assert obj.value == 20"); + run("obj1.inplace_add(obj2); assert obj.value == 20"); + run("obj1.inplace_add_pyref(obj2); assert obj.value == 30"); + run("obj1.optional_inplace_add(obj2); assert obj.value == 40"); + run("obj1.optional_inplace_add(); assert obj.value == 40"); } #[pyclass] From 96115eaaaad6df8039fb0dc07521137a89ce15b4 Mon Sep 17 00:00:00 2001 From: kngwyu Date: Wed, 4 Mar 2020 20:35:46 +0900 Subject: [PATCH 3/3] Refactor some tests in test_methods --- tests/test_methods.rs | 98 +++++++++++++------------------------------ 1 file changed, 29 insertions(+), 69 deletions(-) diff --git a/tests/test_methods.rs b/tests/test_methods.rs index 797b0af3b50..e1b2546b66b 100644 --- a/tests/test_methods.rs +++ b/tests/test_methods.rs @@ -82,30 +82,15 @@ fn class_method() { let py = gil.python(); let d = [("C", py.get_type::())].into_py_dict(py); - py.run( - "assert C.method() == 'ClassMethod.method()!'", - None, - Some(d), - ) - .unwrap(); - py.run( - "assert C().method() == 'ClassMethod.method()!'", - None, - Some(d), - ) - .unwrap(); - py.run( - "assert C.method.__doc__ == 'Test class method.'", - None, - Some(d), - ) - .unwrap(); - py.run( - "assert C().method.__doc__ == 'Test class method.'", - None, - Some(d), - ) - .unwrap(); + let run = |code| { + py.run(code, None, Some(d)) + .map_err(|e| e.print(py)) + .unwrap() + }; + run("assert C.method() == 'ClassMethod.method()!'"); + run("assert C().method() == 'ClassMethod.method()!'"); + run("assert C.method.__doc__ == 'Test class method.'"); + run("assert C().method.__doc__ == 'Test class method.'"); } #[pyclass] @@ -158,30 +143,15 @@ fn static_method() { assert_eq!(StaticMethod::method(py).unwrap(), "StaticMethod.method()!"); let d = [("C", py.get_type::())].into_py_dict(py); - py.run( - "assert C.method() == 'StaticMethod.method()!'", - None, - Some(d), - ) - .unwrap(); - py.run( - "assert C().method() == 'StaticMethod.method()!'", - None, - Some(d), - ) - .unwrap(); - py.run( - "assert C.method.__doc__ == 'Test static method.'", - None, - Some(d), - ) - .unwrap(); - py.run( - "assert C().method.__doc__ == 'Test static method.'", - None, - Some(d), - ) - .unwrap(); + let run = |code| { + py.run(code, None, Some(d)) + .map_err(|e| e.print(py)) + .unwrap() + }; + run("assert C.method() == 'StaticMethod.method()!'"); + run("assert C().method() == 'StaticMethod.method()!'"); + run("assert C.method.__doc__ == 'Test static method.'"); + run("assert C().method.__doc__ == 'Test static method.'"); } #[pyclass] @@ -356,25 +326,15 @@ fn meth_doc() { let gil = Python::acquire_gil(); let py = gil.python(); let d = [("C", py.get_type::())].into_py_dict(py); + let run = |code| { + py.run(code, None, Some(d)) + .map_err(|e| e.print(py)) + .unwrap() + }; - py.run( - "assert C.__doc__ == 'A class with \"documentation\".'", - None, - Some(d), - ) - .unwrap(); - py.run( - "assert C.method.__doc__ == 'A method with \"documentation\" as well.'", - None, - Some(d), - ) - .unwrap(); - py.run( - "assert C.x.__doc__ == '`int`: a very \"important\" member of \\'this\\' instance.'", - None, - Some(d), - ) - .unwrap(); + run("assert C.__doc__ == 'A class with \"documentation\".'"); + run("assert C.method.__doc__ == 'A method with \"documentation\" as well.'"); + run("assert C.x.__doc__ == '`int`: a very \"important\" member of \\'this\\' instance.'"); } #[pyclass] @@ -458,9 +418,9 @@ fn method_with_pyclassarg() { run("obj = obj1.optional_add(); assert obj.value == 20"); run("obj = obj1.optional_add(obj2); assert obj.value == 20"); run("obj1.inplace_add(obj2); assert obj.value == 20"); - run("obj1.inplace_add_pyref(obj2); assert obj.value == 30"); - run("obj1.optional_inplace_add(obj2); assert obj.value == 40"); - run("obj1.optional_inplace_add(); assert obj.value == 40"); + run("obj1.inplace_add_pyref(obj2); assert obj2.value == 30"); + run("obj1.optional_inplace_add(); assert obj2.value == 30"); + run("obj1.optional_inplace_add(obj2); assert obj2.value == 40"); } #[pyclass]