From 8ab19dc3103c9ca117f40d3742166c47e00d8829 Mon Sep 17 00:00:00 2001 From: Roman Khafizianov Date: Mon, 18 Aug 2025 15:15:42 +0200 Subject: [PATCH] Refactor crypto: add master node derivation API --- go.mod | 2 +- go.sum | 2 + util/crypto/mnemonic.go | 83 +++++++++++++--- util/crypto/mnemonic_test.go | 183 +++++++++++++++++++++++++++++++++++ 4 files changed, 258 insertions(+), 12 deletions(-) diff --git a/go.mod b/go.mod index b32edecbc..01dcb6b04 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,7 @@ require ( filippo.io/edwards25519 v1.1.0 github.com/anyproto/any-store v0.3.1 github.com/anyproto/go-chash v0.1.0 - github.com/anyproto/go-slip10 v1.0.0 + github.com/anyproto/go-slip10 v1.0.1-0.20250818123350-f910c27dd080 github.com/anyproto/go-slip21 v1.0.0 github.com/anyproto/lexid v0.0.4 github.com/anyproto/protobuf v1.3.3-0.20240814124528-72b8c7e0e0f5 diff --git a/go.sum b/go.sum index 15e4d1f21..cb603ff94 100644 --- a/go.sum +++ b/go.sum @@ -12,6 +12,8 @@ github.com/anyproto/go-chash v0.1.0 h1:I9meTPjXFRfXZHRJzjOHC/XF7Q5vzysKkiT/grsog github.com/anyproto/go-chash v0.1.0/go.mod h1:0UjNQi3PDazP0fINpFYu6VKhuna+W/V+1vpXHAfNgLY= github.com/anyproto/go-slip10 v1.0.0 h1:uAEtSuudR3jJBOfkOXf3bErxVoxbuKwdoJN55M1i6IA= github.com/anyproto/go-slip10 v1.0.0/go.mod h1:BCmIlM1KB8wX6K4/8pOvxPl9oVKfEvZ5vsmO5rkK6vg= +github.com/anyproto/go-slip10 v1.0.1-0.20250818123350-f910c27dd080 h1:bbHmaibcUbctrXG6LT6136H0oDlBUDoDANX2qBpqhkU= +github.com/anyproto/go-slip10 v1.0.1-0.20250818123350-f910c27dd080/go.mod h1:BCmIlM1KB8wX6K4/8pOvxPl9oVKfEvZ5vsmO5rkK6vg= github.com/anyproto/go-slip21 v1.0.0 h1:CI7lUqTIwmPOEGVAj4jyNLoICvueh++0U2HoAi3m2ZY= github.com/anyproto/go-slip21 v1.0.0/go.mod h1:gbIJt7HAdr5DuT4f2pFTKCBSUWYsm/fysHBNqgsuxT0= github.com/anyproto/go-sqlite v1.4.2-any h1:ZTIcq/u2mYYJ6rJB4I3Ds5QH/7IlONebMiG14FyZcD4= diff --git a/util/crypto/mnemonic.go b/util/crypto/mnemonic.go index 9aed90229..d13ed7b78 100644 --- a/util/crypto/mnemonic.go +++ b/util/crypto/mnemonic.go @@ -49,6 +49,58 @@ type DerivationResult struct { EthereumIdentity ecdsa.PrivateKey } +// DeriveKeysFromMasterNode derives master key and identity from a master node +// The master node should be at path m/44'/2046'/index' +func DeriveKeysFromMasterNode(masterNode slip10.Node) (res DerivationResult, err error) { + res.MasterNode = masterNode + + // Derive master key from the node + res.MasterKey, err = genKey(masterNode) + if err != nil { + return + } + + // Derive identity at m/44'/2046'/index'/0' + identityNode, err := masterNode.Derive(slip10.FirstHardenedIndex) + if err != nil { + return + } + res.Identity, err = genKey(identityNode) + + return +} + +// DeriveMasterNode derives a master node at the specified index +// Returns the node at path m/44'/2046'/index' +func (m Mnemonic) DeriveMasterNode(index uint32) (masterNode slip10.Node, err error) { + seed, err := m.Seed() + if err != nil { + return + } + + prefixNode, err := slip10.DeriveForPath(anytypeAccountNewPrefix, seed) + if err != nil { + return + } + + // m/44'/2046'/index' + masterNode, err = prefixNode.Derive(slip10.FirstHardenedIndex + index) + return +} + +// DeriveMasterNodeFromSeed derives a master node from a seed at the specified index +// This creates a node at path m/44'/2046'/index' +func DeriveMasterNodeFromSeed(seed []byte, index uint32) (masterNode slip10.Node, err error) { + prefixNode, err := slip10.DeriveForPath(anytypeAccountNewPrefix, seed) + if err != nil { + return + } + + // m/44'/2046'/index' + masterNode, err = prefixNode.Derive(slip10.FirstHardenedIndex + index) + return +} + type MnemonicGenerator struct { mnemonic string } @@ -112,31 +164,40 @@ func (m Mnemonic) deriveForPath(onlyMaster bool, index uint32, path string) (res if err != nil { return } - res.MasterKey, err = genKey(res.MasterNode) - if err != nil || onlyMaster { + + if onlyMaster { + // Only derive the master key + res.MasterKey, err = genKey(res.MasterNode) return } - // m/44'/code'/index'/0' - identityNode, err := res.MasterNode.Derive(slip10.FirstHardenedIndex) - if err != nil { - return - } - res.Identity, err = genKey(identityNode) - return + + // Use the public method to derive both master key and identity + return DeriveKeysFromMasterNode(res.MasterNode) } func (m Mnemonic) DeriveKeys(index uint32) (res DerivationResult, err error) { + // Derive old account key for backward compatibility oldRes, err := m.deriveForPath(true, index, anytypeAccountOldPrefix) if err != nil { return } - res, err = m.deriveForPath(false, index, anytypeAccountNewPrefix) + + // Derive master node using the new public method + masterNode, err := m.DeriveMasterNode(index) + if err != nil { + return + } + + // Derive keys from master node using the public method + res, err = DeriveKeysFromMasterNode(masterNode) if err != nil { return } + + // Add old account key for backward compatibility res.OldAccountKey = oldRes.MasterKey - // now derive ethereum key + // Derive ethereum key pk, err := m.ethereumKeyFromMnemonic(index, defaultEthereumDerivation) if err != nil { return diff --git a/util/crypto/mnemonic_test.go b/util/crypto/mnemonic_test.go index 95af7c42b..6856183bc 100644 --- a/util/crypto/mnemonic_test.go +++ b/util/crypto/mnemonic_test.go @@ -137,3 +137,186 @@ func TestMnemonic_ethereumKeyFromMnemonic(t *testing.T) { pkStr = Encode(bytes)[2:] require.Equal(t, "b31048b0aa87649bdb9016c0ee28c788ddfc45e52cd71cc0da08c47cb4390ae7", pkStr) } + +func TestDeriveMasterNode(t *testing.T) { + phrase, err := NewMnemonicGenerator().WithWordCount(12) + require.NoError(t, err) + + // Test deriving master node for index 0 + masterNode0, err := phrase.DeriveMasterNode(0) + require.NoError(t, err) + require.NotNil(t, masterNode0) + + // Test deriving master node for index 1 + masterNode1, err := phrase.DeriveMasterNode(1) + require.NoError(t, err) + require.NotNil(t, masterNode1) + + // Verify that different indices produce different nodes + raw0, err := masterNode0.RawSeed(), nil + require.NoError(t, err) + raw1, err := masterNode1.RawSeed(), nil + require.NoError(t, err) + require.NotEqual(t, raw0, raw1) +} + +func TestDeriveKeysFromMasterNode(t *testing.T) { + phrase, err := NewMnemonicGenerator().WithWordCount(12) + require.NoError(t, err) + + // Get master node + masterNode, err := phrase.DeriveMasterNode(0) + require.NoError(t, err) + + // Derive keys from master node + result, err := DeriveKeysFromMasterNode(masterNode) + require.NoError(t, err) + require.NotNil(t, result.MasterKey) + require.NotNil(t, result.Identity) + + // Verify the keys can sign and verify + testData := []byte("test data for signing") + + // Test master key + masterSig, err := result.MasterKey.Sign(testData) + require.NoError(t, err) + verified, err := result.MasterKey.GetPublic().Verify(testData, masterSig) + require.NoError(t, err) + require.True(t, verified) + + // Test identity key + identitySig, err := result.Identity.Sign(testData) + require.NoError(t, err) + verified, err = result.Identity.GetPublic().Verify(testData, identitySig) + require.NoError(t, err) + require.True(t, verified) +} + +func TestMasterNodeDerivationConsistency(t *testing.T) { + // Use a fixed mnemonic for consistency test + var phrase Mnemonic = "tag volcano eight thank tide danger coast health above argue embrace heavy" + + // Derive using the traditional method + traditionalResult, err := phrase.DeriveKeys(0) + require.NoError(t, err) + + // Derive using the new master node method + masterNode, err := phrase.DeriveMasterNode(0) + require.NoError(t, err) + newMethodResult, err := DeriveKeysFromMasterNode(masterNode) + require.NoError(t, err) + + // Verify that both methods produce the same master key + require.True(t, traditionalResult.MasterKey.Equals(newMethodResult.MasterKey)) + + // Verify that both methods produce the same identity + require.True(t, traditionalResult.Identity.Equals(newMethodResult.Identity)) +} + +func TestBackwardCompatibility(t *testing.T) { + // Test that existing functionality still works + phrase, err := NewMnemonicGenerator().WithWordCount(12) + require.NoError(t, err) + + // Test traditional DeriveKeys method + result, err := phrase.DeriveKeys(0) + require.NoError(t, err) + require.NotNil(t, result.MasterKey) + require.NotNil(t, result.Identity) + require.NotNil(t, result.OldAccountKey) + require.NotNil(t, result.MasterNode) + + // Verify Ethereum identity is still derived + publicKey := result.EthereumIdentity.Public() + _, ok := publicKey.(*ecdsa.PublicKey) + require.True(t, ok) +} + +func TestMasterNodeSerialization(t *testing.T) { + // Generate a test mnemonic + phrase, err := NewMnemonicGenerator().WithWordCount(12) + require.NoError(t, err) + + // Derive a master node + originalNode, err := phrase.DeriveMasterNode(0) + require.NoError(t, err) + + // Serialize the node using slip10's MarshalBinary + serialized, err := originalNode.MarshalBinary() + require.NoError(t, err) + require.Len(t, serialized, 64) // Should be exactly 64 bytes + + // Deserialize the node using slip10's UnmarshalNode + deserializedNode, err := slip10.UnmarshalNode(serialized) + require.NoError(t, err) + + // Verify the deserialized node produces the same keys + originalResult, err := DeriveKeysFromMasterNode(originalNode) + require.NoError(t, err) + + deserializedResult, err := DeriveKeysFromMasterNode(deserializedNode) + require.NoError(t, err) + + // Compare master keys + require.True(t, originalResult.MasterKey.Equals(deserializedResult.MasterKey)) + + // Compare identity keys + require.True(t, originalResult.Identity.Equals(deserializedResult.Identity)) + + // Verify the deserialized node can still derive child keys + childNode, err := deserializedNode.Derive(slip10.FirstHardenedIndex + 1) + require.NoError(t, err) + require.NotNil(t, childNode) +} + +func TestMasterNodeSerializationConsistency(t *testing.T) { + // Use a fixed mnemonic for consistency + var phrase Mnemonic = "tag volcano eight thank tide danger coast health above argue embrace heavy" + + // Derive master node at index 0 + node0, err := phrase.DeriveMasterNode(0) + require.NoError(t, err) + + // Serialize and deserialize using slip10 methods + serialized0, err := node0.MarshalBinary() + require.NoError(t, err) + + deserialized0, err := slip10.UnmarshalNode(serialized0) + require.NoError(t, err) + + // Derive a child from both original and deserialized + originalChild, err := node0.Derive(slip10.FirstHardenedIndex) + require.NoError(t, err) + + deserializedChild, err := deserialized0.Derive(slip10.FirstHardenedIndex) + require.NoError(t, err) + + // Verify both children produce the same key + originalKey, err := genKey(originalChild) + require.NoError(t, err) + + deserializedKey, err := genKey(deserializedChild) + require.NoError(t, err) + + require.True(t, originalKey.Equals(deserializedKey)) +} + +func TestInvalidSerialization(t *testing.T) { + // Test with invalid data lengths + testCases := []struct { + name string + data []byte + }{ + {"empty", []byte{}}, + {"too short", make([]byte, 32)}, + {"too long", make([]byte, 128)}, + {"almost correct", make([]byte, 63)}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + _, err := slip10.UnmarshalNode(tc.data) + require.Error(t, err) + }) + } +}