Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions app/Env.hs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import Data.ByteString.Char8 qualified as BS
import Data.Char
import Data.Char qualified as Char
import Data.Either.Combinators
import Data.Map qualified as Map
import Data.Functor
import Data.HashMap.Strict qualified as HM
import Data.Set qualified as Set
Expand Down Expand Up @@ -73,6 +74,8 @@ withEnv action = do
maxParallelismPerDownloadRequest <- fromEnv "SHARE_MAX_PARALLELISM_PER_DOWNLOAD_REQUEST" (pure . maybeToEither "Invalid SHARE_MAX_PARALLELISM_PER_DOWNLOAD_REQUEST" . readMaybe)
maxParallelismPerUploadRequest <- fromEnv "SHARE_MAX_PARALLELISM_PER_UPLOAD_REQUEST" (pure . maybeToEither "Invalid SHARE_MAX_PARALLELISM_PER_UPLOAD_REQUEST" . readMaybe)
cloudWebsiteOrigin <- fromEnv "SHARE_CLOUD_HOMEPAGE_ORIGIN" (pure . maybeToEither "Invalid SHARE_CLOUD_HOMEPAGE_ORIGIN" . parseURI)
cloudAPIOrigin <- fromEnv "SHARE_CLOUD_API_ORIGIN" (pure . maybeToEither "Invalid SHARE_CLOUD_API_ORIGIN" . parseURI)
cloudAPIJWKEndpoint <- fromEnv "SHARE_CLOUD_API_JWKS_ENDPOINT" (pure . maybeToEither "Invalid SHARE_CLOUD_API_JWKS_ENDPOINT" . parseURI)

sentryService <-
lookupEnv "SHARE_SENTRY_DSN" >>= \case
Expand All @@ -90,18 +93,23 @@ withEnv action = do
| Deployment.onLocal = Nothing
| otherwise = Nothing
in r {Redis.connectTLSParams = tlsParams}
let acceptedAudiences = Set.singleton apiOrigin
let acceptedIssuers = Set.singleton apiOrigin
let shareAudience = JWT.Audience apiOrigin
let shareIssuer = JWT.Issuer apiOrigin
let cloudIssuer = JWT.Issuer cloudAPIOrigin
let acceptedAudiences = Set.singleton $ shareAudience
let acceptedIssuers = Set.fromList [shareIssuer, cloudIssuer]
let legacyKey = JWT.KeyDescription {JWT.key = hs256Key, JWT.alg = JWT.HS256}
let signingKey = JWT.KeyDescription {JWT.key = edDSAKey, JWT.alg = JWT.Ed25519}
let externalJWKs = Map.fromList [ (cloudIssuer, Left cloudAPIJWKEndpoint)
]
hashJWTJWK <- case JWT.keyDescToJWK legacyKey of
Left err -> throwIO err
Right (_thumbprint, jwk) -> pure jwk
Right jwk -> pure jwk
-- I explicitly add the legacy key to the validation keys, so that the thumbprinted
-- version of the key is used for validation, which is needed for HashJWTs which are signed
-- with a 'kid'.
let validationKeys = Set.fromList [legacyKey]
jwtSettings <- case JWT.defaultJWTSettings signingKey (Just legacyKey) validationKeys acceptedAudiences acceptedIssuers of
jwtSettings <- JWT.defaultJWTSettings shareIssuer signingKey (Just legacyKey) validationKeys acceptedAudiences acceptedIssuers externalJWKs >>= \case
Left cryptoError -> throwIO cryptoError
Right settings -> pure settings
let cookieSettings = Cookies.defaultCookieSettings Deployment.onLocal (Just (realToFrac cookieSessionTTL))
Expand Down
2 changes: 2 additions & 0 deletions docker/docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ services:
- SHARE_CLOUD_UI_ORIGIN=http://localhost:5678
- SHARE_HOMEPAGE_ORIGIN=http://localhost:1111
- SHARE_CLOUD_HOMEPAGE_ORIGIN=http://localhost:2222
- SHARE_CLOUD_API_ORIGIN=http://localhost:3333
- SHARE_CLOUD_API_JWKS_ENDPOINT=http://localhost:3333/.well-known/jwks.json
- SHARE_LOG_LEVEL=DEBUG
- SHARE_COMMIT=dev
- SHARE_MAX_PARALLELISM_PER_DOWNLOAD_REQUEST=1
Expand Down
2 changes: 2 additions & 0 deletions local.env
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ export SHARE_SHARE_UI_ORIGIN="http://localhost:1234"
export SHARE_CLOUD_UI_ORIGIN="http://localhost:5678"
export SHARE_HOMEPAGE_ORIGIN="http://localhost:1111"
export SHARE_CLOUD_HOMEPAGE_ORIGIN="http://localhost:2222"
export SHARE_CLOUD_API_ORIGIN="http://localhost:3333"
export SHARE_CLOUD_API_JWKS_ENDPOINT="http://localhost:3333/.well-known/jwks.json"
export SHARE_LOG_LEVEL="DEBUG"
export SHARE_COMMIT="dev"
export SHARE_MAX_PARALLELISM_PER_DOWNLOAD_REQUEST="1"
Expand Down
1 change: 0 additions & 1 deletion share-api.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ library
Share.Ticket
Share.User
Share.UserProfile
Share.Utils.API
Share.Utils.Caching
Share.Utils.Caching.JSON
Share.Utils.Data
Expand Down
3 changes: 3 additions & 0 deletions share-auth/example/package.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,11 @@ dependencies:
- containers
- share-auth
- share-utils
- jose
- aeson
- hedis
- network-uri
- raw-strings-qq
- servant
- servant-auth-server
- servant-server
Expand Down
49 changes: 43 additions & 6 deletions share-auth/example/src/Lib.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

