Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experimental support for Azure managed service identities #4680

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion quaint/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ postgresql-native = [
]
postgresql = []

mssql-native = ["mssql", "tiberius", "tokio-util", "tokio/time", "tokio/net"]
mssql-native = ["mssql", "tiberius", "tokio-util", "tokio/time", "tokio/net", "reqwest/json"]
mssql = []

mysql-native = ["mysql", "mysql_async", "tokio/time", "lru-cache"]
Expand All @@ -77,6 +77,7 @@ futures = "0.3"
url = "2.1"
hex = "0.4"
itertools.workspace = true
reqwest = "0.11"

either = { version = "1.6" }
base64 = { version = "0.12.3" }
Expand Down
40 changes: 39 additions & 1 deletion quaint/src/connector/mssql/native/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,16 @@ use crate::connector::{timeout, IsolationLevel, Transaction, TransactionOptions}
use crate::{
ast::{Query, Value},
connector::{metrics, queryable::*, DefaultTransaction, ResultSet},
error::{Error, ErrorKind},
visitor::{self, Visitor},
};
use async_trait::async_trait;
use connection_string::JdbcString;
use futures::lock::Mutex;
use std::{
collections::HashMap,
convert::TryFrom,
env,
future::Future,
sync::atomic::{AtomicBool, Ordering},
time::Duration,
Expand Down Expand Up @@ -64,7 +68,37 @@ pub struct Mssql {
impl Mssql {
/// Creates a new connection to SQL Server.
pub async fn new(url: MssqlUrl) -> crate::Result<Self> {
let config = Config::from_jdbc_string(&url.connection_string)?;
let mut config = Config::from_jdbc_string(&url.connection_string)?;

// TODO: should I change Config so I don't need to parse this twice, once here
// and again inside Config::from_jdbc_string?
// TODO: maybe this code belongs in tiberius instead of up here?
//
// This code follows MS's documentation at https://learn.microsoft.com/en-us/azure/app-service/overview-managed-identity?tabs=portal%2Chttp#connect-to-azure-services-in-app-code
// TODO: I should actually be re-fetching the token after msi_response.get("expires_on") has passed
let jdbc_config: JdbcString = url.connection_string.parse()?;
if let Some(authentication_type) = jdbc_config.properties().get("authentication") {
if authentication_type == "ActiveDirectoryMsi" {
let mut msi_url = get_required_env_var("IDENTITY_ENDPOINT")?;
msi_url.push_str("?resource=https%3A%2F%2Fdatabase.windows.net%2F&api-version=2019-08-01");
let identity_header = get_required_env_var("IDENTITY_HEADER")?;
let client = reqwest::Client::new();
let msi_response = client
.get(msi_url)
.header("X-IDENTITY-HEADER", identity_header)
.timeout(std::time::Duration::new(30, 0))
.send()
.await
.map_err(|e| Error::builder(ErrorKind::AuthTokenFetchFailure(Box::new(e))).build())?
.json::<HashMap<String, String>>()
.await
.map_err(|e| Error::builder(ErrorKind::AuthTokenFetchFailure(Box::new(e))).build())?;
if let Some(token) = msi_response.get("access_token") {
config.authentication(tiberius::AuthMethod::AADToken(token.clone()));
}
}
}

let tcp = TcpStream::connect_named(&config).await?;
let socket_timeout = url.socket_timeout();

Expand Down Expand Up @@ -121,6 +155,10 @@ impl Mssql {
}
}

fn get_required_env_var(name: &str) -> std::result::Result<String, Error> {
env::var(name).map_err(|_| Error::builder(ErrorKind::MissingEnvironmentVariable { name: name.into() }).build())
}

#[async_trait]
impl Queryable for Mssql {
async fn query(&self, q: Query<'_>) -> crate::Result<ResultSet> {
Expand Down
6 changes: 6 additions & 0 deletions quaint/src/error/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,12 @@ pub enum ErrorKind {
#[error("Foreign key constraint failed: {}", constraint)]
ForeignKeyConstraintViolation { constraint: DatabaseConstraint },

#[error("Error fetching auth token: {}", _0)]
AuthTokenFetchFailure(Box<dyn std::error::Error + Send + Sync + 'static>),

#[error("Missing environment variable: {}", name)]
MissingEnvironmentVariable { name: String },

#[error("Error reading the column value: {}", _0)]
ColumnReadFailure(Box<dyn std::error::Error + Send + Sync + 'static>),

Expand Down
2 changes: 2 additions & 0 deletions query-engine/connectors/sql-query-connector/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,8 @@ impl From<quaint::error::Error> for SqlError {
e @ QuaintKind::DatabaseAccessDenied { .. } => SqlError::ConnectionError(e),
e @ QuaintKind::DatabaseAlreadyExists { .. } => SqlError::ConnectionError(e),
e @ QuaintKind::InvalidConnectionArguments => SqlError::ConnectionError(e),
e @ QuaintKind::AuthTokenFetchFailure { .. } => SqlError::ConnectionError(e),
e @ QuaintKind::MissingEnvironmentVariable { .. } => SqlError::ConnectionError(e),
e @ QuaintKind::SocketTimeout => SqlError::ConnectionError(e),
}
}
Expand Down