Skip to content

[WIP] DO NOT MERGE : Implement Lambda streaming with custom HTTP headers #521

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions Examples/Streaming/Package.swift
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// swift-tools-version:6.0
// swift-tools-version:6.1

import PackageDescription

// needed for CI to test the local version of the library
import struct Foundation.URL

Expand Down
11 changes: 11 additions & 0 deletions Examples/Streaming/Sources/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,24 @@

import AWSLambdaRuntime
import NIOCore
import NIOHTTP1

struct SendNumbersWithPause: StreamingLambdaHandler {
func handle(
_ event: ByteBuffer,
responseWriter: some LambdaResponseStreamWriter,
context: LambdaContext
) async throws {
context.logger.info("Received event: \(event)")
try await responseWriter.writeHeaders(
HTTPHeaders([
("X-Example-Header", "This is an example header")
])
)

try await responseWriter.write(
ByteBuffer(string: "Starting to send numbers with a pause...\n")
)
for i in 1...10 {
// Send partial data
try await responseWriter.write(ByteBuffer(string: "\(i)\n"))
Expand Down
3 changes: 3 additions & 0 deletions Examples/Streaming/template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ Resources:
MemorySize: 128
Architectures:
- arm64
Environment:
Variables:
LOG_LEVEL: trace
FunctionUrlConfig:
AuthType: AWS_IAM
InvokeMode: RESPONSE_STREAM
Expand Down
1 change: 1 addition & 0 deletions Sources/AWSLambdaRuntime/Lambda+LocalServer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ internal struct LambdaHTTPServer {
case .end:
precondition(requestHead != nil, "Received .end without .head")
// process the request
// FIXME: this do not support response streaming
let response = try await self.processRequest(
head: requestHead,
body: requestBody,
Expand Down
5 changes: 5 additions & 0 deletions Sources/AWSLambdaRuntime/LambdaHandlers.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
//===----------------------------------------------------------------------===//

import NIOCore
import NIOHTTP1

/// The base handler protocol that receives a `ByteBuffer` representing the incoming event and returns the response as a `ByteBuffer` too.
/// This handler protocol supports response streaming. Bytes can be streamed outwards through the ``LambdaResponseStreamWriter``
Expand Down Expand Up @@ -46,6 +47,10 @@ public protocol StreamingLambdaHandler: _Lambda_SendableMetatype {
/// A writer object to write the Lambda response stream into. The HTTP response is started lazily.
/// before the first call to ``write(_:)`` or ``writeAndFinish(_:)``.
public protocol LambdaResponseStreamWriter {
/// Write the headers parts of the stream. This allows client to set headers before the first response part is written.
/// - Parameter buffer: The buffer to write.
func writeHeaders(_ headers: HTTPHeaders) async throws

/// Write a response part into the stream. Bytes written are streamed continually.
/// - Parameter buffer: The buffer to write.
func write(_ buffer: ByteBuffer) async throws
Expand Down
48 changes: 35 additions & 13 deletions Sources/AWSLambdaRuntime/LambdaRuntimeClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ final actor LambdaRuntimeClient: LambdaRuntimeClientProtocol {
self.runtimeClient = runtimeClient
}

@usableFromInline
func writeHeaders(_ headers: HTTPHeaders) async throws {
try await self.runtimeClient.appendHeaders(headers)
}

@usableFromInline
func write(_ buffer: NIOCore.ByteBuffer) async throws {
try await self.runtimeClient.write(buffer)
Expand Down Expand Up @@ -188,6 +193,13 @@ final actor LambdaRuntimeClient: LambdaRuntimeClientProtocol {
}
}

// we can use a var here because we are always isolated to this actor
private var userHeaders = HTTPHeaders()
private func appendHeaders(_ headers: HTTPHeaders) async throws {
// buffer the data to send them when we will send the headers
userHeaders.add(contentsOf: headers)
}

private func write(_ buffer: NIOCore.ByteBuffer) async throws {
switch self.lambdaState {
case .idle, .sentResponse:
Expand All @@ -205,7 +217,7 @@ final actor LambdaRuntimeClient: LambdaRuntimeClientProtocol {
guard case .sendingResponse(requestID) = self.lambdaState else {
fatalError("Invalid state: \(self.lambdaState)")
}
return try await handler.writeResponseBodyPart(buffer, requestID: requestID)
return try await handler.writeResponseBodyPart(self.userHeaders, buffer, requestID: requestID)
}
}

Expand All @@ -226,7 +238,7 @@ final actor LambdaRuntimeClient: LambdaRuntimeClientProtocol {
guard case .sentResponse(requestID) = self.lambdaState else {
fatalError("Invalid state: \(self.lambdaState)")
}
try await handler.finishResponseRequest(finalData: buffer, requestID: requestID)
try await handler.finishResponseRequest(userHeaders: self.userHeaders, finalData: buffer, requestID: requestID)
guard case .sentResponse(requestID) = self.lambdaState else {
fatalError("Invalid state: \(self.lambdaState)")
}
Expand Down Expand Up @@ -484,9 +496,11 @@ private final class LambdaChannelHandler<Delegate: LambdaChannelHandlerDelegate>
"lambda-runtime-function-error-type": "Unhandled",
]
self.streamingHeaders = [
"host": "\(self.configuration.ip):\(self.configuration.port)",
"Host": "\(self.configuration.ip):\(self.configuration.port)",
"user-agent": .userAgent,
"transfer-encoding": "chunked",
// "Content-type": "application/vnd.awslambda.http-integration-response",
// "Transfer-encoding": "chunked",
// "Lambda-Runtime-Function-Response-Mode": "streaming",
]
}

Expand Down Expand Up @@ -555,6 +569,7 @@ private final class LambdaChannelHandler<Delegate: LambdaChannelHandlerDelegate>

func writeResponseBodyPart(
isolation: isolated (any Actor)? = #isolation,
_ userHeaders: HTTPHeaders,
_ byteBuffer: ByteBuffer,
requestID: String
) async throws {
Expand All @@ -564,10 +579,10 @@ private final class LambdaChannelHandler<Delegate: LambdaChannelHandlerDelegate>

case .connected(let context, .waitingForResponse):
self.state = .connected(context, .sendingResponse)
try await self.sendResponseBodyPart(byteBuffer, sendHeadWithRequestID: requestID, context: context)
try await self.sendResponseBodyPart(userHeaders, byteBuffer, sendHeadWithRequestID: requestID, context: context)

case .connected(let context, .sendingResponse):
try await self.sendResponseBodyPart(byteBuffer, sendHeadWithRequestID: nil, context: context)
try await self.sendResponseBodyPart(userHeaders, byteBuffer, sendHeadWithRequestID: nil, context: context)

case .connected(_, .idle),
.connected(_, .sentResponse):
Expand All @@ -583,6 +598,7 @@ private final class LambdaChannelHandler<Delegate: LambdaChannelHandlerDelegate>

func finishResponseRequest(
isolation: isolated (any Actor)? = #isolation,
userHeaders: HTTPHeaders,
finalData: ByteBuffer?,
requestID: String
) async throws {
Expand All @@ -594,13 +610,13 @@ private final class LambdaChannelHandler<Delegate: LambdaChannelHandlerDelegate>
case .connected(let context, .waitingForResponse):
try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Void, any Error>) in
self.state = .connected(context, .sentResponse(continuation))
self.sendResponseFinish(finalData, sendHeadWithRequestID: requestID, context: context)
self.sendResponseFinish(userHeaders, finalData, sendHeadWithRequestID: requestID, context: context)
}

case .connected(let context, .sendingResponse):
try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<Void, any Error>) in
self.state = .connected(context, .sentResponse(continuation))
self.sendResponseFinish(finalData, sendHeadWithRequestID: nil, context: context)
self.sendResponseFinish(userHeaders, finalData, sendHeadWithRequestID: nil, context: context)
}

case .connected(_, .sentResponse):
Expand All @@ -616,6 +632,7 @@ private final class LambdaChannelHandler<Delegate: LambdaChannelHandlerDelegate>

private func sendResponseBodyPart(
isolation: isolated (any Actor)? = #isolation,
_ userHeaders: HTTPHeaders,
_ byteBuffer: ByteBuffer,
sendHeadWithRequestID: String?,
context: ChannelHandlerContext
Expand All @@ -625,13 +642,17 @@ private final class LambdaChannelHandler<Delegate: LambdaChannelHandlerDelegate>
// TODO: This feels super expensive. We should be able to make this cheaper. requestIDs are fixed length
let url = Consts.invocationURLPrefix + "/" + requestID + Consts.postResponseURLSuffix

var headers = HTTPHeaders()
headers.add(contentsOf: userHeaders)
headers.add(contentsOf: self.streamingHeaders)
logger.trace("sendResponseBodyPart : ========== Sending response headers: \(headers)")
let httpRequest = HTTPRequestHead(
version: .http1_1,
method: .POST,
uri: url,
headers: self.streamingHeaders
headers: headers // FIXME these are the headers returned to the control plane. I'm not sure if we should use the streaming headers here
)

context.write(self.wrapOutboundOut(.head(httpRequest)), promise: nil)
}

Expand All @@ -642,6 +663,7 @@ private final class LambdaChannelHandler<Delegate: LambdaChannelHandlerDelegate>

private func sendResponseFinish(
isolation: isolated (any Actor)? = #isolation,
_ userHeaders: HTTPHeaders,
_ byteBuffer: ByteBuffer?,
sendHeadWithRequestID: String?,
context: ChannelHandlerContext
Expand All @@ -652,7 +674,7 @@ private final class LambdaChannelHandler<Delegate: LambdaChannelHandlerDelegate>

// If we have less than 6MB, we don't want to use the streaming API. If we have more
// than 6MB we must use the streaming mode.
let headers: HTTPHeaders =
var headers: HTTPHeaders =
if byteBuffer?.readableBytes ?? 0 < 6_000_000 {
[
"host": "\(self.configuration.ip):\(self.configuration.port)",
Expand All @@ -662,14 +684,14 @@ private final class LambdaChannelHandler<Delegate: LambdaChannelHandlerDelegate>
} else {
self.streamingHeaders
}

headers.add(contentsOf: userHeaders)
logger.trace("sendResponseFinish : ========== Sending response headers: \(headers)")
let httpRequest = HTTPRequestHead(
version: .http1_1,
method: .POST,
uri: url,
headers: headers
)

context.write(self.wrapOutboundOut(.head(httpRequest)), promise: nil)
}

Expand Down
Loading