module Lib (main) where

import Data.Aeson qualified as Aeson
import Data.Map (Map)
import Data.Map qualified as Map
import Data.Maybe (fromJust, fromMaybe)
import Data.Set qualified as Set
import Data.Text (Text)
Expand All @@ -22,6 +25,7 @@ import Share.OAuth.ServiceProvider qualified as Auth
import Share.OAuth.Session (AuthCheckCtx, AuthenticatedUserId, MaybeAuthenticatedUserId, addAuthCheckCtx)
import Share.OAuth.Types (OAuthClientId (..), OAuthClientSecret (OAuthClientSecret), RedirectReceiverErr, UserId)
import Share.Utils.Servant.Cookies qualified as Cookies
import Text.RawString.QQ (r)
import UnliftIO

-- | An example application endpoint which is optionally authenticated.
Expand Down Expand Up @@ -78,10 +82,11 @@ main = do
redisConn <- R.checkedConnect R.defaultConnectInfo
putStrLn "booting up"

jwtSettings <- case JWT.defaultJWTSettings signingKey (Just legacyKey) rotatedKeys acceptedAudiences acceptedIssuers of
Left cryptoError -> throwIO cryptoError
Right jwtS -> do
pure jwtS
jwtSettings <-
JWT.defaultJWTSettings issuer signingKey (Just legacyKey) rotatedKeys acceptedAudiences acceptedIssuers externalJWKs >>= \case
Left cryptoError -> throwIO cryptoError
Right jwtS -> do
pure jwtS

