From e749a08a1846728d1d0ba85ae2fb86002588bd35 Mon Sep 17 00:00:00 2001 From: smartgoo Date: Fri, 12 Jul 2024 19:44:21 -0400 Subject: [PATCH] Python wRPC client subscription prototype (#62) * change python package name to 'kaspa' * Introduce Py wRPC client wrapping Inner struct that contains KaspaRpcClient * scaffolding for Python wRPC client Inner struct * Python wRPC subscribtions prototype * Python wRPC subscription callback args/kwargs * lint * minor refactor, handling of UTXO change notification * properly gate python code in kaspa-rpc-core * Attempt to fix test suite CI failure * Subscribe UTXOs Changed * subscriptions * wRPC client disconnect * unregister callbacks * fix failing kaspad build --- Cargo.lock | 6 + Cargo.toml | 5 +- python/Cargo.toml | 3 +- python/README.md | 4 +- python/examples/addresses.py | 2 +- python/examples/rpc.py | 42 ++- python/examples/test.py | 2 +- python/pyproject.toml | 6 +- python/src/lib.rs | 2 +- rpc/core/Cargo.toml | 6 +- rpc/core/src/api/notifications.rs | 21 ++ rpc/macros/src/lib.rs | 6 + rpc/macros/src/wrpc/python.rs | 93 +++++- rpc/wrpc/python/Cargo.toml | 11 +- rpc/wrpc/python/src/client.rs | 455 ++++++++++++++++++++++++++++-- 15 files changed, 620 insertions(+), 44 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 935df6ec5b..cdf56bd9e5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3207,6 +3207,7 @@ dependencies = [ "rand 0.8.5", "pyo3", "serde", + "serde-pyobject", "serde-wasm-bindgen", "serde_json", "smallvec", @@ -3726,8 +3727,12 @@ dependencies = [ name = "kaspa-wrpc-python" version = "0.14.1" dependencies = [ + "ahash 0.8.11", "cfg-if 1.0.0", + "futures", + "kaspa-addresses", "kaspa-consensus-core", + "kaspa-notify", "kaspa-python-macros", "kaspa-rpc-core", "kaspa-rpc-macros", @@ -3738,6 +3743,7 @@ dependencies = [ "serde_json", "thiserror", "workflow-core", + "workflow-log", "workflow-rpc", ] diff --git a/Cargo.toml b/Cargo.toml index bd41b6295b..5bb7f7507f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -219,7 +219,7 @@ pbkdf2 = "0.12.2" portable-atomic = { version = "1.5.1", features = ["float"] } prost = "0.12.1" # prost = "0.13.1" -pyo3 = { version = "0.21.0", features = ["extension-module", "multiple-pymethods"] } +pyo3 = { version = "0.21.0", features = ["multiple-pymethods"] } pyo3-asyncio-0-21 = { version = "0.21", features = ["attributes", "tokio-runtime"] } rand = "0.8.5" rand_chacha = "0.3.1" @@ -240,6 +240,7 @@ seqlock = "0.2.0" serde = { version = "1.0.190", features = ["derive", "rc"] } serde_bytes = "0.11.12" serde_json = "1.0.107" +serde-pyobject = "0.3.0" serde_repr = "0.1.18" serde-value = "0.7.0" serde-wasm-bindgen = "0.6.1" @@ -344,4 +345,4 @@ debug = true strip = false [workspace.lints.clippy] -empty_docs = "allow" +empty_docs = "allow" \ No newline at end of file diff --git a/python/Cargo.toml b/python/Cargo.toml index da7e079b65..38bc00132b 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -7,7 +7,7 @@ edition.workspace = true include.workspace = true [lib] -name = "kaspapy" +name = "kaspa" crate-type = ["cdylib"] [dependencies] @@ -20,6 +20,7 @@ pyo3.workspace = true [features] default = [] py-sdk = [ + "pyo3/extension-module", "kaspa-addresses/py-sdk", "kaspa-wallet-keys/py-sdk", "kaspa-wrpc-python/py-sdk", diff --git a/python/README.md b/python/README.md index 717beb1988..dbc4534a94 100644 --- a/python/README.md +++ b/python/README.md @@ -17,6 +17,6 @@ Rusty-Kaspa/Rust bindings for Python, using [PyO3](https://pyo3.rs/v0.20.0/) and See Python files in `./python/examples`. # Project Layout -The Python package `kaspapy` is built from the `kaspa-python` crate, which is located at `./python`. +The Python package `kaspa` is built from the `kaspa-python` crate, which is located at `./python`. -As such, the `kaspapy` function in `./python/src/lib.rs` is a good starting point. This function uses PyO3 to add functionality to the package. +As such, the `kaspa` function in `./python/src/lib.rs` is a good starting point. This function uses PyO3 to add functionality to the package. diff --git a/python/examples/addresses.py b/python/examples/addresses.py index dcc01159fc..db00c2bea6 100644 --- a/python/examples/addresses.py +++ b/python/examples/addresses.py @@ -1,4 +1,4 @@ -from kaspapy import ( +from kaspa import ( PrivateKey, ) diff --git a/python/examples/rpc.py b/python/examples/rpc.py index 9b151be72d..a7fa8c3c62 100644 --- a/python/examples/rpc.py +++ b/python/examples/rpc.py @@ -1,14 +1,37 @@ import asyncio import json import time +import os -from kaspapy import RpcClient +from kaspa import RpcClient -async def main(): - client = await RpcClient.connect(url = "ws://localhost:17110") - print(f'Client is connected: {client.is_connected()}') +def subscription_callback(event, name, **kwargs): + print(f'{name} | {event}') + +async def rpc_subscriptions(client): + # client.add_event_listener('all', subscription_callback, callback_id=1, kwarg1='Im a kwarg!!') + client.add_event_listener('all', subscription_callback, name="all") + + await client.subscribe_virtual_daa_score_changed() + await client.subscribe_virtual_chain_changed(True) + await client.subscribe_block_added() + await client.subscribe_new_block_template() + + await asyncio.sleep(5) + + client.remove_event_listener('all') + print('Removed all event listeners. Sleeping for 5 seconds before unsubscribing. Should see nothing print.') + + await asyncio.sleep(5) + await client.unsubscribe_virtual_daa_score_changed() + await client.unsubscribe_virtual_chain_changed(True) + await client.unsubscribe_block_added() + await client.unsubscribe_new_block_template() + + +async def rpc_calls(client): get_server_info_response = await client.get_server_info() print(get_server_info_response) @@ -24,6 +47,17 @@ async def main(): get_balances_by_addresses_response = await client.get_balances_by_addresses_call(get_balances_by_addresses_request) print(get_balances_by_addresses_response) +async def main(): + rpc_host = os.environ.get("KASPA_RPC_HOST") + client = RpcClient(url = f"ws://{rpc_host}:17210") + await client.connect() + print(f'Client is connected: {client.is_connected()}') + + await rpc_calls(client) + await rpc_subscriptions(client) + + await client.disconnect() + if __name__ == "__main__": asyncio.run(main()) \ No newline at end of file diff --git a/python/examples/test.py b/python/examples/test.py index bbdb01d6bb..cc2cd78065 100644 --- a/python/examples/test.py +++ b/python/examples/test.py @@ -1,4 +1,4 @@ -from kaspapy import PrivateKeyGenerator +from kaspa import PrivateKeyGenerator if __name__ == "__main__": x = PrivateKeyGenerator('xprv9s21ZrQH143K2hP7m1bU4ZT6tWgX1Qn2cWvtLVDX6sTJVyg3XBa4p1So4s7uEvVFGyBhQWWRe8JeLPeDZ462LggxkkJpZ9z1YMzmPahnaZA', False, 1) diff --git a/python/pyproject.toml b/python/pyproject.toml index 4bcdcb9260..b41e388b64 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -3,7 +3,7 @@ requires = ["maturin>=1.0,<2.0"] build-backend = "maturin" [project] -name = "kaspapy" +name = "kaspa" description = "Kaspa Python Bindings" version = "0.1.0" requires-python = ">=3.8" @@ -23,8 +23,8 @@ dependencies = [] # changelog = "" [package.metadata.maturin] -name = "kaspapy" +name = "kaspa" description = "Kaspa Python Bindings" [tool.maturin] -name = "kaspapy" \ No newline at end of file +name = "kaspa" \ No newline at end of file diff --git a/python/src/lib.rs b/python/src/lib.rs index e1afc2e16f..ee8ce06421 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -3,7 +3,7 @@ cfg_if::cfg_if! { use pyo3::prelude::*; #[pymodule] - fn kaspapy(m: &Bound<'_, PyModule>) -> PyResult<()> { + fn kaspa(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; diff --git a/rpc/core/Cargo.toml b/rpc/core/Cargo.toml index 2d9576ad40..f83a642701 100644 --- a/rpc/core/Cargo.toml +++ b/rpc/core/Cargo.toml @@ -14,7 +14,10 @@ wasm32-sdk = [ "kaspa-consensus-client/wasm32-sdk", "kaspa-consensus-wasm/wasm32-sdk" ] -py-sdk = ["pyo3"] +py-sdk = [ + "pyo3", + "serde-pyobject" +] [dependencies] kaspa-addresses.workspace = true @@ -46,6 +49,7 @@ paste.workspace = true rand.workspace = true pyo3 = { workspace = true, optional = true } serde-wasm-bindgen.workspace = true +serde-pyobject = { workspace = true, optional = true } serde.workspace = true smallvec.workspace = true thiserror.workspace = true diff --git a/rpc/core/src/api/notifications.rs b/rpc/core/src/api/notifications.rs index e07a7c4d98..1792e037eb 100644 --- a/rpc/core/src/api/notifications.rs +++ b/rpc/core/src/api/notifications.rs @@ -9,7 +9,11 @@ use kaspa_notify::{ Subscription, }, }; +#[cfg(feature = "py-sdk")] +use pyo3::prelude::*; use serde::{Deserialize, Serialize}; +#[cfg(feature = "py-sdk")] +use serde_pyobject::to_pyobject; use std::sync::Arc; use wasm_bindgen::JsValue; use workflow_serializer::prelude::*; @@ -62,6 +66,23 @@ impl Notification { Notification::VirtualChainChanged(v) => to_value(&v), } } + + #[cfg(feature = "py-sdk")] + pub fn to_pyobject(&self, py: Python) -> PyResult { + let bound_obj = match self { + Notification::BlockAdded(v) => to_pyobject(py, &v), + Notification::FinalityConflict(v) => to_pyobject(py, &v), + Notification::FinalityConflictResolved(v) => to_pyobject(py, &v), + Notification::NewBlockTemplate(v) => to_pyobject(py, &v), + Notification::PruningPointUtxoSetOverride(v) => to_pyobject(py, &v), + Notification::UtxosChanged(v) => to_pyobject(py, &v), + Notification::VirtualDaaScoreChanged(v) => to_pyobject(py, &v), + Notification::SinkBlueScoreChanged(v) => to_pyobject(py, &v), + Notification::VirtualChainChanged(v) => to_pyobject(py, &v), + }; + + Ok(bound_obj.unwrap().to_object(py)) + } } impl NotificationTrait for Notification { diff --git a/rpc/macros/src/lib.rs b/rpc/macros/src/lib.rs index 6876f74d1f..a6d60b2ec5 100644 --- a/rpc/macros/src/lib.rs +++ b/rpc/macros/src/lib.rs @@ -16,6 +16,12 @@ pub fn build_wrpc_python_interface(input: TokenStream) -> TokenStream { wrpc::python::build_wrpc_python_interface(input) } +#[proc_macro] +#[proc_macro_error] +pub fn build_wrpc_python_subscriptions(input: TokenStream) -> TokenStream { + wrpc::python::build_wrpc_python_subscriptions(input) +} + #[proc_macro] #[proc_macro_error] pub fn declare_typescript_wasm_interface(input: TokenStream) -> TokenStream { diff --git a/rpc/macros/src/wrpc/python.rs b/rpc/macros/src/wrpc/python.rs index aad60e7416..0b535cd4ba 100644 --- a/rpc/macros/src/wrpc/python.rs +++ b/rpc/macros/src/wrpc/python.rs @@ -1,6 +1,8 @@ use crate::handler::*; -use proc_macro2::TokenStream; +use convert_case::{Case, Casing}; +use proc_macro2::{Ident, Span, TokenStream}; use quote::{quote, ToTokens}; +use regex::Regex; use std::convert::Into; use syn::{ parse::{Parse, ParseStream}, @@ -41,7 +43,7 @@ impl ToTokens for RpcTable { #[pymethods] impl RpcClient { fn #fn_call(&self, py: Python, request: Py) -> PyResult> { - let client = self.client.clone(); + let client = self.inner.client.clone(); let request : #request_type = serde_pyobject::from_pyobject(request.into_bound(py)).unwrap(); @@ -71,3 +73,90 @@ pub fn build_wrpc_python_interface(input: proc_macro::TokenStream) -> proc_macro // println!("MACRO: {}", ts.to_string()); ts.into() } + +#[derive(Debug)] +struct RpcSubscriptions { + handlers: ExprArray, +} + +impl Parse for RpcSubscriptions { + fn parse(input: ParseStream) -> Result { + let parsed = Punctuated::::parse_terminated(input).unwrap(); + if parsed.len() != 1 { + return Err(Error::new_spanned(parsed, "usage: build_wrpc_python_!([getInfo, ..])".to_string())); + } + + let mut iter = parsed.iter(); + // Intake enum variants as an array + let handlers = get_handlers(iter.next().unwrap().clone())?; + + Ok(RpcSubscriptions { handlers }) + } +} + +impl ToTokens for RpcSubscriptions { + fn to_tokens(&self, tokens: &mut TokenStream) { + let mut targets = Vec::new(); + + for handler in self.handlers.elems.iter() { + // TODO docs (name, docs) + let (name, _) = match handler { + syn::Expr::Path(expr_path) => (expr_path.path.to_token_stream().to_string(), &expr_path.attrs), + _ => { + continue; + } + }; + + let name = format!("Notify{}", name.as_str()); + let regex = Regex::new(r"^Notify").unwrap(); + let blank = regex.replace(&name, ""); + let subscribe = regex.replace(&name, "Subscribe"); + let unsubscribe = regex.replace(&name, "Unsubscribe"); + let scope = Ident::new(&blank, Span::call_site()); + let sub_scope = Ident::new(format!("{blank}Scope").as_str(), Span::call_site()); + let fn_subscribe_snake = Ident::new(&subscribe.to_case(Case::Snake), Span::call_site()); + let fn_unsubscribe_snake = Ident::new(&unsubscribe.to_case(Case::Snake), Span::call_site()); + + targets.push(quote! { + #[pymethods] + impl RpcClient { + fn #fn_subscribe_snake(&self, py: Python) -> PyResult> { + if let Some(listener_id) = self.listener_id() { + let client = self.inner.client.clone(); + py_async! {py, async move { + client.start_notify(listener_id, Scope::#scope(#sub_scope {})).await?; + Ok(()) + }} + } else { + Err(PyErr::new::("RPC subscribe on a closed connection")) + } + } + + fn #fn_unsubscribe_snake(&self, py: Python) -> PyResult> { + if let Some(listener_id) = self.listener_id() { + let client = self.inner.client.clone(); + py_async! {py, async move { + client.stop_notify(listener_id, Scope::#scope(#sub_scope {})).await?; + Ok(()) + }} + } else { + Err(PyErr::new::("RPC unsubscribe on a closed connection")) + } + } + } + }); + } + + quote! { + #(#targets)* + } + .to_tokens(tokens); + } +} + +pub fn build_wrpc_python_subscriptions(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let rpc_table = parse_macro_input!(input as RpcSubscriptions); + let ts = rpc_table.to_token_stream(); + // println!("MACRO: {}", ts.to_string()); + ts.into() +} diff --git a/rpc/wrpc/python/Cargo.toml b/rpc/wrpc/python/Cargo.toml index 6596e2f308..d0ba90bfe3 100644 --- a/rpc/wrpc/python/Cargo.toml +++ b/rpc/wrpc/python/Cargo.toml @@ -10,15 +10,21 @@ license.workspace = true repository.workspace = true [features] -default = ["py-sdk"] +default = [] py-sdk = [ + "pyo3/extension-module", + "kaspa-addresses/py-sdk", "kaspa-rpc-core/py-sdk", "kaspa-wrpc-client/py-sdk", ] [dependencies] +ahash.workspace = true cfg-if.workspace = true +futures.workspace = true +kaspa-addresses.workspace = true kaspa-consensus-core.workspace = true +kaspa-notify.workspace = true kaspa-rpc-core.workspace = true kaspa-rpc-macros.workspace = true kaspa-wrpc-client.workspace = true @@ -26,7 +32,8 @@ kaspa-python-macros.workspace = true pyo3.workspace = true pyo3-asyncio-0-21.workspace = true serde_json.workspace = true -serde-pyobject = "0.3.0" +serde-pyobject.workspace = true thiserror.workspace = true workflow-core.workspace = true +workflow-log.workspace = true workflow-rpc.workspace = true \ No newline at end of file diff --git a/rpc/wrpc/python/src/client.rs b/rpc/wrpc/python/src/client.rs index 8c74cdb1ff..a48cfb9270 100644 --- a/rpc/wrpc/python/src/client.rs +++ b/rpc/wrpc/python/src/client.rs @@ -1,29 +1,174 @@ +use ahash::AHashMap; +use futures::*; +use kaspa_addresses::Address; +use kaspa_notify::listener::ListenerId; +use kaspa_notify::notification::Notification; +use kaspa_notify::scope::{Scope, UtxosChangedScope, VirtualChainChangedScope, VirtualDaaScoreChangedScope}; +use kaspa_notify::{connection::ChannelType, events::EventType}; use kaspa_python_macros::py_async; use kaspa_rpc_core::api::rpc::RpcApi; use kaspa_rpc_core::model::*; -use kaspa_rpc_macros::build_wrpc_python_interface; +use kaspa_rpc_core::notify::connection::ChannelConnection; +use kaspa_rpc_macros::{build_wrpc_python_interface, build_wrpc_python_subscriptions}; use kaspa_wrpc_client::{ client::{ConnectOptions, ConnectStrategy}, + error::Error, + prelude::*, + result::Result, KaspaRpcClient, WrpcEncoding, }; -use pyo3::{prelude::*, types::PyDict}; -use std::time::Duration; +use pyo3::{ + exceptions::PyException, + prelude::*, + types::{PyDict, PyTuple}, +}; +use std::str::FromStr; +use std::{ + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, Mutex, + }, + time::Duration, +}; +use workflow_core::channel::{Channel, DuplexChannel}; +use workflow_log::*; +use workflow_rpc::client::Ctl; + +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +enum NotificationEvent { + All, + Notification(EventType), + RpcCtl(Ctl), +} + +impl FromStr for NotificationEvent { + type Err = Error; + fn from_str(s: &str) -> Result { + if s == "all" { + Ok(NotificationEvent::All) + } else if let Ok(ctl) = Ctl::from_str(s) { + Ok(NotificationEvent::RpcCtl(ctl)) + } else if let Ok(event) = EventType::from_str(s) { + Ok(NotificationEvent::Notification(event)) + } else { + Err(Error::custom(format!("Invalid notification event type: `{}`", s))) + } + } +} + +#[derive(Clone)] +struct PyCallback { + callback: PyObject, + args: Option>, + kwargs: Option>, +} + +impl PyCallback { + fn append_to_args(&self, py: Python, event: Bound) -> PyResult> { + match &self.args { + Some(existing_args) => { + let tuple_ref = existing_args.bind(py); + + let mut new_args: Vec = tuple_ref.iter().map(|arg| arg.to_object(py)).collect(); + new_args.push(event.into()); + + Ok(Py::from(PyTuple::new_bound(py, new_args))) + } + None => Ok(Py::from(PyTuple::new_bound(py, [event]))), + } + } + + fn execute(&self, py: Python, event: Bound) -> PyResult { + let args = self.append_to_args(py, event).unwrap(); + let kwargs = self.kwargs.as_ref().map(|kw| kw.bind(py)); + + let result = self + .callback + .call_bound(py, args.bind(py), kwargs) + .map_err(|e| pyo3::exceptions::PyException::new_err(format!("Error while executing RPC notification callback: {}", e))) + .unwrap(); + + Ok(result) + } +} + +pub struct Inner { + client: Arc, + // resolver TODO + notification_task: Arc, + notification_ctl: DuplexChannel, + callbacks: Arc>>>, + listener_id: Arc>>, + notification_channel: Channel, +} + +impl Inner { + fn notification_callbacks(&self, event: NotificationEvent) -> Option> { + let notification_callbacks = self.callbacks.lock().unwrap(); + let all = notification_callbacks.get(&NotificationEvent::All).cloned(); + let target = notification_callbacks.get(&event).cloned(); + match (all, target) { + (Some(mut vec_all), Some(vec_target)) => { + vec_all.extend(vec_target); + Some(vec_all) + } + (Some(vec_all), None) => Some(vec_all), + (None, Some(vec_target)) => Some(vec_target), + (None, None) => None, + } + } +} #[pyclass] +#[derive(Clone)] pub struct RpcClient { - client: KaspaRpcClient, - // url: String, - // encoding: Option, - // verbose : Option, - // timeout: Option, + inner: Arc, +} + +impl RpcClient { + fn new(url: Option, encoding: Option) -> Result { + let encoding = encoding.unwrap_or(WrpcEncoding::Borsh); + + let client = Arc::new(KaspaRpcClient::new(encoding, url.as_deref(), None, None, None).unwrap()); + + let rpc_client = RpcClient { + inner: Arc::new(Inner { + client, + notification_task: Arc::new(AtomicBool::new(false)), + notification_ctl: DuplexChannel::oneshot(), + callbacks: Arc::new(Default::default()), + listener_id: Arc::new(Mutex::new(None)), + notification_channel: Channel::unbounded(), + }), + }; + + Ok(rpc_client) + } } #[pymethods] impl RpcClient { - #[staticmethod] - fn connect(py: Python, url: Option) -> PyResult> { - let client = KaspaRpcClient::new(WrpcEncoding::Borsh, url.as_deref(), None, None, None)?; + #[new] + fn ctor(url: Option) -> PyResult { + // TODO expose args to Python similar to WASM wRPC Client IRpcConfig + + Ok(Self::new(url, None)?) + } + + fn url(&self) -> Option { + self.inner.client.url() + } + + fn is_connected(&self) -> bool { + self.inner.client.is_connected() + } + + fn encoding(&self) -> String { + self.inner.client.encoding().to_string() + } + fn connect(&self, py: Python) -> PyResult> { + // TODO expose args to Python similar to WASM wRPC Client IConnectOptions let options = ConnectOptions { block_async_connect: true, connect_timeout: Some(Duration::from_millis(5_000)), @@ -31,23 +176,27 @@ impl RpcClient { ..Default::default() }; - pyo3_asyncio_0_21::tokio::future_into_py(py, async move { - client.connect(Some(options)).await.map_err(|e| pyo3::exceptions::PyException::new_err(e.to_string()))?; + self.start_notification_task(py).unwrap(); - Python::with_gil(|py| { - Py::new(py, RpcClient { client }) - .map(|py_rpc_client| py_rpc_client.into_py(py)) - .map_err(|e| pyo3::exceptions::PyException::new_err(e.to_string())) - }) - }) + let client = self.inner.client.clone(); + py_async! {py, async move { + let _ = client.connect(Some(options)).await.map_err(|e| pyo3::exceptions::PyException::new_err(e.to_string())); + Ok(()) + }} } - fn is_connected(&self) -> bool { - self.client.is_connected() + fn disconnect(&self, py: Python) -> PyResult> { + let client = self.clone(); + + py_async! {py, async move { + client.inner.client.disconnect().await?; + client.stop_notification_task().await?; + Ok(()) + }} } fn get_server_info(&self, py: Python) -> PyResult> { - let client = self.client.clone(); + let client = self.inner.client.clone(); py_async! {py, async move { let response = client.get_server_info_call(GetServerInfoRequest { }).await?; Python::with_gil(|py| { @@ -57,7 +206,7 @@ impl RpcClient { } fn get_block_dag_info(&self, py: Python) -> PyResult> { - let client = self.client.clone(); + let client = self.inner.client.clone(); py_async! {py, async move { let response = client.get_block_dag_info_call(GetBlockDagInfoRequest { }).await?; Python::with_gil(|py| { @@ -65,15 +214,273 @@ impl RpcClient { }) }} } + + #[pyo3(signature = (event, callback, *args, **kwargs))] + fn add_event_listener( + &self, + py: Python, + event: String, + callback: PyObject, + args: &Bound<'_, PyTuple>, + kwargs: Option<&Bound<'_, PyDict>>, + ) -> PyResult<()> { + let event = NotificationEvent::from_str(event.as_str()).unwrap(); + + let args = args.to_object(py).extract::>(py).unwrap(); + let kwargs = kwargs.unwrap().to_object(py).extract::>(py).unwrap(); + + let py_callback = PyCallback { callback, args: Some(args), kwargs: Some(kwargs) }; + + self.inner.callbacks.lock().unwrap().entry(event).or_default().push(py_callback); + Ok(()) + } + + fn remove_event_listener(&self, py: Python, event: String, callback: Option) -> PyResult<()> { + let event = NotificationEvent::from_str(event.as_str()).unwrap(); + let mut callbacks = self.inner.callbacks.lock().unwrap(); + + match (&event, callback) { + (NotificationEvent::All, None) => { + // Remove all callbacks from "all" events + callbacks.clear(); + } + (NotificationEvent::All, Some(callback)) => { + // Remove given callback from "all" events + for callbacks in callbacks.values_mut() { + callbacks.retain(|c| { + let cb_ref = c.callback.bind(py); + let callback_ref = callback.bind(py); + cb_ref.as_ref().ne(callback_ref.as_ref()).unwrap_or(true) + }); + } + } + (_, None) => { + // Remove all callbacks from given event + callbacks.remove(&event); + } + (_, Some(callback)) => { + // Remove given callback from given event + if let Some(callbacks) = callbacks.get_mut(&event) { + callbacks.retain(|c| { + let cb_ref = c.callback.bind(py); + let callback_ref = callback.bind(py); + cb_ref.as_ref().ne(callback_ref.as_ref()).unwrap_or(true) + }); + } + } + } + Ok(()) + } + + // fn clear_event_listener() TODO + + fn remove_all_event_listeners(&self) -> PyResult<()> { + *self.inner.callbacks.lock().unwrap() = Default::default(); + Ok(()) + } +} + +impl RpcClient { + pub fn listener_id(&self) -> Option { + *self.inner.listener_id.lock().unwrap() + } + + async fn stop_notification_task(&self) -> Result<()> { + if self.inner.notification_task.load(Ordering::SeqCst) { + self.inner.notification_ctl.signal(()).await?; + self.inner.notification_task.store(false, Ordering::SeqCst); + } + Ok(()) + } + + fn start_notification_task(&self, py: Python) -> Result<()> { + if self.inner.notification_task.load(Ordering::SeqCst) { + return Ok(()); + } + + self.inner.notification_task.store(true, Ordering::SeqCst); + + let ctl_receiver = self.inner.notification_ctl.request.receiver.clone(); + let ctl_sender = self.inner.notification_ctl.response.sender.clone(); + let notification_receiver = self.inner.notification_channel.receiver.clone(); + let ctl_multiplexer_channel = + self.inner.client.rpc_client().ctl_multiplexer().as_ref().expect("Python RpcClient ctl_multiplexer is None").channel(); + let this = self.clone(); + + let _ = pyo3_asyncio_0_21::tokio::future_into_py(py, async move { + loop { + select_biased! { + msg = ctl_multiplexer_channel.recv().fuse() => { + if let Ok(ctl) = msg { + + match ctl { + Ctl::Connect => { + let listener_id = this.inner.client.register_new_listener(ChannelConnection::new( + "kaspapy-wrpc-client-python", + this.inner.notification_channel.sender.clone(), + ChannelType::Persistent, + )); + *this.inner.listener_id.lock().unwrap() = Some(listener_id); + } + Ctl::Disconnect => { + let listener_id = this.inner.listener_id.lock().unwrap().take(); + if let Some(listener_id) = listener_id { + if let Err(err) = this.inner.client.unregister_listener(listener_id).await { + log_error!("Error in unregister_listener: {:?}",err); + } + } + } + } + + let event = NotificationEvent::RpcCtl(ctl); + if let Some(handlers) = this.inner.notification_callbacks(event) { + for handler in handlers.into_iter() { + Python::with_gil(|py| { + let event = PyDict::new_bound(py); + event.set_item("type", ctl.to_string()).unwrap(); + // objectdict.set_item("rpc", ).unwrap(); TODO + + handler.execute(py, event).unwrap(); + }); + } + } + } + }, + msg = notification_receiver.recv().fuse() => { + if let Ok(notification) = &msg { + match ¬ification { + kaspa_rpc_core::Notification::UtxosChanged(utxos_changed_notification) => { + let event_type = notification.event_type(); + let notification_event = NotificationEvent::Notification(event_type); + if let Some(handlers) = this.inner.notification_callbacks(notification_event) { + let UtxosChangedNotification { added, removed } = utxos_changed_notification; + + for handler in handlers.into_iter() { + Python::with_gil(|py| { + let added = serde_pyobject::to_pyobject(py, added).unwrap(); + let removed = serde_pyobject::to_pyobject(py, removed).unwrap(); + + let event = PyDict::new_bound(py); + event.set_item("type", event_type.to_string()).unwrap(); + event.set_item("added", &added.to_object(py)).unwrap(); + event.set_item("removed", &removed.to_object(py)).unwrap(); + + handler.execute(py, event).unwrap(); + }) + } + } + }, + _ => { + let event_type = notification.event_type(); + let notification_event = NotificationEvent::Notification(event_type); + if let Some(handlers) = this.inner.notification_callbacks(notification_event) { + for handler in handlers.into_iter() { + Python::with_gil(|py| { + let event = PyDict::new_bound(py); + event.set_item("type", event_type.to_string()).unwrap(); + event.set_item("data", ¬ification.to_pyobject(py).unwrap()).unwrap(); + + handler.execute(py, event).unwrap(); + }); + } + } + } + } + } + } + _ = ctl_receiver.recv().fuse() => { + break; + }, + + } + } + + if let Some(listener_id) = this.listener_id() { + this.inner.listener_id.lock().unwrap().take(); + if let Err(err) = this.inner.client.unregister_listener(listener_id).await { + log_error!("Error in unregister_listener: {:?}", err); + } + } + + ctl_sender.send(()).await.ok(); + + Python::with_gil(|_| Ok(())) + }); + + Ok(()) + } +} + +#[pymethods] +impl RpcClient { + fn subscribe_utxos_changed(&self, py: Python, addresses: Vec
) -> PyResult> { + if let Some(listener_id) = self.listener_id() { + let client = self.inner.client.clone(); + py_async! {py, async move { + client.start_notify(listener_id, Scope::UtxosChanged(UtxosChangedScope { addresses })).await?; + Ok(()) + }} + } else { + Err(PyErr::new::("RPC subscribe on a closed connection")) + } + } + + fn unsubscribe_utxos_changed(&self, py: Python, addresses: Vec
) -> PyResult> { + if let Some(listener_id) = self.listener_id() { + let client = self.inner.client.clone(); + py_async! {py, async move { + client.stop_notify(listener_id, Scope::UtxosChanged(UtxosChangedScope { addresses })).await?; + Ok(()) + }} + } else { + Err(PyErr::new::("RPC unsubscribe on a closed connection")) + } + } + + fn subscribe_virtual_chain_changed(&self, py: Python, include_accepted_transaction_ids: bool) -> PyResult> { + if let Some(listener_id) = self.listener_id() { + let client = self.inner.client.clone(); + py_async! {py, async move { + client.start_notify(listener_id, Scope::VirtualChainChanged(VirtualChainChangedScope { include_accepted_transaction_ids })).await?; + Ok(()) + }} + } else { + Err(PyErr::new::("RPC subscribe on a closed connection")) + } + } + + fn unsubscribe_virtual_chain_changed(&self, py: Python, include_accepted_transaction_ids: bool) -> PyResult> { + if let Some(listener_id) = self.listener_id() { + let client = self.inner.client.clone(); + py_async! {py, async move { + client.stop_notify(listener_id, Scope::VirtualChainChanged(VirtualChainChangedScope { include_accepted_transaction_ids })).await?; + Ok(()) + }} + } else { + Err(PyErr::new::("RPC unsubscribe on a closed connection")) + } + } } #[pymethods] impl RpcClient { fn is_connected_test(&self) -> bool { - self.client.is_connected() + self.inner.client.is_connected() } } +build_wrpc_python_subscriptions!([ + // UtxosChanged - added above due to parameter `addresses: Vec
`` + // VirtualChainChanged - added above due to paramter `include_accepted_transaction_ids: bool` + BlockAdded, + FinalityConflict, + FinalityConflictResolved, + NewBlockTemplate, + PruningPointUtxoSetOverride, + SinkBlueScoreChanged, + VirtualDaaScoreChanged, +]); + build_wrpc_python_interface!([ AddPeer, Ban,