@@ -14,15 +14,15 @@ use std::{
1414 fmt, future,
1515 io:: { self , Read , Write } ,
1616 pin:: Pin ,
17- task:: { Context , Poll } ,
17+ task:: { Context , Poll , Waker } ,
1818} ;
1919
2020#[ cfg( test) ]
2121mod test;
2222
2323struct StreamWrapper < S > {
2424 stream : S ,
25- context : usize ,
25+ waker : Waker ,
2626}
2727
2828impl < S > fmt:: Debug for StreamWrapper < S >
@@ -37,12 +37,10 @@ where
3737impl < S > StreamWrapper < S > {
3838 /// # Safety
3939 ///
40- /// Must be called with `context` set to a valid pointer to a live `Context` object, and the
41- /// wrapper must be pinned in memory.
42- unsafe fn parts ( & mut self ) -> ( Pin < & mut S > , & mut Context < ' _ > ) {
43- debug_assert_ne ! ( self . context, 0 ) ;
44- let stream = Pin :: new_unchecked ( & mut self . stream ) ;
45- let context = & mut * ( self . context as * mut _ ) ;
40+ /// The wrapper must be pinned in memory.
41+ unsafe fn parts ( & mut self ) -> ( Pin < & mut S > , Context < ' _ > ) {
42+ let stream = unsafe { Pin :: new_unchecked ( & mut self . stream ) } ;
43+ let context = Context :: from_waker ( & self . waker ) ;
4644 ( stream, context)
4745 }
4846}
5250 S : AsyncRead ,
5351{
5452 fn read ( & mut self , buf : & mut [ u8 ] ) -> io:: Result < usize > {
55- let ( stream, cx) = unsafe { self . parts ( ) } ;
56- match stream. poll_read ( cx, buf) ? {
53+ let ( stream, mut cx) = unsafe { self . parts ( ) } ;
54+ match stream. poll_read ( & mut cx, buf) ? {
5755 Poll :: Ready ( nread) => Ok ( nread) ,
5856 Poll :: Pending => Err ( io:: Error :: from ( io:: ErrorKind :: WouldBlock ) ) ,
5957 }
@@ -65,16 +63,16 @@ where
6563 S : AsyncWrite ,
6664{
6765 fn write ( & mut self , buf : & [ u8 ] ) -> io:: Result < usize > {
68- let ( stream, cx) = unsafe { self . parts ( ) } ;
69- match stream. poll_write ( cx, buf) {
66+ let ( stream, mut cx) = unsafe { self . parts ( ) } ;
67+ match stream. poll_write ( & mut cx, buf) {
7068 Poll :: Ready ( r) => r,
7169 Poll :: Pending => Err ( io:: Error :: from ( io:: ErrorKind :: WouldBlock ) ) ,
7270 }
7371 }
7472
7573 fn flush ( & mut self ) -> io:: Result < ( ) > {
76- let ( stream, cx) = unsafe { self . parts ( ) } ;
77- match stream. poll_flush ( cx) {
74+ let ( stream, mut cx) = unsafe { self . parts ( ) } ;
75+ match stream. poll_flush ( & mut cx) {
7876 Poll :: Ready ( r) => r,
7977 Poll :: Pending => Err ( io:: Error :: from ( io:: ErrorKind :: WouldBlock ) ) ,
8078 }
@@ -109,7 +107,14 @@ where
109107{
110108 /// Like [`SslStream::new`](ssl::SslStream::new).
111109 pub fn new ( ssl : Ssl , stream : S ) -> Result < Self , ErrorStack > {
112- ssl:: SslStream :: new ( ssl, StreamWrapper { stream, context : 0 } ) . map ( SslStream )
110+ ssl:: SslStream :: new (
111+ ssl,
112+ StreamWrapper {
113+ stream,
114+ waker : Waker :: noop ( ) . clone ( ) ,
115+ } ,
116+ )
117+ . map ( SslStream )
113118 }
114119
115120 /// Like [`SslStream::connect`](ssl::SslStream::connect).
@@ -227,10 +232,8 @@ impl<S> SslStream<S> {
227232 F : FnOnce ( & mut ssl:: SslStream < StreamWrapper < S > > ) -> R ,
228233 {
229234 let this = unsafe { self . get_unchecked_mut ( ) } ;
230- this. 0 . get_mut ( ) . context = ctx as * mut _ as usize ;
231- let r = f ( & mut this. 0 ) ;
232- this. 0 . get_mut ( ) . context = 0 ;
233- r
235+ this. 0 . get_mut ( ) . waker = ctx. waker ( ) . clone ( ) ;
236+ f ( & mut this. 0 )
234237 }
235238}
236239
0 commit comments