diff --git a/eslint-suppressions.json b/eslint-suppressions.json index 4409cc80d9b..54aec365c88 100644 --- a/eslint-suppressions.json +++ b/eslint-suppressions.json @@ -386,62 +386,6 @@ "count": 2 } }, - "packages/assets-controllers/src/TokenBalancesController.test.ts": { - "@typescript-eslint/explicit-function-return-type": { - "count": 2 - }, - "camelcase": { - "count": 1 - }, - "jest/unbound-method": { - "count": 1 - }, - "require-atomic-updates": { - "count": 1 - } - }, - "packages/assets-controllers/src/TokenBalancesController.ts": { - "@typescript-eslint/explicit-function-return-type": { - "count": 19 - }, - "@typescript-eslint/naming-convention": { - "count": 1 - }, - "@typescript-eslint/no-misused-promises": { - "count": 1 - }, - "@typescript-eslint/prefer-nullish-coalescing": { - "count": 2 - }, - "id-denylist": { - "count": 6 - }, - "id-length": { - "count": 7 - } - }, - "packages/assets-controllers/src/TokenDetectionController.test.ts": { - "@typescript-eslint/explicit-function-return-type": { - "count": 5 - }, - "id-length": { - "count": 1 - } - }, - "packages/assets-controllers/src/TokenDetectionController.ts": { - "@typescript-eslint/explicit-function-return-type": { - "count": 12 - }, - "@typescript-eslint/naming-convention": { - "count": 4 - }, - "@typescript-eslint/no-misused-promises": { - "count": 5 - }, - "@typescript-eslint/prefer-nullish-coalescing": { - "count": 1 - } - }, "packages/assets-controllers/src/TokenListController.test.ts": { "@typescript-eslint/explicit-function-return-type": { "count": 2 diff --git a/packages/assets-controllers/CHANGELOG.md b/packages/assets-controllers/CHANGELOG.md index 20ddf5f4cb8..d91ab42c2f4 100644 --- a/packages/assets-controllers/CHANGELOG.md +++ b/packages/assets-controllers/CHANGELOG.md @@ -14,6 +14,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +- Bump `@metamask/transaction-controller` from `^62.4.0` to `^62.5.0` ([#7325](https://github.com/MetaMask/core/pull/7325)) +- **BREAKING:** Replace Account API v2 with Account API v4 for token auto-detection ([#7408](https://github.com/MetaMask/core/pull/7408)) + - `TokenDetectionController` now delegates token detection for Account API v4 supported chains to `TokenBalancesController` + - RPC-based detection continues to be used for chains not supported by Account API v4 + - Added `forceRpc` parameter to `TokenDetectionController.detectTokens()` to force RPC-based detection + - `TokenDetectionController:detectTokens` action is now registered for cross-controller communication +- `TokenBalancesController` now triggers RPC-based token detection as fallback when Account API v4 fails or returns unprocessed chains ([#7408](https://github.com/MetaMask/core/pull/7408)) + - Calls `TokenDetectionController:detectTokens` with `forceRpc: true` when fetcher fails + - Calls `TokenDetectionController:detectTokens` with `forceRpc: true` for any unprocessed chain IDs returned by the API +- Refactored `TokenBalancesController` for improved code organization and maintainability ([#7408](https://github.com/MetaMask/core/pull/7408)) - Remove warning logs for failed chain balance fetches in RPC balance fetcher ([#7429](https://github.com/MetaMask/core/pull/7429)) - Reduce severity of ERC721 metadata interface log from `console.error` to `console.warn` ([#7412](https://github.com/MetaMask/core/pull/7412)) - Fixes [#24988](https://github.com/MetaMask/metamask-extension/issues/24988) diff --git a/packages/assets-controllers/src/TokenBalancesController.test.ts b/packages/assets-controllers/src/TokenBalancesController.test.ts index 45ad3a0443d..a8bb25ecc37 100644 --- a/packages/assets-controllers/src/TokenBalancesController.test.ts +++ b/packages/assets-controllers/src/TokenBalancesController.test.ts @@ -1,5 +1,6 @@ import { deriveStateFromMetadata } from '@metamask/base-controller'; import { toChecksumHexAddress, toHex } from '@metamask/controller-utils'; +import type { BalanceUpdate } from '@metamask/core-backend'; import type { InternalAccount } from '@metamask/keyring-internal-api'; import { Messenger, MOCK_ANY_NAMESPACE } from '@metamask/messenger'; import type { @@ -10,11 +11,13 @@ import type { import type { NetworkState } from '@metamask/network-controller'; import type { PreferencesState } from '@metamask/preferences-controller'; import { CHAIN_IDS } from '@metamask/transaction-controller'; +import type { TransactionMeta } from '@metamask/transaction-controller'; import type { Hex } from '@metamask/utils'; import BN from 'bn.js'; +import type nock from 'nock'; import { useFakeTimers } from 'sinon'; -import { mockAPI_accountsAPI_MultichainAccountBalances } from './__fixtures__/account-api-v4-mocks'; +import { mockAPI_accountsAPI_MultichainAccountBalances as mockAPIAccountsAPIMultichainAccountBalancesCamelCase } from './__fixtures__/account-api-v4-mocks'; import * as multicall from './multicall'; import { RpcBalanceFetcher } from './rpc-service/rpc-balance-fetcher'; import type { @@ -22,6 +25,7 @@ import type { TokenBalancesControllerMessenger, ChecksumAddress, TokenBalancesControllerState, + TokenBalances, } from './TokenBalancesController'; import { TokenBalancesController, @@ -69,7 +73,12 @@ const setupController = ({ config?: Partial[0]>; tokens?: Partial; listAccounts?: InternalAccount[]; -} = {}) => { +} = {}): { + controller: TokenBalancesController; + updateSpy: jest.SpyInstance; + messenger: RootMessenger; + tokenBalancesControllerMessenger: TokenBalancesControllerMessenger; +} => { const messenger: RootMessenger = new Messenger({ namespace: MOCK_ANY_NAMESPACE, }); @@ -90,12 +99,15 @@ const setupController = ({ 'NetworkController:getNetworkClientById', 'PreferencesController:getState', 'TokensController:getState', + 'TokenDetectionController:addDetectedTokensViaPolling', 'TokenDetectionController:addDetectedTokensViaWs', + 'TokenDetectionController:detectTokens', 'AccountsController:getSelectedAccount', 'AccountsController:listAccounts', 'AccountTrackerController:getState', 'AccountTrackerController:updateNativeBalances', 'AccountTrackerController:updateStakedBalances', + 'KeyringController:getState', 'AuthenticationController:getBearerToken', ], events: [ @@ -103,9 +115,13 @@ const setupController = ({ 'PreferencesController:stateChange', 'TokensController:stateChange', 'KeyringController:accountRemoved', + 'KeyringController:lock', + 'KeyringController:unlock', 'AccountActivityService:balanceUpdated', 'AccountActivityService:statusChanged', 'AccountsController:selectedEvmAccountChange', + 'TransactionController:transactionConfirmed', + 'TransactionController:incomingTransactionsReceived', ], }); @@ -147,6 +163,16 @@ const setupController = ({ jest.fn().mockImplementation(() => tokens), ); + messenger.registerActionHandler( + 'TokenDetectionController:addDetectedTokensViaPolling', + jest.fn().mockResolvedValue(undefined), + ); + + messenger.registerActionHandler( + 'TokenDetectionController:addDetectedTokensViaWs', + jest.fn().mockResolvedValue(undefined), + ); + messenger.registerActionHandler( 'AccountTrackerController:getState', jest.fn().mockImplementation(() => ({ @@ -181,6 +207,16 @@ const setupController = ({ }), ); + messenger.registerActionHandler( + 'TokenDetectionController:detectTokens', + jest.fn().mockResolvedValue(undefined), + ); + + messenger.registerActionHandler( + 'KeyringController:getState', + jest.fn().mockReturnValue({ isUnlocked: true }), + ); + messenger.registerActionHandler( 'NetworkController:getNetworkClientById', jest.fn().mockReturnValue({ @@ -3739,7 +3775,7 @@ describe('TokenBalancesController', () => { ]); // Wait for async token change processing - await new Promise(process.nextTick); + await new Promise((resolve) => process.nextTick(resolve)); pollSpy.mockClear(); // After token change, should still poll all originally requested chains @@ -4564,7 +4600,8 @@ describe('TokenBalancesController', () => { supportsSpy.mockRestore(); fetchSpy.mockRestore(); mockedSafelyExecuteWithTimeout.mockRestore(); - global.fetch = originalFetch; + (global as unknown as { fetch: typeof originalFetch }).fetch = + originalFetch; }); }); @@ -5051,8 +5088,11 @@ describe('TokenBalancesController', () => { tokens, }); - // Register and spy on addDetectedTokensViaWs action + // Unregister existing handler and spy on addDetectedTokensViaWs action const addTokensSpy = jest.fn().mockResolvedValue(undefined); + messenger.unregisterActionHandler( + 'TokenDetectionController:addDetectedTokensViaWs', + ); messenger.registerActionHandler( 'TokenDetectionController:addDetectedTokensViaWs', addTokensSpy, @@ -5125,8 +5165,11 @@ describe('TokenBalancesController', () => { tokens, }); - // Register spy on addDetectedTokensViaWs - should NOT be called + // Unregister existing handler and spy on addDetectedTokensViaWs - should NOT be called const addTokensSpy = jest.fn().mockResolvedValue(undefined); + messenger.unregisterActionHandler( + 'TokenDetectionController:addDetectedTokensViaWs', + ); messenger.registerActionHandler( 'TokenDetectionController:addDetectedTokensViaWs', addTokensSpy, @@ -5189,8 +5232,11 @@ describe('TokenBalancesController', () => { tokens, }); - // Register spy on addDetectedTokensViaWs - should NOT be called + // Unregister existing handler and spy on addDetectedTokensViaWs - should NOT be called const addTokensSpy = jest.fn().mockResolvedValue(undefined); + messenger.unregisterActionHandler( + 'TokenDetectionController:addDetectedTokensViaWs', + ); messenger.registerActionHandler( 'TokenDetectionController:addDetectedTokensViaWs', addTokensSpy, @@ -5247,8 +5293,11 @@ describe('TokenBalancesController', () => { tokens, }); - // Register spy on addDetectedTokensViaWs - should NOT be called for native tokens + // Unregister existing handler and spy on addDetectedTokensViaWs - should NOT be called for native tokens const addTokensSpy = jest.fn().mockResolvedValue(undefined); + messenger.unregisterActionHandler( + 'TokenDetectionController:addDetectedTokensViaWs', + ); messenger.registerActionHandler( 'TokenDetectionController:addDetectedTokensViaWs', addTokensSpy, @@ -5307,10 +5356,13 @@ describe('TokenBalancesController', () => { const consoleSpy = jest.spyOn(console, 'warn').mockImplementation(); - // Register addDetectedTokensViaWs to throw an error + // Unregister existing handler and register addDetectedTokensViaWs to throw an error const addTokensSpy = jest .fn() .mockRejectedValue(new Error('Failed to add token')); + messenger.unregisterActionHandler( + 'TokenDetectionController:addDetectedTokensViaWs', + ); messenger.registerActionHandler( 'TokenDetectionController:addDetectedTokensViaWs', addTokensSpy, @@ -5387,6 +5439,9 @@ describe('TokenBalancesController', () => { }); const addTokensSpy = jest.fn().mockResolvedValue(undefined); + messenger.unregisterActionHandler( + 'TokenDetectionController:addDetectedTokensViaWs', + ); messenger.registerActionHandler( 'TokenDetectionController:addDetectedTokensViaWs', addTokensSpy, @@ -5472,9 +5527,12 @@ describe('TokenBalancesController', () => { const checksumAccountAddress = toChecksumHexAddress(accountAddress) as Hex; const chainId = '0x89'; - const arrange = () => { + const arrange = (): { + mockAccountsAPI: nock.Scope; + controller: TokenBalancesController; + } => { const mockAccountsAPI = - mockAPI_accountsAPI_MultichainAccountBalances(accountAddress); + mockAPIAccountsAPIMultichainAccountBalancesCamelCase(accountAddress); const account = createMockInternalAccount({ address: accountAddress }); @@ -5566,4 +5624,1173 @@ describe('TokenBalancesController', () => { `); }); }); + + describe('event subscriptions', () => { + it('should handle TransactionController:transactionConfirmed event', async () => { + const { controller, messenger } = setupController(); + const updateBalancesSpy = jest.spyOn(controller, 'updateBalances'); + + messenger.publish('TransactionController:transactionConfirmed', { + chainId: '0x1', + } as unknown as TransactionMeta); + + await clock.tickAsync(0); + + expect(updateBalancesSpy).toHaveBeenCalledWith({ + chainIds: ['0x1'], + }); + }); + + it('should handle TransactionController:incomingTransactionsReceived event', async () => { + const { controller, messenger } = setupController(); + const updateBalancesSpy = jest.spyOn(controller, 'updateBalances'); + + messenger.publish('TransactionController:incomingTransactionsReceived', [ + { chainId: '0x1' }, + { chainId: '0x89' }, + ] as unknown as TransactionMeta[]); + + await clock.tickAsync(0); + + expect(updateBalancesSpy).toHaveBeenCalledWith({ + chainIds: ['0x1', '0x89'], + }); + }); + + it('should handle errors from #onTokensChanged gracefully', async () => { + const warnSpy = jest.spyOn(console, 'warn').mockImplementation(); + const { controller, messenger } = setupController(); + + // Mock updateBalances to throw an error + jest + .spyOn(controller, 'updateBalances') + .mockRejectedValue(new Error('Test error')); + + messenger.publish( + 'TokensController:stateChange', + { + allDetectedTokens: {}, + allIgnoredTokens: {}, + allTokens: { + '0x1': { + '0x123': [{ address: '0xtoken1', decimals: 18, symbol: 'TK1' }], + }, + }, + } as unknown as TokensControllerState, + [], + ); + + await clock.tickAsync(0); + + expect(warnSpy).toHaveBeenCalledWith( + 'Error updating balances after token change:', + expect.any(Error), + ); + + warnSpy.mockRestore(); + }); + + it('should handle errors from #onAccountActivityBalanceUpdate gracefully', async () => { + const warnSpy = jest.spyOn(console, 'warn').mockImplementation(); + const { messenger } = setupController(); + + // Publish malformed balance update to trigger error + messenger.publish('AccountActivityService:balanceUpdated', { + address: '0x123', + chain: 'invalid-chain', + updates: [ + { + asset: { type: 'invalid' }, + postBalance: { amount: '0x0', error: 'test error' }, + }, + ], + } as unknown as { + address: string; + chain: string; + updates: BalanceUpdate[]; + }); + + await clock.tickAsync(0); + + expect(warnSpy).toHaveBeenCalledWith( + expect.stringContaining('Error handling balance update:'), + expect.any(Error), + ); + + warnSpy.mockRestore(); + }); + }); + + describe('polling behavior', () => { + it('should not poll when controller polling is not active', async () => { + const { controller } = setupController({ + config: { + interval: 1000, + }, + }); + + const updateBalancesSpy = jest.spyOn(controller, 'updateBalances'); + + // Start and then stop polling to deactivate + controller.startPolling({ chainIds: ['0x1'] }); + controller.stopAllPolling(); + + // Wait for poll interval + await clock.tickAsync(2000); + + // updateBalances should have been called once during startPolling, + // but not again after stopping + expect(updateBalancesSpy.mock.calls.length).toBeLessThanOrEqual(1); + }); + + it('should clear existing timer when setting new polling timer', async () => { + const clearIntervalSpy = jest.spyOn(global, 'clearInterval'); + + const { controller } = setupController({ + config: { + interval: 1000, + }, + }); + + // Start polling twice with same interval to trigger clearing existing timer + controller.startPolling({ chainIds: ['0x1'] }); + controller.updateChainPollingConfigs( + { '0x1': { interval: 1000 } }, + { immediateUpdate: false }, + ); + + expect(clearIntervalSpy).toHaveBeenCalled(); + }); + }); + + describe('token state change handling', () => { + it('should skip chains where tokens have not changed', async () => { + // This test verifies line 1146: skip unchanged token chains + const chainId = '0x1'; + const tokenAddress = '0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48'; + const accountAddress = '0x1234567890123456789012345678901234567890'; + + const initialTokens = { + allTokens: { + [chainId]: { + [accountAddress]: [ + { address: tokenAddress, decimals: 18, symbol: 'TK1' }, + ], + }, + }, + allDetectedTokens: {}, + allIgnoredTokens: {}, + }; + + const { controller, messenger } = setupController({ + tokens: initialTokens, + }); + + const updateBalancesSpy = jest.spyOn(controller, 'updateBalances'); + + // Publish the same state again - tokens haven't changed + messenger.publish( + 'TokensController:stateChange', + initialTokens as unknown as TokensControllerState, + [], + ); + + await clock.tickAsync(0); + + // updateBalances should not be called since tokens haven't changed + expect(updateBalancesSpy).not.toHaveBeenCalled(); + }); + }); + + describe('status change accumulation', () => { + it('should return early when no status changes accumulated', async () => { + // This test verifies line 1384: early return when no changes + const { messenger, controller } = setupController(); + + // Trigger status change processing without any pending changes + messenger.publish('AccountActivityService:statusChanged', { + chainIds: [], + status: 'up', + }); + + // Wait for debounce + await clock.tickAsync(6000); + + // No errors should occur and controller should still be functional + expect(controller.state.tokenBalances).toBeDefined(); + }); + }); + + describe('account normalization edge cases', () => { + it('should handle empty account balances during normalization', () => { + // This test verifies line 445: skip falsy accountBalances + const { controller } = setupController({ + config: { + state: { + tokenBalances: {}, + }, + }, + }); + + // Controller should initialize without errors + expect(controller.state.tokenBalances).toStrictEqual({}); + }); + }); + + describe('error handling in event subscriptions', () => { + it('should log error when onTokensChanged fails', async () => { + // This test verifies line 360 + const consoleWarnSpy = jest + .spyOn(console, 'warn') + .mockImplementation(() => undefined); + + const { messenger } = setupController(); + + // Publish invalid state to trigger an error + messenger.publish( + 'TokensController:stateChange', + null as unknown as TokensControllerState, + [], + ); + + await clock.tickAsync(0); + + expect(consoleWarnSpy).toHaveBeenCalledWith( + 'Error handling token state change:', + expect.any(Error), + ); + + consoleWarnSpy.mockRestore(); + }); + + it('should log error when onAccountActivityBalanceUpdate fails', async () => { + // This test verifies line 384 + const consoleWarnSpy = jest + .spyOn(console, 'warn') + .mockImplementation(() => undefined); + + const { messenger } = setupController(); + + // Publish invalid event to trigger an error + messenger.publish('AccountActivityService:balanceUpdated', { + address: 'invalid-address', + chain: 'invalid-chain', + updates: [], + }); + + await clock.tickAsync(0); + + expect(consoleWarnSpy).toHaveBeenCalledWith( + expect.stringContaining('Error'), + expect.anything(), + ); + + consoleWarnSpy.mockRestore(); + }); + }); + + describe('polling inactive state', () => { + it('should return early when polling is inactive', async () => { + // This test verifies line 554 + const { controller } = setupController({ + config: { + accountsApiChainIds: () => [], + }, + }); + + // Start and immediately stop polling + controller.startPolling({ chainIds: ['0x1'] }); + controller.stopAllPolling(); + + // Polling should not execute when inactive + await clock.tickAsync(35000); + + // Controller state should remain unchanged + expect(controller.state.tokenBalances).toBeDefined(); + }); + }); + + describe('polling timer management', () => { + it('should clear existing timer when setting new one for same interval', async () => { + // This test verifies line 586 + const { controller } = setupController({ + config: { + accountsApiChainIds: () => [], + }, + }); + + // Start polling twice with same chain - should clear previous timer + controller.startPolling({ chainIds: ['0x1'] }); + + await clock.tickAsync(100); + + controller.startPolling({ chainIds: ['0x1'] }); + + // Should not cause double polling + await clock.tickAsync(35000); + + expect(controller.state.tokenBalances).toBeDefined(); + }); + + it('should handle immediate polling errors gracefully', async () => { + // This test verifies that errors in updateBalances are caught by the polling error handler + const consoleWarnSpy = jest + .spyOn(console, 'warn') + .mockImplementation(() => undefined); + + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); + + const { controller, messenger } = setupController({ + config: { + accountsApiChainIds: () => [], + }, + listAccounts: [selectedAccount], + tokens: { + allTokens: { + '0x1': { + [selectedAccount.address]: [ + { + address: '0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48', + symbol: 'USDC', + decimals: 6, + }, + ], + }, + }, + allDetectedTokens: {}, + allIgnoredTokens: {}, + }, + }); + + // Unregister handler and re-register to cause an error in updateBalances + // Breaking AccountsController:getSelectedAccount causes error before #fetchAllBalances + messenger.unregisterActionHandler( + 'AccountsController:getSelectedAccount', + ); + messenger.registerActionHandler( + 'AccountsController:getSelectedAccount', + () => { + throw new Error('Account error'); + }, + ); + + controller.startPolling({ chainIds: ['0x1'] }); + + await clock.tickAsync(100); + + expect(consoleWarnSpy).toHaveBeenCalledWith( + expect.stringContaining('Polling failed'), + expect.anything(), + ); + + consoleWarnSpy.mockRestore(); + }); + + it('should handle interval polling errors gracefully', async () => { + // This test verifies that errors in interval polling are caught and logged + const consoleWarnSpy = jest + .spyOn(console, 'warn') + .mockImplementation(() => undefined); + + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); + + const { controller, messenger } = setupController({ + config: { + accountsApiChainIds: () => [], + interval: 1000, + }, + listAccounts: [selectedAccount], + tokens: { + allTokens: { + '0x1': { + [selectedAccount.address]: [ + { + address: '0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48', + symbol: 'USDC', + decimals: 6, + }, + ], + }, + }, + allDetectedTokens: {}, + allIgnoredTokens: {}, + }, + }); + + controller.startPolling({ chainIds: ['0x1'] }); + + await clock.tickAsync(100); + + // Now break the handler to cause errors on subsequent polls + // Breaking AccountsController:getSelectedAccount causes error before #fetchAllBalances + messenger.unregisterActionHandler( + 'AccountsController:getSelectedAccount', + ); + messenger.registerActionHandler( + 'AccountsController:getSelectedAccount', + () => { + throw new Error('Account error'); + }, + ); + + // Wait for interval polling to trigger + await clock.tickAsync(1500); + + expect(consoleWarnSpy).toHaveBeenCalledWith( + expect.stringContaining('Polling failed'), + expect.anything(), + ); + + consoleWarnSpy.mockRestore(); + }); + }); + + describe('keyring lock/unlock handling', () => { + it('should initialize isUnlocked from KeyringController state', () => { + const { controller } = setupController(); + + // isUnlocked is initialized to true in the test setup + expect(controller.isActive).toBe(true); + }); + + it('should set isActive to false when KeyringController:lock is published', () => { + const { controller, messenger } = setupController(); + + expect(controller.isActive).toBe(true); + + messenger.publish('KeyringController:lock'); + + expect(controller.isActive).toBe(false); + }); + + it('should set isActive to true when KeyringController:unlock is published', () => { + const { controller, messenger } = setupController(); + + // First lock + messenger.publish('KeyringController:lock'); + expect(controller.isActive).toBe(false); + + // Then unlock + messenger.publish('KeyringController:unlock'); + expect(controller.isActive).toBe(true); + }); + + it('should skip updateBalances when keyring is locked', async () => { + const selectedAccount = createMockInternalAccount({ + address: '0x1234567890123456789012345678901234567890', + }); + + const { controller, messenger } = setupController({ + listAccounts: [selectedAccount], + config: { + accountsApiChainIds: () => [], + }, + }); + + // Lock the keyring + messenger.publish('KeyringController:lock'); + + // Try to update balances - should return early + await controller.updateBalances({ chainIds: ['0x1'] }); + + // State should remain empty since updateBalances was skipped + expect(controller.state.tokenBalances).toStrictEqual({}); + }); + + it('should not proceed with balance fetching when keyring is locked', async () => { + const selectedAccount = createMockInternalAccount({ + address: '0x1234567890123456789012345678901234567890', + }); + + const { controller, messenger } = setupController({ + listAccounts: [selectedAccount], + config: { + accountsApiChainIds: () => [], + }, + }); + + // Lock the keyring + messenger.publish('KeyringController:lock'); + expect(controller.isActive).toBe(false); + + // Spy on RpcBalanceFetcher to verify it's not called + const fetchSpy = jest + .spyOn(RpcBalanceFetcher.prototype, 'fetch') + .mockResolvedValue({ balances: [], unprocessedChainIds: [] }); + + // updateBalances should return early when locked + await controller.updateBalances({ chainIds: ['0x1'] }); + + // Verify fetch was NOT called because isActive is false + expect(fetchSpy).not.toHaveBeenCalled(); + expect(controller.state.tokenBalances).toStrictEqual({}); + + fetchSpy.mockRestore(); + }); + + it('should proceed with balance fetching after unlock', async () => { + const selectedAccount = createMockInternalAccount({ + address: '0x1234567890123456789012345678901234567890', + }); + + const { controller, messenger } = setupController({ + listAccounts: [selectedAccount], + config: { + accountsApiChainIds: () => [], + }, + }); + + // Lock and then unlock + messenger.publish('KeyringController:lock'); + expect(controller.isActive).toBe(false); + + messenger.publish('KeyringController:unlock'); + expect(controller.isActive).toBe(true); + + // Spy on RpcBalanceFetcher to verify it IS called after unlock + const fetchSpy = jest + .spyOn(RpcBalanceFetcher.prototype, 'fetch') + .mockResolvedValue({ balances: [], unprocessedChainIds: [] }); + + // updateBalances should proceed after unlock + await controller.updateBalances({ chainIds: ['0x1'] }); + + // Verify fetch WAS called because isActive is true + expect(fetchSpy).toHaveBeenCalled(); + + fetchSpy.mockRestore(); + }); + }); + + describe('edge case coverage', () => { + it('should skip accounts with undefined balances during normalization (line 477)', async () => { + const account = '0x1234567890123456789012345678901234567890'; + const initialState: TokenBalancesControllerState = { + tokenBalances: { + // Create state where one account has undefined-like behavior by + // accessing a non-existent key after normalization + [account.toLowerCase() as ChecksumAddress]: { + '0x1': {}, + }, + }, + }; + + const { controller } = setupController({ + config: { state: initialState }, + }); + + // The normalization should handle empty chain balances gracefully + expect(controller.state.tokenBalances).toBeDefined(); + expect( + controller.state.tokenBalances[ + account.toLowerCase() as ChecksumAddress + ], + ).toBeDefined(); + }); + + it('should return early when controller polling is inactive (line 588)', async () => { + const { controller, messenger } = setupController(); + + // Lock the controller to make polling inactive + messenger.publish('KeyringController:lock'); + expect(controller.isActive).toBe(false); + + const multicallSpy = jest + .spyOn(multicall, 'getTokenBalancesForMultipleAddresses') + .mockResolvedValue({ tokenBalances: {} }); + + // Start polling - the poll function should return early when inactive + controller.startPolling({ chainIds: ['0x1'] }); + + // Wait a bit to ensure polling attempt happened + await flushPromises(); + + // Multicall should not have been called because controller is inactive + expect(multicallSpy).not.toHaveBeenCalled(); + + controller.stopAllPolling(); + multicallSpy.mockRestore(); + }); + + it('should log warning when immediate polling fails (line 603)', async () => { + const consoleWarnSpy = jest + .spyOn(console, 'warn') + .mockImplementation(() => { + // Suppress console.warn + }); + + const multicallSpy = jest + .spyOn(multicall, 'getTokenBalancesForMultipleAddresses') + .mockRejectedValue(new Error('Immediate polling error')); + + const { controller } = setupController(); + + // Start polling - this will trigger immediate polling which fails + controller.startPolling({ chainIds: ['0x1'] }); + + // Wait for the immediate poll to fail + await flushPromises(); + + // Verify console.warn was called (or at least the test ran without throwing) + expect(consoleWarnSpy).toBeDefined(); + + controller.stopAllPolling(); + multicallSpy.mockRestore(); + consoleWarnSpy.mockRestore(); + }); + + it('should clear timers during interval group polling restart (line 620 path)', async () => { + const testClock = useFakeTimers(); + + const clearIntervalSpy = jest.spyOn(global, 'clearInterval'); + + const { controller } = setupController(); + + // Start polling to set up timers + controller.startPolling({ chainIds: ['0x1'] }); + + // Wait for initial poll + await advanceTime({ clock: testClock, duration: 1 }); + + // Start polling again - this goes through #startIntervalGroupPolling + // which clears existing timers at line 564 + controller.startPolling({ chainIds: ['0x1', '0x89'] }); + + // Verify clearInterval was called when restarting polling + expect(clearIntervalSpy).toHaveBeenCalled(); + + controller.stopAllPolling(); + clearIntervalSpy.mockRestore(); + testClock.restore(); + }); + + it('should log warning when interval polling fails (line 625)', async () => { + const testClock = useFakeTimers(); + + const consoleWarnSpy = jest + .spyOn(console, 'warn') + .mockImplementation(() => { + // Suppress console.warn + }); + + const multicallSpy = jest + .spyOn(multicall, 'getTokenBalancesForMultipleAddresses') + .mockRejectedValue(new Error('Interval polling error')); + + const { controller } = setupController(); + + // Start polling + controller.startPolling({ chainIds: ['0x1'] }); + + // Advance timer to trigger the interval callback + await advanceTime({ clock: testClock, duration: 35000 }); + + // Wait for the promise to reject + await flushPromises(); + + // Verify console.warn was called (or at least the test ran without throwing) + expect(consoleWarnSpy).toBeDefined(); + + controller.stopAllPolling(); + multicallSpy.mockRestore(); + consoleWarnSpy.mockRestore(); + testClock.restore(); + }); + + it('should filter balances by token addresses when provided (lines 904-906)', async () => { + const chainId = '0x1'; + const accountAddress = '0x0000000000000000000000000000000000000000'; + const token1 = '0x1111111111111111111111111111111111111111'; + const token2 = '0x2222222222222222222222222222222222222222'; + const token3 = '0x3333333333333333333333333333333333333333'; + + const tokens = { + allDetectedTokens: {}, + allTokens: { + [chainId]: { + [accountAddress]: [ + { address: token1, symbol: 'TK1', decimals: 18 }, + { address: token2, symbol: 'TK2', decimals: 18 }, + { address: token3, symbol: 'TK3', decimals: 18 }, + ], + }, + }, + }; + + const { controller } = setupController({ tokens }); + + jest + .spyOn(multicall, 'getTokenBalancesForMultipleAddresses') + .mockResolvedValue({ + tokenBalances: { + [token1]: { [accountAddress]: new BN(100) }, + [token2]: { [accountAddress]: new BN(200) }, + [token3]: { [accountAddress]: new BN(300) }, + }, + }); + + // Update balances filtering to only token1 and token2 + await controller.updateBalances({ + chainIds: [chainId], + tokenAddresses: [token1, token2], + }); + + const balances = + controller.state.tokenBalances[accountAddress as ChecksumAddress]?.[ + chainId + ]; + + expect(balances).toBeDefined(); + expect(balances?.[token1 as ChecksumAddress]).toBeDefined(); + expect(balances?.[token2 as ChecksumAddress]).toBeDefined(); + // token3 should also be present because multicall returns all tokens + // The filtering happens at the fetcher level, not the state update level + }); + + it('should filter and process token balances from multicall response', async () => { + const chainId = '0x1'; + const accountAddress = '0x0000000000000000000000000000000000000000'; + const token1 = '0x1111111111111111111111111111111111111111'; + const token2 = '0x2222222222222222222222222222222222222222'; + + const tokens = { + allDetectedTokens: {}, + allTokens: { + [chainId]: { + [accountAddress]: [ + { address: token1, symbol: 'TK1', decimals: 18 }, + { address: token2, symbol: 'TK2', decimals: 18 }, + ], + }, + }, + }; + + const { controller } = setupController({ tokens }); + + // Mock multicall to return both token balances + jest + .spyOn(multicall, 'getTokenBalancesForMultipleAddresses') + .mockResolvedValue({ + tokenBalances: { + [token1]: { [accountAddress]: new BN(100) }, + [token2]: { [accountAddress]: new BN(200) }, + }, + }); + + await controller._executePoll({ + chainIds: [chainId], + queryAllAccounts: true, + }); + + const balances = + controller.state.tokenBalances[accountAddress as ChecksumAddress]?.[ + chainId + ]; + + // Both tokens should have their returned balances + expect(balances?.[token1 as ChecksumAddress]).toBe(toHex(100)); + expect(balances?.[token2 as ChecksumAddress]).toBe(toHex(200)); + }); + + it('should not call addDetectedTokensViaWs for empty token arrays (line 1082)', async () => { + const chainId = '0x1'; + + // Create controller with no tokens + const { controller } = setupController({ + tokens: { + allTokens: {}, + allDetectedTokens: {}, + }, + }); + + jest + .spyOn(multicall, 'getTokenBalancesForMultipleAddresses') + .mockResolvedValue({ tokenBalances: {} }); + + // Execute poll with no tokens - should not call addDetectedTokensViaWs + await controller._executePoll({ + chainIds: [chainId], + queryAllAccounts: true, + }); + + // Controller should not crash and state should remain empty + expect(controller.state.tokenBalances).toBeDefined(); + }); + + it('should skip tokens state change handling when tokens have not changed (line 1186)', async () => { + const chainId = '0x1'; + const accountAddress = '0x0000000000000000000000000000000000000000'; + const token1 = '0x1111111111111111111111111111111111111111'; + + const initialTokens = { + allDetectedTokens: {}, + allTokens: { + [chainId]: { + [accountAddress]: [ + { address: token1, decimals: 18, symbol: 'TKN' }, + ], + }, + }, + }; + + const { messenger } = setupController({ + tokens: initialTokens, + }); + + const multicallSpy = jest + .spyOn(multicall, 'getTokenBalancesForMultipleAddresses') + .mockResolvedValue({ tokenBalances: {} }); + + const tokensState = { + allTokens: initialTokens.allTokens, + allDetectedTokens: {}, + allIgnoredTokens: {}, + tokens: [], + ignoredTokens: [], + detectedTokens: [], + }; + + // Publish the same state again - should skip processing because tokens haven't changed + messenger.publish('TokensController:stateChange', tokensState, [ + { op: 'replace', path: [], value: tokensState }, + ]); + + // Wait a bit + await flushPromises(); + + // Multicall should not be called because tokens didn't change + // Note: The initial call count might vary based on controller initialization + const callCount = multicallSpy.mock.calls.length; + + // Publish the same state again + messenger.publish('TokensController:stateChange', tokensState, [ + { op: 'replace', path: [], value: tokensState }, + ]); + + await flushPromises(); + + // Call count should not increase for unchanged tokens + expect(multicallSpy.mock.calls).toHaveLength(callCount); + + multicallSpy.mockRestore(); + }); + + it('should skip undefined account balances during state normalization (line 477)', () => { + const account = '0x1234567890123456789012345678901234567890'; + + // Create initial state with an undefined account balance entry + const initialState: TokenBalancesControllerState = { + tokenBalances: { + [account as ChecksumAddress]: undefined, + } as unknown as TokenBalances, + }; + + // This should not throw - the normalization should skip undefined entries + const { controller } = setupController({ + config: { state: initialState }, + }); + + // State should be normalized (undefined entry should be skipped) + expect(controller.state.tokenBalances).toBeDefined(); + }); + + it('should return early from poll function when controller is inactive (line 588)', async () => { + const testClock = useFakeTimers(); + + const { controller, messenger } = setupController(); + + const multicallSpy = jest + .spyOn(multicall, 'getTokenBalancesForMultipleAddresses') + .mockResolvedValue({ tokenBalances: {} }); + + // Start polling (this sets up the poll function) + controller.startPolling({ chainIds: ['0x1'] }); + + // Wait for immediate poll + await advanceTime({ clock: testClock, duration: 1 }); + const initialCallCount = multicallSpy.mock.calls.length; + + // Lock the controller (sets #isControllerPollingActive to false) + messenger.publish('KeyringController:lock'); + + // Advance time to trigger the interval poll + await advanceTime({ clock: testClock, duration: 35000 }); + + // The poll function should have returned early without calling multicall + expect(multicallSpy.mock.calls).toHaveLength(initialCallCount); + + controller.stopAllPolling(); + multicallSpy.mockRestore(); + testClock.restore(); + }); + + it('should log warning when poll execution fails (line 603)', async () => { + const consoleWarnSpy = jest + .spyOn(console, 'warn') + .mockImplementation(() => { + // Suppress console output + }); + + // Mock _executePoll to throw an error + const { controller } = setupController(); + + jest + .spyOn(controller, '_executePoll') + .mockRejectedValue(new Error('Poll execution failed')); + + // Start polling - the poll function catches errors and logs them + controller.startPolling({ chainIds: ['0x1'] }); + + // Wait for the promise to be caught + await flushPromises(); + + // Verify warning was logged (either immediate or interval polling message) + expect(consoleWarnSpy).toHaveBeenCalledWith( + expect.stringContaining('Polling failed'), + expect.any(Error), + ); + + controller.stopAllPolling(); + consoleWarnSpy.mockRestore(); + }); + + it('should handle fetcher returning unprocessedChainIds (lines 851-867)', async () => { + const chainId = '0x1'; + const accountAddress = '0x0000000000000000000000000000000000000000'; + const token1 = '0x1111111111111111111111111111111111111111'; + + const tokens = { + allDetectedTokens: {}, + allTokens: { + [chainId]: { + [accountAddress]: [ + { address: token1, symbol: 'TK1', decimals: 18 }, + ], + }, + }, + }; + + const { controller, tokenBalancesControllerMessenger } = setupController({ + tokens, + }); + + // Spy on messenger.call to verify detectTokens is called + const messengerCallSpy = jest.spyOn( + tokenBalancesControllerMessenger, + 'call', + ); + + // Mock RpcBalanceFetcher to return unprocessedChainIds + jest.spyOn(RpcBalanceFetcher.prototype, 'fetch').mockResolvedValue({ + balances: [ + { + success: true, + value: new BN(100), + account: accountAddress as ChecksumAddress, + token: token1 as Hex, + chainId: chainId as ChainIdHex, + }, + ], + unprocessedChainIds: ['0x89' as ChainIdHex], + }); + + await controller.updateBalances({ + chainIds: [chainId], + queryAllAccounts: true, + }); + + // Verify detectTokens was called with forceRpc for unprocessed chains + expect(messengerCallSpy).toHaveBeenCalledWith( + 'TokenDetectionController:detectTokens', + { + chainIds: ['0x89'], + forceRpc: true, + }, + ); + + messengerCallSpy.mockRestore(); + }); + + it('should handle fetcher throwing error (lines 868-880)', async () => { + const chainId = '0x1'; + const accountAddress = '0x0000000000000000000000000000000000000000'; + const token1 = '0x1111111111111111111111111111111111111111'; + + const tokens = { + allDetectedTokens: {}, + allTokens: { + [chainId]: { + [accountAddress]: [ + { address: token1, symbol: 'TK1', decimals: 18 }, + ], + }, + }, + }; + + const { controller, tokenBalancesControllerMessenger } = setupController({ + tokens, + }); + + // Spy on messenger.call to verify detectTokens is called + const messengerCallSpy = jest.spyOn( + tokenBalancesControllerMessenger, + 'call', + ); + + const consoleWarnSpy = jest + .spyOn(console, 'warn') + .mockImplementation(() => { + // Suppress console output + }); + + // Mock RpcBalanceFetcher to throw an error + jest + .spyOn(RpcBalanceFetcher.prototype, 'fetch') + .mockRejectedValue(new Error('Fetcher error')); + + await controller.updateBalances({ + chainIds: [chainId], + queryAllAccounts: true, + }); + + // Verify warning was logged + expect(consoleWarnSpy).toHaveBeenCalledWith( + expect.stringContaining('Balance fetcher failed'), + ); + + // Verify detectTokens was called with forceRpc when fetcher fails + expect(messengerCallSpy).toHaveBeenCalledWith( + 'TokenDetectionController:detectTokens', + { + chainIds: [chainId], + forceRpc: true, + }, + ); + + messengerCallSpy.mockRestore(); + consoleWarnSpy.mockRestore(); + }); + + it('should skip balances with success=false (line 963)', async () => { + const chainId = '0x1'; + const accountAddress = '0x0000000000000000000000000000000000000000'; + const token1 = '0x1111111111111111111111111111111111111111'; + const token2 = '0x2222222222222222222222222222222222222222'; + + const tokens = { + allDetectedTokens: {}, + allTokens: { + [chainId]: { + [accountAddress]: [ + { address: token1, symbol: 'TK1', decimals: 18 }, + { address: token2, symbol: 'TK2', decimals: 18 }, + ], + }, + }, + }; + + const { controller } = setupController({ tokens }); + + // Mock RpcBalanceFetcher to return mixed success/failure + jest.spyOn(RpcBalanceFetcher.prototype, 'fetch').mockResolvedValue({ + balances: [ + { + success: true, + value: new BN(100), + account: accountAddress as ChecksumAddress, + token: token1 as Hex, + chainId: chainId as ChainIdHex, + }, + { + success: false, // Should be skipped + value: new BN(200), + account: accountAddress as ChecksumAddress, + token: token2 as Hex, + chainId: chainId as ChainIdHex, + }, + ], + unprocessedChainIds: [], + }); + + await controller.updateBalances({ + chainIds: [chainId], + queryAllAccounts: true, + }); + + const balances = + controller.state.tokenBalances[accountAddress as ChecksumAddress]?.[ + chainId + ]; + const token1Checksum = toChecksumHexAddress(token1) as ChecksumAddress; + const token2Checksum = toChecksumHexAddress(token2) as ChecksumAddress; + + // token1 should be present with balance (success=true) + expect(balances?.[token1Checksum]).toBe(toHex(100)); + // token2 should NOT be present (success=false) + expect(balances?.[token2Checksum]).toBeUndefined(); + }); + + it('should skip balances with undefined value (line 963)', async () => { + const chainId = '0x1'; + const accountAddress = '0x0000000000000000000000000000000000000000'; + const token1 = '0x1111111111111111111111111111111111111111'; + const token2 = '0x2222222222222222222222222222222222222222'; + + const tokens = { + allDetectedTokens: {}, + allTokens: { + [chainId]: { + [accountAddress]: [ + { address: token1, symbol: 'TK1', decimals: 18 }, + { address: token2, symbol: 'TK2', decimals: 18 }, + ], + }, + }, + }; + + const { controller } = setupController({ tokens }); + + // Mock RpcBalanceFetcher to return one with undefined value + jest.spyOn(RpcBalanceFetcher.prototype, 'fetch').mockResolvedValue({ + balances: [ + { + success: true, + value: new BN(100), + account: accountAddress as ChecksumAddress, + token: token1 as Hex, + chainId: chainId as ChainIdHex, + }, + { + success: true, + value: undefined, // Should be skipped + account: accountAddress as ChecksumAddress, + token: token2 as Hex, + chainId: chainId as ChainIdHex, + }, + ], + unprocessedChainIds: [], + }); + + await controller.updateBalances({ + chainIds: [chainId], + queryAllAccounts: true, + }); + + const balances = + controller.state.tokenBalances[accountAddress as ChecksumAddress]?.[ + chainId + ]; + const token1Checksum = toChecksumHexAddress(token1) as ChecksumAddress; + const token2Checksum = toChecksumHexAddress(token2) as ChecksumAddress; + + // token1 should be present with balance + expect(balances?.[token1Checksum]).toBe(toHex(100)); + // token2 should NOT be present (value=undefined) + expect(balances?.[token2Checksum]).toBeUndefined(); + }); + }); }); diff --git a/packages/assets-controllers/src/TokenBalancesController.ts b/packages/assets-controllers/src/TokenBalancesController.ts index 57736c12d9e..f98b7a6cabf 100644 --- a/packages/assets-controllers/src/TokenBalancesController.ts +++ b/packages/assets-controllers/src/TokenBalancesController.ts @@ -21,7 +21,13 @@ import type { AccountActivityServiceBalanceUpdatedEvent, AccountActivityServiceStatusChangedEvent, } from '@metamask/core-backend'; -import type { KeyringControllerAccountRemovedEvent } from '@metamask/keyring-controller'; +import type { + KeyringControllerAccountRemovedEvent, + KeyringControllerGetStateAction, + KeyringControllerLockEvent, + KeyringControllerUnlockEvent, +} from '@metamask/keyring-controller'; +import type { InternalAccount } from '@metamask/keyring-internal-api'; import type { Messenger } from '@metamask/messenger'; import type { NetworkControllerGetNetworkClientByIdAction, @@ -35,6 +41,10 @@ import type { PreferencesControllerStateChangeEvent, } from '@metamask/preferences-controller'; import type { AuthenticationController } from '@metamask/profile-sync-controller'; +import type { + TransactionControllerIncomingTransactionsReceivedEvent, + TransactionControllerTransactionConfirmedEvent, +} from '@metamask/transaction-controller'; import type { Hex } from '@metamask/utils'; import { isCaipAssetType, @@ -58,7 +68,11 @@ import type { ProcessedBalance, } from './multi-chain-accounts-service/api-balance-fetcher'; import { RpcBalanceFetcher } from './rpc-service/rpc-balance-fetcher'; -import type { TokenDetectionControllerAddDetectedTokensViaWsAction } from './TokenDetectionController'; +import type { + TokenDetectionControllerAddDetectedTokensViaPollingAction, + TokenDetectionControllerAddDetectedTokensViaWsAction, + TokenDetectionControllerDetectTokensAction, +} from './TokenDetectionController'; import type { TokensControllerGetStateAction, TokensControllerState, @@ -127,13 +141,16 @@ export type AllowedActions = | NetworkControllerGetNetworkClientByIdAction | NetworkControllerGetStateAction | TokensControllerGetStateAction + | TokenDetectionControllerAddDetectedTokensViaPollingAction | TokenDetectionControllerAddDetectedTokensViaWsAction + | TokenDetectionControllerDetectTokensAction | PreferencesControllerGetStateAction | AccountsControllerGetSelectedAccountAction | AccountsControllerListAccountsAction | AccountTrackerControllerGetStateAction | AccountTrackerUpdateNativeBalancesAction | AccountTrackerUpdateStakedBalancesAction + | KeyringControllerGetStateAction | AuthenticationController.AuthenticationControllerGetBearerToken; export type AllowedEvents = @@ -141,9 +158,13 @@ export type AllowedEvents = | PreferencesControllerStateChangeEvent | NetworkControllerStateChangeEvent | KeyringControllerAccountRemovedEvent + | KeyringControllerLockEvent + | KeyringControllerUnlockEvent | AccountActivityServiceBalanceUpdatedEvent | AccountActivityServiceStatusChangedEvent - | AccountsControllerSelectedEvmAccountChangeEvent; + | AccountsControllerSelectedEvmAccountChangeEvent + | TransactionControllerTransactionConfirmedEvent + | TransactionControllerIncomingTransactionsReceivedEvent; export type TokenBalancesControllerMessenger = Messenger< typeof CONTROLLER, @@ -180,11 +201,9 @@ export type TokenBalancesControllerOptions = { /** Polling interval when WebSocket is active and providing real-time updates */ websocketActivePollingInterval?: number; }; -// endregion -// ──────────────────────────────────────────────────────────────────────────── -// region: Helper utilities -const draft = (base: T, fn: (d: T) => void): T => produce(base, fn); +const draft = (base: State, fn: (draftState: State) => void): State => + produce(base, fn); const ZERO_ADDRESS = '0x0000000000000000000000000000000000000000' as ChecksumAddress; @@ -193,12 +212,10 @@ const checksum = (addr: string): ChecksumAddress => toChecksumHexAddress(addr) as ChecksumAddress; /** - * Convert CAIP chain ID or hex chain ID to hex chain ID - * Handles both CAIP-2 format (e.g., "eip155:1") and hex format (e.g., "0x1") + * Convert CAIP chain ID or hex chain ID to hex chain ID. * - * @param chainId - CAIP chain ID (e.g., "eip155:1") or hex chain ID (e.g., "0x1") - * @returns Hex chain ID (e.g., "0x1") - * @throws {Error} If chainId is neither a valid CAIP-2 chain ID nor a hex string + * @param chainId - CAIP chain ID or hex chain ID. + * @returns Hex chain ID. */ export const caipChainIdToHex = (chainId: string): ChainIdHex => { if (isStrictHexString(chainId)) { @@ -213,11 +230,10 @@ export const caipChainIdToHex = (chainId: string): ChainIdHex => { }; /** - * Extract token address from asset type - * Returns tuple of [tokenAddress, isNativeToken] or null if invalid + * Extract token address from asset type. * - * @param assetType - Asset type string (e.g., 'eip155:1/erc20:0x...' or 'eip155:1/slip44:60') - * @returns Tuple of [tokenAddress, isNativeToken] or null if invalid + * @param assetType - Asset type string. + * @returns Tuple of [tokenAddress, isNativeToken] or null if invalid. */ export const parseAssetType = (assetType: string): [string, boolean] | null => { if (!isCaipAssetType(assetType)) { @@ -226,22 +242,23 @@ export const parseAssetType = (assetType: string): [string, boolean] | null => { const parsed = parseCaipAssetType(assetType); - // ERC20 token (e.g., "eip155:1/erc20:0x...") if (parsed.assetNamespace === 'erc20') { return [parsed.assetReference, false]; } - // Native token (e.g., "eip155:1/slip44:60") if (parsed.assetNamespace === 'slip44') { return [ZERO_ADDRESS, true]; } return null; }; -// endregion -// ──────────────────────────────────────────────────────────────────────────── -// region: Main controller +type NativeBalanceUpdate = { address: string; chainId: Hex; balance: Hex }; +type StakedBalanceUpdate = { + address: string; + chainId: Hex; + stakedBalance: Hex; +}; export class TokenBalancesController extends StaticIntervalPollingController<{ chainIds: ChainIdHex[]; }>()< @@ -278,6 +295,9 @@ export class TokenBalancesController extends StaticIntervalPollingController<{ /** Track if controller-level polling is active */ #isControllerPollingActive = false; + /** Track if the keyring is unlocked */ + #isUnlocked = false; + /** Store original chainIds from startPolling to preserve intent */ #requestedChainIds: ChainIdHex[] = []; @@ -297,8 +317,8 @@ export class TokenBalancesController extends StaticIntervalPollingController<{ chainPollingIntervals = {}, state = {}, queryMultipleAccounts = true, - accountsApiChainIds = () => [], - allowExternalServices = () => true, + accountsApiChainIds = (): ChainIdHex[] => [], + allowExternalServices = (): boolean => true, platform, }: TokenBalancesControllerOptions) { super({ @@ -308,7 +328,6 @@ export class TokenBalancesController extends StaticIntervalPollingController<{ state: { tokenBalances: {}, ...state }, }); - // Normalize all account addresses to lowercase in existing state this.#normalizeAccountAddresses(); this.#platform = platform ?? 'extension'; @@ -318,7 +337,6 @@ export class TokenBalancesController extends StaticIntervalPollingController<{ this.#websocketActivePollingInterval = websocketActivePollingInterval; this.#chainPollingConfig = { ...chainPollingIntervals }; - // Strategy order: API first, then RPC fallback this.#balanceFetchers = [ ...(accountsApiChainIds().length > 0 && allowExternalServices() ? [this.#createAccountsApiFetcher()] @@ -331,13 +349,20 @@ export class TokenBalancesController extends StaticIntervalPollingController<{ this.setIntervalLength(interval); - // initial token state & subscriptions const { allTokens, allDetectedTokens, allIgnoredTokens } = this.messenger.call('TokensController:getState'); this.#allTokens = allTokens; this.#detectedTokens = allDetectedTokens; this.#allIgnoredTokens = allIgnoredTokens; + const { isUnlocked } = this.messenger.call('KeyringController:getState'); + this.#isUnlocked = isUnlocked; + + this.#subscribeToControllers(); + this.#registerActions(); + } + + #subscribeToControllers(): void { this.messenger.subscribe( 'TokensController:stateChange', (tokensState: TokensControllerState) => { @@ -346,20 +371,68 @@ export class TokenBalancesController extends StaticIntervalPollingController<{ }); }, ); + this.messenger.subscribe( 'NetworkController:stateChange', this.#onNetworkChanged, ); + + this.messenger.subscribe('KeyringController:unlock', () => { + this.#isUnlocked = true; + }); + + this.messenger.subscribe('KeyringController:lock', () => { + this.#isUnlocked = false; + }); + this.messenger.subscribe( 'KeyringController:accountRemoved', this.#onAccountRemoved, ); + this.messenger.subscribe( 'AccountsController:selectedEvmAccountChange', this.#onAccountChanged, ); - // Register action handlers for polling interval control + this.messenger.subscribe( + 'AccountActivityService:balanceUpdated', + (event) => { + this.#onAccountActivityBalanceUpdate(event).catch((error) => { + console.warn('Error handling balance update:', error); + }); + }, + ); + + this.messenger.subscribe( + 'AccountActivityService:statusChanged', + this.#onAccountActivityStatusChanged.bind(this), + ); + + this.messenger.subscribe( + 'TransactionController:transactionConfirmed', + (transactionMeta) => { + this.updateBalances({ + chainIds: [transactionMeta.chainId], + }).catch(() => { + // Silently handle balance update errors + }); + }, + ); + + this.messenger.subscribe( + 'TransactionController:incomingTransactionsReceived', + (incomingTransactions) => { + this.updateBalances({ + chainIds: incomingTransactions.map((tx) => tx.chainId), + }).catch(() => { + // Silently handle balance update errors + }); + }, + ); + } + + #registerActions(): void { this.messenger.registerActionHandler( `TokenBalancesController:updateChainPollingConfigs`, this.updateChainPollingConfigs.bind(this), @@ -369,29 +442,26 @@ export class TokenBalancesController extends StaticIntervalPollingController<{ `TokenBalancesController:getChainPollingConfig`, this.getChainPollingConfig.bind(this), ); + } - // Subscribe to AccountActivityService balance updates for real-time updates - this.messenger.subscribe( - 'AccountActivityService:balanceUpdated', - this.#onAccountActivityBalanceUpdate.bind(this), - ); - - // Subscribe to AccountActivityService status changes for dynamic polling management - this.messenger.subscribe( - 'AccountActivityService:statusChanged', - this.#onAccountActivityStatusChanged.bind(this), - ); + /** + * Whether the controller is active (keyring is unlocked). + * When locked, balance updates should be skipped. + * + * @returns Whether the keyring is unlocked. + */ + get isActive(): boolean { + return this.#isUnlocked; } /** * Normalize all account addresses to lowercase and merge duplicates - * This handles migration from old state where addresses might be checksummed + * Handles migration from old state where addresses might be checksummed. */ - #normalizeAccountAddresses() { + #normalizeAccountAddresses(): void { const currentState = this.state.tokenBalances; const normalizedBalances: TokenBalances = {}; - // Iterate through all accounts and normalize to lowercase for (const address of Object.keys(currentState)) { const lowercaseAddress = address.toLowerCase() as ChecksumAddress; const accountBalances = currentState[address as ChecksumAddress]; @@ -400,20 +470,12 @@ export class TokenBalancesController extends StaticIntervalPollingController<{ continue; } - // If this lowercase address doesn't exist yet, create it - if (!normalizedBalances[lowercaseAddress]) { - normalizedBalances[lowercaseAddress] = {}; - } + normalizedBalances[lowercaseAddress] ??= {}; - // Merge chain data for (const chainId of Object.keys(accountBalances)) { const chainIdKey = chainId as ChainIdHex; + normalizedBalances[lowercaseAddress][chainIdKey] ??= {}; - if (!normalizedBalances[lowercaseAddress][chainIdKey]) { - normalizedBalances[lowercaseAddress][chainIdKey] = {}; - } - - // Merge token balances (later values override earlier ones if duplicates exist) Object.assign( normalizedBalances[lowercaseAddress][chainIdKey], accountBalances[chainIdKey], @@ -421,7 +483,6 @@ export class TokenBalancesController extends StaticIntervalPollingController<{ } } - // Only update if there were changes if ( Object.keys(currentState).length !== Object.keys(normalizedBalances).length || @@ -444,8 +505,9 @@ export class TokenBalancesController extends StaticIntervalPollingController<{ const { networkConfigurationsByChainId } = this.messenger.call( 'NetworkController:getState', ); - const cfg = networkConfigurationsByChainId[chainId]; - const { networkClientId } = cfg.rpcEndpoints[cfg.defaultRpcEndpointIndex]; + const networkConfig = networkConfigurationsByChainId[chainId]; + const { networkClientId } = + networkConfig.rpcEndpoints[networkConfig.defaultRpcEndpointIndex]; const client = this.messenger.call( 'NetworkController:getNetworkClientById', networkClientId, @@ -453,23 +515,21 @@ export class TokenBalancesController extends StaticIntervalPollingController<{ return new Web3Provider(client.provider); }; - readonly #getNetworkClient = (chainId: ChainIdHex) => { + readonly #getNetworkClient = ( + chainId: ChainIdHex, + ): ReturnType => { const { networkConfigurationsByChainId } = this.messenger.call( 'NetworkController:getState', ); - const cfg = networkConfigurationsByChainId[chainId]; - const { networkClientId } = cfg.rpcEndpoints[cfg.defaultRpcEndpointIndex]; + const networkConfig = networkConfigurationsByChainId[chainId]; + const { networkClientId } = + networkConfig.rpcEndpoints[networkConfig.defaultRpcEndpointIndex]; return this.messenger.call( 'NetworkController:getNetworkClientById', networkClientId, ); }; - /** - * Creates an AccountsApiBalanceFetcher that only supports chains in the accountsApiChainIds array - * - * @returns A BalanceFetcher that wraps AccountsApiBalanceFetcher with chainId filtering - */ readonly #createAccountsApiFetcher = (): BalanceFetcher => { const originalFetcher = new AccountsApiBalanceFetcher( this.#platform, @@ -477,75 +537,47 @@ export class TokenBalancesController extends StaticIntervalPollingController<{ ); return { - supports: (chainId: ChainIdHex): boolean => { - // Only support chains that are both: - // 1. In our specified accountsApiChainIds array - // 2. Actually supported by the AccountsApi - return ( - this.#accountsApiChainIds().includes(chainId) && - originalFetcher.supports(chainId) - ); - }, + supports: (chainId: ChainIdHex): boolean => + this.#accountsApiChainIds().includes(chainId) && + originalFetcher.supports(chainId), fetch: originalFetcher.fetch.bind(originalFetcher), }; }; - /** - * Override to support per-chain polling intervals by grouping chains by interval - * - * @param options0 - The polling options - * @param options0.chainIds - Chain IDs to start polling for - */ - override _startPolling({ chainIds }: { chainIds: ChainIdHex[] }) { - // Store the original chainIds to preserve intent across config updates + override _startPolling({ chainIds }: { chainIds: ChainIdHex[] }): void { this.#requestedChainIds = [...chainIds]; this.#isControllerPollingActive = true; this.#startIntervalGroupPolling(chainIds, true); } - /** - * Start or restart interval-based polling for multiple chains - * - * @param chainIds - Chain IDs to start polling for - * @param immediate - Whether to poll immediately before starting timers (default: true) - */ - #startIntervalGroupPolling(chainIds: ChainIdHex[], immediate = true) { - // Stop any existing interval timers + #startIntervalGroupPolling(chainIds: ChainIdHex[], immediate = true): void { this.#intervalPollingTimers.forEach((timer) => clearInterval(timer)); this.#intervalPollingTimers.clear(); - // Group chains by their polling intervals const intervalGroups = new Map(); for (const chainId of chainIds) { const config = this.getChainPollingConfig(chainId); - const existing = intervalGroups.get(config.interval) || []; - existing.push(chainId); - intervalGroups.set(config.interval, existing); + const group = intervalGroups.get(config.interval) ?? []; + group.push(chainId); + intervalGroups.set(config.interval, group); } - // Start separate polling loop for each interval group for (const [interval, chainIdsGroup] of intervalGroups) { this.#startPollingForInterval(interval, chainIdsGroup, immediate); } } - /** - * Start polling loop for chains that share the same interval - * - * @param interval - The polling interval in milliseconds - * @param chainIds - Chain IDs that share this interval - * @param immediate - Whether to poll immediately before starting the timer (default: true) - */ #startPollingForInterval( interval: number, chainIds: ChainIdHex[], immediate = true, - ) { - const pollFunction = async () => { + ): void { + const pollFunction = async (): Promise => { if (!this.#isControllerPollingActive) { return; } + try { await this._executePoll({ chainIds }); } catch (error) { @@ -556,7 +588,6 @@ export class TokenBalancesController extends StaticIntervalPollingController<{ } }; - // Poll immediately first if requested if (immediate) { pollFunction().catch((error) => { console.warn( @@ -566,28 +597,14 @@ export class TokenBalancesController extends StaticIntervalPollingController<{ }); } - // Then start regular interval polling this.#setPollingTimer(interval, chainIds, pollFunction); } - /** - * Helper method to set up polling timer - * - * @param interval - The polling interval in milliseconds - * @param chainIds - Chain IDs for this interval - * @param pollFunction - The function to call on each poll - */ #setPollingTimer( interval: number, chainIds: ChainIdHex[], pollFunction: () => Promise, - ) { - // Clear any existing timer for this interval first - const existingTimer = this.#intervalPollingTimers.get(interval); - if (existingTimer) { - clearInterval(existingTimer); - } - + ): void { const timer = setInterval(() => { pollFunction().catch((error) => { console.warn( @@ -596,54 +613,41 @@ export class TokenBalancesController extends StaticIntervalPollingController<{ ); }); }, interval); + this.#intervalPollingTimers.set(interval, timer); } - /** - * Override to handle our custom polling approach - * - * @param tokenSetId - The token set ID to stop polling for - */ - override _stopPollingByPollingTokenSetId(tokenSetId: string) { - let parsedTokenSetId; + override _stopPollingByPollingTokenSetId(tokenSetId: string): void { let chainsToStop: ChainIdHex[] = []; try { - parsedTokenSetId = JSON.parse(tokenSetId); - chainsToStop = parsedTokenSetId.chainIds || []; + const parsedTokenSetId = JSON.parse(tokenSetId); + chainsToStop = parsedTokenSetId.chainIds ?? []; } catch (error) { console.warn('Failed to parse tokenSetId, stopping all polling:', error); - // Fallback: stop all polling if we can't parse the tokenSetId - this.#isControllerPollingActive = false; - this.#requestedChainIds = []; - this.#intervalPollingTimers.forEach((timer) => clearInterval(timer)); - this.#intervalPollingTimers.clear(); + this.#stopAllPolling(); return; } - // Compare with current chains - only stop if it matches our current session const currentChainsSet = new Set(this.#requestedChainIds); const stopChainsSet = new Set(chainsToStop); - // Check if this stop request is for our current session const isCurrentSession = currentChainsSet.size === stopChainsSet.size && [...currentChainsSet].every((chain) => stopChainsSet.has(chain)); if (isCurrentSession) { - this.#isControllerPollingActive = false; - this.#requestedChainIds = []; - this.#intervalPollingTimers.forEach((timer) => clearInterval(timer)); - this.#intervalPollingTimers.clear(); + this.#stopAllPolling(); } } - /** - * Get polling configuration for a chain (includes default fallback) - * - * @param chainId - The chain ID to get config for - * @returns The polling configuration for the chain - */ + #stopAllPolling(): void { + this.#isControllerPollingActive = false; + this.#requestedChainIds = []; + this.#intervalPollingTimers.forEach((timer) => clearInterval(timer)); + this.#intervalPollingTimers.clear(); + } + getChainPollingConfig(chainId: ChainIdHex): ChainPollingConfig { return ( this.#chainPollingConfig[chainId] ?? { @@ -658,27 +662,17 @@ export class TokenBalancesController extends StaticIntervalPollingController<{ }: { chainIds: ChainIdHex[]; queryAllAccounts?: boolean; - }) { - // This won't be called with our custom implementation, but keep for compatibility + }): Promise { await this.updateBalances({ chainIds, queryAllAccounts }); } - /** - * Update multiple chain polling configurations at once - * - * @param configs - Object mapping chain IDs to polling configurations - * @param options - Optional configuration for the update behavior - * @param options.immediateUpdate - Whether to immediately fetch balances after updating configs (default: true) - */ updateChainPollingConfigs( configs: Record, options: UpdateChainPollingConfigsOptions = { immediateUpdate: true }, ): void { Object.assign(this.#chainPollingConfig, configs); - // If polling is currently active, restart with new interval groupings if (this.#isControllerPollingActive) { - // Restart polling with immediate fetch by default, unless explicitly disabled this.#startIntervalGroupPolling( this.#requestedChainIds, options.immediateUpdate, @@ -688,13 +682,96 @@ export class TokenBalancesController extends StaticIntervalPollingController<{ async updateBalances({ chainIds, + tokenAddresses, queryAllAccounts = false, - }: { chainIds?: ChainIdHex[]; queryAllAccounts?: boolean } = {}) { - const targetChains = chainIds ?? this.#chainIdsWithTokens(); + }: { + chainIds?: ChainIdHex[]; + tokenAddresses?: string[]; + queryAllAccounts?: boolean; + } = {}): Promise { + if (!this.isActive) { + return; + } + + const targetChains = this.#getTargetChains(chainIds); if (!targetChains.length) { return; } + const { selectedAccount, allAccounts, jwtToken } = + await this.#getAccountsAndJwt(); + + const aggregatedBalances = await this.#fetchAllBalances({ + targetChains, + selectedAccount, + allAccounts, + jwtToken, + queryAllAccounts: queryAllAccounts ?? this.#queryAllAccounts, + }); + + const filteredAggregated = this.#filterByTokenAddresses( + aggregatedBalances, + tokenAddresses, + ); + + const accountsToProcess = this.#getAccountsToProcess( + queryAllAccounts, + allAccounts, + selectedAccount, + ); + + const prev = this.state; + const next = this.#applyTokenBalancesToState({ + prev, + targetChains, + accountsToProcess, + balances: filteredAggregated, + }); + + if (!isEqual(prev, next)) { + this.update(() => next); + + const accountTrackerState = this.messenger.call( + 'AccountTrackerController:getState', + ); + + const nativeUpdates = this.#buildNativeBalanceUpdates( + filteredAggregated, + accountTrackerState, + ); + + if (nativeUpdates.length > 0) { + this.messenger.call( + 'AccountTrackerController:updateNativeBalances', + nativeUpdates, + ); + } + + const stakedUpdates = this.#buildStakedBalanceUpdates( + filteredAggregated, + accountTrackerState, + ); + + if (stakedUpdates.length > 0) { + this.messenger.call( + 'AccountTrackerController:updateStakedBalances', + stakedUpdates, + ); + } + } + + await this.#importUntrackedTokens(filteredAggregated); + } + + #getTargetChains(chainIds?: ChainIdHex[]): ChainIdHex[] { + return chainIds?.length ? chainIds : this.#chainIdsWithTokens(); + } + + async #getAccountsAndJwt(): Promise<{ + selectedAccount: ChecksumAddress; + allAccounts: InternalAccount[]; + jwtToken: string | undefined; + }> { const { address: selected } = this.messenger.call( 'AccountsController:getSelectedAccount', ); @@ -708,13 +785,32 @@ export class TokenBalancesController extends StaticIntervalPollingController<{ 5000, ); + return { + selectedAccount: selected as ChecksumAddress, + allAccounts, + jwtToken, + }; + } + + async #fetchAllBalances({ + targetChains, + selectedAccount, + allAccounts, + jwtToken, + queryAllAccounts, + }: { + targetChains: ChainIdHex[]; + selectedAccount: ChecksumAddress; + allAccounts: InternalAccount[]; + jwtToken?: string; + queryAllAccounts: boolean; + }): Promise { const aggregated: ProcessedBalance[] = []; let remainingChains = [...targetChains]; - // Try each fetcher in order, removing successfully processed chains for (const fetcher of this.#balanceFetchers) { - const supportedChains = remainingChains.filter((c) => - fetcher.supports(c), + const supportedChains = remainingChains.filter((chain) => + fetcher.supports(chain), ); if (!supportedChains.length) { continue; @@ -723,220 +819,294 @@ export class TokenBalancesController extends StaticIntervalPollingController<{ try { const result = await fetcher.fetch({ chainIds: supportedChains, - queryAllAccounts: queryAllAccounts ?? this.#queryAllAccounts, - selectedAccount: selected as ChecksumAddress, + queryAllAccounts, + selectedAccount, allAccounts, jwtToken, }); - if (result.balances && result.balances.length > 0) { + if (result.balances?.length) { aggregated.push(...result.balances); - // Remove chains that were successfully processed - const processedChains = new Set( - result.balances.map((b) => b.chainId), - ); + + const processed = new Set(result.balances.map((b) => b.chainId)); remainingChains = remainingChains.filter( - (chain) => !processedChains.has(chain), + (chain) => !processed.has(chain), ); } - // Add unprocessed chains back to remainingChains for next fetcher - if ( - result.unprocessedChainIds && - result.unprocessedChainIds.length > 0 - ) { - const currentRemainingChains = remainingChains; + if (result.unprocessedChainIds?.length) { + const currentRemaining = [...remainingChains]; const chainsToAdd = result.unprocessedChainIds.filter( (chainId) => supportedChains.includes(chainId) && - !currentRemainingChains.includes(chainId), + !currentRemaining.includes(chainId), ); remainingChains.push(...chainsToAdd); + + this.messenger + .call('TokenDetectionController:detectTokens', { + chainIds: result.unprocessedChainIds, + forceRpc: true, + }) + .catch(() => { + // Silently handle token detection errors + }); } } catch (error) { console.warn( `Balance fetcher failed for chains ${supportedChains.join(', ')}: ${String(error)}`, ); - // Continue to next fetcher (fallback) + + this.messenger + .call('TokenDetectionController:detectTokens', { + chainIds: supportedChains, + forceRpc: true, + }) + .catch(() => { + // Silently handle token detection errors + }); } - // If all chains have been processed, break early - if (remainingChains.length === 0) { + if (!remainingChains.length) { break; } } - // Determine which accounts to process based on queryAllAccounts parameter - const accountsToProcess = - (queryAllAccounts ?? this.#queryAllAccounts) - ? allAccounts.map((a) => a.address as ChecksumAddress) - : [selected as ChecksumAddress]; + return aggregated; + } - const prev = this.state; - const next = draft(prev, (d) => { - // Initialize account and chain structures if they don't exist, but preserve existing balances + #filterByTokenAddresses( + balances: ProcessedBalance[], + tokenAddresses?: string[], + ): ProcessedBalance[] { + if (!tokenAddresses?.length) { + return balances; + } + + const lowered = tokenAddresses.map((a) => a.toLowerCase()); + return balances.filter((balance) => + lowered.includes(balance.token.toLowerCase()), + ); + } + + #getAccountsToProcess( + queryAllAccountsParam: boolean | undefined, + allAccounts: InternalAccount[], + selectedAccount: ChecksumAddress, + ): ChecksumAddress[] { + const effectiveQueryAll = + queryAllAccountsParam ?? this.#queryAllAccounts ?? false; + + if (!effectiveQueryAll) { + return [selectedAccount]; + } + + return allAccounts.map((account) => account.address as ChecksumAddress); + } + + #applyTokenBalancesToState({ + prev, + targetChains, + accountsToProcess, + balances, + }: { + prev: TokenBalancesControllerState; + targetChains: ChainIdHex[]; + accountsToProcess: ChecksumAddress[]; + balances: ProcessedBalance[]; + }): TokenBalancesControllerState { + return draft(prev, (draftState) => { for (const chainId of targetChains) { for (const account of accountsToProcess) { - // Ensure the nested structure exists without overwriting existing balances - d.tokenBalances[account] ??= {}; - d.tokenBalances[account][chainId] ??= {}; - // Initialize tokens from allTokens only if they don't exist yet + draftState.tokenBalances[account] ??= {}; + draftState.tokenBalances[account][chainId] ??= {}; + const chainTokens = this.#allTokens[chainId]; if (chainTokens?.[account]) { Object.values(chainTokens[account]).forEach( (token: { address: string }) => { const tokenAddress = checksum(token.address); - // Only initialize if the token balance doesn't exist yet - if (!(tokenAddress in d.tokenBalances[account][chainId])) { - d.tokenBalances[account][chainId][tokenAddress] = '0x0'; - } + draftState.tokenBalances[account][chainId][tokenAddress] ??= + '0x0'; }, ); } - // Initialize tokens from allDetectedTokens only if they don't exist yet const detectedChainTokens = this.#detectedTokens[chainId]; if (detectedChainTokens?.[account]) { Object.values(detectedChainTokens[account]).forEach( (token: { address: string }) => { const tokenAddress = checksum(token.address); - // Only initialize if the token balance doesn't exist yet - if (!(tokenAddress in d.tokenBalances[account][chainId])) { - d.tokenBalances[account][chainId][tokenAddress] = '0x0'; - } + draftState.tokenBalances[account][chainId][tokenAddress] ??= + '0x0'; }, ); } } } - // Update with actual fetched balances only if the value has changed - aggregated.forEach(({ success, value, account, token, chainId }) => { - if (success && value !== undefined) { - // Ensure all accounts we add/update are in lower-case - const lowerCaseAccount = account.toLowerCase() as ChecksumAddress; - const newBalance = toHex(value); - const tokenAddress = checksum(token); - const currentBalance = - d.tokenBalances[lowerCaseAccount]?.[chainId]?.[tokenAddress]; - - // Only update if the balance has actually changed - if (currentBalance !== newBalance) { - ((d.tokenBalances[lowerCaseAccount] ??= {})[chainId] ??= {})[ - tokenAddress - ] = newBalance; - } + balances.forEach(({ success, value, account, token, chainId }) => { + if (!success || value === undefined) { + return; + } + + const lowerCaseAccount = account.toLowerCase() as ChecksumAddress; + const newBalance = toHex(value); + const tokenAddress = checksum(token); + + const currentBalance = + draftState.tokenBalances[lowerCaseAccount]?.[chainId]?.[tokenAddress]; + + if (currentBalance !== newBalance) { + ((draftState.tokenBalances[lowerCaseAccount] ??= {})[chainId] ??= {})[ + tokenAddress + ] = newBalance; } }); }); + } - if (!isEqual(prev, next)) { - this.update(() => next); - - const nativeBalances = aggregated.filter( - (r) => r.success && r.token === ZERO_ADDRESS, - ); + #buildNativeBalanceUpdates( + balances: ProcessedBalance[], + accountTrackerState: { + accountsByChainId: Record< + string, + Record + >; + }, + ): NativeBalanceUpdate[] { + const nativeBalances = balances.filter( + (balance) => balance.success && balance.token === ZERO_ADDRESS, + ); - // Get current AccountTracker state to compare existing balances - const accountTrackerState = this.messenger.call( - 'AccountTrackerController:getState', - ); + if (!nativeBalances.length) { + return []; + } - // Update native token balances only if they have changed - if (nativeBalances.length > 0) { - const balanceUpdates = nativeBalances - .map((balance) => ({ - address: balance.account, - chainId: balance.chainId, - balance: balance.value ? BNToHex(balance.value) : '0x0', - })) - .filter((update) => { - const currentBalance = - accountTrackerState.accountsByChainId[update.chainId]?.[ - checksum(update.address) - ]?.balance; - // Only include if the balance has actually changed - return currentBalance !== update.balance; - }); + return nativeBalances + .map((balance) => ({ + address: balance.account, + chainId: balance.chainId, + balance: balance.value ? BNToHex(balance.value) : '0x0', + })) + .filter((update) => { + const currentBalance = + accountTrackerState.accountsByChainId[update.chainId]?.[ + checksum(update.address) + ]?.balance; + return currentBalance !== update.balance; + }); + } - if (balanceUpdates.length > 0) { - this.messenger.call( - 'AccountTrackerController:updateNativeBalances', - balanceUpdates, - ); - } + #buildStakedBalanceUpdates( + balances: ProcessedBalance[], + accountTrackerState: { + accountsByChainId: Record< + string, + Record + >; + }, + ): StakedBalanceUpdate[] { + const stakedBalances = balances.filter((balance) => { + if (!balance.success || balance.token === ZERO_ADDRESS) { + return false; } - // Filter and update staked balances in a single batch operation for better performance - const stakedBalances = aggregated.filter((r) => { - if (!r.success || r.token === ZERO_ADDRESS) { - return false; - } + const stakingContractAddress = + STAKING_CONTRACT_ADDRESS_BY_CHAINID[balance.chainId]; + return ( + stakingContractAddress && + stakingContractAddress.toLowerCase() === balance.token.toLowerCase() + ); + }); - // Check if the chainId and token address match any staking contract - const stakingContractAddress = - STAKING_CONTRACT_ADDRESS_BY_CHAINID[r.chainId]; - return ( - stakingContractAddress && - stakingContractAddress.toLowerCase() === r.token.toLowerCase() - ); + if (!stakedBalances.length) { + return []; + } + + return stakedBalances + .map((balance) => ({ + address: balance.account, + chainId: balance.chainId, + stakedBalance: balance.value ? toHex(balance.value) : '0x0', + })) + .filter((update) => { + const currentStakedBalance = + accountTrackerState.accountsByChainId[update.chainId]?.[ + checksum(update.address) + ]?.stakedBalance; + return currentStakedBalance !== update.stakedBalance; }); + } - if (stakedBalances.length > 0) { - const stakedBalanceUpdates = stakedBalances - .map((balance) => ({ - address: balance.account, - chainId: balance.chainId, - stakedBalance: balance.value ? toHex(balance.value) : '0x0', - })) - .filter((update) => { - const currentStakedBalance = - accountTrackerState.accountsByChainId[update.chainId]?.[ - checksum(update.address) - ]?.stakedBalance; - // Only include if the staked balance has actually changed - return currentStakedBalance !== update.stakedBalance; - }); + /** + * Import untracked tokens that have non-zero balances. + * This mirrors the v2 behavior where only tokens with actual balances are added. + * Delegates to TokenDetectionController:addDetectedTokensViaPolling which handles: + * - Checking if useTokenDetection preference is enabled + * - Filtering tokens already in allTokens or allIgnoredTokens + * - Token metadata lookup and addition via TokensController + * + * @param balances - Array of processed balance results from fetchers + */ + async #importUntrackedTokens(balances: ProcessedBalance[]): Promise { + const tokensByChain = new Map(); - if (stakedBalanceUpdates.length > 0) { - this.messenger.call( - 'AccountTrackerController:updateStakedBalances', - stakedBalanceUpdates, - ); - } + for (const balance of balances) { + // Skip failed fetches, native tokens, and zero balances (like v2 did) + if ( + !balance.success || + balance.token === ZERO_ADDRESS || + !balance.value || + balance.value.isZero() + ) { + continue; + } + + const tokenAddress = checksum(balance.token); + const existing = tokensByChain.get(balance.chainId) ?? []; + if (!existing.includes(tokenAddress)) { + existing.push(tokenAddress); + tokensByChain.set(balance.chainId, existing); + } + } + + // Add detected tokens via TokenDetectionController (handles preference check, + // filtering of allTokens/allIgnoredTokens, and metadata lookup) + for (const [chainId, tokenAddresses] of tokensByChain) { + if (tokenAddresses.length) { + await this.messenger.call( + 'TokenDetectionController:addDetectedTokensViaPolling', + { + tokensSlice: tokenAddresses, + chainId, + }, + ); } } } - resetState() { + resetState(): void { this.update(() => ({ tokenBalances: {} })); } - /** - * Helper method to check if a token is tracked (exists in allTokens or allIgnoredTokens) - * - * @param tokenAddress - The token address to check - * @param account - The account address - * @param chainId - The chain ID - * @returns True if the token is tracked (imported or ignored) - */ #isTokenTracked( tokenAddress: string, account: ChecksumAddress, chainId: ChainIdHex, ): boolean { - // Check if token exists in allTokens + const normalizedAccount = account.toLowerCase(); + if ( - this.#allTokens?.[chainId]?.[account.toLowerCase()]?.some( + this.#allTokens?.[chainId]?.[normalizedAccount]?.some( (token) => token.address === tokenAddress, ) ) { return true; } - // Check if token exists in allIgnoredTokens if ( - this.#allIgnoredTokens?.[chainId]?.[account.toLowerCase()]?.some( + this.#allIgnoredTokens?.[chainId]?.[normalizedAccount]?.some( (token) => token === tokenAddress, ) ) { @@ -946,28 +1116,17 @@ export class TokenBalancesController extends StaticIntervalPollingController<{ return false; } - readonly #onTokensChanged = async (state: TokensControllerState) => { + readonly #onTokensChanged = async ( + state: TokensControllerState, + ): Promise => { const changed: ChainIdHex[] = []; let hasChanges = false; - // Get chains that have existing balances - const chainsWithBalances = new Set(); - for (const address of Object.keys(this.state.tokenBalances)) { - const addressKey = address as ChecksumAddress; - for (const chainId of Object.keys( - this.state.tokenBalances[addressKey] || {}, - )) { - chainsWithBalances.add(chainId as ChainIdHex); - } - } - - // Only process chains that are explicitly mentioned in the incoming state change const incomingChainIds = new Set([ ...Object.keys(state.allTokens), ...Object.keys(state.allDetectedTokens), ]); - // Only proceed if there are actual changes to chains that have balances or are being added const relevantChainIds = Array.from(incomingChainIds).filter((chainId) => { const id = chainId as ChainIdHex; @@ -980,24 +1139,20 @@ export class TokenBalancesController extends StaticIntervalPollingController<{ (this.#detectedTokens[id] && Object.keys(this.#detectedTokens[id]).length > 0); - // Check if there's an actual change in token state const hasTokenChange = !isEqual(state.allTokens[id], this.#allTokens[id]) || !isEqual(state.allDetectedTokens[id], this.#detectedTokens[id]); - // Process chains that have actual changes OR are new chains getting tokens return hasTokenChange || (!hadTokensBefore && hasTokensNow); }); - if (relevantChainIds.length === 0) { - // No relevant changes, just update internal state + if (!relevantChainIds.length) { this.#allTokens = state.allTokens; this.#detectedTokens = state.allDetectedTokens; return; } - // Handle both cleanup and updates in a single state update - this.update((s) => { + this.update((currentState) => { for (const chainId of relevantChainIds) { const id = chainId as ChainIdHex; const hasTokensNow = @@ -1011,21 +1166,22 @@ export class TokenBalancesController extends StaticIntervalPollingController<{ (this.#detectedTokens[id] && Object.keys(this.#detectedTokens[id]).length > 0); - if ( + const tokensChanged = !isEqual(state.allTokens[id], this.#allTokens[id]) || - !isEqual(state.allDetectedTokens[id], this.#detectedTokens[id]) - ) { - if (hasTokensNow) { - // Chain still has tokens - mark for async balance update - changed.push(id); - } else if (hadTokensBefore) { - // Chain had tokens before but doesn't now - clean up balances immediately - for (const address of Object.keys(s.tokenBalances)) { - const addressKey = address as ChecksumAddress; - if (s.tokenBalances[addressKey]?.[id]) { - s.tokenBalances[addressKey][id] = {}; - hasChanges = true; - } + !isEqual(state.allDetectedTokens[id], this.#detectedTokens[id]); + + if (!tokensChanged) { + continue; + } + + if (hasTokensNow) { + changed.push(id); + } else if (hadTokensBefore) { + for (const address of Object.keys(currentState.tokenBalances)) { + const addressKey = address as ChecksumAddress; + if (currentState.tokenBalances[addressKey]?.[id]) { + currentState.tokenBalances[addressKey][id] = {}; + hasChanges = true; } } } @@ -1036,7 +1192,6 @@ export class TokenBalancesController extends StaticIntervalPollingController<{ this.#detectedTokens = state.allDetectedTokens; this.#allIgnoredTokens = state.allIgnoredTokens; - // Only update balances for chains that still have tokens (and only if we haven't already updated state) if (changed.length && !hasChanges) { this.updateBalances({ chainIds: changed }).catch((error) => { console.warn('Error updating balances after token change:', error); @@ -1044,13 +1199,11 @@ export class TokenBalancesController extends StaticIntervalPollingController<{ } }; - readonly #onNetworkChanged = (state: NetworkState) => { - // Check if any networks were removed by comparing with previous state + readonly #onNetworkChanged = (state: NetworkState): void => { const currentNetworks = new Set( Object.keys(state.networkConfigurationsByChainId), ); - // Get all networks that currently have balances const networksWithBalances = new Set(); for (const address of Object.keys(this.state.tokenBalances)) { const addressKey = address as ChecksumAddress; @@ -1061,65 +1214,47 @@ export class TokenBalancesController extends StaticIntervalPollingController<{ } } - // Find networks that were removed const removedNetworks = Array.from(networksWithBalances).filter( (network) => !currentNetworks.has(network), ); - if (removedNetworks.length > 0) { - this.update((s) => { - // Remove balances for all accounts on the deleted networks - for (const address of Object.keys(s.tokenBalances)) { - const addressKey = address as ChecksumAddress; - for (const removedNetwork of removedNetworks) { - const networkKey = removedNetwork as ChainIdHex; - if (s.tokenBalances[addressKey]?.[networkKey]) { - delete s.tokenBalances[addressKey][networkKey]; - } + if (!removedNetworks.length) { + return; + } + + this.update((currentState) => { + for (const address of Object.keys(currentState.tokenBalances)) { + const addressKey = address as ChecksumAddress; + for (const removedNetwork of removedNetworks) { + const networkKey = removedNetwork as ChainIdHex; + if (currentState.tokenBalances[addressKey]?.[networkKey]) { + delete currentState.tokenBalances[addressKey][networkKey]; } } - }); - } + } + }); }; - readonly #onAccountRemoved = (addr: string) => { + readonly #onAccountRemoved = (addr: string): void => { if (!isStrictHexString(addr) || !isValidHexAddress(addr)) { return; } - this.update((s) => { - delete s.tokenBalances[addr]; + this.update((currentState) => { + delete currentState.tokenBalances[addr]; }); }; - /** - * Handle account selection changes - * Triggers immediate balance fetch to ensure we have the latest balances - * since WebSocket only provides updates for changes going forward - */ - readonly #onAccountChanged = () => { - // Fetch balances for all chains with tokens when account changes + readonly #onAccountChanged = (): void => { const chainIds = this.#chainIdsWithTokens(); - if (chainIds.length > 0) { - this.updateBalances({ chainIds }).catch(() => { - // Silently handle polling errors - }); + if (!chainIds.length) { + return; } - }; - // ──────────────────────────────────────────────────────────────────────────── - // AccountActivityService integration helpers + this.updateBalances({ chainIds }).catch(() => { + // Silently handle polling errors + }); + }; - /** - * Prepare balance updates from AccountActivityService - * Processes all updates and returns categorized results - * Throws an error if any updates have validation/parsing issues - * - * @param updates - Array of balance updates from AccountActivityService - * @param account - Lowercase account address (for consistency with tokenBalances state format) - * @param chainId - Hex chain ID - * @returns Object containing arrays of token balances, new token addresses to add, and native balance updates - * @throws Error if any balance update has validation or parsing errors - */ #prepareBalanceUpdates( updates: BalanceUpdate[], account: ChecksumAddress, @@ -1127,25 +1262,19 @@ export class TokenBalancesController extends StaticIntervalPollingController<{ ): { tokenBalances: { tokenAddress: ChecksumAddress; balance: Hex }[]; newTokens: string[]; - nativeBalanceUpdates: { address: string; chainId: Hex; balance: Hex }[]; + nativeBalanceUpdates: NativeBalanceUpdate[]; } { const tokenBalances: { tokenAddress: ChecksumAddress; balance: Hex }[] = []; const newTokens: string[] = []; - const nativeBalanceUpdates: { - address: string; - chainId: Hex; - balance: Hex; - }[] = []; + const nativeBalanceUpdates: NativeBalanceUpdate[] = []; for (const update of updates) { const { asset, postBalance } = update; - // Throw if balance update has an error if (postBalance.error) { throw new Error('Balance update has error'); } - // Parse token address from asset type const parsed = parseAssetType(asset.type); if (!parsed) { throw new Error('Failed to parse asset type'); @@ -1153,7 +1282,6 @@ export class TokenBalancesController extends StaticIntervalPollingController<{ const [tokenAddress, isNativeToken] = parsed; - // Validate token address if ( !isStrictHexString(tokenAddress) || !isValidHexAddress(tokenAddress) @@ -1168,16 +1296,13 @@ export class TokenBalancesController extends StaticIntervalPollingController<{ chainId, ); - // postBalance.amount is in hex format (raw units) const balanceHex = postBalance.amount as Hex; - // Add token balance (tracked tokens, ignored tokens, and native tokens all get balance updates) tokenBalances.push({ tokenAddress: checksumTokenAddress, balance: balanceHex, }); - // Add native balance update if this is a native token if (isNativeToken) { nativeBalanceUpdates.push({ address: account, @@ -1186,7 +1311,6 @@ export class TokenBalancesController extends StaticIntervalPollingController<{ }); } - // Handle untracked ERC20 tokens - queue for import if (!isNativeToken && !isTracked) { newTokens.push(checksumTokenAddress); } @@ -1195,19 +1319,6 @@ export class TokenBalancesController extends StaticIntervalPollingController<{ return { tokenBalances, newTokens, nativeBalanceUpdates }; } - // ──────────────────────────────────────────────────────────────────────────── - // AccountActivityService event handlers - - /** - * Handle real-time balance updates from AccountActivityService - * Processes balance updates and updates the token balance state - * If any balance update has an error, triggers fallback polling for the chain - * - * @param options0 - Balance update parameters - * @param options0.address - Account address - * @param options0.chain - CAIP chain identifier - * @param options0.updates - Array of balance updates for the account - */ readonly #onAccountActivityBalanceUpdate = async ({ address, chain, @@ -1216,25 +1327,21 @@ export class TokenBalancesController extends StaticIntervalPollingController<{ address: string; chain: string; updates: BalanceUpdate[]; - }) => { + }): Promise => { const chainId = caipChainIdToHex(chain); const checksummedAccount = checksum(address); try { - // Process all balance updates at once const { tokenBalances, newTokens, nativeBalanceUpdates } = this.#prepareBalanceUpdates(updates, checksummedAccount, chainId); - // Update state once with all token balances if (tokenBalances.length > 0) { this.update((state) => { - // Temporary until ADR to normalize all keys - tokenBalances state requires: account in lowercase, token in checksum const lowercaseAccount = checksummedAccount.toLowerCase() as ChecksumAddress; state.tokenBalances[lowercaseAccount] ??= {}; state.tokenBalances[lowercaseAccount][chainId] ??= {}; - // Apply all token balance updates for (const { tokenAddress, balance } of tokenBalances) { state.tokenBalances[lowercaseAccount][chainId][tokenAddress] = balance; @@ -1242,7 +1349,6 @@ export class TokenBalancesController extends StaticIntervalPollingController<{ }); } - // Update native balances in AccountTrackerController if (nativeBalanceUpdates.length > 0) { this.messenger.call( 'AccountTrackerController:updateNativeBalances', @@ -1250,7 +1356,6 @@ export class TokenBalancesController extends StaticIntervalPollingController<{ ); } - // Import any new tokens that were discovered (balance already updated from websocket) if (newTokens.length > 0) { await this.messenger.call( 'TokenDetectionController:addDetectedTokensViaWs', @@ -1267,47 +1372,32 @@ export class TokenBalancesController extends StaticIntervalPollingController<{ ); console.warn('Balance update data:', JSON.stringify(updates, null, 2)); - // On error, trigger fallback polling await this.updateBalances({ chainIds: [chainId] }).catch(() => { // Silently handle polling errors }); } }; - /** - * Handle status changes from AccountActivityService - * Uses aggressive debouncing to prevent excessive HTTP calls from rapid up/down changes - * - * @param options0 - Status change event data - * @param options0.chainIds - Array of chain identifiers - * @param options0.status - Connection status ('up' for connected, 'down' for disconnected) - */ readonly #onAccountActivityStatusChanged = ({ chainIds, status, }: { chainIds: string[]; status: 'up' | 'down'; - }) => { - // Update pending changes (latest status wins for each chain) + }): void => { for (const chainId of chainIds) { this.#statusChangeDebouncer.pendingChanges.set(chainId, status); } - // Clear existing timer to extend debounce window if (this.#statusChangeDebouncer.timer) { clearTimeout(this.#statusChangeDebouncer.timer); } - // Set new timer - only process changes after activity settles this.#statusChangeDebouncer.timer = setTimeout(() => { this.#processAccumulatedStatusChanges(); - }, 5000); // 5-second debounce window + }, 5000); }; - /** - * Process all accumulated status changes in one batch to minimize HTTP calls - */ #processAccumulatedStatusChanges(): void { const changes = Array.from( this.#statusChangeDebouncer.pendingChanges.entries(), @@ -1315,52 +1405,38 @@ export class TokenBalancesController extends StaticIntervalPollingController<{ this.#statusChangeDebouncer.pendingChanges.clear(); this.#statusChangeDebouncer.timer = null; - if (changes.length === 0) { + if (!changes.length) { return; } - // Calculate final polling configurations const chainConfigs: Record = {}; for (const [chainId, status] of changes) { - // Convert CAIP format (eip155:1) to hex format (0x1) - // chainId is always in CAIP format from AccountActivityService const hexChainId = caipChainIdToHex(chainId); - if (status === 'down') { - // Chain is down - use default polling since no real-time updates available - chainConfigs[hexChainId] = { interval: this.#defaultInterval }; - } else { - // Chain is up - use longer intervals since WebSocket provides real-time updates - chainConfigs[hexChainId] = { - interval: this.#websocketActivePollingInterval, - }; - } + chainConfigs[hexChainId] = + status === 'down' + ? { interval: this.#defaultInterval } + : { interval: this.#websocketActivePollingInterval }; } - // Add jitter to prevent synchronized requests across instances - const jitterDelay = Math.random() * this.#defaultInterval; // 0 to default interval + const jitterDelay = Math.random() * this.#defaultInterval; setTimeout(() => { this.updateChainPollingConfigs(chainConfigs, { immediateUpdate: true }); }, jitterDelay); } - /** - * Clean up all timers and resources when controller is destroyed - */ override destroy(): void { this.#isControllerPollingActive = false; this.#intervalPollingTimers.forEach((timer) => clearInterval(timer)); this.#intervalPollingTimers.clear(); - // Clean up debouncing timer if (this.#statusChangeDebouncer.timer) { clearTimeout(this.#statusChangeDebouncer.timer); this.#statusChangeDebouncer.timer = null; } - // Unregister action handlers this.messenger.unregisterActionHandler( `TokenBalancesController:updateChainPollingConfigs`, ); diff --git a/packages/assets-controllers/src/TokenDetectionController.test.ts b/packages/assets-controllers/src/TokenDetectionController.test.ts index ff0308a9557..fba1aebab21 100644 --- a/packages/assets-controllers/src/TokenDetectionController.test.ts +++ b/packages/assets-controllers/src/TokenDetectionController.test.ts @@ -32,26 +32,15 @@ import nock from 'nock'; import sinon from 'sinon'; import { formatAggregatorNames } from './assetsUtil'; -import * as MutliChainAccountsServiceModule from './multi-chain-accounts-service'; -import { - MOCK_GET_BALANCES_RESPONSE, - createMockGetBalancesResponse, -} from './multi-chain-accounts-service/mocks/mock-get-balances'; -import { MOCK_GET_SUPPORTED_NETWORKS_RESPONSE } from './multi-chain-accounts-service/mocks/mock-get-supported-networks'; import { TOKEN_END_POINT_API } from './token-service'; import type { TokenDetectionControllerMessenger } from './TokenDetectionController'; import { - STATIC_MAINNET_TOKEN_LIST, TokenDetectionController, controllerName, mapChainIdWithTokenListMap, } from './TokenDetectionController'; import { getDefaultTokenListState } from './TokenListController'; -import type { - TokenListMap, - TokenListState, - TokenListToken, -} from './TokenListController'; +import type { TokenListState, TokenListToken } from './TokenListController'; import type { Token } from './TokenRatesController'; import type { TokensController, @@ -64,8 +53,6 @@ import { buildCustomRpcEndpoint, buildInfuraNetworkConfiguration, } from '../../network-controller/tests/helpers'; -import type { TransactionMeta } from '../../transaction-controller/src/types'; -import { TransactionStatus } from '../../transaction-controller/src/types'; const DEFAULT_INTERVAL = 180000; @@ -142,11 +129,33 @@ const mockNetworkConfigurations: Record = { rpcEndpoints: [ buildCustomRpcEndpoint({ url: 'https://polygon-mainnet.infura.io/v3/fakekey', + networkClientId: 'polygon', + }), + ], + }, + avalanche: { + blockExplorerUrls: ['https://snowtrace.io/'], + chainId: '0xa86a', + defaultBlockExplorerUrlIndex: 0, + defaultRpcEndpointIndex: 0, + name: 'Avalanche C-Chain', + nativeCurrency: 'AVAX', + rpcEndpoints: [ + buildCustomRpcEndpoint({ + url: 'https://api.avax.network/ext/bc/C/rpc', + networkClientId: 'avalanche', }), ], }, }; +// Network configurations keyed by chain ID (for use when testing with explicit chainIds) +const mockNetworkConfigurationsByChainId: Record = + { + '0xa86a': mockNetworkConfigurations.avalanche, + '0x89': mockNetworkConfigurations.polygon, + }; + type AllTokenDetectionControllerActions = MessengerActions; @@ -201,7 +210,6 @@ function buildTokenDetectionControllerMessenger( 'PreferencesController:getState', 'TokensController:addTokens', 'NetworkController:findNetworkClientIdByChainId', - 'AuthenticationController:getBearerToken', ], events: [ 'AccountsController:selectedEvmAccountChange', @@ -216,25 +224,9 @@ function buildTokenDetectionControllerMessenger( return tokenDetectionControllerMessenger; } -const mockMultiChainAccountsService = () => { - const mockFetchSupportedNetworks = jest - .spyOn(MutliChainAccountsServiceModule, 'fetchSupportedNetworks') - .mockResolvedValue(MOCK_GET_SUPPORTED_NETWORKS_RESPONSE.fullSupport); - const mockFetchMultiChainBalances = jest - .spyOn(MutliChainAccountsServiceModule, 'fetchMultiChainBalances') - .mockResolvedValue(MOCK_GET_BALANCES_RESPONSE); - - return { - mockFetchSupportedNetworks, - mockFetchMultiChainBalances, - }; -}; - describe('TokenDetectionController', () => { const defaultSelectedAccount = createMockInternalAccount(); - mockMultiChainAccountsService(); - beforeEach(async () => { nock(TOKEN_END_POINT_API) .get(getTokensPath(ChainId.mainnet)) @@ -305,7 +297,6 @@ describe('TokenDetectionController', () => { await controller.start(); triggerKeyringUnlock(); - expect(mockTokens.calledOnce).toBe(true); await advanceTime({ clock, duration: DEFAULT_INTERVAL * 1.5 }); expect(mockTokens.calledTwice).toBe(false); }, @@ -378,7 +369,6 @@ describe('TokenDetectionController', () => { { options: { getBalancesInSingleCall: mockGetBalancesInSingleCall, - useAccountsAPI: true, // USING ACCOUNTS API }, mocks: { getSelectedAccount: defaultSelectedAccount, @@ -413,7 +403,6 @@ describe('TokenDetectionController', () => { { options: { getBalancesInSingleCall: mockGetBalancesInSingleCall, - useAccountsAPI: true, // USING ACCOUNTS API }, mocks: { getAccount: selectedAccount, @@ -421,12 +410,30 @@ describe('TokenDetectionController', () => { }, }, - async ({ controller, mockTokenListGetState, callActionSpy }) => { - mockMultiChainAccountsService(); + async ({ + controller, + mockTokenListGetState, + callActionSpy, + mockGetNetworkClientById, + mockNetworkState, + }) => { + // Set selectedNetworkClientId to avalanche so the detection uses the right network + mockNetworkState({ + ...getDefaultNetworkControllerState(), + selectedNetworkClientId: 'avalanche', + }); + // Mock getNetworkClientById to return Avalanche chain ID + mockGetNetworkClientById( + () => + ({ + configuration: { chainId: '0xa86a' }, + }) as unknown as AutoManagedNetworkClient, + ); + mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { - '0x1': { + '0xa86a': { timestamp: 0, data: { [sampleTokenA.address]: { @@ -448,7 +455,7 @@ describe('TokenDetectionController', () => { expect(callActionSpy).toHaveBeenCalledWith( 'TokensController:addTokens', [sampleTokenA], - 'mainnet', + 'avalanche', ); }, ); @@ -466,7 +473,6 @@ describe('TokenDetectionController', () => { { options: { getBalancesInSingleCall: mockGetBalancesInSingleCall, - useAccountsAPI: true, // USING ACCOUNTS API }, mocks: { getAccount: selectedAccount, @@ -475,29 +481,10 @@ describe('TokenDetectionController', () => { }, async ({ controller, mockTokenListGetState, callActionSpy }) => { - mockMultiChainAccountsService(); - - const mockAPI = mockMultiChainAccountsService(); - mockAPI.mockFetchMultiChainBalances.mockResolvedValue({ - count: 0, - balances: [ - { - object: 'token', - address: '0xaddress', - name: 'Mock Token', - symbol: 'MOCK', - decimals: 18, - balance: '10.18', - chainId: 2, - }, - ], - unprocessedNetworks: [], - }); - mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { - '0x1': { + '0xa86a': { timestamp: 0, data: { test: { @@ -520,7 +507,7 @@ describe('TokenDetectionController', () => { 'TokensController:addDetectedTokens', [sampleTokenA], { - chainId: ChainId.mainnet, + chainId: ChainId.sepolia, selectedAddress: selectedAccount.address, }, ); @@ -528,7 +515,7 @@ describe('TokenDetectionController', () => { ); }); - it('should detect tokens correctly on the Polygon network', async () => { + it('should detect tokens correctly on the Sepolia network', async () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); @@ -539,7 +526,6 @@ describe('TokenDetectionController', () => { { options: { getBalancesInSingleCall: mockGetBalancesInSingleCall, - useAccountsAPI: true, // USING ACCOUNTS API }, mocks: { getAccount: selectedAccount, @@ -554,22 +540,22 @@ describe('TokenDetectionController', () => { mockFindNetworkClientIdByChainId, callActionSpy, }) => { - mockMultiChainAccountsService(); + // Use Sepolia (0xaa36a7) which is not in SUPPORTED_NETWORKS_ACCOUNTS_API_V4 mockNetworkState({ ...getDefaultNetworkControllerState(), - selectedNetworkClientId: 'polygon', + selectedNetworkClientId: 'avalanche', }); mockGetNetworkClientById( () => ({ - configuration: { chainId: '0x89' }, + configuration: { chainId: '0xa86a' }, }) as unknown as AutoManagedNetworkClient, ); - mockFindNetworkClientIdByChainId(() => 'polygon'); + mockFindNetworkClientIdByChainId(() => 'avalanche'); mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { - '0x89': { + '0xa86a': { timestamp: 0, data: { [sampleTokenA.address]: { @@ -591,7 +577,7 @@ describe('TokenDetectionController', () => { expect(callActionSpy).toHaveBeenCalledWith( 'TokensController:addTokens', [sampleTokenA], - 'polygon', + 'avalanche', ); }, ); @@ -617,12 +603,20 @@ describe('TokenDetectionController', () => { getSelectedAccount: selectedAccount, }, }, - async ({ controller, mockTokenListGetState, callActionSpy }) => { - mockMultiChainAccountsService(); + async ({ + controller, + mockTokenListGetState, + callActionSpy, + mockNetworkState, + }) => { + mockNetworkState({ + ...getDefaultNetworkControllerState(), + selectedNetworkClientId: 'avalanche', + }); const tokenListState = { ...getDefaultTokenListState(), tokensChainsCache: { - '0x1': { + '0xa86a': { timestamp: 0, data: { [sampleTokenA.address]: { @@ -641,7 +635,9 @@ describe('TokenDetectionController', () => { mockTokenListGetState(tokenListState); await controller.start(); - tokenListState.tokensChainsCache['0x1'].data[sampleTokenB.address] = { + tokenListState.tokensChainsCache['0xa86a'].data[ + sampleTokenB.address + ] = { name: sampleTokenB.name, symbol: sampleTokenB.symbol, decimals: sampleTokenB.decimals, @@ -656,7 +652,7 @@ describe('TokenDetectionController', () => { expect(callActionSpy).toHaveBeenCalledWith( 'TokensController:addTokens', [sampleTokenA, sampleTokenB], - 'mainnet', + 'avalanche', ); }, ); @@ -673,7 +669,6 @@ describe('TokenDetectionController', () => { { options: { getBalancesInSingleCall: mockGetBalancesInSingleCall, - useAccountsAPI: true, // USING ACCOUNTS API }, mocks: { getAccount: selectedAccount, @@ -686,14 +681,13 @@ describe('TokenDetectionController', () => { mockTokenListGetState, callActionSpy, }) => { - mockMultiChainAccountsService(); mockTokensGetState({ ...getDefaultTokensState(), }); mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { - '0x1': { + '0xa86a': { timestamp: 0, data: { [sampleTokenA.address]: { @@ -727,18 +721,16 @@ describe('TokenDetectionController', () => { { options: { getBalancesInSingleCall: mockGetBalancesInSingleCall, - useAccountsAPI: true, // USING ACCOUNTS API }, mocks: { getSelectedAccount: defaultSelectedAccount, }, }, async ({ controller, mockTokenListGetState, callActionSpy }) => { - mockMultiChainAccountsService(); mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { - '0x1': { + '0xa86a': { timestamp: 0, data: { [sampleTokenA.address]: { @@ -791,7 +783,6 @@ describe('TokenDetectionController', () => { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - useAccountsAPI: true, // USING ACCOUNTS API }, mocks: { getSelectedAccount: firstSelectedAccount, @@ -802,12 +793,22 @@ describe('TokenDetectionController', () => { mockTokenListGetState, triggerSelectedAccountChange, callActionSpy, + mockNetworkState, }) => { - mockMultiChainAccountsService(); + // Set selectedNetworkClientId to avalanche and include it in networkConfigurationsByChainId + const defaultState = getDefaultNetworkControllerState(); + mockNetworkState({ + ...defaultState, + selectedNetworkClientId: 'avalanche', + networkConfigurationsByChainId: { + ...defaultState.networkConfigurationsByChainId, + ...mockNetworkConfigurationsByChainId, + }, + }); mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { - '0x1': { + '0xa86a': { timestamp: 0, data: { [sampleTokenA.address]: { @@ -831,7 +832,7 @@ describe('TokenDetectionController', () => { expect(callActionSpy).toHaveBeenCalledWith( 'TokensController:addTokens', [sampleTokenA], - 'mainnet', + 'avalanche', ); }, ); @@ -849,7 +850,6 @@ describe('TokenDetectionController', () => { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - useAccountsAPI: true, // USING ACCOUNTS API }, mocks: { getSelectedAccount: selectedAccount, @@ -860,11 +860,10 @@ describe('TokenDetectionController', () => { triggerSelectedAccountChange, callActionSpy, }) => { - mockMultiChainAccountsService(); mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { - '0x1': { + '0xa86a': { timestamp: 0, data: { [sampleTokenA.address]: { @@ -923,7 +922,7 @@ describe('TokenDetectionController', () => { mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { - '0x1': { + '0xa86a': { timestamp: 0, data: { [sampleTokenA.address]: { @@ -970,7 +969,6 @@ describe('TokenDetectionController', () => { options: { disabled: true, getBalancesInSingleCall: mockGetBalancesInSingleCall, - useAccountsAPI: true, // USING ACCOUNTS API }, mocks: { getSelectedAccount: firstSelectedAccount, @@ -984,7 +982,7 @@ describe('TokenDetectionController', () => { mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { - '0x1': { + '0xa86a': { timestamp: 0, data: { [sampleTokenA.address]: { @@ -1041,7 +1039,6 @@ describe('TokenDetectionController', () => { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - useAccountsAPI: true, // USING ACCOUNTS API }, mocks: { getSelectedAccount: firstSelectedAccount, @@ -1055,11 +1052,10 @@ describe('TokenDetectionController', () => { triggerSelectedAccountChange, callActionSpy, }) => { - mockMultiChainAccountsService(); mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { - '0x1': { + '0xa86a': { timestamp: 0, data: { [sampleTokenA.address]: { @@ -1077,24 +1073,23 @@ describe('TokenDetectionController', () => { }); mockNetworkState({ networkConfigurationsByChainId: { - '0x1': { - name: 'ethereum', - nativeCurrency: 'ETH', + '0xa86a': { + name: 'avalanche', + nativeCurrency: 'AVAX', rpcEndpoints: [ { - networkClientId: 'mainnet', - type: RpcEndpointType.Infura, - url: 'https://mainnet.infura.io/v3/{infuraProjectId}', - failoverUrls: [], + networkClientId: 'avalanche', + type: RpcEndpointType.Custom, + url: 'https://api.avax.network/ext/bc/C/rpc', }, ], blockExplorerUrls: [], - chainId: '0x1', + chainId: '0xa86a', defaultRpcEndpointIndex: 0, }, }, networksMetadata: {}, - selectedNetworkClientId: 'mainnet', + selectedNetworkClientId: 'avalanche', }); triggerPreferencesStateChange({ @@ -1108,7 +1103,7 @@ describe('TokenDetectionController', () => { expect(callActionSpy).toHaveBeenLastCalledWith( 'TokensController:addTokens', [sampleTokenA], - 'mainnet', + 'avalanche', ); }, ); @@ -1129,7 +1124,6 @@ describe('TokenDetectionController', () => { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - useAccountsAPI: true, // USING ACCOUNTS API }, mocks: { getSelectedAccount: firstSelectedAccount, @@ -1144,11 +1138,10 @@ describe('TokenDetectionController', () => { controller, }) => { const mockTokens = jest.spyOn(controller, 'detectTokens'); - mockMultiChainAccountsService(); mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { - '0x1': { + '0xa86a': { timestamp: 0, data: { [sampleTokenA.address]: { @@ -1164,9 +1157,10 @@ describe('TokenDetectionController', () => { }, }, }); + // Set to avalanche which is not in SUPPORTED_NETWORKS_ACCOUNTS_API_V4 mockNetworkState({ ...getDefaultNetworkControllerState(), - selectedNetworkClientId: NetworkType.mainnet, + selectedNetworkClientId: 'avalanche', }); triggerPreferencesStateChange({ @@ -1178,21 +1172,9 @@ describe('TokenDetectionController', () => { await advanceTime({ clock, duration: 1 }); - expect(mockTokens).toHaveBeenNthCalledWith(1, { - chainIds: [ - '0x1', - '0xaa36a7', - '0xe705', - '0xe708', - '0x2105', - '0xa4b1', - '0x38', - '0xa', - '0x89', - '0x531', - ], - selectedAddress: secondSelectedAccount.address, - }); + // detectTokens is called once when account changes + // (preference change doesn't trigger since useTokenDetection was already true by default) + expect(mockTokens).toHaveBeenCalledTimes(1); }, ); }); @@ -1209,7 +1191,6 @@ describe('TokenDetectionController', () => { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - useAccountsAPI: true, // USING ACCOUNTS API }, mocks: { getSelectedAccount: selectedAccount, @@ -1220,13 +1201,18 @@ describe('TokenDetectionController', () => { mockTokenListGetState, triggerPreferencesStateChange, callActionSpy, + mockNetworkState, }) => { - mockMultiChainAccountsService(); + // Set selectedNetworkClientId to avalanche (not in SUPPORTED_NETWORKS_ACCOUNTS_API_V4) + mockNetworkState({ + ...getDefaultNetworkControllerState(), + selectedNetworkClientId: 'avalanche', + }); mockGetAccount(selectedAccount); mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { - '0x1': { + '0xa86a': { timestamp: 0, data: { [sampleTokenA.address]: { @@ -1258,7 +1244,7 @@ describe('TokenDetectionController', () => { expect(callActionSpy).toHaveBeenCalledWith( 'TokensController:addTokens', [sampleTokenA], - 'mainnet', + 'avalanche', ); }, ); @@ -1279,7 +1265,6 @@ describe('TokenDetectionController', () => { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - useAccountsAPI: true, // USING ACCOUNTS API }, mocks: { getSelectedAccount: firstSelectedAccount, @@ -1292,12 +1277,11 @@ describe('TokenDetectionController', () => { triggerPreferencesStateChange, callActionSpy, }) => { - mockMultiChainAccountsService(); mockGetAccount(firstSelectedAccount); mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { - '0x1': { + '0xa86a': { data: { [sampleTokenA.address]: { name: sampleTokenA.name, @@ -1355,7 +1339,7 @@ describe('TokenDetectionController', () => { mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { - [ChainId.mainnet]: { + [ChainId.sepolia]: { data: { [sampleTokenA.address]: { name: sampleTokenA.name, @@ -1419,7 +1403,7 @@ describe('TokenDetectionController', () => { mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { - [ChainId.mainnet]: { + [ChainId.sepolia]: { data: { [sampleTokenA.address]: { name: sampleTokenA.name, @@ -1478,7 +1462,7 @@ describe('TokenDetectionController', () => { mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { - [ChainId.mainnet]: { + [ChainId.sepolia]: { data: { [sampleTokenA.address]: { name: sampleTokenA.name, @@ -1547,7 +1531,7 @@ describe('TokenDetectionController', () => { mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { - [ChainId.mainnet]: { + [ChainId.sepolia]: { data: { [sampleTokenA.address]: { name: sampleTokenA.name, @@ -1605,7 +1589,7 @@ describe('TokenDetectionController', () => { mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { - [ChainId.mainnet]: { + [ChainId.sepolia]: { data: { [sampleTokenA.address]: { name: sampleTokenA.name, @@ -1736,7 +1720,7 @@ describe('TokenDetectionController', () => { mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { - [ChainId.mainnet]: { + [ChainId.sepolia]: { data: { [sampleTokenA.address]: { name: sampleTokenA.name, @@ -1755,7 +1739,7 @@ describe('TokenDetectionController', () => { triggerNetworkDidChange({ ...getDefaultNetworkControllerState(), - selectedNetworkClientId: 'mainnet', + selectedNetworkClientId: 'avalanche', }); await advanceTime({ clock, duration: 1 }); @@ -1794,7 +1778,7 @@ describe('TokenDetectionController', () => { mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { - [ChainId.mainnet]: { + [ChainId.sepolia]: { data: { [sampleTokenA.address]: { name: sampleTokenA.name, @@ -1813,7 +1797,7 @@ describe('TokenDetectionController', () => { triggerNetworkDidChange({ ...getDefaultNetworkControllerState(), - selectedNetworkClientId: 'polygon', + selectedNetworkClientId: 'avalanche', }); await advanceTime({ clock, duration: 1 }); @@ -1853,7 +1837,7 @@ describe('TokenDetectionController', () => { mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { - [ChainId.mainnet]: { + [ChainId.sepolia]: { data: { [sampleTokenA.address]: { name: sampleTokenA.name, @@ -1872,7 +1856,7 @@ describe('TokenDetectionController', () => { triggerNetworkDidChange({ ...getDefaultNetworkControllerState(), - selectedNetworkClientId: 'polygon', + selectedNetworkClientId: 'avalanche', }); await advanceTime({ clock, duration: 1 }); @@ -1908,7 +1892,6 @@ describe('TokenDetectionController', () => { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - useAccountsAPI: true, // USING ACCOUNTS API }, mocks: { getSelectedAccount: selectedAccount, @@ -1919,8 +1902,13 @@ describe('TokenDetectionController', () => { mockTokenListGetState, callActionSpy, triggerTokenListStateChange, + mockNetworkState, }) => { - mockMultiChainAccountsService(); + // Set selectedNetworkClientId to avalanche (not in SUPPORTED_NETWORKS_ACCOUNTS_API_V4) + mockNetworkState({ + ...getDefaultNetworkControllerState(), + selectedNetworkClientId: 'avalanche', + }); const tokenList = { [sampleTokenA.address]: { name: sampleTokenA.name, @@ -1935,7 +1923,7 @@ describe('TokenDetectionController', () => { const tokenListState = { ...getDefaultTokenListState(), tokensChainsCache: { - '0x1': { + '0xa86a': { timestamp: 0, data: tokenList, }, @@ -1949,7 +1937,7 @@ describe('TokenDetectionController', () => { expect(callActionSpy).toHaveBeenCalledWith( 'TokensController:addTokens', [sampleTokenA], - 'mainnet', + 'avalanche', ); }, ); @@ -2022,7 +2010,7 @@ describe('TokenDetectionController', () => { const tokenListState = { ...getDefaultTokenListState(), tokensChainsCache: { - [ChainId.mainnet]: { + [ChainId.sepolia]: { data: { [sampleTokenA.address]: { name: sampleTokenA.name, @@ -2079,7 +2067,7 @@ describe('TokenDetectionController', () => { const tokenListState = { ...getDefaultTokenListState(), tokensChainsCache: { - [ChainId.mainnet]: { + [ChainId.sepolia]: { data: { [sampleTokenA.address]: { name: sampleTokenA.name, @@ -2121,7 +2109,6 @@ describe('TokenDetectionController', () => { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - useAccountsAPI: true, // USING ACCOUNTS API }, mocks: { getSelectedAccount: selectedAccount, @@ -2133,11 +2120,10 @@ describe('TokenDetectionController', () => { triggerTokenListStateChange, controller, }) => { - mockMultiChainAccountsService(); const tokenListState = { ...getDefaultTokenListState(), tokensChainsCache: { - [ChainId.mainnet]: { + [ChainId.sepolia]: { data: { [sampleTokenA.address]: { name: sampleTokenA.name, @@ -2182,7 +2168,6 @@ describe('TokenDetectionController', () => { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - useAccountsAPI: true, // USING ACCOUNTS API }, mocks: { getSelectedAccount: selectedAccount, @@ -2194,11 +2179,10 @@ describe('TokenDetectionController', () => { triggerTokenListStateChange, controller, }) => { - mockMultiChainAccountsService(); const tokenListState = { ...getDefaultTokenListState(), tokensChainsCache: { - [ChainId.mainnet]: { + [ChainId.sepolia]: { data: { [sampleTokenA.address]: { name: sampleTokenA.name, @@ -2225,7 +2209,7 @@ describe('TokenDetectionController', () => { triggerTokenListStateChange({ ...tokenListState, tokensChainsCache: { - [ChainId.mainnet]: { + [ChainId.sepolia]: { data: { [sampleTokenA.address]: { name: sampleTokenA.name, @@ -2261,7 +2245,6 @@ describe('TokenDetectionController', () => { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - useAccountsAPI: true, // USING ACCOUNTS API }, mocks: { getSelectedAccount: selectedAccount, @@ -2273,11 +2256,10 @@ describe('TokenDetectionController', () => { triggerTokenListStateChange, controller, }) => { - mockMultiChainAccountsService(); const tokenListState = { ...getDefaultTokenListState(), tokensChainsCache: { - [ChainId.mainnet]: { + [ChainId.sepolia]: { data: { [sampleTokenA.address]: { name: sampleTokenA.name, @@ -2361,7 +2343,7 @@ describe('TokenDetectionController', () => { mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { - [ChainId.mainnet]: { + [ChainId.sepolia]: { data: { [sampleTokenA.address]: { name: sampleTokenA.name, @@ -2384,11 +2366,11 @@ describe('TokenDetectionController', () => { }); controller.startPolling({ - chainIds: ['0x1'], + chainIds: ['0xa86a'], address: '0x1', }); controller.startPolling({ - chainIds: ['0xaa36a7'], + chainIds: ['0xa86a'], address: '0xdeadbeef', }); controller.startPolling({ @@ -2398,18 +2380,18 @@ describe('TokenDetectionController', () => { await advanceTime({ clock, duration: 0 }); expect(spy.mock.calls).toMatchObject([ - [{ chainIds: ['0x1'], selectedAddress: '0x1' }], - [{ chainIds: ['0xaa36a7'], selectedAddress: '0xdeadbeef' }], + [{ chainIds: ['0xa86a'], selectedAddress: '0x1' }], + [{ chainIds: ['0xa86a'], selectedAddress: '0xdeadbeef' }], [{ chainIds: ['0x5'], selectedAddress: '0x3' }], ]); await advanceTime({ clock, duration: DEFAULT_INTERVAL }); expect(spy.mock.calls).toMatchObject([ - [{ chainIds: ['0x1'], selectedAddress: '0x1' }], - [{ chainIds: ['0xaa36a7'], selectedAddress: '0xdeadbeef' }], + [{ chainIds: ['0xa86a'], selectedAddress: '0x1' }], + [{ chainIds: ['0xa86a'], selectedAddress: '0xdeadbeef' }], [{ chainIds: ['0x5'], selectedAddress: '0x3' }], - [{ chainIds: ['0x1'], selectedAddress: '0x1' }], - [{ chainIds: ['0xaa36a7'], selectedAddress: '0xdeadbeef' }], + [{ chainIds: ['0xa86a'], selectedAddress: '0x1' }], + [{ chainIds: ['0xa86a'], selectedAddress: '0xdeadbeef' }], [{ chainIds: ['0x5'], selectedAddress: '0x3' }], ]); }, @@ -2430,7 +2412,6 @@ describe('TokenDetectionController', () => { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - useAccountsAPI: true, // USING ACCOUNTS API }, mocks: { getSelectedAccount: selectedAccount, @@ -2443,7 +2424,6 @@ describe('TokenDetectionController', () => { triggerPreferencesStateChange, callActionSpy, }) => { - mockMultiChainAccountsService(); mockNetworkState({ ...getDefaultNetworkControllerState(), selectedNetworkClientId: NetworkType.sepolia, @@ -2463,59 +2443,9 @@ describe('TokenDetectionController', () => { ); }); - it('should detect and add tokens from the `@metamask/contract-metadata` legacy token list if token detection is disabled and current network is mainnet', async () => { - const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue( - Object.keys(STATIC_MAINNET_TOKEN_LIST).reduce>( - (acc, address) => { - acc[address] = new BN(1); - return acc; - }, - {}, - ), - ); - const selectedAccount = createMockInternalAccount({ - address: '0x0000000000000000000000000000000000000001', - }); - await withController( - { - options: { - disabled: false, - getBalancesInSingleCall: mockGetBalancesInSingleCall, - }, - mocks: { - getSelectedAccount: selectedAccount, - getAccount: selectedAccount, - }, - }, - async ({ - controller, - triggerPreferencesStateChange, - callActionSpy, - }) => { - mockMultiChainAccountsService(); - triggerPreferencesStateChange({ - ...getDefaultPreferencesState(), - useTokenDetection: false, - }); - await controller.detectTokens({ - chainIds: ['0x1'], - selectedAddress: selectedAccount.address, - }); - expect(callActionSpy).toHaveBeenLastCalledWith( - 'TokensController:addTokens', - Object.values(STATIC_MAINNET_TOKEN_LIST).map((token) => { - const { iconUrl, ...tokenMetadata } = token; - return { - ...tokenMetadata, - image: token.iconUrl, - isERC721: false, - }; - }), - 'mainnet', - ); - }, - ); - }); + // Note: Test for mainnet legacy token list detection has been removed. + // Mainnet is now in SUPPORTED_NETWORKS_ACCOUNTS_API_V4, so RPC detection is skipped. + // Token detection for mainnet is handled via TokenBalancesController (Accounts API). it('should detect and add tokens by networkClientId correctly', async () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ @@ -2529,19 +2459,31 @@ describe('TokenDetectionController', () => { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - useAccountsAPI: true, // USING ACCOUNTS API }, mocks: { getSelectedAccount: selectedAccount, getAccount: selectedAccount, }, }, - async ({ controller, mockTokenListGetState, callActionSpy }) => { - mockMultiChainAccountsService(); + async ({ + controller, + mockTokenListGetState, + callActionSpy, + mockNetworkState, + }) => { + // Include Avalanche in networkConfigurationsByChainId for explicit chainId lookup + const defaultState = getDefaultNetworkControllerState(); + mockNetworkState({ + ...defaultState, + networkConfigurationsByChainId: { + ...defaultState.networkConfigurationsByChainId, + ...mockNetworkConfigurationsByChainId, + }, + }); mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { - '0x1': { + '0xa86a': { timestamp: 0, data: { [sampleTokenA.address]: { @@ -2559,14 +2501,14 @@ describe('TokenDetectionController', () => { }); await controller.detectTokens({ - chainIds: ['0x1'], + chainIds: ['0xa86a'], selectedAddress: selectedAccount.address, }); expect(callActionSpy).toHaveBeenCalledWith( 'TokensController:addTokens', [sampleTokenA], - 'mainnet', + 'avalanche', ); }, ); @@ -2587,19 +2529,26 @@ describe('TokenDetectionController', () => { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, trackMetaMetricsEvent: mockTrackMetaMetricsEvent, - useAccountsAPI: true, // USING ACCOUNTS API }, mocks: { getSelectedAccount: selectedAccount, getAccount: selectedAccount, }, }, - async ({ controller, mockTokenListGetState }) => { - mockMultiChainAccountsService(); + async ({ controller, mockTokenListGetState, mockNetworkState }) => { + // Include Avalanche in networkConfigurationsByChainId for explicit chainId lookup + const defaultState = getDefaultNetworkControllerState(); + mockNetworkState({ + ...defaultState, + networkConfigurationsByChainId: { + ...defaultState.networkConfigurationsByChainId, + ...mockNetworkConfigurationsByChainId, + }, + }); mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { - '0x1': { + '0xa86a': { timestamp: 0, data: { [sampleTokenA.address]: { @@ -2617,7 +2566,7 @@ describe('TokenDetectionController', () => { }); await controller.detectTokens({ - chainIds: ['0x1'], + chainIds: ['0xa86a'], selectedAddress: selectedAccount.address, }); @@ -2647,7 +2596,6 @@ describe('TokenDetectionController', () => { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, trackMetaMetricsEvent: mockTrackMetaMetricsEvent, - useAccountsAPI: true, // USING ACCOUNTS API }, }, async ({ @@ -2655,14 +2603,23 @@ describe('TokenDetectionController', () => { mockGetAccount, mockTokenListGetState, callActionSpy, + mockNetworkState, }) => { - mockMultiChainAccountsService(); + // Include Avalanche in networkConfigurationsByChainId for explicit chainId lookup + const defaultState = getDefaultNetworkControllerState(); + mockNetworkState({ + ...defaultState, + networkConfigurationsByChainId: { + ...defaultState.networkConfigurationsByChainId, + ...mockNetworkConfigurationsByChainId, + }, + }); // @ts-expect-error forcing an undefined value mockGetAccount(undefined); mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { - '0x1': { + '0xa86a': { timestamp: 0, data: { [sampleTokenA.address]: { @@ -2680,7 +2637,7 @@ describe('TokenDetectionController', () => { }); await controller.detectTokens({ - chainIds: ['0x1'], + chainIds: ['0xa86a'], }); expect(callActionSpy).toHaveBeenLastCalledWith( @@ -2709,7 +2666,7 @@ describe('TokenDetectionController', () => { symbol: 'LINK', }, ], - 'mainnet', + 'avalanche', ); }, ); @@ -2727,7 +2684,6 @@ describe('TokenDetectionController', () => { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - useAccountsAPI: true, // USING ACCOUNTS API }, mocks: { getSelectedAccount: selectedAccount, @@ -2740,10 +2696,6 @@ describe('TokenDetectionController', () => { triggerPreferencesStateChange, callActionSpy, }) => { - const mockAPI = mockMultiChainAccountsService(); - mockAPI.mockFetchMultiChainBalances.mockRejectedValue( - new Error('Mock Error'), - ); mockNetworkState({ ...getDefaultNetworkControllerState(), selectedNetworkClientId: 'polygon', @@ -2763,134 +2715,43 @@ describe('TokenDetectionController', () => { ); }); - it('should timeout and fallback to RPC when Accounts API call takes longer than 30 seconds', async () => { - // Use fake timers to simulate the 30-second timeout - const clock = sinon.useFakeTimers(); - - try { - // Arrange - RPC Tokens Flow - Uses sampleTokenA - const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ - [sampleTokenA.address]: new BN(1), - }); - - // Mock a hanging API call that never resolves (simulates network timeout) - const mockAPI = mockMultiChainAccountsService(); - mockAPI.mockFetchSupportedNetworks.mockResolvedValue([1]); - mockAPI.mockFetchMultiChainBalances.mockImplementation( - () => - new Promise(() => { - // Promise that never resolves (simulating a hanging request) - }), - ); - - // Arrange - Selected Account - const selectedAccount = createMockInternalAccount({ - address: '0x0000000000000000000000000000000000000001', - }); - - // Arrange / Act - withController setup - await withController( - { - options: { - disabled: false, - getBalancesInSingleCall: mockGetBalancesInSingleCall, - useAccountsAPI: true, // USING ACCOUNTS API - }, - mocks: { - getSelectedAccount: selectedAccount, - getAccount: selectedAccount, - }, - }, - async ({ controller, mockTokenListGetState, callActionSpy }) => { - mockTokenListGetState({ - ...getDefaultTokenListState(), - tokensChainsCache: { - '0x1': { - timestamp: 0, - data: { - [sampleTokenA.address]: { - name: sampleTokenA.name, - symbol: sampleTokenA.symbol, - decimals: sampleTokenA.decimals, - address: sampleTokenA.address, - occurrences: 1, - aggregators: sampleTokenA.aggregators, - iconUrl: sampleTokenA.image, - }, - }, - }, - }, - }); - - // Start the detection process (don't await yet so we can advance time) - const detectPromise = controller.detectTokens({ - chainIds: ['0x1'], - selectedAddress: selectedAccount.address, - }); - - // Fast-forward time by 30 seconds to trigger the timeout - // This simulates the API call taking longer than the ACCOUNTS_API_TIMEOUT_MS (30000ms) - await advanceTime({ clock, duration: 30000 }); - - // Now await the result after the timeout has been triggered - await detectPromise; - - // Verify that the API was initially called - expect(mockAPI.mockFetchMultiChainBalances).toHaveBeenCalled(); - - // Verify that after timeout, RPC fallback was triggered - expect(mockGetBalancesInSingleCall).toHaveBeenCalled(); - - // Verify that tokens were added via RPC fallback method - expect(callActionSpy).toHaveBeenCalledWith( - 'TokensController:addTokens', - [sampleTokenA], - 'mainnet', - ); - }, - ); - } finally { - clock.restore(); - } - }); - - it('should fallback to RPC when Accounts API call fails with an error (safelyExecute returns undefined)', async () => { - // Arrange - RPC Tokens Flow - Uses sampleTokenA + it('should detect tokens when TransactionController:transactionConfirmed is triggered', async () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); - - // Mock an API call that throws an error inside safelyExecute - // This simulates a scenario where the API throws an error (network failure, parsing error, etc.) - const mockAPI = mockMultiChainAccountsService(); - mockAPI.mockFetchSupportedNetworks.mockResolvedValue([1]); - mockAPI.mockFetchMultiChainBalances.mockRejectedValue( - new Error('API Network Error'), - ); - - // Arrange - Selected Account const selectedAccount = createMockInternalAccount({ address: '0x0000000000000000000000000000000000000001', }); - - // Arrange / Act - withController setup await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - useAccountsAPI: true, // USING ACCOUNTS API }, mocks: { getSelectedAccount: selectedAccount, getAccount: selectedAccount, }, }, - async ({ controller, mockTokenListGetState, callActionSpy }) => { + async ({ + mockTokenListGetState, + mockNetworkState, + callActionSpy, + triggerTransactionConfirmed, + }) => { + const defaultState = getDefaultNetworkControllerState(); + mockNetworkState({ + ...defaultState, + selectedNetworkClientId: 'avalanche', + networkConfigurationsByChainId: { + ...defaultState.networkConfigurationsByChainId, + ...mockNetworkConfigurationsByChainId, + }, + }); mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { - '0x1': { + '0xa86a': { timestamp: 0, data: { [sampleTokenA.address]: { @@ -2907,78 +2768,63 @@ describe('TokenDetectionController', () => { }, }); - // Execute detection - await controller.detectTokens({ - chainIds: ['0x1'], - selectedAddress: selectedAccount.address, - }); - - // Verify that the API was initially called - expect(mockAPI.mockFetchMultiChainBalances).toHaveBeenCalled(); - - // Verify that after API error (safelyExecute returns undefined), RPC fallback was triggered - expect(mockGetBalancesInSingleCall).toHaveBeenCalled(); + triggerTransactionConfirmed({ chainId: '0xa86a' }); + // Wait for async detection to complete + await new Promise((resolve) => setTimeout(resolve, 10)); - // Verify that tokens were added via RPC fallback method expect(callActionSpy).toHaveBeenCalledWith( 'TokensController:addTokens', [sampleTokenA], - 'mainnet', + 'avalanche', ); }, ); }); - /** - * Test Utility - Arrange and Act `detectTokens()` with the Accounts API feature - * RPC flow will return `sampleTokenA` and the Accounts API flow will use `sampleTokenB` - * - * @param props - options to modify these tests - * @param props.overrideMockTokensCache - change the tokens cache - * @param props.mockMultiChainAPI - change the Accounts API responses - * @param props.overrideMockTokenGetState - change the external TokensController state - * @returns properties that can be used for assertions - */ - const arrangeActTestDetectTokensWithAccountsAPI = async (props?: { - /** Overwrite the tokens cache inside Tokens Controller */ - overrideMockTokensCache?: (typeof sampleTokenA)[]; - mockMultiChainAPI?: ReturnType; - overrideMockTokenGetState?: Partial; - }) => { - const { - overrideMockTokensCache = [sampleTokenA, sampleTokenB], - mockMultiChainAPI, - overrideMockTokenGetState, - } = props ?? {}; - - // Arrange - RPC Tokens Flow - Uses sampleTokenA + it('should not detect tokens when useExternalServices returns false', async () => { const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ [sampleTokenA.address]: new BN(1), }); + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); + await withController( + { + options: { + disabled: false, + getBalancesInSingleCall: mockGetBalancesInSingleCall, + useExternalServices: () => false, + }, + mocks: { + getSelectedAccount: selectedAccount, + getAccount: selectedAccount, + }, + }, + async ({ controller, callActionSpy }) => { + await controller.detectTokens(); - // Arrange - API Tokens Flow - Uses sampleTokenB - const { mockFetchSupportedNetworks, mockFetchMultiChainBalances } = - mockMultiChainAPI ?? mockMultiChainAccountsService(); - - if (!mockMultiChainAPI) { - mockFetchSupportedNetworks.mockResolvedValue([1]); - mockFetchMultiChainBalances.mockResolvedValue( - createMockGetBalancesResponse([sampleTokenB.address], 1), - ); - } + expect(callActionSpy).not.toHaveBeenCalledWith( + 'TokensController:addTokens', + ); + expect(callActionSpy).not.toHaveBeenCalledWith( + 'TokensController:addDetectedTokens', + ); + }, + ); + }); - // Arrange - Selected Account + it('should not detect tokens when no client networks are found', async () => { + const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ + [sampleTokenA.address]: new BN(1), + }); const selectedAccount = createMockInternalAccount({ address: '0x0000000000000000000000000000000000000001', }); - - // Arrange / Act - withController setup + invoke detectTokens - const { callAction } = await withController( + await withController( { options: { disabled: false, getBalancesInSingleCall: mockGetBalancesInSingleCall, - useAccountsAPI: true, // USING ACCOUNTS API }, mocks: { getSelectedAccount: selectedAccount, @@ -2987,156 +2833,324 @@ describe('TokenDetectionController', () => { }, async ({ controller, - mockTokenListGetState, + mockNetworkState, + mockGetNetworkConfigurationByNetworkClientId, callActionSpy, - mockTokensGetState, }) => { - const tokenCacheData: TokenListMap = {}; - overrideMockTokensCache.forEach( - (t) => - (tokenCacheData[t.address] = { - name: t.name, - symbol: t.symbol, - decimals: t.decimals, - address: t.address, - occurrences: 1, - aggregators: t.aggregators, - iconUrl: t.image, - }), - ); + mockNetworkState({ + ...getDefaultNetworkControllerState(), + selectedNetworkClientId: 'unknown-network', + }); + // Return undefined for unknown network to simulate no network config + mockGetNetworkConfigurationByNetworkClientId( + () => undefined as never, + ); + + await controller.detectTokens(); + + expect(callActionSpy).not.toHaveBeenCalledWith( + 'TokensController:addTokens', + ); + }, + ); + }); + it('should filter out tokens that are already owned by the user', async () => { + const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ + [sampleTokenA.address]: new BN(1), + }); + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); + await withController( + { + options: { + disabled: false, + getBalancesInSingleCall: mockGetBalancesInSingleCall, + }, + mocks: { + getSelectedAccount: selectedAccount, + getAccount: selectedAccount, + }, + }, + async ({ + controller, + mockNetworkState, + mockTokenListGetState, + mockTokensGetState, + callActionSpy, + }) => { + const defaultState = getDefaultNetworkControllerState(); + mockNetworkState({ + ...defaultState, + selectedNetworkClientId: 'avalanche', + networkConfigurationsByChainId: { + ...defaultState.networkConfigurationsByChainId, + ...mockNetworkConfigurationsByChainId, + }, + }); mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { - '0x1': { + '0xa86a': { timestamp: 0, - data: tokenCacheData, + data: { + [sampleTokenA.address]: { + name: sampleTokenA.name, + symbol: sampleTokenA.symbol, + decimals: sampleTokenA.decimals, + address: sampleTokenA.address, + occurrences: 1, + aggregators: sampleTokenA.aggregators, + iconUrl: sampleTokenA.image, + }, + }, }, }, }); - - if (overrideMockTokenGetState) { - mockTokensGetState({ - ...getDefaultTokensState(), - ...overrideMockTokenGetState, - }); - } - - // Act - await controller.detectTokens({ - chainIds: ['0x1'], - selectedAddress: selectedAccount.address, + // Mock that the user already owns this token + mockTokensGetState({ + ...getDefaultTokensState(), + allTokens: { + '0xa86a': { + [selectedAccount.address]: [ + { address: sampleTokenA.address } as Token, + ], + }, + }, }); - return { - callAction: callActionSpy, - }; + await controller.detectTokens(); + + // Should not call addTokens since token is already owned + expect(callActionSpy).not.toHaveBeenCalledWith( + 'TokensController:addTokens', + expect.anything(), + 'avalanche', + ); }, ); - - const assertAddedTokens = (token: Token) => - expect(callAction).toHaveBeenCalledWith( - 'TokensController:addTokens', - [token], - 'mainnet', - ); - - const assertTokensNeverAdded = () => - expect(callAction).not.toHaveBeenCalledWith( - 'TokensController:addTokens', - ); - - return { - assertAddedTokens, - assertTokensNeverAdded, - mockFetchMultiChainBalances, - mockGetBalancesInSingleCall, - rpcToken: sampleTokenA, - apiToken: sampleTokenB, - }; - }; - - it('should trigger and use Accounts API for detection', async () => { - const { - assertAddedTokens, - mockFetchMultiChainBalances, - apiToken, - mockGetBalancesInSingleCall, - } = await arrangeActTestDetectTokensWithAccountsAPI(); - - expect(mockFetchMultiChainBalances).toHaveBeenCalled(); - expect(mockGetBalancesInSingleCall).not.toHaveBeenCalled(); - assertAddedTokens(apiToken); }); - it('uses the Accounts API but does not add unknown tokens', async () => { - // API returns sampleTokenB - // As this is not a known token (in cache), then is not added - const { - assertTokensNeverAdded, - mockFetchMultiChainBalances, - mockGetBalancesInSingleCall, - } = await arrangeActTestDetectTokensWithAccountsAPI({ - overrideMockTokensCache: [sampleTokenA], + it('should use static mainnet token list when token detection is disabled for mainnet', async () => { + const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ + '0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48': new BN(1), // USDC on mainnet + }); + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', }); + await withController( + { + options: { + disabled: false, + getBalancesInSingleCall: mockGetBalancesInSingleCall, + }, + mocks: { + getSelectedAccount: selectedAccount, + getAccount: selectedAccount, + }, + }, + async ({ + controller, + mockNetworkState, + mockFindNetworkClientIdByChainId, + triggerPreferencesStateChange, + }) => { + const defaultState = getDefaultNetworkControllerState(); + mockNetworkState({ + ...defaultState, + selectedNetworkClientId: 'mainnet', + networkConfigurationsByChainId: { + ...defaultState.networkConfigurationsByChainId, + '0x1': { + chainId: '0x1', + name: 'Ethereum Mainnet', + nativeCurrency: 'ETH', + blockExplorerUrls: [], + defaultBlockExplorerUrlIndex: 0, + defaultRpcEndpointIndex: 0, + rpcEndpoints: [ + { + networkClientId: 'mainnet', + type: RpcEndpointType.Custom, + url: 'https://mainnet.infura.io/v3/test', + failoverUrls: [], + }, + ], + }, + }, + }); + mockFindNetworkClientIdByChainId(() => 'mainnet'); - expect(mockFetchMultiChainBalances).toHaveBeenCalled(); - expect(mockGetBalancesInSingleCall).not.toHaveBeenCalled(); - assertTokensNeverAdded(); - }); + // Disable token detection - this should trigger static mainnet token list usage + triggerPreferencesStateChange({ + ...getDefaultPreferencesState(), + useTokenDetection: false, + }); + + // Trigger detection with forceRpc to ensure we test the static token list path + await controller.detectTokens({ + chainIds: [ChainId.mainnet], + forceRpc: true, + }); - it('fallbacks from using the Accounts API if fails', async () => { - // Test 1 - fetch supported networks fails - let mockAPI = mockMultiChainAccountsService(); - mockAPI.mockFetchSupportedNetworks.mockRejectedValue( - new Error('Mock Error'), + // The detection should have been attempted (static token list is used internally) + // We verify the getBalancesInSingleCall was called, indicating detection ran + expect(mockGetBalancesInSingleCall).toHaveBeenCalled(); + }, ); - let actResult = await arrangeActTestDetectTokensWithAccountsAPI({ - mockMultiChainAPI: mockAPI, + }); + + it('should skip chains supported by Accounts API when forceRpc is false', async () => { + const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ + [sampleTokenA.address]: new BN(1), + }); + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', }); + await withController( + { + options: { + disabled: false, + getBalancesInSingleCall: mockGetBalancesInSingleCall, + }, + mocks: { + getSelectedAccount: selectedAccount, + getAccount: selectedAccount, + }, + }, + async ({ + controller, + mockNetworkState, + mockFindNetworkClientIdByChainId, + }) => { + const defaultState = getDefaultNetworkControllerState(); + mockNetworkState({ + ...defaultState, + selectedNetworkClientId: 'mainnet', + networkConfigurationsByChainId: { + ...defaultState.networkConfigurationsByChainId, + '0x1': { + chainId: '0x1', + name: 'Ethereum Mainnet', + nativeCurrency: 'ETH', + blockExplorerUrls: [], + defaultBlockExplorerUrlIndex: 0, + defaultRpcEndpointIndex: 0, + rpcEndpoints: [ + { + networkClientId: 'mainnet', + type: RpcEndpointType.Custom, + url: 'https://mainnet.infura.io/v3/test', + failoverUrls: [], + }, + ], + }, + }, + }); + mockFindNetworkClientIdByChainId(() => 'mainnet'); - expect(actResult.mockFetchMultiChainBalances).not.toHaveBeenCalled(); // never called as could not fetch supported networks... - expect(actResult.mockGetBalancesInSingleCall).toHaveBeenCalled(); // ...so then RPC flow was initiated - actResult.assertAddedTokens(actResult.rpcToken); + // Call detectTokens with mainnet (which is in SUPPORTED_NETWORKS_ACCOUNTS_API_V4) + // Without forceRpc, it should skip mainnet + await controller.detectTokens({ + chainIds: [ChainId.mainnet], + }); - // Test 2 - fetch multi chain fails - mockAPI = mockMultiChainAccountsService(); - mockAPI.mockFetchMultiChainBalances.mockRejectedValue( - new Error('Mock Error'), + // Should NOT call getBalancesInSingleCall since mainnet is skipped + expect(mockGetBalancesInSingleCall).not.toHaveBeenCalled(); + }, ); - actResult = await arrangeActTestDetectTokensWithAccountsAPI({ - mockMultiChainAPI: mockAPI, - }); - - expect(actResult.mockFetchMultiChainBalances).toHaveBeenCalled(); // API was called, but failed... - expect(actResult.mockGetBalancesInSingleCall).toHaveBeenCalled(); // ...so then RPC flow was initiated - actResult.assertAddedTokens(actResult.rpcToken); }); - it('uses the Accounts API but does not add tokens that are already added', async () => { - // Here we populate the token state with a token that exists in the tokenAPI. - // So the token retrieved from the API should not be added - const { assertTokensNeverAdded, mockFetchMultiChainBalances } = - await arrangeActTestDetectTokensWithAccountsAPI({ - overrideMockTokenGetState: { - allDetectedTokens: { + it('should detect tokens on Accounts API supported chains when forceRpc is true', async () => { + const mainnetUSDC = '0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48'; + const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ + [mainnetUSDC]: new BN(1), + }); + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); + await withController( + { + options: { + disabled: false, + getBalancesInSingleCall: mockGetBalancesInSingleCall, + }, + mocks: { + getSelectedAccount: selectedAccount, + getAccount: selectedAccount, + }, + }, + async ({ + controller, + mockNetworkState, + mockFindNetworkClientIdByChainId, + mockTokenListGetState, + triggerPreferencesStateChange, + }) => { + const defaultState = getDefaultNetworkControllerState(); + mockNetworkState({ + ...defaultState, + selectedNetworkClientId: 'mainnet', + networkConfigurationsByChainId: { + ...defaultState.networkConfigurationsByChainId, '0x1': { - '0x0000000000000000000000000000000000000001': [ + chainId: '0x1', + name: 'Ethereum Mainnet', + nativeCurrency: 'ETH', + blockExplorerUrls: [], + defaultBlockExplorerUrlIndex: 0, + defaultRpcEndpointIndex: 0, + rpcEndpoints: [ { - address: sampleTokenB.address, - name: sampleTokenB.name, - symbol: sampleTokenB.symbol, - decimals: sampleTokenB.decimals, - aggregators: sampleTokenB.aggregators, + networkClientId: 'mainnet', + type: RpcEndpointType.Custom, + url: 'https://mainnet.infura.io/v3/test', + failoverUrls: [], }, ], }, }, - }, - }); + }); + mockFindNetworkClientIdByChainId(() => 'mainnet'); + + // Provide token list data for mainnet + mockTokenListGetState({ + ...getDefaultTokenListState(), + tokensChainsCache: { + '0x1': { + timestamp: 0, + data: { + [mainnetUSDC]: { + name: 'USD Coin', + symbol: 'USDC', + decimals: 6, + address: mainnetUSDC, + occurrences: 1, + aggregators: [], + iconUrl: '', + }, + }, + }, + }, + }); + + // Enable token detection for mainnet + triggerPreferencesStateChange({ + ...getDefaultPreferencesState(), + useTokenDetection: true, + }); - expect(mockFetchMultiChainBalances).toHaveBeenCalled(); - assertTokensNeverAdded(); + // Call detectTokens with forceRpc: true to force RPC detection on mainnet + await controller.detectTokens({ + chainIds: [ChainId.mainnet], + forceRpc: true, + }); + + // Should call getBalancesInSingleCall since forceRpc bypasses Accounts API filter + expect(mockGetBalancesInSingleCall).toHaveBeenCalled(); + }, + ); }); }); @@ -3191,102 +3205,28 @@ describe('TokenDetectionController', () => { }); }); - describe('TransactionController:transactionConfirmed', () => { - let clock: sinon.SinonFakeTimers; - beforeEach(() => { - clock = sinon.useFakeTimers(); - }); + describe('constructor options', () => { + describe('useTokenDetection', () => { + it('should disable token detection when useTokenDetection is false', async () => { + const mockGetBalancesInSingleCall = jest.fn(); - afterEach(() => { - clock.restore(); - }); - it('calls detectTokens when a transaction is confirmed', async () => { - const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ - [sampleTokenA.address]: new BN(1), - }); - const firstSelectedAccount = createMockInternalAccount({ - address: '0x0000000000000000000000000000000000000001', - }); - const secondSelectedAccount = createMockInternalAccount({ - address: '0x0000000000000000000000000000000000000002', - }); - await withController( - { - options: { - disabled: false, - getBalancesInSingleCall: mockGetBalancesInSingleCall, - useAccountsAPI: true, // USING ACCOUNTS API + await withController( + { + options: { + useTokenDetection: () => false, + disabled: false, + getBalancesInSingleCall: mockGetBalancesInSingleCall, + }, + mocks: { + getSelectedAccount: defaultSelectedAccount, + }, }, - mocks: { - getSelectedAccount: firstSelectedAccount, - }, - }, - async ({ - mockGetAccount, - mockTokenListGetState, - triggerTransactionConfirmed, - callActionSpy, - }) => { - mockMultiChainAccountsService(); - mockTokenListGetState({ - ...getDefaultTokenListState(), - tokensChainsCache: { - '0x1': { - timestamp: 0, - data: { - [sampleTokenA.address]: { - name: sampleTokenA.name, - symbol: sampleTokenA.symbol, - decimals: sampleTokenA.decimals, - address: sampleTokenA.address, - occurrences: 1, - aggregators: sampleTokenA.aggregators, - iconUrl: sampleTokenA.image, - }, - }, - }, - }, - }); - - mockGetAccount(secondSelectedAccount); - triggerTransactionConfirmed({ - chainId: '0x1', - status: TransactionStatus.confirmed, - } as unknown as TransactionMeta); - await advanceTime({ clock, duration: 1 }); - - expect(callActionSpy).toHaveBeenCalledWith( - 'TokensController:addTokens', - [sampleTokenA], - 'mainnet', - ); - }, - ); - }); - }); - - describe('constructor options', () => { - describe('useTokenDetection', () => { - it('should disable token detection when useTokenDetection is false', async () => { - const mockGetBalancesInSingleCall = jest.fn(); - - await withController( - { - options: { - useTokenDetection: () => false, - disabled: false, - getBalancesInSingleCall: mockGetBalancesInSingleCall, - }, - mocks: { - getSelectedAccount: defaultSelectedAccount, - }, - }, - async ({ controller }) => { - // Try to detect tokens - await controller.detectTokens(); - - // Should not call getBalancesInSingleCall when useTokenDetection is false - expect(mockGetBalancesInSingleCall).not.toHaveBeenCalled(); + async ({ controller }) => { + // Try to detect tokens + await controller.detectTokens(); + + // Should not call getBalancesInSingleCall when useTokenDetection is false + expect(mockGetBalancesInSingleCall).not.toHaveBeenCalled(); }, ); }); @@ -3305,11 +3245,16 @@ describe('TokenDetectionController', () => { getSelectedAccount: defaultSelectedAccount, }, }, - async ({ controller, mockTokenListGetState }) => { + async ({ controller, mockTokenListGetState, mockNetworkState }) => { + // Set selectedNetworkClientId to avalanche (not in SUPPORTED_NETWORKS_ACCOUNTS_API_V4) + mockNetworkState({ + ...getDefaultNetworkControllerState(), + selectedNetworkClientId: 'avalanche', + }); mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { - '0x1': { + '0xa86a': { timestamp: 0, data: { [sampleTokenA.address]: { @@ -3326,6 +3271,8 @@ describe('TokenDetectionController', () => { }, }); + // Start the controller to make it active + await controller.start(); // Try to detect tokens await controller.detectTokens(); @@ -3372,11 +3319,16 @@ describe('TokenDetectionController', () => { getSelectedAccount: defaultSelectedAccount, }, }, - async ({ controller, mockTokenListGetState }) => { + async ({ controller, mockTokenListGetState, mockNetworkState }) => { + // Set selectedNetworkClientId to avalanche (not in SUPPORTED_NETWORKS_ACCOUNTS_API_V4) + mockNetworkState({ + ...getDefaultNetworkControllerState(), + selectedNetworkClientId: 'avalanche', + }); mockTokenListGetState({ ...getDefaultTokenListState(), tokensChainsCache: { - '0x1': { + '0xa86a': { timestamp: 0, data: { [sampleTokenA.address]: { @@ -3401,350 +3353,319 @@ describe('TokenDetectionController', () => { ); }); }); + }); - describe('useExternalServices', () => { - it('should not use external services when useExternalServices is false (default)', async () => { - const mockFetchSupportedNetworks = jest.spyOn( - MutliChainAccountsServiceModule, - 'fetchSupportedNetworks', - ); + describe('addDetectedTokensViaWs', () => { + it('should add tokens detected from websocket with metadata from cache', async () => { + const mockTokenAddress = '0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48'; + const checksummedTokenAddress = + '0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48'; + const chainId = '0xa86a'; - await withController( - { - options: { - useExternalServices: () => false, - disabled: false, - useAccountsAPI: true, - }, - mocks: { - getSelectedAccount: defaultSelectedAccount, - }, + await withController( + { + options: { + disabled: false, }, - async ({ controller }) => { - await controller.detectTokens(); - - // Should not call fetchSupportedNetworks when useExternalServices is false - expect(mockFetchSupportedNetworks).not.toHaveBeenCalled(); + mockTokenListState: { + tokensChainsCache: { + [chainId]: { + timestamp: 0, + data: { + [mockTokenAddress]: { + name: 'USD Coin', + symbol: 'USDC', + decimals: 6, + address: mockTokenAddress, + aggregators: [], + iconUrl: 'https://example.com/usdc.png', + occurrences: 11, + }, + }, + }, + }, }, - ); - }); - - it('should use external services when useExternalServices is true', async () => { - const mockFetchSupportedNetworks = jest - .spyOn(MutliChainAccountsServiceModule, 'fetchSupportedNetworks') - .mockResolvedValue([1, 137]); // Mainnet and Polygon + }, + async ({ controller, callActionSpy }) => { + await controller.addDetectedTokensViaWs({ + tokensSlice: [mockTokenAddress], + chainId: chainId as Hex, + }); - jest - .spyOn(MutliChainAccountsServiceModule, 'fetchMultiChainBalances') - .mockResolvedValue({ - count: 1, - balances: [ + expect(callActionSpy).toHaveBeenCalledWith( + 'TokensController:addTokens', + [ { - object: 'token_balance', - address: sampleTokenA.address, - symbol: sampleTokenA.symbol, - name: sampleTokenA.name, - decimals: sampleTokenA.decimals, - chainId: 1, - balance: '1000000000000000000', + address: checksummedTokenAddress, + decimals: 6, + symbol: 'USDC', + aggregators: [], + image: 'https://example.com/usdc.png', + isERC721: false, + name: 'USD Coin', }, ], - unprocessedNetworks: [], - }); + 'avalanche', + ); + }, + ); + }); - await withController( - { - options: { - useExternalServices: () => true, - disabled: false, - useAccountsAPI: true, - }, - mocks: { - getSelectedAccount: defaultSelectedAccount, - }, - }, - async ({ controller, mockTokenListGetState }) => { - mockTokenListGetState({ - ...getDefaultTokenListState(), - tokensChainsCache: { - '0x1': { - timestamp: 0, - data: { - [sampleTokenA.address]: { - name: sampleTokenA.name, - symbol: sampleTokenA.symbol, - decimals: sampleTokenA.decimals, - address: sampleTokenA.address, - aggregators: sampleTokenA.aggregators, - iconUrl: sampleTokenA.image, - occurrences: 11, - }, - }, - }, - }, - }); + it('should skip tokens not found in cache and log warning', async () => { + const mockTokenAddress = '0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48'; + const chainId = '0xa86a'; - await controller.detectTokens(); + const consoleSpy = jest.spyOn(console, 'warn').mockImplementation(); - // Should call fetchSupportedNetworks when useExternalServices is true - expect(mockFetchSupportedNetworks).toHaveBeenCalled(); + await withController( + { + options: { + disabled: false, }, - ); - }); - - it('should not use external services when useAccountsAPI is false, regardless of useExternalServices', async () => { - const mockFetchSupportedNetworks = jest.spyOn( - MutliChainAccountsServiceModule, - 'fetchSupportedNetworks', - ); - - await withController( - { - options: { - useExternalServices: () => true, - disabled: false, - useAccountsAPI: false, - }, - mocks: { - getSelectedAccount: defaultSelectedAccount, + mockTokenListState: { + tokensChainsCache: { + [chainId]: { + timestamp: 0, + data: {}, + }, }, }, - async ({ controller }) => { - await controller.detectTokens(); + }, + async ({ controller, callActionSpy }) => { + await controller.addDetectedTokensViaWs({ + tokensSlice: [mockTokenAddress], + chainId: chainId as Hex, + }); - // Should not call fetchSupportedNetworks when useAccountsAPI is false - expect(mockFetchSupportedNetworks).not.toHaveBeenCalled(); - }, - ); + // Should log warning about missing token metadata + expect(consoleSpy).toHaveBeenCalledWith( + expect.stringContaining('Token metadata not found in cache'), + ); + + // Should not call addTokens if no tokens have metadata + expect(callActionSpy).not.toHaveBeenCalledWith( + 'TokensController:addTokens', + expect.anything(), + expect.anything(), + ); + + consoleSpy.mockRestore(); + }, + ); + }); + + it('should add all tokens provided without filtering (filtering is caller responsibility)', async () => { + const mockTokenAddress = '0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48'; + const checksummedTokenAddress = + '0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48'; + const secondTokenAddress = '0x1f573d6fb3f13d689ff844b4ce37794d79a7ff1c'; + const checksummedSecondTokenAddress = + '0x1F573D6Fb3F13d689FF844B4cE37794d79a7FF1C'; + const chainId = '0xa86a'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', }); - it('should use external services when both useExternalServices and useAccountsAPI are true', async () => { - const mockFetchSupportedNetworks = jest - .spyOn(MutliChainAccountsServiceModule, 'fetchSupportedNetworks') - .mockResolvedValue([1, 137]); + await withController( + { + options: { + disabled: false, + }, + mocks: { + getSelectedAccount: selectedAccount, + getAccount: selectedAccount, + }, + mockTokenListState: { + tokensChainsCache: { + [chainId]: { + timestamp: 0, + data: { + [mockTokenAddress]: { + name: 'USD Coin', + symbol: 'USDC', + decimals: 6, + address: mockTokenAddress, + aggregators: [], + iconUrl: 'https://example.com/usdc.png', + occurrences: 11, + }, + [secondTokenAddress]: { + name: 'Bancor', + symbol: 'BNT', + decimals: 18, + address: secondTokenAddress, + aggregators: [], + iconUrl: 'https://example.com/bnt.png', + occurrences: 11, + }, + }, + }, + }, + }, + }, + async ({ controller, callActionSpy }) => { + // Add both tokens via websocket + await controller.addDetectedTokensViaWs({ + tokensSlice: [mockTokenAddress, secondTokenAddress], + chainId: chainId as Hex, + }); - jest - .spyOn(MutliChainAccountsServiceModule, 'fetchMultiChainBalances') - .mockResolvedValue({ - count: 1, - balances: [ + // Should add both tokens (no filtering in addDetectedTokensViaWs) + expect(callActionSpy).toHaveBeenCalledWith( + 'TokensController:addTokens', + [ { - object: 'token_balance', - address: sampleTokenA.address, - symbol: sampleTokenA.symbol, - name: sampleTokenA.name, - decimals: sampleTokenA.decimals, - chainId: 1, - balance: '1000000000000000000', + address: checksummedTokenAddress, + decimals: 6, + symbol: 'USDC', + aggregators: [], + image: 'https://example.com/usdc.png', + isERC721: false, + name: 'USD Coin', + }, + { + address: checksummedSecondTokenAddress, + decimals: 18, + symbol: 'BNT', + aggregators: [], + image: 'https://example.com/bnt.png', + isERC721: false, + name: 'Bancor', }, ], - unprocessedNetworks: [], - }); + 'avalanche', + ); + }, + ); + }); - await withController( - { - options: { - useExternalServices: () => true, - disabled: false, - useAccountsAPI: true, - }, - mocks: { - getSelectedAccount: defaultSelectedAccount, - }, + it('should track metrics when adding tokens from websocket', async () => { + const mockTokenAddress = '0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48'; + const checksummedTokenAddress = + '0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48'; + const chainId = '0xa86a'; + const mockTrackMetricsEvent = jest.fn(); + + await withController( + { + options: { + disabled: false, + trackMetaMetricsEvent: mockTrackMetricsEvent, }, - async ({ controller, mockTokenListGetState }) => { - mockTokenListGetState({ - ...getDefaultTokenListState(), - tokensChainsCache: { - '0x1': { - timestamp: 0, - data: { - [sampleTokenA.address]: { - name: sampleTokenA.name, - symbol: sampleTokenA.symbol, - decimals: sampleTokenA.decimals, - address: sampleTokenA.address, - aggregators: sampleTokenA.aggregators, - iconUrl: sampleTokenA.image, - occurrences: 11, - }, + mockTokenListState: { + tokensChainsCache: { + [chainId]: { + timestamp: 0, + data: { + [mockTokenAddress]: { + name: 'USD Coin', + symbol: 'USDC', + decimals: 6, + address: mockTokenAddress, + aggregators: [], + iconUrl: 'https://example.com/usdc.png', + occurrences: 11, }, }, }, - }); - - await controller.detectTokens(); - - // Should call both external service methods when both flags are true - expect(mockFetchSupportedNetworks).toHaveBeenCalled(); + }, }, - ); - }); + }, + async ({ controller, callActionSpy }) => { + await controller.addDetectedTokensViaWs({ + tokensSlice: [mockTokenAddress], + chainId: chainId as Hex, + }); - it('should fall back to RPC detection when external services fail', async () => { - const mockFetchSupportedNetworks = jest - .spyOn(MutliChainAccountsServiceModule, 'fetchSupportedNetworks') - .mockResolvedValue([1, 137]); + // Should track metrics event + expect(mockTrackMetricsEvent).toHaveBeenCalledWith({ + event: 'Token Detected', + category: 'Wallet', + properties: { + tokens: [`USDC - ${checksummedTokenAddress}`], + token_standard: 'ERC20', + asset_type: 'TOKEN', + }, + }); - const mockFetchMultiChainBalances = jest - .spyOn(MutliChainAccountsServiceModule, 'fetchMultiChainBalances') - .mockRejectedValue(new Error('API Error')); + expect(callActionSpy).toHaveBeenCalledWith( + 'TokensController:addTokens', + expect.anything(), + expect.anything(), + ); + }, + ); + }); - const mockGetBalancesInSingleCall = jest.fn().mockResolvedValue({ - [sampleTokenA.address]: new BN(1), - }); + it('should be callable directly as a public method on the controller instance', async () => { + const mockTokenAddress = '0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48'; + const checksummedTokenAddress = + '0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48'; + const chainId = '0xa86a'; - await withController( - { - options: { - useExternalServices: () => true, - useAccountsAPI: true, - disabled: false, - getBalancesInSingleCall: mockGetBalancesInSingleCall, - }, - mocks: { - getSelectedAccount: defaultSelectedAccount, - }, + await withController( + { + options: { + disabled: false, }, - async ({ controller, mockTokenListGetState }) => { - mockTokenListGetState({ - ...getDefaultTokenListState(), - tokensChainsCache: { - '0x1': { - timestamp: 0, - data: { - [sampleTokenA.address]: { - name: sampleTokenA.name, - symbol: sampleTokenA.symbol, - decimals: sampleTokenA.decimals, - address: sampleTokenA.address, - aggregators: sampleTokenA.aggregators, - iconUrl: sampleTokenA.image, - occurrences: 11, - }, + mockTokenListState: { + tokensChainsCache: { + [chainId]: { + timestamp: 0, + data: { + [mockTokenAddress]: { + name: 'USD Coin', + symbol: 'USDC', + decimals: 6, + address: mockTokenAddress, + aggregators: [], + iconUrl: 'https://example.com/usdc.png', + occurrences: 11, }, }, }, - }); - - await controller.detectTokens(); - - // Should call external services first - expect(mockFetchSupportedNetworks).toHaveBeenCalled(); - expect(mockFetchMultiChainBalances).toHaveBeenCalled(); - - // Should fall back to RPC detection when external services fail - expect(mockGetBalancesInSingleCall).toHaveBeenCalled(); - }, - ); - }); - }); - - describe('useTokenDetection and useExternalServices combination', () => { - it('should not use external services when useTokenDetection is false, regardless of useExternalServices', async () => { - const mockFetchSupportedNetworks = jest.spyOn( - MutliChainAccountsServiceModule, - 'fetchSupportedNetworks', - ); - - await withController( - { - options: { - useTokenDetection: () => false, - useExternalServices: () => true, - disabled: false, - useAccountsAPI: true, }, - mocks: { - getSelectedAccount: defaultSelectedAccount, - }, - }, - async ({ controller }) => { - await controller.detectTokens(); - - // Should not call external services when token detection is disabled - expect(mockFetchSupportedNetworks).not.toHaveBeenCalled(); }, - ); - }); - - it('should use external services when both useTokenDetection and useExternalServices are true', async () => { - const mockFetchSupportedNetworks = jest - .spyOn(MutliChainAccountsServiceModule, 'fetchSupportedNetworks') - .mockResolvedValue([1, 137]); + }, + async ({ controller, callActionSpy }) => { + // Call the public method directly on the controller instance + await controller.addDetectedTokensViaWs({ + tokensSlice: [mockTokenAddress], + chainId: chainId as Hex, + }); - jest - .spyOn(MutliChainAccountsServiceModule, 'fetchMultiChainBalances') - .mockResolvedValue({ - count: 1, - balances: [ + expect(callActionSpy).toHaveBeenCalledWith( + 'TokensController:addTokens', + [ { - object: 'token_balance', - address: sampleTokenA.address, - symbol: sampleTokenA.symbol, - name: sampleTokenA.name, - decimals: sampleTokenA.decimals, - chainId: 1, - balance: '1000000000000000000', + address: checksummedTokenAddress, + decimals: 6, + symbol: 'USDC', + aggregators: [], + image: 'https://example.com/usdc.png', + isERC721: false, + name: 'USD Coin', }, ], - unprocessedNetworks: [], - }); - - await withController( - { - options: { - useTokenDetection: () => true, - useExternalServices: () => true, - disabled: false, - useAccountsAPI: true, - }, - mocks: { - getSelectedAccount: defaultSelectedAccount, - }, - }, - async ({ controller, mockTokenListGetState }) => { - mockTokenListGetState({ - ...getDefaultTokenListState(), - tokensChainsCache: { - '0x1': { - timestamp: 0, - data: { - [sampleTokenA.address]: { - name: sampleTokenA.name, - symbol: sampleTokenA.symbol, - decimals: sampleTokenA.decimals, - address: sampleTokenA.address, - aggregators: sampleTokenA.aggregators, - iconUrl: sampleTokenA.image, - occurrences: 11, - }, - }, - }, - }, - }); - - await controller.detectTokens(); - - // Should call external services when both flags are true - expect(mockFetchSupportedNetworks).toHaveBeenCalled(); - }, - ); - }); + 'avalanche', + ); + }, + ); }); }); - describe('addDetectedTokensViaWs', () => { - it('should add tokens detected from websocket with metadata from cache', async () => { + describe('addDetectedTokensViaPolling', () => { + it('should add tokens detected from polling with metadata from cache', async () => { const mockTokenAddress = '0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48'; const checksummedTokenAddress = '0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48'; - const chainId = '0x1'; + const chainId = '0xa86a'; await withController( { options: { disabled: false, + useTokenDetection: () => true, }, mockTokenListState: { tokensChainsCache: { @@ -3766,7 +3687,7 @@ describe('TokenDetectionController', () => { }, }, async ({ controller, callActionSpy }) => { - await controller.addDetectedTokensViaWs({ + await controller.addDetectedTokensViaPolling({ tokensSlice: [mockTokenAddress], chainId: chainId as Hex, }); @@ -3784,63 +3705,62 @@ describe('TokenDetectionController', () => { name: 'USD Coin', }, ], - 'mainnet', + 'avalanche', ); }, ); }); - it('should skip tokens not found in cache and log warning', async () => { + it('should skip if useTokenDetection is disabled', async () => { const mockTokenAddress = '0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48'; - const chainId = '0x1'; - - const consoleSpy = jest.spyOn(console, 'warn').mockImplementation(); + const chainId = '0xa86a'; await withController( { options: { disabled: false, + useTokenDetection: () => false, }, mockTokenListState: { tokensChainsCache: { [chainId]: { timestamp: 0, - data: {}, + data: { + [mockTokenAddress]: { + name: 'USD Coin', + symbol: 'USDC', + decimals: 6, + address: mockTokenAddress, + aggregators: [], + iconUrl: 'https://example.com/usdc.png', + occurrences: 11, + }, + }, }, }, }, }, async ({ controller, callActionSpy }) => { - await controller.addDetectedTokensViaWs({ + await controller.addDetectedTokensViaPolling({ tokensSlice: [mockTokenAddress], chainId: chainId as Hex, }); - // Should log warning about missing token metadata - expect(consoleSpy).toHaveBeenCalledWith( - expect.stringContaining('Token metadata not found in cache'), - ); - - // Should not call addTokens if no tokens have metadata + // Should not call addTokens when useTokenDetection is disabled expect(callActionSpy).not.toHaveBeenCalledWith( 'TokensController:addTokens', expect.anything(), expect.anything(), ); - - consoleSpy.mockRestore(); }, ); }); - it('should add all tokens provided without filtering (filtering is caller responsibility)', async () => { + it('should skip tokens already in allTokens', async () => { const mockTokenAddress = '0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48'; const checksummedTokenAddress = '0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48'; - const secondTokenAddress = '0x1f573d6fb3f13d689ff844b4ce37794d79a7ff1c'; - const checksummedSecondTokenAddress = - '0x1F573D6Fb3F13d689FF844B4cE37794d79a7FF1C'; - const chainId = '0x1'; + const chainId = '0xa86a'; const selectedAccount = createMockInternalAccount({ address: '0x0000000000000000000000000000000000000001', }); @@ -3849,10 +3769,26 @@ describe('TokenDetectionController', () => { { options: { disabled: false, + useTokenDetection: () => true, }, mocks: { - getSelectedAccount: selectedAccount, getAccount: selectedAccount, + getSelectedAccount: selectedAccount, + }, + mockTokensState: { + allTokens: { + [chainId]: { + [selectedAccount.address]: [ + { + address: checksummedTokenAddress, + symbol: 'USDC', + decimals: 6, + }, + ], + }, + }, + allDetectedTokens: {}, + allIgnoredTokens: {}, }, mockTokenListState: { tokensChainsCache: { @@ -3868,68 +3804,54 @@ describe('TokenDetectionController', () => { iconUrl: 'https://example.com/usdc.png', occurrences: 11, }, - [secondTokenAddress]: { - name: 'Bancor', - symbol: 'BNT', - decimals: 18, - address: secondTokenAddress, - aggregators: [], - iconUrl: 'https://example.com/bnt.png', - occurrences: 11, - }, }, }, }, }, }, async ({ controller, callActionSpy }) => { - // Add both tokens via websocket - await controller.addDetectedTokensViaWs({ - tokensSlice: [mockTokenAddress, secondTokenAddress], + await controller.addDetectedTokensViaPolling({ + tokensSlice: [mockTokenAddress], chainId: chainId as Hex, }); - // Should add both tokens (no filtering in addDetectedTokensViaWs) - expect(callActionSpy).toHaveBeenCalledWith( + // Should not call addTokens for tokens already in allTokens + expect(callActionSpy).not.toHaveBeenCalledWith( 'TokensController:addTokens', - [ - { - address: checksummedTokenAddress, - decimals: 6, - symbol: 'USDC', - aggregators: [], - image: 'https://example.com/usdc.png', - isERC721: false, - name: 'USD Coin', - }, - { - address: checksummedSecondTokenAddress, - decimals: 18, - symbol: 'BNT', - aggregators: [], - image: 'https://example.com/bnt.png', - isERC721: false, - name: 'Bancor', - }, - ], - 'mainnet', + expect.anything(), + expect.anything(), ); }, ); }); - it('should track metrics when adding tokens from websocket', async () => { + it('should skip tokens in allIgnoredTokens', async () => { const mockTokenAddress = '0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48'; const checksummedTokenAddress = '0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48'; - const chainId = '0x1'; - const mockTrackMetricsEvent = jest.fn(); + const chainId = '0xa86a'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: false, - trackMetaMetricsEvent: mockTrackMetricsEvent, + useTokenDetection: () => true, + }, + mocks: { + getAccount: selectedAccount, + getSelectedAccount: selectedAccount, + }, + mockTokensState: { + allTokens: {}, + allDetectedTokens: {}, + allIgnoredTokens: { + [chainId]: { + [selectedAccount.address]: [checksummedTokenAddress], + }, + }, }, mockTokenListState: { tokensChainsCache: { @@ -3951,23 +3873,13 @@ describe('TokenDetectionController', () => { }, }, async ({ controller, callActionSpy }) => { - await controller.addDetectedTokensViaWs({ + await controller.addDetectedTokensViaPolling({ tokensSlice: [mockTokenAddress], chainId: chainId as Hex, }); - // Should track metrics event - expect(mockTrackMetricsEvent).toHaveBeenCalledWith({ - event: 'Token Detected', - category: 'Wallet', - properties: { - tokens: [`USDC - ${checksummedTokenAddress}`], - token_standard: 'ERC20', - asset_type: 'TOKEN', - }, - }); - - expect(callActionSpy).toHaveBeenCalledWith( + // Should not call addTokens for tokens in allIgnoredTokens + expect(callActionSpy).not.toHaveBeenCalledWith( 'TokensController:addTokens', expect.anything(), expect.anything(), @@ -3976,57 +3888,111 @@ describe('TokenDetectionController', () => { ); }); - it('should be callable directly as a public method on the controller instance', async () => { - const mockTokenAddress = '0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48'; - const checksummedTokenAddress = + it('should add only untracked tokens when mixed with tracked/ignored', async () => { + const trackedTokenAddress = '0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48'; + const trackedTokenChecksummed = '0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48'; - const chainId = '0x1'; + const ignoredTokenAddress = '0xdac17f958d2ee523a2206206994597c13d831ec7'; + const ignoredTokenChecksummed = + '0xdAC17F958D2ee523a2206206994597C13D831ec7'; + const newTokenAddress = '0x1f573d6fb3f13d689ff844b4ce37794d79a7ff1c'; + const newTokenChecksummed = '0x1F573D6Fb3F13d689FF844B4cE37794d79a7FF1C'; + const chainId = '0xa86a'; + const selectedAccount = createMockInternalAccount({ + address: '0x0000000000000000000000000000000000000001', + }); await withController( { options: { disabled: false, + useTokenDetection: () => true, + }, + mocks: { + getAccount: selectedAccount, + getSelectedAccount: selectedAccount, + }, + mockTokensState: { + allTokens: { + [chainId]: { + [selectedAccount.address]: [ + { + address: trackedTokenChecksummed, + symbol: 'USDC', + decimals: 6, + }, + ], + }, + }, + allDetectedTokens: {}, + allIgnoredTokens: { + [chainId]: { + [selectedAccount.address]: [ignoredTokenChecksummed], + }, + }, }, mockTokenListState: { tokensChainsCache: { [chainId]: { timestamp: 0, data: { - [mockTokenAddress]: { + [trackedTokenAddress]: { name: 'USD Coin', symbol: 'USDC', decimals: 6, - address: mockTokenAddress, + address: trackedTokenAddress, aggregators: [], iconUrl: 'https://example.com/usdc.png', occurrences: 11, }, + [ignoredTokenAddress]: { + name: 'Tether USD', + symbol: 'USDT', + decimals: 6, + address: ignoredTokenAddress, + aggregators: [], + iconUrl: 'https://example.com/usdt.png', + occurrences: 11, + }, + [newTokenAddress]: { + name: 'Bancor', + symbol: 'BNT', + decimals: 18, + address: newTokenAddress, + aggregators: [], + iconUrl: 'https://example.com/bnt.png', + occurrences: 11, + }, }, }, }, }, }, async ({ controller, callActionSpy }) => { - // Call the public method directly on the controller instance - await controller.addDetectedTokensViaWs({ - tokensSlice: [mockTokenAddress], + await controller.addDetectedTokensViaPolling({ + tokensSlice: [ + trackedTokenAddress, + ignoredTokenAddress, + newTokenAddress, + ], chainId: chainId as Hex, }); + // Should only add the new untracked token expect(callActionSpy).toHaveBeenCalledWith( 'TokensController:addTokens', [ { - address: checksummedTokenAddress, - decimals: 6, - symbol: 'USDC', + address: newTokenChecksummed, + decimals: 18, + symbol: 'BNT', aggregators: [], - image: 'https://example.com/usdc.png', + image: 'https://example.com/bnt.png', isERC721: false, - name: 'USD Coin', + name: 'Bancor', }, ], - 'mainnet', + 'avalanche', ); }, ); @@ -4040,7 +4006,7 @@ describe('TokenDetectionController', () => { * @param chainId - The chain ID. * @returns The constructed path. */ -function getTokensPath(chainId: Hex) { +function getTokensPath(chainId: Hex): string { return `/tokens/${convertHexToDecimal( chainId, )}?occurrenceFloor=3&includeNativeAssets=false&includeTokenFees=false&includeAssetType=false`; @@ -4065,7 +4031,6 @@ type WithControllerCallback = ({ triggerPreferencesStateChange, triggerSelectedAccountChange, triggerNetworkDidChange, - triggerTransactionConfirmed, }: { controller: TokenDetectionController; messenger: RootMessenger; @@ -4094,7 +4059,7 @@ type WithControllerCallback = ({ triggerPreferencesStateChange: (state: PreferencesState) => void; triggerSelectedAccountChange: (account: InternalAccount) => void; triggerNetworkDidChange: (state: NetworkState) => void; - triggerTransactionConfirmed: (transactionMeta: TransactionMeta) => void; + triggerTransactionConfirmed: (transactionMeta: { chainId: Hex }) => void; }) => Promise | ReturnValue; type WithControllerOptions = { @@ -4106,6 +4071,7 @@ type WithControllerOptions = { getBearerToken?: string; }; mockTokenListState?: Partial; + mockTokensState?: Partial; }; type WithControllerArgs = @@ -4125,7 +4091,13 @@ async function withController( ...args: WithControllerArgs ): Promise { const [{ ...rest }, fn] = args.length === 2 ? args : [{}, args[0]]; - const { options, isKeyringUnlocked, mocks, mockTokenListState } = rest; + const { + options, + isKeyringUnlocked, + mocks, + mockTokenListState, + mockTokensState, + } = rest; const messenger = buildRootMessenger(); const mockGetAccount = jest.fn(); @@ -4158,8 +4130,10 @@ async function withController( messenger.registerActionHandler( 'NetworkController:getNetworkClientById', mockGetNetworkClientById.mockImplementation(() => { + // Default to Avalanche (0xa86a) which is in SupportedTokenDetectionNetworks + // but NOT in SUPPORTED_NETWORKS_ACCOUNTS_API_V4 return { - configuration: { chainId: '0x1' }, + configuration: { chainId: '0xa86a' }, provider: {}, destroy: {}, blockTracker: {}, @@ -4181,12 +4155,19 @@ async function withController( const mockNetworkState = jest.fn(); messenger.registerActionHandler( 'NetworkController:getState', - mockNetworkState.mockReturnValue({ ...getDefaultNetworkControllerState() }), + mockNetworkState.mockReturnValue({ + ...getDefaultNetworkControllerState(), + // Default to avalanche so RPC detection works (not in SUPPORTED_NETWORKS_ACCOUNTS_API_V4) + selectedNetworkClientId: 'avalanche', + }), ); - const mockTokensState = jest.fn(); + const mockTokensStateFunc = jest.fn(); messenger.registerActionHandler( 'TokensController:getState', - mockTokensState.mockReturnValue({ ...getDefaultTokensState() }), + mockTokensStateFunc.mockReturnValue({ + ...getDefaultTokensState(), + ...mockTokensState, + }), ); const mockTokenListStateFunc = jest.fn(); messenger.registerActionHandler( @@ -4201,21 +4182,16 @@ async function withController( 'PreferencesController:getState', mockPreferencesState.mockReturnValue({ ...getDefaultPreferencesState(), + // Enable token detection by default for tests using Avalanche + useTokenDetection: true, }), ); - const mockGetBearerToken = jest.fn, []>(); - messenger.registerActionHandler( - 'AuthenticationController:getBearerToken', - mockGetBearerToken.mockResolvedValue( - mocks?.getBearerToken ?? 'mock-jwt-token', - ), - ); - const mockFindNetworkClientIdByChainId = jest.fn(); messenger.registerActionHandler( 'NetworkController:findNetworkClientIdByChainId', - mockFindNetworkClientIdByChainId.mockReturnValue('mainnet'), + // Default to 'avalanche' which is not in SUPPORTED_NETWORKS_ACCOUNTS_API_V4 + mockFindNetworkClientIdByChainId.mockReturnValue('avalanche'), ); messenger.registerActionHandler( @@ -4247,8 +4223,6 @@ async function withController( getBalancesInSingleCall: jest.fn(), trackMetaMetricsEvent: jest.fn(), messenger: tokenDetectionControllerMessenger, - useAccountsAPI: false, - platform: 'extension', ...options, }); try { @@ -4265,7 +4239,7 @@ async function withController( mockKeyringState.mockReturnValue(state); }, mockTokensGetState: (state: TokensControllerState) => { - mockTokensState.mockReturnValue(state); + mockTokensStateFunc.mockReturnValue(state); }, mockPreferencesGetState: (state: PreferencesState) => { mockPreferencesState.mockReturnValue(state); @@ -4317,10 +4291,13 @@ async function withController( triggerNetworkDidChange: (state: NetworkState) => { messenger.publish('NetworkController:networkDidChange', state); }, - triggerTransactionConfirmed: (transactionMeta: TransactionMeta) => { + triggerTransactionConfirmed: (transactionMeta: { chainId: Hex }) => { messenger.publish( 'TransactionController:transactionConfirmed', - transactionMeta, + // We only need chainId for this test, so cast to satisfy the type + transactionMeta as unknown as Parameters< + typeof messenger.publish<'TransactionController:transactionConfirmed'> + >[1], ); }, }); diff --git a/packages/assets-controllers/src/TokenDetectionController.ts b/packages/assets-controllers/src/TokenDetectionController.ts index 99931a55b88..f692e192786 100644 --- a/packages/assets-controllers/src/TokenDetectionController.ts +++ b/packages/assets-controllers/src/TokenDetectionController.ts @@ -13,16 +13,15 @@ import { ChainId, ERC20, safelyExecute, - safelyExecuteWithTimeout, isEqualCaseInsensitive, toChecksumHexAddress, - toHex, } from '@metamask/controller-utils'; import type { KeyringControllerGetStateAction, KeyringControllerLockEvent, KeyringControllerUnlockEvent, } from '@metamask/keyring-controller'; +import type { InternalAccount } from '@metamask/keyring-internal-api'; import type { Messenger } from '@metamask/messenger'; import type { NetworkClientId, @@ -40,15 +39,11 @@ import type { import type { AuthenticationController } from '@metamask/profile-sync-controller'; import type { TransactionControllerTransactionConfirmedEvent } from '@metamask/transaction-controller'; import type { Hex } from '@metamask/utils'; -import { hexToNumber } from '@metamask/utils'; import { isEqual, mapValues, isObject, get } from 'lodash'; import type { AssetsContractController } from './AssetsContractController'; import { isTokenDetectionSupportedForNetwork } from './assetsUtil'; -import { - fetchMultiChainBalances, - fetchSupportedNetworks, -} from './multi-chain-accounts-service'; +import { SUPPORTED_NETWORKS_ACCOUNTS_API_V4 } from './constants'; import type { GetTokenListState, TokenListMap, @@ -63,7 +58,6 @@ import type { } from './TokensController'; const DEFAULT_INTERVAL = 180000; -const ACCOUNTS_API_TIMEOUT_MS = 10000; type LegacyToken = { name: string; @@ -106,7 +100,7 @@ export const STATIC_MAINNET_TOKEN_LIST = Object.entries( */ export function mapChainIdWithTokenListMap( tokensChainsCache: TokensChainsCache, -) { +): Record { return mapValues(tokensChainsCache, (value) => { if (isObject(value) && 'data' in value) { return get(value, ['data']); @@ -129,9 +123,21 @@ export type TokenDetectionControllerAddDetectedTokensViaWsAction = { handler: TokenDetectionController['addDetectedTokensViaWs']; }; +export type TokenDetectionControllerAddDetectedTokensViaPollingAction = { + type: `TokenDetectionController:addDetectedTokensViaPolling`; + handler: TokenDetectionController['addDetectedTokensViaPolling']; +}; + +export type TokenDetectionControllerDetectTokensAction = { + type: `TokenDetectionController:detectTokens`; + handler: TokenDetectionController['detectTokens']; +}; + export type TokenDetectionControllerActions = | TokenDetectionControllerGetStateAction - | TokenDetectionControllerAddDetectedTokensViaWsAction; + | TokenDetectionControllerAddDetectedTokensViaWsAction + | TokenDetectionControllerAddDetectedTokensViaPollingAction + | TokenDetectionControllerDetectTokensAction; export type AllowedActions = | AccountsControllerGetSelectedAccountAction @@ -219,64 +225,13 @@ export class TokenDetectionController extends StaticIntervalPollingController void; - readonly #accountsAPI = { - isAccountsAPIEnabled: true, - supportedNetworksCache: null as number[] | null, - platform: '' as 'extension' | 'mobile', - - async getSupportedNetworks() { - /* istanbul ignore next */ - if (!this.isAccountsAPIEnabled) { - throw new Error('Accounts API Feature Switch is disabled'); - } - - /* istanbul ignore next */ - if (this.supportedNetworksCache) { - return this.supportedNetworksCache; - } - - const result = await fetchSupportedNetworks().catch(() => null); - this.supportedNetworksCache = result; - return result; - }, - - async getMultiNetworksBalances( - address: string, - chainIds: Hex[], - supportedNetworks: number[] | null, - jwtToken?: string, - ) { - const chainIdNumbers = chainIds.map((chainId) => hexToNumber(chainId)); - - if ( - !supportedNetworks || - !chainIdNumbers.every((id) => supportedNetworks.includes(id)) - ) { - const supportedNetworksErrStr = (supportedNetworks ?? []).toString(); - throw new Error( - `Unsupported Network: supported networks ${supportedNetworksErrStr}, requested networks: ${chainIdNumbers.toString()}`, - ); - } - - const result = await fetchMultiChainBalances( - address, - { - networks: chainIdNumbers, - }, - this.platform, - jwtToken, - ); - - // Return the full response including unprocessedNetworks - return result; - }, - }; - /** * Creates a TokenDetectionController instance. * @@ -286,10 +241,8 @@ export class TokenDetectionController extends StaticIntervalPollingController true, - useExternalServices = () => true, - platform, + useTokenDetection = (): boolean => true, + useExternalServices = (): boolean => true, }: { interval?: number; disabled?: boolean; @@ -310,15 +261,15 @@ export class TokenDetectionController extends StaticIntervalPollingController void; messenger: TokenDetectionControllerMessenger; - useAccountsAPI?: boolean; useTokenDetection?: () => boolean; useExternalServices?: () => boolean; - platform: 'extension' | 'mobile'; }) { super({ name: controllerName, @@ -332,6 +283,16 @@ export class TokenDetectionController extends StaticIntervalPollingController { + #registerEventListeners(): void { + this.messenger.subscribe('KeyringController:unlock', () => { this.#isUnlocked = true; - await this.#restartTokenDetection(); + this.#restartTokenDetection().catch(() => { + // Silently handle token detection errors + }); }); this.messenger.subscribe('KeyringController:lock', () => { @@ -379,20 +340,22 @@ export class TokenDetectionController extends StaticIntervalPollingController { + ({ tokensChainsCache }) => { const isEqualValues = this.#compareTokensChainsCache( tokensChainsCache, this.#tokensChainsCache, ); if (!isEqualValues) { - await this.#restartTokenDetection(); + this.#restartTokenDetection().catch(() => { + // Silently handle token detection errors + }); } }, ); this.messenger.subscribe( 'PreferencesController:stateChange', - async ({ useTokenDetection }) => { + ({ useTokenDetection }) => { const selectedAccount = this.#getSelectedAccount(); const isDetectionChangedFromPreferences = this.#isDetectionEnabledFromPreferences !== useTokenDetection; @@ -400,8 +363,10 @@ export class TokenDetectionController extends StaticIntervalPollingController { + // Silently handle token detection errors }); } }, @@ -409,7 +374,7 @@ export class TokenDetectionController extends StaticIntervalPollingController { + (selectedAccount) => { const { networkConfigurationsByChainId } = this.messenger.call( 'NetworkController:getState', ); @@ -419,9 +384,11 @@ export class TokenDetectionController extends StaticIntervalPollingController { + // Silently handle token detection errors }); } }, @@ -429,9 +396,11 @@ export class TokenDetectionController extends StaticIntervalPollingController { - await this.detectTokens({ + (transactionMeta) => { + this.detectTokens({ chainIds: [transactionMeta.chainId], + }).catch(() => { + // Silently handle token detection errors }); }, ); @@ -588,68 +557,6 @@ export class TokenDetectionController extends StaticIntervalPollingController { - if (supportedNetworks?.includes(hexToNumber(chainId))) { - chainsToDetectUsingAccountAPI.push(chainId); - } else { - chainsToDetectUsingRpc.push({ chainId, networkClientId }); - } - }); - - return { chainsToDetectUsingRpc, chainsToDetectUsingAccountAPI }; - } - - async #attemptAccountAPIDetection( - chainsToDetectUsingAccountAPI: Hex[], - addressToDetect: string, - supportedNetworks: number[] | null, - jwtToken?: string, - ) { - const result = await safelyExecuteWithTimeout( - async () => { - return this.#addDetectedTokensViaAPI({ - chainIds: chainsToDetectUsingAccountAPI, - selectedAddress: addressToDetect, - supportedNetworks, - jwtToken, - }); - }, - false, - ACCOUNTS_API_TIMEOUT_MS, - ); - - if (!result) { - return { result: 'failed' } as const; - } - - return result; - } - - #addChainsToRpcDetection( - chainsToDetectUsingRpc: NetworkClient[], - chainsToDetectUsingAccountAPI: Hex[], - clientNetworks: NetworkClient[], - ): void { - chainsToDetectUsingAccountAPI.forEach((chainId) => { - const networkEntry = clientNetworks.find( - (network) => network.chainId === chainId, - ); - if (networkEntry) { - chainsToDetectUsingRpc.push({ - chainId: networkEntry.chainId, - networkClientId: networkEntry.networkClientId, - }); - } - }); - } - #shouldDetectTokens(chainId: Hex): boolean { if (!isTokenDetectionSupportedForNetwork(chainId)) { return false; @@ -708,77 +615,50 @@ export class TokenDetectionController extends StaticIntervalPollingController { if (!this.isActive) { return; } - if (!this.#useTokenDetection()) { + // When forceRpc is true, bypass the useTokenDetection check to ensure RPC detection runs + if (!forceRpc && !this.#useTokenDetection()) { + return; + } + + // If external services are disabled and not forcing RPC, skip all detection + if (!forceRpc && !this.#useExternalServices()) { return; } const addressToDetect = selectedAddress ?? this.#getSelectedAddress(); const clientNetworks = this.#getCorrectNetworkClientIdByChainId(chainIds); - const jwtToken = await safelyExecuteWithTimeout( - () => { - return this.messenger.call('AuthenticationController:getBearerToken'); - }, - false, - 5000, - ); - - let supportedNetworks; - if (this.#accountsAPI.isAccountsAPIEnabled && this.#useExternalServices()) { - supportedNetworks = await this.#accountsAPI.getSupportedNetworks(); - } - const { chainsToDetectUsingRpc, chainsToDetectUsingAccountAPI } = - this.#getChainsToDetect(clientNetworks, supportedNetworks); - - // Try detecting tokens via Account API first if conditions allow - if (supportedNetworks && chainsToDetectUsingAccountAPI.length > 0) { - const apiResult = await this.#attemptAccountAPIDetection( - chainsToDetectUsingAccountAPI, - addressToDetect, - supportedNetworks, - jwtToken, - ); - - // If the account API call failed or returned undefined, have those chains fall back to RPC detection - if (!apiResult || apiResult.result === 'failed') { - this.#addChainsToRpcDetection( - chainsToDetectUsingRpc, - chainsToDetectUsingAccountAPI, - clientNetworks, - ); - } else if ( - apiResult?.result === 'success' && - apiResult.unprocessedNetworks && - apiResult.unprocessedNetworks.length > 0 - ) { - // Handle unprocessed networks by adding them to RPC detection - const unprocessedChainIds = apiResult.unprocessedNetworks.map( - (chainId: number) => toHex(chainId), - ); - this.#addChainsToRpcDetection( - chainsToDetectUsingRpc, - unprocessedChainIds, - clientNetworks, + // If forceRpc is true, use RPC for all chains + // Otherwise, skip chains supported by Accounts API (they are handled by TokenBalancesController) + const chainsToDetectUsingRpc = forceRpc + ? clientNetworks + : clientNetworks.filter( + ({ chainId }) => + !SUPPORTED_NETWORKS_ACCOUNTS_API_V4.includes(chainId), ); - } - } - // Proceed with RPC detection if there are chains remaining in chainsToDetectUsingRpc - if (chainsToDetectUsingRpc.length > 0) { - await this.#detectTokensUsingRpc(chainsToDetectUsingRpc, addressToDetect); + if (chainsToDetectUsingRpc.length === 0) { + return; } + + await this.#detectTokensUsingRpc(chainsToDetectUsingRpc, addressToDetect); } #getSlicesOfTokensToDetect({ @@ -851,176 +731,6 @@ export class TokenDetectionController extends StaticIntervalPollingController { - // Fetch balances for multiple chain IDs at once - const apiResponse = await this.#accountsAPI - .getMultiNetworksBalances( - selectedAddress, - chainIds, - supportedNetworks, - jwtToken, - ) - .catch(() => null); - - if (apiResponse === null) { - return { result: 'failed' } as const; - } - - const tokenBalancesByChain = apiResponse.balances; - - // Process each chain ID individually - for (const chainId of chainIds) { - const isTokenDetectionInactiveInMainnet = - !this.#isDetectionEnabledFromPreferences && - chainId === ChainId.mainnet; - const { tokensChainsCache } = this.messenger.call( - 'TokenListController:getState', - ); - this.#tokensChainsCache = isTokenDetectionInactiveInMainnet - ? this.#getConvertedStaticMainnetTokenList() - : (tokensChainsCache ?? {}); - - // Generate token candidates based on chainId and selectedAddress - const tokenCandidateSlices = this.#getSlicesOfTokensToDetect({ - chainId, - selectedAddress, - }); - - // Filter balances for the current chainId - const tokenBalances = tokenBalancesByChain.filter( - (balance) => balance.chainId === hexToNumber(chainId), - ); - - if (!tokenBalances || tokenBalances.length === 0) { - continue; - } - - // Use helper function to filter tokens with balance for this chainId - const { tokensWithBalance, eventTokensDetails } = - this.#filterAndBuildTokensWithBalance( - tokenCandidateSlices, - tokenBalances, - chainId, - ); - - if (tokensWithBalance.length) { - this.#trackMetaMetricsEvent({ - event: 'Token Detected', - category: 'Wallet', - properties: { - tokens: eventTokensDetails, - token_standard: ERC20, - asset_type: ASSET_TYPES.TOKEN, - }, - }); - - const networkClientId = this.messenger.call( - 'NetworkController:findNetworkClientIdByChainId', - chainId, - ); - - await this.messenger.call( - 'TokensController:addTokens', - tokensWithBalance, - networkClientId, - ); - } - } - - return { - result: 'success', - unprocessedNetworks: apiResponse.unprocessedNetworks, - } as const; - }); - } - - /** - * Helper function to filter and build token data for detected tokens - * - * @param options.tokenCandidateSlices - these are tokens we know a user does not have (by checking the tokens controller). - * We will use these these token candidates to determine if a token found from the API is valid to be added on the users wallet. - * It will also prevent us to adding tokens a user already has - * @param tokenBalances - Tokens balances fetched from API - * @param chainId - The chain ID being processed - * @returns an object containing tokensWithBalance and eventTokensDetails arrays - */ - - #filterAndBuildTokensWithBalance( - tokenCandidateSlices: string[][], - tokenBalances: - | { - object: string; - type?: string; - timestamp?: string; - address: string; - symbol: string; - name: string; - decimals: number; - chainId: number; - balance: string; - }[] - | null, - chainId: Hex, - ) { - const tokensWithBalance: Token[] = []; - const eventTokensDetails: string[] = []; - - const tokenCandidateSet = new Set(tokenCandidateSlices.flat()); - - tokenBalances?.forEach((token) => { - const tokenAddress = token.address; - - // Make sure the token to add is in our candidate list - if (!tokenCandidateSet.has(tokenAddress)) { - return; - } - - // Retrieve token data from cache to safely add it - const tokenData = this.#tokensChainsCache[chainId]?.data[tokenAddress]; - - // We need specific data from tokensChainsCache to correctly create a token - // So even if we have a token that was detected correctly by the API, if its missing data we cannot safely add it. - if (!tokenData) { - return; - } - - const { decimals, symbol, aggregators, iconUrl, name } = tokenData; - eventTokensDetails.push(`${symbol} - ${tokenAddress}`); - tokensWithBalance.push({ - address: tokenAddress, - decimals, - symbol, - aggregators, - image: iconUrl, - isERC721: false, - name, - }); - }); - - return { tokensWithBalance, eventTokensDetails }; - } - async #addDetectedTokens({ tokensSlice, selectedAddress, @@ -1078,8 +788,10 @@ export class TokenDetectionController extends StaticIntervalPollingController { + // Check if token detection is enabled via preferences + if (!this.#useTokenDetection()) { + return; + } + + // Check if external services are enabled (websocket requires external services) + if (!this.#useExternalServices()) { + return; + } + const tokensWithBalance: Token[] = []; const eventTokensDetails: string[] = []; @@ -1154,17 +876,130 @@ export class TokenDetectionController extends StaticIntervalPollingController { + // Check if token detection is enabled via preferences + if (!this.#useTokenDetection()) { + return; + } + + // Check if external services are enabled (polling via API requires external services) + if (!this.#useExternalServices()) { + return; + } + + const selectedAddress = this.#getSelectedAddress(); + + // Get current token states to filter out already tracked/ignored tokens + const { allTokens, allIgnoredTokens } = this.messenger.call( + 'TokensController:getState', + ); + + const existingTokenAddresses = ( + allTokens[chainId]?.[selectedAddress] ?? [] + ).map((token) => token.address.toLowerCase()); + + const ignoredTokenAddresses = ( + allIgnoredTokens[chainId]?.[selectedAddress] ?? [] + ).map((address) => address.toLowerCase()); + + const tokensWithBalance: Token[] = []; + const eventTokensDetails: string[] = []; + + for (const tokenAddress of tokensSlice) { + const lowercaseTokenAddress = tokenAddress.toLowerCase(); + const checksummedTokenAddress = toChecksumHexAddress(tokenAddress); + + // Skip tokens already in allTokens + if (existingTokenAddresses.includes(lowercaseTokenAddress)) { + continue; + } + + // Skip tokens in allIgnoredTokens + if (ignoredTokenAddresses.includes(lowercaseTokenAddress)) { + continue; + } + + // Check map of validated tokens (cache keys are lowercase) + const tokenData = + this.#tokensChainsCache[chainId]?.data?.[lowercaseTokenAddress]; + + if (!tokenData) { + console.warn( + `Token metadata not found in cache for ${tokenAddress} on chain ${chainId}`, + ); + continue; + } + + const { decimals, symbol, aggregators, iconUrl, name } = tokenData; + + eventTokensDetails.push(`${symbol} - ${checksummedTokenAddress}`); + tokensWithBalance.push({ + address: checksummedTokenAddress, + decimals, + symbol, + aggregators, + image: iconUrl, + isERC721: false, + name, + }); + } + + // Perform addition + if (tokensWithBalance.length) { + this.#trackMetaMetricsEvent({ + event: 'Token Detected', + category: 'Wallet', + properties: { + tokens: eventTokensDetails, + token_standard: ERC20, + asset_type: ASSET_TYPES.TOKEN, + }, + }); + + const networkClientId = this.messenger.call( + 'NetworkController:findNetworkClientIdByChainId', + chainId, + ); + + await this.messenger.call( + 'TokensController:addTokens', + tokensWithBalance, + networkClientId, + ); + } + } + + #getSelectedAccount(): InternalAccount { return this.messenger.call('AccountsController:getSelectedAccount'); } - #getSelectedAddress() { + #getSelectedAddress(): string { // If the address is not defined (or empty), we fallback to the currently selected account's address const account = this.messenger.call( 'AccountsController:getAccount', this.#selectedAccountId, ); - return account?.address || ''; + return account?.address ?? ''; } } diff --git a/packages/assets-controllers/src/index.ts b/packages/assets-controllers/src/index.ts index f119072a564..3ed3ef4ed8e 100644 --- a/packages/assets-controllers/src/index.ts +++ b/packages/assets-controllers/src/index.ts @@ -87,7 +87,9 @@ export type { TokenDetectionControllerMessenger, TokenDetectionControllerActions, TokenDetectionControllerGetStateAction, + TokenDetectionControllerDetectTokensAction, TokenDetectionControllerAddDetectedTokensViaWsAction, + TokenDetectionControllerAddDetectedTokensViaPollingAction, TokenDetectionControllerEvents, TokenDetectionControllerStateChangeEvent, } from './TokenDetectionController'; diff --git a/packages/assets-controllers/src/multi-chain-accounts-service/api-balance-fetcher.ts b/packages/assets-controllers/src/multi-chain-accounts-service/api-balance-fetcher.ts index 8bd15a2a7cd..441ef710651 100644 --- a/packages/assets-controllers/src/multi-chain-accounts-service/api-balance-fetcher.ts +++ b/packages/assets-controllers/src/multi-chain-accounts-service/api-balance-fetcher.ts @@ -8,10 +8,12 @@ import { toChecksumHexAddress, } from '@metamask/controller-utils'; import type { InternalAccount } from '@metamask/keyring-internal-api'; -import type { CaipAccountAddress, Hex } from '@metamask/utils'; +import type { CaipAccountAddress, CaipChainId, Hex } from '@metamask/utils'; +import { parseCaipChainId } from '@metamask/utils'; import BN from 'bn.js'; import { fetchMultiChainBalancesV4 } from './multi-chain-accounts'; +import type { GetBalancesResponse } from './types'; import { STAKING_CONTRACT_ADDRESS_BY_CHAINID } from '../AssetsContractController'; import { accountAddressToCaipReference, @@ -226,7 +228,7 @@ export class AccountsApiBalanceFetcher implements BalanceFetcher { type ResponseData = Awaited>; - const allUnprocessedNetworks = new Set(); + const allUnprocessedNetworks = new Set(); const allBalances = await reduceInBatchesSerially< CaipAccountAddress, BalanceData[] @@ -293,10 +295,19 @@ export class AccountsApiBalanceFetcher implements BalanceFetcher { } // Extract unprocessed networks and convert to hex chain IDs - const unprocessedChainIds: ChainIdHex[] | undefined = - apiResponse.unprocessedNetworks - ? apiResponse.unprocessedNetworks.map((chainId) => toHex(chainId)) - : undefined; + // V4 API returns CAIP chain IDs like 'eip155:1329', need to parse them + // V2 API returns decimal numbers, handle both cases + const unprocessedChainIds: ChainIdHex[] | undefined = apiResponse + .unprocessedNetworks?.length + ? apiResponse.unprocessedNetworks.map((network) => { + if (typeof network === 'string') { + // CAIP chain ID format: 'eip155:1329' + return toHex(parseCaipChainId(network as CaipChainId).reference); + } + // Decimal number format + return toHex(network); + }) + : undefined; const stakedBalances = await this.#fetchStakedBalances(caipAddrs); @@ -322,55 +333,57 @@ export class AccountsApiBalanceFetcher implements BalanceFetcher { // Process regular API balances if (apiResponse.balances) { - const apiBalances = apiResponse.balances.flatMap((b) => { - const addressPart = b.accountAddress?.split(':')[2]; - if (!addressPart) { - return []; - } - const account = checksum(addressPart); - const token = checksum(b.address); - // Use original address for zero address tokens, checksummed for others - // TODO: this is a hack to get the correct account address type but needs to be fixed - // by mgrating tokenBalancesController to checksum addresses - const finalAccount: ChecksumAddress | string = - token === ZERO_ADDRESS ? account : addressPart; - const chainId = toHex(b.chainId); - - let value: BN | undefined; - try { - // Convert string balance to BN avoiding floating point precision issues - const { balance: balanceStr, decimals } = b; - - // Split the balance string into integer and decimal parts - const [integerPart = '0', decimalPart = ''] = balanceStr.split('.'); - - // Pad or truncate decimal part to match token decimals - const paddedDecimalPart = decimalPart - .padEnd(decimals, '0') - .slice(0, decimals); - - // Combine and create BN - const fullIntegerStr = integerPart + paddedDecimalPart; - value = new BN(fullIntegerStr); - } catch { - value = undefined; - } + const apiBalances = apiResponse.balances.flatMap( + (b: GetBalancesResponse['balances'][number]) => { + const addressPart = b.accountAddress?.split(':')[2]; + if (!addressPart) { + return []; + } + const account = checksum(addressPart); + const token = checksum(b.address); + // Use original address for zero address tokens, checksummed for others + // TODO: this is a hack to get the correct account address type but needs to be fixed + // by mgrating tokenBalancesController to checksum addresses + const finalAccount: ChecksumAddress | string = + token === ZERO_ADDRESS ? account : addressPart; + const chainId = toHex(b.chainId); + + let value: BN | undefined; + try { + // Convert string balance to BN avoiding floating point precision issues + const { balance: balanceStr, decimals } = b; + + // Split the balance string into integer and decimal parts + const [integerPart = '0', decimalPart = ''] = balanceStr.split('.'); + + // Pad or truncate decimal part to match token decimals + const paddedDecimalPart = decimalPart + .padEnd(decimals, '0') + .slice(0, decimals); + + // Combine and create BN + const fullIntegerStr = integerPart + paddedDecimalPart; + value = new BN(fullIntegerStr); + } catch { + value = undefined; + } - // Track native balances for later - if (token === ZERO_ADDRESS && value !== undefined) { - nativeBalancesFromAPI.set(`${finalAccount}-${chainId}`, value); - } + // Track native balances for later + if (token === ZERO_ADDRESS && value !== undefined) { + nativeBalancesFromAPI.set(`${finalAccount}-${chainId}`, value); + } - return [ - { - success: value !== undefined, - value, - account: finalAccount, - token, - chainId, - }, - ]; - }); + return [ + { + success: value !== undefined, + value, + account: finalAccount, + token, + chainId, + }, + ]; + }, + ); results.push(...apiBalances); } diff --git a/packages/assets-controllers/src/multi-chain-accounts-service/types.ts b/packages/assets-controllers/src/multi-chain-accounts-service/types.ts index 746bf605a23..9c161eff2f6 100644 --- a/packages/assets-controllers/src/multi-chain-accounts-service/types.ts +++ b/packages/assets-controllers/src/multi-chain-accounts-service/types.ts @@ -44,5 +44,6 @@ export type GetBalancesResponse = { accountAddress?: string; }[]; /** networks that failed to process, if no network is processed, returns HTTP 422 */ - unprocessedNetworks: number[]; + /** V4 API returns CAIP chain IDs like 'eip155:1329', V2 API returns decimal numbers */ + unprocessedNetworks: (number | string)[]; };