Warp.run 3030 $ serveWithContext (Proxy @MyAPI) (ctx jwtSettings) (myServer redisConn jwtSettings)
putStrLn "exiting"
Expand Down Expand Up @@ -135,7 +140,39 @@ main = do
signingKey = JWT.KeyDescription {JWT.key = edDSAKey, JWT.alg = JWT.Ed25519}
rotatedKeys = Set.empty
api = unsafeURI "http://cloud:3030"
serviceAudience = api
serviceAudience = JWT.Audience api
acceptedAudiences = Set.singleton serviceAudience
issuer = unsafeURI "http://localhost:5424"
issuer = JWT.Issuer $ unsafeURI "http://localhost:5424"
acceptedIssuers = Set.singleton issuer
externalJWKs :: Map JWT.Issuer (Either URI JWT.JWKSet)
externalJWKs =
Map.fromList
[ -- This will fetch jwks from the identity provider directly, and keep them up to
-- date.
( JWT.Issuer $ unsafeURI "http://cloud:3030",
Left $ unsafeURI "http://cloud:3030/.well-known/jwks.json"
),
-- This will use the provided static JWK set.
( JWT.Issuer $ unsafeURI "https://api.unison.cloud",
Right
. fromJust
. Aeson.decode @JWT.JWKSet
$
-- This is a sample JWK set, replace with your own.
-- The key is an Ed25519 key, which is used for signing JWTs.
[r|
{
"keys": [
{
"alg": "EdDSA",
"crv": "Ed25519",
"kid": "ZGRwKNuN0LlKkg2WCm4ZSQ1IRzBS2ej5NCTJW1KhFOY",
"kty": "OKP",
"use": "sig",
"x": "rl4D9BawfhIP5M2UEKn30QG1BD3rjQMSLE9oFiUEJpo"
}
]
}
|]
)
]
17 changes: 13 additions & 4 deletions share-auth/example/test-auth-app.cabal
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
cabal-version: 1.12

-- This file has been generated from package.yaml by hpack version 0.35.2.
-- This file has been generated from package.yaml by hpack version 0.37.0.
--
-- see: https://github.com/sol/hpack

Expand Down Expand Up @@ -59,10 +59,13 @@ library
ImportQualifiedPost
ghc-options: -Wall -Wcompat -Widentities -Wincomplete-record-updates -Wincomplete-uni-patterns -Wmissing-export-lists -Wmissing-home-modules -Wpartial-fields -Wredundant-constraints
build-depends:
base >=4.7 && <5
aeson
, base >=4.7 && <5
, containers
, hedis
, jose
, network-uri
, raw-strings-qq
, servant
, servant-auth-server
, servant-server
Expand Down Expand Up @@ -110,10 +113,13 @@ executable test-auth-app-exe
ImportQualifiedPost
ghc-options: -Wall -Wcompat -Widentities -Wincomplete-record-updates -Wincomplete-uni-patterns -Wmissing-export-lists -Wmissing-home-modules -Wpartial-fields -Wredundant-constraints -threaded -rtsopts -with-rtsopts=-N
build-depends:
base >=4.7 && <5
aeson
, base >=4.7 && <5
, containers
, hedis
, jose
, network-uri
, raw-strings-qq
, servant
, servant-auth-server
, servant-server
Expand Down Expand Up @@ -163,10 +169,13 @@ test-suite test-auth-app-test
ImportQualifiedPost
ghc-options: -Wall -Wcompat -Widentities -Wincomplete-record-updates -Wincomplete-uni-patterns -Wmissing-export-lists -Wmissing-home-modules -Wpartial-fields -Wredundant-constraints -threaded -rtsopts -with-rtsopts=-N
build-depends:
base >=4.7 && <5
aeson
, base >=4.7 && <5
, containers
, hedis
, jose
, network-uri
, raw-strings-qq
, servant
, servant-auth-server
, servant-server
Expand Down
77 changes: 54 additions & 23 deletions share-auth/src/Share/JWT.hs
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,22 @@ module Share.JWT

-- * Utilities
JWTParam (..),
Issuer (..),
Audience (..),
textToSignedJWT,
signedJWTToText,
createSignedCookie,

-- * Re-exports
CryptoError (..),
JWK.JWK,
JWK.JWKSet,
)
where

