Skip to content

Commit 04f6181

Browse files
committed
RUST-802 Support Unix Domain Sockets
1 parent a44f669 commit 04f6181

File tree

6 files changed

+341
-39
lines changed

6 files changed

+341
-39
lines changed

src/client/executor.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ impl Client {
9999
op: T,
100100
session: impl Into<Option<&mut ClientSession>>,
101101
) -> Result<ExecutionDetails<T>> {
102-
Box::pin(async {
102+
async {
103103
// TODO RUST-9: allow unacknowledged write concerns
104104
if !op.is_acknowledged() {
105105
return Err(ErrorKind::InvalidArgument {
@@ -130,7 +130,7 @@ impl Client {
130130
}
131131
}
132132
self.execute_operation_with_retry(op, session).await
133-
})
133+
}
134134
.await
135135
}
136136

@@ -141,7 +141,7 @@ impl Client {
141141
where
142142
Op: Operation<O = CursorSpecification>,
143143
{
144-
Box::pin(async {
144+
async {
145145
let mut details = self.execute_operation_with_details(op, None).await?;
146146
let pinned =
147147
self.pin_connection_for_cursor(&details.output, &mut details.connection)?;
@@ -151,7 +151,7 @@ impl Client {
151151
details.implicit_session,
152152
pinned,
153153
))
154-
})
154+
}
155155
.await
156156
}
157157

@@ -212,7 +212,7 @@ impl Client {
212212
where
213213
T: DeserializeOwned + Unpin + Send + Sync,
214214
{
215-
Box::pin(async {
215+
async {
216216
let pipeline: Vec<_> = pipeline.into_iter().collect();
217217
let args = WatchArgs {
218218
pipeline,
@@ -235,7 +235,7 @@ impl Client {
235235
let cursor = Cursor::new(self.clone(), cursor_spec, details.implicit_session, pinned);
236236

237237
Ok(ChangeStream::new(cursor, args, cs_data))
238-
})
238+
}
239239
.await
240240
}
241241

@@ -250,7 +250,7 @@ impl Client {
250250
where
251251
T: DeserializeOwned + Unpin + Send + Sync,
252252
{
253-
Box::pin(async {
253+
async {
254254
let pipeline: Vec<_> = pipeline.into_iter().collect();
255255
let args = WatchArgs {
256256
pipeline,
@@ -268,7 +268,7 @@ impl Client {
268268
let cursor = SessionCursor::new(self.clone(), cursor_spec, pinned);
269269

270270
Ok(SessionChangeStream::new(cursor, args, cs_data))
271-
})
271+
}
272272
.await
273273
}
274274

src/client/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,7 @@ impl Client {
485485
&self,
486486
criteria: Option<&SelectionCriteria>,
487487
) -> Result<ServerAddress> {
488-
let server = self.select_server(criteria, "Test select server").await?;
488+
let server = self.select_server(criteria, "Test select server").await?;
489489
Ok(server.address.clone())
490490
}
491491

src/client/options/mod.rs

Lines changed: 77 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,6 @@ lazy_static! {
9696
}
9797

9898
/// An enum representing the address of a MongoDB server.
99-
///
100-
/// Currently this just supports addresses that can be connected to over TCP, but alternative
101-
/// address types may be supported in the future (e.g. Unix Domain Socket paths).
10299
#[derive(Clone, Debug, Eq, Serialize)]
103100
#[non_exhaustive]
104101
pub enum ServerAddress {
@@ -112,6 +109,12 @@ pub enum ServerAddress {
112109
/// The default is 27017.
113110
port: Option<u16>,
114111
},
112+
/// A Unix Domain Socket path.
113+
#[cfg(unix)]
114+
Unix {
115+
/// The path to the Unix Domain Socket.
116+
path: PathBuf,
117+
},
115118
}
116119

117120
impl<'de> Deserialize<'de> for ServerAddress {
@@ -144,6 +147,10 @@ impl PartialEq for ServerAddress {
144147
port: other_port,
145148
},
146149
) => host == other_host && port.unwrap_or(27017) == other_port.unwrap_or(27017),
150+
#[cfg(unix)]
151+
(Self::Unix { path }, Self::Unix { path: other_path }) => path == other_path,
152+
#[cfg(unix)]
153+
_ => false,
147154
}
148155
}
149156
}
@@ -158,6 +165,8 @@ impl Hash for ServerAddress {
158165
host.hash(state);
159166
port.unwrap_or(27017).hash(state);
160167
}
168+
#[cfg(unix)]
169+
Self::Unix { path } => path.hash(state),
161170
}
162171
}
163172
}
@@ -173,6 +182,15 @@ impl ServerAddress {
173182
/// Parses an address string into a `ServerAddress`.
174183
pub fn parse(address: impl AsRef<str>) -> Result<Self> {
175184
let address = address.as_ref();
185+
// checks if the address is a unix domain socket
186+
#[cfg(unix)]
187+
{
188+
if address.starts_with('/') {
189+
return Ok(ServerAddress::Unix {
190+
path: PathBuf::from(address),
191+
});
192+
}
193+
}
176194
let mut parts = address.split(':');
177195
let hostname = match parts.next() {
178196
Some(part) => {
@@ -243,18 +261,28 @@ impl ServerAddress {
243261
"port": port.map(|i| Bson::Int32(i.into())).unwrap_or(Bson::Null)
244262
}
245263
}
264+
#[cfg(unix)]
265+
Self::Unix { path } => {
266+
doc! {
267+
"path": path.to_str().unwrap(),
268+
}
269+
}
246270
}
247271
}
248272

249273
pub(crate) fn host(&self) -> &str {
250274
match self {
251275
Self::Tcp { host, .. } => host.as_str(),
276+
#[cfg(unix)]
277+
Self::Unix { path } => path.to_str().unwrap(),
252278
}
253279
}
254280

255281
pub(crate) fn port(&self) -> Option<u16> {
256282
match self {
257283
Self::Tcp { port, .. } => *port,
284+
#[cfg(unix)]
285+
Self::Unix { .. } => None,
258286
}
259287
}
260288
}
@@ -265,6 +293,8 @@ impl fmt::Display for ServerAddress {
265293
Self::Tcp { host, port } => {
266294
write!(fmt, "{}:{}", host, port.unwrap_or(DEFAULT_PORT))
267295
}
296+
#[cfg(unix)]
297+
Self::Unix { path } => write!(fmt, "{}", path.display()),
268298
}
269299
}
270300
}
@@ -1592,16 +1622,26 @@ impl ConnectionString {
15921622
}
15931623
.into());
15941624
}
1595-
// Unwrap safety: the `len` check above guarantees this can't fail.
1596-
let ServerAddress::Tcp { host, port } = host_list.into_iter().next().unwrap();
15971625

