github.com/cosmos/cosmos-proto@v1.0.0-beta.3/features/fastreflection/list.go (about)

     1  package fastreflection
     2  
     3  import (
     4  	"fmt"
     5  	"github.com/cosmos/cosmos-proto/generator"
     6  
     7  	"google.golang.org/protobuf/compiler/protogen"
     8  	"google.golang.org/protobuf/reflect/protoreflect"
     9  )
    10  
    11  type listGen struct {
    12  	*generator.GeneratedFile
    13  	field *protogen.Field
    14  
    15  	typeName string
    16  }
    17  
    18  func (g *listGen) generate() {
    19  	g.typeName = listTypeName(g.field)
    20  
    21  	g.genAssertions()
    22  	g.genType()
    23  	g.genLen()
    24  	g.genGet()
    25  	g.genSet()
    26  	g.genAppend()
    27  	g.genAppendMutable()
    28  	g.genTruncate()
    29  	g.genNewElement()
    30  	g.genIsValid()
    31  }
    32  
    33  // genAssertions generates protoreflect.List type assertions
    34  func (g *listGen) genAssertions() {
    35  	// type assertion
    36  	g.P("var _ ", protoreflectPkg.Ident("List"), " = (*", g.typeName, ")(nil)")
    37  }
    38  
    39  // genType generates the list type
    40  func (g *listGen) genType() {
    41  	g.P("type ", g.typeName, " struct {")
    42  	g.P("list *[]", getGoType(g.GeneratedFile, g.field))
    43  	g.P("}")
    44  	g.P()
    45  }
    46  
    47  // genLen generates the implementation for protoreflect.List.Len
    48  func (g *listGen) genLen() {
    49  	g.P("func (x *", g.typeName, ") Len() int {")
    50  	g.P("if x.list == nil {")
    51  	g.P("return 0")
    52  	g.P("}")
    53  	g.P("return len(*x.list)")
    54  	g.P("}")
    55  	g.P()
    56  }
    57  
    58  // genGet generates the implementation for protoreflect.List.Get
    59  func (g *listGen) genGet() {
    60  	g.P("func (x *", g.typeName, ") Get(i int) ", protoreflectPkg.Ident("Value"), " {")
    61  	constructor := kindToValueConstructor(g.field.Desc.Kind())
    62  	switch g.field.Desc.Kind() {
    63  	case protoreflect.MessageKind:
    64  		g.P("return ", constructor, "((*x.list)[i].ProtoReflect())")
    65  	case protoreflect.EnumKind:
    66  		g.P("return ", constructor, "((", protoreflectPkg.Ident("EnumNumber"), ")((*x.list)[i]))")
    67  	default:
    68  		g.P("return ", constructor, "((*x.list)[i])")
    69  	}
    70  	g.P("}")
    71  	g.P()
    72  }
    73  
    74  // genSet generates the implementation for protoreflect.List.Set
    75  func (g *listGen) genSet() {
    76  	// Set
    77  	g.P("func (x *", g.typeName, ") Set(i int, value ", protoreflectPkg.Ident("Value"), ") {")
    78  	concreteValueName := genPrefValueToGoValue(g.GeneratedFile, g.field, "value", "concreteValue")
    79  	g.P("(*x.list)[i] = ", concreteValueName)
    80  	g.P("}")
    81  	g.P()
    82  }
    83  
    84  // genAppend generates the protoreflect.List.Append implementation
    85  func (g *listGen) genAppend() {
    86  	g.P("func (x *", g.typeName, ") Append(value ", protoreflectPkg.Ident("Value"), ") {")
    87  	concreteValueName := genPrefValueToGoValue(g.GeneratedFile, g.field, "value", "concreteValue")
    88  	g.P("*x.list = append(*x.list, ", concreteValueName, ")")
    89  	g.P("}")
    90  	g.P()
    91  }
    92  
    93  // genAppendMutable generates the protoreflect.List.AppendMutable implementation
    94  func (g *listGen) genAppendMutable() {
    95  	g.P("func (x *", g.typeName, ") AppendMutable() ", protoreflectPkg.Ident("Value"), " {")
    96  	switch g.field.Desc.Kind() {
    97  	case protoreflect.MessageKind:
    98  		g.P("v := new(", g.QualifiedGoIdent(g.field.Message.GoIdent), ")")
    99  		g.P("*x.list = append(*x.list, v)")
   100  		g.P("return ", protoreflectPkg.Ident("ValueOfMessage"), "(v.ProtoReflect())")
   101  	default:
   102  		panicMsg := fmt.Sprintf("AppendMutable can not be called on message %s at list field %s as it is not of Message kind", g.field.Parent.GoIdent.GoName, g.field.GoName)
   103  		g.P("panic(", fmtPkg.Ident("Errorf"), "(\"", panicMsg, "\"))")
   104  	}
   105  	g.P("}")
   106  	g.P()
   107  }
   108  
   109  // genTruncate generates the protoreflect.List.Truncate implementation
   110  func (g *listGen) genTruncate() {
   111  	g.P("func (x *", g.typeName, ") Truncate(n int)", "{")
   112  
   113  	switch g.field.Desc.Kind() {
   114  	case protoreflect.MessageKind: // zero message kinds to avoid keeping data alive
   115  		g.P("for i := n; i < len(*x.list); i++ {")
   116  		g.P("(*x.list)[i] = nil")
   117  		g.P("}")
   118  	}
   119  	g.P("*x.list = (*x.list)[:n]") // truncate
   120  	g.P("}")
   121  	g.P()
   122  }
   123  
   124  // genNewElement generates the protoreflect.List.NewElement implementation
   125  func (g *listGen) genNewElement() {
   126  	g.P("func (x *", g.typeName, ") NewElement() ", protoreflectPkg.Ident("Value"), "{")
   127  	switch g.field.Desc.Kind() {
   128  	case protoreflect.BytesKind:
   129  		g.P("var v []byte")
   130  	default:
   131  		zeroValue := zeroValueForField(g.GeneratedFile, g.field)
   132  		g.P("v := ", zeroValue)
   133  	}
   134  	switch g.field.Desc.Kind() {
   135  	case protoreflect.MessageKind:
   136  		g.P("return ", kindToValueConstructor(g.field.Desc.Kind()), "(v.ProtoReflect())")
   137  	case protoreflect.EnumKind:
   138  		g.P("return ", kindToValueConstructor(g.field.Desc.Kind()), "((", protoreflectPkg.Ident("EnumNumber"), ")(v))")
   139  	default:
   140  		g.P("return ", kindToValueConstructor(g.field.Desc.Kind()), "(v)")
   141  	}
   142  	g.P("}")
   143  	g.P()
   144  }
   145  
   146  func (g *listGen) genIsValid() {
   147  	g.P("func (x *", g.typeName, ") IsValid() bool {")
   148  	g.P("return x.list != nil") // if we generate this type it's always valid, it either comes from Mutable path, or Get path in which the list is valid.
   149  	g.P("}")
   150  	g.P()
   151  }
   152  
   153  func listTypeName(field *protogen.Field) string {
   154  	return fmt.Sprintf("_%s_%d_list", field.Parent.GoIdent.GoName, field.Desc.Number())
   155  }
   156  
   157  func getGoType(g *generator.GeneratedFile, field *protogen.Field) (goType string) {
   158  	if field.Desc.IsWeak() {
   159  		return "struct{}"
   160  	}
   161  
   162  	switch field.Desc.Kind() {
   163  	case protoreflect.BoolKind:
   164  		goType = "bool"
   165  	case protoreflect.EnumKind:
   166  		goType = g.QualifiedGoIdent(field.Enum.GoIdent)
   167  	case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
   168  		goType = "int32"
   169  	case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
   170  		goType = "uint32"
   171  	case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
   172  		goType = "int64"
   173  	case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
   174  		goType = "uint64"
   175  	case protoreflect.FloatKind:
   176  		goType = "float32"
   177  	case protoreflect.DoubleKind:
   178  		goType = "float64"
   179  	case protoreflect.StringKind:
   180  		goType = "string"
   181  	case protoreflect.BytesKind:
   182  		goType = "[]byte"
   183  	case protoreflect.MessageKind, protoreflect.GroupKind:
   184  		goType = "*" + g.QualifiedGoIdent(field.Message.GoIdent)
   185  	}
   186  	return goType
   187  }
   188  
   189  func kindToValueConstructor(kind protoreflect.Kind) protogen.GoIdent {
   190  	switch kind {
   191  	case protoreflect.BoolKind:
   192  		return protoreflectPkg.Ident("ValueOfBool")
   193  	case protoreflect.EnumKind:
   194  		return protoreflectPkg.Ident("ValueOfEnum")
   195  	case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
   196  		return protoreflectPkg.Ident("ValueOfInt32")
   197  	case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
   198  		return protoreflectPkg.Ident("ValueOfUint32")
   199  	case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
   200  		return protoreflectPkg.Ident("ValueOfInt64")
   201  	case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
   202  		return protoreflectPkg.Ident("ValueOfUint64")
   203  	case protoreflect.FloatKind:
   204  		return protoreflectPkg.Ident("ValueOfFloat32")
   205  	case protoreflect.DoubleKind:
   206  		return protoreflectPkg.Ident("ValueOfFloat64")
   207  	case protoreflect.StringKind:
   208  		return protoreflectPkg.Ident("ValueOfString")
   209  	case protoreflect.BytesKind:
   210  		return protoreflectPkg.Ident("ValueOfBytes")
   211  	case protoreflect.MessageKind, protoreflect.GroupKind:
   212  		return protoreflectPkg.Ident("ValueOfMessage")
   213  	default:
   214  		panic("should not reach here")
   215  	}
   216  }
   217  
   218  // valueUnwrapper provides the function to call on value
   219  // in order to get the concrete underlying type
   220  func valueUnwrapper(kind protoreflect.Kind) string {
   221  	switch kind {
   222  	case protoreflect.BoolKind:
   223  		return "Bool"
   224  	case protoreflect.EnumKind:
   225  		return "Enum"
   226  	case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
   227  		return "Int"
   228  	case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
   229  		return "Uint"
   230  	case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
   231  		return "Int"
   232  	case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
   233  		return "Uint"
   234  	case protoreflect.FloatKind:
   235  		return "Float"
   236  	case protoreflect.DoubleKind:
   237  		return "Float"
   238  	case protoreflect.StringKind:
   239  		return "String"
   240  	case protoreflect.BytesKind:
   241  		return "Bytes"
   242  	case protoreflect.MessageKind, protoreflect.GroupKind:
   243  		return "Message"
   244  	default:
   245  		panic("should not reach here")
   246  	}
   247  }
   248  
   249  func genPrefValueToGoValue(g *generator.GeneratedFile, field *protogen.Field, inputName string, outputName string) string {
   250  
   251  	unwrapperFunc := valueUnwrapper(field.Desc.Kind())
   252  	unwrapperVar := fmt.Sprintf("%sUnwrapped", inputName)
   253  	g.P(unwrapperVar, " := ", inputName, ".", unwrapperFunc, "()")
   254  	switch field.Desc.Kind() {
   255  	case protoreflect.MessageKind:
   256  		g.P(outputName, " := ", unwrapperVar, ".Interface().(*", g.QualifiedGoIdent(field.Message.GoIdent), ")")
   257  	case protoreflect.EnumKind:
   258  		g.P(outputName, " := (", g.QualifiedGoIdent(field.Enum.GoIdent), ")(", unwrapperVar, ")")
   259  	case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
   260  		g.P(outputName, " := (int32)(", unwrapperVar, ")")
   261  	case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
   262  		g.P(outputName, " := (uint32)(", unwrapperVar, ")")
   263  	case protoreflect.FloatKind:
   264  		g.P(outputName, " := (float32)(", unwrapperVar, ")")
   265  	default:
   266  		g.P(outputName, " := ", unwrapperVar)
   267  	}
   268  
   269  	return outputName
   270  }
   271  
   272  func zeroValueForField(g *generator.GeneratedFile, field *protogen.Field) string {
   273  	switch field.Desc.Kind() {
   274  	case protoreflect.BoolKind:
   275  		return "false"
   276  	case protoreflect.EnumKind:
   277  		return fmt.Sprintf("%d", field.Enum.Desc.Values().Get(0).Number())
   278  	case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind:
   279  		return "int32(0)"
   280  	case protoreflect.Uint32Kind, protoreflect.Fixed32Kind:
   281  		return "uint32(0)"
   282  	case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind:
   283  		return "int64(0)"
   284  	case protoreflect.Uint64Kind, protoreflect.Fixed64Kind:
   285  		return "uint64(0)"
   286  	case protoreflect.FloatKind:
   287  		return "float32(0)"
   288  	case protoreflect.DoubleKind:
   289  		return "float64(0)"
   290  	case protoreflect.StringKind:
   291  		return "\"\""
   292  	case protoreflect.BytesKind:
   293  		return "nil"
   294  	case protoreflect.MessageKind, protoreflect.GroupKind:
   295  		return fmt.Sprintf("new(%s)", g.QualifiedGoIdent(field.Message.GoIdent))
   296  	default:
   297  		panic("should not reach")
   298  	}
   299  }