1
- #if NET7_0_OR_GREATER
1
+ using Microsoft . AspNetCore . Http ;
2
+ #if NET7_0_OR_GREATER
2
3
using Microsoft . AspNetCore . RateLimiting ;
3
4
#endif
4
- using Microsoft . AspNetCore . Http ;
5
+ using Microsoft . Net . Http . Headers ;
5
6
using Ocelot . Configuration ;
6
7
using Ocelot . Logging ;
7
8
using Ocelot . Middleware ;
@@ -13,15 +14,18 @@ public class RateLimitingMiddleware : OcelotMiddleware
13
14
{
14
15
private readonly RequestDelegate _next ;
15
16
private readonly IRateLimiting _limiter ;
17
+ private readonly IHttpContextAccessor _contextAccessor ;
16
18
17
19
public RateLimitingMiddleware (
18
20
RequestDelegate next ,
19
21
IOcelotLoggerFactory factory ,
20
- IRateLimiting limiter )
22
+ IRateLimiting limiter ,
23
+ IHttpContextAccessor contextAccessor )
21
24
: base ( factory . CreateLogger < RateLimitingMiddleware > ( ) )
22
25
{
23
26
_next = next ;
24
27
_limiter = limiter ;
28
+ _contextAccessor = contextAccessor ;
25
29
}
26
30
27
31
public async Task Invoke ( HttpContext httpContext )
@@ -83,11 +87,15 @@ public async Task Invoke(HttpContext httpContext)
83
87
}
84
88
}
85
89
86
- //set X-Rate-Limit headers for the longest period
90
+ // Set X-Rate-Limit headers for the longest period
87
91
if ( ! options . DisableRateLimitHeaders )
88
92
{
89
- var headers = _limiter . GetHeaders ( httpContext , identity , options ) ;
90
- httpContext . Response . OnStarting ( SetRateLimitHeaders , state : headers ) ;
93
+ var originalContext = _contextAccessor ? . HttpContext ;
94
+ if ( originalContext != null )
95
+ {
96
+ var headers = _limiter . GetHeaders ( originalContext , identity , options ) ;
97
+ originalContext . Response . OnStarting ( SetRateLimitHeaders , state : headers ) ;
98
+ }
91
99
}
92
100
93
101
await _next . Invoke ( httpContext ) ;
@@ -108,15 +116,8 @@ public virtual ClientRequestIdentity SetIdentity(HttpContext httpContext, RateLi
108
116
) ;
109
117
}
110
118
111
- public bool IsWhitelisted ( ClientRequestIdentity requestIdentity , RateLimitOptions option )
112
- {
113
- if ( option . ClientWhitelist . Contains ( requestIdentity . ClientId ) )
114
- {
115
- return true ;
116
- }
117
-
118
- return false ;
119
- }
119
+ public static bool IsWhitelisted ( ClientRequestIdentity requestIdentity , RateLimitOptions option )
120
+ => option . ClientWhitelist . Contains ( requestIdentity . ClientId ) ;
120
121
121
122
public virtual void LogBlockedRequest ( HttpContext httpContext , ClientRequestIdentity identity , RateLimitCounter counter , RateLimitRule rule , DownstreamRoute downstreamRoute )
122
123
{
@@ -127,14 +128,15 @@ public virtual void LogBlockedRequest(HttpContext httpContext, ClientRequestIden
127
128
public virtual DownstreamResponse ReturnQuotaExceededResponse ( HttpContext httpContext , RateLimitOptions option , string retryAfter )
128
129
{
129
130
var message = GetResponseMessage ( option ) ;
130
-
131
- var http = new HttpResponseMessage ( ( HttpStatusCode ) option . HttpStatusCode ) ;
132
-
133
- http . Content = new StringContent ( message ) ;
131
+ var http = new HttpResponseMessage ( ( HttpStatusCode ) option . HttpStatusCode )
132
+ {
133
+ Content = new StringContent ( message ) ,
134
+ } ;
134
135
135
136
if ( ! option . DisableRateLimitHeaders )
136
137
{
137
- http . Headers . TryAddWithoutValidation ( "Retry-After" , retryAfter ) ; // in seconds, not date string
138
+ http . Headers . TryAddWithoutValidation ( HeaderNames . RetryAfter , retryAfter ) ; // in seconds, not date string
139
+ httpContext . Response . Headers [ HeaderNames . RetryAfter ] = retryAfter ;
138
140
}
139
141
140
142
return new DownstreamResponse ( http ) ;
@@ -148,14 +150,17 @@ private static string GetResponseMessage(RateLimitOptions option)
148
150
return message ;
149
151
}
150
152
151
- private static Task SetRateLimitHeaders ( object rateLimitHeaders )
153
+ /// <summary>TODO: Produced Ocelot's headers don't follow industry standards.</summary>
154
+ /// <remarks>More details in <see cref="RateLimitingHeaders"/> docs.</remarks>
155
+ /// <param name="state">Captured state as a <see cref="RateLimitHeaders"/> object.</param>
156
+ /// <returns>The <see cref="Task.CompletedTask"/> object.</returns>
157
+ private static Task SetRateLimitHeaders ( object state )
152
158
{
153
- var headers = ( RateLimitHeaders ) rateLimitHeaders ;
154
-
155
- headers . Context . Response . Headers [ "X-Rate-Limit-Limit" ] = headers . Limit ;
156
- headers . Context . Response . Headers [ "X-Rate-Limit-Remaining" ] = headers . Remaining ;
157
- headers . Context . Response . Headers [ "X-Rate-Limit-Reset" ] = headers . Reset ;
158
-
159
+ var limitHeaders = ( RateLimitHeaders ) state ;
160
+ var headers = limitHeaders . Context . Response . Headers ;
161
+ headers [ RateLimitingHeaders . X_Rate_Limit_Limit ] = limitHeaders . Limit ;
162
+ headers [ RateLimitingHeaders . X_Rate_Limit_Remaining ] = limitHeaders . Remaining ;
163
+ headers [ RateLimitingHeaders . X_Rate_Limit_Reset ] = limitHeaders . Reset ;
159
164
return Task . CompletedTask ;
160
165
}
161
166
}
0 commit comments