github.com/cloudflare/circl@v1.5.0/oprf/vectors_test.go (about)

     1  package oprf
     2  
     3  import (
     4  	"bytes"
     5  	"encoding"
     6  	"encoding/binary"
     7  	"encoding/hex"
     8  	"encoding/json"
     9  	"fmt"
    10  	"io"
    11  	"os"
    12  	"strings"
    13  	"testing"
    14  
    15  	"github.com/cloudflare/circl/group"
    16  	"github.com/cloudflare/circl/internal/test"
    17  	"github.com/cloudflare/circl/zk/dleq"
    18  )
    19  
    20  type vector struct {
    21  	Identifier string `json:"identifier"`
    22  	Mode       Mode   `json:"mode"`
    23  	Hash       string `json:"hash"`
    24  	PkSm       string `json:"pkSm"`
    25  	SkSm       string `json:"skSm"`
    26  	Seed       string `json:"seed"`
    27  	KeyInfo    string `json:"keyInfo"`
    28  	GroupDST   string `json:"groupDST"`
    29  	Vectors    []struct {
    30  		Batch             int    `json:"Batch"`
    31  		Blind             string `json:"Blind"`
    32  		Info              string `json:"Info"`
    33  		BlindedElement    string `json:"BlindedElement"`
    34  		EvaluationElement string `json:"EvaluationElement"`
    35  		Proof             struct {
    36  			Proof string `json:"proof"`
    37  			R     string `json:"r"`
    38  		} `json:"Proof"`
    39  		Input  string `json:"Input"`
    40  		Output string `json:"Output"`
    41  	} `json:"vectors"`
    42  }
    43  
    44  func toBytes(t *testing.T, s, errMsg string) []byte {
    45  	t.Helper()
    46  	bytes, err := hex.DecodeString(s)
    47  	test.CheckNoErr(t, err, "decoding "+errMsg)
    48  
    49  	return bytes
    50  }
    51  
    52  func toListBytes(t *testing.T, s, errMsg string) [][]byte {
    53  	t.Helper()
    54  	strs := strings.Split(s, ",")
    55  	out := make([][]byte, len(strs))
    56  	for i := range strs {
    57  		out[i] = toBytes(t, strs[i], errMsg)
    58  	}
    59  
    60  	return out
    61  }
    62  
    63  func flattenList(t *testing.T, s, errMsg string) []byte {
    64  	t.Helper()
    65  	strs := strings.Split(s, ",")
    66  	out := []byte{0, 0}
    67  	binary.BigEndian.PutUint16(out, uint16(len(strs)))
    68  	for i := range strs {
    69  		out = append(out, toBytes(t, strs[i], errMsg)...)
    70  	}
    71  
    72  	return out
    73  }
    74  
    75  func toScalar(t *testing.T, g group.Group, s, errMsg string) group.Scalar {
    76  	t.Helper()
    77  	r := g.NewScalar()
    78  	rBytes := toBytes(t, s, errMsg)
    79  	err := r.UnmarshalBinary(rBytes)
    80  	test.CheckNoErr(t, err, errMsg)
    81  
    82  	return r
    83  }
    84  
    85  func readFile(t *testing.T, fileName string) []vector {
    86  	t.Helper()
    87  	jsonFile, err := os.Open(fileName)
    88  	if err != nil {
    89  		t.Fatalf("File %v can not be opened. Error: %v", fileName, err)
    90  	}
    91  	defer jsonFile.Close()
    92  	input, err := io.ReadAll(jsonFile)
    93  	if err != nil {
    94  		t.Fatalf("File %v can not be read. Error: %v", fileName, err)
    95  	}
    96  
    97  	var v []vector
    98  	err = json.Unmarshal(input, &v)
    99  	if err != nil {
   100  		t.Fatalf("File %v can not be loaded. Error: %v", fileName, err)
   101  	}
   102  
   103  	return v
   104  }
   105  
   106  func (v *vector) SetUpParties(t *testing.T) (id params, s commonServer, c commonClient) {
   107  	suite, err := GetSuite(v.Identifier)
   108  	test.CheckNoErr(t, err, "suite id")
   109  	seed := toBytes(t, v.Seed, "seed for key derivation")
   110  	test.CheckOk(len(seed) == 32, ErrInvalidSeed.Error(), t)
   111  	keyInfo := toBytes(t, v.KeyInfo, "info for key derivation")
   112  	privateKey, err := DeriveKey(suite, v.Mode, seed, keyInfo)
   113  	test.CheckNoErr(t, err, "deriving key")
   114  
   115  	got, err := privateKey.MarshalBinary()
   116  	test.CheckNoErr(t, err, "serializing private key")
   117  	want := toBytes(t, v.SkSm, "private key")
   118  	v.compareBytes(t, got, want)
   119  
   120  	switch v.Mode {
   121  	case BaseMode:
   122  		s = NewServer(suite, privateKey)
   123  		c = NewClient(suite)
   124  	case VerifiableMode:
   125  		s = NewVerifiableServer(suite, privateKey)
   126  		c = NewVerifiableClient(suite, s.PublicKey())
   127  	case PartialObliviousMode:
   128  		var info []byte
   129  		s = &s1{NewPartialObliviousServer(suite, privateKey), info}
   130  		c = &c1{NewPartialObliviousClient(suite, s.PublicKey()), info}
   131  	}
   132  
   133  	return suite.(params), s, c
   134  }
   135  
   136  func (v *vector) compareLists(t *testing.T, got, want [][]byte) {
   137  	t.Helper()
   138  	for i := range got {
   139  		if !bytes.Equal(got[i], want[i]) {
   140  			test.ReportError(t, got[i], want[i], v.Identifier, v.Mode, i)
   141  		}
   142  	}
   143  }
   144  
   145  func (v *vector) compareBytes(t *testing.T, got, want []byte) {
   146  	t.Helper()
   147  	if !bytes.Equal(got, want) {
   148  		test.ReportError(t, got, want, v.Identifier, v.Mode)
   149  	}
   150  }
   151  
   152  func (v *vector) test(t *testing.T) {
   153  	params, server, client := v.SetUpParties(t)
   154  
   155  	for i, vi := range v.Vectors {
   156  		if v.Mode == PartialObliviousMode {
   157  			info := toBytes(t, vi.Info, "info")
   158  			ss := server.(*s1)
   159  			cc := client.(*c1)
   160  			ss.info = info
   161  			cc.info = info
   162  		}
   163  
   164  		inputs := toListBytes(t, vi.Input, "input")
   165  		blindsBytes := toListBytes(t, vi.Blind, "blind")
   166  
   167  		blinds := make([]Blind, len(blindsBytes))
   168  		for j := range blindsBytes {
   169  			blinds[j] = params.group.NewScalar()
   170  			err := blinds[j].UnmarshalBinary(blindsBytes[j])
   171  			test.CheckNoErr(t, err, "invalid blind")
   172  		}
   173  
   174  		finData, evalReq, err := client.blind(inputs, blinds)
   175  		test.CheckNoErr(t, err, "invalid client request")
   176  		evalReqBytes, err := elementsMarshalBinary(params.group, evalReq.Elements)
   177  		test.CheckNoErr(t, err, "bad serialization")
   178  		v.compareBytes(t, evalReqBytes, flattenList(t, vi.BlindedElement, "blindedElement"))
   179  
   180  		eval, err := server.Evaluate(evalReq)
   181  		test.CheckNoErr(t, err, "invalid evaluation")
   182  		elemBytes, err := elementsMarshalBinary(params.group, eval.Elements)
   183  		test.CheckNoErr(t, err, "invalid evaluations marshaling")
   184  		v.compareBytes(t, elemBytes, flattenList(t, vi.EvaluationElement, "evaluation"))
   185  
   186  		if v.Mode == VerifiableMode || v.Mode == PartialObliviousMode {
   187  			randomness := toScalar(t, params.group, vi.Proof.R, "invalid proof random scalar")
   188  			var proof encoding.BinaryMarshaler
   189  			switch v.Mode {
   190  			case VerifiableMode:
   191  				ss := server.(VerifiableServer)
   192  				prover := dleq.Prover{Params: ss.getDLEQParams()}
   193  				proof, err = prover.ProveBatchWithRandomness(
   194  					ss.privateKey.k,
   195  					ss.params.group.Generator(),
   196  					server.PublicKey().e,
   197  					evalReq.Elements,
   198  					eval.Elements,
   199  					randomness)
   200  			case PartialObliviousMode:
   201  				ss := server.(*s1)
   202  				keyProof, _, _ := ss.secretFromInfo(ss.info)
   203  				prover := dleq.Prover{Params: ss.getDLEQParams()}
   204  				proof, err = prover.ProveBatchWithRandomness(
   205  					keyProof,
   206  					ss.params.group.Generator(),
   207  					ss.params.group.NewElement().MulGen(keyProof),
   208  					eval.Elements,
   209  					evalReq.Elements,
   210  					randomness)
   211  			}
   212  			test.CheckNoErr(t, err, "failed proof generation")
   213  			proofBytes, errr := proof.MarshalBinary()
   214  			test.CheckNoErr(t, errr, "failed proof marshaling")
   215  			v.compareBytes(t, proofBytes, toBytes(t, vi.Proof.Proof, "proof"))
   216  		}
   217  
   218  		outputs, err := client.Finalize(finData, eval)
   219  		test.CheckNoErr(t, err, "invalid finalize")
   220  		expectedOutputs := toListBytes(t, vi.Output, "output")
   221  		v.compareLists(t,
   222  			outputs,
   223  			expectedOutputs,
   224  		)
   225  
   226  		for j := range inputs {
   227  			output, err := server.FullEvaluate(inputs[j])
   228  			test.CheckNoErr(t, err, "invalid full evaluate")
   229  			got := output
   230  			want := expectedOutputs[j]
   231  			if !bytes.Equal(got, want) {
   232  				test.ReportError(t, got, want, v.Identifier, v.Mode, i, j)
   233  			}
   234  
   235  			test.CheckOk(server.VerifyFinalize(inputs[j], output), "verify finalize", t)
   236  		}
   237  	}
   238  }
   239  
   240  func TestVectors(t *testing.T) {
   241  	// RFC-9497 published at https://www.rfc-editor.org/info/rfc9497
   242  	// Test vectors at https://github.com/cfrg/draft-irtf-cfrg-voprf
   243  	v := readFile(t, "testdata/rfc9497.json")
   244  
   245  	for i := range v {
   246  		suite, err := GetSuite(v[i].Identifier)
   247  		if err != nil {
   248  			t.Log(v[i].Identifier + " not supported yet")
   249  			continue
   250  		}
   251  		t.Run(fmt.Sprintf("%v/Mode%v", suite, v[i].Mode), v[i].test)
   252  	}
   253  }