diff --git a/packages/horizon/contracts/interfaces/internal/IHorizonStakingMain.sol b/packages/horizon/contracts/interfaces/internal/IHorizonStakingMain.sol index 05bd5ad7a..1e4f95a08 100644 --- a/packages/horizon/contracts/interfaces/internal/IHorizonStakingMain.sol +++ b/packages/horizon/contracts/interfaces/internal/IHorizonStakingMain.sol @@ -541,10 +541,10 @@ interface IHorizonStakingMain { */ function stakeTo(address serviceProvider, uint256 tokens) external; - // can be called by anyone if the service provider has provisioned stake to this verifier /** * @notice Deposit tokens on the service provider stake, on behalf of the service provider, * provisioned to a specific verifier. + * @dev This function can be called by the service provider, by an authorized operator or by the verifier itself. * @dev Requirements: * - The `serviceProvider` must have previously provisioned stake to `verifier`. * - `_tokens` cannot be zero. diff --git a/packages/horizon/contracts/staking/HorizonStaking.sol b/packages/horizon/contracts/staking/HorizonStaking.sol index 05af15880..dfb2aae97 100644 --- a/packages/horizon/contracts/staking/HorizonStaking.sol +++ b/packages/horizon/contracts/staking/HorizonStaking.sol @@ -57,6 +57,19 @@ contract HorizonStaking is HorizonStakingBase, IHorizonStakingMain { _; } + /** + * @notice Checks that the caller is authorized to operate over a provision or it is the verifier. + * @param serviceProvider The address of the service provider. + * @param verifier The address of the verifier. + */ + modifier onlyAuthorizedOrVerifier(address serviceProvider, address verifier) { + require( + _isAuthorized(serviceProvider, verifier, msg.sender) || msg.sender == verifier, + HorizonStakingNotAuthorized(serviceProvider, verifier, msg.sender) + ); + _; + } + /** * @dev The staking contract is upgradeable however we still use the constructor to set * a few immutable variables. @@ -121,7 +134,11 @@ contract HorizonStaking is HorizonStakingBase, IHorizonStakingMain { } /// @inheritdoc IHorizonStakingMain - function stakeToProvision(address serviceProvider, address verifier, uint256 tokens) external override notPaused { + function stakeToProvision( + address serviceProvider, + address verifier, + uint256 tokens + ) external override notPaused onlyAuthorizedOrVerifier(serviceProvider, verifier) { _stakeTo(serviceProvider, tokens); _addToProvision(serviceProvider, verifier, tokens); } diff --git a/packages/horizon/test/shared/horizon-staking/HorizonStakingShared.t.sol b/packages/horizon/test/shared/horizon-staking/HorizonStakingShared.t.sol index 52ed55830..2c845f0c4 100644 --- a/packages/horizon/test/shared/horizon-staking/HorizonStakingShared.t.sol +++ b/packages/horizon/test/shared/horizon-staking/HorizonStakingShared.t.sol @@ -181,6 +181,62 @@ abstract contract HorizonStakingSharedTest is GraphBaseTest { ); } + function _stakeToProvision(address serviceProvider, address verifier, uint256 tokens) internal { + (, address msgSender, ) = vm.readCallers(); + + // before + uint256 beforeStakingBalance = token.balanceOf(address(staking)); + uint256 beforeSenderBalance = token.balanceOf(msgSender); + ServiceProviderInternal memory beforeServiceProvider = _getStorage_ServiceProviderInternal(serviceProvider); + Provision memory beforeProvision = staking.getProvision(serviceProvider, verifier); + + // stakeTo + token.approve(address(staking), tokens); + vm.expectEmit(); + emit IHorizonStakingBase.HorizonStakeDeposited(serviceProvider, tokens); + vm.expectEmit(); + emit IHorizonStakingMain.ProvisionIncreased(serviceProvider, verifier, tokens); + staking.stakeToProvision(serviceProvider, verifier, tokens); + + // after + uint256 afterStakingBalance = token.balanceOf(address(staking)); + uint256 afterSenderBalance = token.balanceOf(msgSender); + ServiceProviderInternal memory afterServiceProvider = _getStorage_ServiceProviderInternal(serviceProvider); + Provision memory afterProvision = staking.getProvision(serviceProvider, verifier); + + // assert - stakeTo + assertEq(afterStakingBalance, beforeStakingBalance + tokens); + assertEq(afterSenderBalance, beforeSenderBalance - tokens); + assertEq(afterServiceProvider.tokensStaked, beforeServiceProvider.tokensStaked + tokens); + assertEq(afterServiceProvider.tokensProvisioned, beforeServiceProvider.tokensProvisioned + tokens); + assertEq(afterServiceProvider.__DEPRECATED_tokensAllocated, beforeServiceProvider.__DEPRECATED_tokensAllocated); + assertEq(afterServiceProvider.__DEPRECATED_tokensLocked, beforeServiceProvider.__DEPRECATED_tokensLocked); + assertEq( + afterServiceProvider.__DEPRECATED_tokensLockedUntil, + beforeServiceProvider.__DEPRECATED_tokensLockedUntil + ); + + // assert - addToProvision + assertEq(afterProvision.tokens, beforeProvision.tokens + tokens); + assertEq(afterProvision.tokensThawing, beforeProvision.tokensThawing); + assertEq(afterProvision.sharesThawing, beforeProvision.sharesThawing); + assertEq(afterProvision.maxVerifierCut, beforeProvision.maxVerifierCut); + assertEq(afterProvision.thawingPeriod, beforeProvision.thawingPeriod); + assertEq(afterProvision.createdAt, beforeProvision.createdAt); + assertEq(afterProvision.lastParametersStagedAt, beforeProvision.lastParametersStagedAt); + assertEq(afterProvision.maxVerifierCutPending, beforeProvision.maxVerifierCutPending); + assertEq(afterProvision.thawingPeriodPending, beforeProvision.thawingPeriodPending); + assertEq(afterProvision.thawingNonce, beforeProvision.thawingNonce); + assertEq(afterServiceProvider.tokensStaked, beforeServiceProvider.tokensStaked + tokens); + assertEq(afterServiceProvider.tokensProvisioned, beforeServiceProvider.tokensProvisioned + tokens); + assertEq(afterServiceProvider.__DEPRECATED_tokensAllocated, beforeServiceProvider.__DEPRECATED_tokensAllocated); + assertEq(afterServiceProvider.__DEPRECATED_tokensLocked, beforeServiceProvider.__DEPRECATED_tokensLocked); + assertEq( + afterServiceProvider.__DEPRECATED_tokensLockedUntil, + beforeServiceProvider.__DEPRECATED_tokensLockedUntil + ); + } + function _unstake(uint256 _tokens) internal { (, address msgSender, ) = vm.readCallers(); diff --git a/packages/horizon/test/staking/provision/provision.t.sol b/packages/horizon/test/staking/provision/provision.t.sol index ba580be25..c87e13a45 100644 --- a/packages/horizon/test/staking/provision/provision.t.sol +++ b/packages/horizon/test/staking/provision/provision.t.sol @@ -80,19 +80,6 @@ contract HorizonStakingProvisionTest is HorizonStakingTest { staking.provision(users.indexer, subgraphDataServiceAddress, amount / 2, maxVerifierCut, thawingPeriod); } - function testProvision_OperatorAddTokensToProvision( - uint256 amount, - uint32 maxVerifierCut, - uint64 thawingPeriod, - uint256 tokensToAdd - ) public useIndexer useProvision(amount, maxVerifierCut, thawingPeriod) useOperator { - tokensToAdd = bound(tokensToAdd, 1, MAX_STAKING_TOKENS); - - // Add more tokens to the provision - _stakeTo(users.indexer, tokensToAdd); - _addToProvision(users.indexer, subgraphDataServiceAddress, tokensToAdd); - } - function testProvision_RevertWhen_OperatorNotAuthorized( uint256 amount, uint32 maxVerifierCut, @@ -124,4 +111,113 @@ contract HorizonStakingProvisionTest is HorizonStakingTest { vm.expectRevert(expectedError); staking.provision(users.indexer, subgraphDataServiceAddress, amount, 0, 0); } + + function testProvision_AddTokensToProvision( + uint256 amount, + uint32 maxVerifierCut, + uint64 thawingPeriod, + uint256 tokensToAdd + ) public useIndexer useProvision(amount, maxVerifierCut, thawingPeriod) { + tokensToAdd = bound(tokensToAdd, 1, MAX_STAKING_TOKENS); + + // Add more tokens to the provision + _stakeTo(users.indexer, tokensToAdd); + _addToProvision(users.indexer, subgraphDataServiceAddress, tokensToAdd); + } + + function testProvision_OperatorAddTokensToProvision( + uint256 amount, + uint32 maxVerifierCut, + uint64 thawingPeriod, + uint256 tokensToAdd + ) public useIndexer useProvision(amount, maxVerifierCut, thawingPeriod) useOperator { + tokensToAdd = bound(tokensToAdd, 1, MAX_STAKING_TOKENS); + + // Add more tokens to the provision + _stakeTo(users.indexer, tokensToAdd); + _addToProvision(users.indexer, subgraphDataServiceAddress, tokensToAdd); + } + + function testProvision_AddTokensToProvision_RevertWhen_NotAuthorized( + uint256 amount, + uint32 maxVerifierCut, + uint64 thawingPeriod, + uint256 tokensToAdd + ) public useIndexer useProvision(amount, maxVerifierCut, thawingPeriod) { + tokensToAdd = bound(tokensToAdd, 1, MAX_STAKING_TOKENS); + + // Add more tokens to the provision + _stakeTo(users.indexer, tokensToAdd); + + // use delegator as a non authorized operator + vm.startPrank(users.delegator); + bytes memory expectedError = abi.encodeWithSignature( + "HorizonStakingNotAuthorized(address,address,address)", + users.indexer, + subgraphDataServiceAddress, + users.delegator + ); + vm.expectRevert(expectedError); + staking.addToProvision(users.indexer, subgraphDataServiceAddress, amount); + } + + function testProvision_StakeToProvision( + uint256 amount, + uint32 maxVerifierCut, + uint64 thawingPeriod, + uint256 tokensToAdd + ) public useIndexer useProvision(amount, maxVerifierCut, thawingPeriod) { + tokensToAdd = bound(tokensToAdd, 1, MAX_STAKING_TOKENS); + + // Add more tokens to the provision + _stakeToProvision(users.indexer, subgraphDataServiceAddress, tokensToAdd); + } + + function testProvision_Operator_StakeToProvision( + uint256 amount, + uint32 maxVerifierCut, + uint64 thawingPeriod, + uint256 tokensToAdd + ) public useIndexer useProvision(amount, maxVerifierCut, thawingPeriod) useOperator { + tokensToAdd = bound(tokensToAdd, 1, MAX_STAKING_TOKENS); + + // Add more tokens to the provision + _stakeToProvision(users.indexer, subgraphDataServiceAddress, tokensToAdd); + } + + function testProvision_Verifier_StakeToProvision( + uint256 amount, + uint32 maxVerifierCut, + uint64 thawingPeriod, + uint256 tokensToAdd + ) public useIndexer useProvision(amount, maxVerifierCut, thawingPeriod) { + tokensToAdd = bound(tokensToAdd, 1, MAX_STAKING_TOKENS); + + // Ensure the verifier has enough tokens to then stake to the provision + token.transfer(subgraphDataServiceAddress, tokensToAdd); + + // Add more tokens to the provision + resetPrank(subgraphDataServiceAddress); + _stakeToProvision(users.indexer, subgraphDataServiceAddress, tokensToAdd); + } + + function testProvision_StakeToProvision_RevertWhen_NotAuthorized( + uint256 amount, + uint32 maxVerifierCut, + uint64 thawingPeriod, + uint256 tokensToAdd + ) public useIndexer useProvision(amount, maxVerifierCut, thawingPeriod) { + tokensToAdd = bound(tokensToAdd, 1, MAX_STAKING_TOKENS); + + // Add more tokens to the provision + vm.startPrank(users.delegator); + bytes memory expectedError = abi.encodeWithSignature( + "HorizonStakingNotAuthorized(address,address,address)", + users.indexer, + subgraphDataServiceAddress, + users.delegator + ); + vm.expectRevert(expectedError); + staking.stakeToProvision(users.indexer, subgraphDataServiceAddress, tokensToAdd); + } }