diff --git a/Tests.Vpn.Service/DownloaderTest.cs b/Tests.Vpn.Service/DownloaderTest.cs index 985e331..8b55f50 100644 --- a/Tests.Vpn.Service/DownloaderTest.cs +++ b/Tests.Vpn.Service/DownloaderTest.cs @@ -284,6 +284,34 @@ public async Task Download(CancellationToken ct) Assert.That(await File.ReadAllTextAsync(destPath, ct), Is.EqualTo("test")); } + [Test(Description = "Perform 2 downloads with the same destination")] + [CancelAfter(30_000)] + public async Task DownloadSameDest(CancellationToken ct) + { + using var httpServer = EchoServer(); + var url0 = new Uri(httpServer.BaseUrl + "/test0"); + var url1 = new Uri(httpServer.BaseUrl + "/test1"); + var destPath = Path.Combine(_tempDir, "test"); + + var manager = new Downloader(NullLogger.Instance); + var startTask0 = manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url0), destPath, + NullDownloadValidator.Instance, ct); + var startTask1 = manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url1), destPath, + NullDownloadValidator.Instance, ct); + var dlTask0 = await startTask0; + await dlTask0.Task; + Assert.That(dlTask0.TotalBytes, Is.EqualTo(5)); + Assert.That(dlTask0.BytesRead, Is.EqualTo(5)); + Assert.That(dlTask0.Progress, Is.EqualTo(1)); + Assert.That(dlTask0.IsCompleted, Is.True); + var dlTask1 = await startTask1; + await dlTask1.Task; + Assert.That(dlTask1.TotalBytes, Is.EqualTo(5)); + Assert.That(dlTask1.BytesRead, Is.EqualTo(5)); + Assert.That(dlTask1.Progress, Is.EqualTo(1)); + Assert.That(dlTask1.IsCompleted, Is.True); + } + [Test(Description = "Download with custom headers")] [CancelAfter(30_000)] public async Task WithHeaders(CancellationToken ct) @@ -347,17 +375,17 @@ public async Task DownloadExistingDifferentContent(CancellationToken ct) [Test(Description = "Unexpected response code from server")] [CancelAfter(30_000)] - public void UnexpectedResponseCode(CancellationToken ct) + public async Task UnexpectedResponseCode(CancellationToken ct) { using var httpServer = new TestHttpServer(ctx => { ctx.Response.StatusCode = 404; }); var url = new Uri(httpServer.BaseUrl + "/test"); var destPath = Path.Combine(_tempDir, "test"); var manager = new Downloader(NullLogger.Instance); - // The "outer" Task should fail. - var ex = Assert.ThrowsAsync(async () => - await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath, - NullDownloadValidator.Instance, ct)); + // The "inner" Task should fail. + var dlTask = await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath, + NullDownloadValidator.Instance, ct); + var ex = Assert.ThrowsAsync(async () => await dlTask.Task); Assert.That(ex.Message, Does.Contain("404")); } @@ -384,22 +412,6 @@ public async Task MismatchedETag(CancellationToken ct) Assert.That(ex.Message, Does.Contain("ETag does not match SHA1 hash of downloaded file").And.Contains("beef")); } - [Test(Description = "Timeout on response headers")] - [CancelAfter(30_000)] - public void CancelledOuter(CancellationToken ct) - { - using var httpServer = new TestHttpServer(async _ => { await Task.Delay(TimeSpan.FromSeconds(5), ct); }); - var url = new Uri(httpServer.BaseUrl + "/test"); - var destPath = Path.Combine(_tempDir, "test"); - - var manager = new Downloader(NullLogger.Instance); - // The "outer" Task should fail. - var smallerCt = new CancellationTokenSource(TimeSpan.FromSeconds(1)).Token; - Assert.ThrowsAsync( - async () => await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath, - NullDownloadValidator.Instance, smallerCt)); - } - [Test(Description = "Timeout on response body")] [CancelAfter(30_000)] public async Task CancelledInner(CancellationToken ct) @@ -451,12 +463,10 @@ public async Task ValidationFailureExistingFile(CancellationToken ct) await File.WriteAllTextAsync(destPath, "test", ct); var manager = new Downloader(NullLogger.Instance); - // The "outer" Task should fail because the inner task never starts. - var ex = Assert.ThrowsAsync(async () => - { - await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath, - new TestDownloadValidator(new Exception("test exception")), ct); - }); + var dlTask = await manager.StartDownloadAsync(new HttpRequestMessage(HttpMethod.Get, url), destPath, + new TestDownloadValidator(new Exception("test exception")), ct); + // The "inner" Task should fail. + var ex = Assert.ThrowsAsync(async () => { await dlTask.Task; }); Assert.That(ex.Message, Does.Contain("Existing file failed validation")); Assert.That(ex.InnerException, Is.Not.Null); Assert.That(ex.InnerException!.Message, Is.EqualTo("test exception")); diff --git a/Vpn.Service/Downloader.cs b/Vpn.Service/Downloader.cs index a37a1ec..6a665ae 100644 --- a/Vpn.Service/Downloader.cs +++ b/Vpn.Service/Downloader.cs @@ -3,6 +3,7 @@ using System.Formats.Asn1; using System.Net; using System.Runtime.CompilerServices; +using System.Runtime.ExceptionServices; using System.Security.Cryptography; using System.Security.Cryptography.X509Certificates; using Coder.Desktop.Vpn.Utilities; @@ -288,7 +289,26 @@ public async Task StartDownloadAsync(HttpRequestMessage req, strin { var task = _downloads.GetOrAdd(destinationPath, _ => new DownloadTask(_logger, req, destinationPath, validator)); - await task.EnsureStartedAsync(ct); + // EnsureStarted is a no-op if we didn't create a new DownloadTask. + // So, we will only remove the destination once for each time we start a new task. + task.EnsureStarted(tsk => + { + // remove the key first, before checking the exception, to ensure + // we still clean up. + _downloads.TryRemove(destinationPath, out _); + if (tsk.Exception == null) + { + return; + } + + if (tsk.Exception.InnerException != null) + { + ExceptionDispatchInfo.Capture(tsk.Exception.InnerException).Throw(); + } + + // not sure if this is hittable, but just in case: + throw tsk.Exception; + }, ct); // If the existing (or new) task is for the same URL, return it. if (task.Request.RequestUri == req.RequestUri) @@ -357,13 +377,11 @@ internal DownloadTask(ILogger logger, HttpRequestMessage req, string destination ".download-" + Path.GetRandomFileName()); } - internal async Task EnsureStartedAsync(CancellationToken ct = default) + internal void EnsureStarted(Action continuation, CancellationToken ct = default) { - using var _ = await _semaphore.LockAsync(ct); + using var _ = _semaphore.Lock(); if (Task == null!) - Task = await StartDownloadAsync(ct); - - return Task; + Task = Start(ct).ContinueWith(continuation, ct); } /// @@ -371,7 +389,7 @@ internal async Task EnsureStartedAsync(CancellationToken ct = default) /// and the download will continue in the background. The provided CancellationToken can be used to cancel the /// download. /// - private async Task StartDownloadAsync(CancellationToken ct = default) + private async Task Start(CancellationToken ct = default) { Directory.CreateDirectory(_destinationDirectory); @@ -398,8 +416,7 @@ private async Task StartDownloadAsync(CancellationToken ct = default) throw new Exception("Existing file failed validation after 304 Not Modified", e); } - Task = Task.CompletedTask; - return Task; + return; } if (res.StatusCode != HttpStatusCode.OK) @@ -432,11 +449,11 @@ private async Task StartDownloadAsync(CancellationToken ct = default) throw; } - Task = DownloadAsync(res, tempFile, ct); - return Task; + await Download(res, tempFile, ct); + return; } - private async Task DownloadAsync(HttpResponseMessage res, FileStream tempFile, CancellationToken ct) + private async Task Download(HttpResponseMessage res, FileStream tempFile, CancellationToken ct) { try {