import Control.Lens
import Control.Monad.Except
import Control.Monad.Trans.Except (except)
import Crypto.Error (CryptoError (..), CryptoFailable (..))
import Crypto.JOSE.JWA.JWS qualified as JWS
import Crypto.JOSE.JWK qualified as JWK
Expand All @@ -50,32 +55,41 @@ import Data.Aeson qualified as Aeson
import Data.ByteArray qualified as ByteArray
import Data.ByteString qualified as BS
import Data.ByteString.Base64.URL qualified as Base64URL
import Data.Map (Map)
import Data.Map qualified as Map
import Data.Set (Set)
import Data.Set qualified as Set
import Data.Text (Text)
import Data.Text qualified as Text
import Data.Text.Encoding qualified as Text
import Data.Traversable (for)
import Servant
import Share.JWT.Types
import Share.OAuth.Orphans ()
import Share.Utils.Servant.Cookies qualified as Cookies
import UnliftIO (MonadIO (..))
import UnliftIO.STM

-- | Get the JWK Set value which is safe to expose to the public, e.g. in a JWKS endpoint.
-- | Get the JWK Set for an issuer which is safe to expose to the public, e.g. in a JWKS endpoint.
-- This will only include public keys.
--
-- Note that this will not include the legacy key or any HS256 keys, since those don't have a
-- safe public component.
publicJWKSet :: JWTSettings -> JWK.JWKSet
publicJWKSet JWTSettings {validationKeys = KeyMap {byKeyId}} =
JWK.JWKSet
( byKeyId
& foldMap (\jwk -> jwk ^.. JWK.asPublicKey . _Just)
)
publicJWKSet :: (MonadIO m) => JWTSettings -> Issuer -> m JWK.JWKSet
publicJWKSet JWTSettings {validationKeys = KeyMap {keysVar}} issuer = do
keyMap <- liftIO $ readTVarIO keysVar
pure $
JWK.JWKSet
( keyMap
& Map.lookup issuer
& foldMap (\jwk -> jwk ^.. folded . JWK.asPublicKey . _Just)
)

-- | Create a 'JWTSettings' using the required information.
defaultJWTSettings ::
(MonadIO m) =>
-- | Which issuer is the current service
Issuer ->
-- | The key used to sign JWTs.
KeyDescription ->
-- | The legacy key used to verify old JWTs from before key IDs were used. This will be used to verify tokens that don't have a key id.
Expand All @@ -86,25 +100,42 @@ defaultJWTSettings ::
-- Tokens must have an audience which is present in this set.
--
-- E.g. https://api.unison.cloud
Set URI ->
Set Audience ->
-- | Valid issuers when validating tokens
Set URI ->
Either CryptoError JWTSettings
defaultJWTSettings signingKey legacyKey oldValidKeys acceptedAudiences acceptedIssuers = do
sjwk@(_, signingJWK) <- keyDescToJWK signingKey
verificationJWKs <- (sjwk :) <$> traverse keyDescToJWK (Set.toList oldValidKeys)
let byKeyId = Map.fromList verificationJWKs
legacyKey <- traverse keyDescToJWK legacyKey <&> fmap snd
Set Issuer ->
-- | Mapping of issuers to either their:
-- * JWK json endpoint.
-- * JWK set.
--
-- If a JWK URI is provided for an issuer it will be fetched and kept up to date as needed.
Map Issuer (Either URI JWT.JWKSet) ->
m (Either CryptoError JWTSettings)
defaultJWTSettings myIssuer signingKey legacyKey oldValidKeys acceptedAudiences acceptedIssuers externalJWKsMap = runExceptT $ do
signingJWK <- except $ keyDescToJWK signingKey
myVerificationJWKs <- (signingJWK :) <$> traverse (except . keyDescToJWK) (Set.toList oldValidKeys)
let (externalJWKs, externalJWKLocations) =
externalJWKsMap
& foldMap \case
Left uri -> (mempty, Map.singleton myIssuer uri)
Right (JWT.JWKSet jwks) -> (Map.singleton myIssuer jwks, mempty)

