@@ -6,7 +6,8 @@ use std::io::{Read, Write};
66use std:: ops:: { Deref , DerefMut } ;
77use std:: str;
88use std:: sync:: mpsc;
9- use std:: sync:: Arc ;
9+
10+ #[ cfg( feature = "rsasl" ) ]
1011use rsasl:: prelude:: { Mechname , SASLClient , SASLConfig , Session as SASLSession , State as SASLState } ;
1112
1213use super :: authenticator:: Authenticator ;
@@ -367,7 +368,7 @@ impl<T: Read + Write> Client<T> {
367368 /// match client.login("user", "pass") {
368369 /// Ok(s) => {
369370 /// // you are successfully authenticated!
370- /// },
371+ /// }
371372 /// Err((e, orig_client)) => {
372373 /// eprintln!("error logging in: {}", e);
373374 /// // prompt user and try again with orig_client here
@@ -425,7 +426,7 @@ impl<T: Read + Write> Client<T> {
425426 /// match client.authenticate("XOAUTH2", &auth) {
426427 /// Ok(session) => {
427428 /// // you are successfully authenticated!
428- /// },
429+ /// }
429430 /// Err((e, orig_client)) => {
430431 /// eprintln!("error authenticating: {}", e);
431432 /// // prompt user and try again with orig_client here
@@ -434,9 +435,82 @@ impl<T: Read + Write> Client<T> {
434435 /// };
435436 /// }
436437 /// ```
437- pub fn authenticate (
438+ pub fn authenticate < A : Authenticator > (
438439 mut self ,
439- config : Arc < SASLConfig > ,
440+ auth_type : impl AsRef < str > ,
441+ authenticator : & A ,
442+ ) -> :: std:: result:: Result < Session < T > , ( Error , Client < T > ) > {
443+ ok_or_unauth_client_err ! (
444+ self . run_command( & format!( "AUTHENTICATE {}" , auth_type. as_ref( ) ) ) ,
445+ self
446+ ) ;
447+ self . do_auth_handshake ( authenticator)
448+ }
449+
450+ /// This func does the handshake process once the authenticate command is made.
451+ fn do_auth_handshake < A : Authenticator > (
452+ mut self ,
453+ authenticator : & A ,
454+ ) -> :: std:: result:: Result < Session < T > , ( Error , Client < T > ) > {
455+ // TODO Clean up this code
456+ loop {
457+ let mut line = Vec :: new ( ) ;
458+
459+ // explicit match blocks neccessary to convert error to tuple and not bind self too
460+ // early (see also comment on `login`)
461+ ok_or_unauth_client_err ! ( self . readline( & mut line) , self ) ;
462+
463+ // ignore server comments
464+ if line. starts_with ( b"* " ) {
465+ continue ;
466+ }
467+
468+ // Some servers will only send `+\r\n`.
469+ if line. starts_with ( b"+ " ) || & line == b"+\r \n " {
470+ let challenge = if & line == b"+\r \n " {
471+ Vec :: new ( )
472+ } else {
473+ let line_str = ok_or_unauth_client_err ! (
474+ match str :: from_utf8( line. as_slice( ) ) {
475+ Ok ( line_str) => Ok ( line_str) ,
476+ Err ( e) => Err ( Error :: Parse ( ParseError :: DataNotUtf8 ( line, e) ) ) ,
477+ } ,
478+ self
479+ ) ;
480+ let data =
481+ ok_or_unauth_client_err ! ( parse_authenticate_response( line_str) , self ) ;
482+ ok_or_unauth_client_err ! (
483+ base64:: decode( data) . map_err( |e| Error :: Parse ( ParseError :: Authentication (
484+ data. to_string( ) ,
485+ Some ( e)
486+ ) ) ) ,
487+ self
488+ )
489+ } ;
490+
491+ let raw_response = & authenticator. process ( & challenge) ;
492+ let auth_response = base64:: encode ( raw_response) ;
493+ ok_or_unauth_client_err ! (
494+ self . write_line( auth_response. into_bytes( ) . as_slice( ) ) ,
495+ self
496+ ) ;
497+ } else {
498+ ok_or_unauth_client_err ! ( self . read_response_onto( & mut line) , self ) ;
499+ return Ok ( Session :: new ( self . conn ) ) ;
500+ }
501+ }
502+ }
503+ }
504+
505+ #[ cfg( feature = "rsasl" ) ]
506+ impl < T : Read + Write > Client < T > {
507+
508+ /// Authenticate with the server using the given custom SASLConfig to handle the server's
509+ /// challenge.
510+ ///
511+ pub fn sasl_auth (
512+ mut self ,
513+ config : :: std:: sync:: Arc < SASLConfig > ,
440514 mechanism : & Mechname ,
441515 ) -> :: std:: result:: Result < Session < T > , ( Error , Client < T > ) > {
442516 let client = SASLClient :: new ( config) ;
@@ -450,11 +524,11 @@ impl<T: Read + Write> Client<T> {
450524 self . run_command( & format!( "AUTHENTICATE {}" , mechanism. as_str( ) ) ) ,
451525 self
452526 ) ;
453- self . do_auth_handshake ( session)
527+ self . do_sasl_handshake ( session)
454528 }
455529
456- /// This func does the handshake process once the authenticate command is made.
457- fn do_auth_handshake (
530+ /// This func does the SASL handshake process once the authenticate command is made.
531+ fn do_sasl_handshake (
458532 mut self ,
459533 mut authenticator : SASLSession ,
460534 ) -> :: std:: result:: Result < Session < T > , ( Error , Client < T > ) > {
0 commit comments