Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Android] Remove repeated calls to beginHandshake #78849

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,7 @@ jmethodID g_SSLEngineBeginHandshake;
jmethodID g_SSLEngineCloseOutbound;
jmethodID g_SSLEngineGetApplicationProtocol;
jmethodID g_SSLEngineGetHandshakeStatus;
jmethodID g_SSLEngineGetHandshakeSession;
jmethodID g_SSLEngineGetSession;
jmethodID g_SSLEngineGetSSLParameters;
jmethodID g_SSLEngineGetSupportedProtocols;
Expand Down Expand Up @@ -1002,6 +1003,7 @@ JNI_OnLoad(JavaVM *vm, void *reserved)
g_SSLEngineGetApplicationProtocol = GetOptionalMethod(env, false, g_SSLEngine, "getApplicationProtocol", "()Ljava/lang/String;");
g_SSLEngineGetHandshakeStatus = GetMethod(env, false, g_SSLEngine, "getHandshakeStatus", "()Ljavax/net/ssl/SSLEngineResult$HandshakeStatus;");
g_SSLEngineGetSession = GetMethod(env, false, g_SSLEngine, "getSession", "()Ljavax/net/ssl/SSLSession;");
g_SSLEngineGetHandshakeSession = GetOptionalMethod(env, false, g_SSLEngine, "getHandshakeSession", "()Ljavax/net/ssl/SSLSession;");
g_SSLEngineGetSSLParameters = GetMethod(env, false, g_SSLEngine, "getSSLParameters", "()Ljavax/net/ssl/SSLParameters;");
g_SSLEngineGetSupportedProtocols = GetMethod(env, false, g_SSLEngine, "getSupportedProtocols", "()[Ljava/lang/String;");
g_SSLEngineSetEnabledProtocols = GetMethod(env, false, g_SSLEngine, "setEnabledProtocols", "([Ljava/lang/String;)V");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,7 @@ extern jmethodID g_SSLEngineBeginHandshake;
extern jmethodID g_SSLEngineCloseOutbound;
extern jmethodID g_SSLEngineGetApplicationProtocol;
extern jmethodID g_SSLEngineGetHandshakeStatus;
extern jmethodID g_SSLEngineGetHandshakeSession;
extern jmethodID g_SSLEngineGetSession;
extern jmethodID g_SSLEngineGetSSLParameters;
extern jmethodID g_SSLEngineGetSupportedProtocols;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,27 @@ static bool IsHandshaking(int handshakeStatus)
return handshakeStatus != HANDSHAKE_STATUS__NOT_HANDSHAKING && handshakeStatus != HANDSHAKE_STATUS__FINISHED;
}

static jobject GetSslSessionForHandshakeStatus(JNIEnv* env, SSLStream* sslStream, int handshakeStatus)
{
// SSLEngine.getHandshakeSession() is available since API 24
jobject sslSession = IsHandshaking(handshakeStatus) && g_SSLEngineGetHandshakeSession != NULL
? (*env)->CallObjectMethod(env, sslStream->sslEngine, g_SSLEngineGetHandshakeSession)
: (*env)->CallObjectMethod(env, sslStream->sslEngine, g_SSLEngineGetSession);
if (CheckJNIExceptions(env))
return NULL;

return sslSession;
}

static jobject GetCurrentSslSession(JNIEnv* env, SSLStream* sslStream)
{
int handshakeStatus = GetEnumAsInt(env, (*env)->CallObjectMethod(env, sslStream->sslEngine, g_SSLEngineGetHandshakeStatus));
if (CheckJNIExceptions(env))
return NULL;

return GetSslSessionForHandshakeStatus(env, sslStream, handshakeStatus);
}

ARGS_NON_NULL_ALL static PAL_SSLStreamStatus Close(JNIEnv* env, SSLStream* sslStream)
{
// Call wrap to clear any remaining data before closing
Expand Down Expand Up @@ -523,10 +544,13 @@ PAL_SSLStreamStatus AndroidCryptoNative_SSLStreamHandshake(SSLStream* sslStream)
abort_if_invalid_pointer_argument (sslStream);
JNIEnv* env = GetJNIEnv();

// sslEngine.beginHandshake();
(*env)->CallVoidMethod(env, sslStream->sslEngine, g_SSLEngineBeginHandshake);
if (CheckJNIExceptions(env))
return SSLStreamStatus_Error;
int handshakeStatus = GetEnumAsInt(env, (*env)->CallObjectMethod(env, sslStream->sslEngine, g_SSLEngineGetHandshakeStatus));
if (!IsHandshaking(handshakeStatus)) {
// sslEngine.beginHandshake();
(*env)->CallVoidMethod(env, sslStream->sslEngine, g_SSLEngineBeginHandshake);
if (CheckJNIExceptions(env))
return SSLStreamStatus_Error;
}

return DoHandshake(env, sslStream);
}
Expand Down Expand Up @@ -705,14 +729,16 @@ int32_t AndroidCryptoNative_SSLStreamGetCipherSuite(SSLStream* sslStream, uint16
*out = NULL;

// String cipherSuite = sslSession.getCipherSuite();
jstring cipherSuite = (*env)->CallObjectMethod(env, sslStream->sslSession, g_SSLSessionGetCipherSuite);
jobject sslSession = GetCurrentSslSession(env, sslStream);
jstring cipherSuite = (*env)->CallObjectMethod(env, sslSession, g_SSLSessionGetCipherSuite);
ON_EXCEPTION_PRINT_AND_GOTO(cleanup);
*out = AllocateString(env, cipherSuite);

ret = SUCCESS;

cleanup:
(*env)->DeleteLocalRef(env, cipherSuite);
ReleaseLRef(env, sslSession);
ReleaseLRef(env, cipherSuite);
return ret;
}

