diff --git a/cardano-diffusion/changelog.d/20251016_142053_coot_dmq_signature_validation.md b/cardano-diffusion/changelog.d/20251016_142053_coot_dmq_signature_validation.md new file mode 100644 index 00000000000..2f864c81146 --- /dev/null +++ b/cardano-diffusion/changelog.d/20251016_142053_coot_dmq_signature_validation.md @@ -0,0 +1,4 @@ +### Non-Breaking + +- Addapted tests to changes in the `Ouroboros.Network.TxSubmission.Mempool.Simple` API + diff --git a/cardano-diffusion/tests/lib/Test/Cardano/Network/Diffusion/Testnet/MiniProtocols.hs b/cardano-diffusion/tests/lib/Test/Cardano/Network/Diffusion/Testnet/MiniProtocols.hs index e9e3a13c985..85129d03268 100644 --- a/cardano-diffusion/tests/lib/Test/Cardano/Network/Diffusion/Testnet/MiniProtocols.hs +++ b/cardano-diffusion/tests/lib/Test/Cardano/Network/Diffusion/Testnet/MiniProtocols.hs @@ -682,7 +682,7 @@ applications debugTracer txSubmissionInboundTracer txSubmissionInboundDebug node txSubmissionInitiator :: TxDecisionPolicy - -> Mempool m (Tx TxId) + -> Mempool m TxId (Tx TxId) -> MiniProtocolCb (ExpandedInitiatorContext NtNAddr m) ByteString m () txSubmissionInitiator txDecisionPolicy mempool = MiniProtocolCb $ @@ -709,7 +709,7 @@ applications debugTracer txSubmissionInboundTracer txSubmissionInboundDebug node (txSubmissionClientPeer client) txSubmissionResponder - :: Mempool m (Tx TxId) + :: Mempool m TxId (Tx TxId) -> TxChannelsVar m NtNAddr Int (Tx Int) -> TxMempoolSem m -> SharedTxStateVar m NtNAddr Int (Tx Int) diff --git a/dmq-node/app/Main.hs b/dmq-node/app/Main.hs index edbda5e1d5f..5cb0e4b6ca4 100644 --- a/dmq-node/app/Main.hs +++ b/dmq-node/app/Main.hs @@ -6,6 +6,7 @@ module Main where +import Control.Exception (throwIO) import Control.Monad (void, when) import Control.Tracer (Tracer (..), nullTracer, traceWith) @@ -22,6 +23,7 @@ import System.Exit (exitSuccess) import System.Random (newStdGen, split) import Cardano.Git.Rev (gitRev) +import Cardano.KESAgent.KES.Evolution qualified as KES import Cardano.KESAgent.Protocols.StandardCrypto (StandardCrypto) import DMQ.Configuration @@ -68,6 +70,7 @@ runDMQ commandLineConfig = do let dmqConfig@Configuration { dmqcPrettyLog = I prettyLog, dmqcTopologyFile = I topologyFile, + dmqcShelleyGenesisFile = I genesisFile, dmqcHandshakeTracer = I handshakeTracer, dmqcLocalHandshakeTracer = I localHandshakeTracer, dmqcVersion = I version @@ -95,6 +98,12 @@ runDMQ commandLineConfig = do ] exitSuccess + res <- KES.evolutionConfigFromGenesisFile genesisFile + evolutionConfig <- case res of + Left err -> traceWith tracer (WithEventType "ShelleyGenesisFile" err) + >> throwIO (userError $ err) + Right ev -> return ev + traceWith tracer (WithEventType "Configuration" dmqConfig) nt <- readTopologyFileOrError topologyFile traceWith tracer (WithEventType "NetworkTopology" nt) @@ -102,7 +111,11 @@ runDMQ commandLineConfig = do stdGen <- newStdGen let (psRng, policyRng) = split stdGen - withNodeKernel @StandardCrypto tracer dmqConfig psRng $ \nodeKernel -> do + withNodeKernel @StandardCrypto + tracer + dmqConfig + evolutionConfig + psRng $ \nodeKernel -> do dmqDiffusionConfiguration <- mkDiffusionConfiguration dmqConfig nt let dmqNtNApps = @@ -110,10 +123,8 @@ runDMQ commandLineConfig = do dmqConfig nodeKernel (dmqCodecs - -- TODO: `maxBound :: Cardano.Network.NodeToNode.NodeToNodeVersion` - -- is unsafe here! - (encodeRemoteAddress (maxBound :: NodeToNodeVersion)) - (decodeRemoteAddress (maxBound :: NodeToNodeVersion))) + (encodeRemoteAddress (maxBound @NodeToNodeVersion)) + (decodeRemoteAddress (maxBound @NodeToNodeVersion))) dmqLimitsAndTimeouts defaultSigDecisionPolicy dmqNtCApps = diff --git a/dmq-node/cddl/specs/sig.cddl b/dmq-node/cddl/specs/sig.cddl index 994bfda841f..10740d1a0ec 100644 --- a/dmq-node/cddl/specs/sig.cddl +++ b/dmq-node/cddl/specs/sig.cddl @@ -13,7 +13,7 @@ messagePayload = [ messageId = bstr messageBody = bstr -kesSignature = bstr +kesSignature = bstr .size 448 kesPeriod = word64 operationalCertificate = [ bstr .size 32, word64, word64, bstr .size 64 ] coldVerificationKey = bstr .size 32 diff --git a/dmq-node/changelog.d/20251016_142205_coot_dmq_signature_validation.md b/dmq-node/changelog.d/20251016_142205_coot_dmq_signature_validation.md new file mode 100644 index 00000000000..a32740c07e3 --- /dev/null +++ b/dmq-node/changelog.d/20251016_142205_coot_dmq_signature_validation.md @@ -0,0 +1,20 @@ + + +### Breaking + +- Using `KESPeriod` from `Cardano.Crypto.KES` instead of `SigKESPeriod` + newtype. `KESPeriod` is used by `SigRaw` data type. +- `SigKESSignature` holds `SigKES (KES crypto)` instead of a `ByteString`. +- `SigColdKey` holds `VerKeyDSIGN` instead of a `ByteString`. +- `ntnApps` constraints changed in order to use `sigValidate` function. + +### Non-Breaking + +- `Sig` codec decodes KES signatures, and the cold key. +- Added `DMQ.SigSubmission.Type.validateSig` and `SigValidationError`. + diff --git a/dmq-node/dmq-node.cabal b/dmq-node/dmq-node.cabal index 42cf76e6d8a..a4f3352dcc4 100644 --- a/dmq-node/dmq-node.cabal +++ b/dmq-node/dmq-node.cabal @@ -20,10 +20,15 @@ extra-doc-files: CHANGELOG.md flag cddl description: Enable CDDL based tests of the CBOR encoding - manual: True -- These tests need the cddl and the cbor-diag Ruby-package default: True +flag standardcrypto-tests + description: Enable StandardCrypto tests + -- these tests are flaky on GH Windows instances + manual: True + default: True + common extensions default-extensions: BangPatterns @@ -186,6 +191,9 @@ test-suite dmq-tests -T -RTS + if flag(standardcrypto-tests) + cpp-options: -DSTANDARDCRYPTO_TESTS + test-suite dmq-cddl import: warnings, diff --git a/dmq-node/src/DMQ/Configuration.hs b/dmq-node/src/DMQ/Configuration.hs index d3b9c8c69aa..b524497c969 100644 --- a/dmq-node/src/DMQ/Configuration.hs +++ b/dmq-node/src/DMQ/Configuration.hs @@ -90,6 +90,10 @@ data Configuration' f = dmqcPortNumber :: f PortNumber, dmqcConfigFile :: f FilePath, dmqcTopologyFile :: f FilePath, + dmqcShelleyGenesisFile :: f FilePath, + -- ^ shelley genesis file, e.g. + -- `/configuration/cardano/mainnet-shelley-genesis.json` in `cardano-node` + -- repo. dmqcAcceptedConnectionsLimit :: f AcceptedConnectionsLimit, dmqcDiffusionMode :: f DiffusionMode, dmqcTargetOfRootPeers :: f Int, @@ -210,6 +214,7 @@ defaultConfiguration = Configuration { dmqcPortNumber = I 3_141, dmqcConfigFile = I "dmq.configuration.yaml", dmqcTopologyFile = I "dmq.topology.json", + dmqcShelleyGenesisFile = I "mainnet-shelley-genesis.json", dmqcAcceptedConnectionsLimit = I defaultAcceptedConnectionsLimit, dmqcDiffusionMode = I InitiatorAndResponderDiffusionMode, dmqcTargetOfRootPeers = I targetNumberOfRootPeers, @@ -300,6 +305,8 @@ instance FromJSON PartialConfig where dmqcDiffusionMode <- Last <$> v .:? "DiffusionMode" dmqcPeerSharing <- Last <$> v .:? "PeerSharing" + dmqcShelleyGenesisFile <- Last <$> v .:? "ShelleyGenesisFile" + dmqcTargetOfRootPeers <- Last <$> v .:? "TargetNumberOfRootPeers" dmqcTargetOfKnownPeers <- Last <$> v .:? "TargetNumberOfKnownPeers" dmqcTargetOfEstablishedPeers <- Last <$> v .:? "TargetNumberOfEstablishedPeers" @@ -376,6 +383,7 @@ instance ToJSON Configuration where , "LocalAddress" .= unI dmqcLocalAddress , "ConfigFile" .= unI dmqcConfigFile , "TopologyFile" .= unI dmqcTopologyFile + , "ShelleyGenesisFile" .= unI dmqcShelleyGenesisFile , "AcceptedConnectionsLimit" .= unI dmqcAcceptedConnectionsLimit , "DiffusionMode" .= unI dmqcDiffusionMode , "TargetOfRootPeers" .= unI dmqcTargetOfRootPeers diff --git a/dmq-node/src/DMQ/Diffusion/NodeKernel.hs b/dmq-node/src/DMQ/Diffusion/NodeKernel.hs index da282e989a6..3a50905ce54 100644 --- a/dmq-node/src/DMQ/Diffusion/NodeKernel.hs +++ b/dmq-node/src/DMQ/Diffusion/NodeKernel.hs @@ -19,7 +19,10 @@ import Data.Aeson qualified as Aeson import Data.Function (on) import Data.Functor.Contravariant ((>$<)) import Data.Hashable +import Data.Sequence (Seq) import Data.Sequence qualified as Seq +import Data.Set (Set) +import Data.Set qualified as Set import Data.Time.Clock.POSIX (POSIXTime) import Data.Time.Clock.POSIX qualified as Time import Data.Void (Void) @@ -27,6 +30,7 @@ import System.Random (StdGen) import System.Random qualified as Random import Cardano.KESAgent.KES.Crypto (Crypto (..)) +import Cardano.KESAgent.KES.Evolution qualified as KES import Ouroboros.Network.BlockFetch (FetchClientRegistry, newFetchClientRegistry) @@ -37,11 +41,12 @@ import Ouroboros.Network.PeerSharing (PeerSharingAPI, PeerSharingRegistry, newPeerSharingAPI, newPeerSharingRegistry, ps_POLICY_PEER_SHARE_MAX_PEERS, ps_POLICY_PEER_SHARE_STICKY_TIME) import Ouroboros.Network.TxSubmission.Inbound.V2 -import Ouroboros.Network.TxSubmission.Mempool.Simple (Mempool (..)) +import Ouroboros.Network.TxSubmission.Mempool.Simple (Mempool (..), + MempoolSeq (..)) import Ouroboros.Network.TxSubmission.Mempool.Simple qualified as Mempool import DMQ.Configuration -import DMQ.Protocol.SigSubmission.Type (Sig (sigExpiresAt), SigId) +import DMQ.Protocol.SigSubmission.Type (Sig (sigExpiresAt, sigId), SigId) import DMQ.Tracer @@ -54,7 +59,8 @@ data NodeKernel crypto ntnAddr m = -- the PeerSharing protocol , peerSharingRegistry :: !(PeerSharingRegistry ntnAddr m) , peerSharingAPI :: !(PeerSharingAPI ntnAddr StdGen m) - , mempool :: !(Mempool m (Sig crypto)) + , mempool :: !(Mempool m SigId (Sig crypto)) + , evolutionConfig :: !(KES.EvolutionConfig) , sigChannelVar :: !(TxChannelsVar m ntnAddr SigId (Sig crypto)) , sigMempoolSem :: !(TxMempoolSem m) , sigSharedTxStateVar :: !(SharedTxStateVar m ntnAddr SigId (Sig crypto)) @@ -64,9 +70,10 @@ newNodeKernel :: ( MonadLabelledSTM m , MonadMVar m , Ord ntnAddr ) - => StdGen + => KES.EvolutionConfig + -> StdGen -> m (NodeKernel crypto ntnAddr m) -newNodeKernel rng = do +newNodeKernel evolutionConfig rng = do publicPeerSelectionStateVar <- makePublicPeerSelectionStateVar fetchClientRegistry <- newFetchClientRegistry @@ -89,6 +96,7 @@ newNodeKernel rng = do , peerSharingRegistry , peerSharingAPI , mempool + , evolutionConfig , sigChannelVar , sigMempoolSem , sigSharedTxStateVar @@ -110,6 +118,7 @@ withNodeKernel :: forall crypto ntnAddr m a. ) => (forall ev. Aeson.ToJSON ev => Tracer m (WithEventType ev)) -> Configuration + -> KES.EvolutionConfig -> StdGen -> (NodeKernel crypto ntnAddr m -> m a) -- ^ as soon as the callback exits the `mempoolWorker` and all @@ -119,12 +128,13 @@ withNodeKernel tracer Configuration { dmqcSigSubmissionLogicTracer = I sigSubmissionLogicTracer } + evolutionConfig rng k = do nodeKernel@NodeKernel { mempool, sigChannelVar, sigSharedTxStateVar } - <- newNodeKernel rng + <- newNodeKernel evolutionConfig rng withAsync (mempoolWorker mempool) $ \mempoolThread -> withAsync (decisionLogicThreads @@ -146,22 +156,36 @@ mempoolWorker :: forall crypto m. , MonadSTM m , MonadTime m ) - => Mempool m (Sig crypto) + => Mempool m SigId (Sig crypto) -> m Void mempoolWorker (Mempool v) = loop where loop = do now <- getCurrentPOSIXTime rt <- atomically $ do - (sigs :: Seq.Seq (Sig crypto)) <- readTVar v - let sigs' :: Seq.Seq (Sig crypto) - (resumeTime, sigs') = - foldr (\a (rt, as) -> if sigExpiresAt a <= now - then (rt, as) - else (rt `min` sigExpiresAt a, a Seq.<| as)) - (now, Seq.empty) - sigs - writeTVar v sigs' + MempoolSeq { mempoolSeq, mempoolSet } <- readTVar v + let mempoolSeq' :: Seq (Sig crypto) + mempoolSet', expiredSet' :: Set SigId + + (resumeTime, expiredSet', mempoolSeq') = + foldr (\sig (rt, expiredSet, sigs) -> + if sigExpiresAt sig <= now + then ( rt + , sigId sig `Set.insert` expiredSet + , sigs + ) + else ( rt `min` sigExpiresAt sig + , expiredSet + , sig Seq.<| sigs + ) + ) + (now, Set.empty, Seq.empty) + mempoolSeq + + mempoolSet' = mempoolSet `Set.difference` expiredSet' + + writeTVar v MempoolSeq { mempoolSet = mempoolSet', + mempoolSeq = mempoolSeq' } return resumeTime now' <- getCurrentPOSIXTime diff --git a/dmq-node/src/DMQ/NodeToNode.hs b/dmq-node/src/DMQ/NodeToNode.hs index 9ff7a3acc7c..71322107797 100644 --- a/dmq-node/src/DMQ/NodeToNode.hs +++ b/dmq-node/src/DMQ/NodeToNode.hs @@ -4,6 +4,7 @@ {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeOperators #-} module DMQ.NodeToNode ( RemoteAddress @@ -40,6 +41,7 @@ import Codec.CBOR.Encoding qualified as CBOR import Codec.CBOR.Read qualified as CBOR import Codec.CBOR.Term qualified as CBOR import Data.Aeson qualified as Aeson +import Data.ByteString qualified as BS import Data.ByteString.Lazy qualified as BL import Data.Functor.Contravariant ((>$<)) import Data.Hashable (Hashable) @@ -52,7 +54,10 @@ import Network.Mux.Types (Mode (..)) import Network.Mux.Types qualified as Mx import Network.TypedProtocol.Codec (AnnotatedCodec, Codec) +import Cardano.Crypto.DSIGN.Class qualified as DSIGN +import Cardano.Crypto.KES.Class qualified as KES import Cardano.KESAgent.KES.Crypto (Crypto (..)) +import Cardano.KESAgent.KES.OCert (OCertSignable) import DMQ.Configuration (Configuration, Configuration' (..), I (..)) import DMQ.Diffusion.NodeKernel (NodeKernel (..)) @@ -147,6 +152,10 @@ data Apps addr m a b = ntnApps :: forall crypto m addr . ( Crypto crypto + , DSIGN.ContextDSIGN (DSIGN crypto) ~ () + , DSIGN.Signable (DSIGN crypto) (OCertSignable crypto) + , KES.ContextKES (KES crypto) ~ () + , KES.Signable (KES crypto) BS.ByteString , Typeable crypto , Alternative (STM m) , MonadAsync m @@ -187,6 +196,7 @@ ntnApps , peerSharingRegistry , peerSharingAPI , mempool + , evolutionConfig , sigChannelVar , sigMempoolSem , sigSharedTxStateVar @@ -224,8 +234,8 @@ ntnApps -- connection if we receive one, rather than validate them in the -- mempool. mempoolWriter = Mempool.getWriter sigId - (pure ()) - (\_ _ -> Right () :: Either Void ()) + (pure ()) -- TODO not needed + (\_ -> validateSig evolutionConfig) (\_ -> True) mempool diff --git a/dmq-node/src/DMQ/Protocol/SigSubmission/Codec.hs b/dmq-node/src/DMQ/Protocol/SigSubmission/Codec.hs index b45ee5f702f..aa09d119152 100644 --- a/dmq-node/src/DMQ/Protocol/SigSubmission/Codec.hs +++ b/dmq-node/src/DMQ/Protocol/SigSubmission/Codec.hs @@ -1,7 +1,6 @@ {-# LANGUAGE DataKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} -{-# LANGUAGE KindSignatures #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE PolyKinds #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -33,8 +32,9 @@ import Codec.CBOR.Read qualified as CBOR import Network.TypedProtocol.Codec.CBOR import Cardano.Binary (FromCBOR (..), ToCBOR (..)) -import Cardano.Crypto.DSIGN.Class (decodeSignedDSIGN, encodeSignedDSIGN) -import Cardano.Crypto.KES.Class (decodeVerKeyKES, encodeVerKeyKES) +import Cardano.Crypto.DSIGN.Class (decodeSignedDSIGN, decodeVerKeyDSIGN, + encodeSignedDSIGN) +import Cardano.Crypto.KES.Class (decodeSigKES, decodeVerKeyKES, encodeVerKeyKES) import Cardano.KESAgent.KES.Crypto (Crypto (..)) import Cardano.KESAgent.KES.OCert (OCert (..)) @@ -159,9 +159,9 @@ decodeSig = do endOffset <- CBOR.peekByteOffset -- end of signed data - sigRawKESSignature <- SigKESSignature <$> CBOR.decodeBytes + sigRawKESSignature <- SigKESSignature <$> decodeSigKES sigRawOpCertificate <- decodeSigOpCertificate - sigRawColdKey <- SigColdKey <$> CBOR.decodeBytes + sigRawColdKey <- SigColdKey <$> decodeVerKeyDSIGN return $ \bytes -- ^ full bytes of the message, not just the sig part -> SigRawWithSignedBytes { sigRawSignedBytes = Utils.bytesBetweenOffsets startOffset endOffset bytes, @@ -176,13 +176,13 @@ decodeSig = do } } where - decodePayload :: CBOR.Decoder s (SigId, SigBody, SigKESPeriod, POSIXTime) + decodePayload :: CBOR.Decoder s (SigId, SigBody, KESPeriod, POSIXTime) decodePayload = do a <- CBOR.decodeListLen when (a /= 4) $ fail (printf "decodeSig: unexpected number of parameters %d for Sig's payload" a) (,,,) <$> decodeSigId <*> (SigBody <$> CBOR.decodeBytes) - <*> CBOR.decodeWord + <*> (KESPeriod <$> CBOR.decodeWord) <*> (realToFrac <$> CBOR.decodeWord32) diff --git a/dmq-node/src/DMQ/Protocol/SigSubmission/Type.hs b/dmq-node/src/DMQ/Protocol/SigSubmission/Type.hs index e5d25651d87..feb9e028b27 100644 --- a/dmq-node/src/DMQ/Protocol/SigSubmission/Type.hs +++ b/dmq-node/src/DMQ/Protocol/SigSubmission/Type.hs @@ -4,8 +4,9 @@ {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TypeSynonymInstances #-} +{-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} module DMQ.Protocol.SigSubmission.Type @@ -14,21 +15,24 @@ module DMQ.Protocol.SigSubmission.Type , SigId (..) , SigBody (..) , SigKESSignature (..) - , SigKESPeriod , SigOpCertificate (..) , SigColdKey (..) , SigRaw (..) , SigRawWithSignedBytes (..) , Sig (Sig, SigWithBytes, sigRawWithSignedBytes, sigRawBytes, sigId, sigBody, sigExpiresAt, sigOpCertificate, sigKESPeriod, sigKESSignature, sigColdKey, sigSignedBytes, sigBytes) + , validateSig -- * `TxSubmission` mini-protocol , SigSubmission , module SigSubmission , POSIXTime -- * Utilities , CBORBytes (..) + -- * Re-exports from `kes-agent` + , KESPeriod (..) ) where import Data.Aeson +import Data.Bifunctor (first) import Data.ByteString (ByteString) import Data.ByteString.Base16 as BS.Base16 import Data.ByteString.Base16.Lazy as LBS.Base16 @@ -37,12 +41,15 @@ import Data.ByteString.Lazy.Char8 qualified as LBS.Char8 import Data.Text.Encoding qualified as Text import Data.Time.Clock.POSIX (POSIXTime) import Data.Typeable +import Data.Word (Word64) -import Cardano.Crypto.DSIGN.Class (DSIGNAlgorithm) -import Cardano.Crypto.KES.Class (VerKeyKES) --- import Cardano.Crypto.Util (SignableRepresentation (..)) +import Cardano.Crypto.DSIGN.Class (ContextDSIGN, DSIGNAlgorithm, VerKeyDSIGN) +import Cardano.Crypto.DSIGN.Class qualified as DSIGN +import Cardano.Crypto.KES.Class (KESAlgorithm (..), Signable) import Cardano.KESAgent.KES.Crypto as KES -import Cardano.KESAgent.KES.OCert (OCert (..)) +import Cardano.KESAgent.KES.Evolution qualified as KES +import Cardano.KESAgent.KES.OCert (KESPeriod (..), OCert (..), OCertSignable, + validateOCert) import Ouroboros.Network.Protocol.TxSubmission2.Type as SigSubmission hiding (TxSubmission2) @@ -66,13 +73,13 @@ newtype SigBody = SigBody { getSigBody :: ByteString } deriving stock (Show, Eq) --- TODO: --- This type should be something like: `SignedKES (KES crypto) SigPayload` -newtype SigKESSignature = SigKESSignature { getSigKESSignature :: ByteString } - deriving stock (Show, Eq) +newtype SigKESSignature crypto = SigKESSignature { getSigKESSignature :: SigKES (KES crypto) } + +deriving instance Show (SigKES (KES crypto)) + => Show (SigKESSignature crypto) +deriving instance Eq (SigKES (KES crypto)) + => Eq (SigKESSignature crypto) --- TODO: --- This type should be more than just a `ByteString`. newtype SigOpCertificate crypto = SigOpCertificate { getSigOpCertificate :: OCert crypto } deriving instance ( DSIGNAlgorithm (KES.DSIGN crypto) @@ -81,13 +88,16 @@ deriving instance ( DSIGNAlgorithm (KES.DSIGN crypto) => Show (SigOpCertificate crypto) deriving instance ( DSIGNAlgorithm (KES.DSIGN crypto) , Eq (VerKeyKES (KES crypto)) - ) => Eq (SigOpCertificate crypto) + ) => Eq (SigOpCertificate crypto) -type SigKESPeriod = Word +newtype SigColdKey crypto = SigColdKey { getSigColdKey :: VerKeyDSIGN (KES.DSIGN crypto) } -newtype SigColdKey = SigColdKey { getSigColdKey :: ByteString } - deriving stock (Show, Eq) +deriving instance Show (VerKeyDSIGN (KES.DSIGN crypto)) + => Show (SigColdKey crypto) + +deriving instance Eq (VerKeyDSIGN (KES.DSIGN crypto)) + => Eq (SigColdKey crypto) -- | Sig type consists of payload and its KES signature. -- @@ -95,23 +105,28 @@ newtype SigColdKey = SigColdKey { getSigColdKey :: ByteString } data SigRaw crypto = SigRaw { sigRawId :: SigId, sigRawBody :: SigBody, - sigRawKESPeriod :: SigKESPeriod, + sigRawKESPeriod :: KESPeriod, -- ^ KES period when this signature was created. -- -- NOTE: `kes-agent` library is using `Word` for KES period, CIP-137 -- requires `Word64`, thus we're only supporting 64-bit architectures. - sigRawExpiresAt :: POSIXTime, - sigRawKESSignature :: SigKESSignature, sigRawOpCertificate :: SigOpCertificate crypto, - sigRawColdKey :: SigColdKey + sigRawColdKey :: SigColdKey crypto, + sigRawExpiresAt :: POSIXTime, + sigRawKESSignature :: SigKESSignature crypto + -- ^ KES signature of all previous fields. + -- + -- NOTE: this field must be lazy, otetherwise tests will fail. } deriving instance ( DSIGNAlgorithm (KES.DSIGN crypto) , Show (VerKeyKES (KES crypto)) + , Show (SigKES (KES crypto)) ) => Show (SigRaw crypto) deriving instance ( DSIGNAlgorithm (KES.DSIGN crypto) , Eq (VerKeyKES (KES crypto)) + , Eq (SigKES (KES crypto)) ) => Eq (SigRaw crypto) @@ -151,10 +166,12 @@ data SigRawWithSignedBytes crypto = SigRawWithSignedBytes { deriving instance ( DSIGNAlgorithm (KES.DSIGN crypto) , Show (VerKeyKES (KES crypto)) + , Show (SigKES (KES crypto)) ) => Show (SigRawWithSignedBytes crypto) deriving instance ( DSIGNAlgorithm (KES.DSIGN crypto) , Eq (VerKeyKES (KES crypto)) + , Eq (SigKES (KES crypto)) ) => Eq (SigRawWithSignedBytes crypto) @@ -162,6 +179,7 @@ instance Crypto crypto => ToJSON (SigRawWithSignedBytes crypto) where toJSON SigRawWithSignedBytes {sigRaw} = toJSON sigRaw + data Sig crypto = SigWithBytes { sigRawBytes :: LBS.ByteString, -- ^ encoded `SigRaw` data type @@ -171,10 +189,12 @@ data Sig crypto = SigWithBytes { deriving instance ( DSIGNAlgorithm (KES.DSIGN crypto) , Show (VerKeyKES (KES crypto)) + , Show (SigKES (KES crypto)) ) => Show (Sig crypto) deriving instance ( DSIGNAlgorithm (KES.DSIGN crypto) , Eq (VerKeyKES (KES crypto)) + , Eq (SigKES (KES crypto)) ) => Eq (Sig crypto) @@ -187,10 +207,10 @@ instance Crypto crypto pattern Sig :: SigId -> SigBody - -> SigKESSignature - -> SigKESPeriod + -> SigKESSignature crypto + -> KESPeriod -> SigOpCertificate crypto - -> SigColdKey + -> SigColdKey crypto -> POSIXTime -> LBS.ByteString -> LBS.ByteString @@ -253,6 +273,62 @@ pattern instance Typeable crypto => ShowProxy (Sig crypto) where + +data SigValidationError = + InvalidKESSignature KESPeriod KESPeriod String + | InvalidSignatureOCERT + !Word64 -- OCert counter + !KESPeriod -- OCert KES period + !String -- DSIGN error message + | KESBeforeStartOCERT KESPeriod KESPeriod + | KESAfterEndOCERT KESPeriod KESPeriod + deriving Show + +validateSig :: forall crypto. + ( Crypto crypto + , ContextDSIGN (KES.DSIGN crypto) ~ () + , DSIGN.Signable (DSIGN crypto) (OCertSignable crypto) + , ContextKES (KES crypto) ~ () + , Signable (KES crypto) ByteString + ) + => KES.EvolutionConfig + -> Sig crypto + -> Either SigValidationError () +validateSig _ec + Sig { sigSignedBytes = signedBytes, + sigKESPeriod, + sigOpCertificate = SigOpCertificate ocert@OCert { + ocertKESPeriod, + ocertVkHot, + ocertN + }, + sigColdKey = SigColdKey coldKey, + sigKESSignature = SigKESSignature kesSig + } + = do + sigKESPeriod < endKESPeriod + ?! KESAfterEndOCERT endKESPeriod sigKESPeriod + sigKESPeriod >= startKESPeriod + ?! KESBeforeStartOCERT startKESPeriod sigKESPeriod + + -- validate OCert, which includes verifying its signature + validateOCert coldKey ocertVkHot ocert + ?!: InvalidSignatureOCERT ocertN sigKESPeriod + -- validate KES signature of the payload + verifyKES () ocertVkHot + (unKESPeriod sigKESPeriod - unKESPeriod startKESPeriod) + (LBS.toStrict signedBytes) + kesSig + ?!: InvalidKESSignature ocertKESPeriod sigKESPeriod + where + startKESPeriod, endKESPeriod :: KESPeriod + + startKESPeriod = ocertKESPeriod + -- TODO: is `totalPeriodsKES` the same as `praosMaxKESEvo` + -- or `sgMaxKESEvolution` in the genesis file? + endKESPeriod = KESPeriod $ unKESPeriod startKESPeriod + + totalPeriodsKES (Proxy :: Proxy (KES crypto)) + type SigSubmission crypto = TxSubmission2.TxSubmission2 SigId (Sig crypto) @@ -267,3 +343,19 @@ newtype CBORBytes = CBORBytes { getCBORBytes :: LBS.ByteString } instance Show CBORBytes where show = LBS.Char8.unpack . LBS.Base16.encode . getCBORBytes + + +-- +-- Utility functions +-- + +(?!:) :: Either e1 a -> (e1 -> e2) -> Either e2 a +(?!:) = flip first + +infix 1 ?!: + +(?!) :: Bool -> e -> Either e () +(?!) True _ = Right () +(?!) False e = Left e + +infix 1 ?! diff --git a/dmq-node/test/DMQ/Protocol/SigSubmission/Test.hs b/dmq-node/test/DMQ/Protocol/SigSubmission/Test.hs index bbccc6b1611..d4b6898c5e6 100644 --- a/dmq-node/test/DMQ/Protocol/SigSubmission/Test.hs +++ b/dmq-node/test/DMQ/Protocol/SigSubmission/Test.hs @@ -1,4 +1,6 @@ {-# LANGUAGE BangPatterns #-} +{-# LANGUAGE BlockArguments #-} +{-# LANGUAGE CPP #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE FlexibleInstances #-} @@ -9,21 +11,28 @@ {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE UndecidableInstances #-} {-# OPTIONS_GHC -Wno-orphans #-} +#ifndef STANDARDCRYPTO_TESTS +{-# OPTIONS_GHC -Wno-unused-top-binds #-} +#endif -module DMQ.Protocol.SigSubmission.Test where +module DMQ.Protocol.SigSubmission.Test (tests) where import Codec.CBOR.Encoding qualified as CBOR import Codec.CBOR.Read qualified as CBOR import Codec.CBOR.Write qualified as CBOR +import Control.Monad (zipWithM, (>=>)) import Control.Monad.ST (runST) import Data.Bifunctor (second) +import Data.ByteString (ByteString) import Data.ByteString.Lazy qualified as BL import Data.List.NonEmpty qualified as NonEmpty +import Data.Typeable import Data.Word (Word32) import GHC.TypeNats (KnownNat) import System.IO.Unsafe (unsafePerformIO) @@ -31,12 +40,15 @@ import System.IO.Unsafe (unsafePerformIO) import Network.TypedProtocol.Codec import Network.TypedProtocol.Codec.Properties hiding (prop_codec) +import Cardano.Crypto.DSIGN.Class (DSIGNAlgorithm, SignKeyDSIGN, + deriveVerKeyDSIGN, encodeVerKeyDSIGN) import Cardano.Crypto.DSIGN.Class qualified as DSIGN -import Cardano.Crypto.KES.Class (KESAlgorithm (..), VerKeyKES) +import Cardano.Crypto.KES.Class (KESAlgorithm (..), VerKeyKES, encodeSigKES) import Cardano.Crypto.KES.Class qualified as KES import Cardano.Crypto.PinnedSizedBytes (PinnedSizedBytes, psbToByteString) import Cardano.Crypto.Seed (mkSeedFromBytes) import Cardano.KESAgent.KES.Crypto (Crypto (..)) +import Cardano.KESAgent.KES.Evolution qualified as KES import Cardano.KESAgent.KES.OCert (OCert (..)) import Cardano.KESAgent.KES.OCert qualified as KES import Cardano.KESAgent.Protocols.StandardCrypto (MockCrypto, StandardCrypto) @@ -48,7 +60,7 @@ import DMQ.Protocol.SigSubmission.Type import Ouroboros.Network.Protocol.TxSubmission2.Test (labelMsg) import Test.Ouroboros.Network.Protocol.Utils (prop_codec_cborM, - prop_codec_valid_cbor_encoding, splits2, splits3) + prop_codec_valid_cbor_encoding, splits2) import Test.QuickCheck.Instances.ByteString () import Test.Tasty @@ -59,37 +71,55 @@ tests :: TestTree tests = testGroup "DMQ.Protocol" [ testGroup "SigSubmission" - [ testGroup "mockcrypto" - [ testProperty "OCert" prop_codec_ocert_mockcrypto - , testProperty "Sig" prop_codec_sig_mockcrypto - , testProperty "codec" prop_codec_mockcrypto - , testProperty "codec id" prop_codec_id_mockcrypto - , testProperty "codec 2-splits" $ withMaxSize 20 - $ withMaxSuccess 20 - prop_codec_splits2_mockcrypto - , testProperty "codec 3-splits" $ withMaxSize 10 - $ withMaxSuccess 10 - prop_codec_splits3_mockcrypto - , testProperty "codec cbor" prop_codec_cbor_mockcrypto - , testProperty "codec valid cbor" prop_codec_valid_cbor_mockcrypto + [ testGroup "Codec" + [ testGroup "MockCrypto" + [ testProperty "OCert" prop_codec_ocert_mockcrypto + , testProperty "Sig" prop_codec_sig_mockcrypto + , testProperty "codec" prop_codec_mockcrypto + , testProperty "codec id" prop_codec_id_mockcrypto + , testProperty "codec 2-splits" $ withMaxSize 20 + $ withMaxSuccess 20 + prop_codec_splits2_mockcrypto + -- MockCrypto produces too large messages for this test to run: + -- , testProperty "codec 3-splits" $ withMaxSize 10 + -- $ withMaxSuccess 10 + -- prop_codec_splits3_mockcrypto + , testProperty "codec cbor" prop_codec_cbor_mockcrypto + , testProperty "codec valid cbor" prop_codec_valid_cbor_mockcrypto + , testProperty "OCert" prop_codec_cbor_mockcrypto + ] +#ifdef STANDARDCRYPTO_TESTS + , testGroup "StandardCrypto" + [ testProperty "OCert" prop_codec_ocert_standardcrypto + , testProperty "Sig" prop_codec_sig_standardcrypto + , testProperty "codec" prop_codec_standardcrypto + , testProperty "codec id" prop_codec_id_standardcrypto + , testProperty "codec 2-splits" $ withMaxSize 20 + $ withMaxSuccess 20 + prop_codec_splits2_standardcrypto + -- StandardCrypto produces too large messages for this test to run: + {- + , testProperty "codec 3-splits" $ withMaxSize 10 + $ withMaxSuccess 10 + prop_codec_splits3_standardcrypto + -} + , testProperty "codec cbor" prop_codec_cbor_standardcrypto + , testProperty "codec valid cbor" prop_codec_valid_cbor_standardcrypto + ] +#endif + ] + ] + , testGroup "Crypto" + [ testGroup "MockCrypto" + [ testProperty "KES sign verify" prop_sign_verify_mockcrypto + , testProperty "validateSig" prop_validateSig_mockcrypto ] - , testGroup "standardcrypto" - [ testProperty "OCert" prop_codec_ocert_standardcrypto - , testProperty "Sig" prop_codec_sig_standardcrypto - , testProperty "codec" prop_codec_standardcrypto - , testProperty "codec id" prop_codec_id_standardcrypto - , testProperty "codec 2-splits" $ withMaxSize 20 - $ withMaxSuccess 20 - prop_codec_splits2_standardcrypto - -- StandardCrypt produces too large messages for this test to run: - {- - , testProperty "codec 3-splits" $ withMaxSize 10 - $ withMaxSuccess 10 - prop_codec_splits3_standardcrypto - -} - , testProperty "codec cbor" prop_codec_cbor_standardcrypto - , testProperty "codec valid cbor" prop_codec_valid_cbor_standardcrypto +#ifdef STANDARDCRYPTO_TESTS + , testGroup "StandardCrypto" + [ testProperty "KES sign verify" prop_sign_verify_standardcrypto + , testProperty "validateSig" prop_validateSig_standardcrypto ] +#endif ] ] @@ -111,31 +141,65 @@ instance Arbitrary POSIXTime where -- shrink via Word32 (e.g. in seconds) shrink posix = realToFrac <$> shrink (floor @_ @Word32 posix) -instance Arbitrary SigKESSignature where - arbitrary = SigKESSignature <$> arbitrary - shrink = map SigKESSignature . shrink . getSigKESSignature -mkVerKeyKES +-- | Make a KES key pair. +-- +mkKeysKES :: forall kesCrypto. KESAlgorithm kesCrypto => PinnedSizedBytes (SeedSizeKES kesCrypto) - -> IO (VerKeyKES kesCrypto) -mkVerKeyKES seed = do - withMLockedSeedFromPSB seed $ \mseed -> - KES.genKeyKES mseed >>= deriveVerKeyKES + -> IO (SignKeyKES kesCrypto, VerKeyKES kesCrypto) +mkKeysKES seed = + withMLockedSeedFromPSB seed $ \mseed -> do + snKESKey <- KES.genKeyKES mseed + (snKESKey,) <$> deriveVerKeyKES snKESKey +-- | The idea of this data type is to go around limitation of QuickCheck `Gen` +-- type, which does not allow IO actions. So instead we generate some random +-- context (e.g. key seed) and then the data is created when the property +-- runs. +-- +-- Keeping the `key` seprate allows to have access to it when shrinking, see +-- `shrinkWithConstr`, this is important when the signed data is shrinked and +-- we need to update a KES signature as well. +-- +-- However the limitation is shrinking: it requires `unsafePerformIO` anyway, +-- see `shrinkWithConstr`. +-- +-- TODO: to avoid complexity can we use `UnsoundPureKESAlgorithm` instead of +-- `KESAlgorithm`? +-- data WithConstr ctx key a = - WithConstr { constr :: key -> a, + WithConstr { constr :: key -> IO a, mkKey :: ctx -> IO key, ctx :: ctx } deriving instance Functor (WithConstr ctx key) +withConstrBind :: WithConstr ctx key a -> (a -> IO b) -> WithConstr ctx key b +withConstrBind WithConstr { constr, mkKey, ctx } fn = + WithConstr { constr = constr >=> fn, + mkKey, + ctx + } + +runWithConstr :: WithConstr ctx key a -> IO a +runWithConstr WithConstr { constr, mkKey, ctx } = mkKey ctx >>= constr + +constrWithKeys + :: (keys -> IO a) + -> WithConstr ctx keys keys + -> WithConstr ctx keys a +constrWithKeys f WithConstr { constr, mkKey, ctx } = + WithConstr { constr = constr >=> f, + mkKey, + ctx + } constWithConstr :: a -> WithConstr [ctx] [key] a constWithConstr a = - WithConstr { constr = const a, + WithConstr { constr = const (pure a), mkKey = \_ -> pure [], ctx = [] } @@ -146,16 +210,16 @@ listWithConstr :: forall ctx key a b. -> WithConstr [ctx] [key] b listWithConstr constr' as = WithConstr { - constr = \keys -> constr' (zipWith ($) constrs keys), - mkKey = \ctxs -> sequence (zipWith ($) mkKeys ctxs), + constr = \keys -> constr' <$> zipWithM ($) constrs keys, + mkKey = \ctxs -> zipWithM ($) mkKeys ctxs, ctx = ctx <$> as } where - constrs :: [(key -> a)] + constrs :: [key -> IO a] constrs = constr <$> as - mkKeys :: [(ctx -> IO key)] + mkKeys :: [ctx -> IO key] mkKeys = mkKey <$> as @@ -168,7 +232,7 @@ shrinkWithConstrCtx constr@WithConstr { ctx } = sequenceWithConstr - :: (a -> key -> a) + :: (a -> key -> IO a) -> WithConstr ctx key [a] -> IO [WithConstr ctx key a] sequenceWithConstr update constr@WithConstr { mkKey, ctx } = do @@ -180,33 +244,30 @@ sequenceWithConstr update constr@WithConstr { mkKey, ctx } = do -- unsafePerformIO :( shrinkWithConstr :: Arbitrary ctx - => (a -> key -> a) + => (a -> key -> IO a) -> (a -> [a]) - -> WithConstr ctx key a + -> WithConstr ctx key a -> [WithConstr ctx key a] shrinkWithConstr update shrinker constr = unsafePerformIO (sequenceWithConstr update $ shrinker <$> constr) ++ shrinkWithConstrCtx constr -runWithConstr :: WithConstr ctx key a -> IO a -runWithConstr WithConstr { constr, mkKey, ctx } = constr <$> mkKey ctx - +type KESCTX size = PinnedSizedBytes size +type WithConstrKES size crypto a = WithConstr (KESCTX size) (SignKeyKES crypto, VerKeyKES crypto) a +type WithConstrKESList size crypto a = WithConstr [KESCTX size] [(SignKeyKES crypto, VerKeyKES crypto)] a -type VerKeyKESCTX size = PinnedSizedBytes size -type WithConstrVerKeyKES size crypto a = WithConstr (VerKeyKESCTX size) (VerKeyKES crypto) a -type WithConstrVerKeyKESList size crypto a = WithConstr [VerKeyKESCTX size] [VerKeyKES crypto] a -mkVerKeyKESConstr +mkKeysKESConstr :: forall kesCrypto. KESAlgorithm kesCrypto - => VerKeyKESCTX (SeedSizeKES kesCrypto) - -> WithConstrVerKeyKES (SeedSizeKES kesCrypto) - kesCrypto - (VerKeyKES kesCrypto) -mkVerKeyKESConstr ctx = - WithConstr { constr = id, - mkKey = mkVerKeyKES, + => KESCTX (SeedSizeKES kesCrypto) + -> WithConstrKES (SeedSizeKES kesCrypto) + kesCrypto + (SignKeyKES kesCrypto, VerKeyKES kesCrypto) +mkKeysKESConstr ctx = + WithConstr { constr = pure, + mkKey = mkKeysKES, ctx } @@ -214,71 +275,115 @@ instance ( size ~ SeedSizeKES kesCrypto , KnownNat size , KESAlgorithm kesCrypto ) - => Arbitrary (WithConstrVerKeyKES size kesCrypto (VerKeyKES kesCrypto)) where - arbitrary = mkVerKeyKESConstr <$> arbitrary + => Arbitrary (WithConstrKES size kesCrypto (SignKeyKES kesCrypto, VerKeyKES kesCrypto)) where + arbitrary = mkKeysKESConstr <$> arbitrary shrink = shrinkWithConstrCtx +-- | An auxiliary data type to hold KES keys along with an OCert, payload and +-- its KES signature. +data CryptoCtx crypto = CryptoCtx { + snKESKey :: SignKeyKES (KES crypto), + -- ^ signing KES key + vnKESKey :: VerKeyKES (KES crypto), + -- ^ verification KES key + coldKey :: SignKeyDSIGN (DSIGN crypto), + -- ^ signing cold key + ocert :: OCert crypto + -- ^ ocert + } + + instance ( Crypto crypto , DSIGN.Signable (DSIGN crypto) (KES.OCertSignable crypto) , DSIGN.ContextDSIGN (DSIGN crypto) ~ () + , ContextKES (KES crypto) ~ () , kesCrypto ~ KES crypto + , KESAlgorithm kesCrypto + , Signable kesCrypto ByteString , size ~ SeedSizeKES kesCrypto , KnownNat size ) - => Arbitrary (WithConstrVerKeyKES size kesCrypto (OCert crypto)) where + => Arbitrary (WithConstrKES size kesCrypto (CryptoCtx crypto)) where arbitrary = do - verKeyKES <- arbitrary + withKeys <- arbitrary n <- arbitrary seedColdKey :: PinnedSizedBytes (DSIGN.SeedSizeDSIGN (DSIGN crypto)) <- arbitrary - let !skCold = DSIGN.genKeyDSIGN (mkSeedFromBytes . psbToByteString $ seedColdKey) + let !coldKey = DSIGN.genKeyDSIGN (mkSeedFromBytes . psbToByteString $ seedColdKey) period <- KES.KESPeriod <$> arbitrary - return $ fmap (\vkKES -> KES.makeOCert vkKES n period skCold) verKeyKES - shrink = shrinkWithConstrCtx - - -instance ( kesCrypto ~ KES crypto - , size ~ SeedSizeKES kesCrypto - , KnownNat size - , Arbitrary (WithConstrVerKeyKES size kesCrypto (OCert crypto)) - ) - => Arbitrary (WithConstrVerKeyKES size kesCrypto (SigOpCertificate crypto)) where - arbitrary = fmap SigOpCertificate <$> arbitrary + return $ constrWithKeys + (\(snKESKey, vnKESKey) -> + return $ CryptoCtx { + snKESKey, + vnKESKey, + coldKey, + ocert = KES.makeOCert vnKESKey n period coldKey + }) + withKeys shrink = shrinkWithConstrCtx instance ( Crypto crypto , kesCrypto ~ KES crypto + , ContextKES kesCrypto ~ () , size ~ SeedSizeKES kesCrypto - , Arbitrary (WithConstrVerKeyKES size kesCrypto (OCert crypto)) + , Signable kesCrypto ByteString + , dsignCrypto ~ DSIGN crypto + , DSIGNAlgorithm dsignCrypto + , Arbitrary (WithConstrKES size kesCrypto (CryptoCtx crypto)) ) - => Arbitrary (WithConstrVerKeyKES size kesCrypto (SigRawWithSignedBytes crypto)) where + => Arbitrary (WithConstrKES size kesCrypto (SigRawWithSignedBytes crypto)) where arbitrary = do sigRawId <- arbitrary - sigRawBody <- arbitrary sigRawExpiresAt <- arbitrary - opCert <- arbitrary - sigRawKESPeriod <- arbitrary - sigRawKESSignature <- arbitrary - sigRawColdKey <- arbitrary - return $ fmap (\cert -> let sigRawOpCertificate = SigOpCertificate cert - sigRaw = SigRaw { - sigRawId, - sigRawBody, - sigRawKESPeriod, - sigRawOpCertificate, - sigRawColdKey, - sigRawExpiresAt, - sigRawKESSignature = undefined -- to be filled below - } - signedBytes = CBOR.toStrictByteString (encodeSigRaw' sigRaw) - in - SigRawWithSignedBytes { - sigRawSignedBytes = BL.fromStrict signedBytes, - sigRaw = sigRaw { sigRawKESSignature } - } - ) opCert + let maxKESOffset :: Word + maxKESOffset = totalPeriodsKES (Proxy :: Proxy kesCrypto) + -- offset since `ocertKESPeriod`, so that the signature is still valid + kesOffset <- arbitrary `suchThat` (< maxKESOffset) + payload <- arbitrary + crypto <- arbitrary + return $ withConstrBind crypto \CryptoCtx {ocert, coldKey, snKESKey} -> do + let sigRawOpCertificate :: SigOpCertificate crypto + sigRawOpCertificate = SigOpCertificate ocert + + sigRawBody :: SigBody + sigRawBody = SigBody payload + + sigRawColdKey :: SigColdKey crypto + sigRawColdKey = SigColdKey $ deriveVerKeyDSIGN coldKey + + sigRawKESPeriod :: KESPeriod + sigRawKESPeriod = KESPeriod $ unKESPeriod (ocertKESPeriod ocert) + + kesOffset + + sigRaw = SigRaw { + sigRawId, + sigRawBody, + sigRawKESPeriod, + sigRawOpCertificate, + sigRawColdKey, + sigRawExpiresAt, + sigRawKESSignature = undefined -- to be filled below + } + signedBytes = CBOR.toStrictByteString (encodeSigRaw' sigRaw) + + -- evolve the key to the target period + mbSnKESKey <- KES.updateKESTo () sigRawKESPeriod ocert (KES.SignKeyWithPeriodKES snKESKey 0) + case mbSnKESKey of + Just (KES.SignKeyWithPeriodKES snKESKey' _) -> do + -- signed bytes with the snKESKey' + sigRawKESSignature + <- SigKESSignature + <$> KES.signKES () kesOffset signedBytes snKESKey' + return SigRawWithSignedBytes { + sigRawSignedBytes = BL.fromStrict signedBytes, + sigRaw = sigRaw { sigRawKESSignature } + } + Nothing -> + error $ "arbitrary SigRawWithSignedBytes: could not evolve KES key to the target period by KES offset: " + ++ show kesOffset + shrink = shrinkWithConstrSigRawWithSignedBytes @@ -289,26 +394,38 @@ instance ( Crypto crypto -- shrinkWithConstrSigRawWithSignedBytes :: forall crypto. - Crypto crypto - => WithConstrVerKeyKES (SeedSizeKES (KES crypto)) (KES crypto) (SigRawWithSignedBytes crypto) - -> [WithConstrVerKeyKES (SeedSizeKES (KES crypto)) (KES crypto) (SigRawWithSignedBytes crypto)] + ( Crypto crypto + , ContextKES (KES crypto) ~ () + , Signable (KES crypto) ByteString + ) + => WithConstrKES (SeedSizeKES (KES crypto)) (KES crypto) (SigRawWithSignedBytes crypto) + -> [WithConstrKES (SeedSizeKES (KES crypto)) (KES crypto) (SigRawWithSignedBytes crypto)] shrinkWithConstrSigRawWithSignedBytes = shrinkWithConstr updateFn shrinkSigRawWithSignedBytesFn where updateFn :: SigRawWithSignedBytes crypto - -> VerKeyKES (KES crypto) - -> SigRawWithSignedBytes crypto + -> (SignKeyKES (KES crypto), VerKeyKES (KES crypto)) + -> IO (SigRawWithSignedBytes crypto) updateFn SigRawWithSignedBytes { - sigRaw = sigRaw@SigRaw { sigRawOpCertificate = SigOpCertificate ocert }, + sigRaw = sigRaw@SigRaw { sigRawOpCertificate = SigOpCertificate ocert, + sigRawKESPeriod + }, sigRawSignedBytes } - ocertVkHot - = + (snKeyKES, ocertVkHot) + = do let sigRaw' = sigRaw { sigRawOpCertificate = SigOpCertificate ocert { ocertVkHot } } - in SigRawWithSignedBytes { - sigRaw = sigRaw', + -- update KES key to sigRawKESPeriod + Just (KES.SignKeyWithPeriodKES snKeyKES' _) + <- KES.updateKESTo () sigRawKESPeriod ocert (KES.SignKeyWithPeriodKES snKeyKES 0) + -- sign the message + sign <- KES.signKES () (KES.unKESPeriod sigRawKESPeriod - KES.unKESPeriod (ocertKESPeriod ocert)) + (BL.toStrict sigRawSignedBytes) + snKeyKES' + pure $ SigRawWithSignedBytes { + sigRaw = sigRaw' { sigRawKESSignature = SigKESSignature sign }, sigRawSignedBytes } @@ -326,11 +443,18 @@ shrinkSigRawWithSignedBytesFn SigRawWithSignedBytes { sigRaw } = | sigRaw' <- shrinkSigRawFn sigRaw , let sigRawSignedBytes' = CBOR.toLazyByteString (encodeSigRaw' sigRaw') ] + + +-- | Pure shrinking function for `SigRaw`. It does not update the KES +-- signature. +-- shrinkSigRawFn :: SigRaw crypto -> [SigRaw crypto] shrinkSigRawFn sig@SigRaw { sigRawId, - sigRawBody, - sigRawExpiresAt - } = + sigRawBody, + sigRawKESPeriod, + sigRawExpiresAt, + sigRawOpCertificate = SigOpCertificate ocert + } = [ sig { sigRawId = sigRawId' } | sigRawId' <- shrink sigRawId ] @@ -339,23 +463,15 @@ shrinkSigRawFn sig@SigRaw { sigRawId, | sigRawBody' <- shrink sigRawBody ] ++ + [ sig { sigRawKESPeriod = sigRawKESPeriod' } + | sigRawKESPeriod' <- KESPeriod <$> shrink (unKESPeriod sigRawKESPeriod) + , sigRawKESPeriod' >= ocertKESPeriod ocert + ] + ++ [ sig { sigRawExpiresAt = sigRawExpiresAt' } | sigRawExpiresAt' <- shrink sigRawExpiresAt ] -instance Arbitrary SigColdKey where - arbitrary = SigColdKey <$> arbitrary - shrink = map SigColdKey . shrink . getSigColdKey - - -mkSigRawWithSignedBytes :: SigRaw crypto -> SigRawWithSignedBytes crypto -mkSigRawWithSignedBytes sigRaw = - SigRawWithSignedBytes { - sigRaw, - sigRawSignedBytes - } - where - sigRawSignedBytes = CBOR.toLazyByteString (encodeSigRaw' sigRaw) -- NOTE: this function is not exposed in the main library on purpose. We -- should never construct `Sig` by serialising `SigRaw`. @@ -368,7 +484,7 @@ mkSig sigRawWithSignedBytes@SigRawWithSignedBytes { sigRaw } = sigRawWithSignedBytes } where - sigRawBytes = CBOR.toLazyByteString (encodeSigRaw sigRaw) + sigRawBytes = CBOR.toLazyByteString (encodeSigRaw sigRaw) -- encode only signed part @@ -383,7 +499,7 @@ encodeSigRaw' SigRaw { = CBOR.encodeListLen 4 <> encodeSigId sigRawId <> CBOR.encodeBytes (getSigBody sigRawBody) - <> CBOR.encodeWord sigRawKESPeriod + <> CBOR.encodeWord (unKESPeriod sigRawKESPeriod) <> CBOR.encodeWord32 (floor sigRawExpiresAt) -- encode together with KES signature, OCert and cold key. @@ -393,41 +509,60 @@ encodeSigRaw :: Crypto crypto encodeSigRaw sigRaw@SigRaw { sigRawKESSignature, sigRawOpCertificate, sigRawColdKey } = CBOR.encodeListLen 4 <> encodeSigRaw' sigRaw - <> CBOR.encodeBytes (getSigKESSignature sigRawKESSignature) + <> encodeSigKES (getSigKESSignature sigRawKESSignature) <> encodeSigOpCertificate sigRawOpCertificate - <> CBOR.encodeBytes (getSigColdKey sigRawColdKey) - + <> encodeVerKeyDSIGN (getSigColdKey sigRawColdKey) -shrinkSigFn :: forall crypto. Crypto crypto +-- note: KES signature is updated by updateSigFn +shrinkSigFn :: forall crypto. + Crypto crypto => Sig crypto -> [Sig crypto] shrinkSigFn SigWithBytes {sigRawWithSignedBytes = SigRawWithSignedBytes { sigRaw, sigRawSignedBytes } } = mkSig . (\sigRaw' -> SigRawWithSignedBytes { sigRaw = sigRaw', sigRawSignedBytes }) <$> shrinkSigRawFn sigRaw + +updateSigFn :: forall crypto. + KESAlgorithm (KES crypto) + => ContextKES (KES crypto) ~ () + => Signable (KES crypto) ByteString + => Sig crypto + -> (SignKeyKES (KES crypto), VerKeyKES (KES crypto)) + -> IO (Sig crypto) +updateSigFn + sig@Sig { sigOpCertificate = SigOpCertificate opCert, + sigBody = SigBody body + } + (snKESKey, vnKESKey) + = do + signature <- KES.signKES () (KES.unKESPeriod (ocertKESPeriod opCert)) body snKESKey + return $ sig { sigOpCertificate = SigOpCertificate opCert { ocertVkHot = vnKESKey}, + sigKESSignature = SigKESSignature signature + } + + instance ( Crypto crypto , DSIGN.ContextDSIGN (DSIGN crypto) ~ () , DSIGN.Signable (DSIGN crypto) (KES.OCertSignable crypto) , kesCrypto ~ KES crypto + , ContextKES kesCrypto ~ () + , Signable kesCrypto ByteString , size ~ SeedSizeKES kesCrypto , KnownNat size ) - => Arbitrary (WithConstrVerKeyKES size kesCrypto (Sig crypto)) where + => Arbitrary (WithConstrKES size kesCrypto (Sig crypto)) where arbitrary = fmap mkSig <$> arbitrary shrink = shrinkWithConstr updateSigFn shrinkSigFn -updateSigFn :: Sig crypto -> VerKeyKES (KES crypto) -> Sig crypto -updateSigFn - sig@Sig {sigOpCertificate = SigOpCertificate opCert} - ocertVkHot - = - sig { sigOpCertificate = SigOpCertificate opCert { ocertVkHot } } - instance ( kesCrypto ~ KES crypto + , KESAlgorithm kesCrypto + , ContextKES kesCrypto ~ () + , Signable kesCrypto ByteString , size ~ SeedSizeKES kesCrypto , KnownNat size - , Arbitrary (WithConstrVerKeyKES size kesCrypto (Sig crypto)) + , Arbitrary (WithConstrKES size kesCrypto (Sig crypto)) ) - => Arbitrary (WithConstrVerKeyKESList size kesCrypto (AnyMessage (SigSubmission crypto))) where + => Arbitrary (WithConstrKESList size kesCrypto (AnyMessage (SigSubmission crypto))) where arbitrary = oneof [ pure . constWithConstr $ AnyMessage MsgInit , constWithConstr . AnyMessage <$> @@ -451,15 +586,19 @@ instance ( kesCrypto ~ KES crypto , constWithConstr . AnyMessage <$> MsgRequestTxs <$> arbitrary , listWithConstr (AnyMessage . MsgReplyTxs) - <$> (arbitrary :: Gen [WithConstrVerKeyKES size kesCrypto (Sig crypto)]) + <$> (arbitrary :: Gen [WithConstrKES size kesCrypto (Sig crypto)]) , constWithConstr . AnyMessage <$> pure MsgDone ] shrink = shrinkWithConstr updateFn shrinkFn where - updateFn :: AnyMessage (SigSubmission crypto) -> [VerKeyKES kesCrypto] -> AnyMessage (SigSubmission crypto) - updateFn (AnyMessage (MsgReplyTxs txs)) vkKeyKESs = AnyMessage (MsgReplyTxs (zipWith updateSigFn txs vkKeyKESs)) - updateFn msg _ = msg + updateFn :: AnyMessage (SigSubmission crypto) + -> [(SignKeyKES kesCrypto, VerKeyKES kesCrypto)] + -> IO (AnyMessage (SigSubmission crypto)) + updateFn (AnyMessage (MsgReplyTxs txs)) keys = do + sigs <- traverse (uncurry updateSigFn) (zip txs keys) + return $ AnyMessage (MsgReplyTxs sigs) + updateFn msg _ = pure msg shrinkFn :: AnyMessage (SigSubmission crypto) -> [AnyMessage (SigSubmission crypto)] shrinkFn = \case @@ -494,10 +633,10 @@ instance ( kesCrypto ~ KES crypto prop_codec_ocert :: forall crypto. Crypto crypto - => WithConstrVerKeyKES (SeedSizeKES (KES crypto)) (KES crypto) (OCert crypto) + => WithConstrKES (SeedSizeKES (KES crypto)) (KES crypto) (CryptoCtx crypto) -> Property prop_codec_ocert constr = ioProperty $ do - ocert <- runWithConstr constr + CryptoCtx { ocert } <- runWithConstr constr return . counterexample (show ocert) $ let encoded = CBOR.toLazyByteString (encodeSigOpCertificate (SigOpCertificate ocert)) in case CBOR.deserialiseFromBytes decodeSigOpCertificate encoded of @@ -507,12 +646,12 @@ prop_codec_ocert constr = ioProperty $ do .&&. BL.null bytes prop_codec_ocert_mockcrypto - :: Blind (WithConstrVerKeyKES (SeedSizeKES (KES MockCrypto)) (KES MockCrypto) (OCert MockCrypto)) + :: Blind (WithConstrKES (SeedSizeKES (KES MockCrypto)) (KES MockCrypto) (CryptoCtx MockCrypto)) -> Property prop_codec_ocert_mockcrypto = prop_codec_ocert . getBlind prop_codec_ocert_standardcrypto - :: Blind (WithConstrVerKeyKES (SeedSizeKES (KES StandardCrypto)) (KES StandardCrypto) (OCert StandardCrypto)) + :: Blind (WithConstrKES (SeedSizeKES (KES StandardCrypto)) (KES StandardCrypto) (CryptoCtx StandardCrypto)) -> Property prop_codec_ocert_standardcrypto = prop_codec_ocert . getBlind @@ -523,7 +662,7 @@ prop_codec_ocert_standardcrypto = prop_codec_ocert . getBlind -- * signed bytes match the encoding of `encodeSigRaw'`. prop_codec_sig :: forall crypto. Crypto crypto - => WithConstrVerKeyKES (SeedSizeKES (KES crypto)) (KES crypto) (Sig crypto) + => WithConstrKES (SeedSizeKES (KES crypto)) (KES crypto) (Sig crypto) -> Property prop_codec_sig constr = ioProperty $ do sig <- runWithConstr constr @@ -556,17 +695,17 @@ prop_codec_sig constr = ioProperty $ do .&&. BL.null leftovers prop_codec_sig_mockcrypto - :: Blind (WithConstrVerKeyKES (SeedSizeKES (KES MockCrypto)) (KES MockCrypto) (Sig MockCrypto)) + :: Blind (WithConstrKES (SeedSizeKES (KES MockCrypto)) (KES MockCrypto) (Sig MockCrypto)) -> Property prop_codec_sig_mockcrypto = prop_codec_sig . getBlind prop_codec_sig_standardcrypto - :: Blind (WithConstrVerKeyKES (SeedSizeKES (KES MockCrypto)) (KES MockCrypto) (Sig MockCrypto)) + :: Blind (WithConstrKES (SeedSizeKES (KES MockCrypto)) (KES MockCrypto) (Sig MockCrypto)) -> Property prop_codec_sig_standardcrypto = prop_codec_sig . getBlind -type AnySigMessage crypto = WithConstrVerKeyKESList (SeedSizeKES (KES crypto)) (KES crypto) (AnyMessage (SigSubmission crypto)) +type AnySigMessage crypto = WithConstrKESList (SeedSizeKES (KES crypto)) (KES crypto) (AnyMessage (SigSubmission crypto)) prop_codec :: forall crypto. Crypto crypto @@ -619,6 +758,9 @@ prop_codec_splits2_standardcrypto :: Blind (AnySigMessage StandardCrypto) -> Pro prop_codec_splits2_standardcrypto = prop_codec_splits2 . getBlind +{- +-- TODO: we need a different splits3 function that does not explore all the +-- ways of splitting a message into three chunks. prop_codec_splits3 :: forall crypto. Crypto crypto => AnySigMessage crypto -> Property prop_codec_splits3 constr = ioProperty $ do @@ -629,9 +771,12 @@ prop_codec_splits3 constr = ioProperty $ do prop_codec_splits3_mockcrypto :: Blind (AnySigMessage MockCrypto) -> Property prop_codec_splits3_mockcrypto = prop_codec_splits3 . getBlind +-} +{- prop_codec_splits3_standardcrypto :: Blind (AnySigMessage StandardCrypto) -> Property prop_codec_splits3_standardcrypto = prop_codec_splits3 . getBlind +-} prop_codec_cbor @@ -672,3 +817,74 @@ prop_codec_valid_cbor_standardcrypto :: Blind (AnySigMessage StandardCrypto) -> Property prop_codec_valid_cbor_standardcrypto = prop_codec_valid_cbor . getBlind + + +-- | Check that the KES signature is valid. +-- +prop_validateSig + :: ( Crypto crypto + , DSIGN.ContextDSIGN (DSIGN crypto) ~ () + , DSIGN.Signable (DSIGN crypto) (KES.OCertSignable crypto) + , KES.ContextKES (KES crypto) ~ () + , KES.Signable (KES crypto) ByteString + ) + => WithConstrKES size kesCrypt (Sig crypto) + -> Property +prop_validateSig constr = ioProperty $ do + sig <- runWithConstr constr + return $ case validateSig KES.defEvolutionConfig sig of + Left err -> counterexample ("KES seed: " ++ show (ctx constr)) + . counterexample ("KES vk key: " ++ show (ocertVkHot . getSigOpCertificate . sigOpCertificate $ sig)) + . counterexample (show sig) + . counterexample (show err) + $ False + Right () -> property True + +prop_validateSig_mockcrypto + :: Blind (WithConstrKES (SeedSizeKES (KES MockCrypto)) (KES MockCrypto) (Sig MockCrypto)) + -> Property +prop_validateSig_mockcrypto = prop_validateSig . getBlind + +-- TODO: FAILS, why? +prop_validateSig_standardcrypto + :: Blind (WithConstrKES (SeedSizeKES (KES StandardCrypto)) (KES StandardCrypto) (Sig StandardCrypto)) + -> Property +prop_validateSig_standardcrypto = prop_validateSig . getBlind + + +-- | Sign & verify a payload with KES keys. +-- +prop_sign_verify + :: ( Crypto crypto + , ContextKES (KES crypto) ~ () + , Signable (KES crypto) ByteString + ) + => WithConstrKES (SeedSizeKES (KES crypto)) (KES crypto) (CryptoCtx crypto) + -- ^ KES keys + -> ByteString + -- ^ payload + -> Property +prop_sign_verify constr payload = ioProperty $ do + CryptoCtx { snKESKey, vnKESKey } <- runWithConstr constr + signed <- KES.signKES () 0 payload snKESKey + let res = KES.verifyKES () vnKESKey 0 payload signed + return $ counterexample "KES signature does not verify" + $ case res of + Left err -> counterexample (show err) + . counterexample ("vnKESKey: " ++ show vnKESKey) + . counterexample ("signature: " ++ show signed) + $ False + Right () -> property True + + +prop_sign_verify_mockcrypto + :: Blind (WithConstrKES (SeedSizeKES (KES MockCrypto)) (KES MockCrypto) (CryptoCtx MockCrypto)) + -> ByteString + -> Property +prop_sign_verify_mockcrypto = prop_sign_verify . getBlind + +prop_sign_verify_standardcrypto + :: Blind (WithConstrKES (SeedSizeKES (KES StandardCrypto)) (KES StandardCrypto) (CryptoCtx StandardCrypto)) + -> ByteString + -> Property +prop_sign_verify_standardcrypto = prop_sign_verify . getBlind diff --git a/nix/ouroboros-network.nix b/nix/ouroboros-network.nix index c819a4d09d0..f3a54858a59 100644 --- a/nix/ouroboros-network.nix +++ b/nix/ouroboros-network.nix @@ -46,7 +46,7 @@ let compiler-nix-name = lib.mkDefault defaultCompiler; cabalProjectLocal = if pkgs.stdenv.hostPlatform.isWindows - then lib.readFile ../scripts/ci/cabal.project.local.Windows + then lib.readFile ../scripts/ci/cabal.project.local.Nix.Windows else lib.readFile ../scripts/ci/cabal.project.local.Linux; # diff --git a/ouroboros-network/changelog.d/20251016_141814_coot_dmq_signature_validation.md b/ouroboros-network/changelog.d/20251016_141814_coot_dmq_signature_validation.md new file mode 100644 index 00000000000..7994264a47f --- /dev/null +++ b/ouroboros-network/changelog.d/20251016_141814_coot_dmq_signature_validation.md @@ -0,0 +1,6 @@ +### Breaking + +- Ouroboros.Network.TxSubmission.Mempool.Simple API changes: + - `Mempool` is parametrised over `txid` and `tx` types + - `new` takes `tx -> txid` getter function + diff --git a/ouroboros-network/demo/ping-pong.hs b/ouroboros-network/demo/ping-pong.hs index 47639af2ba4..78e24e992b5 100644 --- a/ouroboros-network/demo/ping-pong.hs +++ b/ouroboros-network/demo/ping-pong.hs @@ -57,7 +57,8 @@ main = do rmIfExists defaultLocalSocketAddrPath void serverPingPong - "pingpong2":"client":[] -> clientPingPong2 + "pingpong2":"client":[] -> clientPingPong2 False + "pingpong2":"client-flood":[] -> clientPingPong2 True "pingpong2":"server":[] -> do rmIfExists defaultLocalSocketAddrPath void serverPingPong2 @@ -69,7 +70,8 @@ instance ShowProxy PingPong where usage :: IO () usage = do - hPutStrLn stderr "usage: demo-ping-pong [pingpong|pingpong2] {client|server} [addr]" + hPutStrLn stderr $ "usage: demo-ping-pong pingpong {client|client-pipelined|server}\n" + ++ " demo-ping-pong pingpong2 {client|client-flood|server}" exitFailure defaultLocalSocketAddrPath :: FilePath @@ -143,7 +145,7 @@ clientPingPong pipelined = mkMiniProtocolCbFromPeerPipelined $ \_ctx -> ( contramap show stdoutTracer , codecPingPong - , void $ pingPongClientPeerPipelined (pingPongClientPipelinedMax 5) + , void $ pingPongClientPeerPipelined (pingPongClientPipelinedMax 15) ) | otherwise = @@ -151,7 +153,7 @@ clientPingPong pipelined = mkMiniProtocolCbFromPeer $ \_ctx -> ( contramap show stdoutTracer , codecPingPong - , pingPongClientPeer (pingPongClientCount 5) + , pingPongClientPeer (pingPongClientCount 15) ) @@ -211,8 +213,8 @@ demoProtocol1 pingPong pingPong' = ] -clientPingPong2 :: IO () -clientPingPong2 = +clientPingPong2 :: Bool -> IO () +clientPingPong2 flood = withIOManager $ \iomgr -> void $ do connectToNode (Snocket.localSnocket iomgr) @@ -233,12 +235,17 @@ clientPingPong2 = Mx.InitiatorMode addr LBS.ByteString IO () Void app = demoProtocol1 pingpong pingpong' + client :: PingPongClient IO () + client = if flood + then pingPongClientFlood + else pingPongClientCount 15 + pingpong = InitiatorProtocolOnly $ mkMiniProtocolCbFromPeer $ \_ctx -> ( contramap (show . (,) (1 :: Int)) tracer , codecPingPong - , pingPongClientPeer (pingPongClientCount 5) + , pingPongClientPeer client ) pingpong'= @@ -246,7 +253,7 @@ clientPingPong2 = mkMiniProtocolCbFromPeer $ \_ctx -> ( contramap (show . (,) (2 :: Int)) tracer , codecPingPong - , pingPongClientPeer (pingPongClientCount 5) + , pingPongClientPeer client ) diff --git a/ouroboros-network/lib/Ouroboros/Network/TxSubmission/Mempool/Simple.hs b/ouroboros-network/lib/Ouroboros/Network/TxSubmission/Mempool/Simple.hs index 94a4bece425..75e49ace3e2 100644 --- a/ouroboros-network/lib/Ouroboros/Network/TxSubmission/Mempool/Simple.hs +++ b/ouroboros-network/lib/Ouroboros/Network/TxSubmission/Mempool/Simple.hs @@ -8,6 +8,7 @@ -- module Ouroboros.Network.TxSubmission.Mempool.Simple ( Mempool (..) + , MempoolSeq (..) , empty , new , read @@ -30,6 +31,7 @@ import Data.List (find, nubBy) import Data.Maybe (isJust) import Data.Sequence (Seq) import Data.Sequence qualified as Seq +import Data.Set (Set) import Data.Set qualified as Set import Data.Typeable (Typeable) @@ -38,25 +40,38 @@ import Ouroboros.Network.TxSubmission.Inbound.V2.Types import Ouroboros.Network.TxSubmission.Mempool.Reader +data MempoolSeq txid tx = MempoolSeq { + mempoolSet :: !(Set txid), + -- ^ cached set of `txid`s in the mempool + mempoolSeq :: !(Seq tx) + -- ^ sequence of all `tx`s + } + -- | A simple in-memory mempool implementation. -- -newtype Mempool m tx = Mempool (StrictTVar m (Seq tx)) +newtype Mempool m txid tx = Mempool (StrictTVar m (MempoolSeq txid tx)) -empty :: MonadSTM m => m (Mempool m tx) -empty = Mempool <$> newTVarIO Seq.empty +empty :: MonadSTM m => m (Mempool m txid tx) +empty = Mempool <$> newTVarIO (MempoolSeq Set.empty Seq.empty) -new :: MonadSTM m - => [tx] - -> m (Mempool m tx) -new = fmap Mempool - . newTVarIO - . Seq.fromList +new :: ( MonadSTM m + , Ord txid + ) + => (tx -> txid) + -> [tx] + -> m (Mempool m txid tx) +new getTxId txs = + fmap Mempool + . newTVarIO + $ MempoolSeq { mempoolSet = Set.fromList (getTxId <$> txs), + mempoolSeq = Seq.fromList txs + } -read :: MonadSTM m => Mempool m tx -> m [tx] -read (Mempool mempool) = toList <$> readTVarIO mempool +read :: MonadSTM m => Mempool m txid tx -> m [tx] +read (Mempool mempool) = toList . mempoolSeq <$> readTVarIO mempool getReader :: forall tx txid m. @@ -65,7 +80,7 @@ getReader :: forall tx txid m. ) => (tx -> txid) -> (tx -> SizeInBytes) - -> Mempool m tx + -> Mempool m txid tx -> TxSubmissionMempoolReader txid tx Int m getReader getTxId getTxSize (Mempool mempool) = -- Using `0`-based index. `mempoolZeroIdx = -1` so that @@ -75,7 +90,7 @@ getReader getTxId getTxSize (Mempool mempool) = } where mempoolGetSnapshot :: STM m (MempoolSnapshot txid tx Int) - mempoolGetSnapshot = getSnapshot <$> readTVar mempool + mempoolGetSnapshot = getSnapshot . mempoolSeq <$> readTVar mempool getSnapshot :: Seq tx -> MempoolSnapshot txid tx Int @@ -124,7 +139,7 @@ getWriter :: forall tx txid ctx failure m. -- ^ validate a tx, any failing `tx` throws an exception. -> (failure -> Bool) -- ^ return `True` when a failure should throw an exception - -> Mempool m tx + -> Mempool m txid tx -> TxSubmissionMempoolWriter txid tx Int m getWriter getTxId getValidationCtx validateTx failureFilterFn (Mempool mempool) = TxSubmissionMempoolWriter { @@ -133,11 +148,8 @@ getWriter getTxId getValidationCtx validateTx failureFilterFn (Mempool mempool) mempoolAddTxs = \txs -> do ctx <- getValidationCtx (invalidTxIds, validTxs) <- atomically $ do - mempoolTxs <- readTVar mempool - let -- TODO: set of current ids should be constructed incrementally, - -- e.g. it should be part of mempoolTxs - currentIds = Set.fromList (map getTxId (toList mempoolTxs)) - (invalidTxIds, validTxs) = + MempoolSeq { mempoolSet, mempoolSeq } <- readTVar mempool + let (invalidTxIds, validTxs) = bimap (filter (failureFilterFn . snd)) (nubBy (on (==) getTxId)) . partitionEithers @@ -145,9 +157,14 @@ getWriter getTxId getValidationCtx validateTx failureFilterFn (Mempool mempool) Left e -> Left (getTxId tx, e) Right _ -> Right tx ) - . filter (\tx -> getTxId tx `Set.notMember` currentIds) + . filter (\tx -> getTxId tx `Set.notMember` mempoolSet) $ txs - mempoolTxs' = Foldable.foldl' (Seq.|>) mempoolTxs validTxs + mempoolTxs' = MempoolSeq { + mempoolSet = Foldable.foldl' (\s tx -> getTxId tx `Set.insert` s) + mempoolSet + validTxs, + mempoolSeq = Foldable.foldl' (Seq.|>) mempoolSeq validTxs + } writeTVar mempool mempoolTxs' return (invalidTxIds, map getTxId validTxs) when (not (null invalidTxIds)) $ diff --git a/ouroboros-network/protocols/tests-lib/Ouroboros/Network/Protocol/TxSubmission2/Test.hs b/ouroboros-network/protocols/tests-lib/Ouroboros/Network/Protocol/TxSubmission2/Test.hs index 6de56ad4951..6c0fc651fc0 100644 --- a/ouroboros-network/protocols/tests-lib/Ouroboros/Network/Protocol/TxSubmission2/Test.hs +++ b/ouroboros-network/protocols/tests-lib/Ouroboros/Network/Protocol/TxSubmission2/Test.hs @@ -435,8 +435,8 @@ labelMsg (AnyMessage msg) = label (case msg of MsgInit -> "MsgInit" MsgRequestTxIds {} -> "MsgRequestTxIds" - MsgReplyTxIds as -> "MsgReplyTxIds " ++ renderRanges 3 (length as) - MsgRequestTxs as -> "MsgRequestTxs " ++ renderRanges 3 (length as) - MsgReplyTxs as -> "MsgReplyTxs " ++ renderRanges 3 (length as) + MsgReplyTxIds as -> "MsgReplyTxIds " ++ renderRanges 25 (length as) + MsgRequestTxs as -> "MsgRequestTxs " ++ renderRanges 25 (length as) + MsgReplyTxs as -> "MsgReplyTxs " ++ renderRanges 25 (length as) MsgDone -> "MsgDone" ) diff --git a/ouroboros-network/tests/lib/Test/Ouroboros/Network/Diffusion/Node/Kernel.hs b/ouroboros-network/tests/lib/Test/Ouroboros/Network/Diffusion/Node/Kernel.hs index 6ab63cd8955..f6a6065565d 100644 --- a/ouroboros-network/tests/lib/Test/Ouroboros/Network/Diffusion/Node/Kernel.hs +++ b/ouroboros-network/tests/lib/Test/Ouroboros/Network/Diffusion/Node/Kernel.hs @@ -310,7 +310,7 @@ data NodeKernel header block s txid m = NodeKernel { :: StrictTVar m (PublicPeerSelectionState NtNAddr), nkMempool - :: Mempool m (Tx txid), + :: Mempool m txid (Tx txid), nkTxChannelsVar :: TxChannelsVar m NtNAddr txid (Tx txid), @@ -325,6 +325,7 @@ data NodeKernel header block s txid m = NodeKernel { newNodeKernel :: ( MonadSTM m , Strict.MonadMVar m , RandomGen rng + , Ord txid , Eq txid ) => rng @@ -426,6 +427,7 @@ withNodeKernelThread , HasFullHeader block , RandomGen seed , Eq txid + , Ord txid ) => NtNAddr -- ^ just for naming a thread diff --git a/ouroboros-network/tests/lib/Test/Ouroboros/Network/TxSubmission/AppV1.hs b/ouroboros-network/tests/lib/Test/Ouroboros/Network/TxSubmission/AppV1.hs index bc154f5762b..a28dec40e2f 100644 --- a/ouroboros-network/tests/lib/Test/Ouroboros/Network/TxSubmission/AppV1.hs +++ b/ouroboros-network/tests/lib/Test/Ouroboros/Network/TxSubmission/AppV1.hs @@ -119,7 +119,7 @@ txSubmissionSimulation tracer maxUnacked outboundTxs return (inmp, outmp) where - outboundPeer :: Mempool m (Tx txid) -> TxSubmissionClient txid (Tx txid) m () + outboundPeer :: Mempool m txid (Tx txid) -> TxSubmissionClient txid (Tx txid) m () outboundPeer outboundMempool = txSubmissionOutbound nullTracer @@ -128,7 +128,7 @@ txSubmissionSimulation tracer maxUnacked outboundTxs (maxBound :: TestVersion) controlMessageSTM - inboundPeer :: Mempool m (Tx txid) -> TxSubmissionServerPipelined txid (Tx txid) m () + inboundPeer :: Mempool m txid (Tx txid) -> TxSubmissionServerPipelined txid (Tx txid) m () inboundPeer inboundMempool = txSubmissionInbound nullTracer diff --git a/ouroboros-network/tests/lib/Test/Ouroboros/Network/TxSubmission/Types.hs b/ouroboros-network/tests/lib/Test/Ouroboros/Network/TxSubmission/Types.hs index e686cc706f1..ca765c2aa56 100644 --- a/ouroboros-network/tests/lib/Test/Ouroboros/Network/TxSubmission/Types.hs +++ b/ouroboros-network/tests/lib/Test/Ouroboros/Network/TxSubmission/Types.hs @@ -103,13 +103,14 @@ maxTxSize = 65536 type TxId = Int -emptyMempool :: MonadSTM m => m (Mempool m (Tx txid)) +emptyMempool :: MonadSTM m => m (Mempool m txid (Tx txid)) emptyMempool = Mempool.empty -newMempool :: MonadSTM m => [Tx txid] -> m (Mempool m (Tx txid)) -newMempool = Mempool.new +newMempool :: (MonadSTM m, Ord txid) + => [Tx txid] -> m (Mempool m txid (Tx txid)) +newMempool = Mempool.new getTxId -readMempool :: MonadSTM m => Mempool m (Tx txid) -> m [Tx txid] +readMempool :: MonadSTM m => Mempool m txid (Tx txid) -> m [Tx txid] readMempool = Mempool.read getMempoolReader :: forall txid m. @@ -117,7 +118,7 @@ getMempoolReader :: forall txid m. , Eq txid , Show txid ) - => Mempool m (Tx txid) + => Mempool m txid (Tx txid) -> TxSubmissionMempoolReader txid (Tx txid) Int m getMempoolReader = Mempool.getReader getTxId getTxAdvSize @@ -130,7 +131,7 @@ getMempoolWriter :: forall txid m. , Typeable txid , Show txid ) - => Mempool m (Tx txid) + => Mempool m txid (Tx txid) -> TxSubmissionMempoolWriter txid (Tx txid) Int m getMempoolWriter = Mempool.getWriter getTxId (pure ()) diff --git a/scripts/ci/cabal.project.local.Nix.Windows b/scripts/ci/cabal.project.local.Nix.Windows new file mode 100644 index 00000000000..af847f4bb59 --- /dev/null +++ b/scripts/ci/cabal.project.local.Nix.Windows @@ -0,0 +1,32 @@ +max-backjumps: 5000 +reorder-goals: True +tests: True +benchmarks: True + +-- IPv6 and nothunks tests are DISABLED on Windows + +program-options + ghc-options: -fno-ignore-asserts -Werror + +package strict-checked-vars + flags: -checktvarinvariants -checkmvarinvariants + +package ntp-client + flags: +demo + +package network-mux + flags: -ipv6 + +package ouroboros-network + flags: +asserts -ipv6 + +-- +-- cddl is disabled on Windows due to missing build tool support in cross +-- compilation +-- + +package dmq-node + flags: -cddl + +package cardano-diffusion + flags: +asserts -cddl diff --git a/scripts/ci/cabal.project.local.Windows b/scripts/ci/cabal.project.local.Windows index af847f4bb59..82912a6529a 100644 --- a/scripts/ci/cabal.project.local.Windows +++ b/scripts/ci/cabal.project.local.Windows @@ -1,32 +1,4 @@ -max-backjumps: 5000 -reorder-goals: True -tests: True -benchmarks: True - --- IPv6 and nothunks tests are DISABLED on Windows - -program-options - ghc-options: -fno-ignore-asserts -Werror - -package strict-checked-vars - flags: -checktvarinvariants -checkmvarinvariants - -package ntp-client - flags: +demo - -package network-mux - flags: -ipv6 - -package ouroboros-network - flags: +asserts -ipv6 - --- --- cddl is disabled on Windows due to missing build tool support in cross --- compilation --- +import ./scripts/cabal.project.local.Nix.Windows package dmq-node - flags: -cddl - -package cardano-diffusion - flags: +asserts -cddl + flags: -standardcrypto-tests