1
1
use std:: collections:: HashMap ;
2
+ use std:: sync:: Arc ;
2
3
3
4
use crate :: error:: MutinyError ;
4
5
use crate :: storage:: MutinyStorage ;
5
6
use core:: time:: Duration ;
7
+ use gloo_net:: websocket:: futures:: WebSocket ;
6
8
use hex_conservative:: DisplayHex ;
7
9
use once_cell:: sync:: Lazy ;
8
10
use payjoin:: receive:: v2:: Enrolled ;
@@ -69,16 +71,73 @@ impl<S: MutinyStorage> PayjoinStorage for S {
69
71
}
70
72
}
71
73
72
- pub async fn fetch_ohttp_keys ( _ohttp_relay : Url , directory : Url ) -> Result < OhttpKeys , Error > {
73
- let http_client = reqwest :: Client :: builder ( ) . build ( ) ? ;
74
+ pub async fn fetch_ohttp_keys ( ohttp_relay : Url , directory : Url ) -> Result < OhttpKeys , Error > {
75
+ use futures_util :: { AsyncReadExt , AsyncWriteExt } ;
74
76
75
- let ohttp_keys_res = http_client
76
- . get ( format ! ( "{}/ohttp-keys" , directory. as_ref( ) ) )
77
- . send ( )
78
- . await ?
79
- . bytes ( )
80
- . await ?;
81
- Ok ( OhttpKeys :: decode ( ohttp_keys_res. as_ref ( ) ) . map_err ( |_| Error :: OhttpDecodeFailed ) ?)
77
+ let tls_connector = {
78
+ let root_store = futures_rustls:: rustls:: RootCertStore {
79
+ roots : webpki_roots:: TLS_SERVER_ROOTS . to_vec ( ) ,
80
+ } ;
81
+ let config = futures_rustls:: rustls:: ClientConfig :: builder ( )
82
+ . with_root_certificates ( root_store)
83
+ . with_no_client_auth ( ) ;
84
+ futures_rustls:: TlsConnector :: from ( Arc :: new ( config) )
85
+ } ;
86
+ let directory_host = directory. host_str ( ) . ok_or ( Error :: BadDirectoryHost ) ?;
87
+ let domain = futures_rustls:: rustls:: pki_types:: ServerName :: try_from ( directory_host)
88
+ . map_err ( |_| Error :: BadDirectoryHost ) ?
89
+ . to_owned ( ) ;
90
+
91
+ let ws = WebSocket :: open ( & format ! (
92
+ "wss://{}:443" ,
93
+ ohttp_relay. host_str( ) . ok_or( Error :: BadOhttpWsHost ) ?
94
+ ) )
95
+ . map_err ( |_| Error :: BadOhttpWsHost ) ?;
96
+
97
+ let mut tls_stream = tls_connector
98
+ . connect ( domain, ws)
99
+ . await
100
+ . map_err ( |e| Error :: RequestFailed ( e. to_string ( ) ) ) ?;
101
+ let ohttp_keys_req = format ! (
102
+ "GET /ohttp-keys HTTP/1.1\r \n Host: {}\r \n Connection: close\r \n \r \n " ,
103
+ directory_host
104
+ ) ;
105
+ tls_stream
106
+ . write_all ( ohttp_keys_req. as_bytes ( ) )
107
+ . await
108
+ . map_err ( |e| Error :: RequestFailed ( e. to_string ( ) ) ) ?;
109
+ tls_stream
110
+ . flush ( )
111
+ . await
112
+ . map_err ( |e| Error :: RequestFailed ( e. to_string ( ) ) ) ?;
113
+ let mut response_bytes = Vec :: new ( ) ;
114
+ tls_stream
115
+ . read_to_end ( & mut response_bytes)
116
+ . await
117
+ . map_err ( |e| Error :: RequestFailed ( e. to_string ( ) ) ) ?;
118
+ let ( _headers, res_body) = separate_headers_and_body ( & response_bytes) ?;
119
+ payjoin:: OhttpKeys :: decode ( res_body) . map_err ( |_| Error :: OhttpDecodeFailed )
120
+ }
121
+
122
+ fn separate_headers_and_body ( response_bytes : & [ u8 ] ) -> Result < ( & [ u8 ] , & [ u8 ] ) , Error > {
123
+ let separator = b"\r \n \r \n " ;
124
+
125
+ // Search for the separator
126
+ if let Some ( position) = response_bytes
127
+ . windows ( separator. len ( ) )
128
+ . position ( |window| window == separator)
129
+ {
130
+ // The body starts immediately after the separator
131
+ let body_start_index = position + separator. len ( ) ;
132
+ let headers = & response_bytes[ ..position] ;
133
+ let body = & response_bytes[ body_start_index..] ;
134
+
135
+ Ok ( ( headers, body) )
136
+ } else {
137
+ Err ( Error :: RequestFailed (
138
+ "No header-body separator found in the response" . to_string ( ) ,
139
+ ) )
140
+ }
82
141
}
83
142
84
143
#[ derive( Debug ) ]
@@ -89,6 +148,9 @@ pub enum Error {
89
148
OhttpDecodeFailed ,
90
149
Shutdown ,
91
150
SessionExpired ,
151
+ BadDirectoryHost ,
152
+ BadOhttpWsHost ,
153
+ RequestFailed ( String ) ,
92
154
}
93
155
94
156
impl std:: error:: Error for Error { }
@@ -102,6 +164,9 @@ impl std::fmt::Display for Error {
102
164
Error :: OhttpDecodeFailed => write ! ( f, "Failed to decode ohttp keys" ) ,
103
165
Error :: Shutdown => write ! ( f, "Payjoin stopped by application shutdown" ) ,
104
166
Error :: SessionExpired => write ! ( f, "Payjoin session expired. Create a new payment request and have the sender try again." ) ,
167
+ Error :: BadDirectoryHost => write ! ( f, "Bad directory host" ) ,
168
+ Error :: BadOhttpWsHost => write ! ( f, "Bad ohttp ws host" ) ,
169
+ Error :: RequestFailed ( e) => write ! ( f, "Request failed: {}" , e) ,
105
170
}
106
171
}
107
172
}
0 commit comments