Expand All @@ -726,14 +752,16 @@ int32_t AndroidCryptoNative_SSLStreamGetProtocol(SSLStream* sslStream, uint16_t*
*out = NULL;

// String protocol = sslSession.getProtocol();
jstring protocol = (*env)->CallObjectMethod(env, sslStream->sslSession, g_SSLSessionGetProtocol);
jobject sslSession = GetCurrentSslSession(env, sslStream);
jstring protocol = (*env)->CallObjectMethod(env, sslSession, g_SSLSessionGetProtocol);
ON_EXCEPTION_PRINT_AND_GOTO(cleanup);
*out = AllocateString(env, protocol);

ret = SUCCESS;

cleanup:
(*env)->DeleteLocalRef(env, protocol);
ReleaseLRef(env, sslSession);
ReleaseLRef(env, protocol);
return ret;
}

Expand All @@ -746,7 +774,8 @@ jobject /*X509Certificate*/ AndroidCryptoNative_SSLStreamGetPeerCertificate(SSLS

// Certificate[] certs = sslSession.getPeerCertificates();
// out = certs[0];
jobjectArray certs = (*env)->CallObjectMethod(env, sslStream->sslSession, g_SSLSessionGetPeerCertificates);
jobject sslSession = GetCurrentSslSession(env, sslStream);
jobjectArray certs = (*env)->CallObjectMethod(env, sslSession, g_SSLSessionGetPeerCertificates);

// If there are no peer certificates, getPeerCertificates will throw. Return null to indicate no certificate.
if (TryClearJNIExceptions(env))
Expand All @@ -761,7 +790,8 @@ jobject /*X509Certificate*/ AndroidCryptoNative_SSLStreamGetPeerCertificate(SSLS
}

cleanup:
(*env)->DeleteLocalRef(env, certs);
ReleaseLRef(env, sslSession);
ReleaseLRef(env, certs);
return ret;
}

Expand All @@ -779,7 +809,8 @@ void AndroidCryptoNative_SSLStreamGetPeerCertificates(SSLStream* sslStream, jobj
// for (int i = 0; i < certs.length; i++) {
// out[i] = certs[i];
// }
jobjectArray certs = (*env)->CallObjectMethod(env, sslStream->sslSession, g_SSLSessionGetPeerCertificates);
jobject sslSession = GetCurrentSslSession(env, sslStream);
jobjectArray certs = (*env)->CallObjectMethod(env, sslSession, g_SSLSessionGetPeerCertificates);

// If there are no peer certificates, getPeerCertificates will throw. Return null and length of zero to indicate no certificates.
if (TryClearJNIExceptions(env))
Expand All @@ -798,7 +829,8 @@ void AndroidCryptoNative_SSLStreamGetPeerCertificates(SSLStream* sslStream, jobj
}

cleanup:
(*env)->DeleteLocalRef(env, certs);
ReleaseLRef(env, sslSession);
ReleaseLRef(env, certs);
}

void AndroidCryptoNative_SSLStreamRequestClientAuthentication(SSLStream* sslStream)
Expand Down Expand Up @@ -912,14 +944,15 @@ bool AndroidCryptoNative_SSLStreamVerifyHostname(SSLStream* sslStream, char* hos
JNIEnv* env = GetJNIEnv();

bool ret = false;
INIT_LOCALS(loc, name, verifier);
INIT_LOCALS(loc, name, verifier, sslSession);

// HostnameVerifier verifier = HttpsURLConnection.getDefaultHostnameVerifier();
// return verifier.verify(hostname, sslSession);
loc[name] = make_java_string(env, hostname);
loc[sslSession] = GetCurrentSslSession(env, sslStream);
loc[verifier] =
(*env)->CallStaticObjectMethod(env, g_HttpsURLConnection, g_HttpsURLConnectionGetDefaultHostnameVerifier);
ret = (*env)->CallBooleanMethod(env, loc[verifier], g_HostnameVerifierVerify, loc[name], sslStream->sslSession);
ret = (*env)->CallBooleanMethod(env, loc[verifier], g_HostnameVerifierVerify, loc[name], loc[sslSession]);

RELEASE_LOCALS(loc, env);
return ret;
Expand Down