Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix for correctly JSON-marshalling subscription input variables within WebSocketConnectionManager #365

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ private AWSAppSyncClient(AWSAppSyncClient.Builder builder) {
builder.mServerUrl,
subscriptionAuthorizer,
new ApolloResponseBuilder(builder.customTypeAdapters, mApolloClient.apolloStore().networkResponseNormalizer()),
new ScalarTypeAdapters(builder.customTypeAdapters),
builder.mSubscriptionsAutoReconnect);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,15 @@
import com.apollographql.apollo.api.Operation;
import com.apollographql.apollo.api.Subscription;
import com.apollographql.apollo.exception.ApolloException;
import com.apollographql.apollo.internal.json.InputFieldJsonWriter;
import com.apollographql.apollo.internal.json.JsonWriter;
import com.apollographql.apollo.internal.response.ScalarTypeAdapters;

import org.jetbrains.annotations.NotNull;
import org.json.JSONException;
import org.json.JSONObject;

import java.io.IOException;
import java.net.MalformedURLException;
import java.net.URL;
import java.util.HashSet;
Expand All @@ -37,6 +42,7 @@
import okhttp3.Response;
import okhttp3.WebSocket;
import okhttp3.WebSocketListener;
import okio.Buffer;

/**
* Manages the lifecycle of a single WebSocket connection,
Expand All @@ -45,6 +51,7 @@
final class WebSocketConnectionManager {
private static final String TAG = WebSocketConnectionManager.class.getName();
private static final int NORMAL_CLOSURE_STATUS = 1000;
private final ScalarTypeAdapters scalarTypeAdapters;

private final Context applicationContext;
private final String serverUrl;
Expand All @@ -60,13 +67,15 @@ final class WebSocketConnectionManager {
String serverUrl,
SubscriptionAuthorizer subscriptionAuthorizer,
ApolloResponseBuilder apolloResponseBuilder,
ScalarTypeAdapters scalarTypeAdapters,
boolean subscriptionsAutoReconnect) {
this.applicationContext = context.getApplicationContext();
this.serverUrl = serverUrl;
this.subscriptionAuthorizer = subscriptionAuthorizer;
this.subscriptions = new ConcurrentHashMap<>();
this.apolloResponseBuilder = apolloResponseBuilder;
this.watchdog = new TimeoutWatchdog();
this.scalarTypeAdapters = scalarTypeAdapters;
this.subscriptionsAutoReconnect = subscriptionsAutoReconnect;
}

Expand Down Expand Up @@ -110,9 +119,7 @@ private synchronized void startSubscription(
.put("id", subscriptionId)
.put("type", "start")
.put("payload", new JSONObject()
.put("data", (new JSONObject()
.put("query", subscription.queryDocument())
.put("variables", new JSONObject(subscription.variables().valueMap()))).toString())
.put("data", httpRequestBody(subscription))
.put("extensions", new JSONObject()
.put("authorization", subscriptionAuthorizer.getAuthorizationDetails(false, subscription))))
.toString()
Expand Down Expand Up @@ -409,6 +416,34 @@ private String getConnectionRequestUrl() throws JSONException {
.toString();
}

/**
* Produces a JSON string of the subscription query and its marshalled variables formatted
* correctly for a websocket subscription initialization call.
* On IO exception from the JsonWriter, the JSON string is instead created using subscription
* variable valueMap. Which may throw a JSONException to be handled by the caller function.
* @param subscription the subscription object whose queryDocument and variables will be used
* @return A JSON String containing fields for the header and formatted input variables
* @throws JSONException
*/
private String httpRequestBody(@NotNull Subscription subscription) throws JSONException {
try {
Buffer buffer = new Buffer();
JsonWriter jsonWriter = JsonWriter.of(buffer);
jsonWriter.beginObject();
jsonWriter.name("query").value(subscription.queryDocument().replaceAll("\\n", ""));
jsonWriter.name("variables").beginObject();
subscription.variables().marshaller().marshal(new InputFieldJsonWriter(jsonWriter, scalarTypeAdapters));
jsonWriter.endObject();
jsonWriter.endObject();
jsonWriter.close();
return buffer.readUtf8();
} catch (IOException e) {
return (new JSONObject()
.put("query", subscription.queryDocument())
.put("variables", new JSONObject(subscription.variables().valueMap()))).toString();
}
}

