-
Notifications
You must be signed in to change notification settings - Fork 3.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
xds: Implement GcpAuthenticationFilter (#11638)
- Loading branch information
1 parent
a5db67d
commit 76705c2
Showing
5 changed files
with
410 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
222 changes: 222 additions & 0 deletions
222
xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,222 @@ | ||
/* | ||
* Copyright 2021 The gRPC Authors | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package io.grpc.xds; | ||
|
||
import com.google.auth.oauth2.ComputeEngineCredentials; | ||
import com.google.auth.oauth2.IdTokenCredentials; | ||
import com.google.common.primitives.UnsignedLongs; | ||
import com.google.protobuf.Any; | ||
import com.google.protobuf.InvalidProtocolBufferException; | ||
import com.google.protobuf.Message; | ||
import io.envoyproxy.envoy.extensions.filters.http.gcp_authn.v3.GcpAuthnFilterConfig; | ||
import io.envoyproxy.envoy.extensions.filters.http.gcp_authn.v3.TokenCacheConfig; | ||
import io.grpc.CallCredentials; | ||
import io.grpc.CallOptions; | ||
import io.grpc.Channel; | ||
import io.grpc.ClientCall; | ||
import io.grpc.ClientInterceptor; | ||
import io.grpc.CompositeCallCredentials; | ||
import io.grpc.LoadBalancer.PickSubchannelArgs; | ||
import io.grpc.Metadata; | ||
import io.grpc.MethodDescriptor; | ||
import io.grpc.Status; | ||
import io.grpc.auth.MoreCallCredentials; | ||
import io.grpc.xds.Filter.ClientInterceptorBuilder; | ||
import java.util.LinkedHashMap; | ||
import java.util.Map; | ||
import java.util.concurrent.ScheduledExecutorService; | ||
import java.util.function.Function; | ||
import javax.annotation.Nullable; | ||
|
||
/** | ||
* A {@link Filter} that injects a {@link CallCredentials} to handle | ||
* authentication for xDS credentials. | ||
*/ | ||
final class GcpAuthenticationFilter implements Filter, ClientInterceptorBuilder { | ||
|
||
static final String TYPE_URL = | ||
"type.googleapis.com/envoy.extensions.filters.http.gcp_authn.v3.GcpAuthnFilterConfig"; | ||
|
||
@Override | ||
public String[] typeUrls() { | ||
return new String[] { TYPE_URL }; | ||
} | ||
|
||
@Override | ||
public ConfigOrError<? extends FilterConfig> parseFilterConfig(Message rawProtoMessage) { | ||
GcpAuthnFilterConfig gcpAuthnProto; | ||
if (!(rawProtoMessage instanceof Any)) { | ||
return ConfigOrError.fromError("Invalid config type: " + rawProtoMessage.getClass()); | ||
} | ||
Any anyMessage = (Any) rawProtoMessage; | ||
|
||
try { | ||
gcpAuthnProto = anyMessage.unpack(GcpAuthnFilterConfig.class); | ||
} catch (InvalidProtocolBufferException e) { | ||
return ConfigOrError.fromError("Invalid proto: " + e); | ||
} | ||
|
||
long cacheSize = 10; | ||
// Validate cache_config | ||
TokenCacheConfig cacheConfig = gcpAuthnProto.getCacheConfig(); | ||
if (cacheConfig != null) { | ||
cacheSize = cacheConfig.getCacheSize().getValue(); | ||
if (cacheSize == 0) { | ||
return ConfigOrError.fromError( | ||
"cache_config.cache_size must be greater than zero"); | ||
} | ||
// LruCache's size is an int and briefly exceeds its maximum size before evicting entries | ||
cacheSize = UnsignedLongs.min(cacheSize, Integer.MAX_VALUE - 1); | ||
} | ||
|
||
GcpAuthenticationConfig config = new GcpAuthenticationConfig((int) cacheSize); | ||
return ConfigOrError.fromConfig(config); | ||
} | ||
|
||
@Override | ||
public ConfigOrError<? extends FilterConfig> parseFilterConfigOverride(Message rawProtoMessage) { | ||
return parseFilterConfig(rawProtoMessage); | ||
} | ||
|
||
@Nullable | ||
@Override | ||
public ClientInterceptor buildClientInterceptor(FilterConfig config, | ||
@Nullable FilterConfig overrideConfig, PickSubchannelArgs args, | ||
ScheduledExecutorService scheduler) { | ||
|
||
ComputeEngineCredentials credentials = ComputeEngineCredentials.create(); | ||
LruCache<String, CallCredentials> callCredentialsCache = | ||
new LruCache<>(((GcpAuthenticationConfig) config).getCacheSize()); | ||
return new ClientInterceptor() { | ||
@Override | ||
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall( | ||
MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) { | ||
|
||
/*String clusterName = callOptions.getOption(InternalXdsAttributes.ATTR_CLUSTER_NAME); | ||
if (clusterName == null) { | ||
return next.newCall(method, callOptions); | ||
}*/ | ||
|
||
// TODO: Fetch the CDS resource for the cluster. | ||
// If the CDS resource is not available, fail the RPC with Status.UNAVAILABLE. | ||
|
||
// TODO: Extract the audience from the CDS resource metadata. | ||
// If the audience is not found or is in the wrong format, fail the RPC. | ||
String audience = "TEST_AUDIENCE"; | ||
|
||
try { | ||
CallCredentials existingCallCredentials = callOptions.getCredentials(); | ||
CallCredentials newCallCredentials = | ||
getCallCredentials(callCredentialsCache, audience, credentials); | ||
if (existingCallCredentials != null) { | ||
callOptions = callOptions.withCallCredentials( | ||
new CompositeCallCredentials(existingCallCredentials, newCallCredentials)); | ||
} else { | ||
callOptions = callOptions.withCallCredentials(newCallCredentials); | ||
} | ||
} | ||
catch (Exception e) { | ||
// If we fail to attach CallCredentials due to any reason, return a FailingClientCall | ||
return new FailingClientCall<>(Status.UNAUTHENTICATED | ||
.withDescription("Failed to attach CallCredentials.") | ||
.withCause(e)); | ||
} | ||
return next.newCall(method, callOptions); | ||
} | ||
}; | ||
} | ||
|
||
private CallCredentials getCallCredentials(LruCache<String, CallCredentials> cache, | ||
String audience, ComputeEngineCredentials credentials) { | ||
|
||
synchronized (cache) { | ||
return cache.getOrInsert(audience, key -> { | ||
IdTokenCredentials creds = IdTokenCredentials.newBuilder() | ||
.setIdTokenProvider(credentials) | ||
.setTargetAudience(audience) | ||
.build(); | ||
return MoreCallCredentials.from(creds); | ||
}); | ||
} | ||
} | ||
|
||
static final class GcpAuthenticationConfig implements FilterConfig { | ||
|
||
private final int cacheSize; | ||
|
||
public GcpAuthenticationConfig(int cacheSize) { | ||
this.cacheSize = cacheSize; | ||
} | ||
|
||
public int getCacheSize() { | ||
return cacheSize; | ||
} | ||
|
||
@Override | ||
public String typeUrl() { | ||
return GcpAuthenticationFilter.TYPE_URL; | ||
} | ||
} | ||
|
||
/** An implementation of {@link ClientCall} that fails when started. */ | ||
private static final class FailingClientCall<ReqT, RespT> extends ClientCall<ReqT, RespT> { | ||
|
||
private final Status error; | ||
|
||
public FailingClientCall(Status error) { | ||
this.error = error; | ||
} | ||
|
||
@Override | ||
public void start(ClientCall.Listener<RespT> listener, Metadata headers) { | ||
listener.onClose(error, new Metadata()); | ||
} | ||
|
||
@Override | ||
public void request(int numMessages) {} | ||
|
||
@Override | ||
public void cancel(String message, Throwable cause) {} | ||
|
||
@Override | ||
public void halfClose() {} | ||
|
||
@Override | ||
public void sendMessage(ReqT message) {} | ||
} | ||
|
||
private static final class LruCache<K, V> { | ||
|
||
private final Map<K, V> cache; | ||
|
||
LruCache(int maxSize) { | ||
this.cache = new LinkedHashMap<K, V>( | ||
maxSize, | ||
0.75f, | ||
true) { | ||
@Override | ||
protected boolean removeEldestEntry(Map.Entry<K, V> eldest) { | ||
return size() > maxSize; | ||
} | ||
}; | ||
} | ||
|
||
V getOrInsert(K key, Function<K, V> create) { | ||
return cache.computeIfAbsent(key, create); | ||
} | ||
} | ||
} |
121 changes: 121 additions & 0 deletions
121
xds/src/test/java/io/grpc/xds/GcpAuthenticationFilterTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
/* | ||
* Copyright 2024 The gRPC Authors | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package io.grpc.xds; | ||
|
||
import static com.google.common.truth.Truth.assertThat; | ||
import static org.junit.Assert.assertEquals; | ||
import static org.junit.Assert.assertNotNull; | ||
import static org.junit.Assert.assertNull; | ||
import static org.junit.Assert.assertSame; | ||
import static org.junit.Assert.assertTrue; | ||
import static org.mockito.ArgumentMatchers.eq; | ||
|
||
import com.google.protobuf.Any; | ||
import com.google.protobuf.Empty; | ||
import com.google.protobuf.Message; | ||
import com.google.protobuf.UInt64Value; | ||
import io.envoyproxy.envoy.extensions.filters.http.gcp_authn.v3.GcpAuthnFilterConfig; | ||
import io.envoyproxy.envoy.extensions.filters.http.gcp_authn.v3.TokenCacheConfig; | ||
import io.grpc.CallOptions; | ||
import io.grpc.Channel; | ||
import io.grpc.ClientInterceptor; | ||
import io.grpc.MethodDescriptor; | ||
import io.grpc.testing.TestMethodDescriptors; | ||
import org.junit.Test; | ||
import org.junit.runner.RunWith; | ||
import org.junit.runners.JUnit4; | ||
import org.mockito.ArgumentCaptor; | ||
import org.mockito.Mockito; | ||
|
||
@RunWith(JUnit4.class) | ||
public class GcpAuthenticationFilterTest { | ||
|
||
@Test | ||
public void testParseFilterConfig_withValidConfig() { | ||
GcpAuthnFilterConfig config = GcpAuthnFilterConfig.newBuilder() | ||
.setCacheConfig(TokenCacheConfig.newBuilder().setCacheSize(UInt64Value.of(20))) | ||
.build(); | ||
Any anyMessage = Any.pack(config); | ||
|
||
GcpAuthenticationFilter filter = new GcpAuthenticationFilter(); | ||
ConfigOrError<? extends Filter.FilterConfig> result = filter.parseFilterConfig(anyMessage); | ||
|
||
assertNotNull(result.config); | ||
assertNull(result.errorDetail); | ||
assertEquals(20L, | ||
((GcpAuthenticationFilter.GcpAuthenticationConfig) result.config).getCacheSize()); | ||
} | ||
|
||
@Test | ||
public void testParseFilterConfig_withZeroCacheSize() { | ||
GcpAuthnFilterConfig config = GcpAuthnFilterConfig.newBuilder() | ||
.setCacheConfig(TokenCacheConfig.newBuilder().setCacheSize(UInt64Value.of(0))) | ||
.build(); | ||
Any anyMessage = Any.pack(config); | ||
|
||
GcpAuthenticationFilter filter = new GcpAuthenticationFilter(); | ||
ConfigOrError<? extends Filter.FilterConfig> result = filter.parseFilterConfig(anyMessage); | ||
|
||
assertNull(result.config); | ||
assertNotNull(result.errorDetail); | ||
assertTrue(result.errorDetail.contains("cache_config.cache_size must be greater than zero")); | ||
} | ||
|
||
@Test | ||
public void testParseFilterConfig_withInvalidMessageType() { | ||
GcpAuthenticationFilter filter = new GcpAuthenticationFilter(); | ||
Message invalidMessage = Empty.getDefaultInstance(); | ||
ConfigOrError<? extends Filter.FilterConfig> result = filter.parseFilterConfig(invalidMessage); | ||
|
||
assertNull(result.config); | ||
assertThat(result.errorDetail).contains("Invalid config type"); | ||
} | ||
|
||
@Test | ||
public void testClientInterceptor_createsAndReusesCachedCredentials() { | ||
GcpAuthenticationFilter.GcpAuthenticationConfig config = | ||
new GcpAuthenticationFilter.GcpAuthenticationConfig(10); | ||
GcpAuthenticationFilter filter = new GcpAuthenticationFilter(); | ||
|
||
// Create interceptor | ||
ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null, null); | ||
MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod(); | ||
|
||
// Mock channel and capture CallOptions | ||
Channel mockChannel = Mockito.mock(Channel.class); | ||
ArgumentCaptor<CallOptions> callOptionsCaptor = ArgumentCaptor.forClass(CallOptions.class); | ||
|
||
// Execute interception twice to check caching | ||
interceptor.interceptCall(methodDescriptor, CallOptions.DEFAULT, mockChannel); | ||
interceptor.interceptCall(methodDescriptor, CallOptions.DEFAULT, mockChannel); | ||
|
||
// Capture and verify CallOptions for CallCredentials presence | ||
Mockito.verify(mockChannel, Mockito.times(2)) | ||
.newCall(eq(methodDescriptor), callOptionsCaptor.capture()); | ||
|
||
// Retrieve the CallOptions captured from both calls | ||
CallOptions firstCapturedOptions = callOptionsCaptor.getAllValues().get(0); | ||
CallOptions secondCapturedOptions = callOptionsCaptor.getAllValues().get(1); | ||
|
||
// Ensure that CallCredentials was added | ||
assertNotNull(firstCapturedOptions.getCredentials()); | ||
assertNotNull(secondCapturedOptions.getCredentials()); | ||
|
||
// Ensure that the CallCredentials from both calls are the same, indicating caching | ||
assertSame(firstCapturedOptions.getCredentials(), secondCapturedOptions.getCredentials()); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.