Skip to content

Commit 7a08fdb

Browse files
authored
xds: float LRU cache across interceptors (grpc#11992)
1 parent 84bd014 commit 7a08fdb

File tree

2 files changed

+175
-31
lines changed

2 files changed

+175
-31
lines changed

xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java

+40-20
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,17 @@ final class GcpAuthenticationFilter implements Filter {
5959

6060
static final String TYPE_URL =
6161
"type.googleapis.com/envoy.extensions.filters.http.gcp_authn.v3.GcpAuthnFilterConfig";
62-
62+
private final LruCache<String, CallCredentials> callCredentialsCache;
6363
final String filterInstanceName;
6464

65-
GcpAuthenticationFilter(String name) {
65+
GcpAuthenticationFilter(String name, int cacheSize) {
6666
filterInstanceName = checkNotNull(name, "name");
67+
this.callCredentialsCache = new LruCache<>(cacheSize);
6768
}
6869

69-
7070
static final class Provider implements Filter.Provider {
71+
private final int cacheSize = 10;
72+
7173
@Override
7274
public String[] typeUrls() {
7375
return new String[]{TYPE_URL};
@@ -80,7 +82,7 @@ public boolean isClientFilter() {
8082

8183
@Override
8284
public GcpAuthenticationFilter newInstance(String name) {
83-
return new GcpAuthenticationFilter(name);
85+
return new GcpAuthenticationFilter(name, cacheSize);
8486
}
8587

8688
@Override
@@ -101,11 +103,14 @@ public ConfigOrError<GcpAuthenticationConfig> parseFilterConfig(Message rawProto
101103
// Validate cache_config
102104
if (gcpAuthnProto.hasCacheConfig()) {
103105
TokenCacheConfig cacheConfig = gcpAuthnProto.getCacheConfig();
104-
cacheSize = cacheConfig.getCacheSize().getValue();
105-
if (cacheSize == 0) {
106-
return ConfigOrError.fromError(
107-
"cache_config.cache_size must be greater than zero");
106+
if (cacheConfig.hasCacheSize()) {
107+
cacheSize = cacheConfig.getCacheSize().getValue();
108+
if (cacheSize == 0) {
109+
return ConfigOrError.fromError(
110+
"cache_config.cache_size must be greater than zero");
111+
}
108112
}
113+
109114
// LruCache's size is an int and briefly exceeds its maximum size before evicting entries
110115
cacheSize = UnsignedLongs.min(cacheSize, Integer.MAX_VALUE - 1);
111116
}
@@ -127,8 +132,9 @@ public ClientInterceptor buildClientInterceptor(FilterConfig config,
127132
@Nullable FilterConfig overrideConfig, ScheduledExecutorService scheduler) {
128133

129134
ComputeEngineCredentials credentials = ComputeEngineCredentials.create();
130-
LruCache<String, CallCredentials> callCredentialsCache =
131-
new LruCache<>(((GcpAuthenticationConfig) config).getCacheSize());
135+
synchronized (callCredentialsCache) {
136+
callCredentialsCache.resizeCache(((GcpAuthenticationConfig) config).getCacheSize());
137+
}
132138
return new ClientInterceptor() {
133139
@Override
134140
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
@@ -254,23 +260,37 @@ public void sendMessage(ReqT message) {}
254260

255261
private static final class LruCache<K, V> {
256262

257-
private final Map<K, V> cache;
263+
private Map<K, V> cache;
264+
private int maxSize;
258265

259266
LruCache(int maxSize) {
260-
this.cache = new LinkedHashMap<K, V>(
261-
maxSize,
262-
0.75f,
263-
true) {
264-
@Override
265-
protected boolean removeEldestEntry(Map.Entry<K, V> eldest) {
266-
return size() > maxSize;
267-
}
268-
};
267+
this.maxSize = maxSize;
268+
this.cache = createEvictingMap(maxSize);
269269
}
270270

271271
V getOrInsert(K key, Function<K, V> create) {
272272
return cache.computeIfAbsent(key, create);
273273
}
274+
275+
private void resizeCache(int newSize) {
276+
if (newSize >= maxSize) {
277+
maxSize = newSize;
278+
return;
279+
}
280+
Map<K, V> newCache = createEvictingMap(newSize);
281+
maxSize = newSize;
282+
newCache.putAll(cache);
283+
cache = newCache;
284+
}
285+
286+
private Map<K, V> createEvictingMap(int size) {
287+
return new LinkedHashMap<K, V>(size, 0.75f, true) {
288+
@Override
289+
protected boolean removeEldestEntry(Map.Entry<K, V> eldest) {
290+
return size() > LruCache.this.maxSize;
291+
}
292+
};
293+
}
274294
}
275295

276296
static class AudienceMetadataParser implements MetadataValueParser {

xds/src/test/java/io/grpc/xds/GcpAuthenticationFilterTest.java

+135-11
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,13 @@
2828
import static io.grpc.xds.XdsTestUtils.getWrrLbConfigAsMap;
2929
import static org.junit.Assert.assertEquals;
3030
import static org.junit.Assert.assertNotNull;
31+
import static org.junit.Assert.assertNotSame;
3132
import static org.junit.Assert.assertNull;
3233
import static org.junit.Assert.assertSame;
3334
import static org.junit.Assert.assertTrue;
3435
import static org.mockito.ArgumentMatchers.eq;
3536
import static org.mockito.Mockito.mock;
37+
import static org.mockito.Mockito.times;
3638
import static org.mockito.Mockito.verify;
3739

3840
import com.google.common.collect.ImmutableList;
@@ -89,8 +91,8 @@ public class GcpAuthenticationFilterTest {
8991

9092
@Test
9193
public void testNewFilterInstancesPerFilterName() {
92-
assertThat(new GcpAuthenticationFilter("FILTER_INSTANCE_NAME1"))
93-
.isNotEqualTo(new GcpAuthenticationFilter("FILTER_INSTANCE_NAME1"));
94+
assertThat(new GcpAuthenticationFilter("FILTER_INSTANCE_NAME1", 10))
95+
.isNotEqualTo(new GcpAuthenticationFilter("FILTER_INSTANCE_NAME1", 10));
9496
}
9597

9698
@Test
@@ -152,7 +154,7 @@ public void testClientInterceptor_success() throws IOException, ResourceInvalidE
152154
.withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0")
153155
.withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig);
154156
GcpAuthenticationConfig config = new GcpAuthenticationConfig(10);
155-
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME");
157+
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10);
156158
ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null);
157159
MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod();
158160
Channel mockChannel = Mockito.mock(Channel.class);
@@ -181,7 +183,7 @@ public void testClientInterceptor_createsAndReusesCachedCredentials()
181183
.withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0")
182184
.withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig);
183185
GcpAuthenticationConfig config = new GcpAuthenticationConfig(10);
184-
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME");
186+
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10);
185187
ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null);
186188
MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod();
187189
Channel mockChannel = Mockito.mock(Channel.class);
@@ -190,7 +192,7 @@ public void testClientInterceptor_createsAndReusesCachedCredentials()
190192
interceptor.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel);
191193
interceptor.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel);
192194

193-
verify(mockChannel, Mockito.times(2))
195+
verify(mockChannel, times(2))
194196
.newCall(eq(methodDescriptor), callOptionsCaptor.capture());
195197
CallOptions firstCapturedOptions = callOptionsCaptor.getAllValues().get(0);
196198
CallOptions secondCapturedOptions = callOptionsCaptor.getAllValues().get(1);
@@ -202,7 +204,7 @@ public void testClientInterceptor_createsAndReusesCachedCredentials()
202204
@Test
203205
public void testClientInterceptor_withoutClusterSelectionKey() throws Exception {
204206
GcpAuthenticationConfig config = new GcpAuthenticationConfig(10);
205-
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME");
207+
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10);
206208
ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null);
207209
MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod();
208210
Channel mockChannel = mock(Channel.class);
@@ -233,7 +235,7 @@ public void testClientInterceptor_clusterSelectionKeyWithoutPrefix() throws Exce
233235
Channel mockChannel = mock(Channel.class);
234236

235237
GcpAuthenticationConfig config = new GcpAuthenticationConfig(10);
236-
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME");
238+
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10);
237239
ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null);
238240
MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod();
239241
interceptor.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel);
@@ -244,7 +246,7 @@ public void testClientInterceptor_clusterSelectionKeyWithoutPrefix() throws Exce
244246
@Test
245247
public void testClientInterceptor_xdsConfigDoesNotExist() throws Exception {
246248
GcpAuthenticationConfig config = new GcpAuthenticationConfig(10);
247-
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME");
249+
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10);
248250
ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null);
249251
MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod();
250252
Channel mockChannel = mock(Channel.class);
@@ -274,7 +276,7 @@ public void testClientInterceptor_incorrectClusterName() throws Exception {
274276
.withOption(CLUSTER_SELECTION_KEY, "cluster:cluster")
275277
.withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig);
276278
GcpAuthenticationConfig config = new GcpAuthenticationConfig(10);
277-
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME");
279+
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10);
278280
ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null);
279281
MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod();
280282
Channel mockChannel = mock(Channel.class);
@@ -300,7 +302,7 @@ public void testClientInterceptor_statusOrError() throws Exception {
300302
.withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0")
301303
.withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig);
302304
GcpAuthenticationConfig config = new GcpAuthenticationConfig(10);
303-
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME");
305+
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10);
304306
ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null);
305307
MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod();
306308
Channel mockChannel = mock(Channel.class);
@@ -329,7 +331,7 @@ public void testClientInterceptor_notAudienceWrapper()
329331
.withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0")
330332
.withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig);
331333
GcpAuthenticationConfig config = new GcpAuthenticationConfig(10);
332-
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME");
334+
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 10);
333335
ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null);
334336
MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod();
335337
Channel mockChannel = Mockito.mock(Channel.class);
@@ -342,6 +344,115 @@ public void testClientInterceptor_notAudienceWrapper()
342344
assertThat(clientCall.error.getDescription()).contains("GCP Authn found wrong type");
343345
}
344346

