Skip to content

Commit

Permalink
feat(devSrv): Prefer to connect on last known endpoint first
Browse files Browse the repository at this point in the history
  • Loading branch information
dr1rrb committed Aug 1, 2024
1 parent cf4fc16 commit c8ebb9d
Showing 1 changed file with 60 additions and 26 deletions.
86 changes: 60 additions & 26 deletions src/Uno.UI.RemoteControl/RemoteControlClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Collections.Generic;
using System.Data.Common;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Globalization;
using System.IO;
using System.Linq;
Expand All @@ -21,6 +22,7 @@
using Uno.UI.RemoteControl.HotReload.Messages;
using Uno.UI.RemoteControl.Messages;
using Windows.Networking.Sockets;
using Windows.Storage;
using static Uno.UI.RemoteControl.RemoteControlStatus;

namespace Uno.UI.RemoteControl;
Expand Down Expand Up @@ -205,7 +207,7 @@ public void RegisterPreProcessor(IRemoteControlPreProcessor preprocessor)
{
if (this.Log().IsEnabled(LogLevel.Warning))
{
this.Log().LogWarning($"No server addresses provided, skipping.");
this.Log().LogWarning("No server addresses provided, skipping.");
}

_status.Report(ConnectionState.NoServer);
Expand All @@ -220,27 +222,38 @@ public void RegisterPreProcessor(IRemoteControlPreProcessor preprocessor)

_status.Report(ConnectionState.Connecting);


const string lastEndpointKey = "__UNO__" + nameof(RemoteControlClient) + "__last_endpoint";
var preferred = ApplicationData.Current.LocalSettings.Values.TryGetValue(lastEndpointKey, out var lastValue) && lastValue is string lastEp
? _serverAddresses.FirstOrDefault(srv => srv.endpoint.Equals(lastEp, StringComparison.OrdinalIgnoreCase)).endpoint
: default;
var pending = _serverAddresses
.Where(adr => adr.port != 0 || Uri.TryCreate(adr.endpoint, UriKind.Absolute, out _))
.Select(s =>
.Select(srv =>
{
var cts = new CancellationTokenSource();
var task = Connect(s.endpoint, s.port, isHttps, cts.Token);
if (TryParse(srv.endpoint, srv.port, isHttps, out var serverUri))
{
var cts = new CancellationTokenSource();
var delay = preferred is null || preferred.Equals(srv.endpoint, StringComparison.OrdinalIgnoreCase) ? 0 : 1000;
var task = Connect(serverUri, delay, cts.Token);

return (task, cts);
return (task, srv.endpoint, cts);
}

return default;
})
.Where(c => c.task is not null)
.ToDictionary(c => c.task as Task);
var timeout = Task.Delay(30000);

// Ensure to await all connection tasks to avoid UnobservedTaskException
CleanupConnections(pending.Keys);

// Wait for the first connection to succeed
Connection? connected = default;
while (connected is null && pending is { Count: > 0 })
Connection? connection = default;
while (connection is null && pending is { Count: > 0 })
{
var completed = await Task.WhenAny(pending.Keys.Concat(timeout));
if (completed == timeout)
var task = await Task.WhenAny(pending.Keys.Concat(timeout));
if (task == timeout)
{
if (this.Log().IsEnabled(LogLevel.Error))
{
Expand All @@ -254,23 +267,26 @@ public void RegisterPreProcessor(IRemoteControlPreProcessor preprocessor)
}

// Remove the completed task from the pending list, no matter its completion status
pending.Remove(completed);
var (_, endpoint, _) = pending[task];
pending.Remove(task);

// If the connection is successful, break the loop
if (completed.IsCompleted
&& ((Task<Connection?>)completed).Result is { Socket: not null } successfulConnection)
if (task.IsCompleted
&& ((Task<Connection?>)task).Result is { Socket: not null } successfulConnection)
{
connected = successfulConnection;
ApplicationData.Current.LocalSettings.Values[lastEndpointKey] = endpoint;

connection = successfulConnection;
break;
}
}

// Abort all other pending connections
AbortPending();

_status.ReportActiveConnection(connected);
_status.ReportActiveConnection(connection);

if (connected is null)
if (connection is null)
{
if (this.Log().IsEnabled(LogLevel.Error))
{
Expand All @@ -283,19 +299,19 @@ public void RegisterPreProcessor(IRemoteControlPreProcessor preprocessor)
{
if (this.Log().IsEnabled(LogLevel.Debug))
{
this.Log().LogDebug($"Connected to {connected!.EndPoint}");
this.Log().LogDebug($"Connected to {connection!.EndPoint}");
}

connected.EnsureActive();
connection.EnsureActive();

return connected;
return connection;
}

void AbortPending()
{
foreach (var connection in pending.Values)
{
connection.cts.Cancel();
connection.cts.Cancel(throwOnFirstException: false);
if (connection is { task.Status: TaskStatus.RanToCompletion, task.Result.Socket: { } socket })
{
_ = socket.CloseAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None);
Expand All @@ -321,11 +337,10 @@ void AbortPending()
}
}

private async Task<Connection?> Connect(string endpoint, int port, bool isHttps, CancellationToken ct)
private bool TryParse(string endpoint, int port, bool isHttps, [NotNullWhen(true)] out Uri? serverUri)
{
try
{
Uri serverUri;
if (Uri.TryCreate(endpoint, UriKind.Absolute, out var fullUri))
{
var wsScheme = fullUri.Scheme switch
Expand Down Expand Up @@ -353,13 +368,18 @@ void AbortPending()

serverUri = new Uri($"wss://{endpoint}/rc");
}
else
else if (port is not 0)
{
var scheme = isHttps ? "wss" : "ws";
serverUri = new Uri($"{scheme}://{endpoint}:{port}/rc");
}
else
{
serverUri = default;
return false;
}

return await Connect(serverUri, ct);
return true;
}
catch (Exception e)
{
Expand All @@ -368,17 +388,31 @@ void AbortPending()
this.Log().Trace($"Connecting to [{endpoint}:{port}] failed: {e.Message}");
}

return null;
serverUri = default;
return false;
}
}

private async Task<Connection> Connect(Uri serverUri, CancellationToken ct)
private Task<Connection> Connect(Uri serverUri, CancellationToken ct)
=> Connect(serverUri, 0, ct);

private async Task<Connection> Connect(Uri serverUri, int delay, CancellationToken ct)
{
// Note: This method **MUST NOT** throw any exception as it being used for re-connection

var watch = Stopwatch.StartNew();
try
{
if (delay > 0)
{
await Task.Delay(delay, ct);
}

if (ct.IsCancellationRequested)
{
return new(this, serverUri, watch, null);
}

if (this.Log().IsEnabled(LogLevel.Trace))
{
this.Log().Trace($"Connecting to [{serverUri}]");
Expand Down

0 comments on commit c8ebb9d

Please sign in to comment.