Skip to content

Commit

Permalink
Bug 1868177: Add ProofOfRef arguments to WebSocketImpl methods r=smaug
Browse files Browse the repository at this point in the history
  • Loading branch information
jesup committed Dec 4, 2023
1 parent e50c3a8 commit d976964
Showing 1 changed file with 70 additions and 49 deletions.
119 changes: 70 additions & 49 deletions dom/websocket/WebSocket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,12 +168,14 @@ class WebSocketImpl final : public nsIInterfaceRequestor,
nsICookieJarSettings* aCookieJarSettings);

// These methods when called can release the WebSocket object
void FailConnection(uint16_t reasonCode,
void FailConnection(const RefPtr<WebSocketImpl>& aProofOfRef,
uint16_t reasonCode,
const nsACString& aReasonString = ""_ns);
nsresult CloseConnection(uint16_t reasonCode,
nsresult CloseConnection(const RefPtr<WebSocketImpl>& aProofOfRef,
uint16_t reasonCode,
const nsACString& aReasonString = ""_ns);
void Disconnect();
void DisconnectInternal();
void Disconnect(const RefPtr<WebSocketImpl>& aProofOfRef);
void DisconnectInternal(const RefPtr<WebSocketImpl>& aProofOfRef);

nsresult ConsoleError();
void PrintErrorOnConsole(const char* aBundleURI, const char* aError,
Expand All @@ -185,7 +187,7 @@ class WebSocketImpl final : public nsIInterfaceRequestor,
nsresult ScheduleConnectionCloseEvents(nsISupports* aContext,
nsresult aStatusCode);
// 2nd half of ScheduleConnectionCloseEvents, run in its own event.
void DispatchConnectionCloseEvents();
void DispatchConnectionCloseEvents(const RefPtr<WebSocketImpl>& aProofOfRef);

nsresult UpdateURI();

Expand Down Expand Up @@ -265,7 +267,8 @@ class WebSocketImpl final : public nsIInterfaceRequestor,

// If we threw during Init we never called disconnect
if (!mDisconnectingOrDisconnected) {
Disconnect();
RefPtr<WebSocketImpl> self(this);
Disconnect(self);
}
}
};
Expand Down Expand Up @@ -306,7 +309,7 @@ class CallDispatchConnectionCloseEvents final : public DiscardableRunnable {

NS_IMETHOD Run() override {
mWebSocketImpl->AssertIsOnTargetThread();
mWebSocketImpl->DispatchConnectionCloseEvents();
mWebSocketImpl->DispatchConnectionCloseEvents(mWebSocketImpl);
return NS_OK;
}

Expand Down Expand Up @@ -452,12 +455,12 @@ class MOZ_STACK_CLASS MaybeDisconnect {
}

if (toDisconnect) {
mImpl->Disconnect();
mImpl->Disconnect(mImpl);
}
}

private:
WebSocketImpl* mImpl;
RefPtr<WebSocketImpl> mImpl;
};

