github.com/cloudflare/circl@v1.5.0/oprf/vectors_test.go (about) 1 package oprf 2 3 import ( 4 "bytes" 5 "encoding" 6 "encoding/binary" 7 "encoding/hex" 8 "encoding/json" 9 "fmt" 10 "io" 11 "os" 12 "strings" 13 "testing" 14 15 "github.com/cloudflare/circl/group" 16 "github.com/cloudflare/circl/internal/test" 17 "github.com/cloudflare/circl/zk/dleq" 18 ) 19 20 type vector struct { 21 Identifier string `json:"identifier"` 22 Mode Mode `json:"mode"` 23 Hash string `json:"hash"` 24 PkSm string `json:"pkSm"` 25 SkSm string `json:"skSm"` 26 Seed string `json:"seed"` 27 KeyInfo string `json:"keyInfo"` 28 GroupDST string `json:"groupDST"` 29 Vectors []struct { 30 Batch int `json:"Batch"` 31 Blind string `json:"Blind"` 32 Info string `json:"Info"` 33 BlindedElement string `json:"BlindedElement"` 34 EvaluationElement string `json:"EvaluationElement"` 35 Proof struct { 36 Proof string `json:"proof"` 37 R string `json:"r"` 38 } `json:"Proof"` 39 Input string `json:"Input"` 40 Output string `json:"Output"` 41 } `json:"vectors"` 42 } 43 44 func toBytes(t *testing.T, s, errMsg string) []byte { 45 t.Helper() 46 bytes, err := hex.DecodeString(s) 47 test.CheckNoErr(t, err, "decoding "+errMsg) 48 49 return bytes 50 } 51 52 func toListBytes(t *testing.T, s, errMsg string) [][]byte { 53 t.Helper() 54 strs := strings.Split(s, ",") 55 out := make([][]byte, len(strs)) 56 for i := range strs { 57 out[i] = toBytes(t, strs[i], errMsg) 58 } 59 60 return out 61 } 62 63 func flattenList(t *testing.T, s, errMsg string) []byte { 64 t.Helper() 65 strs := strings.Split(s, ",") 66 out := []byte{0, 0} 67 binary.BigEndian.PutUint16(out, uint16(len(strs))) 68 for i := range strs { 69 out = append(out, toBytes(t, strs[i], errMsg)...) 70 } 71 72 return out 73 } 74 75 func toScalar(t *testing.T, g group.Group, s, errMsg string) group.Scalar { 76 t.Helper() 77 r := g.NewScalar() 78 rBytes := toBytes(t, s, errMsg) 79 err := r.UnmarshalBinary(rBytes) 80 test.CheckNoErr(t, err, errMsg) 81 82 return r 83 } 84 85 func readFile(t *testing.T, fileName string) []vector { 86 t.Helper() 87 jsonFile, err := os.Open(fileName) 88 if err != nil { 89 t.Fatalf("File %v can not be opened. Error: %v", fileName, err) 90 } 91 defer jsonFile.Close() 92 input, err := io.ReadAll(jsonFile) 93 if err != nil { 94 t.Fatalf("File %v can not be read. Error: %v", fileName, err) 95 } 96 97 var v []vector 98 err = json.Unmarshal(input, &v) 99 if err != nil { 100 t.Fatalf("File %v can not be loaded. Error: %v", fileName, err) 101 } 102 103 return v 104 } 105 106 func (v *vector) SetUpParties(t *testing.T) (id params, s commonServer, c commonClient) { 107 suite, err := GetSuite(v.Identifier) 108 test.CheckNoErr(t, err, "suite id") 109 seed := toBytes(t, v.Seed, "seed for key derivation") 110 test.CheckOk(len(seed) == 32, ErrInvalidSeed.Error(), t) 111 keyInfo := toBytes(t, v.KeyInfo, "info for key derivation") 112 privateKey, err := DeriveKey(suite, v.Mode, seed, keyInfo) 113 test.CheckNoErr(t, err, "deriving key") 114 115 got, err := privateKey.MarshalBinary() 116 test.CheckNoErr(t, err, "serializing private key") 117 want := toBytes(t, v.SkSm, "private key") 118 v.compareBytes(t, got, want) 119 120 switch v.Mode { 121 case BaseMode: 122 s = NewServer(suite, privateKey) 123 c = NewClient(suite) 124 case VerifiableMode: 125 s = NewVerifiableServer(suite, privateKey) 126 c = NewVerifiableClient(suite, s.PublicKey()) 127 case PartialObliviousMode: 128 var info []byte 129 s = &s1{NewPartialObliviousServer(suite, privateKey), info} 130 c = &c1{NewPartialObliviousClient(suite, s.PublicKey()), info} 131 } 132 133 return suite.(params), s, c 134 } 135 136 func (v *vector) compareLists(t *testing.T, got, want [][]byte) { 137 t.Helper() 138 for i := range got { 139 if !bytes.Equal(got[i], want[i]) { 140 test.ReportError(t, got[i], want[i], v.Identifier, v.Mode, i) 141 } 142 } 143 } 144 145 func (v *vector) compareBytes(t *testing.T, got, want []byte) { 146 t.Helper() 147 if !bytes.Equal(got, want) { 148 test.ReportError(t, got, want, v.Identifier, v.Mode) 149 } 150 } 151 152 func (v *vector) test(t *testing.T) { 153 params, server, client := v.SetUpParties(t) 154 155 for i, vi := range v.Vectors { 156 if v.Mode == PartialObliviousMode { 157 info := toBytes(t, vi.Info, "info") 158 ss := server.(*s1) 159 cc := client.(*c1) 160 ss.info = info 161 cc.info = info 162 } 163 164 inputs := toListBytes(t, vi.Input, "input") 165 blindsBytes := toListBytes(t, vi.Blind, "blind") 166 167 blinds := make([]Blind, len(blindsBytes)) 168 for j := range blindsBytes { 169 blinds[j] = params.group.NewScalar() 170 err := blinds[j].UnmarshalBinary(blindsBytes[j]) 171 test.CheckNoErr(t, err, "invalid blind") 172 } 173 174 finData, evalReq, err := client.blind(inputs, blinds) 175 test.CheckNoErr(t, err, "invalid client request") 176 evalReqBytes, err := elementsMarshalBinary(params.group, evalReq.Elements) 177 test.CheckNoErr(t, err, "bad serialization") 178 v.compareBytes(t, evalReqBytes, flattenList(t, vi.BlindedElement, "blindedElement")) 179 180 eval, err := server.Evaluate(evalReq) 181 test.CheckNoErr(t, err, "invalid evaluation") 182 elemBytes, err := elementsMarshalBinary(params.group, eval.Elements) 183 test.CheckNoErr(t, err, "invalid evaluations marshaling") 184 v.compareBytes(t, elemBytes, flattenList(t, vi.EvaluationElement, "evaluation")) 185 186 if v.Mode == VerifiableMode || v.Mode == PartialObliviousMode { 187 randomness := toScalar(t, params.group, vi.Proof.R, "invalid proof random scalar") 188 var proof encoding.BinaryMarshaler 189 switch v.Mode { 190 case VerifiableMode: 191 ss := server.(VerifiableServer) 192 prover := dleq.Prover{Params: ss.getDLEQParams()} 193 proof, err = prover.ProveBatchWithRandomness( 194 ss.privateKey.k, 195 ss.params.group.Generator(), 196 server.PublicKey().e, 197 evalReq.Elements, 198 eval.Elements, 199 randomness) 200 case PartialObliviousMode: 201 ss := server.(*s1) 202 keyProof, _, _ := ss.secretFromInfo(ss.info) 203 prover := dleq.Prover{Params: ss.getDLEQParams()} 204 proof, err = prover.ProveBatchWithRandomness( 205 keyProof, 206 ss.params.group.Generator(), 207 ss.params.group.NewElement().MulGen(keyProof), 208 eval.Elements, 209 evalReq.Elements, 210 randomness) 211 } 212 test.CheckNoErr(t, err, "failed proof generation") 213 proofBytes, errr := proof.MarshalBinary() 214 test.CheckNoErr(t, errr, "failed proof marshaling") 215 v.compareBytes(t, proofBytes, toBytes(t, vi.Proof.Proof, "proof")) 216 } 217 218 outputs, err := client.Finalize(finData, eval) 219 test.CheckNoErr(t, err, "invalid finalize") 220 expectedOutputs := toListBytes(t, vi.Output, "output") 221 v.compareLists(t, 222 outputs, 223 expectedOutputs, 224 ) 225 226 for j := range inputs { 227 output, err := server.FullEvaluate(inputs[j]) 228 test.CheckNoErr(t, err, "invalid full evaluate") 229 got := output 230 want := expectedOutputs[j] 231 if !bytes.Equal(got, want) { 232 test.ReportError(t, got, want, v.Identifier, v.Mode, i, j) 233 } 234 235 test.CheckOk(server.VerifyFinalize(inputs[j], output), "verify finalize", t) 236 } 237 } 238 } 239 240 func TestVectors(t *testing.T) { 241 // RFC-9497 published at https://www.rfc-editor.org/info/rfc9497 242 // Test vectors at https://github.com/cfrg/draft-irtf-cfrg-voprf 243 v := readFile(t, "testdata/rfc9497.json") 244 245 for i := range v { 246 suite, err := GetSuite(v[i].Identifier) 247 if err != nil { 248 t.Log(v[i].Identifier + " not supported yet") 249 continue 250 } 251 t.Run(fmt.Sprintf("%v/Mode%v", suite, v[i].Mode), v[i].test) 252 } 253 }