Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ let package = Package(
.package(url: "https://github.com/swift-server/async-http-client.git", from: "1.20.1"),
.package(url: "https://github.com/orlandos-nl/DNSClient.git", from: "2.4.1"),
.package(url: "https://github.com/Bouke/DNS.git", from: "1.2.0"),
.package(url: "https://github.com/apple/containerization.git", exact: Version(stringLiteral: scVersion)),
.package(url: "https://github.com/apple/containerization.git", branch: "main"),
],
targets: [
.executableTarget(
Expand Down
7 changes: 4 additions & 3 deletions Sources/ContainerClient/Core/ClientImage.swift
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ extension ClientImage {
})
}

public static func pull(reference: String, platform: Platform? = nil, scheme: RequestScheme = .auto, progressUpdate: ProgressUpdateHandler? = nil) async throws -> ClientImage {
public static func pull(reference: String, platform: Platform? = nil, scheme: RequestScheme = .auto, progressUpdate: ProgressUpdateHandler? = nil, maxConcurrentDownloads: Int = 3) async throws -> ClientImage {
let client = newXPCClient()
let request = newRequest(.imagePull)

Expand All @@ -234,6 +234,7 @@ extension ClientImage {

let insecure = try scheme.schemeFor(host: host) == .http
request.set(key: .insecureFlag, value: insecure)
request.set(key: .maxConcurrentDownloads, value: Int64(maxConcurrentDownloads))

var progressUpdateClient: ProgressUpdateClient?
if let progressUpdate {
Expand Down Expand Up @@ -293,7 +294,7 @@ extension ClientImage {
return (digests, size)
}

public static func fetch(reference: String, platform: Platform? = nil, scheme: RequestScheme = .auto, progressUpdate: ProgressUpdateHandler? = nil) async throws -> ClientImage
public static func fetch(reference: String, platform: Platform? = nil, scheme: RequestScheme = .auto, progressUpdate: ProgressUpdateHandler? = nil, maxConcurrentDownloads: Int = 3) async throws -> ClientImage
{
do {
let match = try await self.get(reference: reference)
Expand All @@ -307,7 +308,7 @@ extension ClientImage {
guard err.isCode(.notFound) else {
throw err
}
return try await Self.pull(reference: reference, platform: platform, scheme: scheme, progressUpdate: progressUpdate)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm somewhat confused by the commit description however. I thought we had found we do pull layers in parallel, but that it's just a hardcoded 8.

Copy link
Author

@sbhavani sbhavani Oct 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry I originally thought it wasn't pulling in parallel but I noticed it was hardcoded later

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated original issue to more accurately reflect this: #715

return try await Self.pull(reference: reference, platform: platform, scheme: scheme, progressUpdate: progressUpdate, maxConcurrentDownloads: maxConcurrentDownloads)
}
}
}
Expand Down
8 changes: 8 additions & 0 deletions Sources/ContainerClient/Flags.swift
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,15 @@ public struct Flags {
self.disableProgressUpdates = disableProgressUpdates
}

public init(disableProgressUpdates: Bool, maxConcurrentDownloads: Int) {
self.disableProgressUpdates = disableProgressUpdates
self.maxConcurrentDownloads = maxConcurrentDownloads
}

@Flag(name: .long, help: "Disable progress bar updates")
public var disableProgressUpdates = false

@Option(name: .long, help: "Maximum number of concurrent layer downloads (default: 3)")
public var maxConcurrentDownloads: Int = 3
}
}
2 changes: 1 addition & 1 deletion Sources/ContainerCommands/Image/ImagePull.swift
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ extension Application {
let taskManager = ProgressTaskCoordinator()
let fetchTask = await taskManager.startTask()
let image = try await ClientImage.pull(
reference: processedReference, platform: p, scheme: scheme, progressUpdate: ProgressTaskCoordinator.handler(for: fetchTask, from: progress.handler)
reference: processedReference, platform: p, scheme: scheme, progressUpdate: ProgressTaskCoordinator.handler(for: fetchTask, from: progress.handler), maxConcurrentDownloads: self.progressFlags.maxConcurrentDownloads
)

progress.set(description: "Unpacking image")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public enum ImagesServiceXPCKeys: String {
case ociPlatform
case insecureFlag
case garbageCollect
case maxConcurrentDownloads

/// ContentStore
case digest
Expand All @@ -54,6 +55,10 @@ extension XPCMessage {
self.set(key: key.rawValue, value: value)
}

public func set(key: ImagesServiceXPCKeys, value: Int64) {
self.set(key: key.rawValue, value: value)
}

public func set(key: ImagesServiceXPCKeys, value: Data) {
self.set(key: key.rawValue, value: value)
}
Expand All @@ -78,6 +83,10 @@ extension XPCMessage {
self.uint64(key: key.rawValue)
}

public func int64(key: ImagesServiceXPCKeys) -> Int64 {
self.int64(key: key.rawValue)
}

public func bool(key: ImagesServiceXPCKeys) -> Bool {
self.bool(key: key.rawValue)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,11 @@ public actor ImagesService {
return try await imageStore.list().map { $0.description.fromCZ }
}

public func pull(reference: String, platform: Platform?, insecure: Bool, progressUpdate: ProgressUpdateHandler?) async throws -> ImageDescription {
self.log.info("ImagesService: \(#function) - ref: \(reference), platform: \(String(describing: platform)), insecure: \(insecure)")
public func pull(reference: String, platform: Platform?, insecure: Bool, progressUpdate: ProgressUpdateHandler?, maxConcurrentDownloads: Int = 3) async throws -> ImageDescription {
self.log.info("ImagesService: \(#function) - ref: \(reference), platform: \(String(describing: platform)), insecure: \(insecure), maxConcurrentDownloads: \(maxConcurrentDownloads)")
let img = try await Self.withAuthentication(ref: reference) { auth in
try await self.imageStore.pull(
reference: reference, platform: platform, insecure: insecure, auth: auth, progress: ContainerizationProgressAdapter.handler(from: progressUpdate))
reference: reference, platform: platform, insecure: insecure, auth: auth, progress: ContainerizationProgressAdapter.handler(from: progressUpdate), maxConcurrentDownloads: maxConcurrentDownloads)
}
guard let img else {
throw ContainerizationError(.internalError, message: "Failed to pull image \(reference)")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@ public struct ImagesServiceHarness: Sendable {
platform = try JSONDecoder().decode(ContainerizationOCI.Platform.self, from: platformData)
}
let insecure = message.bool(key: .insecureFlag)
let maxConcurrentDownloads = message.int64(key: .maxConcurrentDownloads)

let progressUpdateService = ProgressUpdateService(message: message)
let imageDescription = try await service.pull(reference: ref, platform: platform, insecure: insecure, progressUpdate: progressUpdateService?.handler)
let imageDescription = try await service.pull(reference: ref, platform: platform, insecure: insecure, progressUpdate: progressUpdateService?.handler, maxConcurrentDownloads: Int(maxConcurrentDownloads))

let imageData = try JSONEncoder().encode(imageDescription)
let reply = message.reply()
Expand Down
49 changes: 47 additions & 2 deletions Sources/TerminalProgress/ProgressTaskCoordinator.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
import Foundation

/// A type that represents a task whose progress is being monitored.
public struct ProgressTask: Sendable, Equatable {
public struct ProgressTask: Sendable, Equatable, Hashable {
private var id = UUID()
private var coordinator: ProgressTaskCoordinator
internal var coordinator: ProgressTaskCoordinator

init(manager: ProgressTaskCoordinator) {
self.coordinator = manager
Expand All @@ -29,6 +29,10 @@ public struct ProgressTask: Sendable, Equatable {
lhs.id == rhs.id
}

public func hash(into hasher: inout Hasher) {
hasher.combine(id)
}

/// Returns `true` if this task is the currently active task, `false` otherwise.
public func isCurrent() async -> Bool {
guard let currentTask = await coordinator.currentTask else {
Expand All @@ -41,6 +45,7 @@ public struct ProgressTask: Sendable, Equatable {
/// A type that coordinates progress tasks to ignore updates from completed tasks.
public actor ProgressTaskCoordinator {
var currentTask: ProgressTask?
var activeTasks: Set<ProgressTask> = []

/// Creates an instance of `ProgressTaskCoordinator`.
public init() {}
Expand All @@ -52,9 +57,36 @@ public actor ProgressTaskCoordinator {
return newTask
}

/// Starts multiple concurrent tasks and returns them.
/// - Parameter count: The number of concurrent tasks to start.
/// - Returns: An array of ProgressTask instances.
public func startConcurrentTasks(count: Int) -> [ProgressTask] {
var tasks: [ProgressTask] = []
for _ in 0..<count {
let task = ProgressTask(manager: self)
tasks.append(task)
activeTasks.insert(task)
}
return tasks
}

/// Marks a specific task as completed and removes it from active tasks.
/// - Parameter task: The task to mark as completed.
public func completeTask(_ task: ProgressTask) {
activeTasks.remove(task)
}

/// Checks if a task is currently active.
/// - Parameter task: The task to check.
/// - Returns: `true` if the task is active, `false` otherwise.
public func isTaskActive(_ task: ProgressTask) -> Bool {
activeTasks.contains(task)
}

/// Performs cleanup when the monitored tasks complete.
public func finish() {
currentTask = nil
activeTasks.removeAll()
}

/// Returns a handler that updates the progress of a given task.
Expand All @@ -69,4 +101,17 @@ public actor ProgressTaskCoordinator {
}
}
}

/// Returns a handler that updates the progress for concurrent tasks.
/// - Parameters:
/// - task: The task whose progress is being updated.
/// - progressUpdate: The handler to invoke when progress updates are received.
public static func concurrentHandler(for task: ProgressTask, from progressUpdate: @escaping ProgressUpdateHandler) -> ProgressUpdateHandler {
{ events in
// Only process updates if the task is still active
if await task.coordinator.isTaskActive(task) {
await progressUpdate(events)
}
}
}
}
109 changes: 109 additions & 0 deletions test_concurrency.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
#!/usr/bin/env swift
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this and test_parameter_flow.swift in the top level directory?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should they belong in Tests/CLITests/Subcommands/Images/?


import Foundation

func testConcurrentDownloads() async throws {
print("Testing concurrent download behavior...\n")

// Track concurrent task count
actor ConcurrencyTracker {
var currentCount = 0
var maxObservedCount = 0
var completedTasks = 0

func taskStarted() {
currentCount += 1
maxObservedCount = max(maxObservedCount, currentCount)
}

func taskCompleted() {
currentCount -= 1
completedTasks += 1
}

func getStats() -> (max: Int, completed: Int) {
return (maxObservedCount, completedTasks)
}

func reset() {
currentCount = 0
maxObservedCount = 0
completedTasks = 0
}
}

let tracker = ConcurrencyTracker()

// Test with different concurrency limits
for maxConcurrent in [1, 3, 6] {
await tracker.reset()

// Simulate downloading 20 layers
let layerCount = 20
let layers = Array(0..<layerCount)

print("Testing maxConcurrent=\(maxConcurrent) with \(layerCount) layers...")

let startTime = Date()

try await withThrowingTaskGroup(of: Void.self) { group in
var iterator = layers.makeIterator()

// Start initial batch based on maxConcurrent
for _ in 0..<maxConcurrent {
if iterator.next() != nil {
group.addTask {
await tracker.taskStarted()
try await Task.sleep(nanoseconds: 10_000_000)
await tracker.taskCompleted()
}
}
}
for try await _ in group {
if iterator.next() != nil {
group.addTask {
await tracker.taskStarted()
try await Task.sleep(nanoseconds: 10_000_000)
await tracker.taskCompleted()
}
}
}
}

let duration = Date().timeIntervalSince(startTime)
let stats = await tracker.getStats()

print(" ✓ Completed: \(stats.completed)/\(layerCount)")
print(" ✓ Max concurrent: \(stats.max)")
print(" ✓ Duration: \(String(format: "%.3f", duration))s")

guard stats.max <= maxConcurrent + 1 else {
throw TestError.concurrencyLimitExceeded
}

guard stats.completed == layerCount else {
throw TestError.incompleteTasks
}

print(" ✅ PASSED\n")
}

print("All tests passed!")
}

enum TestError: Error {
case concurrencyLimitExceeded
case incompleteTasks
}

Task {
do {
try await testConcurrentDownloads()
exit(0)
} catch {
print("Test failed: \(error)")
exit(1)
}
}

RunLoop.main.run()
Loading