static final class SubscriptionResponseDispatcher<D extends Operation.Data, T, V extends Operation.Variables> {
private final Subscription<D, T, V> subscription;
private final AppSyncSubscriptionCall.Callback<T> callback;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package com.amazonaws.mobileconnectors.appsync;

import android.content.Context;

import com.amazonaws.mobileconnectors.appsync.util.subscriptions.EnumFieldSubscription;
import com.amazonaws.mobileconnectors.appsync.util.subscriptions.TestEnum;
import com.apollographql.apollo.api.Subscription;
import com.google.gson.Gson;

import org.json.JSONException;
import org.json.JSONObject;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;
import org.robolectric.RobolectricTestRunner;
import org.robolectric.annotation.Config;

import java.lang.reflect.Field;

import okhttp3.WebSocket;

import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.fail;

/**
* tests for WebSocketConnectionManager
*/
@RunWith(RobolectricTestRunner.class)
@Config(manifest = Config.NONE, sdk = 28)
public class WebSocketConnectionManagerTest {

Context mockContext;
Subscription<?, ?, ?> testSubscription;
AppSyncSubscriptionCall.Callback mockCallback;
WebSocketConnectionManager webSocketConnectionManager;
SubscriptionAuthorizer mockSubscriptionAuthorizer;
WebSocket mockWebSocket;

@Before
public void beforeEachTest() {
// set up mocks
mockContext = Mockito.mock(Context.class);
testSubscription = new EnumFieldSubscription(TestEnum.TEST_ENUM);
mockCallback = Mockito.mock(AppSyncSubscriptionCall.Callback.class);
mockWebSocket = Mockito.mock(WebSocket.class);
mockSubscriptionAuthorizer = Mockito.mock(SubscriptionAuthorizer.class);
try {
Mockito.when(mockSubscriptionAuthorizer.getAuthorizationDetails(Mockito.anyBoolean(), Mockito.<Subscription>any())).thenReturn(null);
} catch (JSONException e) {
fail("This shouldn't happen.");
}

// set up webSocketConnectionManager
webSocketConnectionManager = new WebSocketConnectionManager(mockContext,
null,
mockSubscriptionAuthorizer,
null,
null,
true);

// set webSocketConnectionManager's websocket to mockWebSocket
try {
Field reader = WebSocketConnectionManager.class.getDeclaredField("websocket");
reader.setAccessible(true);
reader.set(webSocketConnectionManager, mockWebSocket);
} catch (NoSuchFieldException e) {
fail("WebSocketConnectionManager's websocket field has changed.");
} catch (IllegalAccessException e) {
fail("This shouldn't happen.");
}
}

/**
* Test to check whether a subscription request from [webSocketConnectionManager] will correctly
* marshall the enum field of the [EnumFieldSubscription]. If the "testEnum" field of the JSON
* string sent to the mockWebSocket is null, this test will fail.
*/
@Test
public void testWebSocketConnectionManagerCorrectlyMarshalsSubscriptionsWithEnums() {
webSocketConnectionManager.requestSubscription(testSubscription, mockCallback);
ArgumentCaptor<String> sentStringCaptor = ArgumentCaptor.forClass(String.class);
Mockito.verify(mockWebSocket).send(sentStringCaptor.capture());

try {
JSONObject sentJSON = new JSONObject(sentStringCaptor.getValue());
JSONObject payload = sentJSON.getJSONObject("payload");
String data = payload.getString("data");
EnumFieldTestSubscriptionData subscriptionData = new Gson().fromJson(data, EnumFieldTestSubscriptionData.class);
assertNotNull(subscriptionData.variables.testEnum);
} catch (JSONException e) {
fail("invalid JSON was sent: " + e.getLocalizedMessage());
}
}

static class EnumFieldTestSubscriptionData {
String query;
EnumFieldTestVariables variables;

static class EnumFieldTestVariables {
String testEnum;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package com.amazonaws.mobileconnectors.appsync.util.subscriptions;

import com.apollographql.apollo.api.InputFieldMarshaller;
import com.apollographql.apollo.api.InputFieldWriter;
import com.apollographql.apollo.api.Operation;
import com.apollographql.apollo.api.OperationName;
import com.apollographql.apollo.api.ResponseFieldMapper;
import com.apollographql.apollo.api.ResponseFieldMarshaller;
import com.apollographql.apollo.api.Subscription;
import com.apollographql.apollo.api.internal.Utils;

import java.io.IOException;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;

import javax.annotation.Nonnull;

/**
* A testing class of a subscription with an enum field. Follows how apollo-generated subscriptions
* would generate (as of 3.1.4) - with some methods removed that aren't relevant to testing.
*/
public final class EnumFieldSubscription implements Subscription<EnumFieldSubscription.Data, EnumFieldSubscription.Data, EnumFieldSubscription.Variables> {

private final Variables variables;

private static final OperationName OPERATION_NAME = new OperationName() {
@Override
public String name() {
return "";
}
};

public EnumFieldSubscription(@Nonnull TestEnum testEnum) {
Utils.checkNotNull(testEnum, "testEnum == null");
variables = new Variables(testEnum);
}

@Override
public String operationId() { return ""; }

@Override
public String queryDocument() {
return "";
}

@Override
public Data wrapData(Data data) {
return data;
}

@Override
public Variables variables() {
return variables;
}

@Override
public ResponseFieldMapper<Data> responseFieldMapper() {
return null;
}

@Override
public OperationName name() {
return OPERATION_NAME;
}

public static final class Variables extends Operation.Variables {

private final @Nonnull TestEnum testEnum;

private final transient Map<String, Object> valueMap = new LinkedHashMap<>();

Variables(@Nonnull TestEnum testEnum) {
this.testEnum = testEnum;
this.valueMap.put("testEnum", testEnum);
}

public @Nonnull TestEnum testEnum() {
return testEnum;
}

@Override
public Map<String, Object> valueMap() {
return Collections.unmodifiableMap(valueMap);
}

@Override
public InputFieldMarshaller marshaller() {
return new InputFieldMarshaller() {
@Override
public void marshal(InputFieldWriter writer) throws IOException {
writer.writeString("testEnum", testEnum.name());
}
};
}
}

public static class Data implements Operation.Data {

@Override
public ResponseFieldMarshaller marshaller() {
return null;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package com.amazonaws.mobileconnectors.appsync.util.subscriptions;

/**
* a very simple enum class for testing
*/
public enum TestEnum {
TEST_ENUM
}