github.com/klaytn/klaytn@v1.12.1/rlp/rlpgen/gen.go (about)

     1  // Copyright 2022 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package main
    18  
    19  import (
    20  	"bytes"
    21  	"fmt"
    22  	"go/format"
    23  	"go/types"
    24  	"sort"
    25  
    26  	"github.com/klaytn/klaytn/rlp/internal/rlpstruct"
    27  )
    28  
    29  // buildContext keeps the data needed for make*Op.
    30  type buildContext struct {
    31  	topType *types.Named // the type we're creating methods for
    32  
    33  	encoderIface *types.Interface
    34  	decoderIface *types.Interface
    35  	rawValueType *types.Named
    36  
    37  	typeToStructCache map[types.Type]*rlpstruct.Type
    38  }
    39  
    40  func newBuildContext(packageRLP *types.Package) *buildContext {
    41  	enc := packageRLP.Scope().Lookup("Encoder").Type().Underlying()
    42  	dec := packageRLP.Scope().Lookup("Decoder").Type().Underlying()
    43  	rawv := packageRLP.Scope().Lookup("RawValue").Type()
    44  	return &buildContext{
    45  		typeToStructCache: make(map[types.Type]*rlpstruct.Type),
    46  		encoderIface:      enc.(*types.Interface),
    47  		decoderIface:      dec.(*types.Interface),
    48  		rawValueType:      rawv.(*types.Named),
    49  	}
    50  }
    51  
    52  func (bctx *buildContext) isEncoder(typ types.Type) bool {
    53  	return types.Implements(typ, bctx.encoderIface)
    54  }
    55  
    56  func (bctx *buildContext) isDecoder(typ types.Type) bool {
    57  	return types.Implements(typ, bctx.decoderIface)
    58  }
    59  
    60  // typeToStructType converts typ to rlpstruct.Type.
    61  func (bctx *buildContext) typeToStructType(typ types.Type) *rlpstruct.Type {
    62  	if prev := bctx.typeToStructCache[typ]; prev != nil {
    63  		return prev // short-circuit for recursive types.
    64  	}
    65  
    66  	// Resolve named types to their underlying type, but keep the name.
    67  	name := types.TypeString(typ, nil)
    68  	for {
    69  		utype := typ.Underlying()
    70  		if utype == typ {
    71  			break
    72  		}
    73  		typ = utype
    74  	}
    75  
    76  	// Create the type and store it in cache.
    77  	t := &rlpstruct.Type{
    78  		Name:      name,
    79  		Kind:      typeReflectKind(typ),
    80  		IsEncoder: bctx.isEncoder(typ),
    81  		IsDecoder: bctx.isDecoder(typ),
    82  	}
    83  	bctx.typeToStructCache[typ] = t
    84  
    85  	// Assign element type.
    86  	switch typ.(type) {
    87  	case *types.Array, *types.Slice, *types.Pointer:
    88  		etype := typ.(interface{ Elem() types.Type }).Elem()
    89  		t.Elem = bctx.typeToStructType(etype)
    90  	}
    91  	return t
    92  }
    93  
    94  // genContext is passed to the gen* methods of op when generating
    95  // the output code. It tracks packages to be imported by the output
    96  // file and assigns unique names of temporary variables.
    97  type genContext struct {
    98  	inPackage   *types.Package
    99  	imports     map[string]struct{}
   100  	tempCounter int
   101  }
   102  
   103  func newGenContext(inPackage *types.Package) *genContext {
   104  	return &genContext{
   105  		inPackage: inPackage,
   106  		imports:   make(map[string]struct{}),
   107  	}
   108  }
   109  
   110  func (ctx *genContext) temp() string {
   111  	v := fmt.Sprintf("_tmp%d", ctx.tempCounter)
   112  	ctx.tempCounter++
   113  	return v
   114  }
   115  
   116  func (ctx *genContext) resetTemp() {
   117  	ctx.tempCounter = 0
   118  }
   119  
   120  func (ctx *genContext) addImport(path string) {
   121  	if path == ctx.inPackage.Path() {
   122  		return // avoid importing the package that we're generating in.
   123  	}
   124  	// TODO: renaming?
   125  	ctx.imports[path] = struct{}{}
   126  }
   127  
   128  // importsList returns all packages that need to be imported.
   129  func (ctx *genContext) importsList() []string {
   130  	imp := make([]string, 0, len(ctx.imports))
   131  	for k := range ctx.imports {
   132  		imp = append(imp, k)
   133  	}
   134  	sort.Strings(imp)
   135  	return imp
   136  }
   137  
   138  // qualify is the types.Qualifier used for printing types.
   139  func (ctx *genContext) qualify(pkg *types.Package) string {
   140  	if pkg.Path() == ctx.inPackage.Path() {
   141  		return ""
   142  	}
   143  	ctx.addImport(pkg.Path())
   144  	// TODO: renaming?
   145  	return pkg.Name()
   146  }
   147  
   148  type op interface {
   149  	// genWrite creates the encoder. The generated code should write v,
   150  	// which is any Go expression, to the rlp.EncoderBuffer 'w'.
   151  	genWrite(ctx *genContext, v string) string
   152  
   153  	// genDecode creates the decoder. The generated code should read
   154  	// a value from the rlp.Stream 'dec' and store it to dst.
   155  	genDecode(ctx *genContext) (string, string)
   156  }
   157  
   158  // basicOp handles basic types bool, uint*, string.
   159  type basicOp struct {
   160  	typ           types.Type
   161  	writeMethod   string     // calle write the value
   162  	writeArgType  types.Type // parameter type of writeMethod
   163  	decMethod     string
   164  	decResultType types.Type // return type of decMethod
   165  	decUseBitSize bool       // if true, result bit size is appended to decMethod
   166  }
   167  
   168  func (*buildContext) makeBasicOp(typ *types.Basic) (op, error) {
   169  	op := basicOp{typ: typ}
   170  	kind := typ.Kind()
   171  	switch {
   172  	case kind == types.Bool:
   173  		op.writeMethod = "WriteBool"
   174  		op.writeArgType = types.Typ[types.Bool]
   175  		op.decMethod = "Bool"
   176  		op.decResultType = types.Typ[types.Bool]
   177  	case kind >= types.Uint8 && kind <= types.Uint64:
   178  		op.writeMethod = "WriteUint64"
   179  		op.writeArgType = types.Typ[types.Uint64]
   180  		op.decMethod = "Uint"
   181  		op.decResultType = typ
   182  		op.decUseBitSize = true
   183  	case kind == types.String:
   184  		op.writeMethod = "WriteString"
   185  		op.writeArgType = types.Typ[types.String]
   186  		op.decMethod = "String"
   187  		op.decResultType = types.Typ[types.String]
   188  	default:
   189  		return nil, fmt.Errorf("unhandled basic type: %v", typ)
   190  	}
   191  	return op, nil
   192  }
   193  
   194  func (*buildContext) makeByteSliceOp(typ *types.Slice) op {
   195  	if !isByte(typ.Elem()) {
   196  		panic("non-byte slice type in makeByteSliceOp")
   197  	}
   198  	bslice := types.NewSlice(types.Typ[types.Uint8])
   199  	return basicOp{
   200  		typ:           typ,
   201  		writeMethod:   "WriteBytes",
   202  		writeArgType:  bslice,
   203  		decMethod:     "Bytes",
   204  		decResultType: bslice,
   205  	}
   206  }
   207  
   208  func (bctx *buildContext) makeRawValueOp() op {
   209  	bslice := types.NewSlice(types.Typ[types.Uint8])
   210  	return basicOp{
   211  		typ:           bctx.rawValueType,
   212  		writeMethod:   "Write",
   213  		writeArgType:  bslice,
   214  		decMethod:     "Raw",
   215  		decResultType: bslice,
   216  	}
   217  }
   218  
   219  func (op basicOp) writeNeedsConversion() bool {
   220  	return !types.AssignableTo(op.typ, op.writeArgType)
   221  }
   222  
   223  func (op basicOp) decodeNeedsConversion() bool {
   224  	return !types.AssignableTo(op.decResultType, op.typ)
   225  }
   226  
   227  func (op basicOp) genWrite(ctx *genContext, v string) string {
   228  	if op.writeNeedsConversion() {
   229  		v = fmt.Sprintf("%s(%s)", op.writeArgType, v)
   230  	}
   231  	return fmt.Sprintf("w.%s(%s)\n", op.writeMethod, v)
   232  }
   233  
   234  func (op basicOp) genDecode(ctx *genContext) (string, string) {
   235  	var (
   236  		resultV = ctx.temp()
   237  		result  = resultV
   238  		method  = op.decMethod
   239  	)
   240  	if op.decUseBitSize {
   241  		// Note: For now, this only works for platform-independent integer
   242  		// sizes. makeBasicOp forbids the platform-dependent types.
   243  		var sizes types.StdSizes
   244  		method = fmt.Sprintf("%s%d", op.decMethod, sizes.Sizeof(op.typ)*8)
   245  	}
   246  
   247  	// Call the decoder method.
   248  	var b bytes.Buffer
   249  	fmt.Fprintf(&b, "%s, err := dec.%s()\n", resultV, method)
   250  	fmt.Fprintf(&b, "if err != nil { return err }\n")
   251  	if op.decodeNeedsConversion() {
   252  		conv := ctx.temp()
   253  		fmt.Fprintf(&b, "%s := %s(%s)\n", conv, types.TypeString(op.typ, ctx.qualify), resultV)
   254  		result = conv
   255  	}
   256  	return result, b.String()
   257  }
   258  
   259  // byteArrayOp handles [...]byte.
   260  type byteArrayOp struct {
   261  	typ  types.Type
   262  	name types.Type // name != typ for named byte array types (e.g. common.Address)
   263  }
   264  
   265  func (bctx *buildContext) makeByteArrayOp(name *types.Named, typ *types.Array) byteArrayOp {
   266  	nt := types.Type(name)
   267  	if name == nil {
   268  		nt = typ
   269  	}
   270  	return byteArrayOp{typ, nt}
   271  }
   272  
   273  func (op byteArrayOp) genWrite(ctx *genContext, v string) string {
   274  	return fmt.Sprintf("w.WriteBytes(%s[:])\n", v)
   275  }
   276  
   277  func (op byteArrayOp) genDecode(ctx *genContext) (string, string) {
   278  	resultV := ctx.temp()
   279  
   280  	var b bytes.Buffer
   281  	fmt.Fprintf(&b, "var %s %s\n", resultV, types.TypeString(op.name, ctx.qualify))
   282  	fmt.Fprintf(&b, "if err := dec.ReadBytes(%s[:]); err != nil { return err }\n", resultV)
   283  	return resultV, b.String()
   284  }
   285  
   286  // bigIntNoPtrOp handles non-pointer big.Int.
   287  // This exists because big.Int has it's own decoder operation on rlp.Stream,
   288  // but the decode method returns *big.Int, so it needs to be dereferenced.
   289  type bigIntOp struct {
   290  	pointer bool
   291  }
   292  
   293  func (op bigIntOp) genWrite(ctx *genContext, v string) string {
   294  	var b bytes.Buffer
   295  
   296  	fmt.Fprintf(&b, "if %s.Sign() == -1 {\n", v)
   297  	fmt.Fprintf(&b, "  return rlp.ErrNegativeBigInt\n")
   298  	fmt.Fprintf(&b, "}\n")
   299  	dst := v
   300  	if !op.pointer {
   301  		dst = "&" + v
   302  	}
   303  	fmt.Fprintf(&b, "w.WriteBigInt(%s)\n", dst)
   304  
   305  	// Wrap with nil check.
   306  	if op.pointer {
   307  		code := b.String()
   308  		b.Reset()
   309  		fmt.Fprintf(&b, "if %s == nil {\n", v)
   310  		fmt.Fprintf(&b, "  w.Write(rlp.EmptyString)")
   311  		fmt.Fprintf(&b, "} else {\n")
   312  		fmt.Fprint(&b, code)
   313  		fmt.Fprintf(&b, "}\n")
   314  	}
   315  
   316  	return b.String()
   317  }
   318  
   319  func (op bigIntOp) genDecode(ctx *genContext) (string, string) {
   320  	resultV := ctx.temp()
   321  
   322  	var b bytes.Buffer
   323  	fmt.Fprintf(&b, "%s, err := dec.BigInt()\n", resultV)
   324  	fmt.Fprintf(&b, "if err != nil { return err }\n")
   325  
   326  	result := resultV
   327  	if !op.pointer {
   328  		result = "(*" + resultV + ")"
   329  	}
   330  	return result, b.String()
   331  }
   332  
   333  // encoderDecoderOp handles rlp.Encoder and rlp.Decoder.
   334  // In order to be used with this, the type must implement both interfaces.
   335  // This restriction may be lifted in the future by creating separate ops for
   336  // encoding and decoding.
   337  type encoderDecoderOp struct {
   338  	typ types.Type
   339  }
   340  
   341  func (op encoderDecoderOp) genWrite(ctx *genContext, v string) string {
   342  	return fmt.Sprintf("if err := %s.EncodeRLP(w); err != nil { return err }\n", v)
   343  }
   344  
   345  func (op encoderDecoderOp) genDecode(ctx *genContext) (string, string) {
   346  	// DecodeRLP must have pointer receiver, and this is verified in makeOp.
   347  	etyp := op.typ.(*types.Pointer).Elem()
   348  	resultV := ctx.temp()
   349  
   350  	var b bytes.Buffer
   351  	fmt.Fprintf(&b, "%s := new(%s)\n", resultV, types.TypeString(etyp, ctx.qualify))
   352  	fmt.Fprintf(&b, "if err := %s.DecodeRLP(dec); err != nil { return err }\n", resultV)
   353  	return resultV, b.String()
   354  }
   355  
   356  // ptrOp handles pointer types.
   357  type ptrOp struct {
   358  	elemTyp  types.Type
   359  	elem     op
   360  	nilOK    bool
   361  	nilValue rlpstruct.NilKind
   362  }
   363  
   364  func (bctx *buildContext) makePtrOp(elemTyp types.Type, tags rlpstruct.Tags) (op, error) {
   365  	elemOp, err := bctx.makeOp(nil, elemTyp, rlpstruct.Tags{})
   366  	if err != nil {
   367  		return nil, err
   368  	}
   369  	op := ptrOp{elemTyp: elemTyp, elem: elemOp}
   370  
   371  	// Determine nil value.
   372  	if tags.NilOK {
   373  		op.nilOK = true
   374  		op.nilValue = tags.NilKind
   375  	} else {
   376  		styp := bctx.typeToStructType(elemTyp)
   377  		op.nilValue = styp.DefaultNilValue()
   378  	}
   379  	return op, nil
   380  }
   381  
   382  func (op ptrOp) genWrite(ctx *genContext, v string) string {
   383  	// Note: in writer functions, accesses to v are read-only, i.e. v is any Go
   384  	// expression. To make all accesses work through the pointer, we substitute
   385  	// v with (*v). This is required for most accesses including `v`, `call(v)`,
   386  	// and `v[index]` on slices.
   387  	//
   388  	// For `v.field` and `v[:]` on arrays, the dereference operation is not required.
   389  	var vv string
   390  	_, isStruct := op.elem.(structOp)
   391  	_, isByteArray := op.elem.(byteArrayOp)
   392  	if isStruct || isByteArray {
   393  		vv = v
   394  	} else {
   395  		vv = fmt.Sprintf("(*%s)", v)
   396  	}
   397  
   398  	var b bytes.Buffer
   399  	fmt.Fprintf(&b, "if %s == nil {\n", v)
   400  	fmt.Fprintf(&b, "  w.Write([]byte{0x%X})\n", op.nilValue)
   401  	fmt.Fprintf(&b, "} else {\n")
   402  	fmt.Fprintf(&b, "  %s", op.elem.genWrite(ctx, vv))
   403  	fmt.Fprintf(&b, "}\n")
   404  	return b.String()
   405  }
   406  
   407  func (op ptrOp) genDecode(ctx *genContext) (string, string) {
   408  	result, code := op.elem.genDecode(ctx)
   409  	if !op.nilOK {
   410  		// If nil pointers are not allowed, we can just decode the element.
   411  		return "&" + result, code
   412  	}
   413  
   414  	// nil is allowed, so check the kind and size first.
   415  	// If size is zero and kind matches the nilKind of the type,
   416  	// the value decodes as a nil pointer.
   417  	var (
   418  		resultV  = ctx.temp()
   419  		kindV    = ctx.temp()
   420  		sizeV    = ctx.temp()
   421  		wantKind string
   422  	)
   423  	if op.nilValue == rlpstruct.NilKindList {
   424  		wantKind = "rlp.List"
   425  	} else {
   426  		wantKind = "rlp.String"
   427  	}
   428  	var b bytes.Buffer
   429  	fmt.Fprintf(&b, "var %s %s\n", resultV, types.TypeString(types.NewPointer(op.elemTyp), ctx.qualify))
   430  	fmt.Fprintf(&b, "if %s, %s, err := dec.Kind(); err != nil {\n", kindV, sizeV)
   431  	fmt.Fprintf(&b, "  return err\n")
   432  	fmt.Fprintf(&b, "} else if %s != 0 || %s != %s {\n", sizeV, kindV, wantKind)
   433  	fmt.Fprint(&b, code)
   434  	fmt.Fprintf(&b, "  %s = &%s\n", resultV, result)
   435  	fmt.Fprintf(&b, "}\n")
   436  	return resultV, b.String()
   437  }
   438  
   439  // structOp handles struct types.
   440  type structOp struct {
   441  	named          *types.Named
   442  	typ            *types.Struct
   443  	fields         []*structField
   444  	optionalFields []*structField
   445  }
   446  
   447  type structField struct {
   448  	name string
   449  	typ  types.Type
   450  	elem op
   451  }
   452  
   453  func (bctx *buildContext) makeStructOp(named *types.Named, typ *types.Struct) (op, error) {
   454  	// Convert fields to []rlpstruct.Field.
   455  	var allStructFields []rlpstruct.Field
   456  	for i := 0; i < typ.NumFields(); i++ {
   457  		f := typ.Field(i)
   458  		allStructFields = append(allStructFields, rlpstruct.Field{
   459  			Name:     f.Name(),
   460  			Exported: f.Exported(),
   461  			Index:    i,
   462  			Tag:      typ.Tag(i),
   463  			Type:     *bctx.typeToStructType(f.Type()),
   464  		})
   465  	}
   466  
   467  	// Filter/validate fields.
   468  	fields, tags, err := rlpstruct.ProcessFields(allStructFields)
   469  	if err != nil {
   470  		return nil, err
   471  	}
   472  
   473  	// Create field ops.
   474  	op := structOp{named: named, typ: typ}
   475  	for i, field := range fields {
   476  		// Advanced struct tags are not supported yet.
   477  		tag := tags[i]
   478  		if err := checkUnsupportedTags(field.Name, tag); err != nil {
   479  			return nil, err
   480  		}
   481  		typ := typ.Field(field.Index).Type()
   482  		elem, err := bctx.makeOp(nil, typ, tags[i])
   483  		if err != nil {
   484  			return nil, fmt.Errorf("field %s: %v", field.Name, err)
   485  		}
   486  		f := &structField{name: field.Name, typ: typ, elem: elem}
   487  		if tag.Optional {
   488  			op.optionalFields = append(op.optionalFields, f)
   489  		} else {
   490  			op.fields = append(op.fields, f)
   491  		}
   492  	}
   493  	return op, nil
   494  }
   495  
   496  func checkUnsupportedTags(field string, tag rlpstruct.Tags) error {
   497  	if tag.Tail {
   498  		return fmt.Errorf(`field %s has unsupported struct tag "tail"`, field)
   499  	}
   500  	return nil
   501  }
   502  
   503  func (op structOp) genWrite(ctx *genContext, v string) string {
   504  	var b bytes.Buffer
   505  	listMarker := ctx.temp()
   506  	fmt.Fprintf(&b, "%s := w.List()\n", listMarker)
   507  	for _, field := range op.fields {
   508  		selector := v + "." + field.name
   509  		fmt.Fprint(&b, field.elem.genWrite(ctx, selector))
   510  	}
   511  	op.writeOptionalFields(&b, ctx, v)
   512  	fmt.Fprintf(&b, "w.ListEnd(%s)\n", listMarker)
   513  	return b.String()
   514  }
   515  
   516  func (op structOp) writeOptionalFields(b *bytes.Buffer, ctx *genContext, v string) {
   517  	if len(op.optionalFields) == 0 {
   518  		return
   519  	}
   520  	// First check zero-ness of all optional fields.
   521  	zeroV := make([]string, len(op.optionalFields))
   522  	for i, field := range op.optionalFields {
   523  		selector := v + "." + field.name
   524  		zeroV[i] = ctx.temp()
   525  		fmt.Fprintf(b, "%s := %s\n", zeroV[i], nonZeroCheck(selector, field.typ, ctx.qualify))
   526  	}
   527  	// Now write the fields.
   528  	for i, field := range op.optionalFields {
   529  		selector := v + "." + field.name
   530  		cond := ""
   531  		for j := i; j < len(op.optionalFields); j++ {
   532  			if j > i {
   533  				cond += " || "
   534  			}
   535  			cond += zeroV[j]
   536  		}
   537  		fmt.Fprintf(b, "if %s {\n", cond)
   538  		fmt.Fprint(b, field.elem.genWrite(ctx, selector))
   539  		fmt.Fprintf(b, "}\n")
   540  	}
   541  }
   542  
   543  func (op structOp) genDecode(ctx *genContext) (string, string) {
   544  	// Get the string representation of the type.
   545  	// Here, named types are handled separately because the output
   546  	// would contain a copy of the struct definition otherwise.
   547  	var typeName string
   548  	if op.named != nil {
   549  		typeName = types.TypeString(op.named, ctx.qualify)
   550  	} else {
   551  		typeName = types.TypeString(op.typ, ctx.qualify)
   552  	}
   553  
   554  	// Create struct object.
   555  	resultV := ctx.temp()
   556  	var b bytes.Buffer
   557  	fmt.Fprintf(&b, "var %s %s\n", resultV, typeName)
   558  
   559  	// Decode fields.
   560  	fmt.Fprintf(&b, "{\n")
   561  	fmt.Fprintf(&b, "if _, err := dec.List(); err != nil { return err }\n")
   562  	for _, field := range op.fields {
   563  		result, code := field.elem.genDecode(ctx)
   564  		fmt.Fprintf(&b, "// %s:\n", field.name)
   565  		fmt.Fprint(&b, code)
   566  		fmt.Fprintf(&b, "%s.%s = %s\n", resultV, field.name, result)
   567  	}
   568  	op.decodeOptionalFields(&b, ctx, resultV)
   569  	fmt.Fprintf(&b, "if err := dec.ListEnd(); err != nil { return err }\n")
   570  	fmt.Fprintf(&b, "}\n")
   571  	return resultV, b.String()
   572  }
   573  
   574  func (op structOp) decodeOptionalFields(b *bytes.Buffer, ctx *genContext, resultV string) {
   575  	var suffix bytes.Buffer
   576  	for _, field := range op.optionalFields {
   577  		result, code := field.elem.genDecode(ctx)
   578  		fmt.Fprintf(b, "// %s:\n", field.name)
   579  		fmt.Fprintf(b, "if dec.MoreDataInList() {\n")
   580  		fmt.Fprint(b, code)
   581  		fmt.Fprintf(b, "%s.%s = %s\n", resultV, field.name, result)
   582  		fmt.Fprintf(&suffix, "}\n")
   583  	}
   584  	suffix.WriteTo(b)
   585  }
   586  
   587  // sliceOp handles slice types.
   588  type sliceOp struct {
   589  	typ    *types.Slice
   590  	elemOp op
   591  }
   592  
   593  func (bctx *buildContext) makeSliceOp(typ *types.Slice) (op, error) {
   594  	elemOp, err := bctx.makeOp(nil, typ.Elem(), rlpstruct.Tags{})
   595  	if err != nil {
   596  		return nil, err
   597  	}
   598  	return sliceOp{typ: typ, elemOp: elemOp}, nil
   599  }
   600  
   601  func (op sliceOp) genWrite(ctx *genContext, v string) string {
   602  	var (
   603  		listMarker = ctx.temp() // holds return value of w.List()
   604  		iterElemV  = ctx.temp() // iteration variable
   605  		elemCode   = op.elemOp.genWrite(ctx, iterElemV)
   606  	)
   607  
   608  	var b bytes.Buffer
   609  	fmt.Fprintf(&b, "%s := w.List()\n", listMarker)
   610  	fmt.Fprintf(&b, "for _, %s := range %s {\n", iterElemV, v)
   611  	fmt.Fprint(&b, elemCode)
   612  	fmt.Fprintf(&b, "}\n")
   613  	fmt.Fprintf(&b, "w.ListEnd(%s)\n", listMarker)
   614  	return b.String()
   615  }
   616  
   617  func (op sliceOp) genDecode(ctx *genContext) (string, string) {
   618  	sliceV := ctx.temp() // holds the output slice
   619  	elemResult, elemCode := op.elemOp.genDecode(ctx)
   620  
   621  	var b bytes.Buffer
   622  	fmt.Fprintf(&b, "var %s %s\n", sliceV, types.TypeString(op.typ, ctx.qualify))
   623  	fmt.Fprintf(&b, "if _, err := dec.List(); err != nil { return err }\n")
   624  	fmt.Fprintf(&b, "for dec.MoreDataInList() {\n")
   625  	fmt.Fprintf(&b, "  %s", elemCode)
   626  	fmt.Fprintf(&b, "  %s = append(%s, %s)\n", sliceV, sliceV, elemResult)
   627  	fmt.Fprintf(&b, "}\n")
   628  	fmt.Fprintf(&b, "if err := dec.ListEnd(); err != nil { return err }\n")
   629  	return sliceV, b.String()
   630  }
   631  
   632  func (bctx *buildContext) makeOp(name *types.Named, typ types.Type, tags rlpstruct.Tags) (op, error) {
   633  	switch typ := typ.(type) {
   634  	case *types.Named:
   635  		if isBigInt(typ) {
   636  			return bigIntOp{}, nil
   637  		}
   638  		if typ == bctx.rawValueType {
   639  			return bctx.makeRawValueOp(), nil
   640  		}
   641  		if bctx.isDecoder(typ) {
   642  			return nil, fmt.Errorf("type %v implements rlp.Decoder with non-pointer receiver", typ)
   643  		}
   644  		// TODO: same check for encoder?
   645  		return bctx.makeOp(typ, typ.Underlying(), tags)
   646  	case *types.Pointer:
   647  		if isBigInt(typ.Elem()) {
   648  			return bigIntOp{pointer: true}, nil
   649  		}
   650  		// Encoder/Decoder interfaces.
   651  		if bctx.isEncoder(typ) {
   652  			if bctx.isDecoder(typ) {
   653  				return encoderDecoderOp{typ}, nil
   654  			}
   655  			return nil, fmt.Errorf("type %v implements rlp.Encoder but not rlp.Decoder", typ)
   656  		}
   657  		if bctx.isDecoder(typ) {
   658  			return nil, fmt.Errorf("type %v implements rlp.Decoder but not rlp.Encoder", typ)
   659  		}
   660  		// Default pointer handling.
   661  		return bctx.makePtrOp(typ.Elem(), tags)
   662  	case *types.Basic:
   663  		return bctx.makeBasicOp(typ)
   664  	case *types.Struct:
   665  		return bctx.makeStructOp(name, typ)
   666  	case *types.Slice:
   667  		etyp := typ.Elem()
   668  		if isByte(etyp) && !bctx.isEncoder(etyp) {
   669  			return bctx.makeByteSliceOp(typ), nil
   670  		}
   671  		return bctx.makeSliceOp(typ)
   672  	case *types.Array:
   673  		etyp := typ.Elem()
   674  		if isByte(etyp) && !bctx.isEncoder(etyp) {
   675  			return bctx.makeByteArrayOp(name, typ), nil
   676  		}
   677  		return nil, fmt.Errorf("unhandled array type: %v", typ)
   678  	default:
   679  		return nil, fmt.Errorf("unhandled type: %v", typ)
   680  	}
   681  }
   682  
   683  // generateDecoder generates the DecodeRLP method on 'typ'.
   684  func generateDecoder(ctx *genContext, typ string, op op) []byte {
   685  	ctx.resetTemp()
   686  	ctx.addImport(pathOfPackageRLP)
   687  
   688  	result, code := op.genDecode(ctx)
   689  	var b bytes.Buffer
   690  	fmt.Fprintf(&b, "func (obj *%s) DecodeRLP(dec *rlp.Stream) error {\n", typ)
   691  	fmt.Fprint(&b, code)
   692  	fmt.Fprintf(&b, "  *obj = %s\n", result)
   693  	fmt.Fprintf(&b, "  return nil\n")
   694  	fmt.Fprintf(&b, "}\n")
   695  	return b.Bytes()
   696  }
   697  
   698  // generateEncoder generates the EncodeRLP method on 'typ'.
   699  func generateEncoder(ctx *genContext, typ string, op op) []byte {
   700  	ctx.resetTemp()
   701  	ctx.addImport("io")
   702  	ctx.addImport(pathOfPackageRLP)
   703  
   704  	var b bytes.Buffer
   705  	fmt.Fprintf(&b, "func (obj *%s) EncodeRLP(_w io.Writer) error {\n", typ)
   706  	fmt.Fprintf(&b, "  w := rlp.NewEncoderBuffer(_w)\n")
   707  	fmt.Fprint(&b, op.genWrite(ctx, "obj"))
   708  	fmt.Fprintf(&b, "  return w.Flush()\n")
   709  	fmt.Fprintf(&b, "}\n")
   710  	return b.Bytes()
   711  }
   712  
   713  func (bctx *buildContext) generate(typ *types.Named, encoder, decoder bool) ([]byte, error) {
   714  	bctx.topType = typ
   715  
   716  	pkg := typ.Obj().Pkg()
   717  	op, err := bctx.makeOp(nil, typ, rlpstruct.Tags{})
   718  	if err != nil {
   719  		return nil, err
   720  	}
   721  
   722  	var (
   723  		ctx       = newGenContext(pkg)
   724  		encSource []byte
   725  		decSource []byte
   726  	)
   727  	if encoder {
   728  		encSource = generateEncoder(ctx, typ.Obj().Name(), op)
   729  	}
   730  	if decoder {
   731  		decSource = generateDecoder(ctx, typ.Obj().Name(), op)
   732  	}
   733  
   734  	var b bytes.Buffer
   735  	fmt.Fprintf(&b, "package %s\n\n", pkg.Name())
   736  	for _, imp := range ctx.importsList() {
   737  		fmt.Fprintf(&b, "import %q\n", imp)
   738  	}
   739  	if encoder {
   740  		fmt.Fprintln(&b)
   741  		b.Write(encSource)
   742  	}
   743  	if decoder {
   744  		fmt.Fprintln(&b)
   745  		b.Write(decSource)
   746  	}
   747  
   748  	source := b.Bytes()
   749  	// fmt.Println(string(source))
   750  	return format.Source(source)
   751  }