Skip to content

Commit 2fa0023

Browse files
sainitarun063dawn-scoped@luci-project-accounts.iam.gserviceaccount.com
authored andcommitted
[android][webgpu] Fix thread safety and resource leaks in WebGpu helper.
Migrates the WebGpu helper and associated instrumentation tests to run sequentially on a dedicated single-threaded CoroutineDispatcher. - Prevents thread leaks by closing the dispatcher service upon close() if owned by the instance. - Prevents reentrancy deadlocks during teardown by using a ThreadLocal to run cleanup inline when already on the WebGPU thread. - Updates all tests to run within the dispatcher's execution context and clean up dispatchers on teardown. Bug: 490019860 Change-Id: I716ce0e616105c65a8dce7d489c62d9b3156df69 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/311456 Reviewed-by: Loko Kung <lokokung@google.com> Commit-Queue: Tarun Saini <sainitarun@google.com> Reviewed-by: Mridul Goyal <mridulgoyal@google.com>
1 parent 5c52c5d commit 2fa0023

26 files changed

Lines changed: 2824 additions & 1977 deletions

tools/android/webgpu/src/androidTest/java/androidx/webgpu/ApiRequirement.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,5 @@ import kotlin.annotation.Target
2222
@Target(AnnotationTarget.FUNCTION)
2323
annotation class ApiRequirement(
2424
val minApi: Int,
25-
val onlySkipOnEmulator: Boolean = false
25+
val onlySkipOnEmulator: Boolean = false,
2626
)

tools/android/webgpu/src/androidTest/java/androidx/webgpu/AsyncHelperTest.kt

