github.com/gogo/protobuf@v1.3.2/plugin/marshalto/marshalto.go (about) 1 // Protocol Buffers for Go with Gadgets 2 // 3 // Copyright (c) 2013, The GoGo Authors. All rights reserved. 4 // http://github.com/gogo/protobuf 5 // 6 // Redistribution and use in source and binary forms, with or without 7 // modification, are permitted provided that the following conditions are 8 // met: 9 // 10 // * Redistributions of source code must retain the above copyright 11 // notice, this list of conditions and the following disclaimer. 12 // * Redistributions in binary form must reproduce the above 13 // copyright notice, this list of conditions and the following disclaimer 14 // in the documentation and/or other materials provided with the 15 // distribution. 16 // 17 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 18 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 19 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 20 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 21 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 22 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 23 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 24 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 25 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 26 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 29 /* 30 The marshalto plugin generates a Marshal and MarshalTo method for each message. 31 The `Marshal() ([]byte, error)` method results in the fact that the message 32 implements the Marshaler interface. 33 This allows proto.Marshal to be faster by calling the generated Marshal method rather than using reflect to Marshal the struct. 34 35 If is enabled by the following extensions: 36 37 - marshaler 38 - marshaler_all 39 40 Or the following extensions: 41 42 - unsafe_marshaler 43 - unsafe_marshaler_all 44 45 That is if you want to use the unsafe package in your generated code. 46 The speed up using the unsafe package is not very significant. 47 48 The generation of marshalling tests are enabled using one of the following extensions: 49 50 - testgen 51 - testgen_all 52 53 And benchmarks given it is enabled using one of the following extensions: 54 55 - benchgen 56 - benchgen_all 57 58 Let us look at: 59 60 github.com/gogo/protobuf/test/example/example.proto 61 62 Btw all the output can be seen at: 63 64 github.com/gogo/protobuf/test/example/* 65 66 The following message: 67 68 option (gogoproto.marshaler_all) = true; 69 70 message B { 71 option (gogoproto.description) = true; 72 optional A A = 1 [(gogoproto.nullable) = false, (gogoproto.embed) = true]; 73 repeated bytes G = 2 [(gogoproto.customtype) = "github.com/gogo/protobuf/test/custom.Uint128", (gogoproto.nullable) = false]; 74 } 75 76 given to the marshalto plugin, will generate the following code: 77 78 func (m *B) Marshal() (dAtA []byte, err error) { 79 size := m.Size() 80 dAtA = make([]byte, size) 81 n, err := m.MarshalToSizedBuffer(dAtA[:size]) 82 if err != nil { 83 return nil, err 84 } 85 return dAtA[:n], nil 86 } 87 88 func (m *B) MarshalTo(dAtA []byte) (int, error) { 89 size := m.Size() 90 return m.MarshalToSizedBuffer(dAtA[:size]) 91 } 92 93 func (m *B) MarshalToSizedBuffer(dAtA []byte) (int, error) { 94 i := len(dAtA) 95 _ = i 96 var l int 97 _ = l 98 if m.XXX_unrecognized != nil { 99 i -= len(m.XXX_unrecognized) 100 copy(dAtA[i:], m.XXX_unrecognized) 101 } 102 if len(m.G) > 0 { 103 for iNdEx := len(m.G) - 1; iNdEx >= 0; iNdEx-- { 104 { 105 size := m.G[iNdEx].Size() 106 i -= size 107 if _, err := m.G[iNdEx].MarshalTo(dAtA[i:]); err != nil { 108 return 0, err 109 } 110 i = encodeVarintExample(dAtA, i, uint64(size)) 111 } 112 i-- 113 dAtA[i] = 0x12 114 } 115 } 116 { 117 size, err := m.A.MarshalToSizedBuffer(dAtA[:i]) 118 if err != nil { 119 return 0, err 120 } 121 i -= size 122 i = encodeVarintExample(dAtA, i, uint64(size)) 123 } 124 i-- 125 dAtA[i] = 0xa 126 return len(dAtA) - i, nil 127 } 128 129 As shown above Marshal calculates the size of the not yet marshalled message 130 and allocates the appropriate buffer. 131 This is followed by calling the MarshalToSizedBuffer method which requires a preallocated buffer, and marshals backwards. 132 The MarshalTo method allows a user to rather preallocated a reusable buffer. 133 134 The Size method is generated using the size plugin and the gogoproto.sizer, gogoproto.sizer_all extensions. 135 The user can also using the generated Size method to check that his reusable buffer is still big enough. 136 137 The generated tests and benchmarks will keep you safe and show that this is really a significant speed improvement. 138 139 An additional message-level option `stable_marshaler` (and the file-level 140 option `stable_marshaler_all`) exists which causes the generated marshalling 141 code to behave deterministically. Today, this only changes the serialization of 142 maps; they are serialized in sort order. 143 */ 144 package marshalto 145 146 import ( 147 "fmt" 148 "sort" 149 "strconv" 150 "strings" 151 152 "github.com/gogo/protobuf/gogoproto" 153 "github.com/gogo/protobuf/proto" 154 descriptor "github.com/gogo/protobuf/protoc-gen-gogo/descriptor" 155 "github.com/gogo/protobuf/protoc-gen-gogo/generator" 156 "github.com/gogo/protobuf/vanity" 157 ) 158 159 type NumGen interface { 160 Next() string 161 Current() string 162 } 163 164 type numGen struct { 165 index int 166 } 167 168 func NewNumGen() NumGen { 169 return &numGen{0} 170 } 171 172 func (this *numGen) Next() string { 173 this.index++ 174 return this.Current() 175 } 176 177 func (this *numGen) Current() string { 178 return strconv.Itoa(this.index) 179 } 180 181 type marshalto struct { 182 *generator.Generator 183 generator.PluginImports 184 atleastOne bool 185 errorsPkg generator.Single 186 protoPkg generator.Single 187 sortKeysPkg generator.Single 188 mathPkg generator.Single 189 typesPkg generator.Single 190 binaryPkg generator.Single 191 localName string 192 } 193 194 func NewMarshal() *marshalto { 195 return &marshalto{} 196 } 197 198 func (p *marshalto) Name() string { 199 return "marshalto" 200 } 201 202 func (p *marshalto) Init(g *generator.Generator) { 203 p.Generator = g 204 } 205 206 func (p *marshalto) callFixed64(varName ...string) { 207 p.P(`i -= 8`) 208 p.P(p.binaryPkg.Use(), `.LittleEndian.PutUint64(dAtA[i:], uint64(`, strings.Join(varName, ""), `))`) 209 } 210 211 func (p *marshalto) callFixed32(varName ...string) { 212 p.P(`i -= 4`) 213 p.P(p.binaryPkg.Use(), `.LittleEndian.PutUint32(dAtA[i:], uint32(`, strings.Join(varName, ""), `))`) 214 } 215 216 func (p *marshalto) callVarint(varName ...string) { 217 p.P(`i = encodeVarint`, p.localName, `(dAtA, i, uint64(`, strings.Join(varName, ""), `))`) 218 } 219 220 func (p *marshalto) encodeKey(fieldNumber int32, wireType int) { 221 x := uint32(fieldNumber)<<3 | uint32(wireType) 222 i := 0 223 keybuf := make([]byte, 0) 224 for i = 0; x > 127; i++ { 225 keybuf = append(keybuf, 0x80|uint8(x&0x7F)) 226 x >>= 7 227 } 228 keybuf = append(keybuf, uint8(x)) 229 for i = len(keybuf) - 1; i >= 0; i-- { 230 p.P(`i--`) 231 p.P(`dAtA[i] = `, fmt.Sprintf("%#v", keybuf[i])) 232 } 233 } 234 235 func keySize(fieldNumber int32, wireType int) int { 236 x := uint32(fieldNumber)<<3 | uint32(wireType) 237 size := 0 238 for size = 0; x > 127; size++ { 239 x >>= 7 240 } 241 size++ 242 return size 243 } 244 245 func wireToType(wire string) int { 246 switch wire { 247 case "fixed64": 248 return proto.WireFixed64 249 case "fixed32": 250 return proto.WireFixed32 251 case "varint": 252 return proto.WireVarint 253 case "bytes": 254 return proto.WireBytes 255 case "group": 256 return proto.WireBytes 257 case "zigzag32": 258 return proto.WireVarint 259 case "zigzag64": 260 return proto.WireVarint 261 } 262 panic("unreachable") 263 } 264 265 func (p *marshalto) mapField(numGen NumGen, field *descriptor.FieldDescriptorProto, kvField *descriptor.FieldDescriptorProto, varName string, protoSizer bool) { 266 switch kvField.GetType() { 267 case descriptor.FieldDescriptorProto_TYPE_DOUBLE: 268 p.callFixed64(p.mathPkg.Use(), `.Float64bits(float64(`, varName, `))`) 269 case descriptor.FieldDescriptorProto_TYPE_FLOAT: 270 p.callFixed32(p.mathPkg.Use(), `.Float32bits(float32(`, varName, `))`) 271 case descriptor.FieldDescriptorProto_TYPE_INT64, 272 descriptor.FieldDescriptorProto_TYPE_UINT64, 273 descriptor.FieldDescriptorProto_TYPE_INT32, 274 descriptor.FieldDescriptorProto_TYPE_UINT32, 275 descriptor.FieldDescriptorProto_TYPE_ENUM: 276 p.callVarint(varName) 277 case descriptor.FieldDescriptorProto_TYPE_FIXED64, 278 descriptor.FieldDescriptorProto_TYPE_SFIXED64: 279 p.callFixed64(varName) 280 case descriptor.FieldDescriptorProto_TYPE_FIXED32, 281 descriptor.FieldDescriptorProto_TYPE_SFIXED32: 282 p.callFixed32(varName) 283 case descriptor.FieldDescriptorProto_TYPE_BOOL: 284 p.P(`i--`) 285 p.P(`if `, varName, ` {`) 286 p.In() 287 p.P(`dAtA[i] = 1`) 288 p.Out() 289 p.P(`} else {`) 290 p.In() 291 p.P(`dAtA[i] = 0`) 292 p.Out() 293 p.P(`}`) 294 case descriptor.FieldDescriptorProto_TYPE_STRING, 295 descriptor.FieldDescriptorProto_TYPE_BYTES: 296 if gogoproto.IsCustomType(field) && kvField.IsBytes() { 297 p.forward(varName, true, protoSizer) 298 } else { 299 p.P(`i -= len(`, varName, `)`) 300 p.P(`copy(dAtA[i:], `, varName, `)`) 301 p.callVarint(`len(`, varName, `)`) 302 } 303 case descriptor.FieldDescriptorProto_TYPE_SINT32: 304 p.callVarint(`(uint32(`, varName, `) << 1) ^ uint32((`, varName, ` >> 31))`) 305 case descriptor.FieldDescriptorProto_TYPE_SINT64: 306 p.callVarint(`(uint64(`, varName, `) << 1) ^ uint64((`, varName, ` >> 63))`) 307 case descriptor.FieldDescriptorProto_TYPE_MESSAGE: 308 if !p.marshalAllSizeOf(kvField, `(*`+varName+`)`, numGen.Next()) { 309 if gogoproto.IsCustomType(field) { 310 p.forward(varName, true, protoSizer) 311 } else { 312 p.backward(varName, true) 313 } 314 } 315 316 } 317 } 318 319 type orderFields []*descriptor.FieldDescriptorProto 320 321 func (this orderFields) Len() int { 322 return len(this) 323 } 324 325 func (this orderFields) Less(i, j int) bool { 326 return this[i].GetNumber() < this[j].GetNumber() 327 } 328 329 func (this orderFields) Swap(i, j int) { 330 this[i], this[j] = this[j], this[i] 331 } 332 333 func (p *marshalto) generateField(proto3 bool, numGen NumGen, file *generator.FileDescriptor, message *generator.Descriptor, field *descriptor.FieldDescriptorProto) { 334 fieldname := p.GetOneOfFieldName(message, field) 335 nullable := gogoproto.IsNullable(field) 336 repeated := field.IsRepeated() 337 required := field.IsRequired() 338 339 protoSizer := gogoproto.IsProtoSizer(file.FileDescriptorProto, message.DescriptorProto) 340 doNilCheck := gogoproto.NeedsNilCheck(proto3, field) 341 if required && nullable { 342 p.P(`if m.`, fieldname, `== nil {`) 343 p.In() 344 if !gogoproto.ImportsGoGoProto(file.FileDescriptorProto) { 345 p.P(`return 0, new(`, p.protoPkg.Use(), `.RequiredNotSetError)`) 346 } else { 347 p.P(`return 0, `, p.protoPkg.Use(), `.NewRequiredNotSetError("`, field.GetName(), `")`) 348 } 349 p.Out() 350 p.P(`} else {`) 351 } else if repeated { 352 p.P(`if len(m.`, fieldname, `) > 0 {`) 353 p.In() 354 } else if doNilCheck { 355 p.P(`if m.`, fieldname, ` != nil {`) 356 p.In() 357 } 358 packed := field.IsPacked() || (proto3 && field.IsPacked3()) 359 wireType := field.WireType() 360 fieldNumber := field.GetNumber() 361 if packed { 362 wireType = proto.WireBytes 363 } 364 switch *field.Type { 365 case descriptor.FieldDescriptorProto_TYPE_DOUBLE: 366 if packed { 367 val := p.reverseListRange(`m.`, fieldname) 368 p.P(`f`, numGen.Next(), ` := `, p.mathPkg.Use(), `.Float64bits(float64(`, val, `))`) 369 p.callFixed64("f" + numGen.Current()) 370 p.Out() 371 p.P(`}`) 372 p.callVarint(`len(m.`, fieldname, `) * 8`) 373 p.encodeKey(fieldNumber, wireType) 374 } else if repeated { 375 val := p.reverseListRange(`m.`, fieldname) 376 p.P(`f`, numGen.Next(), ` := `, p.mathPkg.Use(), `.Float64bits(float64(`, val, `))`) 377 p.callFixed64("f" + numGen.Current()) 378 p.encodeKey(fieldNumber, wireType) 379 p.Out() 380 p.P(`}`) 381 } else if proto3 { 382 p.P(`if m.`, fieldname, ` != 0 {`) 383 p.In() 384 p.callFixed64(p.mathPkg.Use(), `.Float64bits(float64(m.`+fieldname, `))`) 385 p.encodeKey(fieldNumber, wireType) 386 p.Out() 387 p.P(`}`) 388 } else if !nullable { 389 p.callFixed64(p.mathPkg.Use(), `.Float64bits(float64(m.`+fieldname, `))`) 390 p.encodeKey(fieldNumber, wireType) 391 } else { 392 p.callFixed64(p.mathPkg.Use(), `.Float64bits(float64(*m.`+fieldname, `))`) 393 p.encodeKey(fieldNumber, wireType) 394 } 395 case descriptor.FieldDescriptorProto_TYPE_FLOAT: 396 if packed { 397 val := p.reverseListRange(`m.`, fieldname) 398 p.P(`f`, numGen.Next(), ` := `, p.mathPkg.Use(), `.Float32bits(float32(`, val, `))`) 399 p.callFixed32("f" + numGen.Current()) 400 p.Out() 401 p.P(`}`) 402 p.callVarint(`len(m.`, fieldname, `) * 4`) 403 p.encodeKey(fieldNumber, wireType) 404 } else if repeated { 405 val := p.reverseListRange(`m.`, fieldname) 406 p.P(`f`, numGen.Next(), ` := `, p.mathPkg.Use(), `.Float32bits(float32(`, val, `))`) 407 p.callFixed32("f" + numGen.Current()) 408 p.encodeKey(fieldNumber, wireType) 409 p.Out() 410 p.P(`}`) 411 } else if proto3 { 412 p.P(`if m.`, fieldname, ` != 0 {`) 413 p.In() 414 p.callFixed32(p.mathPkg.Use(), `.Float32bits(float32(m.`+fieldname, `))`) 415 p.encodeKey(fieldNumber, wireType) 416 p.Out() 417 p.P(`}`) 418 } else if !nullable { 419 p.callFixed32(p.mathPkg.Use(), `.Float32bits(float32(m.`+fieldname, `))`) 420 p.encodeKey(fieldNumber, wireType) 421 } else { 422 p.callFixed32(p.mathPkg.Use(), `.Float32bits(float32(*m.`+fieldname, `))`) 423 p.encodeKey(fieldNumber, wireType) 424 } 425 case descriptor.FieldDescriptorProto_TYPE_INT64, 426 descriptor.FieldDescriptorProto_TYPE_UINT64, 427 descriptor.FieldDescriptorProto_TYPE_INT32, 428 descriptor.FieldDescriptorProto_TYPE_UINT32, 429 descriptor.FieldDescriptorProto_TYPE_ENUM: 430 if packed { 431 jvar := "j" + numGen.Next() 432 p.P(`dAtA`, numGen.Next(), ` := make([]byte, len(m.`, fieldname, `)*10)`) 433 p.P(`var `, jvar, ` int`) 434 if *field.Type == descriptor.FieldDescriptorProto_TYPE_INT64 || 435 *field.Type == descriptor.FieldDescriptorProto_TYPE_INT32 { 436 p.P(`for _, num1 := range m.`, fieldname, ` {`) 437 p.In() 438 p.P(`num := uint64(num1)`) 439 } else { 440 p.P(`for _, num := range m.`, fieldname, ` {`) 441 p.In() 442 } 443 p.P(`for num >= 1<<7 {`) 444 p.In() 445 p.P(`dAtA`, numGen.Current(), `[`, jvar, `] = uint8(uint64(num)&0x7f|0x80)`) 446 p.P(`num >>= 7`) 447 p.P(jvar, `++`) 448 p.Out() 449 p.P(`}`) 450 p.P(`dAtA`, numGen.Current(), `[`, jvar, `] = uint8(num)`) 451 p.P(jvar, `++`) 452 p.Out() 453 p.P(`}`) 454 p.P(`i -= `, jvar) 455 p.P(`copy(dAtA[i:], dAtA`, numGen.Current(), `[:`, jvar, `])`) 456 p.callVarint(jvar) 457 p.encodeKey(fieldNumber, wireType) 458 } else if repeated { 459 val := p.reverseListRange(`m.`, fieldname) 460 p.callVarint(val) 461 p.encodeKey(fieldNumber, wireType) 462 p.Out() 463 p.P(`}`) 464 } else if proto3 { 465 p.P(`if m.`, fieldname, ` != 0 {`) 466 p.In() 467 p.callVarint(`m.`, fieldname) 468 p.encodeKey(fieldNumber, wireType) 469 p.Out() 470 p.P(`}`) 471 } else if !nullable { 472 p.callVarint(`m.`, fieldname) 473 p.encodeKey(fieldNumber, wireType) 474 } else { 475 p.callVarint(`*m.`, fieldname) 476 p.encodeKey(fieldNumber, wireType) 477 } 478 case descriptor.FieldDescriptorProto_TYPE_FIXED64, 479 descriptor.FieldDescriptorProto_TYPE_SFIXED64: 480 if packed { 481 val := p.reverseListRange(`m.`, fieldname) 482 p.callFixed64(val) 483 p.Out() 484 p.P(`}`) 485 p.callVarint(`len(m.`, fieldname, `) * 8`) 486 p.encodeKey(fieldNumber, wireType) 487 } else if repeated { 488 val := p.reverseListRange(`m.`, fieldname) 489 p.callFixed64(val) 490 p.encodeKey(fieldNumber, wireType) 491 p.Out() 492 p.P(`}`) 493 } else if proto3 { 494 p.P(`if m.`, fieldname, ` != 0 {`) 495 p.In() 496 p.callFixed64("m." + fieldname) 497 p.encodeKey(fieldNumber, wireType) 498 p.Out() 499 p.P(`}`) 500 } else if !nullable { 501 p.callFixed64("m." + fieldname) 502 p.encodeKey(fieldNumber, wireType) 503 } else { 504 p.callFixed64("*m." + fieldname) 505 p.encodeKey(fieldNumber, wireType) 506 } 507 case descriptor.FieldDescriptorProto_TYPE_FIXED32, 508 descriptor.FieldDescriptorProto_TYPE_SFIXED32: 509 if packed { 510 val := p.reverseListRange(`m.`, fieldname) 511 p.callFixed32(val) 512 p.Out() 513 p.P(`}`) 514 p.callVarint(`len(m.`, fieldname, `) * 4`) 515 p.encodeKey(fieldNumber, wireType) 516 } else if repeated { 517 val := p.reverseListRange(`m.`, fieldname) 518 p.callFixed32(val) 519 p.encodeKey(fieldNumber, wireType) 520 p.Out() 521 p.P(`}`) 522 } else if proto3 { 523 p.P(`if m.`, fieldname, ` != 0 {`) 524 p.In() 525 p.callFixed32("m." + fieldname) 526 p.encodeKey(fieldNumber, wireType) 527 p.Out() 528 p.P(`}`) 529 } else if !nullable { 530 p.callFixed32("m." + fieldname) 531 p.encodeKey(fieldNumber, wireType) 532 } else { 533 p.callFixed32("*m." + fieldname) 534 p.encodeKey(fieldNumber, wireType) 535 } 536 case descriptor.FieldDescriptorProto_TYPE_BOOL: 537 if packed { 538 val := p.reverseListRange(`m.`, fieldname) 539 p.P(`i--`) 540 p.P(`if `, val, ` {`) 541 p.In() 542 p.P(`dAtA[i] = 1`) 543 p.Out() 544 p.P(`} else {`) 545 p.In() 546 p.P(`dAtA[i] = 0`) 547 p.Out() 548 p.P(`}`) 549 p.Out() 550 p.P(`}`) 551 p.callVarint(`len(m.`, fieldname, `)`) 552 p.encodeKey(fieldNumber, wireType) 553 } else if repeated { 554 val := p.reverseListRange(`m.`, fieldname) 555 p.P(`i--`) 556 p.P(`if `, val, ` {`) 557 p.In() 558 p.P(`dAtA[i] = 1`) 559 p.Out() 560 p.P(`} else {`) 561 p.In() 562 p.P(`dAtA[i] = 0`) 563 p.Out() 564 p.P(`}`) 565 p.encodeKey(fieldNumber, wireType) 566 p.Out() 567 p.P(`}`) 568 } else if proto3 { 569 p.P(`if m.`, fieldname, ` {`) 570 p.In() 571 p.P(`i--`) 572 p.P(`if m.`, fieldname, ` {`) 573 p.In() 574 p.P(`dAtA[i] = 1`) 575 p.Out() 576 p.P(`} else {`) 577 p.In() 578 p.P(`dAtA[i] = 0`) 579 p.Out() 580 p.P(`}`) 581 p.encodeKey(fieldNumber, wireType) 582 p.Out() 583 p.P(`}`) 584 } else if !nullable { 585 p.P(`i--`) 586 p.P(`if m.`, fieldname, ` {`) 587 p.In() 588 p.P(`dAtA[i] = 1`) 589 p.Out() 590 p.P(`} else {`) 591 p.In() 592 p.P(`dAtA[i] = 0`) 593 p.Out() 594 p.P(`}`) 595 p.encodeKey(fieldNumber, wireType) 596 } else { 597 p.P(`i--`) 598 p.P(`if *m.`, fieldname, ` {`) 599 p.In() 600 p.P(`dAtA[i] = 1`) 601 p.Out() 602 p.P(`} else {`) 603 p.In() 604 p.P(`dAtA[i] = 0`) 605 p.Out() 606 p.P(`}`) 607 p.encodeKey(fieldNumber, wireType) 608 } 609 case descriptor.FieldDescriptorProto_TYPE_STRING: 610 if repeated { 611 val := p.reverseListRange(`m.`, fieldname) 612 p.P(`i -= len(`, val, `)`) 613 p.P(`copy(dAtA[i:], `, val, `)`) 614 p.callVarint(`len(`, val, `)`) 615 p.encodeKey(fieldNumber, wireType) 616 p.Out() 617 p.P(`}`) 618 } else if proto3 { 619 p.P(`if len(m.`, fieldname, `) > 0 {`) 620 p.In() 621 p.P(`i -= len(m.`, fieldname, `)`) 622 p.P(`copy(dAtA[i:], m.`, fieldname, `)`) 623 p.callVarint(`len(m.`, fieldname, `)`) 624 p.encodeKey(fieldNumber, wireType) 625 p.Out() 626 p.P(`}`) 627 } else if !nullable { 628 p.P(`i -= len(m.`, fieldname, `)`) 629 p.P(`copy(dAtA[i:], m.`, fieldname, `)`) 630 p.callVarint(`len(m.`, fieldname, `)`) 631 p.encodeKey(fieldNumber, wireType) 632 } else { 633 p.P(`i -= len(*m.`, fieldname, `)`) 634 p.P(`copy(dAtA[i:], *m.`, fieldname, `)`) 635 p.callVarint(`len(*m.`, fieldname, `)`) 636 p.encodeKey(fieldNumber, wireType) 637 } 638 case descriptor.FieldDescriptorProto_TYPE_GROUP: 639 panic(fmt.Errorf("marshaler does not support group %v", fieldname)) 640 case descriptor.FieldDescriptorProto_TYPE_MESSAGE: 641 if p.IsMap(field) { 642 m := p.GoMapType(nil, field) 643 keygoTyp, keywire := p.GoType(nil, m.KeyField) 644 keygoAliasTyp, _ := p.GoType(nil, m.KeyAliasField) 645 // keys may not be pointers 646 keygoTyp = strings.Replace(keygoTyp, "*", "", 1) 647 keygoAliasTyp = strings.Replace(keygoAliasTyp, "*", "", 1) 648 keyCapTyp := generator.CamelCase(keygoTyp) 649 valuegoTyp, valuewire := p.GoType(nil, m.ValueField) 650 valuegoAliasTyp, _ := p.GoType(nil, m.ValueAliasField) 651 nullable, valuegoTyp, valuegoAliasTyp = generator.GoMapValueTypes(field, m.ValueField, valuegoTyp, valuegoAliasTyp) 652 var val string 653 if gogoproto.IsStableMarshaler(file.FileDescriptorProto, message.DescriptorProto) { 654 keysName := `keysFor` + fieldname 655 p.P(keysName, ` := make([]`, keygoTyp, `, 0, len(m.`, fieldname, `))`) 656 p.P(`for k := range m.`, fieldname, ` {`) 657 p.In() 658 p.P(keysName, ` = append(`, keysName, `, `, keygoTyp, `(k))`) 659 p.Out() 660 p.P(`}`) 661 p.P(p.sortKeysPkg.Use(), `.`, keyCapTyp, `s(`, keysName, `)`) 662 val = p.reverseListRange(keysName) 663 } else { 664 p.P(`for k := range m.`, fieldname, ` {`) 665 val = "k" 666 p.In() 667 } 668 if gogoproto.IsStableMarshaler(file.FileDescriptorProto, message.DescriptorProto) { 669 p.P(`v := m.`, fieldname, `[`, keygoAliasTyp, `(`, val, `)]`) 670 } else { 671 p.P(`v := m.`, fieldname, `[`, val, `]`) 672 } 673 p.P(`baseI := i`) 674 accessor := `v` 675 676 if m.ValueField.GetType() == descriptor.FieldDescriptorProto_TYPE_MESSAGE { 677 if valuegoTyp != valuegoAliasTyp && !gogoproto.IsStdType(m.ValueAliasField) { 678 if nullable { 679 // cast back to the type that has the generated methods on it 680 accessor = `((` + valuegoTyp + `)(` + accessor + `))` 681 } else { 682 accessor = `((*` + valuegoTyp + `)(&` + accessor + `))` 683 } 684 } else if !nullable { 685 accessor = `(&v)` 686 } 687 } 688 689 nullableMsg := nullable && (m.ValueField.GetType() == descriptor.FieldDescriptorProto_TYPE_MESSAGE || 690 gogoproto.IsCustomType(field) && m.ValueField.IsBytes()) 691 plainBytes := m.ValueField.IsBytes() && !gogoproto.IsCustomType(field) 692 if nullableMsg { 693 p.P(`if `, accessor, ` != nil { `) 694 p.In() 695 } else if plainBytes { 696 if proto3 { 697 p.P(`if len(`, accessor, `) > 0 {`) 698 } else { 699 p.P(`if `, accessor, ` != nil {`) 700 } 701 p.In() 702 } 703 p.mapField(numGen, field, m.ValueAliasField, accessor, protoSizer) 704 p.encodeKey(2, wireToType(valuewire)) 705 if nullableMsg || plainBytes { 706 p.Out() 707 p.P(`}`) 708 } 709 710 p.mapField(numGen, field, m.KeyField, val, protoSizer) 711 p.encodeKey(1, wireToType(keywire)) 712 713 p.callVarint(`baseI - i`) 714 715 p.encodeKey(fieldNumber, wireType) 716 p.Out() 717 p.P(`}`) 718 } else if repeated { 719 val := p.reverseListRange(`m.`, fieldname) 720 sizeOfVarName := val 721 if gogoproto.IsNullable(field) { 722 sizeOfVarName = `*` + val 723 } 724 if !p.marshalAllSizeOf(field, sizeOfVarName, ``) { 725 if gogoproto.IsCustomType(field) { 726 p.forward(val, true, protoSizer) 727 } else { 728 p.backward(val, true) 729 } 730 } 731 p.encodeKey(fieldNumber, wireType) 732 p.Out() 733 p.P(`}`) 734 } else { 735 sizeOfVarName := `m.` + fieldname 736 if gogoproto.IsNullable(field) { 737 sizeOfVarName = `*` + sizeOfVarName 738 } 739 if !p.marshalAllSizeOf(field, sizeOfVarName, numGen.Next()) { 740 if gogoproto.IsCustomType(field) { 741 p.forward(`m.`+fieldname, true, protoSizer) 742 } else { 743 p.backward(`m.`+fieldname, true) 744 } 745 } 746 p.encodeKey(fieldNumber, wireType) 747 } 748 case descriptor.FieldDescriptorProto_TYPE_BYTES: 749 if !gogoproto.IsCustomType(field) { 750 if repeated { 751 val := p.reverseListRange(`m.`, fieldname) 752 p.P(`i -= len(`, val, `)`) 753 p.P(`copy(dAtA[i:], `, val, `)`) 754 p.callVarint(`len(`, val, `)`) 755 p.encodeKey(fieldNumber, wireType) 756 p.Out() 757 p.P(`}`) 758 } else if proto3 { 759 p.P(`if len(m.`, fieldname, `) > 0 {`) 760 p.In() 761 p.P(`i -= len(m.`, fieldname, `)`) 762 p.P(`copy(dAtA[i:], m.`, fieldname, `)`) 763 p.callVarint(`len(m.`, fieldname, `)`) 764 p.encodeKey(fieldNumber, wireType) 765 p.Out() 766 p.P(`}`) 767 } else { 768 p.P(`i -= len(m.`, fieldname, `)`) 769 p.P(`copy(dAtA[i:], m.`, fieldname, `)`) 770 p.callVarint(`len(m.`, fieldname, `)`) 771 p.encodeKey(fieldNumber, wireType) 772 } 773 } else { 774 if repeated { 775 val := p.reverseListRange(`m.`, fieldname) 776 p.forward(val, true, protoSizer) 777 p.encodeKey(fieldNumber, wireType) 778 p.Out() 779 p.P(`}`) 780 } else { 781 p.forward(`m.`+fieldname, true, protoSizer) 782 p.encodeKey(fieldNumber, wireType) 783 } 784 } 785 case descriptor.FieldDescriptorProto_TYPE_SINT32: 786 if packed { 787 datavar := "dAtA" + numGen.Next() 788 jvar := "j" + numGen.Next() 789 p.P(datavar, ` := make([]byte, len(m.`, fieldname, ")*5)") 790 p.P(`var `, jvar, ` int`) 791 p.P(`for _, num := range m.`, fieldname, ` {`) 792 p.In() 793 xvar := "x" + numGen.Next() 794 p.P(xvar, ` := (uint32(num) << 1) ^ uint32((num >> 31))`) 795 p.P(`for `, xvar, ` >= 1<<7 {`) 796 p.In() 797 p.P(datavar, `[`, jvar, `] = uint8(uint64(`, xvar, `)&0x7f|0x80)`) 798 p.P(jvar, `++`) 799 p.P(xvar, ` >>= 7`) 800 p.Out() 801 p.P(`}`) 802 p.P(datavar, `[`, jvar, `] = uint8(`, xvar, `)`) 803 p.P(jvar, `++`) 804 p.Out() 805 p.P(`}`) 806 p.P(`i -= `, jvar) 807 p.P(`copy(dAtA[i:], `, datavar, `[:`, jvar, `])`) 808 p.callVarint(jvar) 809 p.encodeKey(fieldNumber, wireType) 810 } else if repeated { 811 val := p.reverseListRange(`m.`, fieldname) 812 p.P(`x`, numGen.Next(), ` := (uint32(`, val, `) << 1) ^ uint32((`, val, ` >> 31))`) 813 p.callVarint(`x`, numGen.Current()) 814 p.encodeKey(fieldNumber, wireType) 815 p.Out() 816 p.P(`}`) 817 } else if proto3 { 818 p.P(`if m.`, fieldname, ` != 0 {`) 819 p.In() 820 p.callVarint(`(uint32(m.`, fieldname, `) << 1) ^ uint32((m.`, fieldname, ` >> 31))`) 821 p.encodeKey(fieldNumber, wireType) 822 p.Out() 823 p.P(`}`) 824 } else if !nullable { 825 p.callVarint(`(uint32(m.`, fieldname, `) << 1) ^ uint32((m.`, fieldname, ` >> 31))`) 826 p.encodeKey(fieldNumber, wireType) 827 } else { 828 p.callVarint(`(uint32(*m.`, fieldname, `) << 1) ^ uint32((*m.`, fieldname, ` >> 31))`) 829 p.encodeKey(fieldNumber, wireType) 830 } 831 case descriptor.FieldDescriptorProto_TYPE_SINT64: 832 if packed { 833 jvar := "j" + numGen.Next() 834 xvar := "x" + numGen.Next() 835 datavar := "dAtA" + numGen.Next() 836 p.P(`var `, jvar, ` int`) 837 p.P(datavar, ` := make([]byte, len(m.`, fieldname, `)*10)`) 838 p.P(`for _, num := range m.`, fieldname, ` {`) 839 p.In() 840 p.P(xvar, ` := (uint64(num) << 1) ^ uint64((num >> 63))`) 841 p.P(`for `, xvar, ` >= 1<<7 {`) 842 p.In() 843 p.P(datavar, `[`, jvar, `] = uint8(uint64(`, xvar, `)&0x7f|0x80)`) 844 p.P(jvar, `++`) 845 p.P(xvar, ` >>= 7`) 846 p.Out() 847 p.P(`}`) 848 p.P(datavar, `[`, jvar, `] = uint8(`, xvar, `)`) 849 p.P(jvar, `++`) 850 p.Out() 851 p.P(`}`) 852 p.P(`i -= `, jvar) 853 p.P(`copy(dAtA[i:], `, datavar, `[:`, jvar, `])`) 854 p.callVarint(jvar) 855 p.encodeKey(fieldNumber, wireType) 856 } else if repeated { 857 val := p.reverseListRange(`m.`, fieldname) 858 p.P(`x`, numGen.Next(), ` := (uint64(`, val, `) << 1) ^ uint64((`, val, ` >> 63))`) 859 p.callVarint("x" + numGen.Current()) 860 p.encodeKey(fieldNumber, wireType) 861 p.Out() 862 p.P(`}`) 863 } else if proto3 { 864 p.P(`if m.`, fieldname, ` != 0 {`) 865 p.In() 866 p.callVarint(`(uint64(m.`, fieldname, `) << 1) ^ uint64((m.`, fieldname, ` >> 63))`) 867 p.encodeKey(fieldNumber, wireType) 868 p.Out() 869 p.P(`}`) 870 } else if !nullable { 871 p.callVarint(`(uint64(m.`, fieldname, `) << 1) ^ uint64((m.`, fieldname, ` >> 63))`) 872 p.encodeKey(fieldNumber, wireType) 873 } else { 874 p.callVarint(`(uint64(*m.`, fieldname, `) << 1) ^ uint64((*m.`, fieldname, ` >> 63))`) 875 p.encodeKey(fieldNumber, wireType) 876 } 877 default: 878 panic("not implemented") 879 } 880 if (required && nullable) || repeated || doNilCheck { 881 p.Out() 882 p.P(`}`) 883 } 884 } 885 886 func (p *marshalto) Generate(file *generator.FileDescriptor) { 887 numGen := NewNumGen() 888 p.PluginImports = generator.NewPluginImports(p.Generator) 889 890 p.atleastOne = false 891 p.localName = generator.FileName(file) 892 893 p.mathPkg = p.NewImport("math") 894 p.sortKeysPkg = p.NewImport("github.com/gogo/protobuf/sortkeys") 895 p.protoPkg = p.NewImport("github.com/gogo/protobuf/proto") 896 if !gogoproto.ImportsGoGoProto(file.FileDescriptorProto) { 897 p.protoPkg = p.NewImport("github.com/golang/protobuf/proto") 898 } 899 p.errorsPkg = p.NewImport("errors") 900 p.binaryPkg = p.NewImport("encoding/binary") 901 p.typesPkg = p.NewImport("github.com/gogo/protobuf/types") 902 903 for _, message := range file.Messages() { 904 if message.DescriptorProto.GetOptions().GetMapEntry() { 905 continue 906 } 907 ccTypeName := generator.CamelCaseSlice(message.TypeName()) 908 if !gogoproto.IsMarshaler(file.FileDescriptorProto, message.DescriptorProto) && 909 !gogoproto.IsUnsafeMarshaler(file.FileDescriptorProto, message.DescriptorProto) { 910 continue 911 } 912 p.atleastOne = true 913 914 p.P(`func (m *`, ccTypeName, `) Marshal() (dAtA []byte, err error) {`) 915 p.In() 916 if gogoproto.IsProtoSizer(file.FileDescriptorProto, message.DescriptorProto) { 917 p.P(`size := m.ProtoSize()`) 918 } else { 919 p.P(`size := m.Size()`) 920 } 921 p.P(`dAtA = make([]byte, size)`) 922 p.P(`n, err := m.MarshalToSizedBuffer(dAtA[:size])`) 923 p.P(`if err != nil {`) 924 p.In() 925 p.P(`return nil, err`) 926 p.Out() 927 p.P(`}`) 928 p.P(`return dAtA[:n], nil`) 929 p.Out() 930 p.P(`}`) 931 p.P(``) 932 p.P(`func (m *`, ccTypeName, `) MarshalTo(dAtA []byte) (int, error) {`) 933 p.In() 934 if gogoproto.IsProtoSizer(file.FileDescriptorProto, message.DescriptorProto) { 935 p.P(`size := m.ProtoSize()`) 936 } else { 937 p.P(`size := m.Size()`) 938 } 939 p.P(`return m.MarshalToSizedBuffer(dAtA[:size])`) 940 p.Out() 941 p.P(`}`) 942 p.P(``) 943 p.P(`func (m *`, ccTypeName, `) MarshalToSizedBuffer(dAtA []byte) (int, error) {`) 944 p.In() 945 p.P(`i := len(dAtA)`) 946 p.P(`_ = i`) 947 p.P(`var l int`) 948 p.P(`_ = l`) 949 if gogoproto.HasUnrecognized(file.FileDescriptorProto, message.DescriptorProto) { 950 p.P(`if m.XXX_unrecognized != nil {`) 951 p.In() 952 p.P(`i -= len(m.XXX_unrecognized)`) 953 p.P(`copy(dAtA[i:], m.XXX_unrecognized)`) 954 p.Out() 955 p.P(`}`) 956 } 957 if message.DescriptorProto.HasExtension() { 958 if gogoproto.HasExtensionsMap(file.FileDescriptorProto, message.DescriptorProto) { 959 p.P(`if n, err := `, p.protoPkg.Use(), `.EncodeInternalExtensionBackwards(m, dAtA[:i]); err != nil {`) 960 p.In() 961 p.P(`return 0, err`) 962 p.Out() 963 p.P(`} else {`) 964 p.In() 965 p.P(`i -= n`) 966 p.Out() 967 p.P(`}`) 968 } else { 969 p.P(`if m.XXX_extensions != nil {`) 970 p.In() 971 p.P(`i -= len(m.XXX_extensions)`) 972 p.P(`copy(dAtA[i:], m.XXX_extensions)`) 973 p.Out() 974 p.P(`}`) 975 } 976 } 977 fields := orderFields(message.GetField()) 978 sort.Sort(fields) 979 oneofs := make(map[string]struct{}) 980 for i := len(message.Field) - 1; i >= 0; i-- { 981 field := message.Field[i] 982 oneof := field.OneofIndex != nil 983 if !oneof { 984 proto3 := gogoproto.IsProto3(file.FileDescriptorProto) 985 p.generateField(proto3, numGen, file, message, field) 986 } else { 987 fieldname := p.GetFieldName(message, field) 988 if _, ok := oneofs[fieldname]; !ok { 989 oneofs[fieldname] = struct{}{} 990 p.P(`if m.`, fieldname, ` != nil {`) 991 p.In() 992 p.forward(`m.`+fieldname, false, gogoproto.IsProtoSizer(file.FileDescriptorProto, message.DescriptorProto)) 993 p.Out() 994 p.P(`}`) 995 } 996 } 997 } 998 p.P(`return len(dAtA) - i, nil`) 999 p.Out() 1000 p.P(`}`) 1001 p.P() 1002 1003 //Generate MarshalTo methods for oneof fields 1004 m := proto.Clone(message.DescriptorProto).(*descriptor.DescriptorProto) 1005 for _, field := range m.Field { 1006 oneof := field.OneofIndex != nil 1007 if !oneof { 1008 continue 1009 } 1010 ccTypeName := p.OneOfTypeName(message, field) 1011 p.P(`func (m *`, ccTypeName, `) MarshalTo(dAtA []byte) (int, error) {`) 1012 p.In() 1013 if gogoproto.IsProtoSizer(file.FileDescriptorProto, message.DescriptorProto) { 1014 p.P(`size := m.ProtoSize()`) 1015 } else { 1016 p.P(`size := m.Size()`) 1017 } 1018 p.P(`return m.MarshalToSizedBuffer(dAtA[:size])`) 1019 p.Out() 1020 p.P(`}`) 1021 p.P(``) 1022 p.P(`func (m *`, ccTypeName, `) MarshalToSizedBuffer(dAtA []byte) (int, error) {`) 1023 p.In() 1024 p.P(`i := len(dAtA)`) 1025 vanity.TurnOffNullableForNativeTypes(field) 1026 p.generateField(false, numGen, file, message, field) 1027 p.P(`return len(dAtA) - i, nil`) 1028 p.Out() 1029 p.P(`}`) 1030 } 1031 } 1032 1033 if p.atleastOne { 1034 p.P(`func encodeVarint`, p.localName, `(dAtA []byte, offset int, v uint64) int {`) 1035 p.In() 1036 p.P(`offset -= sov`, p.localName, `(v)`) 1037 p.P(`base := offset`) 1038 p.P(`for v >= 1<<7 {`) 1039 p.In() 1040 p.P(`dAtA[offset] = uint8(v&0x7f|0x80)`) 1041 p.P(`v >>= 7`) 1042 p.P(`offset++`) 1043 p.Out() 1044 p.P(`}`) 1045 p.P(`dAtA[offset] = uint8(v)`) 1046 p.P(`return base`) 1047 p.Out() 1048 p.P(`}`) 1049 } 1050 1051 } 1052 1053 func (p *marshalto) reverseListRange(expression ...string) string { 1054 exp := strings.Join(expression, "") 1055 p.P(`for iNdEx := len(`, exp, `) - 1; iNdEx >= 0; iNdEx-- {`) 1056 p.In() 1057 return exp + `[iNdEx]` 1058 } 1059 1060 func (p *marshalto) marshalAllSizeOf(field *descriptor.FieldDescriptorProto, varName, num string) bool { 1061 if gogoproto.IsStdTime(field) { 1062 p.marshalSizeOf(`StdTimeMarshalTo`, `SizeOfStdTime`, varName, num) 1063 } else if gogoproto.IsStdDuration(field) { 1064 p.marshalSizeOf(`StdDurationMarshalTo`, `SizeOfStdDuration`, varName, num) 1065 } else if gogoproto.IsStdDouble(field) { 1066 p.marshalSizeOf(`StdDoubleMarshalTo`, `SizeOfStdDouble`, varName, num) 1067 } else if gogoproto.IsStdFloat(field) { 1068 p.marshalSizeOf(`StdFloatMarshalTo`, `SizeOfStdFloat`, varName, num) 1069 } else if gogoproto.IsStdInt64(field) { 1070 p.marshalSizeOf(`StdInt64MarshalTo`, `SizeOfStdInt64`, varName, num) 1071 } else if gogoproto.IsStdUInt64(field) { 1072 p.marshalSizeOf(`StdUInt64MarshalTo`, `SizeOfStdUInt64`, varName, num) 1073 } else if gogoproto.IsStdInt32(field) { 1074 p.marshalSizeOf(`StdInt32MarshalTo`, `SizeOfStdInt32`, varName, num) 1075 } else if gogoproto.IsStdUInt32(field) { 1076 p.marshalSizeOf(`StdUInt32MarshalTo`, `SizeOfStdUInt32`, varName, num) 1077 } else if gogoproto.IsStdBool(field) { 1078 p.marshalSizeOf(`StdBoolMarshalTo`, `SizeOfStdBool`, varName, num) 1079 } else if gogoproto.IsStdString(field) { 1080 p.marshalSizeOf(`StdStringMarshalTo`, `SizeOfStdString`, varName, num) 1081 } else if gogoproto.IsStdBytes(field) { 1082 p.marshalSizeOf(`StdBytesMarshalTo`, `SizeOfStdBytes`, varName, num) 1083 } else { 1084 return false 1085 } 1086 return true 1087 } 1088 1089 func (p *marshalto) marshalSizeOf(marshal, size, varName, num string) { 1090 p.P(`n`, num, `, err`, num, ` := `, p.typesPkg.Use(), `.`, marshal, `(`, varName, `, dAtA[i-`, p.typesPkg.Use(), `.`, size, `(`, varName, `):])`) 1091 p.P(`if err`, num, ` != nil {`) 1092 p.In() 1093 p.P(`return 0, err`, num) 1094 p.Out() 1095 p.P(`}`) 1096 p.P(`i -= n`, num) 1097 p.callVarint(`n`, num) 1098 } 1099 1100 func (p *marshalto) backward(varName string, varInt bool) { 1101 p.P(`{`) 1102 p.In() 1103 p.P(`size, err := `, varName, `.MarshalToSizedBuffer(dAtA[:i])`) 1104 p.P(`if err != nil {`) 1105 p.In() 1106 p.P(`return 0, err`) 1107 p.Out() 1108 p.P(`}`) 1109 p.P(`i -= size`) 1110 if varInt { 1111 p.callVarint(`size`) 1112 } 1113 p.Out() 1114 p.P(`}`) 1115 } 1116 1117 func (p *marshalto) forward(varName string, varInt, protoSizer bool) { 1118 p.P(`{`) 1119 p.In() 1120 if protoSizer { 1121 p.P(`size := `, varName, `.ProtoSize()`) 1122 } else { 1123 p.P(`size := `, varName, `.Size()`) 1124 } 1125 p.P(`i -= size`) 1126 p.P(`if _, err := `, varName, `.MarshalTo(dAtA[i:]); err != nil {`) 1127 p.In() 1128 p.P(`return 0, err`) 1129 p.Out() 1130 p.P(`}`) 1131 p.Out() 1132 if varInt { 1133 p.callVarint(`size`) 1134 } 1135 p.P(`}`) 1136 } 1137 1138 func init() { 1139 generator.RegisterPlugin(NewMarshal()) 1140 }