Skip to content

Commit 4c80c9c

Browse files
authored
Merge pull request #111 from aviramha/re-auth
Refactor auth - add re-auth mechanism
2 parents 5b0ab7a + f8c23d2 commit 4c80c9c

File tree

2 files changed

+139
-74
lines changed

2 files changed

+139
-74
lines changed

src/client.rs

Lines changed: 106 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ use std::collections::HashMap;
3333
use std::convert::TryFrom;
3434
use std::sync::Arc;
3535
use tokio::io::{AsyncWrite, AsyncWriteExt};
36+
use tokio::sync::RwLock;
3637
use tracing::{debug, trace, warn};
3738

3839
const MIME_TYPES_DISTRIBUTION_MANIFEST: &[&str] = &[
@@ -205,6 +206,8 @@ impl TryFrom<Config> for ConfigFile {
205206
#[derive(Clone)]
206207
pub struct Client {
207208
config: Arc<ClientConfig>,
209+
// Registry -> RegistryAuth
210+
auth_store: Arc<RwLock<HashMap<String, RegistryAuth>>>,
208211
tokens: TokenCache,
209212
client: reqwest::Client,
210213
push_chunk_size: usize,
@@ -213,9 +216,10 @@ pub struct Client {
213216
impl Default for Client {
214217
fn default() -> Self {
215218
Self {
216-
config: Arc::new(ClientConfig::default()),
217-
tokens: TokenCache::new(),
218-
client: reqwest::Client::new(),
219+
config: Arc::default(),
220+
auth_store: Arc::default(),
221+
tokens: TokenCache::default(),
222+
client: reqwest::Client::default(),
219223
push_chunk_size: PUSH_CHUNK_MAX_SIZE,
220224
}
221225
}
@@ -257,9 +261,9 @@ impl TryFrom<ClientConfig> for Client {
257261

258262
Ok(Self {
259263
config: Arc::new(config),
260-
tokens: TokenCache::new(),
261264
client: client_builder.build()?,
262265
push_chunk_size: PUSH_CHUNK_MAX_SIZE,
266+
..Default::default()
263267
})
264268
}
265269
}
@@ -271,10 +275,8 @@ impl Client {
271275
warn!("Cannot create OCI client from config: {:?}", err);
272276
warn!("Creating client with default configuration");
273277
Self {
274-
config: Arc::new(ClientConfig::default()),
275-
tokens: TokenCache::new(),
276-
client: reqwest::Client::new(),
277278
push_chunk_size: PUSH_CHUNK_MAX_SIZE,
279+
..Default::default()
278280
}
279281
})
280282
}
@@ -284,6 +286,41 @@ impl Client {
284286
Self::new(config_source.client_config())
285287
}
286288

289+
async fn store_auth(&self, registry: &str, auth: RegistryAuth) {
290+
self.auth_store
291+
.write()
292+
.await
293+
.insert(registry.to_string(), auth);
294+
}
295+
296+
async fn is_stored_auth(&self, registry: &str) -> bool {
297+
self.auth_store.read().await.contains_key(registry)
298+
}
299+
300+
async fn store_auth_if_needed(&self, registry: &str, auth: &RegistryAuth) {
301+
if !self.is_stored_auth(registry).await {
302+
self.store_auth(registry, auth.clone()).await;
303+
}
304+
}
305+
306+
/// Checks if we got a token, if we don't - create it and store it in cache.
307+
async fn get_auth_token(
308+
&self,
309+
reference: &Reference,
310+
op: RegistryOperation,
311+
) -> Option<RegistryTokenType> {
312+
let registry = reference.resolve_registry();
313+
let auth = self.auth_store.read().await.get(registry)?.clone();
314+
match self.tokens.get(reference, op).await {
315+
Some(token) => Some(token),
316+
None => {
317+
let token = self._auth(reference, &auth, op).await.ok()??;
318+
self.tokens.insert(reference, op, token.clone()).await;
319+
Some(token)
320+
}
321+
}
322+
}
323+
287324
/// Fetches the available Tags for the given Reference
288325
///
289326
/// The client will check if it's already been authenticated and if
@@ -298,9 +335,8 @@ impl Client {
298335
let op = RegistryOperation::Pull;
299336
let url = self.to_list_tags_url(image);
300337

301-
if !self.tokens.contains_key(image, op).await {
302-
self.auth(image, auth, op).await?;
303-
}
338+
self.store_auth_if_needed(image.resolve_registry(), auth)
339+
.await;
304340

305341
let request = self.client.get(&url);
306342
let request = if let Some(num) = n {
@@ -342,10 +378,8 @@ impl Client {
342378
accepted_media_types: Vec<&str>,
343379
) -> Result<ImageData> {
344380
debug!("Pulling image: {:?}", image);
345-
let op = RegistryOperation::Pull;
346-
if !self.tokens.contains_key(image, op).await {
347-
self.auth(image, auth, op).await?;
348-
}
381+
self.store_auth_if_needed(image.resolve_registry(), auth)
382+
.await;
349383

350384
let (manifest, digest, config) = self._pull_manifest_and_config(image).await?;
351385

@@ -400,10 +434,8 @@ impl Client {
400434
manifest: Option<OciImageManifest>,
401435
) -> Result<PushResponse> {
402436
debug!("Pushing image: {:?}", image_ref);
403-
let op = RegistryOperation::Push;
404-
if !self.tokens.contains_key(image_ref, op).await {
405-
self.auth(image_ref, auth, op).await?;
406-
}
437+
self.store_auth_if_needed(image_ref.resolve_registry(), auth)
438+
.await;
407439

408440
let manifest: OciImageManifest = match manifest {
409441
Some(m) => m,
@@ -502,6 +534,38 @@ impl Client {
502534
authentication: &RegistryAuth,
503535
operation: RegistryOperation,
504536
) -> Result<Option<String>> {
537+
self.store_auth_if_needed(image.resolve_registry(), authentication)
538+
.await;
539+
// preserve old caching behavior
540+
match self._auth(image, authentication, operation).await {
541+
Ok(Some(RegistryTokenType::Bearer(token))) => {
542+
self.tokens
543+
.insert(image, operation, RegistryTokenType::Bearer(token.clone()))
544+
.await;
545+
Ok(Some(token.token().to_string()))
546+
}
547+
Ok(Some(RegistryTokenType::Basic(username, password))) => {
548+
self.tokens
549+
.insert(
550+
image,
551+
operation,
552+
RegistryTokenType::Basic(username, password),
553+
)
554+
.await;
555+
Ok(None)
556+
}
557+
Ok(None) => Ok(None),
558+
Err(e) => Err(e),
559+
}
560+
}
561+
562+
/// Internal auth that retrieves token.
563+
async fn _auth(
564+
&self,
565+
image: &Reference,
566+
authentication: &RegistryAuth,
567+
operation: RegistryOperation,
568+
) -> Result<Option<RegistryTokenType>> {
505569
debug!("Authorizing for image: {:?}", image);
506570
// The version request will tell us where to go.
507571
let url = format!(
@@ -521,13 +585,10 @@ impl Client {
521585
Err(e) => {
522586
debug!(error = ?e, "Falling back to HTTP Basic Auth");
523587
if let RegistryAuth::Basic(username, password) = authentication {
524-
self.tokens
525-
.insert(
526-
image,
527-
operation,
528-
RegistryTokenType::Basic(username.to_string(), password.to_string()),
529-
)
530-
.await;
588+
return Ok(Some(RegistryTokenType::Basic(
589+
username.to_string(),
590+
password.to_string(),
591+
)));
531592
}
532593
return Ok(None);
533594
}
@@ -566,11 +627,7 @@ impl Client {
566627
let token: RegistryToken = serde_json::from_str(&text)
567628
.map_err(|e| OciDistributionError::RegistryTokenDecodeError(e.to_string()))?;
568629
debug!("Successfully authorized for image '{:?}'", image);
569-
let oauth_token = token.token().to_string();
570-
self.tokens
571-
.insert(image, operation, RegistryTokenType::Bearer(token))
572-
.await;
573-
Ok(Some(oauth_token))
630+
Ok(Some(RegistryTokenType::Bearer(token)))
574631
}
575632
_ => {
576633
let reason = auth_res.text().await?;
@@ -593,10 +650,8 @@ impl Client {
593650
image: &Reference,
594651
auth: &RegistryAuth,
595652
) -> Result<String> {
596-
let op = RegistryOperation::Pull;
597-
if !self.tokens.contains_key(image, op).await {
598-
self.auth(image, auth, op).await?;
599-
}
653+
self.store_auth_if_needed(image.resolve_registry(), auth)
654+
.await;
600655

601656
let url = self.to_v2_manifest_url(image);
602657
debug!("HEAD image manifest from {}", url);
@@ -670,10 +725,8 @@ impl Client {
670725
image: &Reference,
671726
auth: &RegistryAuth,
672727
) -> Result<(OciImageManifest, String)> {
673-
let op = RegistryOperation::Pull;
674-
if !self.tokens.contains_key(image, op).await {
675-
self.auth(image, auth, op).await?;
676-
}
728+
self.store_auth_if_needed(image.resolve_registry(), auth)
729+
.await;
677730

678731
self._pull_image_manifest(image).await
679732
}
@@ -690,10 +743,8 @@ impl Client {
690743
image: &Reference,
691744
auth: &RegistryAuth,
692745
) -> Result<(OciManifest, String)> {
693-
let op = RegistryOperation::Pull;
694-
if !self.tokens.contains_key(image, op).await {
695-
self.auth(image, auth, op).await?;
696-
}
746+
self.store_auth_if_needed(image.resolve_registry(), auth)
747+
.await;
697748

698749
self._pull_manifest(image).await
699750
}
@@ -811,10 +862,8 @@ impl Client {
811862
image: &Reference,
812863
auth: &RegistryAuth,
813864
) -> Result<(OciImageManifest, String, String)> {
814-
let op = RegistryOperation::Pull;
815-
if !self.tokens.contains_key(image, op).await {
816-
self.auth(image, auth, op).await?;
817-
}
865+
self.store_auth_if_needed(image.resolve_registry(), auth)
866+
.await;
818867

819868
self._pull_manifest_and_config(image)
820869
.await
@@ -855,7 +904,8 @@ impl Client {
855904
auth: &RegistryAuth,
856905
manifest: OciImageIndex,
857906
) -> Result<String> {
858-
self.auth(reference, auth, RegistryOperation::Push).await?;
907+
self.store_auth_if_needed(reference.resolve_registry(), auth)
908+
.await;
859909
self.push_manifest(reference, &OciManifest::ImageIndex(manifest))
860910
.await
861911
}
@@ -1418,7 +1468,7 @@ impl<'a> RequestBuilderWrapper<'a> {
14181468
) -> Result<RequestBuilderWrapper> {
14191469
let mut headers = HeaderMap::new();
14201470

1421-
if let Some(token) = self.client.tokens.get(image, op).await {
1471+
if let Some(token) = self.client.get_auth_token(image, op).await {
14221472
match token {
14231473
RegistryTokenType::Bearer(token) => {
14241474
debug!("Using bearer token authentication.");
@@ -1816,6 +1866,14 @@ mod test {
18161866
.as_str()
18171867
.to_string();
18181868

1869+
// we have to have it in the stored auth so we'll get to the token cache check.
1870+
client
1871+
.store_auth(
1872+
&Reference::try_from(HELLO_IMAGE_TAG)?.resolve_registry(),
1873+
RegistryAuth::Anonymous,
1874+
)
1875+
.await;
1876+
18191877
client
18201878
.tokens
18211879
.insert(

src/token_cache.rs

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -59,21 +59,25 @@ pub enum RegistryOperation {
5959
Pull,
6060
}
6161

62-
type CacheType = BTreeMap<(String, String, RegistryOperation), (RegistryTokenType, u64)>;
62+
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
63+
struct TokenCacheKey {
64+
registry: String,
65+
repository: String,
66+
operation: RegistryOperation,
67+
}
68+
69+
struct TokenCacheValue {
70+
token: RegistryTokenType,
71+
expiration: u64,
72+
}
6373

6474
#[derive(Default, Clone)]
6575
pub(crate) struct TokenCache {
6676
// (registry, repository, scope) -> (token, expiration)
67-
tokens: Arc<RwLock<CacheType>>,
77+
tokens: Arc<RwLock<BTreeMap<TokenCacheKey, TokenCacheValue>>>,
6878
}
6979

7080
impl TokenCache {
71-
pub(crate) fn new() -> Self {
72-
TokenCache {
73-
tokens: Arc::new(RwLock::new(BTreeMap::new())),
74-
}
75-
}
76-
7781
pub(crate) async fn insert(
7882
&self,
7983
reference: &Reference,
@@ -119,10 +123,14 @@ impl TokenCache {
119123
let registry = reference.resolve_registry().to_string();
120124
let repository = reference.repository().to_string();
121125
debug!(%registry, %repository, ?op, %expiration, "Inserting token");
122-
self.tokens
123-
.write()
124-
.await
125-
.insert((registry, repository, op), (token, expiration));
126+
self.tokens.write().await.insert(
127+
TokenCacheKey {
128+
registry,
129+
repository,
130+
operation: op,
131+
},
132+
TokenCacheValue { token, expiration },
133+
);
126134
}
127135

128136
pub(crate) async fn get(
@@ -132,34 +140,33 @@ impl TokenCache {
132140
) -> Option<RegistryTokenType> {
133141
let registry = reference.resolve_registry().to_string();
134142
let repository = reference.repository().to_string();
135-
match self
136-
.tokens
137-
.read()
138-
.await
139-
.get(&(registry.clone(), repository.clone(), op))
140-
{
141-
Some((ref token, expiration)) => {
143+
let key = TokenCacheKey {
144+
registry,
145+
repository,
146+
operation: op,
147+
};
148+
match self.tokens.read().await.get(&key) {
149+
Some(TokenCacheValue {
150+
ref token,
151+
expiration,
152+
}) => {
142153
let now = SystemTime::now();
143154
let epoch = now
144155
.duration_since(UNIX_EPOCH)
145156
.expect("Time went backwards")
146157
.as_secs();
147158
if epoch > *expiration {
148-
debug!(%registry, %repository, ?op, %expiration, miss=false, expired=true, "Fetching token");
159+
debug!(%key.registry, %key.repository, ?key.operation, %expiration, miss=false, expired=true, "Fetching token");
149160
None
150161
} else {
151-
debug!(%registry, %repository, ?op, %expiration, miss=false, expired=false, "Fetching token");
162+
debug!(%key.registry, %key.repository, ?key.operation, %expiration, miss=false, expired=false, "Fetching token");
152163
Some(token.clone())
153164
}
154165
}
155166
None => {
156-
debug!(%registry, %repository, ?op, miss=true, "Fetching token");
167+
debug!(%key.registry, %key.repository, ?key.operation, miss = true, "Fetching token");
157168
None
158169
}
159170
}
160171
}
161-
162-
pub(crate) async fn contains_key(&self, reference: &Reference, op: RegistryOperation) -> bool {
163-
self.get(reference, op).await.is_some()
164-
}
165172
}

0 commit comments

Comments
 (0)