diff --git a/src/Ocelot/Middleware/OcelotPipelineConfiguration.cs b/src/Ocelot/Middleware/OcelotPipelineConfiguration.cs index 403b6e844..273a97786 100644 --- a/src/Ocelot/Middleware/OcelotPipelineConfiguration.cs +++ b/src/Ocelot/Middleware/OcelotPipelineConfiguration.cs @@ -31,6 +31,14 @@ public class OcelotPipelineConfiguration /// public Func, Task> AuthenticationMiddleware { get; set; } + /// + /// This is to allow the user to run any extra authentication after the Ocelot authentication kicks in. + /// + /// + /// A delegate object. + /// + public Func, Task> AfterAuthenticationMiddleware { get; set; } + /// /// This is to allow the user to run any extra authorization before the Ocelot authentication kicks in. /// @@ -47,6 +55,14 @@ public class OcelotPipelineConfiguration /// public Func, Task> AuthorizationMiddleware { get; set; } + /// + /// This is to allow the user to run any extra authorization after the Ocelot authorization kicks in. + /// + /// + /// A delegate object. + /// + public Func, Task> AfterAuthorizationMiddleware { get; set; } + /// /// This allows the user to implement there own query string manipulation logic. /// diff --git a/src/Ocelot/Middleware/OcelotPipelineExtensions.cs b/src/Ocelot/Middleware/OcelotPipelineExtensions.cs index 16f4a5cff..4e601e33d 100644 --- a/src/Ocelot/Middleware/OcelotPipelineExtensions.cs +++ b/src/Ocelot/Middleware/OcelotPipelineExtensions.cs @@ -100,6 +100,9 @@ public static RequestDelegate BuildOcelotPipeline(this IApplicationBuilder app, app.Use(pipelineConfiguration.AuthenticationMiddleware); } + // Allow After authentication logic. The idea being people might want to run something custom after what is built in. + app.UseIfNotNull(pipelineConfiguration.AfterAuthenticationMiddleware); + // The next thing we do is look at any claims transforms in case this is important for authorization app.UseClaimsToClaimsMiddleware(); @@ -119,6 +122,9 @@ public static RequestDelegate BuildOcelotPipeline(this IApplicationBuilder app, app.Use(pipelineConfiguration.AuthorizationMiddleware); } + // Allow after authorization logic. The idea being people might want to run something custom after what is built in. + app.UseIfNotNull(pipelineConfiguration.AfterAuthorizationMiddleware); + // Now we can run the claims to headers transformation middleware app.UseClaimsToHeadersMiddleware(); diff --git a/test/Ocelot.AcceptanceTests/CustomMiddlewareTests.cs b/test/Ocelot.AcceptanceTests/CustomMiddlewareTests.cs index 5fc91dd8d..65ea66589 100644 --- a/test/Ocelot.AcceptanceTests/CustomMiddlewareTests.cs +++ b/test/Ocelot.AcceptanceTests/CustomMiddlewareTests.cs @@ -2,7 +2,7 @@ using Ocelot.Configuration.File; using Ocelot.Middleware; using System.Diagnostics; - + namespace Ocelot.AcceptanceTests { public class CustomMiddlewareTests : IDisposable @@ -19,7 +19,7 @@ public CustomMiddlewareTests() } [Fact] - public void should_call_pre_query_string_builder_middleware() + public void Should_call_pre_query_string_builder_middleware() { var configuration = new OcelotPipelineConfiguration { @@ -28,8 +28,8 @@ public void should_call_pre_query_string_builder_middleware() _counter++; await next.Invoke(); }, - }; - + }; + var port = PortFinder.GetRandomPort(); var fileConfiguration = new FileConfiguration @@ -64,7 +64,7 @@ public void should_call_pre_query_string_builder_middleware() } [Fact] - public void should_call_authorization_middleware() + public void Should_call_authorization_middleware() { var configuration = new OcelotPipelineConfiguration { @@ -73,8 +73,8 @@ public void should_call_authorization_middleware() _counter++; await next.Invoke(); }, - }; - + }; + var port = PortFinder.GetRandomPort(); var fileConfiguration = new FileConfiguration @@ -109,7 +109,7 @@ public void should_call_authorization_middleware() } [Fact] - public void should_call_authentication_middleware() + public void Should_call_authentication_middleware() { var configuration = new OcelotPipelineConfiguration { @@ -118,8 +118,8 @@ public void should_call_authentication_middleware() _counter++; await next.Invoke(); }, - }; - + }; + var port = PortFinder.GetRandomPort(); var fileConfiguration = new FileConfiguration @@ -154,7 +154,7 @@ public void should_call_authentication_middleware() } [Fact] - public void should_call_pre_error_middleware() + public void Should_call_pre_error_middleware() { var configuration = new OcelotPipelineConfiguration { @@ -163,8 +163,8 @@ public void should_call_pre_error_middleware() _counter++; await next.Invoke(); }, - }; - + }; + var port = PortFinder.GetRandomPort(); var fileConfiguration = new FileConfiguration @@ -199,7 +199,7 @@ public void should_call_pre_error_middleware() } [Fact] - public void should_call_pre_authorization_middleware() + public void Should_call_pre_authorization_middleware() { var configuration = new OcelotPipelineConfiguration { @@ -208,8 +208,8 @@ public void should_call_pre_authorization_middleware() _counter++; await next.Invoke(); }, - }; - + }; + var port = PortFinder.GetRandomPort(); var fileConfiguration = new FileConfiguration @@ -244,7 +244,52 @@ public void should_call_pre_authorization_middleware() } [Fact] - public void should_call_pre_http_authentication_middleware() + public void Should_call_after_authorization_middleware() + { + var configuration = new OcelotPipelineConfiguration + { + AfterAuthorizationMiddleware = async (ctx, next) => + { + _counter++; + await next.Invoke(); + }, + }; + + var port = PortFinder.GetRandomPort(); + + var fileConfiguration = new FileConfiguration + { + Routes = new List + { + new() + { + DownstreamPathTemplate = "/", + DownstreamHostAndPorts = new List + { + new() + { + Host = "localhost", + Port = port, + }, + }, + DownstreamScheme = "http", + UpstreamPathTemplate = "/", + UpstreamHttpMethod = new List { "Get" }, + }, + }, + }; + + this.Given(x => x.GivenThereIsAServiceRunningOn($"http://localhost:{port}", 200, "")) + .And(x => _steps.GivenThereIsAConfiguration(fileConfiguration)) + .And(x => _steps.GivenOcelotIsRunning(configuration)) + .When(x => _steps.WhenIGetUrlOnTheApiGateway("/")) + .Then(x => _steps.ThenTheStatusCodeShouldBe(HttpStatusCode.OK)) + .And(x => x.ThenTheCounterIs(1)) + .BDDfy(); + } + + [Fact] + public void Should_call_pre_http_authentication_middleware() { var configuration = new OcelotPipelineConfiguration { @@ -253,8 +298,8 @@ public void should_call_pre_http_authentication_middleware() _counter++; await next.Invoke(); }, - }; - + }; + var port = PortFinder.GetRandomPort(); var fileConfiguration = new FileConfiguration @@ -332,24 +377,69 @@ public void should_not_throw_when_pipeline_terminates_early() .Then(x => _steps.ThenTheStatusCodeShouldBe(HttpStatusCode.OK)) .And(x => x.ThenTheCounterIs(1)) .BDDfy(); - } - + } + + [Fact] + public void Should_call_after_http_authentication_middleware() + { + var configuration = new OcelotPipelineConfiguration + { + AfterAuthenticationMiddleware = async (ctx, next) => + { + _counter++; + await next.Invoke(); + }, + }; + + var port = PortFinder.GetRandomPort(); + + var fileConfiguration = new FileConfiguration + { + Routes = new List + { + new() + { + DownstreamPathTemplate = "/", + DownstreamHostAndPorts = new List + { + new() + { + Host = "localhost", + Port = port, + }, + }, + DownstreamScheme = "http", + UpstreamPathTemplate = "/", + UpstreamHttpMethod = new List { "Get" }, + }, + }, + }; + + this.Given(x => x.GivenThereIsAServiceRunningOn($"http://localhost:{port}", 200, "")) + .And(x => _steps.GivenThereIsAConfiguration(fileConfiguration)) + .And(x => _steps.GivenOcelotIsRunning(configuration)) + .When(x => _steps.WhenIGetUrlOnTheApiGateway("/")) + .Then(x => _steps.ThenTheStatusCodeShouldBe(HttpStatusCode.OK)) + .And(x => x.ThenTheCounterIs(1)) + .BDDfy(); + } + [Fact(Skip = "This is just an example to show how you could hook into Ocelot pipeline with your own middleware. At the moment you must use Response.OnCompleted callback and cannot change the response :( I will see if this can be changed one day!")] - public void should_fix_issue_237() + public void Should_fix_issue_237() { Func callback = state => { var httpContext = (HttpContext)state; if (httpContext.Response.StatusCode > 400) - { + { Debug.WriteLine("COUNT CALLED"); Console.WriteLine("COUNT CALLED"); } return Task.CompletedTask; - }; - + }; + var port = PortFinder.GetRandomPort(); var fileConfiguration = new FileConfiguration @@ -406,15 +496,16 @@ private void GivenThereIsAServiceRunningOn(string url, int statusCode, string ba public void Dispose() { - _serviceHandler?.Dispose(); - _steps.Dispose(); + _serviceHandler.Dispose(); + _steps.Dispose(); + GC.SuppressFinalize(this); } public class FakeMiddleware { private readonly RequestDelegate _next; - private readonly Func _callback; - + private readonly Func _callback; + public FakeMiddleware(RequestDelegate next, Func callback) { _next = next; @@ -423,10 +514,10 @@ public FakeMiddleware(RequestDelegate next, Func callback) public async Task Invoke(HttpContext context) { - await _next(context); - + await _next(context); + context.Response.OnCompleted(_callback, context); } } } -} +}