From d49f627ef975fc158eff8f7a3fc66f8d4d1b3b25 Mon Sep 17 00:00:00 2001 From: LegrandNico Date: Wed, 13 Nov 2024 11:12:46 +0100 Subject: [PATCH] improve generalised filtering nodes --- .../notebooks/0.3-Generalised_filtering.ipynb | 4 +-- examples/exponential.rs | 2 +- src/lib.rs | 2 +- src/math.rs | 3 --- src/maths/mod.rs | 1 + src/maths/sufficient_statistics.rs | 15 +++++++++++ src/model.rs | 25 ++++++++++++++++--- src/updates/prediction/continuous.rs | 2 +- src/updates/prediction_error/exponential.rs | 4 +-- src/utils/set_sequence.rs | 1 + 10 files changed, 46 insertions(+), 13 deletions(-) delete mode 100644 src/math.rs create mode 100644 src/maths/mod.rs create mode 100644 src/maths/sufficient_statistics.rs diff --git a/docs/source/notebooks/0.3-Generalised_filtering.ipynb b/docs/source/notebooks/0.3-Generalised_filtering.ipynb index 959326490..6c04b3bff 100644 --- a/docs/source/notebooks/0.3-Generalised_filtering.ipynb +++ b/docs/source/notebooks/0.3-Generalised_filtering.ipynb @@ -391,7 +391,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "id": "1798765e-3d65-4bfd-964b-7f9b6b0902be", "metadata": {}, "outputs": [ @@ -441,7 +441,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "id": "2d921e51-a940-42b2-88f2-e25bd7ab5a01", "metadata": { "editable": true, diff --git a/examples/exponential.rs b/examples/exponential.rs index dd58f6368..9451c52a5 100644 --- a/examples/exponential.rs +++ b/examples/exponential.rs @@ -7,7 +7,7 @@ fn main() { // create a network with two exponential family state nodes network.add_nodes( - "exponential-state", + "ef-state", None, None, None, diff --git a/src/lib.rs b/src/lib.rs index 55ad98cd0..f45ae2192 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,4 @@ pub mod model; pub mod utils; -pub mod math; +pub mod maths; pub mod updates; \ No newline at end of file diff --git a/src/math.rs b/src/math.rs deleted file mode 100644 index 80cd8d4dd..000000000 --- a/src/math.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub fn sufficient_statistics(x: &f64) -> Vec { - vec![*x, x.powf(2.0)] -} \ No newline at end of file diff --git a/src/maths/mod.rs b/src/maths/mod.rs new file mode 100644 index 000000000..2ce73d356 --- /dev/null +++ b/src/maths/mod.rs @@ -0,0 +1 @@ +pub mod sufficient_statistics; \ No newline at end of file diff --git a/src/maths/sufficient_statistics.rs b/src/maths/sufficient_statistics.rs new file mode 100644 index 000000000..5bcb4d96e --- /dev/null +++ b/src/maths/sufficient_statistics.rs @@ -0,0 +1,15 @@ +pub fn normal(x: &f64) -> Vec { + vec![*x, x.powf(2.0)] +} + +pub fn multivariate_normal(x: &Vec) -> Vec { + vec![*x, x.powf(2.0)] +} + +pub fn get_sufficient_statistics_fn(distribution: String) { + if distribution == "normal" { + normal; + } else if distribution == "multivariate_normal" { + multivariate_normal; + } +} \ No newline at end of file diff --git a/src/model.rs b/src/model.rs index 15e9fee99..80d62bf1c 100644 --- a/src/model.rs +++ b/src/model.rs @@ -73,8 +73,21 @@ impl Network { /// * `value_children` - The indexes of the node's value children. /// * `volatility_children` - The indexes of the node's volatility children. /// * `volatility_parents` - The indexes of the node's volatility parents. - #[pyo3(signature = (kind="continuous-state", value_parents=None, value_children=None, volatility_parents=None, volatility_children=None,))] - pub fn add_nodes(&mut self, kind: &str, value_parents: Option>, + #[pyo3(signature = ( + kind="continuous-state", + value_parents=None, + value_children=None, + volatility_parents=None, + volatility_children=None, + ef_dimension=None, + ef_distribution=None, + ef_learning=None, + ) + )] + pub fn add_nodes( + &mut self, + kind: &str, + value_parents: Option>, value_children: Option>, volatility_parents: Option>, volatility_children: Option>, ) { @@ -86,6 +99,7 @@ impl Network { self.inputs.push(node_id); } + // Update the edges variable let edges = AdjacencyLists{ node_type: String::from(kind), value_parents: value_parents, @@ -94,6 +108,11 @@ impl Network { volatility_children: volatility_children, }; + // Add emtpy adjacency lists in the new node + self.edges.insert(node_id, edges); + + // TODO: Update the edges of parents and children accordingly + // add edges and attributes if kind == "continuous-state" { @@ -107,7 +126,6 @@ impl Network { (String::from("autoconnection_strength"), 1.0)].into_iter().collect(); self.attributes.floats.insert(node_id, attributes); - self.edges.insert(node_id, edges); } else if kind == "ef-state" { @@ -123,6 +141,7 @@ impl Network { } } +} pub fn set_update_sequence(&mut self) { self.update_sequence = set_update_sequence(self); diff --git a/src/updates/prediction/continuous.rs b/src/updates/prediction/continuous.rs index f40d2c35e..1a2ff3b6d 100644 --- a/src/updates/prediction/continuous.rs +++ b/src/updates/prediction/continuous.rs @@ -3,7 +3,7 @@ use crate::model::Network; /// Prediction from a continuous state node /// /// # Arguments -/// * `network` - The main network containing the node. +/// * `network` - The main network structure. /// * `node_idx` - The node index. /// /// # Returns diff --git a/src/updates/prediction_error/exponential.rs b/src/updates/prediction_error/exponential.rs index 82239ed2a..84ac38d92 100644 --- a/src/updates/prediction_error/exponential.rs +++ b/src/updates/prediction_error/exponential.rs @@ -1,15 +1,15 @@ use crate::model::Network; -use crate::math::sufficient_statistics; /// Updating an exponential family state node /// /// # Arguments /// * `network` - The main network containing the node. /// * `node_idx` - The node index. +/// * `sufficient_statistics` - A function computing the sufficient statistics of an exponential family distribution. /// /// # Returns /// * `network` - The network after message passing. -pub fn prediction_error_exponential_state_node(network: &mut Network, node_idx: usize) { +pub fn prediction_error_exponential_state_node(network: &mut Network, node_idx: usize, sufficient_statistics: fn(&f64) -> Vec) { let floats_attributes = network.attributes.floats.get_mut(&node_idx).expect("No floats attributes"); let vectors_attributes = network.attributes.vectors.get_mut(&node_idx).expect("No vector attributes"); diff --git a/src/utils/set_sequence.rs b/src/utils/set_sequence.rs index 436467367..83a595056 100644 --- a/src/utils/set_sequence.rs +++ b/src/utils/set_sequence.rs @@ -1,5 +1,6 @@ use crate::{model::{AdjacencyLists, Network, UpdateSequence}, updates::{posterior::continuous::posterior_update_continuous_state_node, prediction::continuous::prediction_continuous_state_node, prediction_error::{continuous::prediction_error_continuous_state_node, exponential::prediction_error_exponential_state_node}}}; use crate::utils::function_pointer::FnType; +use crate::maths::sufficient_statistics::get_sufficient_statistics_fn; pub fn set_update_sequence(network: &Network) -> UpdateSequence { let predictions = get_predictions_sequence(network);