From 67f810e7206595e3b1102e1f53a890732da673e3 Mon Sep 17 00:00:00 2001 From: Sergi Delgado Segura Date: Fri, 16 Feb 2024 14:08:15 -0500 Subject: [PATCH] review suggestions --- sim-cli/src/main.rs | 25 +-- sim-lib/src/lib.rs | 43 ++++- sim-lib/src/sim_node.rs | 416 +++++++++++++++++----------------------- 3 files changed, 220 insertions(+), 264 deletions(-) diff --git a/sim-cli/src/main.rs b/sim-cli/src/main.rs index 474ff5c2..6e2e84b2 100644 --- a/sim-cli/src/main.rs +++ b/sim-cli/src/main.rs @@ -8,9 +8,7 @@ use rand::Rng; use sim_lib::{ cln::ClnNode, lnd::LndNode, - sim_node::{ - ln_node_from_graph, populate_network_graph, ChannelPolicy, SimGraph, SimulatedChannel, - }, + sim_node::{ChannelPolicy, SimulatedChannel}, ActivityDefinition, LightningError, LightningNode, NodeConnection, NodeId, SimParams, Simulation, WriteResults, }; @@ -204,28 +202,15 @@ async fn main() -> anyhow::Result<()> { None }; - let (shutdown_trigger, shutdown_listener) = triggered::trigger(); - - let channels = generate_sim_nodes(); - let graph = match SimGraph::new(channels.clone(), shutdown_trigger.clone()) { - Ok(graph) => Arc::new(Mutex::new(graph)), - Err(e) => anyhow::bail!("failed: {:?}", e), - }; - - let routing_graph = match populate_network_graph(channels) { - Ok(r) => r, - Err(e) => anyhow::bail!("failed: {:?}", e), - }; - - let sim = Simulation::new( - ln_node_from_graph(graph.clone(), Arc::new(routing_graph)).await, + let (sim, graph) = Simulation::from_sim_channels( + generate_sim_nodes(), validated_activities, cli.total_time, cli.expected_pmt_amt, cli.capacity_multiplier, write_results, - (shutdown_trigger, shutdown_listener), - ); + ) + .await?; let sim2 = sim.clone(); ctrlc::set_handler(move || { diff --git a/sim-lib/src/lib.rs b/sim-lib/src/lib.rs index 336a8dcb..ca9949bc 100644 --- a/sim-lib/src/lib.rs +++ b/sim-lib/src/lib.rs @@ -6,6 +6,7 @@ use lightning::ln::features::NodeFeatures; use lightning::ln::PaymentHash; use random_activity::RandomActivityError; use serde::{Deserialize, Serialize}; +use sim_node::{SimGraph, SimulatedChannel}; use std::collections::HashSet; use std::fmt::{Display, Formatter}; use std::marker::Send; @@ -21,6 +22,7 @@ use triggered::{Listener, Trigger}; use self::defined_activity::DefinedPaymentActivity; use self::random_activity::{NetworkGraphView, RandomPaymentActivity}; +use self::sim_node::{ln_node_from_graph, populate_network_graph}; pub mod cln; mod defined_activity; @@ -134,6 +136,8 @@ pub enum SimulationError { FileError, #[error("{0}")] RandomActivityError(RandomActivityError), + #[error("{0}")] + RandomGraphError(String), } #[derive(Debug, Error)] @@ -368,13 +372,14 @@ impl Simulation { expected_payment_msat: u64, activity_multiplier: f64, write_results: Option, - shutdown: (Trigger, Listener), ) -> Self { + let (shutdown_trigger, shutdown_listener) = triggered::trigger(); + Self { nodes, activity, - shutdown_trigger: shutdown.0, - shutdown_listener: shutdown.1, + shutdown_trigger, + shutdown_listener, total_time: total_time.map(|x| Duration::from_secs(x as u64)), expected_payment_msat, activity_multiplier, @@ -382,6 +387,38 @@ impl Simulation { } } + pub async fn from_sim_channels( + channels: Vec, + activity: Vec, + total_time: Option, + expected_payment_msat: u64, + activity_multiplier: f64, + write_results: Option, + ) -> Result<(Self, Arc>), SimulationError> { + let (shutdown_trigger, shutdown_listener) = triggered::trigger(); + + let sim_graph = SimGraph::new(channels.clone(), shutdown_trigger.clone()) + .map(|graph| Arc::new(Mutex::new(graph))) + .map_err(|e| SimulationError::RandomGraphError(e.err))?; + + let routing_graph = populate_network_graph(channels) + .map_err(|e| SimulationError::RandomGraphError(e.err))?; + + Ok(( + Self { + nodes: ln_node_from_graph(sim_graph.clone(), Arc::new(routing_graph)).await, + activity, + shutdown_trigger, + shutdown_listener, + total_time: total_time.map(|x| Duration::from_secs(x as u64)), + expected_payment_msat, + activity_multiplier, + write_results, + }, + sim_graph.clone(), + )) + } + /// validate_activity validates that the user-provided activity description is achievable for the network that /// we're working with. If no activity description is provided, then it ensures that we have configured a network /// that is suitable for random activity generation. diff --git a/sim-lib/src/sim_node.rs b/sim-lib/src/sim_node.rs index 46dde1b5..9b110a3e 100644 --- a/sim-lib/src/sim_node.rs +++ b/sim-lib/src/sim_node.rs @@ -16,11 +16,9 @@ use lightning::ln::msgs::{ use lightning::ln::{PaymentHash, PaymentPreimage}; use lightning::routing::gossip::{NetworkGraph, NodeId}; use lightning::routing::router::{ - find_route, Path, Payee, PaymentParameters, Route, RouteParameters, -}; -use lightning::routing::scoring::{ - ProbabilisticScorer, ProbabilisticScoringDecayParameters, ProbabilisticScoringFeeParameters, + find_route, Path, PaymentParameters, Route, RouteParameters, }; +use lightning::routing::scoring::ProbabilisticScorer; use lightning::routing::utxo::{UtxoLookup, UtxoResult}; use lightning::util::logger::{Level, Logger, Record}; use thiserror::Error; @@ -143,9 +141,7 @@ impl ChannelState { /// Returns the sum of all the *in flight outgoing* HTLCs on the channel. fn in_flight_total(&self) -> u64 { - self.in_flight - .iter() - .fold(0, |sum, val| sum + val.1.amount_msat) + self.in_flight.values().map(|h| h.amount_msat).sum() } /// Checks whether the proposed HTLC abides by the channel policy advertised for using this channel as the @@ -164,9 +160,7 @@ impl ChannelState { } // As u64 will round expected fee down to nearest msat (this is what the protocol dictates). - let expected_fee = (self.policy.base_fee as f64 - + ((self.policy.fee_rate_prop as f64 * amt as f64) / 1000000.0)) - as u64; + let expected_fee = self.policy.base_fee + ((self.policy.fee_rate_prop * amt) / 1000000); if fee < expected_fee { return Err(ForwardingError::InsufficientFee( fee, @@ -184,19 +178,22 @@ impl ChannelState { /// the addition of the HTLC. Specification sanity checks (such as reasonable CLTV) are also included, as this /// is where we'd check it in real life. fn check_outgoing_addition(&self, htlc: &Htlc) -> Result<(), ForwardingError> { - // Fails if the value provided fails its inequality check without policy. macro_rules! fail_policy_inequality { - ($value:expr, $op:tt, $field:ident, $error_variant:ident) => { - if $value $op self.policy.$field { - return Err(ForwardingError::$error_variant( - $value, - self.policy.$field, - )); - } - }; - } + ($value:expr, $op:tt, $field:ident, $error_variant:ident $(, $opt:ident)*) => { + if $value $op self.policy.$field { + return Err(ForwardingError::$error_variant( + $value, + self.policy.$field, + $( + self.policy.$opt, + )* + )); + } + }; + } fail_policy_inequality!(htlc.amount_msat, >, max_htlc_size_msat, MoreThanMaximum); + fail_policy_inequality!(htlc.amount_msat, >, max_htlc_size_msat, InsufficientFee, max_htlc_size_msat, max_htlc_size_msat); fail_policy_inequality!(htlc.amount_msat, <, min_htlc_size_msat, LessThanMinimum); fail_policy_inequality!(self.in_flight.len() as u64 + 1, >, max_htlc_count, ExceedsInFlightCount); fail_policy_inequality!( @@ -227,14 +224,13 @@ impl ChannelState { fn add_outgoing_htlc(&mut self, htlc: Htlc) -> Result<(), ForwardingError> { self.check_outgoing_addition(&htlc)?; - match self.in_flight.get(&htlc.hash) { - Some(_) => Err(ForwardingError::PaymentHashExists(htlc.hash)), - None => { - self.local_balance_msat -= htlc.amount_msat; - self.in_flight.insert(htlc.hash, htlc); - Ok(()) - } + if self.in_flight.get(&htlc.hash).is_some() { + return Err(ForwardingError::PaymentHashExists(htlc.hash)); } + self.local_balance_msat -= htlc.amount_msat; + self.in_flight.insert(htlc.hash, htlc); + + Ok(()) } /// Removes the HTLC from our set of outgoing in-flight HTLCs, failing if the payment hash is not found. If the @@ -250,7 +246,7 @@ impl ChannelState { Some(v) => { // If the HTLC failed, pending balance returns to local balance. if !success { - self.local_balance_msat += v.amount_msat + self.local_balance_msat += v.amount_msat; } Ok(v) @@ -300,23 +296,38 @@ impl SimulatedChannel { } } - /// Adds a htlc to the appropriate side of the simulated channel, checking its policy and balance are okay. - fn add_htlc(&mut self, node: PublicKey, htlc: Htlc) -> Result<(), ForwardingError> { - if htlc.amount_msat == 0 { - return Err(ForwardingError::ZeroAmountHtlc); + fn get_src(&self, src_pk: PublicKey) -> Result<&ChannelState, ForwardingError> { + if src_pk == self.node_1.policy.pubkey { + Ok(&self.node_1) + } else if src_pk == self.node_2.policy.pubkey { + Ok(&self.node_2) + } else { + Err(ForwardingError::NodeNotFound(src_pk)) } + } - if node == self.node_1.policy.pubkey { - self.node_1.add_outgoing_htlc(htlc)?; - return self.sanity_check(); + fn get_src_dst_mut( + &mut self, + src_pk: PublicKey, + ) -> Result<(&mut ChannelState, &mut ChannelState), ForwardingError> { + if src_pk == self.node_1.policy.pubkey { + Ok((&mut self.node_1, &mut self.node_2)) + } else if src_pk == self.node_2.policy.pubkey { + Ok((&mut self.node_2, &mut self.node_1)) + } else { + Err(ForwardingError::NodeNotFound(src_pk)) } + } - if node == self.node_2.policy.pubkey { - self.node_2.add_outgoing_htlc(htlc)?; - return self.sanity_check(); + /// Adds a htlc to the appropriate side of the simulated channel, checking its policy and balance are okay. + fn add_htlc(&mut self, node: PublicKey, htlc: Htlc) -> Result<(), ForwardingError> { + if htlc.amount_msat == 0 { + return Err(ForwardingError::ZeroAmountHtlc); } - Err(ForwardingError::NodeNotFound(node)) + let (src, _) = self.get_src_dst_mut(node)?; + src.add_outgoing_htlc(htlc)?; + self.sanity_check() } /// Performs a sanity check on the total balances in a channel. Note that we do not currently include on-chain @@ -344,35 +355,14 @@ impl SimulatedChannel { hash: PaymentHash, success: bool, ) -> Result<(), ForwardingError> { - // Removes the HTLC from the node that it was added to as an outgoing HTLC. If it succeeded, move the balance - // over to the other side of the channel. The HTLC removal will handle returning balance to the local channel - // in the case of a failure. - macro_rules! process_outgoing_htlc { - ($self:ident, $sender:ident, $receiver:ident, $hash:expr, $success:expr) => { - match $self.$sender.remove_outgoing_htlc($hash, $success){ - // If the HTLC was settled, its amount is transferred to the remote party's local balance. - // If it was failed, the above removal has already dealt with balance management. - Ok(htlc) => { - if $success { - $self.$receiver.local_balance_msat += htlc.amount_msat; - } - - return $self.sanity_check(); - }, - Err(e) => Err(e), - } - }; - } - - if incoming_node == self.node_1.policy.pubkey { - return process_outgoing_htlc!(self, node_1, node_2, hash, success); + let (src, dst) = self.get_src_dst_mut(incoming_node)?; + let htlc = src.remove_outgoing_htlc(hash, success)?; + if success { + dst.local_balance_msat += htlc.amount_msat } - if incoming_node == self.node_2.policy.pubkey { - return process_outgoing_htlc!(self, node_2, node_1, hash, success); - } + self.sanity_check() - Err(ForwardingError::NodeNotFound(incoming_node)) } /// Checks a htlc forward against the outgoing policy of the node provided. @@ -383,19 +373,7 @@ impl SimulatedChannel { amount_msat: u64, fee_msat: u64, ) -> Result<(), ForwardingError> { - if node == self.node_1.policy.pubkey { - return self - .node_1 - .check_htlc_forward(cltv_delta, amount_msat, fee_msat); - } - - if node == self.node_2.policy.pubkey { - return self - .node_2 - .check_htlc_forward(cltv_delta, amount_msat, fee_msat); - } - - Err(ForwardingError::NodeNotFound(node)) + self.get_src(node)?.check_htlc_forward(cltv_delta, amount_msat, fee_msat) } } @@ -464,34 +442,24 @@ fn node_info(pk: PublicKey) -> NodeInfo { /// Uses LDK's pathfinding algorithm with default parameters to find a path from source to destination, with no /// restrictions on fee budget. fn find_payment_route( - source: PublicKey, + source: &PublicKey, dest: PublicKey, amount_msat: u64, pathfinding_graph: &NetworkGraph<&WrappedLog>, ) -> Result { - let params = ProbabilisticScoringDecayParameters::default(); - let scorer = ProbabilisticScorer::new(params, pathfinding_graph, &WrappedLog {}); + let scorer = ProbabilisticScorer::new(Default::default(), pathfinding_graph, &WrappedLog {}); find_route( - &source, + source, &RouteParameters { - payment_params: PaymentParameters { - payee: Payee::Clear { - node_id: dest, - route_hints: Vec::new(), - features: None, - // We don't currently bother with final CLTV delta. - final_cltv_expiry_delta: 0, - }, - expiry_time: None, - max_total_cltv_expiry_delta: u32::MAX, + payment_params: + // We don't currently bother with final CLTV delta. + PaymentParameters::from_node_id(dest, 0) + .with_max_total_cltv_expiry_delta(u32::MAX) // TODO: set non-zero value to support MPP. - max_path_count: 1, + .with_max_path_count(1) // Allow sending htlcs up to 50% of the channel's capacity. - max_channel_saturation_power_of_half: 1, - previously_failed_channels: Vec::new(), - previously_failed_blinded_path_idxs: Vec::new(), - }, + .with_max_channel_saturation_power_of_half(1), final_value_msat: amount_msat, max_total_routing_fee_msat: None, }, @@ -499,7 +467,7 @@ fn find_payment_route( None, &WrappedLog {}, &scorer, - &ProbabilisticScoringFeeParameters::default(), + &Default::default(), &[0; 32], ) } @@ -531,7 +499,7 @@ impl LightningNode for SimNode<'_, T> { self.in_flight.insert(payment_hash, receiver); let route = match find_payment_route( - self.info.pubkey, + &self.info.pubkey, dest, amount_msat, &self.pathfinding_graph, @@ -569,7 +537,7 @@ impl LightningNode for SimNode<'_, T> { hash: PaymentHash, listener: Listener, ) -> Result { - match self.in_flight.get_mut(&hash) { + match self.in_flight.remove(&hash) { Some(receiver) => { select! { biased; @@ -577,7 +545,6 @@ impl LightningNode for SimNode<'_, T> { // If we get a payment result back, remove from our in flight set of payments and return the result. res = receiver => { - self.in_flight.remove(&hash); res.map_err(|e| LightningError::TrackPaymentError(format!("channel receive err: {}", e)))? }, } @@ -629,22 +596,17 @@ impl SimGraph { let mut nodes: HashMap> = HashMap::new(); let mut channels = HashMap::new(); - for channel in graph_channels.iter() { - channels.insert(channel.short_channel_id, channel.clone()); - - macro_rules! insert_node_entry { - ($pubkey:expr) => {{ - match nodes.entry($pubkey) { - Entry::Occupied(o) => o.into_mut().push(channel.capacity_msat), - Entry::Vacant(v) => { - v.insert(vec![channel.capacity_msat]); - } + for channel in graph_channels.into_iter() { + for pubkey in [channel.node_1.policy.pubkey, channel.node_2.policy.pubkey]{ + match nodes.entry(pubkey) { + Entry::Occupied(o) => o.into_mut().push(channel.capacity_msat), + Entry::Vacant(v) => { + v.insert(vec![channel.capacity_msat]); } - }}; - } + } + } - insert_node_entry!(channel.node_1.policy.pubkey); - insert_node_entry!(channel.node_2.policy.pubkey); + channels.insert(channel.short_channel_id, channel); } Ok(SimGraph { @@ -729,32 +691,27 @@ pub fn populate_network_graph<'a>( graph.update_channel_from_unsigned_announcement(&announcement, &Some(&utxo_validator))?; - macro_rules! generate_and_update_channel { - ($node:expr, $flags:expr) => {{ - let update = UnsignedChannelUpdate { - chain_hash, - short_channel_id: channel.short_channel_id, - timestamp: SystemTime::now() - .duration_since(UNIX_EPOCH) - .unwrap() - .as_secs() as u32, - flags: $flags, - cltv_expiry_delta: $node.policy.cltv_expiry_delta as u16, - htlc_minimum_msat: $node.policy.min_htlc_size_msat, - htlc_maximum_msat: $node.policy.max_htlc_size_msat, - fee_base_msat: $node.policy.base_fee as u32, - fee_proportional_millionths: $node.policy.fee_rate_prop as u32, - excess_data: Vec::new(), - }; - - graph.update_channel_unsigned(&update)?; - }}; - } - // The least significant bit of the channel flag field represents the direction that the channel update // applies to. This value is interpreted as node_1 if it is zero, and node_2 otherwise. - generate_and_update_channel!(channel.node_1, 0); - generate_and_update_channel!(channel.node_2, 1); + for (i, node) in [channel.node_1, channel.node_2].iter().enumerate() { + let update = UnsignedChannelUpdate { + chain_hash, + short_channel_id: channel.short_channel_id, + timestamp: SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() as u32, + flags: i as u8, + cltv_expiry_delta: node.policy.cltv_expiry_delta as u16, + htlc_minimum_msat: node.policy.min_htlc_size_msat, + htlc_maximum_msat: node.policy.max_htlc_size_msat, + fee_base_msat: node.policy.base_fee as u32, + fee_proportional_millionths: node.policy.fee_rate_prop as u32, + excess_data: Vec::new(), + }; + + graph.update_channel_unsigned(&update)?; + } } Ok(graph) @@ -802,12 +759,7 @@ impl SimNetwork for SimGraph { /// lookup_node fetches a node's information and channel capacities. async fn lookup_node(&self, node: &PublicKey) -> Result<(NodeInfo, Vec), LightningError> { - match self.nodes.get(node) { - Some(channels) => Ok((node_info(*node), channels.clone())), - None => Err(LightningError::GetNodeInfoError( - "Node not found".to_string(), - )), - } + self.nodes.get(node).map(|channels| (node_info(*node), channels.clone())).ok_or(LightningError::GetNodeInfoError("Node not found".to_string())) } } @@ -836,61 +788,51 @@ async fn add_htlcs( ) -> Result<(), (Option, ForwardingError)> { let mut outgoing_node = source; let mut outgoing_amount = route.fee_msat() + route.final_value_msat(); - let mut outgoing_cltv = route - .hops - .iter() - .fold(0, |sum, value| sum + value.cltv_expiry_delta); + let mut outgoing_cltv = route.hops.iter().map(|hop| hop.cltv_expiry_delta).sum(); let mut fail_idx = None; + let mut node_lock = nodes.lock().await; for (i, hop) in route.hops.iter().enumerate() { - let mut node_lock = nodes.lock().await; - match node_lock.get_mut(&hop.short_channel_id) { - Some(channel) => { - if let Err(e) = channel.add_htlc( - outgoing_node, - Htlc { - amount_msat: outgoing_amount, - cltv_expiry: outgoing_cltv, - hash: payment_hash, - }, - ) { - // If we couldn't add to this HTLC, we only need to fail back from the preceeding hop, so we don't - // have to progress our fail_idx. - return Err((fail_idx, e)); - } - - // If the HTLC was successfully added, then we'll need to remove the HTLC from this channel if we fail, - // so we progress our failure index to include this node. - fail_idx = Some(i); - - // Once we've added the HTLC on this hop's channel, we want to check whether it has sufficient fee - // and CLTV delta per the _next_ channel's policy (because fees and CLTV delta in LN are charged on - // the outgoing link). We check the policy belonging to the node that we just forwarded to, which - // represents the fee in that direction. - // - // Note that we don't check the final hop's requirements for CLTV delta at present. - if i != route.hops.len() - 1 { - if let Some(channel) = node_lock.get(&route.hops[i + 1].short_channel_id) { - if let Err(e) = channel.check_htlc_forward( - hop.pubkey, - hop.cltv_expiry_delta, - outgoing_amount - hop.fee_msat, - hop.fee_msat, - ) { - // If we haven't met forwarding conditions for the next channel's policy, then we fail at - // the current index, because we've already added the HTLC as outgoing. - return Err((fail_idx, e)); - } - } + if let Some(channel) = node_lock.get_mut(&hop.short_channel_id) { + channel.add_htlc( + outgoing_node, + Htlc { + amount_msat: outgoing_amount, + cltv_expiry: outgoing_cltv, + hash: payment_hash, + }, + // If we couldn't add to this HTLC, we only need to fail back from the preceeding hop, so we don't + // have to progress our fail_idx. + ).map_err(|e| (fail_idx, e))?; + + // If the HTLC was successfully added, then we'll need to remove the HTLC from this channel if we fail, + // so we progress our failure index to include this node. + fail_idx = Some(i); + + // Once we've added the HTLC on this hop's channel, we want to check whether it has sufficient fee + // and CLTV delta per the _next_ channel's policy (because fees and CLTV delta in LN are charged on + // the outgoing link). We check the policy belonging to the node that we just forwarded to, which + // represents the fee in that direction. + // + // Note that we don't check the final hop's requirements for CLTV delta at present. + if i != route.hops.len() - 1 { + if let Some(channel) = node_lock.get_mut(&route.hops[i + 1].short_channel_id) { + channel.check_htlc_forward( + hop.pubkey, + hop.cltv_expiry_delta, + outgoing_amount - hop.fee_msat, + hop.fee_msat, + // If we haven't met forwarding conditions for the next channel's policy, then we fail at + // the current index, because we've already added the HTLC as outgoing. + ).map_err(|e| (fail_idx, e))?; } } - None => { - return Err(( - fail_idx, - ForwardingError::ChannelNotFound(hop.short_channel_id), - )) - } + } else { + return Err(( + fail_idx, + ForwardingError::ChannelNotFound(hop.short_channel_id), + )) } // Once we've taken the "hop" to the destination pubkey, it becomes the source of the next outgoing htlc. @@ -921,9 +863,8 @@ async fn remove_htlcs( payment_hash: PaymentHash, success: bool, ) -> Result<(), ForwardingError> { - for i in (0..resolution_idx).rev() { - let hop = &route.hops[i]; - + let mut locked_nodes = nodes.lock().await; + for (i, hop) in route.hops[0..resolution_idx].iter().rev().enumerate() { // When we add HTLCs, we do so on the state of the node that sent the htlc along the channel so we need to // look up our incoming node so that we can remove it when we go backwards. For the first htlc, this is just // the sending node, otherwise it's the hop before. @@ -933,7 +874,7 @@ async fn remove_htlcs( route.hops[i - 1].pubkey }; - match nodes.lock().await.get_mut(&hop.short_channel_id) { + match locked_nodes.get_mut(&hop.short_channel_id) { Some(channel) => channel.remove_htlc(incoming_node, payment_hash, success)?, None => return Err(ForwardingError::ChannelNotFound(hop.short_channel_id)), } @@ -957,60 +898,53 @@ async fn propagate_payment( let preimage_bytes = Sha256::hash(&preimage.0[..]).to_byte_array(); let payment_hash = PaymentHash(preimage_bytes); - let notify_result = match add_htlcs(nodes.clone(), source, route.clone(), payment_hash).await { - // If we successfully added the htlc, go ahead and remove all the htlcs in the route with successful resolution. - Ok(_) => { - if let Err(e) = remove_htlcs( - nodes, - route.hops.len() - 1, - source, - route, - payment_hash, - true, - ) - .await + let notify_result = if let Err((fail_idx, err)) = add_htlcs(nodes.clone(), source, route.clone(), payment_hash).await { + if err.is_critical() { + shutdown.trigger(); + } + + if let Some(resolution_idx) = fail_idx { + if let Err(e) = + remove_htlcs(nodes, resolution_idx, source, route, payment_hash, false).await { if e.is_critical() { shutdown.trigger(); } - - log::error!("Could not remove htlcs from channel: {e}."); } + } - PaymentResult { - htlc_count: 1, - payment_outcome: PaymentOutcome::Success, - } + // We have more information about failures because we're in control of the whole route, so we log the + // actual failure reason and then fail back with unknown failure type. + log::debug!( + "Forwarding failure for simulated payment {}: {err}", + hex::encode(payment_hash.0) + ); + + PaymentResult { + htlc_count: 0, + payment_outcome: PaymentOutcome::Unknown, } - // If we partially added HTLCs along the route, we need to fail them back to the source to clean up our - // partial state. It's possible that we failed with the very first add, and then we don't need to clean - // anything up. - Err((fail_idx, err)) => { - if err.is_critical() { + } else { + if let Err(e) = remove_htlcs( + nodes, + route.hops.len() - 1, + source, + route, + payment_hash, + true, + ) + .await + { + if e.is_critical() { shutdown.trigger(); } - if let Some(resolution_idx) = fail_idx { - if let Err(e) = - remove_htlcs(nodes, resolution_idx, source, route, payment_hash, false).await - { - if e.is_critical() { - shutdown.trigger(); - } - } - } - - // We have more information about failures because we're in control of the whole route, so we log the - // actual failure reason and then fail back with unknown failure type. - log::debug!( - "Forwarding failure for simulated payment {}: {err}", - hex::encode(payment_hash.0) - ); + log::error!("Could not remove htlcs from channel: {e}."); + } - PaymentResult { - htlc_count: 0, - payment_outcome: PaymentOutcome::Unknown, - } + PaymentResult { + htlc_count: 1, + payment_outcome: PaymentOutcome::Success, } };