let myJWKs = Map.singleton myIssuer myVerificationJWKs
let keysMap = myJWKs <> externalJWKs
keysVar <- liftIO $ newTVarIO keysMap
lastCheckedVar <- liftIO $ newTVarIO Map.empty
legacyKey <- for legacyKey (except . keyDescToJWK)
pure $
JWTSettings
{ signingJWK,
validationKeys = KeyMap {byKeyId, legacyKey},
validationKeys = KeyMap {keysVar, legacyKey},
externalJWKLocations,
lastCheckedVar,
acceptedAudiences,
acceptedIssuers
}

-- | Converts a 'KeyDescription' to a 'JWK' and a 'KeyThumbprint'.
keyDescToJWK :: KeyDescription -> Either CryptoError (KeyThumbprint, JWK.JWK)
keyDescToJWK :: KeyDescription -> Either CryptoError JWK.JWK
keyDescToJWK (KeyDescription {key, alg}) = cryptoFailableToEither $ do
case alg of
HS256 -> do
Expand All @@ -113,7 +144,7 @@ keyDescToJWK (KeyDescription {key, alg}) = cryptoFailableToEither $ do
& JWK.jwkUse .~ Just JWK.Sig
& JWK.jwkAlg .~ Just (JWK.JWSAlg JWS.HS256)
let thumbprint = jwkThumbprint jwk
pure (KeyThumbprint thumbprint, jwk & JWK.jwkKid .~ Just thumbprint)
pure (jwk & JWK.jwkKid .~ Just thumbprint)
Ed25519 -> do
privKey <- Ed25519.secretKey key
let pubKey = Ed25519.toPublic privKey
Expand All @@ -124,7 +155,7 @@ keyDescToJWK (KeyDescription {key, alg}) = cryptoFailableToEither $ do
& JWK.jwkUse .~ Just JWK.Sig
& JWK.jwkAlg .~ Just (JWK.JWSAlg JWS.EdDSA)
let thumbprint = jwkThumbprint jwk
pure (KeyThumbprint thumbprint, jwk & JWK.jwkKid .~ Just thumbprint)
pure (jwk & JWK.jwkKid .~ Just thumbprint)
where
cryptoFailableToEither :: CryptoFailable a -> Either CryptoError a
cryptoFailableToEither (CryptoFailed err) = Left err
Expand Down Expand Up @@ -160,8 +191,8 @@ signJWTWithJWK jwk v = runExceptT $ do
--
-- Any other checks should be performed on the returned claims.
verifyJWT :: forall claims m. (AsJWTClaims claims, MonadIO m) => JWTSettings -> JWT.SignedJWT -> m (Either JWT.JWTError claims)
verifyJWT JWTSettings {validationKeys, acceptedAudiences, acceptedIssuers} signedJWT = runExceptT do
jwtClaimsMap <- ExceptT . liftIO . runExceptT $ JWT.verifyJWT validators validationKeys signedJWT
verifyJWT jwtSettings@JWTSettings {acceptedAudiences, acceptedIssuers} signedJWT = runExceptT do
jwtClaimsMap <- ExceptT . liftIO . runExceptT $ JWT.verifyJWT validators jwtSettings signedJWT
case fromClaims jwtClaimsMap of
Left err -> throwError $ JWT.JWTClaimsSetDecodeError (Text.unpack err)
Right claims -> pure claims
Expand All @@ -170,12 +201,12 @@ verifyJWT JWTSettings {validationKeys, acceptedAudiences, acceptedIssuers} signe
auds =
-- Annoyingly StringOrURI doesn't have an ord instance.
Set.toList acceptedAudiences
& map (review CryptoJWT.uri)
& map (\(Audience aud) -> review CryptoJWT.uri aud)
issuers :: [CryptoJWT.StringOrURI]
issuers =
-- Annoyingly StringOrURI doesn't have an ord instance.
Set.toList acceptedIssuers
& map (review CryptoJWT.uri)
& map (\(Issuer iss) -> review CryptoJWT.uri iss)
validators =
CryptoJWT.defaultJWTValidationSettings (`elem` auds)
& CryptoJWT.issuerPredicate .~ (`elem` issuers)
Expand Down
Loading
Loading