347+
@Test
348+
public void testLruCacheAcrossInterceptors() throws IOException, ResourceInvalidException {
349+
XdsConfig.XdsClusterConfig clusterConfig = new XdsConfig.XdsClusterConfig(
350+
CLUSTER_NAME, cdsUpdate, new EndpointConfig(StatusOr.fromValue(edsUpdate)));
351+
XdsConfig defaultXdsConfig = new XdsConfig.XdsConfigBuilder()
352+
.setListener(ldsUpdate)
353+
.setRoute(rdsUpdate)
354+
.setVirtualHost(rdsUpdate.virtualHosts.get(0))
355+
.addCluster(CLUSTER_NAME, StatusOr.fromValue(clusterConfig)).build();
356+
CallOptions callOptionsWithXds = CallOptions.DEFAULT
357+
.withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0")
358+
.withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig);
359+
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 2);
360+
ClientInterceptor interceptor1
361+
= filter.buildClientInterceptor(new GcpAuthenticationConfig(2), null, null);
362+
MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod();
363+
Channel mockChannel = Mockito.mock(Channel.class);
364+
ArgumentCaptor<CallOptions> callOptionsCaptor = ArgumentCaptor.forClass(CallOptions.class);
365+
366+
interceptor1.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel);
367+
verify(mockChannel).newCall(eq(methodDescriptor), callOptionsCaptor.capture());
368+
CallOptions capturedOptions1 = callOptionsCaptor.getAllValues().get(0);
369+
assertNotNull(capturedOptions1.getCredentials());
370+
ClientInterceptor interceptor2
371+
= filter.buildClientInterceptor(new GcpAuthenticationConfig(1), null, null);
372+
interceptor2.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel);
373+
verify(mockChannel, times(2))
374+
.newCall(eq(methodDescriptor), callOptionsCaptor.capture());
375+
CallOptions capturedOptions2 = callOptionsCaptor.getAllValues().get(1);
376+
assertNotNull(capturedOptions2.getCredentials());
377+
378+
assertSame(capturedOptions1.getCredentials(), capturedOptions2.getCredentials());
379+
}
380+
381+
@Test
382+
public void testLruCacheEvictionOnResize() throws IOException, ResourceInvalidException {
383+
XdsConfig.XdsClusterConfig clusterConfig = new XdsConfig.XdsClusterConfig(
384+
CLUSTER_NAME, cdsUpdate, new EndpointConfig(StatusOr.fromValue(edsUpdate)));
385+
XdsConfig defaultXdsConfig = new XdsConfig.XdsConfigBuilder()
386+
.setListener(ldsUpdate)
387+
.setRoute(rdsUpdate)
388+
.setVirtualHost(rdsUpdate.virtualHosts.get(0))
389+
.addCluster(CLUSTER_NAME, StatusOr.fromValue(clusterConfig)).build();
390+
CallOptions callOptionsWithXds = CallOptions.DEFAULT
391+
.withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0")
392+
.withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig);
393+
GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME", 2);
394+
MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod();
395+
396+
ClientInterceptor interceptor1 =
397+
filter.buildClientInterceptor(new GcpAuthenticationConfig(2), null, null);
398+
Channel mockChannel1 = Mockito.mock(Channel.class);
399+
ArgumentCaptor<CallOptions> captor = ArgumentCaptor.forClass(CallOptions.class);
400+
interceptor1.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel1);
401+
verify(mockChannel1).newCall(eq(methodDescriptor), captor.capture());
402+
CallOptions options1 = captor.getValue();
403+
// This will recreate the cache with max size of 1 and copy the credential for audience1.
404+
ClientInterceptor interceptor2 =
405+
filter.buildClientInterceptor(new GcpAuthenticationConfig(1), null, null);
406+
Channel mockChannel2 = Mockito.mock(Channel.class);
407+
interceptor2.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel2);
408+
verify(mockChannel2).newCall(eq(methodDescriptor), captor.capture());
409+
CallOptions options2 = captor.getValue();
410+
411+
assertSame(options1.getCredentials(), options2.getCredentials());
412+
413+
clusterConfig = new XdsConfig.XdsClusterConfig(
414+
CLUSTER_NAME, getCdsUpdate2(), new EndpointConfig(StatusOr.fromValue(edsUpdate)));
415+
defaultXdsConfig = new XdsConfig.XdsConfigBuilder()
416+
.setListener(ldsUpdate)
417+
.setRoute(rdsUpdate)
418+
.setVirtualHost(rdsUpdate.virtualHosts.get(0))
419+
.addCluster(CLUSTER_NAME, StatusOr.fromValue(clusterConfig)).build();
420+
callOptionsWithXds = CallOptions.DEFAULT
421+
.withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0")
422+
.withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig);
423+
424+
// This will evict the credential for audience1 and add new credential for audience2
425+
ClientInterceptor interceptor3 =
426+
filter.buildClientInterceptor(new GcpAuthenticationConfig(1), null, null);
427+
Channel mockChannel3 = Mockito.mock(Channel.class);
428+
interceptor3.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel3);
429+
verify(mockChannel3).newCall(eq(methodDescriptor), captor.capture());
430+
CallOptions options3 = captor.getValue();
431+
432+
assertNotSame(options1.getCredentials(), options3.getCredentials());
433+
434+
clusterConfig = new XdsConfig.XdsClusterConfig(
435+
CLUSTER_NAME, cdsUpdate, new EndpointConfig(StatusOr.fromValue(edsUpdate)));
436+
defaultXdsConfig = new XdsConfig.XdsConfigBuilder()
437+
.setListener(ldsUpdate)
438+
.setRoute(rdsUpdate)
439+
.setVirtualHost(rdsUpdate.virtualHosts.get(0))
440+
.addCluster(CLUSTER_NAME, StatusOr.fromValue(clusterConfig)).build();
441+
callOptionsWithXds = CallOptions.DEFAULT
442+
.withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0")
443+
.withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig);
444+
445+
// This will create new credential for audience1 because it has been evicted
446+
ClientInterceptor interceptor4 =
447+
filter.buildClientInterceptor(new GcpAuthenticationConfig(1), null, null);
448+
Channel mockChannel4 = Mockito.mock(Channel.class);
449+
interceptor4.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel4);
450+
verify(mockChannel4).newCall(eq(methodDescriptor), captor.capture());
451+
CallOptions options4 = captor.getValue();
452+
453+
assertNotSame(options1.getCredentials(), options4.getCredentials());
454+
}
455+
345456
private static LdsUpdate getLdsUpdate() {
346457
Filter.NamedFilterConfig routerFilterConfig = new Filter.NamedFilterConfig(
347458
serverName, RouterFilter.ROUTER_CONFIG);
@@ -384,6 +495,19 @@ private static CdsUpdate getCdsUpdate() {
384495
}
385496
}
386497

498+
private static CdsUpdate getCdsUpdate2() {
499+
ImmutableMap.Builder<String, Object> parsedMetadata = ImmutableMap.builder();
500+
parsedMetadata.put("FILTER_INSTANCE_NAME", new AudienceWrapper("NEW_TEST_AUDIENCE"));
501+
try {
502+
CdsUpdate.Builder cdsUpdate = CdsUpdate.forEds(
503+
CLUSTER_NAME, EDS_NAME, null, null, null, null, false)
504+
.lbPolicyConfig(getWrrLbConfigAsMap());
505+
return cdsUpdate.parsedMetadata(parsedMetadata.build()).build();
506+
} catch (IOException ex) {
507+
return null;
508+
}
509+
}
510+
387511
private static CdsUpdate getCdsUpdateWithIncorrectAudienceWrapper() throws IOException {
388512
ImmutableMap.Builder<String, Object> parsedMetadata = ImmutableMap.builder();
389513
parsedMetadata.put("FILTER_INSTANCE_NAME", "TEST_AUDIENCE");

0 commit comments

Comments
 (0)