9191//! # Ok(())
9292//! # }
9393//! ```
94+ //!
95+ //! ## Customizing extensions
96+ //!
97+ //! You can use [`FollowRedirectLayer::with_policy_extension()`]
98+ //! to also set the [`FollowedPolicy`] extension on the response.
99+ //!
100+ //! ```
101+ //! use http::{Request, Response};
102+ //! use bytes::Bytes;
103+ //! use http_body_util::Full;
104+ //! use tower::{Service, ServiceBuilder, ServiceExt};
105+ //! use tower_http::follow_redirect::{FollowRedirectLayer, FollowedPolicy, policy};
106+ //!
107+ //! # #[tokio::main]
108+ //! # async fn main() -> Result<(), std::convert::Infallible> {
109+ //! # let http_client =
110+ //! # tower::service_fn(|_: Request<Full<Bytes>>| async { Ok::<_, std::convert::Infallible>(Response::new(Full::<Bytes>::default())) });
111+ //! let mut client = ServiceBuilder::new()
112+ //! .layer(FollowRedirectLayer::with_policy_extension(policy::Limited::new(10)))
113+ //! .service(http_client);
114+ //!
115+ //! let res = client.ready().await?.call(Request::default()).await?;
116+ //! assert_eq!(
117+ //! res.extensions()
118+ //! .get::<FollowedPolicy<policy::Limited>>()
119+ //! .unwrap()
120+ //! .0
121+ //! .remaining,
122+ //! 10
123+ //! );
124+ //! # Ok(())
125+ //! # }
126+ //! ```
94127
95128pub mod policy;
96129
@@ -120,9 +153,9 @@ use tower_service::Service;
120153///
121154/// See the [module docs](self) for more details.
122155#[ derive( Clone , Copy , Debug , Default ) ]
123- pub struct FollowRedirectLayer < P = Standard , CB = NoOp > {
156+ pub struct FollowRedirectLayer < P = Standard , CB = UriExtension > {
124157 policy : P ,
125- callback : CB ,
158+ handler : CB ,
126159}
127160
128161impl FollowRedirectLayer {
@@ -137,12 +170,12 @@ impl<P> FollowRedirectLayer<P> {
137170 pub fn with_policy ( policy : P ) -> Self {
138171 Self {
139172 policy,
140- callback : NoOp :: default ( ) ,
173+ handler : UriExtension :: default ( ) ,
141174 }
142175 }
143176}
144177
145- impl < P > FollowRedirectLayer < P , PolicyExtension >
178+ impl < P > FollowRedirectLayer < P , UriAndPolicyExtensions >
146179where
147180 P : Send + Sync + ' static ,
148181{
@@ -151,7 +184,7 @@ where
151184 pub fn with_policy_extension ( policy : P ) -> Self {
152185 Self {
153186 policy,
154- callback : PolicyExtension :: default ( ) ,
187+ handler : UriAndPolicyExtensions :: default ( ) ,
155188 }
156189 }
157190}
@@ -165,18 +198,18 @@ where
165198 type Service = FollowRedirect < S , P , CB > ;
166199
167200 fn layer ( & self , inner : S ) -> Self :: Service {
168- FollowRedirect :: with_policy_callback ( inner, self . policy . clone ( ) , self . callback )
201+ FollowRedirect :: with_policy_handler ( inner, self . policy . clone ( ) , self . handler )
169202 }
170203}
171204
172205/// Middleware that retries requests with a [`Service`] to follow redirection responses.
173206///
174207/// See the [module docs](self) for more details.
175208#[ derive( Clone , Copy , Debug ) ]
176- pub struct FollowRedirect < S , P = Standard , CB = NoOp > {
209+ pub struct FollowRedirect < S , P = Standard , CB = UriExtension > {
177210 inner : S ,
178211 policy : P ,
179- callback : CB ,
212+ handler : CB ,
180213}
181214
182215impl < S > FollowRedirect < S > {
@@ -193,18 +226,22 @@ impl<S> FollowRedirect<S> {
193226 }
194227}
195228
196- impl < S > FollowRedirect < S , Standard , PolicyExtension > {
229+ impl < S > FollowRedirect < S , Standard , UriAndPolicyExtensions > {
197230 /// Create a new [`FollowRedirect`] with a [`Standard`] redirection policy,
198231 /// that inserts the [`FollowedPolicy`] extension.
199232 pub fn with_extension ( inner : S ) -> Self {
200- Self :: with_policy_callback ( inner, Standard :: default ( ) , PolicyExtension :: default ( ) )
233+ Self :: with_policy_handler (
234+ inner,
235+ Standard :: default ( ) ,
236+ UriAndPolicyExtensions :: default ( ) ,
237+ )
201238 }
202239
203240 /// Returns a new [`Layer`] that wraps services with a `FollowRedirect` middleware
204241 /// that inserts the [`FollowedPolicy`] extension.
205242 ///
206243 /// [`Layer`]: tower_layer::Layer
207- pub fn layer_with_extension ( ) -> FollowRedirectLayer < Standard , PolicyExtension > {
244+ pub fn layer_with_extension ( ) -> FollowRedirectLayer < Standard , UriAndPolicyExtensions > {
208245 FollowRedirectLayer :: with_policy_extension ( Standard :: default ( ) )
209246 }
210247}
@@ -218,7 +255,7 @@ where
218255 FollowRedirect {
219256 inner,
220257 policy,
221- callback : NoOp :: default ( ) ,
258+ handler : UriExtension :: default ( ) ,
222259 }
223260 }
224261
@@ -235,53 +272,53 @@ impl<S, P, CB> FollowRedirect<S, P, CB>
235272where
236273 P : Clone ,
237274{
238- /// Create a new [`FollowRedirect`] with the given redirection [`Policy`] and [`ResponseCallback `].
239- fn with_policy_callback ( inner : S , policy : P , callback : CB ) -> Self {
275+ /// Create a new [`FollowRedirect`] with the given redirection [`Policy`] and [`ResponseHandler `].
276+ fn with_policy_handler ( inner : S , policy : P , handler : CB ) -> Self {
240277 FollowRedirect {
241278 inner,
242279 policy,
243- callback ,
280+ handler ,
244281 }
245282 }
246283
247284 define_inner_service_accessors ! ( ) ;
248285}
249286
250287/// Called on each new response, can be used for example to add [`http::Extensions`]
251- trait ResponseCallback < ReqBody , ResBody , S , P > : Sized
288+ trait ResponseHandler < ReqBody , ResBody , S , P > : Sized
252289where
253290 S : Service < Request < ReqBody > > ,
254291{
255- fn handle ( res : & mut Response < ResBody > , req : & RedirectingRequest < S , ReqBody , P > ) ;
292+ fn on_response ( res : & mut Response < ResBody > , req : & RedirectingRequest < S , ReqBody , P > ) ;
256293}
257294
258- /// Default behavior: doesn't do anything
295+ /// Default behavior: adds a [`RequestUri`] extension to the response.
259296#[ derive( Default , Clone , Copy ) ]
260- pub struct NoOp { }
297+ pub struct UriExtension { }
261298
262- impl < ReqBody , ResBody , S , P > ResponseCallback < ReqBody , ResBody , S , P > for NoOp
299+ impl < ReqBody , ResBody , S , P > ResponseHandler < ReqBody , ResBody , S , P > for UriExtension
263300where
264301 S : Service < Request < ReqBody > > ,
265302{
266- fn handle ( _res : & mut Response < ResBody > , _req : & RedirectingRequest < S , ReqBody , P > ) { }
303+ #[ inline]
304+ fn on_response ( res : & mut Response < ResBody > , req : & RedirectingRequest < S , ReqBody , P > ) {
305+ res. extensions_mut ( ) . insert ( RequestUri ( req. uri . clone ( ) ) ) ;
306+ }
267307}
268308
269- /// Response [`Extensions`][http::Extensions] value that contains the redirect [`Policy`] that
270- /// was run before the last request of the redirect chain by a [`FollowRedirectExtension`] middleware.
271- #[ derive( Clone ) ]
272- pub struct FollowedPolicy < P > ( pub P ) ;
273-
274- /// Adds a [`FollowedPolicy`] extension to the response
275-
309+ /// Adds a [`FollowedPolicy`] and [`RequestUri`] extension to the response.
276310#[ derive( Default , Clone , Copy ) ]
277- pub struct PolicyExtension { }
311+ pub struct UriAndPolicyExtensions { }
278312
279- impl < ReqBody , ResBody , S , P > ResponseCallback < ReqBody , ResBody , S , P > for PolicyExtension
313+ impl < ReqBody , ResBody , S , P > ResponseHandler < ReqBody , ResBody , S , P > for UriAndPolicyExtensions
280314where
281315 S : Service < Request < ReqBody > > ,
282316 P : Clone + Send + Sync + ' static ,
283317{
284- fn handle ( res : & mut Response < ResBody > , req : & RedirectingRequest < S , ReqBody , P > ) {
318+ #[ inline]
319+ fn on_response ( res : & mut Response < ResBody > , req : & RedirectingRequest < S , ReqBody , P > ) {
320+ UriExtension :: on_response ( res, req) ;
321+
285322 res. extensions_mut ( )
286323 . insert ( FollowedPolicy ( req. policy . clone ( ) ) ) ;
287324 }
@@ -292,7 +329,7 @@ where
292329 S : Service < Request < ReqBody > , Response = Response < ResBody > > + Clone ,
293330 ReqBody : Body + Default ,
294331 P : Policy < ReqBody , S :: Error > + Clone ,
295- CB : ResponseCallback < ReqBody , ResBody , S , P > + Copy ,
332+ CB : ResponseHandler < ReqBody , ResBody , S , P > + Copy ,
296333{
297334 type Response = Response < ResBody > ;
298335 type Error = S :: Error ;
@@ -312,7 +349,7 @@ where
312349 ResponseFuture {
313350 future : Either :: Left ( request. service . call ( req) ) ,
314351 request,
315- callback : self . callback ,
352+ handler : self . handler ,
316353 }
317354 }
318355}
@@ -327,7 +364,7 @@ pin_project! {
327364 #[ pin]
328365 future: Either <S :: Future , Oneshot <S , Request <B >>>,
329366 request: RedirectingRequest <S , B , P >,
330- callback : CB
367+ handler : CB
331368 }
332369}
333370
@@ -336,14 +373,14 @@ where
336373 S : Service < Request < ReqBody > , Response = Response < ResBody > > + Clone ,
337374 ReqBody : Body + Default ,
338375 P : Policy < ReqBody , S :: Error > ,
339- CB : ResponseCallback < ReqBody , ResBody , S , P > ,
376+ CB : ResponseHandler < ReqBody , ResBody , S , P > ,
340377{
341378 type Output = Result < Response < ResBody > , S :: Error > ;
342379
343380 fn poll ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Self :: Output > {
344381 let mut this = self . project ( ) ;
345382 let mut res = ready ! ( this. future. as_mut( ) . poll( cx) ?) ;
346- CB :: handle ( & mut res, & this. request ) ;
383+ CB :: on_response ( & mut res, & this. request ) ;
347384
348385 match this. request . handle_response ( & mut res) {
349386 Ok ( Some ( pending) ) => {
@@ -402,8 +439,6 @@ where
402439 & mut self ,
403440 res : & mut Response < ResBody > ,
404441 ) -> Result < Option < Oneshot < S , Request < ReqBody > > > , S :: Error > {
405- res. extensions_mut ( ) . insert ( RequestUri ( self . uri . clone ( ) ) ) ;
406-
407442 let drop_payload_headers = |headers : & mut HeaderMap | {
408443 for header in & [
409444 CONTENT_TYPE ,
@@ -483,6 +518,11 @@ where
483518#[ derive( Clone ) ]
484519pub struct RequestUri ( pub Uri ) ;
485520
521+ /// Response [`Extensions`][http::Extensions] value that contains the redirect [`Policy`] that
522+ /// was run before the last request of the redirect chain by a [`FollowRedirectExtension`] middleware.
523+ #[ derive( Clone ) ]
524+ pub struct FollowedPolicy < P > ( pub P ) ;
525+
486526#[ derive( Debug ) ]
487527enum BodyRepr < B > {
488528 Some ( B ) ,
@@ -551,7 +591,7 @@ mod tests {
551591 #[ tokio:: test]
552592 async fn follows ( ) {
553593 let svc = ServiceBuilder :: new ( )
554- . layer ( FollowRedirectLayer :: with_policy ( Action :: Follow ) )
594+ . layer ( FollowRedirectLayer :: with_policy_extension ( Action :: Follow ) )
555595 . buffer ( 1 )
556596 . service_fn ( handle) ;
557597 let req = Request :: builder ( )
@@ -564,12 +604,18 @@ mod tests {
564604 res. extensions( ) . get:: <RequestUri >( ) . unwrap( ) . 0 ,
565605 "http://example.com/0"
566606 ) ;
607+ assert ! ( res
608+ . extensions( )
609+ . get:: <FollowedPolicy <Action >>( )
610+ . unwrap( )
611+ . 0
612+ . is_follow( ) ) ;
567613 }
568614
569615 #[ tokio:: test]
570616 async fn stops ( ) {
571617 let svc = ServiceBuilder :: new ( )
572- . layer ( FollowRedirectLayer :: with_policy ( Action :: Stop ) )
618+ . layer ( FollowRedirectLayer :: with_policy_extension ( Action :: Stop ) )
573619 . buffer ( 1 )
574620 . service_fn ( handle) ;
575621 let req = Request :: builder ( )
@@ -582,12 +628,18 @@ mod tests {
582628 res. extensions( ) . get:: <RequestUri >( ) . unwrap( ) . 0 ,
583629 "http://example.com/42"
584630 ) ;
631+ assert ! ( res
632+ . extensions( )
633+ . get:: <FollowedPolicy <Action >>( )
634+ . unwrap( )
635+ . 0
636+ . is_stop( ) ) ;
585637 }
586638
587639 #[ tokio:: test]
588640 async fn limited ( ) {
589641 let svc = ServiceBuilder :: new ( )
590- . layer ( FollowRedirectLayer :: with_policy ( Limited :: new ( 10 ) ) )
642+ . layer ( FollowRedirectLayer :: with_policy_extension ( Limited :: new ( 10 ) ) )
591643 . buffer ( 1 )
592644 . service_fn ( handle) ;
593645 let req = Request :: builder ( )
@@ -600,6 +652,14 @@ mod tests {
600652 res. extensions( ) . get:: <RequestUri >( ) . unwrap( ) . 0 ,
601653 "http://example.com/32"
602654 ) ;
655+ assert_eq ! (
656+ res. extensions( )
657+ . get:: <FollowedPolicy <Limited >>( )
658+ . unwrap( )
659+ . 0
660+ . remaining,
661+ 0
662+ ) ;
603663 }
604664
605665 /// A server with an endpoint `GET /{n}` which redirects to `/{n-1}` unless `n` equals zero,
0 commit comments