github.com/cloudflare/circl@v1.5.0/expander/expander_test.go (about)

     1  package expander_test
     2  
     3  import (
     4  	"bytes"
     5  	"crypto"
     6  	_ "crypto/sha256"
     7  	_ "crypto/sha512"
     8  	"encoding/hex"
     9  	"encoding/json"
    10  	"fmt"
    11  	"os"
    12  	"path/filepath"
    13  	"strconv"
    14  	"testing"
    15  
    16  	"github.com/cloudflare/circl/expander"
    17  	"github.com/cloudflare/circl/internal/test"
    18  	"github.com/cloudflare/circl/xof"
    19  )
    20  
    21  func TestExpander(t *testing.T) {
    22  	fileNames, err := filepath.Glob("./testdata/*.json")
    23  	if err != nil {
    24  		t.Fatal(err)
    25  	}
    26  
    27  	for _, fileName := range fileNames {
    28  		f, err := os.Open(fileName)
    29  		if err != nil {
    30  			t.Fatal(err)
    31  		}
    32  		dec := json.NewDecoder(f)
    33  		var v vectorExpanderSuite
    34  		err = dec.Decode(&v)
    35  		if err != nil {
    36  			t.Fatal(err)
    37  		}
    38  		f.Close()
    39  
    40  		t.Run(v.Name+"/"+v.Hash, func(t *testing.T) { testExpander(t, &v) })
    41  	}
    42  }
    43  
    44  func testExpander(t *testing.T, vs *vectorExpanderSuite) {
    45  	var exp expander.Expander
    46  	switch vs.Hash {
    47  	case "SHA256":
    48  		exp = expander.NewExpanderMD(crypto.SHA256, []byte(vs.DST))
    49  	case "SHA512":
    50  		exp = expander.NewExpanderMD(crypto.SHA512, []byte(vs.DST))
    51  	case "SHAKE128":
    52  		exp = expander.NewExpanderXOF(xof.SHAKE128, vs.K, []byte(vs.DST))
    53  	case "SHAKE256":
    54  		exp = expander.NewExpanderXOF(xof.SHAKE256, vs.K, []byte(vs.DST))
    55  	default:
    56  		t.Skip("hash not supported: " + vs.Hash)
    57  	}
    58  
    59  	for i, v := range vs.Tests {
    60  		lenBytes, err := strconv.ParseUint(v.Len, 0, 64)
    61  		if err != nil {
    62  			t.Fatal(err)
    63  		}
    64  
    65  		got := exp.Expand([]byte(v.Msg), uint(lenBytes))
    66  		want, err := hex.DecodeString(v.UniformBytes)
    67  		if err != nil {
    68  			t.Fatal(err)
    69  		}
    70  
    71  		if !bytes.Equal(got, want) {
    72  			test.ReportError(t, got, want, i)
    73  		}
    74  	}
    75  }
    76  
    77  type vectorExpanderSuite struct {
    78  	DST   string `json:"DST"`
    79  	Hash  string `json:"hash"`
    80  	Name  string `json:"name"`
    81  	K     uint   `json:"k"`
    82  	Tests []struct {
    83  		DstPrime     string `json:"DST_prime"`
    84  		Len          string `json:"len_in_bytes"`
    85  		Msg          string `json:"msg"`
    86  		MsgPrime     string `json:"msg_prime"`
    87  		UniformBytes string `json:"uniform_bytes"`
    88  	} `json:"tests"`
    89  }
    90  
    91  func BenchmarkExpander(b *testing.B) {
    92  	in := []byte("input")
    93  	dst := []byte("dst")
    94  
    95  	for _, v := range []struct {
    96  		Name string
    97  		Exp  expander.Expander
    98  	}{
    99  		{"XMD", expander.NewExpanderMD(crypto.SHA256, dst)},
   100  		{"XOF", expander.NewExpanderXOF(xof.SHAKE128, 0, dst)},
   101  	} {
   102  		exp := v.Exp
   103  		for l := 8; l <= 10; l++ {
   104  			max := int64(1) << uint(l)
   105  
   106  			b.Run(fmt.Sprintf("%v/%v", v.Name, max), func(b *testing.B) {
   107  				b.SetBytes(max)
   108  				b.ResetTimer()
   109  				for i := 0; i < b.N; i++ {
   110  					exp.Expand(in, uint(max))
   111  				}
   112  			})
   113  		}
   114  	}
   115  }