Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add generalised filtering nodes [Rust] #259

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/notebooks/0.3-Generalised_filtering.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"id": "1798765e-3d65-4bfd-964b-7f9b6b0902be",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -441,7 +441,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"id": "2d921e51-a940-42b2-88f2-e25bd7ab5a01",
"metadata": {
"editable": true,
Expand Down
2 changes: 1 addition & 1 deletion examples/exponential.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ fn main() {

// create a network with two exponential family state nodes
network.add_nodes(
"exponential-state",
"ef-state",
None,
None,
None,
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pub mod model;
pub mod utils;
pub mod math;
pub mod maths;
pub mod updates;
3 changes: 0 additions & 3 deletions src/math.rs

This file was deleted.

1 change: 1 addition & 0 deletions src/maths/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub mod sufficient_statistics;
15 changes: 15 additions & 0 deletions src/maths/sufficient_statistics.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
pub fn normal(x: &f64) -> Vec<f64> {
vec![*x, x.powf(2.0)]
}

pub fn multivariate_normal(x: &Vec<f64>) -> Vec<f64> {
vec![*x, x.powf(2.0)]
}

pub fn get_sufficient_statistics_fn(distribution: String) {
if distribution == "normal" {
normal;
} else if distribution == "multivariate_normal" {
multivariate_normal;
}
}
25 changes: 22 additions & 3 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,21 @@ impl Network {
/// * `value_children` - The indexes of the node's value children.
/// * `volatility_children` - The indexes of the node's volatility children.
/// * `volatility_parents` - The indexes of the node's volatility parents.
#[pyo3(signature = (kind="continuous-state", value_parents=None, value_children=None, volatility_parents=None, volatility_children=None,))]
pub fn add_nodes(&mut self, kind: &str, value_parents: Option<Vec<usize>>,
#[pyo3(signature = (
kind="continuous-state",
value_parents=None,
value_children=None,
volatility_parents=None,
volatility_children=None,
ef_dimension=None,
ef_distribution=None,
ef_learning=None,
)
)]
pub fn add_nodes(
&mut self,
kind: &str,
value_parents: Option<Vec<usize>>,
value_children: Option<Vec<usize>>,
volatility_parents: Option<Vec<usize>>, volatility_children: Option<Vec<usize>>, ) {

Expand All @@ -86,6 +99,7 @@ impl Network {
self.inputs.push(node_id);
}

// Update the edges variable
let edges = AdjacencyLists{
node_type: String::from(kind),
value_parents: value_parents,
Expand All @@ -94,6 +108,11 @@ impl Network {
volatility_children: volatility_children,
};

// Add emtpy adjacency lists in the new node
self.edges.insert(node_id, edges);

// TODO: Update the edges of parents and children accordingly

// add edges and attributes
if kind == "continuous-state" {

Expand All @@ -107,7 +126,6 @@ impl Network {
(String::from("autoconnection_strength"), 1.0)].into_iter().collect();

self.attributes.floats.insert(node_id, attributes);
self.edges.insert(node_id, edges);

} else if kind == "ef-state" {

Expand All @@ -123,6 +141,7 @@ impl Network {

}
}
}

pub fn set_update_sequence(&mut self) {
self.update_sequence = set_update_sequence(self);
Expand Down
2 changes: 1 addition & 1 deletion src/updates/prediction/continuous.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::model::Network;
/// Prediction from a continuous state node
///
/// # Arguments
/// * `network` - The main network containing the node.
/// * `network` - The main network structure.
/// * `node_idx` - The node index.
///
/// # Returns
Expand Down
4 changes: 2 additions & 2 deletions src/updates/prediction_error/exponential.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
use crate::model::Network;
use crate::math::sufficient_statistics;

/// Updating an exponential family state node
///
/// # Arguments
/// * `network` - The main network containing the node.
/// * `node_idx` - The node index.
/// * `sufficient_statistics` - A function computing the sufficient statistics of an exponential family distribution.
///
/// # Returns
/// * `network` - The network after message passing.
pub fn prediction_error_exponential_state_node(network: &mut Network, node_idx: usize) {
pub fn prediction_error_exponential_state_node(network: &mut Network, node_idx: usize, sufficient_statistics: fn(&f64) -> Vec<f64>) {

let floats_attributes = network.attributes.floats.get_mut(&node_idx).expect("No floats attributes");
let vectors_attributes = network.attributes.vectors.get_mut(&node_idx).expect("No vector attributes");
Expand Down
1 change: 1 addition & 0 deletions src/utils/set_sequence.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::{model::{AdjacencyLists, Network, UpdateSequence}, updates::{posterior::continuous::posterior_update_continuous_state_node, prediction::continuous::prediction_continuous_state_node, prediction_error::{continuous::prediction_error_continuous_state_node, exponential::prediction_error_exponential_state_node}}};
use crate::utils::function_pointer::FnType;
use crate::maths::sufficient_statistics::get_sufficient_statistics_fn;

pub fn set_update_sequence(network: &Network) -> UpdateSequence {
let predictions = get_predictions_sequence(network);
Expand Down
Loading