diff --git a/storage/adls_gen1_mount.go b/storage/adls_gen1_mount.go index f2e49d1f4b..c06c64e924 100644 --- a/storage/adls_gen1_mount.go +++ b/storage/adls_gen1_mount.go @@ -34,7 +34,7 @@ func (m AzureADLSGen1Mount) ValidateAndApplyDefaults(d *schema.ResourceData, cli // Config ... func (m AzureADLSGen1Mount) Config(client *common.DatabricksClient) map[string]string { - aadEndpoint := client.Config.Environment().AzureActiveDirectoryEndpoint() + aadEndpoint := azureActiveDirectoryEndpoint(client.Config) return map[string]string{ m.PrefixType + ".oauth2.access.token.provider.type": "ClientCredential", diff --git a/storage/adls_gen2_mount.go b/storage/adls_gen2_mount.go index dd912885d2..5528468c8d 100644 --- a/storage/adls_gen2_mount.go +++ b/storage/adls_gen2_mount.go @@ -2,7 +2,6 @@ package storage import ( "fmt" - "strings" "github.com/databricks/terraform-provider-databricks/common" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" @@ -21,23 +20,9 @@ type AzureADLSGen2Mount struct { InitializeFileSystem bool `json:"initialize_file_system"` } -func getAzureDomain(client *common.DatabricksClient) string { - domains := map[string]string{ - "PUBLIC": "core.windows.net", - "USGOVERNMENT": "core.usgovcloudapi.net", - "CHINA": "core.chinacloudapi.cn", - } - azureEnvironment := client.Config.Environment().AzureEnvironment.Name - domain, ok := domains[strings.ToUpper(azureEnvironment)] - if !ok { - panic(fmt.Sprintf("Unknown Azure environment: '%s'", azureEnvironment)) - } - return domain -} - // Source returns ABFSS URI backing the mount func (m AzureADLSGen2Mount) Source(client *common.DatabricksClient) string { - return fmt.Sprintf("abfss://%s@%s.dfs.%s%s", m.ContainerName, m.StorageAccountName, getAzureDomain(client), m.Directory) + return fmt.Sprintf("abfss://%s@%s.dfs.%s%s", m.ContainerName, m.StorageAccountName, azureDomain(client.Config), m.Directory) } func (m AzureADLSGen2Mount) Name() string { @@ -50,7 +35,7 @@ func (m AzureADLSGen2Mount) ValidateAndApplyDefaults(d *schema.ResourceData, cli // Config returns mount configurations func (m AzureADLSGen2Mount) Config(client *common.DatabricksClient) map[string]string { - aadEndpoint := client.Config.Environment().AzureActiveDirectoryEndpoint() + aadEndpoint := azureActiveDirectoryEndpoint(client.Config) return map[string]string{ "fs.azure.account.auth.type": "OAuth", "fs.azure.account.oauth.provider.type": "org.apache.hadoop.fs.azurebfs.oauth2.ClientCredsTokenProvider", diff --git a/storage/azure_blob_mount.go b/storage/azure_blob_mount.go index 1b4832be2e..e47f9507e6 100644 --- a/storage/azure_blob_mount.go +++ b/storage/azure_blob_mount.go @@ -21,7 +21,7 @@ type AzureBlobMount struct { // Source ... func (m AzureBlobMount) Source(client *common.DatabricksClient) string { return fmt.Sprintf("wasbs://%[1]s@%[2]s.blob.%[3]s%[4]s", - m.ContainerName, m.StorageAccountName, getAzureDomain(client), m.Directory) + m.ContainerName, m.StorageAccountName, azureDomain(client.Config), m.Directory) } func (m AzureBlobMount) Name() string { diff --git a/storage/environments.go b/storage/environments.go new file mode 100644 index 0000000000..9370ee4d10 --- /dev/null +++ b/storage/environments.go @@ -0,0 +1,80 @@ +package storage + +import ( + "fmt" + "log" + "strings" + + "github.com/databricks/databricks-sdk-go/config" +) + +type azureEnvironment struct { + Name string + Domain string + ActiveDirectoryEndpoint string +} + +var AzurePublicCloud = azureEnvironment{ + Name: "PUBLIC", + Domain: "core.windows.net", + ActiveDirectoryEndpoint: "https://login.microsoftonline.com/", +} + +var AzureUsGovernmentCloud = azureEnvironment{ + Name: "USGOVERNMENT", + Domain: "core.usgovcloudapi.net", + ActiveDirectoryEndpoint: "https://login.microsoftonline.us/", +} + +var AzureChinaCloud = azureEnvironment{ + Name: "CHINA", + Domain: "core.chinacloudapi.cn", + ActiveDirectoryEndpoint: "https://login.chinacloudapi.cn/", +} + +func azureActiveDirectoryEndpoint(cfg *config.Config) string { + env, err := environment(cfg) + if err != nil { + // TODO: The error is swallowed for backward compatibility. We should + // consider returning it to the caller. + log.Printf("[DEBUG] Failed to get Azure Active Directory endpoint: %s", err) + return "" + } + return env.ActiveDirectoryEndpoint +} + +func azureDomain(cfg *config.Config) string { + env, err := environment(cfg) + if err != nil { + panic(fmt.Sprintf("Failed to get Azure domain: %s", err)) + } + return env.Domain +} + +func environment(cfg *config.Config) (azureEnvironment, error) { + switch strings.ToUpper(cfg.AzureEnvironment) { + case "PUBLIC", "": + return AzurePublicCloud, nil + case "USGOVERNMENT": + return AzureUsGovernmentCloud, nil + case "CHINA": + return AzureChinaCloud, nil + } + + // If the environment is not specified, infer the environment from + // the host. + switch { + case strings.HasSuffix(cfg.Host, ".dev.azuredatabricks.net"): + return AzurePublicCloud, nil + case strings.HasSuffix(cfg.Host, ".staging.azuredatabricks.net"): + return AzurePublicCloud, nil + case strings.HasSuffix(cfg.Host, ".azuredatabricks.net"): + return AzurePublicCloud, nil + case strings.HasSuffix(cfg.Host, ".databricks.azure.us"): + return AzureUsGovernmentCloud, nil + case strings.HasSuffix(cfg.Host, ".databricks.azure.cn"): + return AzureChinaCloud, nil + } + + return azureEnvironment{}, fmt.Errorf("unable to infer Azure environment") +} diff --git a/storage/generic_mounts.go b/storage/generic_mounts.go index 72e8d4dedb..da5b42c7cd 100644 --- a/storage/generic_mounts.go +++ b/storage/generic_mounts.go @@ -135,7 +135,7 @@ type AzureADLSGen2MountGeneric struct { // Source returns ABFSS URI backing the mount func (m *AzureADLSGen2MountGeneric) Source(client *common.DatabricksClient) string { - return fmt.Sprintf("abfss://%s@%s.dfs.%s%s", m.ContainerName, m.StorageAccountName, getAzureDomain(client), m.Directory) + return fmt.Sprintf("abfss://%s@%s.dfs.%s%s", m.ContainerName, m.StorageAccountName, azureDomain(client.Config), m.Directory) } func (m *AzureADLSGen2MountGeneric) Name() string { @@ -168,7 +168,7 @@ func (m *AzureADLSGen2MountGeneric) ValidateAndApplyDefaults(d *schema.ResourceD // Config returns mount configurations func (m *AzureADLSGen2MountGeneric) Config(client *common.DatabricksClient) map[string]string { - aadEndpoint := client.Config.Environment().AzureActiveDirectoryEndpoint() + aadEndpoint := azureActiveDirectoryEndpoint(client.Config) return map[string]string{ "fs.azure.account.auth.type": "OAuth", "fs.azure.account.oauth.provider.type": "org.apache.hadoop.fs.azurebfs.oauth2.ClientCredsTokenProvider", @@ -233,7 +233,7 @@ func (m *AzureADLSGen1MountGeneric) ValidateAndApplyDefaults(d *schema.ResourceD // Config ... func (m *AzureADLSGen1MountGeneric) Config(client *common.DatabricksClient) map[string]string { - aadEndpoint := client.Config.Environment().AzureActiveDirectoryEndpoint() + aadEndpoint := azureActiveDirectoryEndpoint(client.Config) return map[string]string{ m.PrefixType + ".oauth2.access.token.provider.type": "ClientCredential", m.PrefixType + ".oauth2.client.id": m.ClientID, @@ -257,7 +257,7 @@ type AzureBlobMountGeneric struct { // Source ... func (m *AzureBlobMountGeneric) Source(client *common.DatabricksClient) string { return fmt.Sprintf("wasbs://%[1]s@%[2]s.blob.%[3]s%[4]s", - m.ContainerName, m.StorageAccountName, getAzureDomain(client), m.Directory) + m.ContainerName, m.StorageAccountName, azureDomain(client.Config), m.Directory) } func (m *AzureBlobMountGeneric) Name() string {