diff --git a/go/tdh2/tdh2/tdh2.go b/go/tdh2/tdh2/tdh2.go index c56e3d3..f2d2b4c 100644 --- a/go/tdh2/tdh2/tdh2.go +++ b/go/tdh2/tdh2/tdh2.go @@ -14,13 +14,13 @@ import ( "github.com/smartcontractkit/tdh2/go/tdh2/lib/group/share" ) -var ( - // defaultHash is the default hash function used. Note, its output size - // determines the input size in TDH2. - defaultHash = sha256.New - // InputSize determines the size of messages and labels. - InputSize = defaultHash().Size() -) +// defaultHash is the default hash function used. Note, its output size +// determines the input size in TDH2. +var defaultHash = sha256.New + +// InputSize determines the size of messages and labels. +// It is fixed to 32 bytes (SHA-256 output size). +const InputSize = sha256.Size func parseGroup(group string) (group.Group, error) { switch group { @@ -426,6 +426,13 @@ func (a *Ciphertext) Equal(b *Ciphertext) bool { } +// Label returns a copy of the label associated with the ciphertext. +func (c *Ciphertext) Label() [InputSize]byte { + var label [InputSize]byte + copy(label[:], c.label) + return label +} + // Decrypt decrypts a ciphertext using a secret key share x_i according to TDH2 paper. // The caller has to ensure that the ciphertext is validated. func (ctxt *Ciphertext) Decrypt(group group.Group, x_i *PrivateShare, rand cipher.Stream) (*DecryptionShare, error) { diff --git a/go/tdh2/tdh2/tdh2_test.go b/go/tdh2/tdh2/tdh2_test.go index 8a5778c..c07008a 100644 --- a/go/tdh2/tdh2/tdh2_test.go +++ b/go/tdh2/tdh2/tdh2_test.go @@ -327,6 +327,10 @@ func TestEncrypt(t *testing.T) { } else if err != nil { return } + gotLabel := ctxt.Label() + if !bytes.Equal(gotLabel[:], label) { + t.Errorf("label mismatch got=%v want=%v", gotLabel, label) + } if diff := cmp.Diff(label, ctxt.label); diff != "" { t.Errorf("label/ctx.Label diff: %v", diff) } @@ -451,6 +455,19 @@ func TestCtxtVerify(t *testing.T) { }, err: cmpopts.AnyError, }, + { + name: "wrong Label, good length", + ctxt: &Ciphertext{ + group: group, + c: ctxt.c, + label: make([]byte, InputSize), + u: ctxt.u, + u_bar: ctxt.u_bar, + e: ctxt.e, + f: ctxt.f, + }, + err: cmpopts.AnyError, + }, { name: "broken U", ctxt: &Ciphertext{ diff --git a/go/tdh2/tdh2easy/tdh2easy.go b/go/tdh2/tdh2easy/tdh2easy.go index 6fe61bb..132b800 100644 --- a/go/tdh2/tdh2easy/tdh2easy.go +++ b/go/tdh2/tdh2easy/tdh2easy.go @@ -266,9 +266,14 @@ func Redeal(pk *PublicKey, ms *MasterSecret, k, n int) (*PublicKey, []*PrivateSh } // Encrypt generates a fresh symmetric key, encrypts and authenticates -// the message with it, and encrypts the key using TDH2. It returns a -// struct encoding the generated ciphertexts. +// the message with it, and encrypts the key using TDH2 with empty label. +// It returns a struct encoding the generated ciphertexts. func Encrypt(pk *PublicKey, msg []byte) (*Ciphertext, error) { + return EncryptWithLabel(pk, msg, [32]byte{}) +} + +// EncryptWithLabel is identical to Encrypt but allows passing a non-empty label. +func EncryptWithLabel(pk *PublicKey, msg []byte, label [tdh2.InputSize]byte) (*Ciphertext, error) { if aes256KeySize != tdh2.InputSize { return nil, fmt.Errorf("incorrect key size") } @@ -288,8 +293,8 @@ func Encrypt(pk *PublicKey, msg []byte) (*Ciphertext, error) { if err != nil { return nil, err } - // encrypt the key with TDH2 using empty label - tdh2Ctxt, err := tdh2.Encrypt(pk.p, key, make([]byte, tdh2.InputSize), r) + // encrypt the key with TDH2 using the provided label + tdh2Ctxt, err := tdh2.Encrypt(pk.p, key, label[:], r) if err != nil { return nil, fmt.Errorf("cannot TDH2 encrypt: %w", err) } @@ -299,3 +304,8 @@ func Encrypt(pk *PublicKey, msg []byte) (*Ciphertext, error) { nonce: nonce, }, nil } + +// Label returns a defensive copy of the ciphertext's TDH2 label. +func (c *Ciphertext) Label() [tdh2.InputSize]byte { + return c.tdh2Ctxt.Label() +} diff --git a/go/tdh2/tdh2easy/tdh2easy_test.go b/go/tdh2/tdh2easy/tdh2easy_test.go index c2507c4..07121ba 100644 --- a/go/tdh2/tdh2easy/tdh2easy_test.go +++ b/go/tdh2/tdh2easy/tdh2easy_test.go @@ -595,3 +595,30 @@ func FuzzCiphertextUnmarshal(f *testing.F) { } }) } + +// TestEncryptWithLabel ensures non-empty labels are preserved, and default Encrypt uses empty label. +func TestEncryptWithLabel(t *testing.T) { + _, pk, _, err := GenerateKeys(2, 3) + if err != nil { + t.Fatalf("GenerateKeys: %v", err) + } + var label [tdh2.InputSize]byte + for i := range label { + label[i] = byte(i + 1) + } + c, err := EncryptWithLabel(pk, []byte("msg"), label) + if err != nil { + t.Fatalf("EncryptWithLabel: %v", err) + } + if got := c.Label(); got != label { + t.Errorf("label mismatch got=%v want=%v", got, label) + } + // Ensure regular Encrypt produces all-zero label. + cZero, err := Encrypt(pk, []byte("msg")) + if err != nil { + t.Fatalf("Encrypt: %v", err) + } + if got := cZero.Label(); got != [tdh2.InputSize]byte{} { + t.Errorf("expected zero label got=%v", got) + } +}