Skip to content

Commit

Permalink
xds: Implement GcpAuthenticationFilter (#11638)
Browse files Browse the repository at this point in the history
  • Loading branch information
shivaspeaks authored Nov 6, 2024
1 parent a5db67d commit 76705c2
Show file tree
Hide file tree
Showing 5 changed files with 410 additions and 0 deletions.
3 changes: 3 additions & 0 deletions xds/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ java_library(
"//:auto_value_annotations",
"//alts",
"//api",
"//auth",
"//context",
"//core:internal",
"//netty",
Expand All @@ -45,6 +46,7 @@ java_library(
"@com_google_googleapis//google/rpc:rpc_java_proto",
"@com_google_protobuf//:protobuf_java",
"@com_google_protobuf//:protobuf_java_util",
"@maven//:com_google_auth_google_auth_library_oauth2_http",
artifact("com.google.code.findbugs:jsr305"),
artifact("com.google.code.gson:gson"),
artifact("com.google.errorprone:error_prone_annotations"),
Expand Down Expand Up @@ -73,6 +75,7 @@ java_proto_library(
"@envoy_api//envoy/extensions/clusters/aggregate/v3:pkg",
"@envoy_api//envoy/extensions/filters/common/fault/v3:pkg",
"@envoy_api//envoy/extensions/filters/http/fault/v3:pkg",
"@envoy_api//envoy/extensions/filters/http/gcp_authn/v3:pkg",
"@envoy_api//envoy/extensions/filters/http/rbac/v3:pkg",
"@envoy_api//envoy/extensions/filters/http/router/v3:pkg",
"@envoy_api//envoy/extensions/filters/network/http_connection_manager/v3:pkg",
Expand Down
222 changes: 222 additions & 0 deletions xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
/*
* Copyright 2021 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.grpc.xds;

import com.google.auth.oauth2.ComputeEngineCredentials;
import com.google.auth.oauth2.IdTokenCredentials;
import com.google.common.primitives.UnsignedLongs;
import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.google.protobuf.Message;
import io.envoyproxy.envoy.extensions.filters.http.gcp_authn.v3.GcpAuthnFilterConfig;
import io.envoyproxy.envoy.extensions.filters.http.gcp_authn.v3.TokenCacheConfig;
import io.grpc.CallCredentials;
import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ClientCall;
import io.grpc.ClientInterceptor;
import io.grpc.CompositeCallCredentials;
import io.grpc.LoadBalancer.PickSubchannelArgs;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.Status;
import io.grpc.auth.MoreCallCredentials;
import io.grpc.xds.Filter.ClientInterceptorBuilder;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.concurrent.ScheduledExecutorService;
import java.util.function.Function;
import javax.annotation.Nullable;

/**
* A {@link Filter} that injects a {@link CallCredentials} to handle
* authentication for xDS credentials.
*/
final class GcpAuthenticationFilter implements Filter, ClientInterceptorBuilder {

static final String TYPE_URL =
"type.googleapis.com/envoy.extensions.filters.http.gcp_authn.v3.GcpAuthnFilterConfig";

@Override
public String[] typeUrls() {
return new String[] { TYPE_URL };
}

@Override
public ConfigOrError<? extends FilterConfig> parseFilterConfig(Message rawProtoMessage) {
GcpAuthnFilterConfig gcpAuthnProto;
if (!(rawProtoMessage instanceof Any)) {
return ConfigOrError.fromError("Invalid config type: " + rawProtoMessage.getClass());
}
Any anyMessage = (Any) rawProtoMessage;

try {
gcpAuthnProto = anyMessage.unpack(GcpAuthnFilterConfig.class);
} catch (InvalidProtocolBufferException e) {
return ConfigOrError.fromError("Invalid proto: " + e);
}

long cacheSize = 10;
// Validate cache_config
TokenCacheConfig cacheConfig = gcpAuthnProto.getCacheConfig();
if (cacheConfig != null) {
cacheSize = cacheConfig.getCacheSize().getValue();
if (cacheSize == 0) {
return ConfigOrError.fromError(
"cache_config.cache_size must be greater than zero");
}
// LruCache's size is an int and briefly exceeds its maximum size before evicting entries
cacheSize = UnsignedLongs.min(cacheSize, Integer.MAX_VALUE - 1);
}

GcpAuthenticationConfig config = new GcpAuthenticationConfig((int) cacheSize);
return ConfigOrError.fromConfig(config);
}

@Override
public ConfigOrError<? extends FilterConfig> parseFilterConfigOverride(Message rawProtoMessage) {
return parseFilterConfig(rawProtoMessage);
}

@Nullable
@Override
public ClientInterceptor buildClientInterceptor(FilterConfig config,
@Nullable FilterConfig overrideConfig, PickSubchannelArgs args,
ScheduledExecutorService scheduler) {

ComputeEngineCredentials credentials = ComputeEngineCredentials.create();
LruCache<String, CallCredentials> callCredentialsCache =
new LruCache<>(((GcpAuthenticationConfig) config).getCacheSize());
return new ClientInterceptor() {
@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {

/*String clusterName = callOptions.getOption(InternalXdsAttributes.ATTR_CLUSTER_NAME);
if (clusterName == null) {
return next.newCall(method, callOptions);
}*/

// TODO: Fetch the CDS resource for the cluster.
// If the CDS resource is not available, fail the RPC with Status.UNAVAILABLE.

// TODO: Extract the audience from the CDS resource metadata.
// If the audience is not found or is in the wrong format, fail the RPC.
String audience = "TEST_AUDIENCE";

try {
CallCredentials existingCallCredentials = callOptions.getCredentials();
CallCredentials newCallCredentials =
getCallCredentials(callCredentialsCache, audience, credentials);
if (existingCallCredentials != null) {
callOptions = callOptions.withCallCredentials(
new CompositeCallCredentials(existingCallCredentials, newCallCredentials));
} else {
callOptions = callOptions.withCallCredentials(newCallCredentials);
}
}
catch (Exception e) {
// If we fail to attach CallCredentials due to any reason, return a FailingClientCall
return new FailingClientCall<>(Status.UNAUTHENTICATED
.withDescription("Failed to attach CallCredentials.")
.withCause(e));
}
return next.newCall(method, callOptions);
}
};
}

private CallCredentials getCallCredentials(LruCache<String, CallCredentials> cache,
String audience, ComputeEngineCredentials credentials) {

synchronized (cache) {
return cache.getOrInsert(audience, key -> {
IdTokenCredentials creds = IdTokenCredentials.newBuilder()
.setIdTokenProvider(credentials)
.setTargetAudience(audience)
.build();
return MoreCallCredentials.from(creds);
});
}
}

static final class GcpAuthenticationConfig implements FilterConfig {

private final int cacheSize;

public GcpAuthenticationConfig(int cacheSize) {
this.cacheSize = cacheSize;
}

public int getCacheSize() {
return cacheSize;
}

@Override
public String typeUrl() {
return GcpAuthenticationFilter.TYPE_URL;
}
}

/** An implementation of {@link ClientCall} that fails when started. */
private static final class FailingClientCall<ReqT, RespT> extends ClientCall<ReqT, RespT> {

private final Status error;

public FailingClientCall(Status error) {
this.error = error;
}

@Override
public void start(ClientCall.Listener<RespT> listener, Metadata headers) {
listener.onClose(error, new Metadata());
}

@Override
public void request(int numMessages) {}

@Override
public void cancel(String message, Throwable cause) {}

@Override
public void halfClose() {}

@Override
public void sendMessage(ReqT message) {}
}

private static final class LruCache<K, V> {

private final Map<K, V> cache;

LruCache(int maxSize) {
this.cache = new LinkedHashMap<K, V>(
maxSize,
0.75f,
true) {
@Override
protected boolean removeEldestEntry(Map.Entry<K, V> eldest) {
return size() > maxSize;
}
};
}

V getOrInsert(K key, Function<K, V> create) {
return cache.computeIfAbsent(key, create);
}
}
}
121 changes: 121 additions & 0 deletions xds/src/test/java/io/grpc/xds/GcpAuthenticationFilterTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/*
* Copyright 2024 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.grpc.xds;

import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.eq;

import com.google.protobuf.Any;
import com.google.protobuf.Empty;
import com.google.protobuf.Message;
import com.google.protobuf.UInt64Value;
import io.envoyproxy.envoy.extensions.filters.http.gcp_authn.v3.GcpAuthnFilterConfig;
import io.envoyproxy.envoy.extensions.filters.http.gcp_authn.v3.TokenCacheConfig;
import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ClientInterceptor;
import io.grpc.MethodDescriptor;
import io.grpc.testing.TestMethodDescriptors;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;

@RunWith(JUnit4.class)
public class GcpAuthenticationFilterTest {

@Test
public void testParseFilterConfig_withValidConfig() {
GcpAuthnFilterConfig config = GcpAuthnFilterConfig.newBuilder()
.setCacheConfig(TokenCacheConfig.newBuilder().setCacheSize(UInt64Value.of(20)))
.build();
Any anyMessage = Any.pack(config);

GcpAuthenticationFilter filter = new GcpAuthenticationFilter();
ConfigOrError<? extends Filter.FilterConfig> result = filter.parseFilterConfig(anyMessage);

assertNotNull(result.config);
assertNull(result.errorDetail);
assertEquals(20L,
((GcpAuthenticationFilter.GcpAuthenticationConfig) result.config).getCacheSize());
}

@Test
public void testParseFilterConfig_withZeroCacheSize() {
GcpAuthnFilterConfig config = GcpAuthnFilterConfig.newBuilder()
.setCacheConfig(TokenCacheConfig.newBuilder().setCacheSize(UInt64Value.of(0)))
.build();
Any anyMessage = Any.pack(config);

GcpAuthenticationFilter filter = new GcpAuthenticationFilter();
ConfigOrError<? extends Filter.FilterConfig> result = filter.parseFilterConfig(anyMessage);

assertNull(result.config);
assertNotNull(result.errorDetail);
assertTrue(result.errorDetail.contains("cache_config.cache_size must be greater than zero"));
}

@Test
public void testParseFilterConfig_withInvalidMessageType() {
GcpAuthenticationFilter filter = new GcpAuthenticationFilter();
Message invalidMessage = Empty.getDefaultInstance();
ConfigOrError<? extends Filter.FilterConfig> result = filter.parseFilterConfig(invalidMessage);

assertNull(result.config);
assertThat(result.errorDetail).contains("Invalid config type");
}

@Test
public void testClientInterceptor_createsAndReusesCachedCredentials() {
GcpAuthenticationFilter.GcpAuthenticationConfig config =
new GcpAuthenticationFilter.GcpAuthenticationConfig(10);
GcpAuthenticationFilter filter = new GcpAuthenticationFilter();

// Create interceptor
ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null, null);
MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod();

// Mock channel and capture CallOptions
Channel mockChannel = Mockito.mock(Channel.class);
ArgumentCaptor<CallOptions> callOptionsCaptor = ArgumentCaptor.forClass(CallOptions.class);

// Execute interception twice to check caching
interceptor.interceptCall(methodDescriptor, CallOptions.DEFAULT, mockChannel);
interceptor.interceptCall(methodDescriptor, CallOptions.DEFAULT, mockChannel);

// Capture and verify CallOptions for CallCredentials presence
Mockito.verify(mockChannel, Mockito.times(2))
.newCall(eq(methodDescriptor), callOptionsCaptor.capture());

// Retrieve the CallOptions captured from both calls
CallOptions firstCapturedOptions = callOptionsCaptor.getAllValues().get(0);
CallOptions secondCapturedOptions = callOptionsCaptor.getAllValues().get(1);

// Ensure that CallCredentials was added
assertNotNull(firstCapturedOptions.getCredentials());
assertNotNull(secondCapturedOptions.getCredentials());

// Ensure that the CallCredentials from both calls are the same, indicating caching
assertSame(firstCapturedOptions.getCredentials(), secondCapturedOptions.getCredentials());
}
}
1 change: 1 addition & 0 deletions xds/third_party/envoy/import.sh
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ envoy/extensions/clusters/aggregate/v3/cluster.proto
envoy/extensions/filters/common/fault/v3/fault.proto
envoy/extensions/filters/http/fault/v3/fault.proto
envoy/extensions/filters/http/rate_limit_quota/v3/rate_limit_quota.proto
envoy/extensions/filters/http/gcp_authn/v3/gcp_authn.proto
envoy/extensions/filters/http/rbac/v3/rbac.proto
envoy/extensions/filters/http/router/v3/router.proto
envoy/extensions/filters/network/http_connection_manager/v3/http_connection_manager.proto
Expand Down
Loading

0 comments on commit 76705c2

Please sign in to comment.