Lines changed: 124 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,16 @@ import androidx.test.ext.junit.runners.AndroidJUnit4
1919
import androidx.test.filters.SmallTest
2020
import androidx.webgpu.helper.WebGpu
2121
import androidx.webgpu.helper.createWebGpu
22+
import java.util.concurrent.Executors
23+
import kotlinx.coroutines.CoroutineDispatcher
24+
import kotlinx.coroutines.CoroutineScope
25+
import kotlinx.coroutines.Dispatchers
26+
import kotlinx.coroutines.ExecutorCoroutineDispatcher
27+
import kotlinx.coroutines.asCoroutineDispatcher
28+
import kotlinx.coroutines.cancel
2229
import kotlinx.coroutines.launch
2330
import kotlinx.coroutines.runBlocking
31+
import org.junit.After
2432
import org.junit.Assert.assertEquals
2533
import org.junit.Assert.assertThrows
2634
import org.junit.Assume.assumeFalse
@@ -33,10 +41,14 @@ import java.util.concurrent.atomic.AtomicBoolean
3341
@SmallTest
3442
class AsyncHelperTest {
3543

36-
private lateinit var webGpu: WebGpu
37-
private lateinit var device: GPUDevice
44+
private val dispatcher: CoroutineDispatcher = Executors.newSingleThreadExecutor { runnable ->
45+
Thread(runnable, "Test-WebGPU-Thread")
46+
}.asCoroutineDispatcher()
47+
private val testScope = CoroutineScope(dispatcher)
48+
private lateinit var device: GPUDevice
49+
private lateinit var webGpu: WebGpu
3850

39-
private val BASIC_SHADER = """
51+
private val BASIC_SHADER = """
4052
@vertex fn vertexMain(@builtin(vertex_index) i : u32) ->
4153
@builtin(position) vec4f {
4254
return vec4f();
@@ -45,113 +57,128 @@ class AsyncHelperTest {
4557
return vec4f();
4658
} """
4759

48-
@Before
49-
fun setup() {
50-
runBlocking {
51-
webGpu = createWebGpu()
52-
device = webGpu.device
53-
}
60+
@Before
61+
fun setup(): Unit = runBlocking {
62+
webGpu = createWebGpu(dispatcher)
63+
device = webGpu.device
64+
testScope.launch {
65+
webGpu.processEventsLoop()
5466
}
67+
}
5568

56-
@Test
57-
fun asyncMethodTest() {
58-
runBlocking {
59-
/* Set up a shader module to support the async call. */
60-
val shaderModule = device.createShaderModule(
61-
GPUShaderModuleDescriptor(shaderSourceWGSL = GPUShaderSourceWGSL(""))
62-
)
69+
@After
70+
fun teardown() {
71+
if (::webGpu.isInitialized) {
72+
webGpu.close()
73+
}
74+
testScope.cancel()
75+
(dispatcher as? ExecutorCoroutineDispatcher)?.close()
76+
}
77+
78+
@Test
79+
fun asyncMethodTest() {
80+
runBlocking {
81+
webGpu.execute {
82+
/* Set up a shader module to support the async call. */
83+
val shaderModule = device.createShaderModule(
84+
GPUShaderModuleDescriptor(shaderSourceWGSL = GPUShaderSourceWGSL(""))
85+
)
6386

64-
val exception = assertThrows(WebGpuException::class.java) {
65-
runBlocking {
66-
/* Call an asynchronous method, converted from a callback pattern by a helper. */
67-
device.createRenderPipelineAndAwait(
68-
GPURenderPipelineDescriptor(vertex = GPUVertexState(module = shaderModule))
69-
)
70-
}
71-
}
72-
73-
assertEquals(
74-
"""Create render pipeline (async) should fail when no shader entry point exists.
75-
The result was: ${exception.status}""",
76-
CreatePipelineAsyncStatus.ValidationError,
77-
exception.status
78-
)
87+
val exception = assertThrowsSuspend(WebGpuException::class.java) {
88+
/* Call an asynchronous method, converted from a callback pattern by a helper. */
89+
device.createRenderPipelineAndAwait(
90+
GPURenderPipelineDescriptor(vertex = GPUVertexState(module = shaderModule))
91+
)
7992
}
80-
}
8193

82-
@Test
83-
fun asyncMethodTestValidationPasses() {
84-
runBlocking {
85-
/* Set up a valid shader module and descriptor */
86-
val shaderModule = device.createShaderModule(
87-
GPUShaderModuleDescriptor(shaderSourceWGSL = GPUShaderSourceWGSL(BASIC_SHADER))
88-
)
94+
assertEquals(
95+
"""Create render pipeline (async) should fail when no shader entry point exists.
96+
The result was: ${exception.status}""",
97+
CreatePipelineAsyncStatus.ValidationError,
98+
exception.status
99+
)
100+
}
101+
}
102+
}
103+
104+
@Test
105+
fun asyncMethodTestValidationPasses() {
106+
runBlocking {
107+
webGpu.execute {
108+
/* Set up a valid shader module and descriptor */
109+
val shaderModule = device.createShaderModule(
110+
GPUShaderModuleDescriptor(shaderSourceWGSL = GPUShaderSourceWGSL(BASIC_SHADER))
111+
)
89112

90-
/* Call an asynchronous method, converted from a callback pattern by a helper. */
91-
val unused = device.createRenderPipelineAndAwait(
92-
GPURenderPipelineDescriptor(
93-
vertex = GPUVertexState(module = shaderModule), fragment = GPUFragmentState(
94-
module = shaderModule,
95-
targets = arrayOf(GPUColorTargetState(format = TextureFormat.RGBA8Unorm))
96-
)
97-
)
113+
/* Call an asynchronous method, converted from a callback pattern by a helper. */
114+
val unused = device.createRenderPipelineAndAwait(
115+
GPURenderPipelineDescriptor(
116+
vertex = GPUVertexState(module = shaderModule), fragment = GPUFragmentState(
117+
module = shaderModule,
118+
targets = arrayOf(GPUColorTargetState(format = TextureFormat.RGBA8Unorm))
98119
)
120+
)
121+
)
99122

100-
/* Create render pipeline (async) should pass with a simple shader.. */
101-
}
123+
/* Create render pipeline (async) should pass with a simple shader.. */
124+
}
102125
}
126+
}
103127

104-
private fun baseCancellationTest(doCancel: Boolean): Boolean {
105-
val hasReturned = AtomicBoolean(false)
128+
private fun baseCancellationTest(doCancel: Boolean): Boolean {
129+
val hasReturned = AtomicBoolean(false)
106130

107-
runBlocking {
108-
val shaderModule = device.createShaderModule(
109-
GPUShaderModuleDescriptor(shaderSourceWGSL = GPUShaderSourceWGSL(BASIC_SHADER))
110-
)
131+
runBlocking {
132+
webGpu.execute {
133+
val shaderModule = device.createShaderModule(
134+
GPUShaderModuleDescriptor(shaderSourceWGSL = GPUShaderSourceWGSL(BASIC_SHADER))
135+
)
111136

112-
/* Launch the function in a new coroutine, giving us a job handle we can cancel. */
113-
val job = launch {
114-
var unused = device.createRenderPipelineAndAwait(
115-
GPURenderPipelineDescriptor(vertex = GPUVertexState(module = shaderModule),
116-
fragment = GPUFragmentState(
117-
module = shaderModule,
118-
targets = arrayOf(GPUColorTargetState(format = TextureFormat.RGBA8Unorm))
119-
)
120-
)
121-
)
122-
hasReturned.set(true)
123-
}
124-
assumeFalse("The job completed before we could test it", hasReturned.get())
125-
126-
if (doCancel) {
127-
job.cancel()
128-
}
129-
job.join()
137+
/* Launch the function in a new coroutine, giving us a job handle we can cancel. */
138+
val job = launch {
139+
var unused = device.createRenderPipelineAndAwait(
140+
GPURenderPipelineDescriptor(
141+
vertex = GPUVertexState(module = shaderModule),
142+
fragment = GPUFragmentState(
143+
module = shaderModule,
144+
targets = arrayOf(GPUColorTargetState(format = TextureFormat.RGBA8Unorm))
145+
)
146+
)
147+
)
148+
hasReturned.set(true)
130149
}
131-
return hasReturned.get()
132-
}
133-
134-
/**
135-
* Test that the async-based job will complete if it's not cancelled.
136-
*/
137-
@Test
138-
fun asyncMethodCancellationTestControl() {
139-
assertEquals(
140-
"The async job should have completed but it failed to do so.",
141-
true,
142-
baseCancellationTest(false)
143-
)
144-
}
150+
assumeFalse("The job completed before we could test it", hasReturned.get())
145151

146-
/**
147-
* Test that the async-based job will not complete if it is cancelled.
148-
*/
149-
@Test
150-
fun asyncMethodCancellationTest() {
151-
assertEquals(
152-
"The async job should have been cancelled but it completed.",
153-
false,
154-
baseCancellationTest(true)
155-
)
152+
if (doCancel) {
153+
job.cancel()
154+
}
155+
job.join()
156+
}
156157
}
158+
return hasReturned.get()
159+
}
160+
161+
/**
162+
* Test that the async-based job will complete if it's not cancelled.
163+
*/
164+
@Test
165+
fun asyncMethodCancellationTestControl() {
166+
assertEquals(
167+
"The async job should have completed but it failed to do so.",
168+
true,
169+
baseCancellationTest(false)
170+
)
171+
}
172+
173+
/**
174+
* Test that the async-based job will not complete if it is cancelled.
175+
*/
176+
@Test
177+
fun asyncMethodCancellationTest() {
178+
assertEquals(
179+
"The async job should have been cancelled but it completed.",
180+
false,
181+
baseCancellationTest(true)
182+
)
183+
}
157184
}

tools/android/webgpu/src/androidTest/java/androidx/webgpu/BitmapUtils.kt

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@ import java.io.File
2222
import java.io.FileOutputStream
2323

2424
fun writeReferenceImage(bitmap: Bitmap) {
25-
val path =
26-
Environment.getExternalStoragePublicDirectory(Environment.DIRECTORY_DOWNLOADS)
27-
val file = File("${path}${File.separator}${"reference.png"}")
28-
BufferedOutputStream(FileOutputStream(file)).use {
29-
bitmap.compress(Bitmap.CompressFormat.PNG, 100, it)
30-
it.close()
31-
}
25+
val path =
26+
Environment.getExternalStoragePublicDirectory(Environment.DIRECTORY_DOWNLOADS)
27+
val file = File("${path}${File.separator}${"reference.png"}")
28+
BufferedOutputStream(FileOutputStream(file)).use {
29+
bitmap.compress(Bitmap.CompressFormat.PNG, 100, it)
30+
it.close()
31+
}
3232
}

0 commit comments

Comments
 (0)