github.com/cosmos/cosmos-proto@v1.0.0-beta.3/features/fastreflection/map.go (about)

     1  package fastreflection
     2  
     3  import (
     4  	"fmt"
     5  	"github.com/cosmos/cosmos-proto/generator"
     6  
     7  	"google.golang.org/protobuf/compiler/protogen"
     8  	"google.golang.org/protobuf/reflect/protoreflect"
     9  )
    10  
    11  type mapGen struct {
    12  	*generator.GeneratedFile
    13  
    14  	field    *protogen.Field // TODO(fdmylja): maybe we could split this field into 2 fields one for key and one for value for the sake of being more readable
    15  	typeName string
    16  }
    17  
    18  func (g *mapGen) generate() {
    19  	g.typeName = mapTypeName(g.field)
    20  
    21  	g.genAssertions()
    22  	g.genType()
    23  	g.genLen()
    24  	g.genRange()
    25  	g.genHas()
    26  	g.genClear()
    27  	g.genGet()
    28  	g.genSet()
    29  	g.genMutable()
    30  	g.genNewValue()
    31  	g.genIsValid()
    32  }
    33  
    34  // genAssertions generates protoreflect.Map type assertions
    35  func (g *mapGen) genAssertions() {
    36  	g.P("var _ ", protoreflectPkg.Ident("Map"), " = (*", g.typeName, ")(nil)")
    37  }
    38  
    39  // genType generates the type definition for the protoreflect.Map implementer
    40  func (g *mapGen) genType() {
    41  	g.P("type ", g.typeName, " struct {")
    42  	g.P("m *map[", getGoType(g.GeneratedFile, g.field.Message.Fields[0]), "]", getGoType(g.GeneratedFile, g.field.Message.Fields[1]))
    43  	g.P("}")
    44  	g.P()
    45  }
    46  
    47  // genLen generates the implementation of protoreflect.Map.Len
    48  func (g *mapGen) genLen() {
    49  	g.P("func (x *", g.typeName, ") Len() int {")
    50  	// invalid map
    51  	g.P("if x.m == nil {")
    52  	g.P("return 0")
    53  	g.P("}")
    54  	// valid map
    55  	g.P("return len(*x.m)")
    56  	g.P("}")
    57  	g.P()
    58  }
    59  
    60  // genRange generates the implementation for protoreflect.Map.Range
    61  func (g *mapGen) genRange() {
    62  	g.P("func (x *", g.typeName, ") Range(f func(", protoreflectPkg.Ident("MapKey"), ", ", protoreflectPkg.Ident("Value"), ") bool) {")
    63  	// invalid map
    64  	g.P("if x.m == nil {")
    65  	g.P("return")
    66  	g.P("}")
    67  	// valid map
    68  	g.P("for k, v := range *x.m {")
    69  	g.P("mapKey := (", protoreflectPkg.Ident("MapKey"), ")(", kindToValueConstructor(g.field.Message.Fields[0].Desc.Kind()), "(k))")
    70  	switch g.field.Message.Fields[1].Desc.Kind() {
    71  	case protoreflect.MessageKind:
    72  		g.P("mapValue := ", kindToValueConstructor(g.field.Message.Fields[1].Desc.Kind()), "(v.ProtoReflect())")
    73  	case protoreflect.EnumKind:
    74  		g.P("mapValue := ", kindToValueConstructor(g.field.Message.Fields[1].Desc.Kind()), "(v.Number())")
    75  	default:
    76  		g.P("mapValue := ", kindToValueConstructor(g.field.Message.Fields[1].Desc.Kind()), "(v)")
    77  	}
    78  	g.P("if !f(mapKey, mapValue) {")
    79  	g.P("break")
    80  	g.P("}")
    81  	g.P("}")
    82  	g.P("}")
    83  	g.P()
    84  }
    85  
    86  // genHas generates the implementation for protoreflect.Map.Has
    87  func (g *mapGen) genHas() {
    88  	g.P("func (x *", g.typeName, ") Has(key ", protoreflectPkg.Ident("MapKey"), ") bool {")
    89  	// invalid map
    90  	g.P("if x.m == nil {")
    91  	g.P("return false")
    92  	g.P("}")
    93  	// valid map
    94  	genPrefValueToGoValue(g.GeneratedFile, g.field.Message.Fields[0], "key", "concreteValue")
    95  	g.P("_, ok := (*x.m)[concreteValue]")
    96  	g.P("return ok")
    97  	g.P("}")
    98  	g.P()
    99  }
   100  
   101  func (g *mapGen) genClear() {
   102  	g.P("func (x *", g.typeName, ") Clear(key ", protoreflectPkg.Ident("MapKey"), ") {")
   103  	// invalid map
   104  	g.P("if x.m == nil {")
   105  	g.P("return")
   106  	g.P("}")
   107  	// valid map
   108  	genPrefValueToGoValue(g.GeneratedFile, g.field.Message.Fields[0], "key", "concreteKey")
   109  	g.P("delete(*x.m, concreteKey)")
   110  	g.P("}")
   111  	g.P()
   112  }
   113  
   114  func (g *mapGen) genGet() {
   115  	g.P("func (x *", g.typeName, ") Get(key ", protoreflectPkg.Ident("MapKey"), ") ", protoreflectPkg.Ident("Value"), "{")
   116  	g.P("if x.m == nil {")
   117  	g.P("return ", protoreflectPkg.Ident("Value"), "{}")
   118  	g.P("}")
   119  	genPrefValueToGoValue(g.GeneratedFile, g.field.Message.Fields[0], "key", "concreteKey")
   120  	g.P("v, ok := (*x.m)[concreteKey]")
   121  	g.P("if !ok {")
   122  	g.P("return ", protoreflectPkg.Ident("Value"), "{}")
   123  	g.P("}")
   124  	switch g.field.Message.Fields[1].Desc.Kind() {
   125  	case protoreflect.MessageKind:
   126  		g.P("return ", kindToValueConstructor(g.field.Message.Fields[1].Desc.Kind()), "(v.ProtoReflect())")
   127  	case protoreflect.EnumKind:
   128  		g.P("return ", kindToValueConstructor(g.field.Message.Fields[1].Desc.Kind()), "((", protoreflectPkg.Ident("EnumNumber"), ")(v))")
   129  	default:
   130  		g.P("return ", kindToValueConstructor(g.field.Message.Fields[1].Desc.Kind()), "(v)")
   131  	}
   132  	g.P("}")
   133  	g.P()
   134  }
   135  
   136  func (g *mapGen) genSet() {
   137  	g.P("func (x *", g.typeName, ") Set(key ", protoreflectPkg.Ident("MapKey"), ", value ", protoreflectPkg.Ident("Value"), ") {")
   138  	g.P("if !key.IsValid() || !value.IsValid() {")
   139  	g.P("panic(\"invalid key or value provided\")")
   140  	g.P("}")
   141  	genPrefValueToGoValue(g.GeneratedFile, g.field.Message.Fields[0], "key", "concreteKey")
   142  	genPrefValueToGoValue(g.GeneratedFile, g.field.Message.Fields[1], "value", "concreteValue")
   143  	g.P("(*x.m)[concreteKey] = concreteValue")
   144  	g.P("}")
   145  	g.P()
   146  }
   147  
   148  func (g *mapGen) genMutable() {
   149  	// if it's not a message value type, we construct a panic function
   150  	g.P("func (x *", g.typeName, ") Mutable(key ", protoreflectPkg.Ident("MapKey"), ") ", protoreflectPkg.Ident("Value"), " {")
   151  	if g.field.Message.Fields[1].Desc.Kind() != protoreflect.MessageKind {
   152  		panicMsg := "should not call Mutable on protoreflect.Map whose value is not of type protoreflect.Message"
   153  		g.P("panic(\"", panicMsg, "\")")
   154  		g.P("}")
   155  		g.P()
   156  		return
   157  	}
   158  	// generate mutable message logic
   159  	genPrefValueToGoValue(g.GeneratedFile, g.field.Message.Fields[0], "key", "concreteKey")
   160  	g.P("v, ok := (*x.m)[concreteKey]")
   161  	g.P("if ok {")
   162  	g.P("return ", protoreflectPkg.Ident("ValueOfMessage"), "(v.ProtoReflect())")
   163  	g.P("}")
   164  	g.P("newValue := new(", g.QualifiedGoIdent(g.field.Message.Fields[1].Message.GoIdent), ")")
   165  	g.P("(*x.m)[concreteKey] = newValue")
   166  	g.P("return ", protoreflectPkg.Ident("ValueOfMessage"), "(newValue.ProtoReflect())")
   167  	g.P("}")
   168  	g.P()
   169  }
   170  
   171  func (g *mapGen) genNewValue() {
   172  	g.P("func (x *", g.typeName, ") NewValue() ", protoreflectPkg.Ident("Value"), " {")
   173  	valueField := g.field.Message.Fields[1]
   174  	switch {
   175  	case valueField.Desc.Kind() == protoreflect.BytesKind:
   176  		g.P("var v []byte")
   177  	default:
   178  		g.P("v := ", zeroValueForField(g.GeneratedFile, valueField))
   179  	}
   180  	switch valueField.Desc.Kind() {
   181  	case protoreflect.MessageKind:
   182  		g.P("return ", kindToValueConstructor(valueField.Desc.Kind()), "(v.ProtoReflect())")
   183  	case protoreflect.EnumKind:
   184  		g.P("return ", kindToValueConstructor(valueField.Desc.Kind()), "((", protoreflectPkg.Ident("EnumNumber"), ")(v))")
   185  	default:
   186  		g.P("return ", kindToValueConstructor(valueField.Desc.Kind()), "(v)")
   187  	}
   188  	g.P("}")
   189  	g.P()
   190  }
   191  
   192  func (g *mapGen) genIsValid() {
   193  	g.P("func (x *", g.typeName, ") IsValid() bool {")
   194  	g.P("return x.m != nil")
   195  	g.P("}")
   196  	g.P()
   197  }
   198  
   199  func mapTypeName(field *protogen.Field) string {
   200  	return fmt.Sprintf("_%s_%d_map", field.Parent.GoIdent.GoName, field.Desc.Number())
   201  }