1515
1616package software .amazon .awssdk .services .s3 .internal .multipart ;
1717
18+ import java .util .Collections ;
19+ import java .util .HashSet ;
1820import java .util .Map ;
21+ import java .util .Optional ;
1922import java .util .Queue ;
23+ import java .util .Set ;
2024import java .util .concurrent .CompletableFuture ;
2125import java .util .concurrent .ConcurrentHashMap ;
2226import java .util .concurrent .ConcurrentLinkedQueue ;
3034import software .amazon .awssdk .services .s3 .model .GetObjectRequest ;
3135import software .amazon .awssdk .services .s3 .model .GetObjectResponse ;
3236import software .amazon .awssdk .utils .CompletableFutureUtils ;
37+ import software .amazon .awssdk .utils .ContentRangeParser ;
3338import software .amazon .awssdk .utils .Logger ;
3439import software .amazon .awssdk .utils .Pair ;
3540
@@ -66,7 +71,7 @@ public class ParallelMultipartDownloaderSubscriber
6671 * The total number of completed parts. A part is considered complete once the completable future associated with its request
6772 * completes successfully.
6873 */
69- private final AtomicInteger completedParts = new AtomicInteger () ;
74+ private final AtomicInteger completedParts ;
7075
7176 /**
7277 * The future returned to the user when calling
@@ -80,7 +85,7 @@ public class ParallelMultipartDownloaderSubscriber
8085 * The {@link GetObjectResponse} to be returned in the completed future to the user. It corresponds to the response of first
8186 * part GetObject
8287 */
83- private GetObjectResponse getObjectResponse ;
88+ private volatile GetObjectResponse getObjectResponse ;
8489
8590 /**
8691 * The subscription received from the publisher this subscriber subscribes to.
@@ -135,12 +140,17 @@ public class ParallelMultipartDownloaderSubscriber
135140 private final AtomicInteger partNumber = new AtomicInteger (0 );
136141
137142 /**
138- * Tracks if one of the parts requests future completed exceptionally. If this occurs, it means all retries were
139- * attempted for that part, but it still failed. This is a failure state, the error should be reported back to the user
140- * and any more request should be ignored.
143+ * Tracks if one of the parts requests future completed exceptionally. If this occurs, it means all retries were attempted for
144+ * that part, but it still failed. This is a failure state, the error should be reported back to the user and any more request
145+ * should be ignored.
141146 */
142147 private final AtomicBoolean isCompletedExceptionally = new AtomicBoolean (false );
143148
149+ /**
150+ * When resuming a paused download, indicates which parts were already completed before pausing.
151+ */
152+ private final Set <Integer > initialCompletedParts ;
153+
144154 public ParallelMultipartDownloaderSubscriber (S3AsyncClient s3 ,
145155 GetObjectRequest getObjectRequest ,
146156 CompletableFuture <GetObjectResponse > resultFuture ,
@@ -149,6 +159,36 @@ public ParallelMultipartDownloaderSubscriber(S3AsyncClient s3,
149159 this .getObjectRequest = getObjectRequest ;
150160 this .resultFuture = resultFuture ;
151161 this .maxInFlightParts = maxInFlightParts ;
162+ this .initialCompletedParts = initialCompletedParts (getObjectRequest );
163+ this .completedParts = new AtomicInteger (initialCompletedParts .size ());
164+
165+ if (resumingDownload ()) {
166+ int totalPartsFromInitialRequest = MultipartDownloadUtils .multipartDownloadResumeContext (getObjectRequest )
167+ .map (MultipartDownloadResumeContext ::totalParts )
168+ .orElse (0 );
169+ if (totalPartsFromInitialRequest > 0 ) {
170+ totalPartsFuture .complete (totalPartsFromInitialRequest );
171+ }
172+ getObjectResponse = MultipartDownloadUtils .multipartDownloadResumeContext (getObjectRequest )
173+ .map (MultipartDownloadResumeContext ::response )
174+ .orElse (null );
175+ }
176+ }
177+
178+ private static Set <Integer > initialCompletedParts (GetObjectRequest getObjectRequest ) {
179+ return Collections .unmodifiableSet (
180+ MultipartDownloadUtils .multipartDownloadResumeContext (getObjectRequest )
181+ .map (MultipartDownloadResumeContext ::completedParts )
182+ .<Set <Integer >>map (HashSet ::new )
183+ .orElse (Collections .emptySet ())
184+ );
185+ }
186+
187+ private boolean resumingDownload () {
188+ Optional <Boolean > hasAlreadyCompletedParts =
189+ MultipartDownloadUtils .multipartDownloadResumeContext (getObjectRequest )
190+ .map (ctx -> !ctx .completedParts ().isEmpty ());
191+ return hasAlreadyCompletedParts .orElse (false );
152192 }
153193
154194 @ Override
@@ -176,7 +216,7 @@ public void onNext(AsyncResponseTransformer<GetObjectResponse, GetObjectResponse
176216 + " - Total pending transformers: " + pendingTransformers .size ()
177217 + " - Current in flight requests: " + inFlightRequests .keySet ());
178218
179- int currentPartNum = partNumber . incrementAndGet ();
219+ int currentPartNum = nextPart ();
180220
181221 if (currentPartNum == 1 ) {
182222 sendFirstRequest (asyncResponseTransformer );
@@ -188,7 +228,7 @@ public void onNext(AsyncResponseTransformer<GetObjectResponse, GetObjectResponse
188228 }
189229
190230 private void processingRequests (AsyncResponseTransformer <GetObjectResponse , GetObjectResponse > asyncResponseTransformer ,
191- int currentPartNum , Integer totalParts ) {
231+ int currentPartNum , int totalParts ) {
192232
193233 if (currentPartNum > totalParts ) {
194234 // Do not process requests above total parts.
@@ -203,6 +243,7 @@ private void processingRequests(AsyncResponseTransformer<GetObjectResponse, GetO
203243 return ;
204244 }
205245
246+ sendNextRequest (asyncResponseTransformer , currentPartNum , totalParts );
206247 processPendingTransformers (totalParts );
207248 }
208249
@@ -233,11 +274,14 @@ private void sendNextRequest(AsyncResponseTransformer<GetObjectResponse, GetObje
233274 inFlightRequests .remove (currentPartNumber );
234275 inFlightRequestsNum .decrementAndGet ();
235276 completedParts .incrementAndGet ();
277+ MultipartDownloadUtils .multipartDownloadResumeContext (getObjectRequest )
278+ .ifPresent (ctx -> ctx .addCompletedPart (currentPartNumber ));
236279
237280 if (completedParts .get () >= totalParts ) {
238281 if (completedParts .get () > totalParts ) {
239282 resultFuture .completeExceptionally (new IllegalStateException ("Total parts exceeded" ));
240283 } else {
284+ updateResumeContextForCompletion (res );
241285 resultFuture .complete (getObjectResponse );
242286 }
243287
@@ -254,6 +298,14 @@ private void sendNextRequest(AsyncResponseTransformer<GetObjectResponse, GetObje
254298 });
255299 }
256300
301+ private void updateResumeContextForCompletion (GetObjectResponse response ) {
302+ ContentRangeParser .totalBytes (response .contentRange ())
303+ .ifPresent (total -> MultipartDownloadUtils
304+ .multipartDownloadResumeContext (getObjectRequest )
305+ .ifPresent (ctx ->
306+ ctx .addToBytesToLastCompletedParts (total )));
307+ }
308+
257309 private void sendFirstRequest (AsyncResponseTransformer <GetObjectResponse , GetObjectResponse > asyncResponseTransformer ) {
258310 log .debug (() -> "Sending first request" );
259311 GetObjectRequest request = nextRequest (1 );
@@ -282,6 +334,13 @@ private void sendFirstRequest(AsyncResponseTransformer<GetObjectResponse, GetObj
282334 getObjectResponse = res ;
283335
284336 processPendingTransformers (res .partsCount ());
337+ MultipartDownloadUtils .multipartDownloadResumeContext (getObjectRequest )
338+ .ifPresent (ctx -> {
339+ ctx .addCompletedPart (1 );
340+ ctx .response (res );
341+ ctx .totalParts (res .partsCount ());
342+ });
343+
285344 synchronized (subscriptionLock ) {
286345 subscription .request (1 );
287346 }
@@ -312,7 +371,7 @@ private void setInitialPartCountAndEtag(GetObjectResponse response) {
312371
313372 private void handlePartError (Throwable e , int part ) {
314373 isCompletedExceptionally .set (true );
315- log .debug (() -> "Error on part " + part , e );
374+ log .debug (() -> "Error on part " + part , e );
316375 resultFuture .completeExceptionally (e );
317376 inFlightRequests .values ().forEach (future -> future .cancel (true ));
318377 }
@@ -334,9 +393,12 @@ private void processPendingTransformers(int totalParts) {
334393
335394 private void doProcessPendingTransformers (int totalParts ) {
336395 while (shouldProcessPendingTransformers ()) {
337- Pair <Integer , AsyncResponseTransformer <GetObjectResponse , GetObjectResponse >> transformer =
338- pendingTransformers .poll ();
339- sendNextRequest (transformer .right (), transformer .left (), totalParts );
396+ Pair <Integer , AsyncResponseTransformer <GetObjectResponse , GetObjectResponse >> pair = pendingTransformers .poll ();
397+ Integer part = pair .left ();
398+ AsyncResponseTransformer <GetObjectResponse , GetObjectResponse > transformer = pair .right ();
399+ if (part <= totalParts ) {
400+ sendNextRequest (transformer , part , totalParts );
401+ }
340402 }
341403 }
342404
@@ -372,4 +434,18 @@ private GetObjectRequest nextRequest(int nextPartToGet) {
372434 });
373435 }
374436
437+ private int nextPart () {
438+ if (initialCompletedParts .isEmpty ()) {
439+ return partNumber .incrementAndGet ();
440+ }
441+
442+ synchronized (initialCompletedParts ) {
443+ int part = partNumber .incrementAndGet ();
444+ while (initialCompletedParts .contains (part )) {
445+ part = partNumber .incrementAndGet ();
446+ }
447+ return part ;
448+ }
449+ }
450+
375451}
0 commit comments