Skip to content

Commit 0c431c4

Browse files
Integration tests for LocalSocketShellMain.
PiperOrigin-RevId: 681170539
1 parent 35bdab8 commit 0c431c4

File tree

9 files changed

+605
-21
lines changed

9 files changed

+605
-21
lines changed

services/shellexecutor/java/androidx/test/services/shellexecutor/BUILD

+5
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,12 @@ kt_android_library(
6060
srcs = [
6161
"BlockingPublish.java",
6262
"FileObserverShellMain.kt",
63+
"LocalSocketShellMain.kt",
6364
"ShellCommand.java",
6465
"ShellCommandExecutor.java",
6566
"ShellCommandExecutorServer.java",
6667
"ShellCommandFileObserverExecutorServer.kt",
68+
"ShellCommandLocalSocketExecutorServer.kt",
6769
"ShellExecSharedConstants.java",
6870
"ShellMain.java",
6971
],
@@ -72,6 +74,8 @@ kt_android_library(
7274
deps = [
7375
":coroutine_file_observer",
7476
":file_observer_protocol",
77+
":local_socket_protocol",
78+
":local_socket_protocol_pb_java_proto_lite",
7579
"//services/speakeasy/java/androidx/test/services/speakeasy:protocol",
7680
"//services/speakeasy/java/androidx/test/services/speakeasy/client",
7781
"//services/speakeasy/java/androidx/test/services/speakeasy/client:tool_connection",
@@ -94,6 +98,7 @@ kt_android_library(
9498
"ShellExecutorFactory.java",
9599
"ShellExecutorFileObserverImpl.kt",
96100
"ShellExecutorImpl.java",
101+
"ShellExecutorLocalSocketImpl.kt",
97102
],
98103
idl_srcs = ["Command.aidl"],
99104
visibility = [":export"],
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
/*
2+
* Copyright (C) 2024 The Android Open Source Project
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package androidx.test.services.shellexecutor
18+
19+
import android.util.Log
20+
import java.io.IOException
21+
import java.io.InputStream
22+
import java.io.OutputStream
23+
import java.util.concurrent.Executors
24+
import kotlin.time.Duration.Companion.milliseconds
25+
import kotlinx.coroutines.CoroutineScope
26+
import kotlinx.coroutines.asCoroutineDispatcher
27+
import kotlinx.coroutines.launch
28+
import kotlinx.coroutines.runBlocking
29+
import kotlinx.coroutines.runInterruptible
30+
31+
/** Variant of ShellMain that uses a LocalSocket to communicate with the client. */
32+
class LocalSocketShellMain {
33+
34+
suspend fun run(args: Array<String>): Int {
35+
val scope = CoroutineScope(Executors.newCachedThreadPool().asCoroutineDispatcher())
36+
val server = ShellCommandLocalSocketExecutorServer(scope = scope)
37+
server.start()
38+
39+
val processArgs = args.toMutableList()
40+
processArgs.addAll(
41+
processArgs.size - 1,
42+
listOf("-e", ShellExecSharedConstants.BINDER_KEY, server.binderKey()),
43+
)
44+
val pb = ProcessBuilder(processArgs.toList())
45+
46+
val exitCode: Int
47+
48+
try {
49+
val process = pb.start()
50+
51+
val stdinCopier = scope.launch { copyStream("stdin", System.`in`, process.outputStream) }
52+
val stdoutCopier = scope.launch { copyStream("stdout", process.inputStream, System.out) }
53+
val stderrCopier = scope.launch { copyStream("stderr", process.errorStream, System.err) }
54+
55+
runInterruptible { process.waitFor() }
56+
exitCode = process.exitValue()
57+
58+
stdinCopier.cancel() // System.`in`.close() does not force input.read() to return
59+
stdoutCopier.join()
60+
stderrCopier.join()
61+
} finally {
62+
server.stop(100.milliseconds)
63+
}
64+
return exitCode
65+
}
66+
67+
suspend fun copyStream(name: String, input: InputStream, output: OutputStream) {
68+
val buf = ByteArray(1024)
69+
try {
70+
while (true) {
71+
val size = input.read(buf)
72+
if (size == -1) break
73+
output.write(buf, 0, size)
74+
}
75+
output.flush()
76+
} catch (x: IOException) {
77+
Log.e(TAG, "IOException on $name. Terminating.", x)
78+
}
79+
}
80+
81+
companion object {
82+
private const val TAG = "LocalSocketShellMain"
83+
84+
@JvmStatic
85+
public fun main(args: Array<String>) {
86+
System.exit(runBlocking { LocalSocketShellMain().run(args) })
87+
}
88+
}
89+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
/*
2+
* Copyright (C) 2024 The Android Open Source Project
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package androidx.test.services.shellexecutor
18+
19+
import android.net.LocalServerSocket
20+
import android.net.LocalSocket
21+
import android.net.LocalSocketAddress
22+
import android.os.Process as AndroidProcess
23+
import android.util.Log
24+
import androidx.test.services.shellexecutor.LocalSocketProtocol.asBinderKey
25+
import androidx.test.services.shellexecutor.LocalSocketProtocol.readRequest
26+
import androidx.test.services.shellexecutor.LocalSocketProtocol.sendResponse
27+
import androidx.test.services.shellexecutor.LocalSocketProtocolProto.RunCommandRequest
28+
import java.io.IOException
29+
import java.io.InterruptedIOException
30+
import java.security.SecureRandom
31+
import java.util.concurrent.Executors
32+
import java.util.concurrent.atomic.AtomicBoolean
33+
import kotlin.time.Duration
34+
import kotlin.time.Duration.Companion.milliseconds
35+
import kotlinx.coroutines.CoroutineScope
36+
import kotlinx.coroutines.Job
37+
import kotlinx.coroutines.SupervisorJob
38+
import kotlinx.coroutines.TimeoutCancellationException
39+
import kotlinx.coroutines.asCoroutineDispatcher
40+
import kotlinx.coroutines.async
41+
import kotlinx.coroutines.coroutineScope
42+
import kotlinx.coroutines.delay
43+
import kotlinx.coroutines.launch
44+
import kotlinx.coroutines.runBlocking
45+
import kotlinx.coroutines.runInterruptible
46+
import kotlinx.coroutines.withTimeout
47+
48+
/** Server that run shell commands for a client talking over a LocalSocket. */
49+
final class ShellCommandLocalSocketExecutorServer
50+
@JvmOverloads
51+
constructor(
52+
private val scope: CoroutineScope =
53+
CoroutineScope(Executors.newCachedThreadPool().asCoroutineDispatcher())
54+
) {
55+
// Use the same secret generation as SpeakEasy does.
56+
private val secret = java.lang.Long.toHexString(SecureRandom().nextLong())
57+
lateinit var socket: LocalServerSocket
58+
lateinit var address: LocalSocketAddress
59+
// Since LocalServerSocket.accept() has to be interrupted, we keep that in its own Job...
60+
lateinit var serverJob: Job
61+
// ...while all the child jobs are under a single SupervisorJob that we can join later.
62+
val shellJobs = SupervisorJob()
63+
val running = AtomicBoolean(true)
64+
65+
/** Returns the binder key to pass to client processes. */
66+
fun binderKey(): String {
67+
// The address can contain spaces, and since it gets passed through a command line, we need to
68+
// encode it. java.net.URLEncoder is conveniently available in all SDK versions.
69+
return address.asBinderKey(secret)
70+
}
71+
72+
/** Runs a simple server. */
73+
private suspend fun server() = coroutineScope {
74+
while (running.get()) {
75+
val connection =
76+
try {
77+
runInterruptible { socket.accept() }
78+
} catch (x: Exception) {
79+
// None of my tests have managed to trigger this one.
80+
Log.e(TAG, "LocalServerSocket.accept() failed", x)
81+
break
82+
}
83+
launch(scope.coroutineContext + shellJobs) { handleConnection(connection) }
84+
}
85+
}
86+
87+
/**
88+
* Relays the output of process to connection with a series of RunCommandResponses.
89+
*
90+
* @param process The process to relay output from.
91+
* @param connection The connection to relay output to.
92+
* @return false if there was a problem, true otherwise.
93+
*/
94+
private suspend fun relay(process: Process, connection: LocalSocket): Boolean {
95+
// Experiment shows that 64K is *much* faster than 4K, especially on API 21-23. Streaming 1MB
96+
// takes 3s with 4K buffers and 2s with 64K on API 23. 22 is a bit faster (2.6s -> 1.5s),
97+
// 21 faster still (630ms -> 545ms). Higher API levels are *much* faster (24 is 119 ms ->
98+
// 75ms).
99+
val buffer = ByteArray(65536)
100+
var size: Int
101+
102+
// LocalSocket.isOutputShutdown() throws UnsupportedOperationException, so we can't use
103+
// that as our loop constraint.
104+
while (true) {
105+
try {
106+
size = runInterruptible { process.inputStream.read(buffer) }
107+
if (size < 0) return true // EOF
108+
if (size == 0) {
109+
delay(1.milliseconds)
110+
continue
111+
}
112+
} catch (x: InterruptedIOException) {
113+
// We start getting these at API 24 when the timeout handling kicks in.
114+
Log.i(TAG, "Interrupted while reading from ${process}: ${x.message}")
115+
return false
116+
} catch (x: IOException) {
117+
Log.i(TAG, "Error reading from ${process}; did it time out?", x)
118+
return false
119+
}
120+
121+
if (!connection.sendResponse(buffer = buffer, size = size)) {
122+
return false
123+
}
124+
}
125+
}
126+
127+
/** Handle one connection. */
128+
private suspend fun handleConnection(connection: LocalSocket) {
129+
// connection.localSocketAddress is always null, so no point in logging it.
130+
131+
// Close the connection when done.
132+
connection.use {
133+
val request = connection.readRequest()
134+
135+
if (request.secret.compareTo(secret) != 0) {
136+
Log.w(TAG, "Ignoring request with wrong secret: $request")
137+
return
138+
}
139+
140+
val pb = request.toProcessBuilder()
141+
pb.redirectErrorStream(true)
142+
143+
val process: Process
144+
try {
145+
process = pb.start()
146+
} catch (x: IOException) {
147+
Log.e(TAG, "Failed to start process", x)
148+
connection.sendResponse(
149+
buffer = x.stackTraceToString().toByteArray(),
150+
exitCode = EXIT_CODE_FAILED_TO_START,
151+
)
152+
return
153+
}
154+
155+
// We will not be writing anything to the process' stdin.
156+
process.outputStream.close()
157+
158+
// Close the process' stdout when we're done reading.
159+
process.inputStream.use {
160+
// Launch a coroutine to relay the process' output to the client. If it times out, kill the
161+
// process and cancel the job. This is more coroutine-friendly than using waitFor() to
162+
// handle timeouts.
163+
val ioJob = scope.async { relay(process, connection) }
164+
165+
try {
166+
withTimeout(request.timeout()) {
167+
if (!ioJob.await()) {
168+
Log.w(TAG, "Relaying ${process} output failed")
169+
}
170+
runInterruptible { process.waitFor() }
171+
}
172+
} catch (x: TimeoutCancellationException) {
173+
Log.e(TAG, "Process ${process} timed out after ${request.timeout()}")
174+
process.destroy()
175+
ioJob.cancel()
176+
connection.sendResponse(exitCode = EXIT_CODE_TIMED_OUT)
177+
return
178+
}
179+
180+
connection.sendResponse(exitCode = process.exitValue())
181+
}
182+
}
183+
}
184+
185+
/** Starts the server. */
186+
fun start() {
187+
socket = LocalServerSocket("androidx.test.services ${AndroidProcess.myPid()}")
188+
address = socket.localSocketAddress
189+
Log.i(TAG, "Starting server on ${address.name}")
190+
191+
// Launch a coroutine to call socket.accept()
192+
serverJob = scope.launch { server() }
193+
}
194+
195+
/** Stops the server. */
196+
fun stop(timeout: Duration) {
197+
running.set(false)
198+
// Closing the socket does not interrupt accept()...
199+
socket.close()
200+
runBlocking(scope.coroutineContext) {
201+
try {
202+
// ...so we simply cancel that job...
203+
serverJob.cancel()
204+
// ...and play nicely with all the shell jobs underneath.
205+
withTimeout(timeout) {
206+
shellJobs.complete()
207+
shellJobs.join()
208+
}
209+
} catch (x: TimeoutCancellationException) {
210+
Log.w(TAG, "Shell jobs did not stop after $timeout", x)
211+
shellJobs.cancel()
212+
}
213+
}
214+
}
215+
216+
private fun RunCommandRequest.timeout(): Duration =
217+
if (timeoutMs <= 0) {
218+
Duration.INFINITE
219+
} else {
220+
timeoutMs.milliseconds
221+
}
222+
223+
/**
224+
* Sets up a ProcessBuilder with information from the request; other configuration is up to the
225+
* caller.
226+
*/
227+
private fun RunCommandRequest.toProcessBuilder(): ProcessBuilder {
228+
val pb = ProcessBuilder(argvList)
229+
val redacted = argvList.map { it.replace(secret, "(SECRET)") } // Don't log the secret!
230+
Log.i(TAG, "Command to execute: [${redacted.joinToString("] [")}] within ${timeout()}")
231+
if (environmentMap.isNotEmpty()) {
232+
pb.environment().putAll(environmentMap)
233+
val env = environmentMap.entries.map { (k, v) -> "$k=$v" }.joinToString(", ")
234+
Log.i(TAG, "Environment: $env")
235+
}
236+
return pb
237+
}
238+
239+
private companion object {
240+
const val TAG = "SCLSEServer" // up to 23 characters
241+
242+
const val EXIT_CODE_FAILED_TO_START = -1
243+
const val EXIT_CODE_TIMED_OUT = -2
244+
}
245+
}

services/shellexecutor/java/androidx/test/services/shellexecutor/ShellExecutorFactory.java

+5-2
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,11 @@ public ShellExecutorFactory(Context context, String binderKey) {
3131

3232
public ShellExecutor create() {
3333
// Binder keys for SpeakEasy are a string of hex digits. Binder keys for the FileObserver
34-
// protocol are the absolute path of the directory that the server is watching.
35-
if (binderKey.startsWith("/")) {
34+
// protocol are the absolute path of the directory that the server is watching. Binder keys for
35+
// the LocalSocket protocol start and end with a colon.
36+
if (LocalSocketProtocol.isBinderKey(binderKey)) {
37+
return new ShellExecutorLocalSocketImpl(binderKey);
38+
} else if (binderKey.startsWith("/")) {
3639
return new ShellExecutorFileObserverImpl(binderKey);
3740
} else {
3841
return new ShellExecutorImpl(context, binderKey);

0 commit comments

Comments
 (0)