Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 187 additions & 10 deletions src/auth/src/credentials/mds.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ use google_cloud_gax::retry_policy::RetryPolicyArg;
use google_cloud_gax::retry_throttler::RetryThrottlerArg;
use http::{Extensions, HeaderMap};
use std::default::Default;
use std::sync::Arc;
use std::sync::{Arc, OnceLock};

// TODO(#2235) - Improve this message by talking about retries when really running with MDS
const MDS_NOT_FOUND_ERROR: &str = concat!(
Expand All @@ -105,7 +105,10 @@ where
T: CachedTokenProvider,
{
quota_project_id: Option<String>,
universe_domain_override: Option<String>,
universe_domain: OnceLock<Option<String>>,
token_provider: T,
mds_client: MDSClient,
}

/// Creates [Credentials] instances backed by the [Metadata Service].
Expand All @@ -123,6 +126,7 @@ where
pub struct Builder {
endpoint: Option<String>,
quota_project_id: Option<String>,
universe_domain: Option<String>,
scopes: Option<Vec<String>>,
created_by_adc: bool,
retry_builder: RetryTokenProviderBuilder,
Expand All @@ -135,6 +139,7 @@ impl Default for Builder {
Self {
endpoint: None,
quota_project_id: None,
universe_domain: None,
scopes: None,
created_by_adc: false,
retry_builder: RetryTokenProviderBuilder::default(),
Expand Down Expand Up @@ -177,6 +182,15 @@ impl Builder {
self
}

/// Sets the Google Cloud universe domain for these credentials.
///
/// Any value provided here overrides a `universe_domain` value from the input service account JSON.
#[allow(dead_code)]
pub(crate) fn with_universe_domain<S: Into<String>>(mut self, universe_domain: S) -> Self {
self.universe_domain = Some(universe_domain.into());
self
}

/// Sets the [scopes] for this credentials.
///
/// Metadata server issues tokens based on the requested scopes.
Expand Down Expand Up @@ -319,7 +333,10 @@ impl Builder {
let mds_client = MDSClient::new(self.endpoint.clone());
let mdsc = MDSCredentials {
quota_project_id: self.quota_project_id.clone(),
universe_domain_override: self.universe_domain.clone(),
universe_domain: OnceLock::new(),
token_provider: TokenCache::new(self.build_token_provider()),
mds_client: mds_client.clone(),
};
if !is_access_boundary_enabled {
return Ok(CredentialsWithAccessBoundary::new_no_op(mdsc));
Expand Down Expand Up @@ -375,6 +392,33 @@ where
.maybe_quota_project_id(self.quota_project_id.as_deref())
.build()
}

async fn universe_domain(&self) -> Option<String> {
if let Some(ud) = &self.universe_domain_override {
return Some(ud.clone());
}
if let Some(ud) = self.universe_domain.get() {
return ud.clone();
}

// No overrides and no cache. Try to fetch from MDS.
let response = self.mds_client.universe_domain().send().await;
match response {
Ok(universe_domain) => {
let _ = self.universe_domain.set(Some(universe_domain.clone()));
Some(universe_domain)
}
Err(e) => {
if !e.is_transient() {
// Only cache None if the error is permanent (e.g., 404 on GDU)
let _ = self.universe_domain.set(None);
}
// Return None but do not cache it if it's transient,
// allowing subsequent calls to retry or try again.
None
}
}
}
}

#[async_trait::async_trait]
Expand Down Expand Up @@ -469,7 +513,6 @@ impl TokenProvider for MDSAccessTokenProvider {
#[cfg(test)]
mod tests {
use super::*;
use crate::constants::DEFAULT_UNIVERSE_DOMAIN;
use crate::credentials::QUOTA_PROJECT_KEY;
use crate::credentials::tests::{
find_source_error, get_headers_from_cache, get_mock_auth_retry_policy,
Expand All @@ -479,8 +522,11 @@ mod tests {
use crate::errors;
use crate::errors::CredentialsError;
use crate::mds::client::MDSTokenResponse;
use crate::mds::{GCE_METADATA_HOST_ENV_VAR, MDS_DEFAULT_URI, METADATA_ROOT};
use crate::mds::{
GCE_METADATA_HOST_ENV_VAR, MDS_DEFAULT_URI, MDS_UNIVERSE_DOMAIN_URI, METADATA_ROOT,
};
use crate::token::tests::MockTokenProvider;
use crate::token_cache::TokenCache;
use base64::{Engine, prelude::BASE64_STANDARD};
use http::HeaderValue;
use http::header::AUTHORIZATION;
Expand Down Expand Up @@ -611,6 +657,9 @@ mod tests {
let mdsc = MDSCredentials {
quota_project_id: None,
token_provider: TokenCache::new(mock),
universe_domain_override: None,
universe_domain: OnceLock::new(),
mds_client: MDSClient::new(None),
};

let mut extensions = Extensions::new();
Expand Down Expand Up @@ -672,6 +721,9 @@ mod tests {
let mdsc = MDSCredentials {
quota_project_id: None,
token_provider: TokenCache::new(mock),
universe_domain_override: None,
universe_domain: OnceLock::new(),
mds_client: MDSClient::new(None),
};
let result = mdsc.headers(Extensions::new()).await;
assert!(result.is_err(), "{result:?}");
Expand Down Expand Up @@ -846,7 +898,7 @@ mod tests {
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
#[parallel]
async fn token_caching() -> TestResult {
let mut server = Server::run();
let server = Server::run();
let scopes = vec!["scope1".to_string()];
let response = MDSTokenResponse {
access_token: "test-access-token".to_string(),
Expand Down Expand Up @@ -878,9 +930,6 @@ mod tests {
"test-access-token"
);

// validate that the inner token provider is called only once
server.verify_and_clear();

Ok(())
}

Expand Down Expand Up @@ -1077,9 +1126,132 @@ mod tests {

#[tokio::test]
#[parallel]
async fn get_default_universe_domain_success() -> TestResult {
let universe_domain_response = Builder::default().build()?.universe_domain().await.unwrap();
assert_eq!(universe_domain_response, DEFAULT_UNIVERSE_DOMAIN);
async fn get_default_universe_domain() -> TestResult {
let server = Server::run();
server.expect(
Expectation::matching(all_of![request::path(MDS_UNIVERSE_DOMAIN_URI),])
.respond_with(status_code(404)),
);

let mut mock = MockTokenProvider::new();
mock.expect_token()
.returning(|| Err(crate::errors::non_retryable_from_str("fail")));

let creds = MDSCredentials {
quota_project_id: None,
universe_domain_override: None,
universe_domain: OnceLock::new(),
token_provider: TokenCache::new(mock),
mds_client: crate::mds::client::Client::new(Some(format!("http://{}", server.addr()))),
};

let universe_domain = creds.universe_domain().await;
assert!(universe_domain.is_none());
Ok(())
}

#[tokio::test]
#[parallel]
async fn get_universe_domain_override() -> TestResult {
let creds = Builder::default()
.with_universe_domain("my-universe-domain.com")
.without_access_boundary()
.build()?;
let universe_domain = creds.universe_domain().await;
assert_eq!(universe_domain.as_deref(), Some("my-universe-domain.com"));
Ok(())
}

#[tokio::test]
#[parallel]
async fn get_universe_domain_from_mds() -> TestResult {
let server = Server::run();
server.expect(
Expectation::matching(all_of![request::path(MDS_UNIVERSE_DOMAIN_URI),])
.respond_with(status_code(200).body("my-universe-domain.com")),
);

let mut mock = MockTokenProvider::new();
mock.expect_token()
.returning(|| Err(crate::errors::non_retryable_from_str("fail")));

let creds = MDSCredentials {
quota_project_id: None,
universe_domain_override: None,
universe_domain: OnceLock::new(),
token_provider: TokenCache::new(mock),
mds_client: crate::mds::client::Client::new(Some(format!("http://{}", server.addr()))),
};
let universe_domain = creds.universe_domain().await;
assert_eq!(universe_domain.as_deref(), Some("my-universe-domain.com"));
Ok(())
}

#[tokio::test]
#[parallel]
async fn get_universe_domain_caching() -> TestResult {
let server = Server::run();
server.expect(
Expectation::matching(all_of![request::path(MDS_UNIVERSE_DOMAIN_URI),])
.times(2)
.respond_with(cycle![
status_code(503).body("transient error"),
status_code(200).body("my-universe-domain.com"),
]),
);

let mut mock = MockTokenProvider::new();
mock.expect_token()
.returning(|| Err(crate::errors::non_retryable_from_str("fail")));

let creds = MDSCredentials {
quota_project_id: None,
universe_domain_override: None,
universe_domain: OnceLock::new(),
token_provider: TokenCache::new(mock),
mds_client: crate::mds::client::Client::new(Some(format!("http://{}", server.addr()))),
};

let universe_domain = creds.universe_domain().await;
assert_eq!(universe_domain, None);

let universe_domain = creds.universe_domain().await;
assert_eq!(universe_domain.as_deref(), Some("my-universe-domain.com"));

let universe_domain = creds.universe_domain().await;
assert_eq!(universe_domain.as_deref(), Some("my-universe-domain.com"));

Ok(())
}

#[tokio::test]
#[parallel]
async fn get_universe_domain_caching_permanent_error() -> TestResult {
let server = Server::run();
server.expect(
Expectation::matching(all_of![request::path(MDS_UNIVERSE_DOMAIN_URI),])
.times(1)
.respond_with(status_code(404).body("permanent error")),
);

let mut mock = MockTokenProvider::new();
mock.expect_token()
.returning(|| Err(crate::errors::non_retryable_from_str("fail")));

let creds = MDSCredentials {
quota_project_id: None,
universe_domain_override: None,
universe_domain: OnceLock::new(),
token_provider: TokenCache::new(mock),
mds_client: crate::mds::client::Client::new(Some(format!("http://{}", server.addr()))),
};

let universe_domain = creds.universe_domain().await;
assert_eq!(universe_domain, None);

let universe_domain = creds.universe_domain().await;
assert_eq!(universe_domain, None);

Ok(())
}

Expand Down Expand Up @@ -1134,6 +1306,7 @@ mod tests {
#[cfg(google_cloud_unstable_trusted_boundaries)]
async fn e2e_access_boundary() -> TestResult {
use crate::credentials::tests::get_access_boundary_from_headers;
use crate::mds::MDS_UNIVERSE_DOMAIN_URI;

let server = Server::run();
server.expect(
Expand All @@ -1148,6 +1321,10 @@ mod tests {
Expectation::matching(all_of![request::path(format!("{MDS_DEFAULT_URI}/email")),])
.respond_with(status_code(200).body("test-client-email")),
);
server.expect(
Expectation::matching(all_of![request::path(MDS_UNIVERSE_DOMAIN_URI),])
.respond_with(status_code(404)),
);
server.expect(
Expectation::matching(all_of![
request::method_path(
Expand Down
Loading