diff --git a/Cargo.lock b/Cargo.lock index 95199b79..2948de15 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -458,6 +458,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "downcast" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1435fa1053d8b2fbbe9be7e97eca7f33d37b28409959813daefc1446a14247f1" + [[package]] name = "either" version = "1.9.0" @@ -535,6 +541,12 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "fragile" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c2141d6d6c8512188a7891b4b01590a45f6dac67afb4f255c4124dbb86d4eaa" + [[package]] name = "futures" version = "0.3.28" @@ -951,6 +963,33 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "mockall" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43766c2b5203b10de348ffe19f7e54564b64f3d6018ff7648d1e2d6d3a0f0a48" +dependencies = [ + "cfg-if", + "downcast", + "fragile", + "lazy_static", + "mockall_derive", + "predicates", + "predicates-tree", +] + +[[package]] +name = "mockall_derive" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af7cbce79ec385a1d4f54baa90a76401eb15d9cab93685f62e7e9f942aa00ae2" +dependencies = [ + "cfg-if", + "proc-macro2", + "quote", + "syn 2.0.38", +] + [[package]] name = "mpsc" version = "0.2.3" @@ -1150,6 +1189,32 @@ version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" +[[package]] +name = "predicates" +version = "3.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68b87bfd4605926cdfefc1c3b5f8fe560e3feca9d5552cf68c466d3d8236c7e8" +dependencies = [ + "anstyle", + "predicates-core", +] + +[[package]] +name = "predicates-core" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b794032607612e7abeb4db69adb4e33590fa6cf1149e95fd7cb00e634b92f174" + +[[package]] +name = "predicates-tree" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "368ba315fb8c5052ab692e68a0eefec6ec57b23a36959c14496f0b0df2c0cecf" +dependencies = [ + "predicates-core", + "termtree", +] + [[package]] name = "prettyplease" version = "0.1.25" @@ -1713,6 +1778,7 @@ dependencies = [ "hex", "lightning", "log", + "mockall", "mpsc", "ntest", "rand", @@ -1841,6 +1907,12 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "termtree" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3369f5ac52d5eb6ab48c6b4ffdc8efbcad6b89c765749064ba298f2c68a16a76" + [[package]] name = "thiserror" version = "1.0.50" diff --git a/sim-lib/Cargo.toml b/sim-lib/Cargo.toml index 91819433..0e0d9ec1 100644 --- a/sim-lib/Cargo.toml +++ b/sim-lib/Cargo.toml @@ -29,6 +29,7 @@ hex = "0.4.3" csv = "1.2.2" serde_millis = "0.1.1" rand_distr = "0.4.3" +mockall = "0.12.1" [dev-dependencies] ntest = "0.9.0" \ No newline at end of file diff --git a/sim-lib/src/sim_node.rs b/sim-lib/src/sim_node.rs index 32e4f628..363764e4 100644 --- a/sim-lib/src/sim_node.rs +++ b/sim-lib/src/sim_node.rs @@ -225,25 +225,28 @@ impl ChannelState { Ok(()) } - /// Removes the HTLC from our set of outgoing in-flight HTLCs, failing if the payment hash is not found. If the - /// HTLC failed, the balance is returned to our local liquidity. Note that this function is not responsible for - /// reflecting that the balance has moved to the other side of the channel in the success-case, calling code is - /// responsible for that. - fn remove_outgoing_htlc( - &mut self, - hash: &PaymentHash, - success: bool, - ) -> Result { - match self.in_flight.remove(hash) { - Some(v) => { - // If the HTLC failed, pending balance returns to local balance. - if !success { - self.local_balance_msat += v.amount_msat; - } + /// Removes the HTLC from our set of outgoing in-flight HTLCs, failing if the payment hash is not found. + fn remove_outgoing_htlc(&mut self, hash: &PaymentHash) -> Result { + self.in_flight + .remove(hash) + .ok_or(ForwardingError::PaymentHashNotFound(*hash)) + } - Ok(v) - }, - None => Err(ForwardingError::PaymentHashNotFound(*hash)), + // Updates channel state to account for the resolution of an outgoing in-flight HTLC. If the HTLC failed, the + // balance is failed back to the channel's local balance. If not, the in-flight balance is settled to the other + // node, so there is no operation. + fn settle_outgoing_htlc(&mut self, amt: u64, success: bool) { + if !success { + self.local_balance_msat += amt + } + } + + // Updates channel state to account for the resolution of an incoming in-flight HTLC. If the HTLC succeeded, + // the balance is settled to the channel's local balance. If not, the in-flight balance is failed back to the + // other node, so there is no operation. + fn settle_incoming_htlc(&mut self, amt: u64, success: bool) { + if success { + self.local_balance_msat += amt } } } @@ -349,7 +352,7 @@ impl SimulatedChannel { ) -> Result<(), ForwardingError> { let htlc = self .get_node_mut(sending_node)? - .remove_outgoing_htlc(hash, success)?; + .remove_outgoing_htlc(hash)?; self.settle_htlc(sending_node, htlc.amount_msat, success)?; self.sanity_check() } @@ -363,20 +366,13 @@ impl SimulatedChannel { amount_msat: u64, success: bool, ) -> Result<(), ForwardingError> { - // Successful payments push balance to the receiver, failures return it to the sender. - let (sender_delta_msat, receiver_delta_msat) = if success { - (0, amount_msat) - } else { - (amount_msat, 0) - }; - if sending_node == &self.node_1.policy.pubkey { - self.node_1.local_balance_msat += sender_delta_msat; - self.node_2.local_balance_msat += receiver_delta_msat; + self.node_1.settle_outgoing_htlc(amount_msat, success); + self.node_2.settle_incoming_htlc(amount_msat, success); Ok(()) } else if sending_node == &self.node_2.policy.pubkey { - self.node_2.local_balance_msat += sender_delta_msat; - self.node_1.local_balance_msat += receiver_delta_msat; + self.node_2.settle_outgoing_htlc(amount_msat, success); + self.node_1.settle_incoming_htlc(amount_msat, success); Ok(()) } else { Err(ForwardingError::NodeNotFound(*sending_node)) @@ -909,7 +905,7 @@ async fn remove_htlcs( payment_hash: PaymentHash, success: bool, ) -> Result<(), ForwardingError> { - for (i, hop) in route.hops[0..resolution_idx].iter().enumerate().rev() { + for (i, hop) in route.hops[0..=resolution_idx].iter().enumerate().rev() { // When we add HTLCs, we do so on the state of the node that sent the htlc along the channel so we need to // look up our incoming node so that we can remove it when we go backwards. For the first htlc, this is just // the sending node, otherwise it's the hop before. @@ -982,8 +978,15 @@ async fn propagate_payment( } } else { // If we successfully added the htlc, go ahead and remove all the htlcs in the route with successful resolution. - if let Err(e) = - remove_htlcs(nodes, route.hops.len(), source, route, payment_hash, true).await + if let Err(e) = remove_htlcs( + nodes, + route.hops.len() - 1, + source, + route, + payment_hash, + true, + ) + .await { if e.is_critical() { shutdown.trigger(); @@ -1035,3 +1038,739 @@ impl UtxoLookup for UtxoValidator { })) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::test_utils::get_random_keypair; + use bitcoin::secp256k1::PublicKey; + use lightning::routing::router::Route; + use mockall::mock; + use std::time::Duration; + use tokio::sync::oneshot; + use tokio::time::timeout; + + /// Creates a test channel policy with its maximum HTLC size set to half of the in flight limit of the channel. + /// The minimum HTLC size is hardcoded to 2 so that we can fall beneath this value with a 1 msat htlc. + fn create_test_policy(max_in_flight_msat: u64) -> ChannelPolicy { + let (_, pk) = get_random_keypair(); + ChannelPolicy { + pubkey: pk, + max_htlc_count: 10, + max_in_flight_msat, + min_htlc_size_msat: 2, + max_htlc_size_msat: max_in_flight_msat / 2, + cltv_expiry_delta: 10, + base_fee: 1000, + fee_rate_prop: 5000, + } + } + + /// Creates a set of n simulated channels connected in a chain of channels, where the short channel ID of each + /// channel is its index in the chain of channels and all capacity is on the side of the node that opened the + /// channel. + /// + /// For example if n = 3 it will produce: node_1 -- node_2 -- node_3 -- node_4, connected by channels. + fn create_simulated_channels(n: u64, capacity_msat: u64) -> Vec { + let mut channels: Vec = vec![]; + let (_, first_node) = get_random_keypair(); + + // Create channels in a ring so that we'll get long payment paths. + let mut node_1 = first_node; + for i in 0..n { + // Generate a new random node pubkey. + let (_, node_2) = get_random_keypair(); + + let node_1_to_2 = ChannelPolicy { + pubkey: node_1, + max_htlc_count: 483, + max_in_flight_msat: capacity_msat / 2, + min_htlc_size_msat: 1, + max_htlc_size_msat: capacity_msat / 2, + cltv_expiry_delta: 40, + base_fee: 1000 * i, + fee_rate_prop: 1500 * i, + }; + + let node_2_to_1 = ChannelPolicy { + pubkey: node_2, + max_htlc_count: 483, + max_in_flight_msat: capacity_msat / 2, + min_htlc_size_msat: 1, + max_htlc_size_msat: capacity_msat / 2, + cltv_expiry_delta: 40 + 10 * i as u32, + base_fee: 2000 * i, + fee_rate_prop: i, + }; + + channels.push(SimulatedChannel { + capacity_msat, + // Unique channel ID per link. + short_channel_id: ShortChannelID::from(i), + node_1: ChannelState::new(node_1_to_2, capacity_msat), + node_2: ChannelState::new(node_2_to_1, 0), + }); + + // Progress source ID to create a chain of nodes. + node_1 = node_2; + } + + channels + } + + macro_rules! assert_channel_balances { + ($channel_state:expr, $local_balance:expr, $in_flight_len:expr, $in_flight_total:expr) => { + assert_eq!($channel_state.local_balance_msat, $local_balance); + assert_eq!($channel_state.in_flight.len(), $in_flight_len); + assert_eq!($channel_state.in_flight_total(), $in_flight_total); + }; + } + + /// Tests state updates related to adding and removing HTLCs to a channel. + #[test] + fn test_channel_state_transitions() { + let local_balance = 100_000_000; + let mut channel_state = + ChannelState::new(create_test_policy(local_balance / 2), local_balance); + + // Basic sanity check that we Initialize the channel correctly. + assert_channel_balances!(channel_state, local_balance, 0, 0); + + // Add a few HTLCs to our internal state and assert that balances are as expected. We'll test + // `check_outgoing_addition` in more detail in another test, so we just assert that we can add the htlc in + // this test. + let hash_1 = PaymentHash { 0: [1; 32] }; + let htlc_1 = Htlc { + amount_msat: 1000, + cltv_expiry: 40, + }; + + assert!(channel_state.add_outgoing_htlc(hash_1, htlc_1).is_ok()); + assert_channel_balances!( + channel_state, + local_balance - htlc_1.amount_msat, + 1, + htlc_1.amount_msat + ); + + // Try to add a htlc with the same payment hash and assert that we fail because we enforce one htlc per hash + // at present. + assert!(matches!( + channel_state.add_outgoing_htlc(hash_1, htlc_1), + Err(ForwardingError::PaymentHashExists(_)) + )); + + // Add a second, distinct htlc to our in-flight state. + let hash_2 = PaymentHash { 0: [2; 32] }; + let htlc_2 = Htlc { + amount_msat: 1000, + cltv_expiry: 40, + }; + + assert!(channel_state.add_outgoing_htlc(hash_2, htlc_2).is_ok()); + assert_channel_balances!( + channel_state, + local_balance - htlc_1.amount_msat - htlc_2.amount_msat, + 2, + htlc_1.amount_msat + htlc_2.amount_msat + ); + + // Remove our second htlc with a failure so that our in-flight drops and we return the balance. + assert!(channel_state.remove_outgoing_htlc(&hash_2).is_ok()); + channel_state.settle_outgoing_htlc(htlc_2.amount_msat, false); + assert_channel_balances!( + channel_state, + local_balance - htlc_1.amount_msat, + 1, + htlc_1.amount_msat + ); + + // Try to remove the same htlc and assert that we fail because the htlc can't be found. + assert!(matches!( + channel_state.remove_outgoing_htlc(&hash_2), + Err(ForwardingError::PaymentHashNotFound(_)) + )); + + // Finally, remove our original htlc with success and assert that our local balance is accordingly updated. + assert!(channel_state.remove_outgoing_htlc(&hash_1).is_ok()); + channel_state.settle_outgoing_htlc(htlc_1.amount_msat, true); + assert_channel_balances!(channel_state, local_balance - htlc_1.amount_msat, 0, 0); + } + + /// Tests policy checks applied when forwarding a htlc over a channel. + #[test] + fn test_htlc_forward() { + let local_balance = 140_000; + let channel_state = ChannelState::new(create_test_policy(local_balance / 2), local_balance); + + // CLTV delta insufficient (one less than required). + assert!(matches!( + channel_state.check_htlc_forward(channel_state.policy.cltv_expiry_delta - 1, 0, 0), + Err(ForwardingError::InsufficientCltvDelta(_, _)) + )); + + // Test insufficient fee. + let htlc_amount = 1000; + let htlc_fee = channel_state.policy.base_fee + + (channel_state.policy.fee_rate_prop * htlc_amount) / 1e6 as u64; + + assert!(matches!( + channel_state.check_htlc_forward( + channel_state.policy.cltv_expiry_delta, + htlc_amount, + htlc_fee - 1 + ), + Err(ForwardingError::InsufficientFee(_, _, _, _)) + )); + + // Test exact and over-estimation of required policy. + assert!(channel_state + .check_htlc_forward( + channel_state.policy.cltv_expiry_delta, + htlc_amount, + htlc_fee, + ) + .is_ok()); + + assert!(channel_state + .check_htlc_forward( + channel_state.policy.cltv_expiry_delta * 2, + htlc_amount, + htlc_fee * 3 + ) + .is_ok()); + } + + /// Test addition of outgoing htlc to local state. + #[test] + fn test_check_outgoing_addition() { + // Create test channel with low local liquidity so that we run into failures. + let local_balance = 100_000; + let mut channel_state = + ChannelState::new(create_test_policy(local_balance / 2), local_balance); + + let mut htlc = Htlc { + amount_msat: channel_state.policy.max_htlc_size_msat + 1, + cltv_expiry: channel_state.policy.cltv_expiry_delta, + }; + // HTLC maximum size exceeded. + assert!(matches!( + channel_state.check_outgoing_addition(&htlc), + Err(ForwardingError::MoreThanMaximum(_, _)) + )); + + // Beneath HTLC minimum size. + htlc.amount_msat = channel_state.policy.min_htlc_size_msat - 1; + assert!(matches!( + channel_state.check_outgoing_addition(&htlc), + Err(ForwardingError::LessThanMinimum(_, _)) + )); + + // Add two large htlcs so that we will start to run into our in-flight total amount limit. + let hash_1 = PaymentHash { 0: [1; 32] }; + let htlc_1 = Htlc { + amount_msat: channel_state.policy.max_in_flight_msat / 2, + cltv_expiry: channel_state.policy.cltv_expiry_delta, + }; + + assert!(channel_state.check_outgoing_addition(&htlc_1).is_ok()); + assert!(channel_state.add_outgoing_htlc(hash_1, htlc_1).is_ok()); + + let hash_2 = PaymentHash { 0: [2; 32] }; + let htlc_2 = Htlc { + amount_msat: channel_state.policy.max_in_flight_msat / 2, + cltv_expiry: channel_state.policy.cltv_expiry_delta, + }; + + assert!(channel_state.check_outgoing_addition(&htlc_2).is_ok()); + assert!(channel_state.add_outgoing_htlc(hash_2, htlc_2).is_ok()); + + // Now, assert that we can't add even our smallest htlc size, because we're hit our in-flight amount limit. + htlc.amount_msat = channel_state.policy.min_htlc_size_msat; + assert!(matches!( + channel_state.check_outgoing_addition(&htlc), + Err(ForwardingError::ExceedsInFlightTotal(_, _)) + )); + + // Resolve both of the htlcs successfully so that the local liquidity is no longer available. + assert!(channel_state.remove_outgoing_htlc(&hash_1).is_ok()); + channel_state.settle_outgoing_htlc(htlc_1.amount_msat, true); + + assert!(channel_state.remove_outgoing_htlc(&hash_2).is_ok()); + channel_state.settle_outgoing_htlc(htlc_2.amount_msat, true); + + // Now we're going to add many htlcs so that we hit our in-flight count limit (unique payment hash per htlc). + for i in 0..channel_state.policy.max_htlc_count { + let hash = PaymentHash { + 0: [i.try_into().unwrap(); 32], + }; + assert!(channel_state.check_outgoing_addition(&htlc).is_ok()); + assert!(channel_state.add_outgoing_htlc(hash, htlc).is_ok()); + } + + // Try to add one more htlc and we should be rejected. + let htlc_3 = Htlc { + amount_msat: channel_state.policy.min_htlc_size_msat, + cltv_expiry: channel_state.policy.cltv_expiry_delta, + }; + + assert!(matches!( + channel_state.check_outgoing_addition(&htlc_3), + Err(ForwardingError::ExceedsInFlightCount(_, _)) + )); + + // Resolve all in-flight htlcs. + for i in 0..channel_state.policy.max_htlc_count { + let hash = PaymentHash { + 0: [i.try_into().unwrap(); 32], + }; + assert!(channel_state.remove_outgoing_htlc(&hash).is_ok()); + channel_state.settle_outgoing_htlc(htlc.amount_msat, true) + } + + // Add and settle another htlc to move more liquidity away from our local balance. + let hash_4 = PaymentHash { 0: [1; 32] }; + let htlc_4 = Htlc { + amount_msat: channel_state.policy.max_htlc_size_msat, + cltv_expiry: channel_state.policy.cltv_expiry_delta, + }; + assert!(channel_state.check_outgoing_addition(&htlc_4).is_ok()); + assert!(channel_state.add_outgoing_htlc(hash_4, htlc_4).is_ok()); + assert!(channel_state.remove_outgoing_htlc(&hash_4).is_ok()); + channel_state.settle_outgoing_htlc(htlc_4.amount_msat, true); + + // Finally, assert that we don't have enough balance to forward our largest possible htlc (because of all the + // htlcs that we've settled) and assert that we fail to a large htlc. The balance assertion here is just a + // sanity check for the test, which will fail if we change the amounts settled/failed in the test. + assert!(channel_state.local_balance_msat < channel_state.policy.max_htlc_size_msat); + assert!(matches!( + channel_state.check_outgoing_addition(&htlc_4), + Err(ForwardingError::InsufficientBalance(_, _)) + )); + } + + /// Tests basic functionality of a `SimulatedChannel` but does no endeavor to test the underlying + /// `ChannelState`, as this is covered elsewhere in our tests. + #[test] + fn test_simulated_channel() { + // Create a test channel with all balance available to node 1 as local liquidity, and none for node_2 to begin + // with. + let capacity_msat = 500_000_000; + let node_1 = ChannelState::new(create_test_policy(capacity_msat / 2), capacity_msat); + let node_2 = ChannelState::new(create_test_policy(capacity_msat / 2), 0); + + let mut simulated_channel = SimulatedChannel { + capacity_msat, + short_channel_id: ShortChannelID::from(123), + node_1: node_1.clone(), + node_2: node_2.clone(), + }; + + // Assert that we're not able to send a htlc over node_2 -> node_1 (no liquidity). + let hash_1 = PaymentHash { 0: [1; 32] }; + let htlc_1 = Htlc { + amount_msat: node_2.policy.min_htlc_size_msat, + cltv_expiry: node_1.policy.cltv_expiry_delta, + }; + + assert!(matches!( + simulated_channel.add_htlc(&node_2.policy.pubkey, hash_1, htlc_1), + Err(ForwardingError::InsufficientBalance(_, _)) + )); + + // Assert that we can send a htlc over node_1 -> node_2. + let hash_2 = PaymentHash { 0: [1; 32] }; + let htlc_2 = Htlc { + amount_msat: node_1.policy.max_htlc_size_msat, + cltv_expiry: node_2.policy.cltv_expiry_delta, + }; + assert!(simulated_channel + .add_htlc(&node_1.policy.pubkey, hash_2, htlc_2) + .is_ok()); + + // Settle the htlc and then assert that we can send from node_2 -> node_2 because the balance has been shifted + // across channels. + assert!(simulated_channel + .remove_htlc(&node_1.policy.pubkey, &hash_2, true) + .is_ok()); + + assert!(simulated_channel + .add_htlc(&node_2.policy.pubkey, hash_2, htlc_2) + .is_ok()); + + // Finally, try to add/remove htlcs for a pubkey that is not participating in the channel and assert that we + // fail. + let (_, pk) = get_random_keypair(); + assert!(matches!( + simulated_channel.add_htlc(&pk, hash_2, htlc_2), + Err(ForwardingError::NodeNotFound(_)) + )); + + assert!(matches!( + simulated_channel.remove_htlc(&pk, &hash_2, true), + Err(ForwardingError::NodeNotFound(_)) + )); + } + + mock! { + Network{} + + #[async_trait] + impl SimNetwork for Network{ + fn dispatch_payment( + &mut self, + source: PublicKey, + route: Route, + payment_hash: PaymentHash, + sender: Sender>, + ); + + async fn lookup_node(&self, node: &PublicKey) -> Result<(NodeInfo, Vec), LightningError>; + } + } + + /// Tests the functionality of a `SimNode`, mocking out the `SimNetwork` that is responsible for payment + /// propagation to isolate testing to just the implementation of `LightningNode`. + #[tokio::test] + async fn test_simulated_node() { + // Mock out our network and create a routing graph with 5 hops. + let mock = MockNetwork::new(); + let sim_network = Arc::new(Mutex::new(mock)); + let channels = create_simulated_channels(5, 300000000); + let graph = populate_network_graph(channels.clone()).unwrap(); + + // Create a simulated node for the first channel in our network. + let pk = channels[0].node_1.policy.pubkey; + let mut node = SimNode::new(pk, sim_network.clone(), Arc::new(graph)); + + // Prime mock to return node info from lookup and assert that we get the pubkey we're expecting. + let lookup_pk = channels[3].node_1.policy.pubkey; + sim_network + .lock() + .await + .expect_lookup_node() + .returning(move |_| Ok((node_info(lookup_pk), vec![1, 2, 3]))); + + // Assert that we get three channels from the mock. + let node_info = node.get_node_info(&lookup_pk).await.unwrap(); + assert_eq!(lookup_pk, node_info.pubkey); + assert_eq!(node.list_channels().await.unwrap().len(), 3); + + // Next, we're going to test handling of in-flight payments. To do this, we'll mock out calls to our dispatch + // function to send different results depending on the destination. + let dest_1 = channels[2].node_1.policy.pubkey; + let dest_2 = channels[4].node_1.policy.pubkey; + + sim_network + .lock() + .await + .expect_dispatch_payment() + .returning( + move |_, route: Route, _, sender: Sender>| { + // If we've reached dispatch, we must have at least one path, grab the last hop to match the + // receiver. + let receiver = route.paths[0].hops.last().unwrap().pubkey; + let result = if receiver == dest_1 { + PaymentResult { + htlc_count: 2, + payment_outcome: PaymentOutcome::Success, + } + } else if receiver == dest_2 { + PaymentResult { + htlc_count: 0, + payment_outcome: PaymentOutcome::InsufficientBalance, + } + } else { + panic!("unknown mocked receiver"); + }; + + let _ = sender.send(Ok(result)).unwrap(); + }, + ); + + // Dispatch payments to different destinations and assert that our track payment results are as expected. + let hash_1 = node.send_payment(dest_1, 10_000).await.unwrap(); + let hash_2 = node.send_payment(dest_2, 15_000).await.unwrap(); + + let (_, shutdown_listener) = triggered::trigger(); + + let result_1 = node + .track_payment(hash_1, shutdown_listener.clone()) + .await + .unwrap(); + assert!(matches!(result_1.payment_outcome, PaymentOutcome::Success)); + + let result_2 = node + .track_payment(hash_2, shutdown_listener.clone()) + .await + .unwrap(); + assert!(matches!( + result_2.payment_outcome, + PaymentOutcome::InsufficientBalance + )); + } + + /// Contains elements required to test dispatch_payment functionality. + struct DispatchPaymentTestKit<'a> { + graph: SimGraph, + nodes: Vec, + routing_graph: NetworkGraph<&'a WrappedLog>, + shutdown: triggered::Trigger, + } + + impl<'a> DispatchPaymentTestKit<'a> { + /// Creates a test graph with a set of nodes connected by three channels, with all the capacity of the channel + /// on the side of the first node. For example, if called with capacity = 100 it will set up the following + /// network: + /// Alice (100) --- (0) Bob (100) --- (0) Carol (100) --- (0) Dave + /// + /// The nodes pubkeys in this chain of channels are provided in-order for easy access. + async fn new(capacity: u64) -> Self { + let (shutdown, _listener) = triggered::trigger(); + let channels = create_simulated_channels(3, capacity); + + // Collect pubkeys in-order, pushing the last node on separately because they don't have an outgoing + // channel (they are not node_1 in any channel, only node_2). + let mut nodes = channels + .iter() + .map(|c| c.node_1.policy.pubkey) + .collect::>(); + nodes.push(channels.last().unwrap().node_2.policy.pubkey); + + let kit = DispatchPaymentTestKit { + graph: SimGraph::new(channels.clone(), shutdown.clone()) + .expect("could not create test graph"), + nodes, + routing_graph: populate_network_graph(channels).unwrap(), + shutdown, + }; + + // Assert that our channel balance is all on the side of the channel opener when we start up. + assert_eq!( + kit.channel_balances().await, + vec![(capacity, 0), (capacity, 0), (capacity, 0)] + ); + + kit + } + + /// Returns a vector of local/remote channel balances for channels in the network. + async fn channel_balances(&self) -> Vec<(u64, u64)> { + let mut balances = vec![]; + + // We can't iterate through our hashmap of channels in-order, so we take advantage of our short channel id + // being the index in our chain of channels. This allows us to look up channels in-order. + let chan_count = self.graph.channels.lock().await.len(); + + for i in 0..chan_count { + let chan_lock = self.graph.channels.lock().await; + let channel = chan_lock.get(&ShortChannelID::from(i as u64)).unwrap(); + + // Take advantage of our test setup, which always makes node_1 the channel initiator to get our + // "in order" balances for the chain of channels. + balances.push(( + channel.node_1.local_balance_msat, + channel.node_2.local_balance_msat, + )); + } + + balances + } + + // Sends a test payment from source to destination and waits for the payment to complete, returning the route + // used. + async fn send_test_payemnt( + &mut self, + source: PublicKey, + dest: PublicKey, + amt: u64, + ) -> Route { + let route = find_payment_route(&source, dest, amt, &self.routing_graph).unwrap(); + + let (sender, receiver) = oneshot::channel(); + self.graph + .dispatch_payment(source, route.clone(), PaymentHash { 0: [1; 32] }, sender); + + // Assert that we receive from the channel or fail. + assert!(timeout(Duration::from_millis(10), receiver).await.is_ok()); + + route + } + + // Sets the balance on the channel to the tuple provided, used to arrange liquidity for testing. + async fn set_channel_balance(&mut self, scid: &ShortChannelID, balance: (u64, u64)) { + let mut channels_lock = self.graph.channels.lock().await; + let channel = channels_lock.get_mut(scid).unwrap(); + + channel.node_1.local_balance_msat = balance.0; + channel.node_2.local_balance_msat = balance.1; + + assert!(channel.sanity_check().is_ok()); + } + } + + /// Tests dispatch of a successfully settled payment across a test network of simulated channels: + /// Alice --- Bob --- Carol --- Dave + #[tokio::test] + async fn test_successful_dispatch() { + let chan_capacity = 500_000_000; + let mut test_kit = DispatchPaymentTestKit::new(chan_capacity).await; + + // Send a payment that should succeed from Alice -> Dave. + let mut amt = 20_000; + let route = test_kit + .send_test_payemnt(test_kit.nodes[0], test_kit.nodes[3], amt) + .await; + + let route_total = amt + route.get_total_fees(); + let hop_1_amt = amt + route.paths[0].hops[1].fee_msat; + + // The sending node should have pushed the amount + total fee to the intermediary. + let alice_to_bob = (chan_capacity - route_total, route_total); + // The middle hop should include fees for the outgoing link. + let mut bob_to_carol = (chan_capacity - hop_1_amt, hop_1_amt); + // The receiving node should have the payment amount pushed to them. + let carol_to_dave = (chan_capacity - amt, amt); + + let mut expected_balances = vec![alice_to_bob, bob_to_carol, carol_to_dave]; + assert_eq!(test_kit.channel_balances().await, expected_balances); + + // Next, we'll test the case where a payment fails on the first hop. This is an edge case in our state + // machine, so we want to specifically hit it. To do this, we'll try to send double the amount that we just + // pushed to Dave back to Bob, expecting a failure on Dave's outgoing link due to insufficient liquidity. + let _ = test_kit + .send_test_payemnt(test_kit.nodes[3], test_kit.nodes[1], amt * 2) + .await; + assert_eq!(test_kit.channel_balances().await, expected_balances); + + // Now, test a successful single-hop payment from Bob -> Carol. We'll do this twice, so that we can drain all + // the liquidity on Bob's side (to prepare for a multi-hop failure test). Our pathfinding only allows us to + // use 50% of the channel's capacity, so we need to do two payments. + amt = bob_to_carol.0 / 2; + let _ = test_kit + .send_test_payemnt(test_kit.nodes[1], test_kit.nodes[2], amt) + .await; + + bob_to_carol = (bob_to_carol.0 / 2, bob_to_carol.1 + amt); + expected_balances = vec![alice_to_bob, bob_to_carol, carol_to_dave]; + assert_eq!(test_kit.channel_balances().await, expected_balances); + + // When we push this amount a second time, all the liquidity should be moved to Carol's end. + let _ = test_kit + .send_test_payemnt(test_kit.nodes[1], test_kit.nodes[2], amt) + .await; + bob_to_carol = (0, chan_capacity); + expected_balances = vec![alice_to_bob, bob_to_carol, carol_to_dave]; + assert_eq!(test_kit.channel_balances().await, expected_balances); + + // Finally, we'll test a multi-hop failure by trying to send from Alice -> Dave. Since Bob's liquidity is + // drained, we expect a failure and unchanged balances along the route. + let _ = test_kit + .send_test_payemnt(test_kit.nodes[0], test_kit.nodes[3], 20_000) + .await; + assert_eq!(test_kit.channel_balances().await, expected_balances); + + test_kit.shutdown.trigger(); + test_kit.graph.wait_for_shutdown().await; + } + + /// Tests successful dispatch of a multi-hop payment. + #[tokio::test] + async fn test_successful_multi_hop() { + let chan_capacity = 500_000_000; + let mut test_kit = DispatchPaymentTestKit::new(chan_capacity).await; + + // Send a payment that should succeed from Alice -> Dave. + let amt = 20_000; + let route = test_kit + .send_test_payemnt(test_kit.nodes[0], test_kit.nodes[3], amt) + .await; + + let route_total = amt + route.get_total_fees(); + let hop_1_amt = amt + route.paths[0].hops[1].fee_msat; + + let expected_balances = vec![ + // The sending node should have pushed the amount + total fee to the intermediary. + (chan_capacity - route_total, route_total), + // The middle hop should include fees for the outgoing link. + (chan_capacity - hop_1_amt, hop_1_amt), + // The receiving node should have the payment amount pushed to them. + (chan_capacity - amt, amt), + ]; + assert_eq!(test_kit.channel_balances().await, expected_balances); + + test_kit.shutdown.trigger(); + test_kit.graph.wait_for_shutdown().await; + } + + /// Tests success and failure for single hop payments, which are an edge case in our state machine. + #[tokio::test] + async fn test_single_hop_payments() { + let chan_capacity = 500_000_000; + let mut test_kit = DispatchPaymentTestKit::new(chan_capacity).await; + + // Send a single hop payment from Alice -> Bob, it will succeed because Alice has all the liquidity. + let amt = 150_000; + let _ = test_kit + .send_test_payemnt(test_kit.nodes[0], test_kit.nodes[1], amt) + .await; + + let expected_balances = vec![ + (chan_capacity - amt, amt), + (chan_capacity, 0), + (chan_capacity, 0), + ]; + assert_eq!(test_kit.channel_balances().await, expected_balances); + + // Send a single hop payment from Dave -> Carol that will fail due to lack of liquidity, balances should be + // unchanged. + let _ = test_kit + .send_test_payemnt(test_kit.nodes[3], test_kit.nodes[2], amt) + .await; + + assert_eq!(test_kit.channel_balances().await, expected_balances); + + test_kit.shutdown.trigger(); + test_kit.graph.wait_for_shutdown().await; + } + + /// Tests failing back of multi-hop payments at various failure indexes. + #[tokio::test] + async fn test_multi_hop_faiulre() { + let chan_capacity = 500_000_000; + let mut test_kit = DispatchPaymentTestKit::new(chan_capacity).await; + + // Drain liquidity between Bob and Carol to force failures on Bob's outgoing linke. + test_kit + .set_channel_balance(&ShortChannelID::from(1), (0, chan_capacity)) + .await; + + let mut expected_balances = + vec![(chan_capacity, 0), (0, chan_capacity), (chan_capacity, 0)]; + assert_eq!(test_kit.channel_balances().await, expected_balances); + + // Send a payment from Alice -> Dave which we expect to fail leaving balances unaffected. + let amt = 150_000; + let _ = test_kit + .send_test_payemnt(test_kit.nodes[0], test_kit.nodes[3], amt) + .await; + + assert_eq!(test_kit.channel_balances().await, expected_balances); + + // Push liquidity to Dave so that we can send a payment which will fail on Bob's outgoing link, leaving + // balances unaffected. + expected_balances[2] = (0, chan_capacity); + test_kit + .set_channel_balance(&ShortChannelID::from(2), (0, chan_capacity)) + .await; + + let _ = test_kit + .send_test_payemnt(test_kit.nodes[3], test_kit.nodes[0], amt) + .await; + + assert_eq!(test_kit.channel_balances().await, expected_balances); + + test_kit.shutdown.trigger(); + test_kit.graph.wait_for_shutdown().await; + } +}