Skip to content

Commit 1a2285b

Browse files
authored
xds: ensure server interceptors are created in a sync context (grpc#11930)
`XdsServerWrapper#generatePerRouteInterceptors` was always intended to be executed within a sync context. This PR ensures that by calling `syncContext.throwIfNotInThisSynchronizationContext()`. This change is needed for upcoming xDS filter state retention because the new tests in XdsServerWrapperTest flake with this NPE: > `Cannot invoke "io.grpc.xds.client.XdsClient$ResourceWatcher.onChanged(io.grpc.xds.client.XdsClient$ResourceUpdate)" because "this.ldsWatcher" is null`
1 parent cdab410 commit 1a2285b

File tree

3 files changed

+79
-36
lines changed

3 files changed

+79
-36
lines changed

xds/src/main/java/io/grpc/xds/XdsServerWrapper.java

+1-3
Original file line numberDiff line numberDiff line change
@@ -524,9 +524,7 @@ private AtomicReference<ServerRoutingConfig> generateRoutingConfig(FilterChain f
524524

525525
private ImmutableMap<Route, ServerInterceptor> generatePerRouteInterceptors(
526526
@Nullable List<NamedFilterConfig> filterConfigs, List<VirtualHost> virtualHosts) {
527-
// This should always be called from the sync context.
528-
// Ideally we'd want to throw otherwise, but this breaks the tests now.
529-
// syncContext.throwIfNotInThisSynchronizationContext();
527+
syncContext.throwIfNotInThisSynchronizationContext();
530528

531529
ImmutableMap.Builder<Route, ServerInterceptor> perRouteInterceptors =
532530
new ImmutableMap.Builder<>();

xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java

+65-18
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,18 @@
3838
import io.grpc.xds.client.XdsClient;
3939
import io.grpc.xds.client.XdsInitializationException;
4040
import io.grpc.xds.client.XdsResourceType;
41+
import java.time.Duration;
4142
import java.util.ArrayList;
4243
import java.util.Arrays;
4344
import java.util.Collections;
4445
import java.util.HashMap;
4546
import java.util.List;
4647
import java.util.Map;
4748
import java.util.concurrent.CountDownLatch;
49+
import java.util.concurrent.ExecutionException;
4850
import java.util.concurrent.Executor;
51+
import java.util.concurrent.TimeUnit;
52+
import java.util.concurrent.TimeoutException;
4953
import javax.annotation.Nullable;
5054

5155
/**
@@ -174,12 +178,18 @@ public List<String> getTargets() {
174178
}
175179
}
176180

181+
// Implementation details:
182+
// 1. Use `synchronized` in methods where XdsClientImpl uses its own `syncContext`.
183+
// 2. Use `serverExecutor` via `execute()` in methods where XdsClientImpl uses watcher's executor.
177184
static final class FakeXdsClient extends XdsClient {
178-
boolean shutdown;
179-
SettableFuture<String> ldsResource = SettableFuture.create();
180-
ResourceWatcher<LdsUpdate> ldsWatcher;
181-
CountDownLatch rdsCount = new CountDownLatch(1);
185+
public static final Duration DEFAULT_TIMEOUT = Duration.ofSeconds(5);
186+
187+
private boolean shutdown;
188+
@Nullable SettableFuture<String> ldsResource = SettableFuture.create();
189+
@Nullable ResourceWatcher<LdsUpdate> ldsWatcher;
190+
private CountDownLatch rdsCount = new CountDownLatch(1);
182191
final Map<String, ResourceWatcher<RdsUpdate>> rdsWatchers = new HashMap<>();
192+
@Nullable private volatile Executor serverExecutor;
183193

184194
@Override
185195
public TlsContextManager getSecurityConfig() {
@@ -193,14 +203,20 @@ public BootstrapInfo getBootstrapInfo() {
193203

194204
@Override
195205
@SuppressWarnings("unchecked")
196-
public <T extends ResourceUpdate> void watchXdsResource(XdsResourceType<T> resourceType,
197-
String resourceName,
198-
ResourceWatcher<T> watcher,
199-
Executor syncContext) {
206+
public synchronized <T extends ResourceUpdate> void watchXdsResource(
207+
XdsResourceType<T> resourceType,
208+
String resourceName,
209+
ResourceWatcher<T> watcher,
210+
Executor executor) {
211+
if (serverExecutor != null) {
212+
assertThat(executor).isEqualTo(serverExecutor);
213+
}
214+
200215
switch (resourceType.typeName()) {
201216
case "LDS":
202217
assertThat(ldsWatcher).isNull();
203218
ldsWatcher = (ResourceWatcher<LdsUpdate>) watcher;
219+
serverExecutor = executor;
204220
ldsResource.set(resourceName);
205221
break;
206222
case "RDS":
@@ -213,14 +229,14 @@ public <T extends ResourceUpdate> void watchXdsResource(XdsResourceType<T> resou
213229
}
214230

215231
@Override
216-
public <T extends ResourceUpdate> void cancelXdsResourceWatch(XdsResourceType<T> type,
217-
String resourceName,
218-
ResourceWatcher<T> watcher) {
232+
public synchronized <T extends ResourceUpdate> void cancelXdsResourceWatch(
233+
XdsResourceType<T> type, String resourceName, ResourceWatcher<T> watcher) {
219234
switch (type.typeName()) {
220235
case "LDS":
221236
assertThat(ldsWatcher).isNotNull();
222237
ldsResource = null;
223238
ldsWatcher = null;
239+
serverExecutor = null;
224240
break;
225241
case "RDS":
226242
rdsWatchers.remove(resourceName);
@@ -230,27 +246,58 @@ public <T extends ResourceUpdate> void cancelXdsResourceWatch(XdsResourceType<T>
230246
}
231247

232248
@Override
233-
public void shutdown() {
249+
public synchronized void shutdown() {
234250
shutdown = true;
235251
}
236252

237253
@Override
238-
public boolean isShutDown() {
254+
public synchronized boolean isShutDown() {
239255
return shutdown;
240256
}
241257

258+
public void awaitRds(Duration timeout) throws InterruptedException, TimeoutException {
259+
if (!rdsCount.await(timeout.toMillis(), TimeUnit.MILLISECONDS)) {
260+
throw new TimeoutException("Timeout " + timeout + " waiting for RDSs");
261+
}
262+
}
263+
264+
public void setExpectedRdsCount(int count) {
265+
rdsCount = new CountDownLatch(count);
266+
}
267+
268+
private void execute(Runnable action) {
269+
// This method ensures that all watcher updates:
270+
// - Happen after the server started watching LDS.
271+
// - Are executed within the sync context of the server.
272+
//
273+
// Note that this doesn't guarantee that any of the RDS watchers are created.
274+
// Tests should use setExpectedRdsCount(int) and awaitRds() for that.
275+
if (ldsResource == null) {
276+
throw new IllegalStateException("xDS resource update after watcher cancel");
277+
}
278+
try {
279+
ldsResource.get(DEFAULT_TIMEOUT.toMillis(), TimeUnit.MILLISECONDS);
280+
} catch (ExecutionException | TimeoutException e) {
281+
throw new RuntimeException("Can't resolve LDS resource name in " + DEFAULT_TIMEOUT, e);
282+
} catch (InterruptedException e) {
283+
Thread.currentThread().interrupt();
284+
throw new RuntimeException(e);
285+
}
286+
serverExecutor.execute(action);
287+
}
288+
242289
void deliverLdsUpdate(List<FilterChain> filterChains,
243290
FilterChain defaultFilterChain) {
244-
ldsWatcher.onChanged(LdsUpdate.forTcpListener(Listener.create(
245-
"listener", "0.0.0.0:1", ImmutableList.copyOf(filterChains), defaultFilterChain)));
291+
deliverLdsUpdate(LdsUpdate.forTcpListener(Listener.create(
292+
"listener", "0.0.0.0:1", ImmutableList.copyOf(filterChains), defaultFilterChain)));
246293
}
247294

248295
void deliverLdsUpdate(LdsUpdate ldsUpdate) {
249-
ldsWatcher.onChanged(ldsUpdate);
296+
execute(() -> ldsWatcher.onChanged(ldsUpdate));
250297
}
251298

252-
void deliverRdsUpdate(String rdsName, List<VirtualHost> virtualHosts) {
253-
rdsWatchers.get(rdsName).onChanged(new RdsUpdate(virtualHosts));
299+
void deliverRdsUpdate(String resourceName, List<VirtualHost> virtualHosts) {
300+
execute(() -> rdsWatchers.get(resourceName).onChanged(new RdsUpdate(virtualHosts)));
254301
}
255302
}
256303
}

xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java

+13-15
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@
7474
import java.util.Arrays;
7575
import java.util.Collections;
7676
import java.util.List;
77-
import java.util.concurrent.CountDownLatch;
7877
import java.util.concurrent.ExecutionException;
7978
import java.util.concurrent.Executors;
8079
import java.util.concurrent.TimeUnit;
@@ -252,7 +251,7 @@ public void run() {
252251
FilterChain f0 = createFilterChain("filter-chain-0", hcm_virtual);
253252
FilterChain f1 = createFilterChain("filter-chain-1", createRds("rds"));
254253
xdsClient.deliverLdsUpdate(Collections.singletonList(f0), f1);
255-
xdsClient.rdsCount.await(5, TimeUnit.SECONDS);
254+
xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT);
256255
xdsClient.deliverRdsUpdate("rds",
257256
Collections.singletonList(createVirtualHost("virtual-host-1")));
258257
verify(listener, timeout(5000)).onServing();
@@ -261,7 +260,7 @@ public void run() {
261260
xdsServerWrapper.shutdown();
262261
assertThat(xdsServerWrapper.isShutdown()).isTrue();
263262
assertThat(xdsClient.ldsResource).isNull();
264-
assertThat(xdsClient.shutdown).isTrue();
263+
assertThat(xdsClient.isShutDown()).isTrue();
265264
verify(mockServer).shutdown();
266265
assertThat(f0.sslContextProviderSupplier().isShutdown()).isTrue();
267266
assertThat(f1.sslContextProviderSupplier().isShutdown()).isTrue();
@@ -303,7 +302,7 @@ public void run() {
303302
verify(mockServer, never()).start();
304303
assertThat(xdsServerWrapper.isShutdown()).isTrue();
305304
assertThat(xdsClient.ldsResource).isNull();
306-
assertThat(xdsClient.shutdown).isTrue();
305+
assertThat(xdsClient.isShutDown()).isTrue();
307306
verify(mockServer).shutdown();
308307
assertThat(f0.sslContextProviderSupplier().isShutdown()).isTrue();
309308
assertThat(f1.sslContextProviderSupplier().isShutdown()).isTrue();
@@ -342,7 +341,7 @@ public void run() {
342341
xdsServerWrapper.shutdown();
343342
assertThat(xdsServerWrapper.isShutdown()).isTrue();
344343
assertThat(xdsClient.ldsResource).isNull();
345-
assertThat(xdsClient.shutdown).isTrue();
344+
assertThat(xdsClient.isShutDown()).isTrue();
346345
verify(mockBuilder, times(1)).build();
347346
verify(mockServer, times(1)).shutdown();
348347
xdsServerWrapper.awaitTermination(1, TimeUnit.SECONDS);
@@ -367,7 +366,7 @@ public void run() {
367366
FilterChain filterChain = createFilterChain("filter-chain-1", createRds("rds"));
368367
SslContextProviderSupplier sslSupplier = filterChain.sslContextProviderSupplier();
369368
xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain), null);
370-
xdsClient.rdsCount.await(5, TimeUnit.SECONDS);
369+
xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT);
371370
xdsClient.deliverRdsUpdate("rds",
372371
Collections.singletonList(createVirtualHost("virtual-host-1")));
373372
try {
@@ -434,7 +433,7 @@ public void run() {
434433
xdsClient.ldsResource.get(5, TimeUnit.SECONDS);
435434
FilterChain filterChain = createFilterChain("filter-chain-1", createRds("rds"));
436435
xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain), null);
437-
xdsClient.rdsCount.await(5, TimeUnit.SECONDS);
436+
xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT);
438437
xdsClient.deliverRdsUpdate("rds",
439438
Collections.singletonList(createVirtualHost("virtual-host-1")));
440439
try {
@@ -544,7 +543,7 @@ public void run() {
544543
0L, Collections.singletonList(virtualHost), new ArrayList<NamedFilterConfig>());
545544
EnvoyServerProtoData.FilterChain f0 = createFilterChain("filter-chain-0", hcmVirtual);
546545
EnvoyServerProtoData.FilterChain f1 = createFilterChain("filter-chain-1", createRds("r0"));
547-
xdsClient.rdsCount = new CountDownLatch(3);
546+
xdsClient.setExpectedRdsCount(3);
548547
xdsClient.deliverLdsUpdate(Arrays.asList(f0, f1), null);
549548
assertThat(start.isDone()).isFalse();
550549
assertThat(selectorManager.getSelectorToUpdateSelector()).isNull();
@@ -556,7 +555,7 @@ public void run() {
556555
xdsClient.deliverLdsUpdate(Arrays.asList(f0, f2), f3);
557556
verify(mockServer, never()).start();
558557
verify(listener, never()).onServing();
559-
xdsClient.rdsCount.await(5, TimeUnit.SECONDS);
558+
xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT);
560559

561560
xdsClient.deliverRdsUpdate("r1",
562561
Collections.singletonList(createVirtualHost("virtual-host-1")));
@@ -602,12 +601,11 @@ public void run() {
602601
EnvoyServerProtoData.FilterChain f1 = createFilterChain("filter-chain-1", createRds("r0"));
603602
EnvoyServerProtoData.FilterChain f2 = createFilterChain("filter-chain-2", createRds("r0"));
604603

605-
xdsClient.rdsCount = new CountDownLatch(1);
606604
xdsClient.deliverLdsUpdate(Arrays.asList(f0, f1), f2);
607605
assertThat(start.isDone()).isFalse();
608606
assertThat(selectorManager.getSelectorToUpdateSelector()).isNull();
609607

610-
xdsClient.rdsCount.await(5, TimeUnit.SECONDS);
608+
xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT);
611609
xdsClient.deliverRdsUpdate("r0",
612610
Collections.singletonList(createVirtualHost("virtual-host-0")));
613611
start.get(5000, TimeUnit.MILLISECONDS);
@@ -633,9 +631,9 @@ public void run() {
633631
EnvoyServerProtoData.FilterChain f3 = createFilterChain("filter-chain-3", createRds("r0"));
634632
EnvoyServerProtoData.FilterChain f4 = createFilterChain("filter-chain-4", createRds("r1"));
635633
EnvoyServerProtoData.FilterChain f5 = createFilterChain("filter-chain-4", createRds("r1"));
636-
xdsClient.rdsCount = new CountDownLatch(1);
634+
xdsClient.setExpectedRdsCount(1);
637635
xdsClient.deliverLdsUpdate(Arrays.asList(f5, f3), f4);
638-
xdsClient.rdsCount.await(5, TimeUnit.SECONDS);
636+
xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT);
639637
xdsClient.deliverRdsUpdate("r1",
640638
Collections.singletonList(createVirtualHost("virtual-host-1")));
641639
xdsClient.deliverRdsUpdate("r0",
@@ -688,7 +686,7 @@ public void run() {
688686
EnvoyServerProtoData.FilterChain f0 = createFilterChain("filter-chain-0", hcmVirtual);
689687
EnvoyServerProtoData.FilterChain f1 = createFilterChain("filter-chain-1", createRds("r0"));
690688
xdsClient.deliverLdsUpdate(Arrays.asList(f0, f1), null);
691-
xdsClient.rdsCount.await();
689+
xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT);
692690
xdsClient.rdsWatchers.get("r0").onError(Status.CANCELLED);
693691
start.get(5000, TimeUnit.MILLISECONDS);
694692
assertThat(selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().size())
@@ -1235,7 +1233,7 @@ public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, Re
12351233
VirtualHost virtualHost = VirtualHost.create(
12361234
"v1", Collections.singletonList("foo.google.com"), Arrays.asList(route),
12371235
ImmutableMap.of("filter-config-name-0", f0Override));
1238-
xdsClient.rdsCount.await(5, TimeUnit.SECONDS);
1236+
xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT);
12391237
xdsClient.deliverRdsUpdate("r0", Collections.singletonList(virtualHost));
12401238
start.get(5000, TimeUnit.MILLISECONDS);
12411239
verify(mockServer).start();

0 commit comments

Comments
 (0)