github.com/gogo/protobuf@v1.3.2/proto/extensions_test.go (about) 1 // Go support for Protocol Buffers - Google's data interchange format 2 // 3 // Copyright 2014 The Go Authors. All rights reserved. 4 // https://github.com/golang/protobuf 5 // 6 // Redistribution and use in source and binary forms, with or without 7 // modification, are permitted provided that the following conditions are 8 // met: 9 // 10 // * Redistributions of source code must retain the above copyright 11 // notice, this list of conditions and the following disclaimer. 12 // * Redistributions in binary form must reproduce the above 13 // copyright notice, this list of conditions and the following disclaimer 14 // in the documentation and/or other materials provided with the 15 // distribution. 16 // * Neither the name of Google Inc. nor the names of its 17 // contributors may be used to endorse or promote products derived from 18 // this software without specific prior written permission. 19 // 20 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 21 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 22 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 23 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 24 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 25 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 26 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 27 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 28 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 32 package proto_test 33 34 import ( 35 "bytes" 36 "fmt" 37 "io" 38 "reflect" 39 "sort" 40 "strings" 41 "testing" 42 43 "github.com/gogo/protobuf/proto" 44 pb "github.com/gogo/protobuf/proto/test_proto" 45 ) 46 47 func TestGetExtensionsWithMissingExtensions(t *testing.T) { 48 msg := &pb.MyMessage{} 49 ext1 := &pb.Ext{} 50 if err := proto.SetExtension(msg, pb.E_Ext_More, ext1); err != nil { 51 t.Fatalf("Could not set ext1: %s", err) 52 } 53 exts, err := proto.GetExtensions(msg, []*proto.ExtensionDesc{ 54 pb.E_Ext_More, 55 pb.E_Ext_Text, 56 }) 57 if err != nil { 58 t.Fatalf("GetExtensions() failed: %s", err) 59 } 60 if exts[0] != ext1 { 61 t.Errorf("ext1 not in returned extensions: %T %v", exts[0], exts[0]) 62 } 63 if exts[1] != nil { 64 t.Errorf("ext2 in returned extensions: %T %v", exts[1], exts[1]) 65 } 66 } 67 68 func TestGetExtensionWithEmptyBuffer(t *testing.T) { 69 // Make sure that GetExtension returns an error if its 70 // undecoded buffer is empty. 71 msg := &pb.MyMessage{} 72 proto.SetRawExtension(msg, pb.E_Ext_More.Field, []byte{}) 73 _, err := proto.GetExtension(msg, pb.E_Ext_More) 74 if want := io.ErrUnexpectedEOF; err != want { 75 t.Errorf("unexpected error in GetExtension from empty buffer: got %v, want %v", err, want) 76 } 77 } 78 79 func TestGetExtensionForIncompleteDesc(t *testing.T) { 80 msg := &pb.MyMessage{Count: proto.Int32(0)} 81 extdesc1 := &proto.ExtensionDesc{ 82 ExtendedType: (*pb.MyMessage)(nil), 83 ExtensionType: (*bool)(nil), 84 Field: 123456789, 85 Name: "a.b", 86 Tag: "varint,123456789,opt", 87 } 88 ext1 := proto.Bool(true) 89 if err := proto.SetExtension(msg, extdesc1, ext1); err != nil { 90 t.Fatalf("Could not set ext1: %s", err) 91 } 92 extdesc2 := &proto.ExtensionDesc{ 93 ExtendedType: (*pb.MyMessage)(nil), 94 ExtensionType: ([]byte)(nil), 95 Field: 123456790, 96 Name: "a.c", 97 Tag: "bytes,123456790,opt", 98 } 99 ext2 := []byte{0, 1, 2, 3, 4, 5, 6, 7} 100 if err := proto.SetExtension(msg, extdesc2, ext2); err != nil { 101 t.Fatalf("Could not set ext2: %s", err) 102 } 103 extdesc3 := &proto.ExtensionDesc{ 104 ExtendedType: (*pb.MyMessage)(nil), 105 ExtensionType: (*pb.Ext)(nil), 106 Field: 123456791, 107 Name: "a.d", 108 Tag: "bytes,123456791,opt", 109 } 110 ext3 := &pb.Ext{Data: proto.String("foo")} 111 if err := proto.SetExtension(msg, extdesc3, ext3); err != nil { 112 t.Fatalf("Could not set ext3: %s", err) 113 } 114 115 b, err := proto.Marshal(msg) 116 if err != nil { 117 t.Fatalf("Could not marshal msg: %v", err) 118 } 119 if err := proto.Unmarshal(b, msg); err != nil { 120 t.Fatalf("Could not unmarshal into msg: %v", err) 121 } 122 123 var expected proto.Buffer 124 if err := expected.EncodeVarint(uint64((extdesc1.Field << 3) | proto.WireVarint)); err != nil { 125 t.Fatalf("failed to compute expected prefix for ext1: %s", err) 126 } 127 if err := expected.EncodeVarint(1 /* bool true */); err != nil { 128 t.Fatalf("failed to compute expected value for ext1: %s", err) 129 } 130 131 if b, err := proto.GetExtension(msg, &proto.ExtensionDesc{Field: extdesc1.Field}); err != nil { 132 t.Fatalf("Failed to get raw value for ext1: %s", err) 133 } else if !reflect.DeepEqual(b, expected.Bytes()) { 134 t.Fatalf("Raw value for ext1: got %v, want %v", b, expected.Bytes()) 135 } 136 137 expected = proto.Buffer{} // reset 138 if err := expected.EncodeVarint(uint64((extdesc2.Field << 3) | proto.WireBytes)); err != nil { 139 t.Fatalf("failed to compute expected prefix for ext2: %s", err) 140 } 141 if err := expected.EncodeRawBytes(ext2); err != nil { 142 t.Fatalf("failed to compute expected value for ext2: %s", err) 143 } 144 145 if b, err := proto.GetExtension(msg, &proto.ExtensionDesc{Field: extdesc2.Field}); err != nil { 146 t.Fatalf("Failed to get raw value for ext2: %s", err) 147 } else if !reflect.DeepEqual(b, expected.Bytes()) { 148 t.Fatalf("Raw value for ext2: got %v, want %v", b, expected.Bytes()) 149 } 150 151 expected = proto.Buffer{} // reset 152 if err := expected.EncodeVarint(uint64((extdesc3.Field << 3) | proto.WireBytes)); err != nil { 153 t.Fatalf("failed to compute expected prefix for ext3: %s", err) 154 } 155 if b, err := proto.Marshal(ext3); err != nil { 156 t.Fatalf("failed to compute expected value for ext3: %s", err) 157 } else if err := expected.EncodeRawBytes(b); err != nil { 158 t.Fatalf("failed to compute expected value for ext3: %s", err) 159 } 160 161 if b, err := proto.GetExtension(msg, &proto.ExtensionDesc{Field: extdesc3.Field}); err != nil { 162 t.Fatalf("Failed to get raw value for ext3: %s", err) 163 } else if !reflect.DeepEqual(b, expected.Bytes()) { 164 t.Fatalf("Raw value for ext3: got %v, want %v", b, expected.Bytes()) 165 } 166 } 167 168 func TestExtensionDescsWithUnregisteredExtensions(t *testing.T) { 169 msg := &pb.MyMessage{Count: proto.Int32(0)} 170 extdesc1 := pb.E_Ext_More 171 if descs, err := proto.ExtensionDescs(msg); len(descs) != 0 || err != nil { 172 t.Errorf("proto.ExtensionDescs: got %d descs, error %v; want 0, nil", len(descs), err) 173 } 174 175 ext1 := &pb.Ext{} 176 if err := proto.SetExtension(msg, extdesc1, ext1); err != nil { 177 t.Fatalf("Could not set ext1: %s", err) 178 } 179 extdesc2 := &proto.ExtensionDesc{ 180 ExtendedType: (*pb.MyMessage)(nil), 181 ExtensionType: (*bool)(nil), 182 Field: 123456789, 183 Name: "a.b", 184 Tag: "varint,123456789,opt", 185 } 186 ext2 := proto.Bool(false) 187 if err := proto.SetExtension(msg, extdesc2, ext2); err != nil { 188 t.Fatalf("Could not set ext2: %s", err) 189 } 190 191 b, err := proto.Marshal(msg) 192 if err != nil { 193 t.Fatalf("Could not marshal msg: %v", err) 194 } 195 if err = proto.Unmarshal(b, msg); err != nil { 196 t.Fatalf("Could not unmarshal into msg: %v", err) 197 } 198 199 descs, err := proto.ExtensionDescs(msg) 200 if err != nil { 201 t.Fatalf("proto.ExtensionDescs: got error %v", err) 202 } 203 sortExtDescs(descs) 204 wantDescs := []*proto.ExtensionDesc{extdesc1, {Field: extdesc2.Field}} 205 if !reflect.DeepEqual(descs, wantDescs) { 206 t.Errorf("proto.ExtensionDescs(msg) sorted extension ids: got %+v, want %+v", descs, wantDescs) 207 } 208 } 209 210 type ExtensionDescSlice []*proto.ExtensionDesc 211 212 func (s ExtensionDescSlice) Len() int { return len(s) } 213 func (s ExtensionDescSlice) Less(i, j int) bool { return s[i].Field < s[j].Field } 214 func (s ExtensionDescSlice) Swap(i, j int) { s[i], s[j] = s[j], s[i] } 215 216 func sortExtDescs(s []*proto.ExtensionDesc) { 217 sort.Sort(ExtensionDescSlice(s)) 218 } 219 220 func TestGetExtensionStability(t *testing.T) { 221 check := func(m *pb.MyMessage) bool { 222 ext1, err := proto.GetExtension(m, pb.E_Ext_More) 223 if err != nil { 224 t.Fatalf("GetExtension() failed: %s", err) 225 } 226 ext2, err := proto.GetExtension(m, pb.E_Ext_More) 227 if err != nil { 228 t.Fatalf("GetExtension() failed: %s", err) 229 } 230 return ext1 == ext2 231 } 232 msg := &pb.MyMessage{Count: proto.Int32(4)} 233 ext0 := &pb.Ext{} 234 if err := proto.SetExtension(msg, pb.E_Ext_More, ext0); err != nil { 235 t.Fatalf("Could not set ext1: %s", ext0) 236 } 237 if !check(msg) { 238 t.Errorf("GetExtension() not stable before marshaling") 239 } 240 bb, err := proto.Marshal(msg) 241 if err != nil { 242 t.Fatalf("Marshal() failed: %s", err) 243 } 244 msg1 := &pb.MyMessage{} 245 err = proto.Unmarshal(bb, msg1) 246 if err != nil { 247 t.Fatalf("Unmarshal() failed: %s", err) 248 } 249 if !check(msg1) { 250 t.Errorf("GetExtension() not stable after unmarshaling") 251 } 252 } 253 254 func TestGetExtensionDefaults(t *testing.T) { 255 var setFloat64 float64 = 1 256 var setFloat32 float32 = 2 257 var setInt32 int32 = 3 258 var setInt64 int64 = 4 259 var setUint32 uint32 = 5 260 var setUint64 uint64 = 6 261 var setBool = true 262 var setBool2 = false 263 var setString = "Goodnight string" 264 var setBytes = []byte("Goodnight bytes") 265 var setEnum = pb.DefaultsMessage_TWO 266 267 type testcase struct { 268 ext *proto.ExtensionDesc // Extension we are testing. 269 want interface{} // Expected value of extension, or nil (meaning that GetExtension will fail). 270 def interface{} // Expected value of extension after ClearExtension(). 271 } 272 tests := []testcase{ 273 {pb.E_NoDefaultDouble, setFloat64, nil}, 274 {pb.E_NoDefaultFloat, setFloat32, nil}, 275 {pb.E_NoDefaultInt32, setInt32, nil}, 276 {pb.E_NoDefaultInt64, setInt64, nil}, 277 {pb.E_NoDefaultUint32, setUint32, nil}, 278 {pb.E_NoDefaultUint64, setUint64, nil}, 279 {pb.E_NoDefaultSint32, setInt32, nil}, 280 {pb.E_NoDefaultSint64, setInt64, nil}, 281 {pb.E_NoDefaultFixed32, setUint32, nil}, 282 {pb.E_NoDefaultFixed64, setUint64, nil}, 283 {pb.E_NoDefaultSfixed32, setInt32, nil}, 284 {pb.E_NoDefaultSfixed64, setInt64, nil}, 285 {pb.E_NoDefaultBool, setBool, nil}, 286 {pb.E_NoDefaultBool, setBool2, nil}, 287 {pb.E_NoDefaultString, setString, nil}, 288 {pb.E_NoDefaultBytes, setBytes, nil}, 289 {pb.E_NoDefaultEnum, setEnum, nil}, 290 {pb.E_DefaultDouble, setFloat64, float64(3.1415)}, 291 {pb.E_DefaultFloat, setFloat32, float32(3.14)}, 292 {pb.E_DefaultInt32, setInt32, int32(42)}, 293 {pb.E_DefaultInt64, setInt64, int64(43)}, 294 {pb.E_DefaultUint32, setUint32, uint32(44)}, 295 {pb.E_DefaultUint64, setUint64, uint64(45)}, 296 {pb.E_DefaultSint32, setInt32, int32(46)}, 297 {pb.E_DefaultSint64, setInt64, int64(47)}, 298 {pb.E_DefaultFixed32, setUint32, uint32(48)}, 299 {pb.E_DefaultFixed64, setUint64, uint64(49)}, 300 {pb.E_DefaultSfixed32, setInt32, int32(50)}, 301 {pb.E_DefaultSfixed64, setInt64, int64(51)}, 302 {pb.E_DefaultBool, setBool, true}, 303 {pb.E_DefaultBool, setBool2, true}, 304 {pb.E_DefaultString, setString, "Hello, string,def=foo"}, 305 {pb.E_DefaultBytes, setBytes, []byte("Hello, bytes")}, 306 {pb.E_DefaultEnum, setEnum, pb.DefaultsMessage_ONE}, 307 } 308 309 checkVal := func(test testcase, msg *pb.DefaultsMessage, valWant interface{}) error { 310 val, err := proto.GetExtension(msg, test.ext) 311 if err != nil { 312 if valWant != nil { 313 return fmt.Errorf("GetExtension(): %s", err) 314 } 315 if want := proto.ErrMissingExtension; err != want { 316 return fmt.Errorf("Unexpected error: got %v, want %v", err, want) 317 } 318 return nil 319 } 320 321 // All proto2 extension values are either a pointer to a value or a slice of values. 322 ty := reflect.TypeOf(val) 323 tyWant := reflect.TypeOf(test.ext.ExtensionType) 324 if got, want := ty, tyWant; got != want { 325 return fmt.Errorf("unexpected reflect.TypeOf(): got %v want %v", got, want) 326 } 327 tye := ty.Elem() 328 tyeWant := tyWant.Elem() 329 if got, want := tye, tyeWant; got != want { 330 return fmt.Errorf("unexpected reflect.TypeOf().Elem(): got %v want %v", got, want) 331 } 332 333 // Check the name of the type of the value. 334 // If it is an enum it will be type int32 with the name of the enum. 335 if got, want := tye.Name(), tye.Name(); got != want { 336 return fmt.Errorf("unexpected reflect.TypeOf().Elem().Name(): got %v want %v", got, want) 337 } 338 339 // Check that value is what we expect. 340 // If we have a pointer in val, get the value it points to. 341 valExp := val 342 if ty.Kind() == reflect.Ptr { 343 valExp = reflect.ValueOf(val).Elem().Interface() 344 } 345 if got, want := valExp, valWant; !reflect.DeepEqual(got, want) { 346 return fmt.Errorf("unexpected reflect.DeepEqual(): got %v want %v", got, want) 347 } 348 349 return nil 350 } 351 352 setTo := func(test testcase) interface{} { 353 setTo := reflect.ValueOf(test.want) 354 if typ := reflect.TypeOf(test.ext.ExtensionType); typ.Kind() == reflect.Ptr { 355 setTo = reflect.New(typ).Elem() 356 setTo.Set(reflect.New(setTo.Type().Elem())) 357 setTo.Elem().Set(reflect.ValueOf(test.want)) 358 } 359 return setTo.Interface() 360 } 361 362 for _, test := range tests { 363 msg := &pb.DefaultsMessage{} 364 name := test.ext.Name 365 366 // Check the initial value. 367 if err := checkVal(test, msg, test.def); err != nil { 368 t.Errorf("%s: %v", name, err) 369 } 370 371 // Set the per-type value and check value. 372 name = fmt.Sprintf("%s (set to %T %v)", name, test.want, test.want) 373 if err := proto.SetExtension(msg, test.ext, setTo(test)); err != nil { 374 t.Errorf("%s: SetExtension(): %v", name, err) 375 continue 376 } 377 if err := checkVal(test, msg, test.want); err != nil { 378 t.Errorf("%s: %v", name, err) 379 continue 380 } 381 382 // Set and check the value. 383 name += " (cleared)" 384 proto.ClearExtension(msg, test.ext) 385 if err := checkVal(test, msg, test.def); err != nil { 386 t.Errorf("%s: %v", name, err) 387 } 388 } 389 } 390 391 func TestNilMessage(t *testing.T) { 392 name := "nil interface" 393 if got, err := proto.GetExtension(nil, pb.E_Ext_More); err == nil { 394 t.Errorf("%s: got %T %v, expected to fail", name, got, got) 395 } else if !strings.Contains(err.Error(), "extendable") { 396 t.Errorf("%s: got error %v, expected not-extendable error", name, err) 397 } 398 399 // Regression tests: all functions of the Extension API 400 // used to panic when passed (*M)(nil), where M is a concrete message 401 // type. Now they handle this gracefully as a no-op or reported error. 402 var nilMsg *pb.MyMessage 403 desc := pb.E_Ext_More 404 405 isNotExtendable := func(err error) bool { 406 return strings.Contains(fmt.Sprint(err), "not extendable") 407 } 408 409 if proto.HasExtension(nilMsg, desc) { 410 t.Error("HasExtension(nil) = true") 411 } 412 413 if _, err := proto.GetExtensions(nilMsg, []*proto.ExtensionDesc{desc}); !isNotExtendable(err) { 414 t.Errorf("GetExtensions(nil) = %q (wrong error)", err) 415 } 416 417 if _, err := proto.ExtensionDescs(nilMsg); !isNotExtendable(err) { 418 t.Errorf("ExtensionDescs(nil) = %q (wrong error)", err) 419 } 420 421 if err := proto.SetExtension(nilMsg, desc, nil); !isNotExtendable(err) { 422 t.Errorf("SetExtension(nil) = %q (wrong error)", err) 423 } 424 425 proto.ClearExtension(nilMsg, desc) // no-op 426 proto.ClearAllExtensions(nilMsg) // no-op 427 } 428 429 func TestExtensionsRoundTrip(t *testing.T) { 430 msg := &pb.MyMessage{} 431 ext1 := &pb.Ext{ 432 Data: proto.String("hi"), 433 } 434 ext2 := &pb.Ext{ 435 Data: proto.String("there"), 436 } 437 exists := proto.HasExtension(msg, pb.E_Ext_More) 438 if exists { 439 t.Error("Extension More present unexpectedly") 440 } 441 if err := proto.SetExtension(msg, pb.E_Ext_More, ext1); err != nil { 442 t.Error(err) 443 } 444 if err := proto.SetExtension(msg, pb.E_Ext_More, ext2); err != nil { 445 t.Error(err) 446 } 447 e, err := proto.GetExtension(msg, pb.E_Ext_More) 448 if err != nil { 449 t.Error(err) 450 } 451 x, ok := e.(*pb.Ext) 452 if !ok { 453 t.Errorf("e has type %T, expected test_proto.Ext", e) 454 } else if *x.Data != "there" { 455 t.Errorf("SetExtension failed to overwrite, got %+v, not 'there'", x) 456 } 457 proto.ClearExtension(msg, pb.E_Ext_More) 458 if _, err = proto.GetExtension(msg, pb.E_Ext_More); err != proto.ErrMissingExtension { 459 t.Errorf("got %v, expected ErrMissingExtension", e) 460 } 461 if _, err := proto.GetExtension(msg, pb.E_X215); err == nil { 462 t.Error("expected bad extension error, got nil") 463 } 464 if err := proto.SetExtension(msg, pb.E_X215, 12); err == nil { 465 t.Error("expected extension err") 466 } 467 if err := proto.SetExtension(msg, pb.E_Ext_More, 12); err == nil { 468 t.Error("expected some sort of type mismatch error, got nil") 469 } 470 } 471 472 func TestNilExtension(t *testing.T) { 473 msg := &pb.MyMessage{ 474 Count: proto.Int32(1), 475 } 476 if err := proto.SetExtension(msg, pb.E_Ext_Text, proto.String("hello")); err != nil { 477 t.Fatal(err) 478 } 479 if err := proto.SetExtension(msg, pb.E_Ext_More, (*pb.Ext)(nil)); err == nil { 480 t.Error("expected SetExtension to fail due to a nil extension") 481 } else if want := fmt.Sprintf("proto: SetExtension called with nil value of type %T", new(pb.Ext)); err.Error() != want { 482 t.Errorf("expected error %v, got %v", want, err) 483 } 484 // Note: if the behavior of Marshal is ever changed to ignore nil extensions, update 485 // this test to verify that E_Ext_Text is properly propagated through marshal->unmarshal. 486 } 487 488 func TestMarshalUnmarshalRepeatedExtension(t *testing.T) { 489 // Add a repeated extension to the result. 490 tests := []struct { 491 name string 492 ext []*pb.ComplexExtension 493 }{ 494 { 495 "two fields", 496 []*pb.ComplexExtension{ 497 {First: proto.Int32(7)}, 498 {Second: proto.Int32(11)}, 499 }, 500 }, 501 { 502 "repeated field", 503 []*pb.ComplexExtension{ 504 {Third: []int32{1000}}, 505 {Third: []int32{2000}}, 506 }, 507 }, 508 { 509 "two fields and repeated field", 510 []*pb.ComplexExtension{ 511 {Third: []int32{1000}}, 512 {First: proto.Int32(9)}, 513 {Second: proto.Int32(21)}, 514 {Third: []int32{2000}}, 515 }, 516 }, 517 } 518 for _, test := range tests { 519 // Marshal message with a repeated extension. 520 msg1 := new(pb.OtherMessage) 521 err := proto.SetExtension(msg1, pb.E_RComplex, test.ext) 522 if err != nil { 523 t.Fatalf("[%s] Error setting extension: %v", test.name, err) 524 } 525 b, err := proto.Marshal(msg1) 526 if err != nil { 527 t.Fatalf("[%s] Error marshaling message: %v", test.name, err) 528 } 529 530 // Unmarshal and read the merged proto. 531 msg2 := new(pb.OtherMessage) 532 err = proto.Unmarshal(b, msg2) 533 if err != nil { 534 t.Fatalf("[%s] Error unmarshaling message: %v", test.name, err) 535 } 536 e, err := proto.GetExtension(msg2, pb.E_RComplex) 537 if err != nil { 538 t.Fatalf("[%s] Error getting extension: %v", test.name, err) 539 } 540 ext := e.([]*pb.ComplexExtension) 541 if ext == nil { 542 t.Fatalf("[%s] Invalid extension", test.name) 543 } 544 if len(ext) != len(test.ext) { 545 t.Errorf("[%s] Wrong length of ComplexExtension: got: %v want: %v\n", test.name, len(ext), len(test.ext)) 546 } 547 for i := range test.ext { 548 if !proto.Equal(ext[i], test.ext[i]) { 549 t.Errorf("[%s] Wrong value for ComplexExtension[%d]: got: %v want: %v\n", test.name, i, ext[i], test.ext[i]) 550 } 551 } 552 } 553 } 554 555 func TestUnmarshalRepeatingNonRepeatedExtension(t *testing.T) { 556 // We may see multiple instances of the same extension in the wire 557 // format. For example, the proto compiler may encode custom options in 558 // this way. Here, we verify that we merge the extensions together. 559 tests := []struct { 560 name string 561 ext []*pb.ComplexExtension 562 }{ 563 { 564 "two fields", 565 []*pb.ComplexExtension{ 566 {First: proto.Int32(7)}, 567 {Second: proto.Int32(11)}, 568 }, 569 }, 570 { 571 "repeated field", 572 []*pb.ComplexExtension{ 573 {Third: []int32{1000}}, 574 {Third: []int32{2000}}, 575 }, 576 }, 577 { 578 "two fields and repeated field", 579 []*pb.ComplexExtension{ 580 {Third: []int32{1000}}, 581 {First: proto.Int32(9)}, 582 {Second: proto.Int32(21)}, 583 {Third: []int32{2000}}, 584 }, 585 }, 586 } 587 for _, test := range tests { 588 var buf bytes.Buffer 589 var want pb.ComplexExtension 590 591 // Generate a serialized representation of a repeated extension 592 // by catenating bytes together. 593 for i, e := range test.ext { 594 // Merge to create the wanted proto. 595 proto.Merge(&want, e) 596 597 // serialize the message 598 msg := new(pb.OtherMessage) 599 err := proto.SetExtension(msg, pb.E_Complex, e) 600 if err != nil { 601 t.Fatalf("[%s] Error setting extension %d: %v", test.name, i, err) 602 } 603 b, err := proto.Marshal(msg) 604 if err != nil { 605 t.Fatalf("[%s] Error marshaling message %d: %v", test.name, i, err) 606 } 607 buf.Write(b) 608 } 609 610 // Unmarshal and read the merged proto. 611 msg2 := new(pb.OtherMessage) 612 err := proto.Unmarshal(buf.Bytes(), msg2) 613 if err != nil { 614 t.Fatalf("[%s] Error unmarshaling message: %v", test.name, err) 615 } 616 e, err := proto.GetExtension(msg2, pb.E_Complex) 617 if err != nil { 618 t.Fatalf("[%s] Error getting extension: %v", test.name, err) 619 } 620 ext := e.(*pb.ComplexExtension) 621 if ext == nil { 622 t.Fatalf("[%s] Invalid extension", test.name) 623 } 624 if !proto.Equal(ext, &want) { 625 t.Errorf("[%s] Wrong value for ComplexExtension: got: %v want: %v\n", test.name, ext, &want) 626 627 } 628 } 629 } 630 631 func TestClearAllExtensions(t *testing.T) { 632 // unregistered extension 633 desc := &proto.ExtensionDesc{ 634 ExtendedType: (*pb.MyMessage)(nil), 635 ExtensionType: (*bool)(nil), 636 Field: 101010100, 637 Name: "emptyextension", 638 Tag: "varint,0,opt", 639 } 640 m := &pb.MyMessage{} 641 if proto.HasExtension(m, desc) { 642 t.Errorf("proto.HasExtension(%s): got true, want false", proto.MarshalTextString(m)) 643 } 644 if err := proto.SetExtension(m, desc, proto.Bool(true)); err != nil { 645 t.Errorf("proto.SetExtension(m, desc, true): got error %q, want nil", err) 646 } 647 if !proto.HasExtension(m, desc) { 648 t.Errorf("proto.HasExtension(%s): got false, want true", proto.MarshalTextString(m)) 649 } 650 proto.ClearAllExtensions(m) 651 if proto.HasExtension(m, desc) { 652 t.Errorf("proto.HasExtension(%s): got true, want false", proto.MarshalTextString(m)) 653 } 654 } 655 656 func TestMarshalRace(t *testing.T) { 657 ext := &pb.Ext{} 658 m := &pb.MyMessage{Count: proto.Int32(4)} 659 if err := proto.SetExtension(m, pb.E_Ext_More, ext); err != nil { 660 t.Fatalf("proto.SetExtension(m, desc, true): got error %q, want nil", err) 661 } 662 663 b, err := proto.Marshal(m) 664 if err != nil { 665 t.Fatalf("Could not marshal message: %v", err) 666 } 667 if err := proto.Unmarshal(b, m); err != nil { 668 t.Fatalf("Could not unmarshal message: %v", err) 669 } 670 // after Unmarshal, the extension is in undecoded form. 671 // GetExtension will decode it lazily. Make sure this does 672 // not race against Marshal. 673 674 errChan := make(chan error, 6) 675 for n := 3; n > 0; n-- { 676 go func() { 677 _, err := proto.Marshal(m) 678 errChan <- err 679 }() 680 go func() { 681 _, err := proto.GetExtension(m, pb.E_Ext_More) 682 errChan <- err 683 }() 684 } 685 for i := 0; i < 6; i++ { 686 err := <-errChan 687 if err != nil { 688 t.Fatal(err) 689 } 690 } 691 }