github.com/Cleverse/go-ethereum@v0.0.0-20220927095127-45113064e7f2/rlp/rlpgen/gen.go (about) 1 // Copyright 2022 The go-ethereum Authors 2 // This file is part of the go-ethereum library. 3 // 4 // The go-ethereum library is free software: you can redistribute it and/or modify 5 // it under the terms of the GNU Lesser General Public License as published by 6 // the Free Software Foundation, either version 3 of the License, or 7 // (at your option) any later version. 8 // 9 // The go-ethereum library is distributed in the hope that it will be useful, 10 // but WITHOUT ANY WARRANTY; without even the implied warranty of 11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 // GNU Lesser General Public License for more details. 13 // 14 // You should have received a copy of the GNU Lesser General Public License 15 // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. 16 17 package main 18 19 import ( 20 "bytes" 21 "fmt" 22 "go/format" 23 "go/types" 24 "sort" 25 26 "github.com/ethereum/go-ethereum/rlp/internal/rlpstruct" 27 ) 28 29 // buildContext keeps the data needed for make*Op. 30 type buildContext struct { 31 topType *types.Named // the type we're creating methods for 32 33 encoderIface *types.Interface 34 decoderIface *types.Interface 35 rawValueType *types.Named 36 37 typeToStructCache map[types.Type]*rlpstruct.Type 38 } 39 40 func newBuildContext(packageRLP *types.Package) *buildContext { 41 enc := packageRLP.Scope().Lookup("Encoder").Type().Underlying() 42 dec := packageRLP.Scope().Lookup("Decoder").Type().Underlying() 43 rawv := packageRLP.Scope().Lookup("RawValue").Type() 44 return &buildContext{ 45 typeToStructCache: make(map[types.Type]*rlpstruct.Type), 46 encoderIface: enc.(*types.Interface), 47 decoderIface: dec.(*types.Interface), 48 rawValueType: rawv.(*types.Named), 49 } 50 } 51 52 func (bctx *buildContext) isEncoder(typ types.Type) bool { 53 return types.Implements(typ, bctx.encoderIface) 54 } 55 56 func (bctx *buildContext) isDecoder(typ types.Type) bool { 57 return types.Implements(typ, bctx.decoderIface) 58 } 59 60 // typeToStructType converts typ to rlpstruct.Type. 61 func (bctx *buildContext) typeToStructType(typ types.Type) *rlpstruct.Type { 62 if prev := bctx.typeToStructCache[typ]; prev != nil { 63 return prev // short-circuit for recursive types. 64 } 65 66 // Resolve named types to their underlying type, but keep the name. 67 name := types.TypeString(typ, nil) 68 for { 69 utype := typ.Underlying() 70 if utype == typ { 71 break 72 } 73 typ = utype 74 } 75 76 // Create the type and store it in cache. 77 t := &rlpstruct.Type{ 78 Name: name, 79 Kind: typeReflectKind(typ), 80 IsEncoder: bctx.isEncoder(typ), 81 IsDecoder: bctx.isDecoder(typ), 82 } 83 bctx.typeToStructCache[typ] = t 84 85 // Assign element type. 86 switch typ.(type) { 87 case *types.Array, *types.Slice, *types.Pointer: 88 etype := typ.(interface{ Elem() types.Type }).Elem() 89 t.Elem = bctx.typeToStructType(etype) 90 } 91 return t 92 } 93 94 // genContext is passed to the gen* methods of op when generating 95 // the output code. It tracks packages to be imported by the output 96 // file and assigns unique names of temporary variables. 97 type genContext struct { 98 inPackage *types.Package 99 imports map[string]struct{} 100 tempCounter int 101 } 102 103 func newGenContext(inPackage *types.Package) *genContext { 104 return &genContext{ 105 inPackage: inPackage, 106 imports: make(map[string]struct{}), 107 } 108 } 109 110 func (ctx *genContext) temp() string { 111 v := fmt.Sprintf("_tmp%d", ctx.tempCounter) 112 ctx.tempCounter++ 113 return v 114 } 115 116 func (ctx *genContext) resetTemp() { 117 ctx.tempCounter = 0 118 } 119 120 func (ctx *genContext) addImport(path string) { 121 if path == ctx.inPackage.Path() { 122 return // avoid importing the package that we're generating in. 123 } 124 // TODO: renaming? 125 ctx.imports[path] = struct{}{} 126 } 127 128 // importsList returns all packages that need to be imported. 129 func (ctx *genContext) importsList() []string { 130 imp := make([]string, 0, len(ctx.imports)) 131 for k := range ctx.imports { 132 imp = append(imp, k) 133 } 134 sort.Strings(imp) 135 return imp 136 } 137 138 // qualify is the types.Qualifier used for printing types. 139 func (ctx *genContext) qualify(pkg *types.Package) string { 140 if pkg.Path() == ctx.inPackage.Path() { 141 return "" 142 } 143 ctx.addImport(pkg.Path()) 144 // TODO: renaming? 145 return pkg.Name() 146 } 147 148 type op interface { 149 // genWrite creates the encoder. The generated code should write v, 150 // which is any Go expression, to the rlp.EncoderBuffer 'w'. 151 genWrite(ctx *genContext, v string) string 152 153 // genDecode creates the decoder. The generated code should read 154 // a value from the rlp.Stream 'dec' and store it to dst. 155 genDecode(ctx *genContext) (string, string) 156 } 157 158 // basicOp handles basic types bool, uint*, string. 159 type basicOp struct { 160 typ types.Type 161 writeMethod string // calle write the value 162 writeArgType types.Type // parameter type of writeMethod 163 decMethod string 164 decResultType types.Type // return type of decMethod 165 decUseBitSize bool // if true, result bit size is appended to decMethod 166 } 167 168 func (*buildContext) makeBasicOp(typ *types.Basic) (op, error) { 169 op := basicOp{typ: typ} 170 kind := typ.Kind() 171 switch { 172 case kind == types.Bool: 173 op.writeMethod = "WriteBool" 174 op.writeArgType = types.Typ[types.Bool] 175 op.decMethod = "Bool" 176 op.decResultType = types.Typ[types.Bool] 177 case kind >= types.Uint8 && kind <= types.Uint64: 178 op.writeMethod = "WriteUint64" 179 op.writeArgType = types.Typ[types.Uint64] 180 op.decMethod = "Uint" 181 op.decResultType = typ 182 op.decUseBitSize = true 183 case kind == types.String: 184 op.writeMethod = "WriteString" 185 op.writeArgType = types.Typ[types.String] 186 op.decMethod = "String" 187 op.decResultType = types.Typ[types.String] 188 default: 189 return nil, fmt.Errorf("unhandled basic type: %v", typ) 190 } 191 return op, nil 192 } 193 194 func (*buildContext) makeByteSliceOp(typ *types.Slice) op { 195 if !isByte(typ.Elem()) { 196 panic("non-byte slice type in makeByteSliceOp") 197 } 198 bslice := types.NewSlice(types.Typ[types.Uint8]) 199 return basicOp{ 200 typ: typ, 201 writeMethod: "WriteBytes", 202 writeArgType: bslice, 203 decMethod: "Bytes", 204 decResultType: bslice, 205 } 206 } 207 208 func (bctx *buildContext) makeRawValueOp() op { 209 bslice := types.NewSlice(types.Typ[types.Uint8]) 210 return basicOp{ 211 typ: bctx.rawValueType, 212 writeMethod: "Write", 213 writeArgType: bslice, 214 decMethod: "Raw", 215 decResultType: bslice, 216 } 217 } 218 219 func (op basicOp) writeNeedsConversion() bool { 220 return !types.AssignableTo(op.typ, op.writeArgType) 221 } 222 223 func (op basicOp) decodeNeedsConversion() bool { 224 return !types.AssignableTo(op.decResultType, op.typ) 225 } 226 227 func (op basicOp) genWrite(ctx *genContext, v string) string { 228 if op.writeNeedsConversion() { 229 v = fmt.Sprintf("%s(%s)", op.writeArgType, v) 230 } 231 return fmt.Sprintf("w.%s(%s)\n", op.writeMethod, v) 232 } 233 234 func (op basicOp) genDecode(ctx *genContext) (string, string) { 235 var ( 236 resultV = ctx.temp() 237 result = resultV 238 method = op.decMethod 239 ) 240 if op.decUseBitSize { 241 // Note: For now, this only works for platform-independent integer 242 // sizes. makeBasicOp forbids the platform-dependent types. 243 var sizes types.StdSizes 244 method = fmt.Sprintf("%s%d", op.decMethod, sizes.Sizeof(op.typ)*8) 245 } 246 247 // Call the decoder method. 248 var b bytes.Buffer 249 fmt.Fprintf(&b, "%s, err := dec.%s()\n", resultV, method) 250 fmt.Fprintf(&b, "if err != nil { return err }\n") 251 if op.decodeNeedsConversion() { 252 conv := ctx.temp() 253 fmt.Fprintf(&b, "%s := %s(%s)\n", conv, types.TypeString(op.typ, ctx.qualify), resultV) 254 result = conv 255 } 256 return result, b.String() 257 } 258 259 // byteArrayOp handles [...]byte. 260 type byteArrayOp struct { 261 typ types.Type 262 name types.Type // name != typ for named byte array types (e.g. common.Address) 263 } 264 265 func (bctx *buildContext) makeByteArrayOp(name *types.Named, typ *types.Array) byteArrayOp { 266 nt := types.Type(name) 267 if name == nil { 268 nt = typ 269 } 270 return byteArrayOp{typ, nt} 271 } 272 273 func (op byteArrayOp) genWrite(ctx *genContext, v string) string { 274 return fmt.Sprintf("w.WriteBytes(%s[:])\n", v) 275 } 276 277 func (op byteArrayOp) genDecode(ctx *genContext) (string, string) { 278 var resultV = ctx.temp() 279 280 var b bytes.Buffer 281 fmt.Fprintf(&b, "var %s %s\n", resultV, types.TypeString(op.name, ctx.qualify)) 282 fmt.Fprintf(&b, "if err := dec.ReadBytes(%s[:]); err != nil { return err }\n", resultV) 283 return resultV, b.String() 284 } 285 286 // bigIntNoPtrOp handles non-pointer big.Int. 287 // This exists because big.Int has it's own decoder operation on rlp.Stream, 288 // but the decode method returns *big.Int, so it needs to be dereferenced. 289 type bigIntOp struct { 290 pointer bool 291 } 292 293 func (op bigIntOp) genWrite(ctx *genContext, v string) string { 294 var b bytes.Buffer 295 296 fmt.Fprintf(&b, "if %s.Sign() == -1 {\n", v) 297 fmt.Fprintf(&b, " return rlp.ErrNegativeBigInt\n") 298 fmt.Fprintf(&b, "}\n") 299 dst := v 300 if !op.pointer { 301 dst = "&" + v 302 } 303 fmt.Fprintf(&b, "w.WriteBigInt(%s)\n", dst) 304 305 // Wrap with nil check. 306 if op.pointer { 307 code := b.String() 308 b.Reset() 309 fmt.Fprintf(&b, "if %s == nil {\n", v) 310 fmt.Fprintf(&b, " w.Write(rlp.EmptyString)") 311 fmt.Fprintf(&b, "} else {\n") 312 fmt.Fprint(&b, code) 313 fmt.Fprintf(&b, "}\n") 314 } 315 316 return b.String() 317 } 318 319 func (op bigIntOp) genDecode(ctx *genContext) (string, string) { 320 var resultV = ctx.temp() 321 322 var b bytes.Buffer 323 fmt.Fprintf(&b, "%s, err := dec.BigInt()\n", resultV) 324 fmt.Fprintf(&b, "if err != nil { return err }\n") 325 326 result := resultV 327 if !op.pointer { 328 result = "(*" + resultV + ")" 329 } 330 return result, b.String() 331 } 332 333 // encoderDecoderOp handles rlp.Encoder and rlp.Decoder. 334 // In order to be used with this, the type must implement both interfaces. 335 // This restriction may be lifted in the future by creating separate ops for 336 // encoding and decoding. 337 type encoderDecoderOp struct { 338 typ types.Type 339 } 340 341 func (op encoderDecoderOp) genWrite(ctx *genContext, v string) string { 342 return fmt.Sprintf("if err := %s.EncodeRLP(w); err != nil { return err }\n", v) 343 } 344 345 func (op encoderDecoderOp) genDecode(ctx *genContext) (string, string) { 346 // DecodeRLP must have pointer receiver, and this is verified in makeOp. 347 etyp := op.typ.(*types.Pointer).Elem() 348 var resultV = ctx.temp() 349 350 var b bytes.Buffer 351 fmt.Fprintf(&b, "%s := new(%s)\n", resultV, types.TypeString(etyp, ctx.qualify)) 352 fmt.Fprintf(&b, "if err := %s.DecodeRLP(dec); err != nil { return err }\n", resultV) 353 return resultV, b.String() 354 } 355 356 // ptrOp handles pointer types. 357 type ptrOp struct { 358 elemTyp types.Type 359 elem op 360 nilOK bool 361 nilValue rlpstruct.NilKind 362 } 363 364 func (bctx *buildContext) makePtrOp(elemTyp types.Type, tags rlpstruct.Tags) (op, error) { 365 elemOp, err := bctx.makeOp(nil, elemTyp, rlpstruct.Tags{}) 366 if err != nil { 367 return nil, err 368 } 369 op := ptrOp{elemTyp: elemTyp, elem: elemOp} 370 371 // Determine nil value. 372 if tags.NilOK { 373 op.nilOK = true 374 op.nilValue = tags.NilKind 375 } else { 376 styp := bctx.typeToStructType(elemTyp) 377 op.nilValue = styp.DefaultNilValue() 378 } 379 return op, nil 380 } 381 382 func (op ptrOp) genWrite(ctx *genContext, v string) string { 383 // Note: in writer functions, accesses to v are read-only, i.e. v is any Go 384 // expression. To make all accesses work through the pointer, we substitute 385 // v with (*v). This is required for most accesses including `v`, `call(v)`, 386 // and `v[index]` on slices. 387 // 388 // For `v.field` and `v[:]` on arrays, the dereference operation is not required. 389 var vv string 390 _, isStruct := op.elem.(structOp) 391 _, isByteArray := op.elem.(byteArrayOp) 392 if isStruct || isByteArray { 393 vv = v 394 } else { 395 vv = fmt.Sprintf("(*%s)", v) 396 } 397 398 var b bytes.Buffer 399 fmt.Fprintf(&b, "if %s == nil {\n", v) 400 fmt.Fprintf(&b, " w.Write([]byte{0x%X})\n", op.nilValue) 401 fmt.Fprintf(&b, "} else {\n") 402 fmt.Fprintf(&b, " %s", op.elem.genWrite(ctx, vv)) 403 fmt.Fprintf(&b, "}\n") 404 return b.String() 405 } 406 407 func (op ptrOp) genDecode(ctx *genContext) (string, string) { 408 result, code := op.elem.genDecode(ctx) 409 if !op.nilOK { 410 // If nil pointers are not allowed, we can just decode the element. 411 return "&" + result, code 412 } 413 414 // nil is allowed, so check the kind and size first. 415 // If size is zero and kind matches the nilKind of the type, 416 // the value decodes as a nil pointer. 417 var ( 418 resultV = ctx.temp() 419 kindV = ctx.temp() 420 sizeV = ctx.temp() 421 wantKind string 422 ) 423 if op.nilValue == rlpstruct.NilKindList { 424 wantKind = "rlp.List" 425 } else { 426 wantKind = "rlp.String" 427 } 428 var b bytes.Buffer 429 fmt.Fprintf(&b, "var %s %s\n", resultV, types.TypeString(types.NewPointer(op.elemTyp), ctx.qualify)) 430 fmt.Fprintf(&b, "if %s, %s, err := dec.Kind(); err != nil {\n", kindV, sizeV) 431 fmt.Fprintf(&b, " return err\n") 432 fmt.Fprintf(&b, "} else if %s != 0 || %s != %s {\n", sizeV, kindV, wantKind) 433 fmt.Fprint(&b, code) 434 fmt.Fprintf(&b, " %s = &%s\n", resultV, result) 435 fmt.Fprintf(&b, "}\n") 436 return resultV, b.String() 437 } 438 439 // structOp handles struct types. 440 type structOp struct { 441 named *types.Named 442 typ *types.Struct 443 fields []*structField 444 optionalFields []*structField 445 } 446 447 type structField struct { 448 name string 449 typ types.Type 450 elem op 451 } 452 453 func (bctx *buildContext) makeStructOp(named *types.Named, typ *types.Struct) (op, error) { 454 // Convert fields to []rlpstruct.Field. 455 var allStructFields []rlpstruct.Field 456 for i := 0; i < typ.NumFields(); i++ { 457 f := typ.Field(i) 458 allStructFields = append(allStructFields, rlpstruct.Field{ 459 Name: f.Name(), 460 Exported: f.Exported(), 461 Index: i, 462 Tag: typ.Tag(i), 463 Type: *bctx.typeToStructType(f.Type()), 464 }) 465 } 466 467 // Filter/validate fields. 468 fields, tags, err := rlpstruct.ProcessFields(allStructFields) 469 if err != nil { 470 return nil, err 471 } 472 473 // Create field ops. 474 var op = structOp{named: named, typ: typ} 475 for i, field := range fields { 476 // Advanced struct tags are not supported yet. 477 tag := tags[i] 478 if err := checkUnsupportedTags(field.Name, tag); err != nil { 479 return nil, err 480 } 481 typ := typ.Field(field.Index).Type() 482 elem, err := bctx.makeOp(nil, typ, tags[i]) 483 if err != nil { 484 return nil, fmt.Errorf("field %s: %v", field.Name, err) 485 } 486 f := &structField{name: field.Name, typ: typ, elem: elem} 487 if tag.Optional { 488 op.optionalFields = append(op.optionalFields, f) 489 } else { 490 op.fields = append(op.fields, f) 491 } 492 } 493 return op, nil 494 } 495 496 func checkUnsupportedTags(field string, tag rlpstruct.Tags) error { 497 if tag.Tail { 498 return fmt.Errorf(`field %s has unsupported struct tag "tail"`, field) 499 } 500 return nil 501 } 502 503 func (op structOp) genWrite(ctx *genContext, v string) string { 504 var b bytes.Buffer 505 var listMarker = ctx.temp() 506 fmt.Fprintf(&b, "%s := w.List()\n", listMarker) 507 for _, field := range op.fields { 508 selector := v + "." + field.name 509 fmt.Fprint(&b, field.elem.genWrite(ctx, selector)) 510 } 511 op.writeOptionalFields(&b, ctx, v) 512 fmt.Fprintf(&b, "w.ListEnd(%s)\n", listMarker) 513 return b.String() 514 } 515 516 func (op structOp) writeOptionalFields(b *bytes.Buffer, ctx *genContext, v string) { 517 if len(op.optionalFields) == 0 { 518 return 519 } 520 // First check zero-ness of all optional fields. 521 var zeroV = make([]string, len(op.optionalFields)) 522 for i, field := range op.optionalFields { 523 selector := v + "." + field.name 524 zeroV[i] = ctx.temp() 525 fmt.Fprintf(b, "%s := %s\n", zeroV[i], nonZeroCheck(selector, field.typ, ctx.qualify)) 526 } 527 // Now write the fields. 528 for i, field := range op.optionalFields { 529 selector := v + "." + field.name 530 cond := "" 531 for j := i; j < len(op.optionalFields); j++ { 532 if j > i { 533 cond += " || " 534 } 535 cond += zeroV[j] 536 } 537 fmt.Fprintf(b, "if %s {\n", cond) 538 fmt.Fprint(b, field.elem.genWrite(ctx, selector)) 539 fmt.Fprintf(b, "}\n") 540 } 541 } 542 543 func (op structOp) genDecode(ctx *genContext) (string, string) { 544 // Get the string representation of the type. 545 // Here, named types are handled separately because the output 546 // would contain a copy of the struct definition otherwise. 547 var typeName string 548 if op.named != nil { 549 typeName = types.TypeString(op.named, ctx.qualify) 550 } else { 551 typeName = types.TypeString(op.typ, ctx.qualify) 552 } 553 554 // Create struct object. 555 var resultV = ctx.temp() 556 var b bytes.Buffer 557 fmt.Fprintf(&b, "var %s %s\n", resultV, typeName) 558 559 // Decode fields. 560 fmt.Fprintf(&b, "{\n") 561 fmt.Fprintf(&b, "if _, err := dec.List(); err != nil { return err }\n") 562 for _, field := range op.fields { 563 result, code := field.elem.genDecode(ctx) 564 fmt.Fprintf(&b, "// %s:\n", field.name) 565 fmt.Fprint(&b, code) 566 fmt.Fprintf(&b, "%s.%s = %s\n", resultV, field.name, result) 567 } 568 op.decodeOptionalFields(&b, ctx, resultV) 569 fmt.Fprintf(&b, "if err := dec.ListEnd(); err != nil { return err }\n") 570 fmt.Fprintf(&b, "}\n") 571 return resultV, b.String() 572 } 573 574 func (op structOp) decodeOptionalFields(b *bytes.Buffer, ctx *genContext, resultV string) { 575 var suffix bytes.Buffer 576 for _, field := range op.optionalFields { 577 result, code := field.elem.genDecode(ctx) 578 fmt.Fprintf(b, "// %s:\n", field.name) 579 fmt.Fprintf(b, "if dec.MoreDataInList() {\n") 580 fmt.Fprint(b, code) 581 fmt.Fprintf(b, "%s.%s = %s\n", resultV, field.name, result) 582 fmt.Fprintf(&suffix, "}\n") 583 } 584 suffix.WriteTo(b) 585 } 586 587 // sliceOp handles slice types. 588 type sliceOp struct { 589 typ *types.Slice 590 elemOp op 591 } 592 593 func (bctx *buildContext) makeSliceOp(typ *types.Slice) (op, error) { 594 elemOp, err := bctx.makeOp(nil, typ.Elem(), rlpstruct.Tags{}) 595 if err != nil { 596 return nil, err 597 } 598 return sliceOp{typ: typ, elemOp: elemOp}, nil 599 } 600 601 func (op sliceOp) genWrite(ctx *genContext, v string) string { 602 var ( 603 listMarker = ctx.temp() // holds return value of w.List() 604 iterElemV = ctx.temp() // iteration variable 605 elemCode = op.elemOp.genWrite(ctx, iterElemV) 606 ) 607 608 var b bytes.Buffer 609 fmt.Fprintf(&b, "%s := w.List()\n", listMarker) 610 fmt.Fprintf(&b, "for _, %s := range %s {\n", iterElemV, v) 611 fmt.Fprint(&b, elemCode) 612 fmt.Fprintf(&b, "}\n") 613 fmt.Fprintf(&b, "w.ListEnd(%s)\n", listMarker) 614 return b.String() 615 } 616 617 func (op sliceOp) genDecode(ctx *genContext) (string, string) { 618 var sliceV = ctx.temp() // holds the output slice 619 elemResult, elemCode := op.elemOp.genDecode(ctx) 620 621 var b bytes.Buffer 622 fmt.Fprintf(&b, "var %s %s\n", sliceV, types.TypeString(op.typ, ctx.qualify)) 623 fmt.Fprintf(&b, "if _, err := dec.List(); err != nil { return err }\n") 624 fmt.Fprintf(&b, "for dec.MoreDataInList() {\n") 625 fmt.Fprintf(&b, " %s", elemCode) 626 fmt.Fprintf(&b, " %s = append(%s, %s)\n", sliceV, sliceV, elemResult) 627 fmt.Fprintf(&b, "}\n") 628 fmt.Fprintf(&b, "if err := dec.ListEnd(); err != nil { return err }\n") 629 return sliceV, b.String() 630 } 631 632 func (bctx *buildContext) makeOp(name *types.Named, typ types.Type, tags rlpstruct.Tags) (op, error) { 633 switch typ := typ.(type) { 634 case *types.Named: 635 if isBigInt(typ) { 636 return bigIntOp{}, nil 637 } 638 if typ == bctx.rawValueType { 639 return bctx.makeRawValueOp(), nil 640 } 641 if bctx.isDecoder(typ) { 642 return nil, fmt.Errorf("type %v implements rlp.Decoder with non-pointer receiver", typ) 643 } 644 // TODO: same check for encoder? 645 return bctx.makeOp(typ, typ.Underlying(), tags) 646 case *types.Pointer: 647 if isBigInt(typ.Elem()) { 648 return bigIntOp{pointer: true}, nil 649 } 650 // Encoder/Decoder interfaces. 651 if bctx.isEncoder(typ) { 652 if bctx.isDecoder(typ) { 653 return encoderDecoderOp{typ}, nil 654 } 655 return nil, fmt.Errorf("type %v implements rlp.Encoder but not rlp.Decoder", typ) 656 } 657 if bctx.isDecoder(typ) { 658 return nil, fmt.Errorf("type %v implements rlp.Decoder but not rlp.Encoder", typ) 659 } 660 // Default pointer handling. 661 return bctx.makePtrOp(typ.Elem(), tags) 662 case *types.Basic: 663 return bctx.makeBasicOp(typ) 664 case *types.Struct: 665 return bctx.makeStructOp(name, typ) 666 case *types.Slice: 667 etyp := typ.Elem() 668 if isByte(etyp) && !bctx.isEncoder(etyp) { 669 return bctx.makeByteSliceOp(typ), nil 670 } 671 return bctx.makeSliceOp(typ) 672 case *types.Array: 673 etyp := typ.Elem() 674 if isByte(etyp) && !bctx.isEncoder(etyp) { 675 return bctx.makeByteArrayOp(name, typ), nil 676 } 677 return nil, fmt.Errorf("unhandled array type: %v", typ) 678 default: 679 return nil, fmt.Errorf("unhandled type: %v", typ) 680 } 681 } 682 683 // generateDecoder generates the DecodeRLP method on 'typ'. 684 func generateDecoder(ctx *genContext, typ string, op op) []byte { 685 ctx.resetTemp() 686 ctx.addImport(pathOfPackageRLP) 687 688 result, code := op.genDecode(ctx) 689 var b bytes.Buffer 690 fmt.Fprintf(&b, "func (obj *%s) DecodeRLP(dec *rlp.Stream) error {\n", typ) 691 fmt.Fprint(&b, code) 692 fmt.Fprintf(&b, " *obj = %s\n", result) 693 fmt.Fprintf(&b, " return nil\n") 694 fmt.Fprintf(&b, "}\n") 695 return b.Bytes() 696 } 697 698 // generateEncoder generates the EncodeRLP method on 'typ'. 699 func generateEncoder(ctx *genContext, typ string, op op) []byte { 700 ctx.resetTemp() 701 ctx.addImport("io") 702 ctx.addImport(pathOfPackageRLP) 703 704 var b bytes.Buffer 705 fmt.Fprintf(&b, "func (obj *%s) EncodeRLP(_w io.Writer) error {\n", typ) 706 fmt.Fprintf(&b, " w := rlp.NewEncoderBuffer(_w)\n") 707 fmt.Fprint(&b, op.genWrite(ctx, "obj")) 708 fmt.Fprintf(&b, " return w.Flush()\n") 709 fmt.Fprintf(&b, "}\n") 710 return b.Bytes() 711 } 712 713 func (bctx *buildContext) generate(typ *types.Named, encoder, decoder bool) ([]byte, error) { 714 bctx.topType = typ 715 716 pkg := typ.Obj().Pkg() 717 op, err := bctx.makeOp(nil, typ, rlpstruct.Tags{}) 718 if err != nil { 719 return nil, err 720 } 721 722 var ( 723 ctx = newGenContext(pkg) 724 encSource []byte 725 decSource []byte 726 ) 727 if encoder { 728 encSource = generateEncoder(ctx, typ.Obj().Name(), op) 729 } 730 if decoder { 731 decSource = generateDecoder(ctx, typ.Obj().Name(), op) 732 } 733 734 var b bytes.Buffer 735 fmt.Fprintf(&b, "package %s\n\n", pkg.Name()) 736 for _, imp := range ctx.importsList() { 737 fmt.Fprintf(&b, "import %q\n", imp) 738 } 739 if encoder { 740 fmt.Fprintln(&b) 741 b.Write(encSource) 742 } 743 if decoder { 744 fmt.Fprintln(&b) 745 b.Write(decSource) 746 } 747 748 source := b.Bytes() 749 // fmt.Println(string(source)) 750 return format.Source(source) 751 }