github.com/cloudflare/circl@v1.5.0/oprf/oprf_test.go (about)

     1  package oprf
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/rand"
     6  	"encoding"
     7  	"encoding/binary"
     8  	"fmt"
     9  	"testing"
    10  
    11  	"github.com/cloudflare/circl/group"
    12  	"github.com/cloudflare/circl/internal/test"
    13  )
    14  
    15  type commonClient interface {
    16  	blind(inputs [][]byte, blinds []Blind) (*FinalizeData, *EvaluationRequest, error)
    17  	DeterministicBlind(inputs [][]byte, blinds []Blind) (*FinalizeData, *EvaluationRequest, error)
    18  	Blind(inputs [][]byte) (*FinalizeData, *EvaluationRequest, error)
    19  	Finalize(d *FinalizeData, e *Evaluation) ([][]byte, error)
    20  }
    21  
    22  type c1 struct {
    23  	PartialObliviousClient
    24  	info []byte
    25  }
    26  
    27  func (c *c1) Finalize(f *FinalizeData, e *Evaluation) ([][]byte, error) {
    28  	return c.PartialObliviousClient.Finalize(f, e, c.info)
    29  }
    30  
    31  type commonServer interface {
    32  	Evaluate(req *EvaluationRequest) (*Evaluation, error)
    33  	FullEvaluate(input []byte) ([]byte, error)
    34  	VerifyFinalize(input, expectedOutput []byte) bool
    35  	PublicKey() *PublicKey
    36  }
    37  
    38  type s1 struct {
    39  	PartialObliviousServer
    40  	info []byte
    41  }
    42  
    43  func (s *s1) Evaluate(req *EvaluationRequest) (*Evaluation, error) {
    44  	return s.PartialObliviousServer.Evaluate(req, s.info)
    45  }
    46  
    47  func (s *s1) FullEvaluate(input []byte) ([]byte, error) {
    48  	return s.PartialObliviousServer.FullEvaluate(input, s.info)
    49  }
    50  
    51  func (s *s1) VerifyFinalize(input, expectedOutput []byte) bool {
    52  	return s.PartialObliviousServer.VerifyFinalize(input, s.info, expectedOutput)
    53  }
    54  
    55  type canMarshal interface {
    56  	encoding.BinaryMarshaler
    57  	UnmarshalBinary(id Suite, data []byte) (err error)
    58  }
    59  
    60  func testMarshal(t *testing.T, suite Suite, x, y canMarshal, name string) {
    61  	t.Helper()
    62  
    63  	wantBytes, err := x.MarshalBinary()
    64  	test.CheckNoErr(t, err, "error on marshaling "+name)
    65  
    66  	err = y.UnmarshalBinary(suite, wantBytes)
    67  	test.CheckNoErr(t, err, "error on unmarshaling "+name)
    68  
    69  	gotBytes, err := x.MarshalBinary()
    70  	test.CheckNoErr(t, err, "error on marshaling "+name)
    71  
    72  	if !bytes.Equal(gotBytes, wantBytes) {
    73  		test.ReportError(t, gotBytes, wantBytes)
    74  	}
    75  }
    76  
    77  func elementsMarshalBinary(g group.Group, e []group.Element) ([]byte, error) {
    78  	output := make([]byte, 2, 2+len(e)*int(g.Params().CompressedElementLength))
    79  	binary.BigEndian.PutUint16(output[0:2], uint16(len(e)))
    80  
    81  	for i := range e {
    82  		ei, err := e[i].MarshalBinaryCompress()
    83  		if err != nil {
    84  			return nil, err
    85  		}
    86  		output = append(output, ei...)
    87  	}
    88  
    89  	return output, nil
    90  }
    91  
    92  func testAPI(t *testing.T, server commonServer, client commonClient) {
    93  	t.Helper()
    94  
    95  	inputs := [][]byte{{0x00}, {0xFF}}
    96  	finData, evalReq, err := client.Blind(inputs)
    97  	test.CheckNoErr(t, err, "invalid blinding of client")
    98  
    99  	blinds := finData.CopyBlinds()
   100  	_, detEvalReq, err := client.DeterministicBlind(inputs, blinds)
   101  	test.CheckNoErr(t, err, "invalid deterministic blinding of client")
   102  	test.CheckOk(len(detEvalReq.Elements) == len(evalReq.Elements), "invalid number of evaluations", t)
   103  	for i := range evalReq.Elements {
   104  		test.CheckOk(evalReq.Elements[i].IsEqual(detEvalReq.Elements[i]), "invalid blinded element mismatch", t)
   105  	}
   106  
   107  	eval, err := server.Evaluate(evalReq)
   108  	test.CheckNoErr(t, err, "invalid evaluation of server")
   109  	test.CheckOk(eval != nil, "invalid evaluation of server: no evaluation", t)
   110  
   111  	clientOutputs, err := client.Finalize(finData, eval)
   112  	test.CheckNoErr(t, err, "invalid finalize of client")
   113  	test.CheckOk(clientOutputs != nil, "invalid finalize of client: no outputs", t)
   114  
   115  	for i := range inputs {
   116  		valid := server.VerifyFinalize(inputs[i], clientOutputs[i])
   117  		test.CheckOk(valid, "invalid verification from the server", t)
   118  
   119  		serverOutput, err := server.FullEvaluate(inputs[i])
   120  		test.CheckNoErr(t, err, "FullEvaluate failed")
   121  
   122  		if !bytes.Equal(serverOutput, clientOutputs[i]) {
   123  			test.ReportError(t, serverOutput, clientOutputs[i])
   124  		}
   125  	}
   126  }
   127  
   128  func TestAPI(t *testing.T) {
   129  	info := []byte("shared info")
   130  
   131  	for _, suite := range []Suite{
   132  		SuiteRistretto255,
   133  		SuiteP256,
   134  		SuiteP384,
   135  		SuiteP521,
   136  	} {
   137  		t.Run(suite.(fmt.Stringer).String(), func(t *testing.T) {
   138  			private, err := GenerateKey(suite, rand.Reader)
   139  			test.CheckNoErr(t, err, "failed private key generation")
   140  			testMarshal(t, suite, private, new(PrivateKey), "PrivateKey")
   141  			public := private.Public()
   142  			testMarshal(t, suite, public, new(PublicKey), "PublicKey")
   143  
   144  			t.Run("OPRF", func(t *testing.T) {
   145  				s := NewServer(suite, private)
   146  				c := NewClient(suite)
   147  				testAPI(t, s, c)
   148  			})
   149  
   150  			t.Run("VOPRF", func(t *testing.T) {
   151  				s := NewVerifiableServer(suite, private)
   152  				c := NewVerifiableClient(suite, s.PublicKey())
   153  				testAPI(t, s, c)
   154  			})
   155  
   156  			t.Run("POPRF", func(t *testing.T) {
   157  				s := &s1{NewPartialObliviousServer(suite, private), info}
   158  				c := &c1{NewPartialObliviousClient(suite, s.PublicKey()), info}
   159  				testAPI(t, s, c)
   160  			})
   161  		})
   162  	}
   163  }
   164  
   165  func TestErrors(t *testing.T) {
   166  	goodID := SuiteP256
   167  	strErrNil := "must be nil"
   168  	strErrK := "must fail key"
   169  	strErrC := "must fail client"
   170  	strErrS := "must fail server"
   171  
   172  	t.Run("badID", func(t *testing.T) {
   173  		var badID Suite
   174  
   175  		k, err := GenerateKey(badID, rand.Reader)
   176  		test.CheckIsErr(t, err, strErrK)
   177  		test.CheckOk(k == nil, strErrNil, t)
   178  
   179  		k, err = DeriveKey(badID, BaseMode, nil, nil)
   180  		test.CheckIsErr(t, err, strErrK)
   181  		test.CheckOk(k == nil, strErrNil, t)
   182  
   183  		err = new(PrivateKey).UnmarshalBinary(badID, nil)
   184  		test.CheckIsErr(t, err, strErrK)
   185  
   186  		err = new(PublicKey).UnmarshalBinary(badID, nil)
   187  		test.CheckIsErr(t, err, strErrK)
   188  
   189  		err = test.CheckPanic(func() { NewClient(badID) })
   190  		test.CheckNoErr(t, err, strErrC)
   191  
   192  		err = test.CheckPanic(func() { NewServer(badID, nil) })
   193  		test.CheckNoErr(t, err, strErrS)
   194  
   195  		err = test.CheckPanic(func() { NewVerifiableClient(badID, nil) })
   196  		test.CheckNoErr(t, err, strErrC)
   197  
   198  		err = test.CheckPanic(func() { NewVerifiableServer(badID, nil) })
   199  		test.CheckNoErr(t, err, strErrS)
   200  
   201  		err = test.CheckPanic(func() { NewPartialObliviousClient(badID, nil) })
   202  		test.CheckNoErr(t, err, strErrC)
   203  
   204  		err = test.CheckPanic(func() { NewPartialObliviousServer(badID, nil) })
   205  		test.CheckNoErr(t, err, strErrS)
   206  	})
   207  
   208  	t.Run("nilPubKey", func(t *testing.T) {
   209  		err := test.CheckPanic(func() { NewVerifiableClient(goodID, nil) })
   210  		test.CheckNoErr(t, err, strErrC)
   211  	})
   212  
   213  	t.Run("nilCalls", func(t *testing.T) {
   214  		c := NewClient(goodID)
   215  		finData, evalReq, err := c.Blind(nil)
   216  		test.CheckIsErr(t, err, strErrC)
   217  		test.CheckOk(finData == nil, strErrNil, t)
   218  		test.CheckOk(evalReq == nil, strErrNil, t)
   219  
   220  		var emptyEval Evaluation
   221  		finData, _, _ = c.Blind([][]byte{[]byte("in0"), []byte("in1")})
   222  		out, err := c.Finalize(finData, &emptyEval)
   223  		test.CheckIsErr(t, err, strErrC)
   224  		test.CheckOk(out == nil, strErrNil, t)
   225  	})
   226  
   227  	t.Run("invalidProof", func(t *testing.T) {
   228  		key, _ := GenerateKey(goodID, rand.Reader)
   229  		s := NewVerifiableServer(goodID, key)
   230  		c := NewVerifiableClient(goodID, key.Public())
   231  
   232  		finData, evalReq, _ := c.Blind([][]byte{[]byte("in0"), []byte("in1")})
   233  		_, _ = s.Evaluate(evalReq)
   234  		_, evalReq, _ = c.Blind([][]byte{[]byte("in0"), []byte("in2")})
   235  		badEV, _ := s.Evaluate(evalReq)
   236  		_, err := c.Finalize(finData, badEV)
   237  		test.CheckIsErr(t, err, strErrC)
   238  	})
   239  
   240  	t.Run("badKeyGen", func(t *testing.T) {
   241  		key, err := GenerateKey(goodID, nil)
   242  		test.CheckIsErr(t, err, strErrNil)
   243  		test.CheckOk(key == nil, strErrNil, t)
   244  
   245  		key, err = DeriveKey(goodID, Mode(8), nil, nil)
   246  		test.CheckIsErr(t, err, strErrK)
   247  		test.CheckOk(key == nil, strErrNil, t)
   248  	})
   249  }
   250  
   251  func Example_oprf() {
   252  	suite := SuiteP256
   253  	//                                  Server(sk, pk, info*)
   254  	private, _ := GenerateKey(suite, rand.Reader)
   255  	server := NewServer(suite, private)
   256  	//   Client(info*)
   257  	client := NewClient(suite)
   258  	//   =================================================================
   259  	//   finData, evalReq = Blind(input)
   260  	inputs := [][]byte{[]byte("first input"), []byte("second input")}
   261  	finData, evalReq, _ := client.Blind(inputs)
   262  	//
   263  	//                               evalReq
   264  	//                             ---------->
   265  	//
   266  	//                               evaluation = Evaluate(evalReq, info*)
   267  	evaluation, _ := server.Evaluate(evalReq)
   268  	//
   269  	//                              evaluation
   270  	//                             <----------
   271  	//
   272  	//   output = Finalize(finData, evaluation, info*)
   273  	outputs, err := client.Finalize(finData, evaluation)
   274  	fmt.Print(err == nil && len(inputs) == len(outputs))
   275  	// Output: true
   276  }
   277  
   278  func BenchmarkAPI(b *testing.B) {
   279  	for _, suite := range []Suite{
   280  		SuiteRistretto255,
   281  		SuiteP256,
   282  		SuiteP384,
   283  		SuiteP521,
   284  	} {
   285  		key, err := GenerateKey(suite, rand.Reader)
   286  		test.CheckNoErr(b, err, "failed key generation")
   287  
   288  		b.Run("OPRF/"+suite.Identifier(), func(b *testing.B) {
   289  			s := NewServer(suite, key)
   290  			c := NewClient(suite)
   291  			benchAPI(b, s, c)
   292  		})
   293  
   294  		b.Run("VOPRF/"+suite.Identifier(), func(b *testing.B) {
   295  			s := NewVerifiableServer(suite, key)
   296  			c := NewVerifiableClient(suite, s.PublicKey())
   297  			benchAPI(b, s, c)
   298  		})
   299  
   300  		b.Run("POPRF/"+suite.Identifier(), func(b *testing.B) {
   301  			info := []byte("shared info")
   302  			s := &s1{NewPartialObliviousServer(suite, key), info}
   303  			c := &c1{NewPartialObliviousClient(suite, s.PublicKey()), info}
   304  			benchAPI(b, s, c)
   305  		})
   306  	}
   307  }
   308  
   309  func benchAPI(b *testing.B, server commonServer, client commonClient) {
   310  	b.Helper()
   311  	inputs := [][]byte{[]byte("first input"), []byte("second input")}
   312  	finData, evalReq, err := client.Blind(inputs)
   313  	test.CheckNoErr(b, err, "failed client request")
   314  
   315  	eval, err := server.Evaluate(evalReq)
   316  	test.CheckNoErr(b, err, "failed server evaluate")
   317  
   318  	clientOutputs, err := client.Finalize(finData, eval)
   319  	test.CheckNoErr(b, err, "failed client finalize")
   320  
   321  	b.Run("Client/Request", func(b *testing.B) {
   322  		for i := 0; i < b.N; i++ {
   323  			_, _, _ = client.Blind(inputs)
   324  		}
   325  	})
   326  
   327  	b.Run("Server/Evaluate", func(b *testing.B) {
   328  		for i := 0; i < b.N; i++ {
   329  			_, _ = server.Evaluate(evalReq)
   330  		}
   331  	})
   332  
   333  	b.Run("Client/Finalize", func(b *testing.B) {
   334  		for i := 0; i < b.N; i++ {
   335  			_, _ = client.Finalize(finData, eval)
   336  		}
   337  	})
   338  
   339  	b.Run("Server/VerifyFinalize", func(b *testing.B) {
   340  		for i := 0; i < b.N; i++ {
   341  			for j := range inputs {
   342  				server.VerifyFinalize(inputs[j], clientOutputs[j])
   343  			}
   344  		}
   345  	})
   346  
   347  	b.Run("Server/FullEvaluate", func(b *testing.B) {
   348  		for i := 0; i < b.N; i++ {
   349  			for j := range inputs {
   350  				_, _ = server.FullEvaluate(inputs[j])
   351  			}
   352  		}
   353  	})
   354  }