|
| 1 | +/* |
| 2 | + * Copyright (C) 2024 The Android Open Source Project |
| 3 | + * |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + */ |
| 16 | + |
| 17 | +package androidx.test.services.shellexecutor |
| 18 | + |
| 19 | +import android.net.LocalServerSocket |
| 20 | +import android.net.LocalSocket |
| 21 | +import android.net.LocalSocketAddress |
| 22 | +import android.os.Process as AndroidProcess |
| 23 | +import android.util.Log |
| 24 | +import androidx.test.services.shellexecutor.LocalSocketProtocol.asBinderKey |
| 25 | +import androidx.test.services.shellexecutor.LocalSocketProtocol.readRequest |
| 26 | +import androidx.test.services.shellexecutor.LocalSocketProtocol.sendResponse |
| 27 | +import androidx.test.services.shellexecutor.LocalSocketProtocolProto.RunCommandRequest |
| 28 | +import java.io.IOException |
| 29 | +import java.io.InterruptedIOException |
| 30 | +import java.security.SecureRandom |
| 31 | +import java.util.concurrent.Executors |
| 32 | +import java.util.concurrent.atomic.AtomicBoolean |
| 33 | +import kotlin.time.Duration |
| 34 | +import kotlin.time.Duration.Companion.milliseconds |
| 35 | +import kotlinx.coroutines.CoroutineScope |
| 36 | +import kotlinx.coroutines.Job |
| 37 | +import kotlinx.coroutines.SupervisorJob |
| 38 | +import kotlinx.coroutines.TimeoutCancellationException |
| 39 | +import kotlinx.coroutines.asCoroutineDispatcher |
| 40 | +import kotlinx.coroutines.async |
| 41 | +import kotlinx.coroutines.coroutineScope |
| 42 | +import kotlinx.coroutines.delay |
| 43 | +import kotlinx.coroutines.launch |
| 44 | +import kotlinx.coroutines.runBlocking |
| 45 | +import kotlinx.coroutines.runInterruptible |
| 46 | +import kotlinx.coroutines.withTimeout |
| 47 | + |
| 48 | +/** Server that run shell commands for a client talking over a LocalSocket. */ |
| 49 | +final class ShellCommandLocalSocketExecutorServer |
| 50 | +@JvmOverloads |
| 51 | +constructor( |
| 52 | + private val scope: CoroutineScope = |
| 53 | + CoroutineScope(Executors.newCachedThreadPool().asCoroutineDispatcher()) |
| 54 | +) { |
| 55 | + // Use the same secret generation as SpeakEasy does. |
| 56 | + private val secret = java.lang.Long.toHexString(SecureRandom().nextLong()) |
| 57 | + lateinit var socket: LocalServerSocket |
| 58 | + lateinit var address: LocalSocketAddress |
| 59 | + // Since LocalServerSocket.accept() has to be interrupted, we keep that in its own Job... |
| 60 | + lateinit var serverJob: Job |
| 61 | + // ...while all the child jobs are under a single SupervisorJob that we can join later. |
| 62 | + val shellJobs = SupervisorJob() |
| 63 | + val running = AtomicBoolean(true) |
| 64 | + |
| 65 | + /** Returns the binder key to pass to client processes. */ |
| 66 | + fun binderKey(): String { |
| 67 | + // The address can contain spaces, and since it gets passed through a command line, we need to |
| 68 | + // encode it. java.net.URLEncoder is conveniently available in all SDK versions. |
| 69 | + return address.asBinderKey(secret) |
| 70 | + } |
| 71 | + |
| 72 | + /** Runs a simple server. */ |
| 73 | + private suspend fun server() = coroutineScope { |
| 74 | + while (running.get()) { |
| 75 | + val connection = |
| 76 | + try { |
| 77 | + runInterruptible { socket.accept() } |
| 78 | + } catch (x: Exception) { |
| 79 | + // None of my tests have managed to trigger this one. |
| 80 | + Log.e(TAG, "LocalServerSocket.accept() failed", x) |
| 81 | + break |
| 82 | + } |
| 83 | + launch(scope.coroutineContext + shellJobs) { handleConnection(connection) } |
| 84 | + } |
| 85 | + } |
| 86 | + |
| 87 | + /** |
| 88 | + * Relays the output of process to connection with a series of RunCommandResponses. |
| 89 | + * |
| 90 | + * @param process The process to relay output from. |
| 91 | + * @param connection The connection to relay output to. |
| 92 | + * @return false if there was a problem, true otherwise. |
| 93 | + */ |
| 94 | + private suspend fun relay(process: Process, connection: LocalSocket): Boolean { |
| 95 | + // Experiment shows that 64K is *much* faster than 4K, especially on API 21-23. Streaming 1MB |
| 96 | + // takes 3s with 4K buffers and 2s with 64K on API 23. 22 is a bit faster (2.6s -> 1.5s), |
| 97 | + // 21 faster still (630ms -> 545ms). Higher API levels are *much* faster (24 is 119 ms -> |
| 98 | + // 75ms). |
| 99 | + val buffer = ByteArray(65536) |
| 100 | + var size: Int |
| 101 | + |
| 102 | + // LocalSocket.isOutputShutdown() throws UnsupportedOperationException, so we can't use |
| 103 | + // that as our loop constraint. |
| 104 | + while (true) { |
| 105 | + try { |
| 106 | + size = runInterruptible { process.inputStream.read(buffer) } |
| 107 | + if (size < 0) return true // EOF |
| 108 | + if (size == 0) { |
| 109 | + delay(1.milliseconds) |
| 110 | + continue |
| 111 | + } |
| 112 | + } catch (x: InterruptedIOException) { |
| 113 | + // We start getting these at API 24 when the timeout handling kicks in. |
| 114 | + Log.i(TAG, "Interrupted while reading from ${process}: ${x.message}") |
| 115 | + return false |
| 116 | + } catch (x: IOException) { |
| 117 | + Log.i(TAG, "Error reading from ${process}; did it time out?", x) |
| 118 | + return false |
| 119 | + } |
| 120 | + |
| 121 | + if (!connection.sendResponse(buffer = buffer, size = size)) { |
| 122 | + return false |
| 123 | + } |
| 124 | + } |
| 125 | + } |
| 126 | + |
| 127 | + /** Handle one connection. */ |
| 128 | + private suspend fun handleConnection(connection: LocalSocket) { |
| 129 | + // connection.localSocketAddress is always null, so no point in logging it. |
| 130 | + |
| 131 | + // Close the connection when done. |
| 132 | + connection.use { |
| 133 | + val request = connection.readRequest() |
| 134 | + |
| 135 | + if (request.secret.compareTo(secret) != 0) { |
| 136 | + Log.w(TAG, "Ignoring request with wrong secret: $request") |
| 137 | + return |
| 138 | + } |
| 139 | + |
| 140 | + val pb = request.toProcessBuilder() |
| 141 | + pb.redirectErrorStream(true) |
| 142 | + |
| 143 | + val process: Process |
| 144 | + try { |
| 145 | + process = pb.start() |
| 146 | + } catch (x: IOException) { |
| 147 | + Log.e(TAG, "Failed to start process", x) |
| 148 | + connection.sendResponse( |
| 149 | + buffer = x.stackTraceToString().toByteArray(), |
| 150 | + exitCode = EXIT_CODE_FAILED_TO_START, |
| 151 | + ) |
| 152 | + return |
| 153 | + } |
| 154 | + |
| 155 | + // We will not be writing anything to the process' stdin. |
| 156 | + process.outputStream.close() |
| 157 | + |
| 158 | + // Close the process' stdout when we're done reading. |
| 159 | + process.inputStream.use { |
| 160 | + // Launch a coroutine to relay the process' output to the client. If it times out, kill the |
| 161 | + // process and cancel the job. This is more coroutine-friendly than using waitFor() to |
| 162 | + // handle timeouts. |
| 163 | + val ioJob = scope.async { relay(process, connection) } |
| 164 | + |
| 165 | + try { |
| 166 | + withTimeout(request.timeout()) { |
| 167 | + if (!ioJob.await()) { |
| 168 | + Log.w(TAG, "Relaying ${process} output failed") |
| 169 | + } |
| 170 | + runInterruptible { process.waitFor() } |
| 171 | + } |
| 172 | + } catch (x: TimeoutCancellationException) { |
| 173 | + Log.e(TAG, "Process ${process} timed out after ${request.timeout()}") |
| 174 | + process.destroy() |
| 175 | + ioJob.cancel() |
| 176 | + connection.sendResponse(exitCode = EXIT_CODE_TIMED_OUT) |
| 177 | + return |
| 178 | + } |
| 179 | + |
| 180 | + connection.sendResponse(exitCode = process.exitValue()) |
| 181 | + } |
| 182 | + } |
| 183 | + } |
| 184 | + |
| 185 | + /** Starts the server. */ |
| 186 | + fun start() { |
| 187 | + socket = LocalServerSocket("androidx.test.services ${AndroidProcess.myPid()}") |
| 188 | + address = socket.localSocketAddress |
| 189 | + Log.i(TAG, "Starting server on ${address.name}") |
| 190 | + |
| 191 | + // Launch a coroutine to call socket.accept() |
| 192 | + serverJob = scope.launch { server() } |
| 193 | + } |
| 194 | + |
| 195 | + /** Stops the server. */ |
| 196 | + fun stop(timeout: Duration) { |
| 197 | + running.set(false) |
| 198 | + // Closing the socket does not interrupt accept()... |
| 199 | + socket.close() |
| 200 | + runBlocking(scope.coroutineContext) { |
| 201 | + try { |
| 202 | + // ...so we simply cancel that job... |
| 203 | + serverJob.cancel() |
| 204 | + // ...and play nicely with all the shell jobs underneath. |
| 205 | + withTimeout(timeout) { |
| 206 | + shellJobs.complete() |
| 207 | + shellJobs.join() |
| 208 | + } |
| 209 | + } catch (x: TimeoutCancellationException) { |
| 210 | + Log.w(TAG, "Shell jobs did not stop after $timeout", x) |
| 211 | + shellJobs.cancel() |
| 212 | + } |
| 213 | + } |
| 214 | + } |
| 215 | + |
| 216 | + private fun RunCommandRequest.timeout(): Duration = |
| 217 | + if (timeoutMs <= 0) { |
| 218 | + Duration.INFINITE |
| 219 | + } else { |
| 220 | + timeoutMs.milliseconds |
| 221 | + } |
| 222 | + |
| 223 | + /** |
| 224 | + * Sets up a ProcessBuilder with information from the request; other configuration is up to the |
| 225 | + * caller. |
| 226 | + */ |
| 227 | + private fun RunCommandRequest.toProcessBuilder(): ProcessBuilder { |
| 228 | + val pb = ProcessBuilder(argvList) |
| 229 | + val redacted = argvList.map { it.replace(secret, "(SECRET)") } // Don't log the secret! |
| 230 | + Log.i(TAG, "Command to execute: [${redacted.joinToString("] [")}] within ${timeout()}") |
| 231 | + if (environmentMap.isNotEmpty()) { |
| 232 | + pb.environment().putAll(environmentMap) |
| 233 | + val env = environmentMap.entries.map { (k, v) -> "$k=$v" }.joinToString(", ") |
| 234 | + Log.i(TAG, "Environment: $env") |
| 235 | + } |
| 236 | + return pb |
| 237 | + } |
| 238 | + |
| 239 | + private companion object { |
| 240 | + const val TAG = "SCLSEServer" // up to 23 characters |
| 241 | + |
| 242 | + const val EXIT_CODE_FAILED_TO_START = -1 |
| 243 | + const val EXIT_CODE_TIMED_OUT = -2 |
| 244 | + } |
| 245 | +} |
0 commit comments