From 7ebeaa8ff811bebd9bb8aa412a1e5c525a3e2ba3 Mon Sep 17 00:00:00 2001 From: Michael-J-Ward Date: Tue, 16 Apr 2024 14:21:04 -0500 Subject: [PATCH] update pyo3 to new bounds api --- apis/python/node/src/lib.rs | 20 ++++++------ apis/python/operator/src/lib.rs | 16 ++++++---- apis/rust/operator/types/src/lib.rs | 2 +- binaries/runtime/src/operator/python.rs | 32 ++++++++++--------- .../extensions/ros2-bridge/python/src/lib.rs | 16 +++++----- .../ros2-bridge/python/src/typed/mod.rs | 18 ++++++----- 6 files changed, 56 insertions(+), 48 deletions(-) diff --git a/apis/python/node/src/lib.rs b/apis/python/node/src/lib.rs index 0d6714f7..f471637c 100644 --- a/apis/python/node/src/lib.rs +++ b/apis/python/node/src/lib.rs @@ -90,17 +90,17 @@ impl Node { &mut self, output_id: String, data: PyObject, - metadata: Option<&PyDict>, + metadata: Option>, py: Python, ) -> eyre::Result<()> { let parameters = pydict_to_metadata(metadata)?; - if let Ok(py_bytes) = data.downcast::(py) { + if let Ok(py_bytes) = data.downcast_bound::(py) { let data = py_bytes.as_bytes(); self.node .send_output_bytes(output_id.into(), parameters, data.len(), data) .wrap_err("failed to send output")?; - } else if let Ok(arrow_array) = arrow::array::ArrayData::from_pyarrow(data.as_ref(py)) { + } else if let Ok(arrow_array) = arrow::array::ArrayData::from_pyarrow_bound(data.bind(py)) { self.node.send_output( output_id.into(), parameters, @@ -203,15 +203,15 @@ pub fn start_runtime() -> eyre::Result<()> { } #[pymodule] -fn dora(py: Python, m: &PyModule) -> PyResult<()> { - m.add_function(wrap_pyfunction!(start_runtime, m)?)?; +fn dora(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> { + m.add_function(wrap_pyfunction!(start_runtime, &m)?)?; m.add_class::().unwrap(); - let ros2_bridge = PyModule::new(py, "ros2_bridge")?; - dora_ros2_bridge_python::create_dora_ros2_bridge_module(ros2_bridge)?; - let experimental = PyModule::new(py, "experimental")?; - experimental.add_submodule(ros2_bridge)?; - m.add_submodule(experimental)?; + let ros2_bridge = PyModule::new_bound(py, "ros2_bridge")?; + dora_ros2_bridge_python::create_dora_ros2_bridge_module(&ros2_bridge)?; + let experimental = PyModule::new_bound(py, "experimental")?; + experimental.add_submodule(&ros2_bridge)?; + m.add_submodule(&experimental)?; Ok(()) } diff --git a/apis/python/operator/src/lib.rs b/apis/python/operator/src/lib.rs index f4712e92..f29c2c0a 100644 --- a/apis/python/operator/src/lib.rs +++ b/apis/python/operator/src/lib.rs @@ -1,7 +1,7 @@ use arrow::{array::ArrayRef, pyarrow::ToPyArrow}; use dora_node_api::{merged::MergedEvent, Event, Metadata, MetadataParameters}; use eyre::{Context, Result}; -use pyo3::{exceptions::PyLookupError, prelude::*, types::PyDict}; +use pyo3::{exceptions::PyLookupError, prelude::*, pybacked::PyBackedStr, types::PyDict}; #[pyclass] pub struct PyEvent { @@ -110,11 +110,15 @@ impl From> for PyEvent { } } -pub fn pydict_to_metadata(dict: Option<&PyDict>) -> Result { +pub fn pydict_to_metadata(dict: Option>) -> Result { let mut default_metadata = MetadataParameters::default(); if let Some(metadata) = dict { for (key, value) in metadata.iter() { - match key.extract::<&str>().context("Parsing metadata keys")? { + match key + .extract::() + .context("Parsing metadata keys")? + .as_ref() + { "watermark" => { default_metadata.watermark = value.extract().context("parsing watermark failed")?; @@ -124,7 +128,7 @@ pub fn pydict_to_metadata(dict: Option<&PyDict>) -> Result { value.extract().context("parsing deadline failed")?; } "open_telemetry_context" => { - let otel_context: &str = value + let otel_context: PyBackedStr = value .extract() .context("parsing open telemetry context failed")?; default_metadata.open_telemetry_context = otel_context.to_string(); @@ -136,8 +140,8 @@ pub fn pydict_to_metadata(dict: Option<&PyDict>) -> Result { Ok(default_metadata) } -pub fn metadata_to_pydict<'a>(metadata: &'a Metadata, py: Python<'a>) -> &'a PyDict { - let dict = PyDict::new(py); +pub fn metadata_to_pydict<'a>(metadata: &'a Metadata, py: Python<'a>) -> pyo3::Bound<'a, PyDict> { + let dict = PyDict::new_bound(py); dict.set_item( "open_telemetry_context", &metadata.parameters.open_telemetry_context, diff --git a/apis/rust/operator/types/src/lib.rs b/apis/rust/operator/types/src/lib.rs index 7f299b00..49ca0a4b 100644 --- a/apis/rust/operator/types/src/lib.rs +++ b/apis/rust/operator/types/src/lib.rs @@ -165,7 +165,7 @@ pub fn dora_free_input_id(_input_id: char_p_boxed) {} #[ffi_export] pub fn dora_read_data(input: &mut Input) -> Option> { let data_array = input.data_array.take()?; - let data = unsafe {arrow::ffi::from_ffi(data_array, &input.schema).ok()? }; + let data = unsafe { arrow::ffi::from_ffi(data_array, &input.schema).ok()? }; let array = ArrowData(arrow::array::make_array(data)); let bytes: &[u8] = TryFrom::try_from(&array).ok()?; Some(bytes.to_owned().into()) diff --git a/binaries/runtime/src/operator/python.rs b/binaries/runtime/src/operator/python.rs index eec5b7ec..f6c907ad 100644 --- a/binaries/runtime/src/operator/python.rs +++ b/binaries/runtime/src/operator/python.rs @@ -12,7 +12,7 @@ use dora_operator_api_types::DoraStatus; use eyre::{bail, eyre, Context, Result}; use pyo3::{ pyclass, - types::{IntoPyDict, PyDict}, + types::{IntoPyDict, PyAnyMethods, PyDict, PyTracebackMethods}, Py, PyAny, Python, }; use std::{ @@ -23,7 +23,7 @@ use tokio::sync::{mpsc::Sender, oneshot}; use tracing::{error, field, span, warn}; fn traceback(err: pyo3::PyErr) -> eyre::Report { - let traceback = Python::with_gil(|py| err.traceback(py).and_then(|t| t.format().ok())); + let traceback = Python::with_gil(|py| err.traceback_bound(py).and_then(|t| t.format().ok())); if let Some(traceback) = traceback { eyre::eyre!("{traceback}\n{err}") } else { @@ -78,7 +78,9 @@ pub fn run( let parent_path = parent_path .to_str() .ok_or_else(|| eyre!("module path is not valid utf8"))?; - let sys = py.import("sys").wrap_err("failed to import `sys` module")?; + let sys = py + .import_bound("sys") + .wrap_err("failed to import `sys` module")?; let sys_path = sys .getattr("path") .wrap_err("failed to import `sys.path` module")?; @@ -90,14 +92,14 @@ pub fn run( .wrap_err("failed to append module path to python search path")?; } - let module = py.import(module_name).map_err(traceback)?; + let module = py.import_bound(module_name).map_err(traceback)?; let operator_class = module .getattr("Operator") .wrap_err("no `Operator` class found in module")?; - let locals = [("Operator", operator_class)].into_py_dict(py); + let locals = [("Operator", operator_class)].into_py_dict_bound(py); let operator = py - .eval("Operator()", None, Some(locals)) + .eval_bound("Operator()", None, Some(&locals)) .map_err(traceback)?; operator.setattr( "dataflow_descriptor", @@ -140,11 +142,11 @@ pub fn run( .wrap_err("could not extract operator state as a PyDict")?; // Reload module let module = py - .import(module_name) + .import_bound(module_name) .map_err(traceback) .wrap_err(format!("Could not retrieve {module_name} while reloading"))?; let importlib = py - .import("importlib") + .import_bound("importlib") .wrap_err("failed to import `importlib` module")?; let module = importlib .call_method("reload", (module,), None) @@ -154,9 +156,9 @@ pub fn run( .wrap_err("no `Operator` class found in module")?; // Create a new reloaded operator - let locals = [("Operator", reloaded_operator_class)].into_py_dict(py); + let locals = [("Operator", reloaded_operator_class)].into_py_dict_bound(py); let operator: Py = py - .eval("Operator()", None, Some(locals)) + .eval_bound("Operator()", None, Some(&locals)) .map_err(traceback) .wrap_err("Could not initialize reloaded operator")? .into(); @@ -299,8 +301,8 @@ mod callback_impl { use eyre::{eyre, Context, Result}; use pyo3::{ pymethods, - types::{PyBytes, PyDict}, - PyObject, Python, + types::{PyBytes, PyBytesMethods, PyDict}, + Bound, PyObject, Python, }; use tokio::sync::oneshot; use tracing::{field, span}; @@ -317,7 +319,7 @@ mod callback_impl { &mut self, output: &str, data: PyObject, - metadata: Option<&PyDict>, + metadata: Option>, py: Python, ) -> Result<()> { let parameters = pydict_to_metadata(metadata) @@ -353,12 +355,12 @@ mod callback_impl { } }; - let (sample, type_info) = if let Ok(py_bytes) = data.downcast::(py) { + let (sample, type_info) = if let Ok(py_bytes) = data.downcast_bound::(py) { let data = py_bytes.as_bytes(); let mut sample = allocate_sample(data.len())?; sample.copy_from_slice(data); (sample, ArrowTypeInfo::byte_array(data.len())) - } else if let Ok(arrow_array) = ArrayData::from_pyarrow(data.as_ref(py)) { + } else if let Ok(arrow_array) = ArrayData::from_pyarrow_bound(data.bind(py)) { let total_len = required_data_size(&arrow_array); let mut sample = allocate_sample(total_len)?; diff --git a/libraries/extensions/ros2-bridge/python/src/lib.rs b/libraries/extensions/ros2-bridge/python/src/lib.rs index ac0a0bec..15554b63 100644 --- a/libraries/extensions/ros2-bridge/python/src/lib.rs +++ b/libraries/extensions/ros2-bridge/python/src/lib.rs @@ -15,8 +15,8 @@ use eyre::{eyre, Context, ContextCompat}; use futures::{Stream, StreamExt}; use pyo3::{ prelude::{pyclass, pymethods}, - types::{PyDict, PyList, PyModule}, - PyAny, PyObject, PyResult, Python, + types::{PyAnyMethods, PyDict, PyList, PyModule, PyModuleMethods}, + Bound, PyAny, PyObject, PyResult, Python, }; use typed::{deserialize::StructDeserializer, TypeInfo, TypedValue}; @@ -194,8 +194,8 @@ pub struct Ros2Publisher { #[pymethods] impl Ros2Publisher { - pub fn publish(&self, data: &PyAny) -> eyre::Result<()> { - let pyarrow = PyModule::import(data.py(), "pyarrow")?; + pub fn publish(&self, data: Bound<'_, PyAny>) -> eyre::Result<()> { + let pyarrow = PyModule::import_bound(data.py(), "pyarrow")?; let data = if data.is_instance_of::() { // convert to arrow struct scalar @@ -204,15 +204,15 @@ impl Ros2Publisher { data }; - let data = if data.is_instance(pyarrow.getattr("StructScalar")?)? { + let data = if data.is_instance(&pyarrow.getattr("StructScalar")?)? { // convert to arrow array - let list = PyList::new(data.py(), [data]); + let list = PyList::new_bound(data.py(), [data]); pyarrow.getattr("array")?.call1((list,))? } else { data }; - let value = arrow::array::ArrayData::from_pyarrow(data)?; + let value = arrow::array::ArrayData::from_pyarrow_bound(&data)?; //// add type info to ensure correct serialization (e.g. struct types //// and map types need to be serialized differently) let typed_value = TypedValue { @@ -297,7 +297,7 @@ impl Stream for Ros2SubscriptionStream { } } -pub fn create_dora_ros2_bridge_module(m: &PyModule) -> PyResult<()> { +pub fn create_dora_ros2_bridge_module(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/libraries/extensions/ros2-bridge/python/src/typed/mod.rs b/libraries/extensions/ros2-bridge/python/src/typed/mod.rs index c6875b7f..2b841589 100644 --- a/libraries/extensions/ros2-bridge/python/src/typed/mod.rs +++ b/libraries/extensions/ros2-bridge/python/src/typed/mod.rs @@ -37,10 +37,12 @@ mod tests { use arrow::pyarrow::ToPyArrow; use pyo3::types::IntoPyDict; + use pyo3::types::PyAnyMethods; use pyo3::types::PyDict; use pyo3::types::PyList; use pyo3::types::PyModule; use pyo3::types::PyTuple; + use pyo3::PyNativeType; use pyo3::Python; use serde::de::DeserializeSeed; use serde::Serialize; @@ -61,13 +63,13 @@ mod tests { let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); //.join("test_utils.py"); // Adjust this path as needed // Add the Python module's directory to sys.path - py.run( + py.run_bound( "import sys; sys.path.append(str(path))", - Some([("path", path)].into_py_dict(py)), + Some(&[("path", path)].into_py_dict_bound(py)), None, )?; - let my_module = PyModule::import(py, "test_utils")?; + let my_module = PyModule::import_bound(py, "test_utils")?; let arrays: &PyList = my_module.getattr("TEST_ARRAYS")?.extract()?; for array_wrapper in arrays.iter() { @@ -77,7 +79,7 @@ mod tests { println!("Checking {}::{}", package_name, message_name); let in_pyarrow = arrays.get_item(2)?; - let array = arrow::array::ArrayData::from_pyarrow(in_pyarrow)?; + let array = arrow::array::ArrayData::from_pyarrow_bound(&in_pyarrow.as_borrowed())?; let type_info = TypeInfo { package_name: package_name.into(), message_name: message_name.clone().into(), @@ -99,17 +101,17 @@ mod tests { let out_pyarrow = out_value.to_pyarrow(py)?; - let test_utils = PyModule::import(py, "test_utils")?; - let context = PyDict::new(py); + let test_utils = PyModule::import_bound(py, "test_utils")?; + let context = PyDict::new_bound(py); context.set_item("test_utils", test_utils)?; context.set_item("in_pyarrow", in_pyarrow)?; context.set_item("out_pyarrow", out_pyarrow)?; let _ = py - .eval( + .eval_bound( "test_utils.is_subset(in_pyarrow, out_pyarrow)", - Some(context), + Some(&context), None, ) .context("could not check if it is a subset")?;