Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

shake: add validity check #285

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 61 additions & 10 deletions sha3/sha3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ var testDigests = map[string]func() hash.Hash{
// testShakes contains functions that return sha3.ShakeHash instances for
// with output-length equal to the KAT length.
var testShakes = map[string]struct {
constructor func(N []byte, S []byte) ShakeHash
constructor func(N []byte, S []byte) (ShakeHash, error)
defAlgoName string
defCustomStr string
}{
Expand Down Expand Up @@ -136,7 +136,10 @@ func TestKeccakKats(t *testing.T) {
if err != nil {
t.Errorf("error decoding KAT: %s", err)
}
d := v.constructor(N, S)
d, err := v.constructor(N, S)
if err != nil {
t.Errorf("error creating cSHAKE: %s", err)
}
in, err := hex.DecodeString(kat.Message)
if err != nil {
t.Errorf("error decoding KAT: %s", err)
Expand Down Expand Up @@ -221,7 +224,10 @@ func TestUnalignedWrite(t *testing.T) {
for alg, df := range testShakes {
want := make([]byte, 16)
got := make([]byte, 16)
d := df.constructor([]byte(df.defAlgoName), []byte(df.defCustomStr))
d, err := df.constructor([]byte(df.defAlgoName), []byte(df.defCustomStr))
if err != nil {
t.Errorf("error creating cSHAKE: %s", err)
}

d.Reset()
d.Write(buf)
Expand Down Expand Up @@ -286,12 +292,19 @@ func TestAppendNoRealloc(t *testing.T) {
func TestSqueezing(t *testing.T) {
testUnalignedAndGeneric(t, func(impl string) {
for algo, v := range testShakes {
d0 := v.constructor([]byte(v.defAlgoName), []byte(v.defCustomStr))
d0, err := v.constructor([]byte(v.defAlgoName), []byte(v.defCustomStr))
if err != nil {
t.Errorf("error creating cSHAKE: %s", err)
}

d0.Write([]byte(testString))
ref := make([]byte, 32)
d0.Read(ref)

d1 := v.constructor([]byte(v.defAlgoName), []byte(v.defCustomStr))
d1, err := v.constructor([]byte(v.defAlgoName), []byte(v.defCustomStr))
if err != nil {
t.Errorf("error creating cSHAKE: %s", err)
}
d1.Write([]byte(testString))
var multiple []byte
for range ref {
Expand Down Expand Up @@ -327,7 +340,10 @@ func TestReset(t *testing.T) {

for _, v := range testShakes {
// Calculate hash for the first time
c := v.constructor(nil, []byte{0x99, 0x98})
c, err := v.constructor(nil, []byte{0x99, 0x98})
if err != nil {
t.Errorf("error creating cSHAKE: %s", err)
}
c.Write(sequentialBytes(0x100))
c.Read(out1)

Expand All @@ -350,7 +366,10 @@ func TestClone(t *testing.T) {
for _, size := range []int{0x1, 0x100} {
in := sequentialBytes(size)
for _, v := range testShakes {
h1 := v.constructor(nil, []byte{0x01})
h1, err := v.constructor(nil, []byte{0x01})
if err != nil {
t.Errorf("error creating cSHAKE: %s", err)
}
h1.Write([]byte{0x01})

h2 := h1.Clone()
Expand All @@ -368,6 +387,26 @@ func TestClone(t *testing.T) {
}
}

// TestValidity tests the length validity checks for cSHAKE.
func TestValidity(t *testing.T) {
inValidBytes := make([]byte, 256)

for _, v := range testShakes {
_, err := v.constructor(nil, inValidBytes)
if err == nil {
t.Error("expected error for S length")
}
_, err = v.constructor(inValidBytes, nil)
if err == nil {
t.Error("expected error for N length")
}
_, err = v.constructor(inValidBytes, inValidBytes)
if err == nil {
t.Error("expected error for N and S length")
}
}
}

// BenchmarkPermutationFunction measures the speed of the permutation function
// with no input data.
func BenchmarkPermutationFunction(b *testing.B) {
Expand Down Expand Up @@ -460,20 +499,32 @@ func ExampleNewCShake256() {
msg := []byte("The quick brown fox jumps over the lazy dog")

// Example 1: Simple cshake
c1 := NewCShake256([]byte("NAME"), []byte("Partition1"))
c1, err := NewCShake256([]byte("NAME"), []byte("Partition1"))
if err != nil {
fmt.Println(err)
return
}
c1.Write(msg)
c1.Read(out)
fmt.Println(hex.EncodeToString(out))

// Example 2: Different customization string produces different digest
c1 = NewCShake256([]byte("NAME"), []byte("Partition2"))
c1, err = NewCShake256([]byte("NAME"), []byte("Partition2"))
if err != nil {
fmt.Println(err)
return
}
c1.Write(msg)
c1.Read(out)
fmt.Println(hex.EncodeToString(out))

// Example 3: Longer output length produces longer digest
out = make([]byte, 64)
c1 = NewCShake256([]byte("NAME"), []byte("Partition1"))
c1, err = NewCShake256([]byte("NAME"), []byte("Partition1"))
if err != nil {
fmt.Println(err)
return
}
c1.Write(msg)
c1.Read(out)
fmt.Println(hex.EncodeToString(out))
Expand Down
19 changes: 13 additions & 6 deletions sha3/shake.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package sha3

import (
"encoding/binary"
"errors"
"hash"
"io"
)
Expand Down Expand Up @@ -80,7 +81,11 @@ func leftEncode(value uint64) []byte {
return b[i-1:]
}

func newCShake(N, S []byte, rate, outputLen int, dsbyte byte) ShakeHash {
func newCShake(N, S []byte, rate, outputLen int, dsbyte byte) (ShakeHash, error) {
if len(N) >= 256 || len(S) >= 256 {
return nil, errors.New("crypto/cSHAKE: N and S can be at most 255 bytes long")
}

c := cshakeState{state: &state{rate: rate, outputLen: outputLen, dsbyte: dsbyte}}

// leftEncode returns max 9 bytes
Expand All @@ -90,7 +95,7 @@ func newCShake(N, S []byte, rate, outputLen int, dsbyte byte) ShakeHash {
c.initBlock = append(c.initBlock, leftEncode(uint64(len(S)*8))...)
c.initBlock = append(c.initBlock, S...)
c.Write(bytepad(c.initBlock, c.rate))
return &c
return &c, nil
}

// Reset resets the hash to initial state.
Expand Down Expand Up @@ -137,9 +142,10 @@ func NewShake256() ShakeHash {
// desired. S is a customization byte string used for domain separation - two cSHAKE
// computations on same input with different S yield unrelated outputs.
// When N and S are both empty, this is equivalent to NewShake128.
func NewCShake128(N, S []byte) ShakeHash {
// N and S can be at most 255 bytes long.
func NewCShake128(N, S []byte) (ShakeHash, error) {
if len(N) == 0 && len(S) == 0 {
return NewShake128()
return NewShake128(), nil
}
return newCShake(N, S, rate128, 32, dsbyteCShake)
}
Expand All @@ -150,9 +156,10 @@ func NewCShake128(N, S []byte) ShakeHash {
// desired. S is a customization byte string used for domain separation - two cSHAKE
// computations on same input with different S yield unrelated outputs.
// When N and S are both empty, this is equivalent to NewShake256.
func NewCShake256(N, S []byte) ShakeHash {
// N and S can be at most 255 bytes long.
func NewCShake256(N, S []byte) (ShakeHash, error) {
if len(N) == 0 && len(S) == 0 {
return NewShake256()
return NewShake256(), nil
}
return newCShake(N, S, rate256, 64, dsbyteCShake)
}
Expand Down