class CloseConnectionRunnable final : public Runnable {
Expand All @@ -470,7 +473,7 @@ class CloseConnectionRunnable final : public Runnable {
mReasonString(aReasonString) {}

NS_IMETHOD Run() override {
return mImpl->CloseConnection(mReasonCode, mReasonString);
return mImpl->CloseConnection(mImpl, mReasonCode, mReasonString);
}

private:
Expand All @@ -481,8 +484,9 @@ class CloseConnectionRunnable final : public Runnable {

} // namespace

nsresult WebSocketImpl::CloseConnection(uint16_t aReasonCode,
const nsACString& aReasonString) {
nsresult WebSocketImpl::CloseConnection(
const RefPtr<WebSocketImpl>& aProofOfRef, uint16_t aReasonCode,
const nsACString& aReasonString) {
if (!IsTargetThread()) {
nsCOMPtr<nsIRunnable> runnable =
new CloseConnectionRunnable(this, aReasonCode, aReasonString);
Expand Down Expand Up @@ -564,7 +568,8 @@ nsresult WebSocketImpl::ConsoleError() {
return NS_OK;
}

void WebSocketImpl::FailConnection(uint16_t aReasonCode,
void WebSocketImpl::FailConnection(const RefPtr<WebSocketImpl>& aProofOfRef,
uint16_t aReasonCode,
const nsACString& aReasonString) {
AssertIsOnTargetThread();

Expand All @@ -574,7 +579,7 @@ void WebSocketImpl::FailConnection(uint16_t aReasonCode,

ConsoleError();
mFailed = true;
CloseConnection(aReasonCode, aReasonString);
CloseConnection(aProofOfRef, aReasonCode, aReasonString);

if (NS_IsMainThread() && mImplProxy) {
mImplProxy->Disconnect();
Expand All @@ -586,36 +591,35 @@ namespace {

class DisconnectInternalRunnable final : public WorkerMainThreadRunnable {
public:
explicit DisconnectInternalRunnable(WebSocketImpl* aImpl)
explicit DisconnectInternalRunnable(const RefPtr<WebSocketImpl>& aImpl)
: WorkerMainThreadRunnable(GetCurrentThreadWorkerPrivate(),
"WebSocket :: disconnect"_ns),
mImpl(aImpl) {}

bool MainThreadRun() override {
mImpl->DisconnectInternal();
mImpl->DisconnectInternal(mImpl);
return true;
}

private:
// A raw pointer because this runnable is sync.
WebSocketImpl* mImpl;
RefPtr<WebSocketImpl> mImpl;
};

} // namespace

void WebSocketImpl::Disconnect() {
void WebSocketImpl::Disconnect(const RefPtr<WebSocketImpl>& aProofOfRef) {
MOZ_RELEASE_ASSERT(NS_IsMainThread() == mIsMainThread);

if (mDisconnectingOrDisconnected) {
return;
}

// DontKeepAliveAnyMore() and DisconnectInternal() can release the object. So
// hold a reference to this until the end of the method.
RefPtr<WebSocketImpl> kungfuDeathGrip = this;
// DontKeepAliveAnyMore() and DisconnectInternal() can release the
// object. aProofOfRef ensures we're holding a reference to this until
// the end of the method.

// Disconnect can be called from some control event (such as a callback from
// StrongWorkerRef). This will be schedulated before any other sync/async
// StrongWorkerRef). This will be scheduled before any other sync/async
// runnable. In order to prevent some double Disconnect() calls, we use this
// boolean.
mDisconnectingOrDisconnected = true;
Expand All @@ -624,7 +628,7 @@ void WebSocketImpl::Disconnect() {
// the main thread.

if (NS_IsMainThread()) {
DisconnectInternal();
DisconnectInternal(aProofOfRef);

// If we haven't called WebSocket::DisconnectFromOwner yet, update
// web socket count here.
Expand All @@ -633,7 +637,7 @@ void WebSocketImpl::Disconnect() {
}
} else {
RefPtr<DisconnectInternalRunnable> runnable =
new DisconnectInternalRunnable(this);
new DisconnectInternalRunnable(aProofOfRef);
ErrorResult rv;
runnable->Dispatch(Killing, rv);
// XXXbz this seems totally broken. We should be propagating this out, but
Expand All @@ -655,12 +659,13 @@ void WebSocketImpl::Disconnect() {
mWebSocket = nullptr;
}

void WebSocketImpl::DisconnectInternal() {
void WebSocketImpl::DisconnectInternal(
const RefPtr<WebSocketImpl>& aProofOfRef) {
AssertIsOnMainThread();

nsCOMPtr<nsILoadGroup> loadGroup = do_QueryReferent(mWeakLoadGroup);
if (loadGroup) {
loadGroup->RemoveRequest(this, nullptr, NS_OK);
loadGroup->RemoveRequest(aProofOfRef, nullptr, NS_OK);
// mWeakLoadGroup has to be release on main-thread because WeakReferences
// are not thread-safe.
mWeakLoadGroup = nullptr;
Expand Down Expand Up @@ -777,7 +782,8 @@ WebSocketImpl::OnStart(nsISupports* aContext) {
// Attempt to kill "ghost" websocket: but usually too early for check to fail
nsresult rv = mWebSocket->CheckCurrentGlobalCorrectness();
if (NS_FAILED(rv)) {
CloseConnection(nsIWebSocketChannel::CLOSE_GOING_AWAY);
RefPtr<WebSocketImpl> self(this);
CloseConnection(self, nsIWebSocketChannel::CLOSE_GOING_AWAY);
return rv;
}

Expand All @@ -798,7 +804,7 @@ WebSocketImpl::OnStart(nsISupports* aContext) {
mChannel->HttpChannelId());

// Let's keep the object alive because the webSocket can be CCed in the
// onopen callback.
// onopen callback
RefPtr<WebSocket> webSocket = mWebSocket;

// Call 'onopen'
Expand Down Expand Up @@ -909,10 +915,11 @@ WebSocketImpl::OnServerClose(nsISupports* aContext, uint16_t aCode,
// RFC 6455, 5.5.1: "When sending a Close frame in response, the endpoint
// typically echos the status code it received".
// But never send certain codes, per section 7.4.1
RefPtr<WebSocketImpl> self(this);
if (aCode == 1005 || aCode == 1006 || aCode == 1015) {
CloseConnection(0, ""_ns);
CloseConnection(self, 0, ""_ns);
} else {
CloseConnection(aCode, aReason);
CloseConnection(self, aCode, aReason);
}
} else {
// We initiated close, and server has replied: OnStop does rest of the work.
Expand All @@ -929,13 +936,14 @@ WebSocketImpl::OnError() {
NS_NewRunnableFunction("dom::FailConnectionRunnable",
[self = RefPtr{this}]() {
self->FailConnection(
nsIWebSocketChannel::CLOSE_ABNORMAL);
self, nsIWebSocketChannel::CLOSE_ABNORMAL);
}),
NS_DISPATCH_NORMAL);
}

AssertIsOnTargetThread();
FailConnection(nsIWebSocketChannel::CLOSE_ABNORMAL);
RefPtr<WebSocketImpl> self(this);
FailConnection(self, nsIWebSocketChannel::CLOSE_ABNORMAL);
return NS_OK;
}

Expand Down Expand Up @@ -1420,7 +1428,8 @@ already_AddRefed<WebSocket> WebSocket::ConstructorCommon(
// We don't return an error if the connection just failed. Instead we dispatch
// an event.
if (connectionFailed) {
webSocket->mImpl->FailConnection(nsIWebSocketChannel::CLOSE_ABNORMAL);
webSocketImpl->FailConnection(webSocketImpl,
nsIWebSocketChannel::CLOSE_ABNORMAL);
}

// If we don't have a channel, the connection is failed and onerror() will be
Expand All @@ -1439,11 +1448,12 @@ already_AddRefed<WebSocket> WebSocket::ConstructorCommon(
~ClearWebSocket() {
if (!mDone) {
mWebSocketImpl->mChannel = nullptr;
mWebSocketImpl->FailConnection(nsIWebSocketChannel::CLOSE_ABNORMAL);
mWebSocketImpl->FailConnection(mWebSocketImpl,
nsIWebSocketChannel::CLOSE_ABNORMAL);
}
}

WebSocketImpl* mWebSocketImpl;
RefPtr<WebSocketImpl> mWebSocketImpl;
bool mDone;
};

Expand Down Expand Up @@ -1531,7 +1541,8 @@ NS_IMPL_CYCLE_COLLECTION_TRAVERSE_END
NS_IMPL_CYCLE_COLLECTION_UNLINK_BEGIN_INHERITED(WebSocket, DOMEventTargetHelper)
if (tmp->mImpl) {
NS_IMPL_CYCLE_COLLECTION_UNLINK(mImpl->mChannel)
tmp->mImpl->Disconnect();
RefPtr<WebSocketImpl> pin(tmp->mImpl);
pin->Disconnect(pin);
MOZ_ASSERT(!tmp->mImpl);
}
NS_IMPL_CYCLE_COLLECTION_UNLINK_END
Expand All @@ -1555,7 +1566,8 @@ void WebSocket::DisconnectFromOwner() {
DOMEventTargetHelper::DisconnectFromOwner();

if (mImpl) {
mImpl->CloseConnection(nsIWebSocketChannel::CLOSE_GOING_AWAY);
RefPtr<WebSocketImpl> pin(mImpl);
pin->CloseConnection(pin, nsIWebSocketChannel::CLOSE_GOING_AWAY);
}

DontKeepAliveAnyMore();
Expand Down Expand Up @@ -1842,7 +1854,7 @@ class nsAutoCloseWS final {
~nsAutoCloseWS() {
if (!mWebSocketImpl->mChannel) {
mWebSocketImpl->CloseConnection(
nsIWebSocketChannel::CLOSE_INTERNAL_ERROR);
mWebSocketImpl, nsIWebSocketChannel::CLOSE_INTERNAL_ERROR);
}
}

Expand Down Expand Up @@ -1923,7 +1935,8 @@ nsresult WebSocketImpl::InitializeConnection(
return NS_OK;
}

void WebSocketImpl::DispatchConnectionCloseEvents() {
void WebSocketImpl::DispatchConnectionCloseEvents(
const RefPtr<WebSocketImpl>& aProofOfRef) {
AssertIsOnTargetThread();

if (mDisconnectingOrDisconnected) {
Expand All @@ -1933,7 +1946,7 @@ void WebSocketImpl::DispatchConnectionCloseEvents() {
mWebSocket->SetReadyState(WebSocket::CLOSED);

// Let's keep the object alive because the webSocket can be CCed in the
// onerror or in the onclose callback.
// onerror or in the onclose callback
RefPtr<WebSocket> webSocket = mWebSocket;

// Call 'onerror' if needed
Expand All @@ -1951,7 +1964,7 @@ void WebSocketImpl::DispatchConnectionCloseEvents() {
}

webSocket->UpdateMustKeepAlive();
Disconnect();
Disconnect(aProofOfRef);
}

nsresult WebSocket::CreateAndDispatchSimpleEvent(const nsAString& aName) {
Expand Down Expand Up @@ -2216,6 +2229,8 @@ void WebSocket::UpdateMustKeepAlive() {
if (mKeepingAlive && !shouldKeepAlive) {
mKeepingAlive = false;
mImpl->ReleaseObject();
// Note that this could be made 'alive' again if another listener is
// added.
} else if (!mKeepingAlive && shouldKeepAlive) {
mKeepingAlive = true;
mImpl->AddRefObject();
Expand Down Expand Up @@ -2249,7 +2264,7 @@ void WebSocketImpl::ReleaseObject() {
bool WebSocketImpl::RegisterWorkerRef(WorkerPrivate* aWorkerPrivate) {
MOZ_ASSERT(aWorkerPrivate);

RefPtr<WebSocketImpl> self = this;
RefPtr<WebSocketImpl> self(this);

// In workers we have to keep the worker alive using a strong reference in
// order to dispatch messages correctly.
Expand All @@ -2260,7 +2275,8 @@ bool WebSocketImpl::RegisterWorkerRef(WorkerPrivate* aWorkerPrivate) {
self->mWorkerShuttingDown = true;
}

self->CloseConnection(nsIWebSocketChannel::CLOSE_GOING_AWAY, ""_ns);
self->CloseConnection(self, nsIWebSocketChannel::CLOSE_GOING_AWAY,
""_ns);
});
if (NS_WARN_IF(!workerRef)) {
return false;
Expand Down Expand Up @@ -2519,14 +2535,17 @@ void WebSocket::Close(const Optional<uint16_t>& aCode,
return;
}

RefPtr<WebSocketImpl> impl = mImpl;
// These could cause the mImpl to be released (and so this to be
// released); make sure it stays valid through the call
RefPtr<WebSocketImpl> pin(mImpl);

if (readyState == CONNECTING) {
impl->FailConnection(closeCode, closeReason);
pin->FailConnection(pin, closeCode, closeReason);
return;
}

MOZ_ASSERT(readyState == OPEN);
impl->CloseConnection(closeCode, closeReason);
pin->CloseConnection(pin, closeCode, closeReason);
}

//-----------------------------------------------------------------------------
Expand All @@ -2550,7 +2569,8 @@ WebSocketImpl::Observe(nsISupports* aSubject, const char* aTopic,

if ((strcmp(aTopic, DOM_WINDOW_FROZEN_TOPIC) == 0) ||
(strcmp(aTopic, DOM_WINDOW_DESTROYED_TOPIC) == 0)) {
CloseConnection(nsIWebSocketChannel::CLOSE_GOING_AWAY);
RefPtr<WebSocketImpl> self(this);
CloseConnection(self, nsIWebSocketChannel::CLOSE_GOING_AWAY);
}

return NS_OK;
Expand Down Expand Up @@ -2649,7 +2669,8 @@ nsresult WebSocketImpl::CancelInternal() {
return NS_OK;
}

return CloseConnection(nsIWebSocketChannel::CLOSE_GOING_AWAY);
RefPtr<WebSocketImpl> self(this);
return CloseConnection(self, nsIWebSocketChannel::CLOSE_GOING_AWAY);
}

NS_IMETHODIMP
Expand Down

0 comments on commit d976964

Please sign in to comment.