1598-
if port.is_some() {
1599-
return Err(ErrorKind::InvalidArgument {
1600-
message: "a port cannot be specified with 'mongodb+srv'".into(),
1626+
// Unwrap safety: the `len` check above guarantees this can't fail.
1627+
match host_list.into_iter().next().unwrap() {
1628+
ServerAddress::Tcp { host, port } => {
1629+
if port.is_some() {
1630+
return Err(ErrorKind::InvalidArgument {
1631+
message: "a port cannot be specified with 'mongodb+srv'".into(),
1632+
}
1633+
.into());
1634+
}
1635+
HostInfo::DnsRecord(host)
1636+
}
1637+
#[cfg(unix)]
1638+
ServerAddress::Unix { .. } => {
1639+
return Err(ErrorKind::InvalidArgument {
1640+
message: "unix sockets cannot be used with 'mongodb+srv'".into(),
1641+
}
1642+
.into());
16011643
}
1602-
.into());
16031644
}
1604-
HostInfo::DnsRecord(host)
16051645
} else {
16061646
HostInfo::HostIdentifiers(host_list)
16071647
};
@@ -2299,18 +2339,39 @@ mod tests {
22992339
#[test]
23002340
fn test_parse_address_with_from_str() {
23012341
let x = "localhost:27017".parse::<ServerAddress>().unwrap();
2302-
let ServerAddress::Tcp { host, port } = x;
2303-
assert_eq!(host, "localhost");
2304-
assert_eq!(port, Some(27017));
2342+
match x {
2343+
ServerAddress::Tcp { host, port } => {
2344+
assert_eq!(host, "localhost");
2345+
assert_eq!(port, Some(27017));
2346+
}
2347+
#[cfg(unix)]
2348+
_ => panic!("expected ServerAddress::Tcp"),
2349+
}
23052350

23062351
// Port defaults to 27017 (so this doesn't fail)
23072352
let x = "localhost".parse::<ServerAddress>().unwrap();
2308-
let ServerAddress::Tcp { host, port } = x;
2309-
assert_eq!(host, "localhost");
2310-
assert_eq!(port, None);
2353+
match x {
2354+
ServerAddress::Tcp { host, port } => {
2355+
assert_eq!(host, "localhost");
2356+
assert_eq!(port, None);
2357+
}
2358+
#[cfg(unix)]
2359+
_ => panic!("expected ServerAddress::Tcp"),
2360+
}
23112361

23122362
let x = "localhost:not a number".parse::<ServerAddress>();
23132363
assert!(x.is_err());
2364+
2365+
#[cfg(unix)]
2366+
{
2367+
let x = "/path/to/socket.sock".parse::<ServerAddress>().unwrap();
2368+
match x {
2369+
ServerAddress::Unix { path } => {
2370+
assert_eq!(path.to_str().unwrap(), "/path/to/socket.sock");
2371+
}
2372+
_ => panic!("expected ServerAddress::Unix"),
2373+
}
2374+
}
23142375
}
23152376

23162377
#[cfg_attr(feature = "tokio-runtime", tokio::test)]

0 commit comments

Comments
 (0)