diff --git a/sha3/sha3_test.go b/sha3/sha3_test.go index 83bd6195d6..e155b5c420 100644 --- a/sha3/sha3_test.go +++ b/sha3/sha3_test.go @@ -43,7 +43,7 @@ var testDigests = map[string]func() hash.Hash{ // testShakes contains functions that return sha3.ShakeHash instances for // with output-length equal to the KAT length. var testShakes = map[string]struct { - constructor func(N []byte, S []byte) ShakeHash + constructor func(N []byte, S []byte) (ShakeHash, error) defAlgoName string defCustomStr string }{ @@ -136,7 +136,10 @@ func TestKeccakKats(t *testing.T) { if err != nil { t.Errorf("error decoding KAT: %s", err) } - d := v.constructor(N, S) + d, err := v.constructor(N, S) + if err != nil { + t.Errorf("error creating cSHAKE: %s", err) + } in, err := hex.DecodeString(kat.Message) if err != nil { t.Errorf("error decoding KAT: %s", err) @@ -221,7 +224,10 @@ func TestUnalignedWrite(t *testing.T) { for alg, df := range testShakes { want := make([]byte, 16) got := make([]byte, 16) - d := df.constructor([]byte(df.defAlgoName), []byte(df.defCustomStr)) + d, err := df.constructor([]byte(df.defAlgoName), []byte(df.defCustomStr)) + if err != nil { + t.Errorf("error creating cSHAKE: %s", err) + } d.Reset() d.Write(buf) @@ -286,12 +292,19 @@ func TestAppendNoRealloc(t *testing.T) { func TestSqueezing(t *testing.T) { testUnalignedAndGeneric(t, func(impl string) { for algo, v := range testShakes { - d0 := v.constructor([]byte(v.defAlgoName), []byte(v.defCustomStr)) + d0, err := v.constructor([]byte(v.defAlgoName), []byte(v.defCustomStr)) + if err != nil { + t.Errorf("error creating cSHAKE: %s", err) + } + d0.Write([]byte(testString)) ref := make([]byte, 32) d0.Read(ref) - d1 := v.constructor([]byte(v.defAlgoName), []byte(v.defCustomStr)) + d1, err := v.constructor([]byte(v.defAlgoName), []byte(v.defCustomStr)) + if err != nil { + t.Errorf("error creating cSHAKE: %s", err) + } d1.Write([]byte(testString)) var multiple []byte for range ref { @@ -327,7 +340,10 @@ func TestReset(t *testing.T) { for _, v := range testShakes { // Calculate hash for the first time - c := v.constructor(nil, []byte{0x99, 0x98}) + c, err := v.constructor(nil, []byte{0x99, 0x98}) + if err != nil { + t.Errorf("error creating cSHAKE: %s", err) + } c.Write(sequentialBytes(0x100)) c.Read(out1) @@ -350,7 +366,10 @@ func TestClone(t *testing.T) { for _, size := range []int{0x1, 0x100} { in := sequentialBytes(size) for _, v := range testShakes { - h1 := v.constructor(nil, []byte{0x01}) + h1, err := v.constructor(nil, []byte{0x01}) + if err != nil { + t.Errorf("error creating cSHAKE: %s", err) + } h1.Write([]byte{0x01}) h2 := h1.Clone() @@ -368,6 +387,26 @@ func TestClone(t *testing.T) { } } +// TestValidity tests the length validity checks for cSHAKE. +func TestValidity(t *testing.T) { + inValidBytes := make([]byte, 256) + + for _, v := range testShakes { + _, err := v.constructor(nil, inValidBytes) + if err == nil { + t.Error("expected error for S length") + } + _, err = v.constructor(inValidBytes, nil) + if err == nil { + t.Error("expected error for N length") + } + _, err = v.constructor(inValidBytes, inValidBytes) + if err == nil { + t.Error("expected error for N and S length") + } + } +} + // BenchmarkPermutationFunction measures the speed of the permutation function // with no input data. func BenchmarkPermutationFunction(b *testing.B) { @@ -460,20 +499,32 @@ func ExampleNewCShake256() { msg := []byte("The quick brown fox jumps over the lazy dog") // Example 1: Simple cshake - c1 := NewCShake256([]byte("NAME"), []byte("Partition1")) + c1, err := NewCShake256([]byte("NAME"), []byte("Partition1")) + if err != nil { + fmt.Println(err) + return + } c1.Write(msg) c1.Read(out) fmt.Println(hex.EncodeToString(out)) // Example 2: Different customization string produces different digest - c1 = NewCShake256([]byte("NAME"), []byte("Partition2")) + c1, err = NewCShake256([]byte("NAME"), []byte("Partition2")) + if err != nil { + fmt.Println(err) + return + } c1.Write(msg) c1.Read(out) fmt.Println(hex.EncodeToString(out)) // Example 3: Longer output length produces longer digest out = make([]byte, 64) - c1 = NewCShake256([]byte("NAME"), []byte("Partition1")) + c1, err = NewCShake256([]byte("NAME"), []byte("Partition1")) + if err != nil { + fmt.Println(err) + return + } c1.Write(msg) c1.Read(out) fmt.Println(hex.EncodeToString(out)) diff --git a/sha3/shake.go b/sha3/shake.go index bb69984027..cc1d23df53 100644 --- a/sha3/shake.go +++ b/sha3/shake.go @@ -17,6 +17,7 @@ package sha3 import ( "encoding/binary" + "errors" "hash" "io" ) @@ -80,7 +81,11 @@ func leftEncode(value uint64) []byte { return b[i-1:] } -func newCShake(N, S []byte, rate, outputLen int, dsbyte byte) ShakeHash { +func newCShake(N, S []byte, rate, outputLen int, dsbyte byte) (ShakeHash, error) { + if len(N) >= 256 || len(S) >= 256 { + return nil, errors.New("crypto/cSHAKE: N and S can be at most 255 bytes long") + } + c := cshakeState{state: &state{rate: rate, outputLen: outputLen, dsbyte: dsbyte}} // leftEncode returns max 9 bytes @@ -90,7 +95,7 @@ func newCShake(N, S []byte, rate, outputLen int, dsbyte byte) ShakeHash { c.initBlock = append(c.initBlock, leftEncode(uint64(len(S)*8))...) c.initBlock = append(c.initBlock, S...) c.Write(bytepad(c.initBlock, c.rate)) - return &c + return &c, nil } // Reset resets the hash to initial state. @@ -137,9 +142,10 @@ func NewShake256() ShakeHash { // desired. S is a customization byte string used for domain separation - two cSHAKE // computations on same input with different S yield unrelated outputs. // When N and S are both empty, this is equivalent to NewShake128. -func NewCShake128(N, S []byte) ShakeHash { +// N and S can be at most 255 bytes long. +func NewCShake128(N, S []byte) (ShakeHash, error) { if len(N) == 0 && len(S) == 0 { - return NewShake128() + return NewShake128(), nil } return newCShake(N, S, rate128, 32, dsbyteCShake) } @@ -150,9 +156,10 @@ func NewCShake128(N, S []byte) ShakeHash { // desired. S is a customization byte string used for domain separation - two cSHAKE // computations on same input with different S yield unrelated outputs. // When N and S are both empty, this is equivalent to NewShake256. -func NewCShake256(N, S []byte) ShakeHash { +// N and S can be at most 255 bytes long. +func NewCShake256(N, S []byte) (ShakeHash, error) { if len(N) == 0 && len(S) == 0 { - return NewShake256() + return NewShake256(), nil } return newCShake(N, S, rate256, 64, dsbyteCShake) }