Skip to content

Commit ceb2430

Browse files
committed
refactored auth, tried not breaking public API while adding re-auth [token gen] mechanism when needed
Signed-off-by: Aviram Hassan <[email protected]>
1 parent 2de8286 commit ceb2430

File tree

2 files changed

+105
-59
lines changed

2 files changed

+105
-59
lines changed

src/client.rs

Lines changed: 96 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ use std::collections::HashMap;
3333
use std::convert::TryFrom;
3434
use std::sync::Arc;
3535
use tokio::io::{AsyncWrite, AsyncWriteExt};
36+
use tokio::sync::RwLock;
3637
use tracing::{debug, trace, warn};
3738

3839
const MIME_TYPES_DISTRIBUTION_MANIFEST: &[&str] = &[
@@ -205,6 +206,8 @@ impl TryFrom<Config> for ConfigFile {
205206
#[derive(Clone)]
206207
pub struct Client {
207208
config: Arc<ClientConfig>,
209+
// Registry -> RegistryAuth
210+
auth_store: Arc<RwLock<HashMap<String, RegistryAuth>>>,
208211
tokens: TokenCache,
209212
client: reqwest::Client,
210213
push_chunk_size: usize,
@@ -213,9 +216,10 @@ pub struct Client {
213216
impl Default for Client {
214217
fn default() -> Self {
215218
Self {
216-
config: Arc::new(ClientConfig::default()),
217-
tokens: TokenCache::new(),
218-
client: reqwest::Client::new(),
219+
config: Arc::default(),
220+
auth_store: Arc::default(),
221+
tokens: TokenCache::default(),
222+
client: reqwest::Client::default(),
219223
push_chunk_size: PUSH_CHUNK_MAX_SIZE,
220224
}
221225
}
@@ -257,9 +261,9 @@ impl TryFrom<ClientConfig> for Client {
257261

258262
Ok(Self {
259263
config: Arc::new(config),
260-
tokens: TokenCache::new(),
261264
client: client_builder.build()?,
262265
push_chunk_size: PUSH_CHUNK_MAX_SIZE,
266+
..Default::default()
263267
})
264268
}
265269
}
@@ -271,10 +275,8 @@ impl Client {
271275
warn!("Cannot create OCI client from config: {:?}", err);
272276
warn!("Creating client with default configuration");
273277
Self {
274-
config: Arc::new(ClientConfig::default()),
275-
tokens: TokenCache::new(),
276-
client: reqwest::Client::new(),
277278
push_chunk_size: PUSH_CHUNK_MAX_SIZE,
279+
..Default::default()
278280
}
279281
})
280282
}
@@ -284,6 +286,41 @@ impl Client {
284286
Self::new(config_source.client_config())
285287
}
286288

289+
async fn store_auth(&self, registry: &str, auth: RegistryAuth) {
290+
self.auth_store
291+
.write()
292+
.await
293+
.insert(registry.to_string(), auth);
294+
}
295+
296+
async fn is_stored_auth(&self, registry: &str) -> bool {
297+
self.auth_store.read().await.contains_key(registry)
298+
}
299+
300+
async fn store_auth_if_needed(&self, registry: &str, auth: &RegistryAuth) {
301+
if !self.is_stored_auth(registry).await {
302+
self.store_auth(registry, auth.clone()).await;
303+
}
304+
}
305+
306+
/// Checks if we got a token, if we don't - create it and store it in cache.
307+
async fn get_auth_token(
308+
&self,
309+
reference: &Reference,
310+
op: RegistryOperation,
311+
) -> Option<RegistryTokenType> {
312+
let registry = reference.resolve_registry();
313+
let auth = self.auth_store.read().await.get(registry)?.clone();
314+
match self.tokens.get(reference, op).await {
315+
Some(token) => Some(token),
316+
None => {
317+
let token = self._auth(reference, &auth, op).await.ok()??;
318+
self.tokens.insert(reference, op, token.clone()).await;
319+
Some(token)
320+
}
321+
}
322+
}
323+
287324
/// Fetches the available Tags for the given Reference
288325
///
289326
/// The client will check if it's already been authenticated and if
@@ -298,9 +335,8 @@ impl Client {
298335
let op = RegistryOperation::Pull;
299336
let url = self.to_list_tags_url(image);
300337

301-
if !self.tokens.contains_key(image, op).await {
302-
self.auth(image, auth, op).await?;
303-
}
338+
self.store_auth_if_needed(image.resolve_registry(), auth)
339+
.await;
304340

305341
let request = self.client.get(&url);
306342
let request = if let Some(num) = n {
@@ -342,10 +378,8 @@ impl Client {
342378
accepted_media_types: Vec<&str>,
343379
) -> Result<ImageData> {
344380
debug!("Pulling image: {:?}", image);
345-
let op = RegistryOperation::Pull;
346-
if !self.tokens.contains_key(image, op).await {
347-
self.auth(image, auth, op).await?;
348-
}
381+
self.store_auth_if_needed(image.resolve_registry(), auth)
382+
.await;
349383

350384
let (manifest, digest, config) = self._pull_manifest_and_config(image).await?;
351385

@@ -400,10 +434,8 @@ impl Client {
400434
manifest: Option<OciImageManifest>,
401435
) -> Result<PushResponse> {
402436
debug!("Pushing image: {:?}", image_ref);
403-
let op = RegistryOperation::Push;
404-
if !self.tokens.contains_key(image_ref, op).await {
405-
self.auth(image_ref, auth, op).await?;
406-
}
437+
self.store_auth_if_needed(image_ref.resolve_registry(), auth)
438+
.await;
407439

408440
let manifest: OciImageManifest = match manifest {
409441
Some(m) => m,
@@ -502,6 +534,36 @@ impl Client {
502534
authentication: &RegistryAuth,
503535
operation: RegistryOperation,
504536
) -> Result<Option<String>> {
537+
// preserve old caching behavior
538+
match self._auth(image, authentication, operation).await {
539+
Ok(Some(RegistryTokenType::Bearer(token))) => {
540+
self.tokens
541+
.insert(image, operation, RegistryTokenType::Bearer(token.clone()))
542+
.await;
543+
Ok(Some(token.token().to_string()))
544+
}
545+
Ok(Some(RegistryTokenType::Basic(username, password))) => {
546+
self.tokens
547+
.insert(
548+
image,
549+
operation,
550+
RegistryTokenType::Basic(username, password),
551+
)
552+
.await;
553+
Ok(None)
554+
}
555+
Ok(None) => Ok(None),
556+
Err(e) => Err(e),
557+
}
558+
}
559+
560+
/// Internal auth that retrieves token.
561+
async fn _auth(
562+
&self,
563+
image: &Reference,
564+
authentication: &RegistryAuth,
565+
operation: RegistryOperation,
566+
) -> Result<Option<RegistryTokenType>> {
505567
debug!("Authorizing for image: {:?}", image);
506568
// The version request will tell us where to go.
507569
let url = format!(
@@ -521,13 +583,10 @@ impl Client {
521583
Err(e) => {
522584
debug!(error = ?e, "Falling back to HTTP Basic Auth");
523585
if let RegistryAuth::Basic(username, password) = authentication {
524-
self.tokens
525-
.insert(
526-
image,
527-
operation,
528-
RegistryTokenType::Basic(username.to_string(), password.to_string()),
529-
)
530-
.await;
586+
return Ok(Some(RegistryTokenType::Basic(
587+
username.to_string(),
588+
password.to_string(),
589+
)));
531590
}
532591
return Ok(None);
533592
}
@@ -566,11 +625,7 @@ impl Client {
566625
let token: RegistryToken = serde_json::from_str(&text)
567626
.map_err(|e| OciDistributionError::RegistryTokenDecodeError(e.to_string()))?;
568627
debug!("Successfully authorized for image '{:?}'", image);
569-
let oauth_token = token.token().to_string();
570-
self.tokens
571-
.insert(image, operation, RegistryTokenType::Bearer(token))
572-
.await;
573-
Ok(Some(oauth_token))
628+
Ok(Some(RegistryTokenType::Bearer(token)))
574629
}
575630
_ => {
576631
let reason = auth_res.text().await?;
@@ -593,10 +648,8 @@ impl Client {
593648
image: &Reference,
594649
auth: &RegistryAuth,
595650
) -> Result<String> {
596-
let op = RegistryOperation::Pull;
597-
if !self.tokens.contains_key(image, op).await {
598-
self.auth(image, auth, op).await?;
599-
}
651+
self.store_auth_if_needed(image.resolve_registry(), auth)
652+
.await;
600653

601654
let url = self.to_v2_manifest_url(image);
602655
debug!("HEAD image manifest from {}", url);
@@ -670,10 +723,8 @@ impl Client {
670723
image: &Reference,
671724
auth: &RegistryAuth,
672725
) -> Result<(OciImageManifest, String)> {
673-
let op = RegistryOperation::Pull;
674-
if !self.tokens.contains_key(image, op).await {
675-
self.auth(image, auth, op).await?;
676-
}
726+
self.store_auth_if_needed(image.resolve_registry(), auth)
727+
.await;
677728

678729
self._pull_image_manifest(image).await
679730
}
@@ -690,10 +741,8 @@ impl Client {
690741
image: &Reference,
691742
auth: &RegistryAuth,
692743
) -> Result<(OciManifest, String)> {
693-
let op = RegistryOperation::Pull;
694-
if !self.tokens.contains_key(image, op).await {
695-
self.auth(image, auth, op).await?;
696-
}
744+
self.store_auth_if_needed(image.resolve_registry(), auth)
745+
.await;
697746

698747
self._pull_manifest(image).await
699748
}
@@ -811,10 +860,8 @@ impl Client {
811860
image: &Reference,
812861
auth: &RegistryAuth,
813862
) -> Result<(OciImageManifest, String, String)> {
814-
let op = RegistryOperation::Pull;
815-
if !self.tokens.contains_key(image, op).await {
816-
self.auth(image, auth, op).await?;
817-
}
863+
self.store_auth_if_needed(image.resolve_registry(), auth)
864+
.await;
818865

819866
self._pull_manifest_and_config(image)
820867
.await
@@ -856,7 +903,8 @@ impl Client {
856903
auth: &RegistryAuth,
857904
manifest: OciImageIndex,
858905
) -> Result<String> {
859-
self.auth(reference, auth, RegistryOperation::Push).await?;
906+
self.store_auth_if_needed(reference.resolve_registry(), auth)
907+
.await;
860908
self.push_manifest(reference, &OciManifest::ImageIndex(manifest))
861909
.await
862910
}
@@ -1368,7 +1416,7 @@ impl<'a> RequestBuilderWrapper<'a> {
13681416
) -> Result<RequestBuilderWrapper> {
13691417
let mut headers = HeaderMap::new();
13701418

1371-
if let Some(token) = self.client.tokens.get(image, op).await {
1419+
if let Some(token) = self.client.get_auth_token(image, op).await {
13721420
match token {
13731421
RegistryTokenType::Bearer(token) => {
13741422
debug!("Using bearer token authentication.");

src/token_cache.rs

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,15 @@ pub enum RegistryOperation {
5959
Pull,
6060
}
6161

62-
type CacheType = BTreeMap<(String, String, RegistryOperation), (RegistryTokenType, u64)>;
62+
// Types to allow better naming
63+
type Registry = String;
64+
type Repository = String;
65+
type TokenCacheKey = (Registry, Repository, RegistryOperation);
66+
type TokenExpiration = u64;
67+
type TokenCacheValue = (RegistryTokenType, TokenExpiration);
68+
69+
// (registry, repository, scope) -> (token, expiration)
70+
type CacheType = BTreeMap<TokenCacheKey, TokenCacheValue>;
6371

6472
#[derive(Default, Clone)]
6573
pub(crate) struct TokenCache {
@@ -68,12 +76,6 @@ pub(crate) struct TokenCache {
6876
}
6977

7078
impl TokenCache {
71-
pub(crate) fn new() -> Self {
72-
TokenCache {
73-
tokens: Arc::new(RwLock::new(BTreeMap::new())),
74-
}
75-
}
76-
7779
pub(crate) async fn insert(
7880
&self,
7981
reference: &Reference,
@@ -158,8 +160,4 @@ impl TokenCache {
158160
}
159161
}
160162
}
161-
162-
pub(crate) async fn contains_key(&self, reference: &Reference, op: RegistryOperation) -> bool {
163-
self.get(reference, op).await.is_some()
164-
}
165163
}

0 commit comments

Comments
 (0)