github.com/cosmos/cosmos-proto@v1.0.0-beta.3/features/fastreflection/proto_size.go (about)

     1  package fastreflection
     2  
     3  import (
     4  	"fmt"
     5  	"github.com/cosmos/cosmos-proto/generator"
     6  	"google.golang.org/protobuf/compiler/protogen"
     7  	"google.golang.org/protobuf/encoding/protowire"
     8  	"google.golang.org/protobuf/reflect/protoreflect"
     9  	"strconv"
    10  	"strings"
    11  )
    12  
    13  var kindToGoType = map[protoreflect.Kind]string{
    14  	protoreflect.BoolKind:     "bool",
    15  	protoreflect.EnumKind:     "Enumeration",
    16  	protoreflect.Int32Kind:    "int32",
    17  	protoreflect.Sint32Kind:   "int32",
    18  	protoreflect.Uint32Kind:   "uint32",
    19  	protoreflect.Int64Kind:    "int64",
    20  	protoreflect.Sint64Kind:   "int64",
    21  	protoreflect.Uint64Kind:   "uint64",
    22  	protoreflect.Sfixed32Kind: "int32",
    23  	protoreflect.Fixed32Kind:  "uint32",
    24  	protoreflect.FloatKind:    "float32",
    25  	protoreflect.Sfixed64Kind: "int64",
    26  	protoreflect.Fixed64Kind:  "uint64",
    27  	protoreflect.DoubleKind:   "float64",
    28  	protoreflect.StringKind:   "string",
    29  	protoreflect.BytesKind:    "byte",
    30  }
    31  
    32  func (g *fastGenerator) genSizeMethod() {
    33  
    34  	g.P(`size := func(input `, protoifacePkg.Ident("SizeInput"), ") ", protoifacePkg.Ident("SizeOutput"), " {")
    35  	g.P("x := input.Message.Interface().(*", g.message.GoIdent, ")")
    36  	g.P(`if x == nil {`)
    37  	g.P(`return `, protoifacePkg.Ident("SizeOutput"), "{ ")
    38  	g.P("NoUnkeyedLiterals: input.NoUnkeyedLiterals,")
    39  	g.P("Size: 0,")
    40  	g.P("}")
    41  	g.P(`}`)
    42  	g.P("options := ", runtimePackage.Ident("SizeInputToOptions"), "(input)")
    43  	g.P("_ = options")
    44  	g.P(`var n int`)
    45  	g.P(`var l int`)
    46  	g.P(`_ = l`)
    47  	oneofs := make(map[string]struct{})
    48  	for _, field := range g.message.Fields {
    49  		oneof := field.Oneof != nil && !field.Oneof.Desc.IsSynthetic()
    50  		if !oneof {
    51  			g.field(true, field, false)
    52  		} else {
    53  			fieldName := field.Oneof.GoName
    54  			if _, ok := oneofs[fieldName]; !ok {
    55  				oneofs[fieldName] = struct{}{}
    56  				g.P("switch x := x.", fieldName, ".(type) {")
    57  				for _, ooField := range field.Oneof.Fields {
    58  
    59  					g.P("case *", ooField.GoIdent, ": ")
    60  					g.P("if x == nil {")
    61  					g.P("break")
    62  					g.P("}")
    63  					g.field(true, ooField, true)
    64  				}
    65  				g.P("}")
    66  			}
    67  		}
    68  	}
    69  
    70  	// last thing to do
    71  	g.P(`if x.unknownFields != nil {`)
    72  	g.P(`n+=len(x.unknownFields)`)
    73  	g.P(`}`)
    74  	g.P(`return `, protoifacePkg.Ident("SizeOutput"), "{ ")
    75  	g.P("NoUnkeyedLiterals: input.NoUnkeyedLiterals,")
    76  	g.P("Size: n,")
    77  	g.P("}")
    78  	g.P(`}`)
    79  	g.P()
    80  }
    81  
    82  func (g *fastGenerator) field(proto3 bool, field *protogen.Field, oneof bool) {
    83  	fieldname := field.GoName
    84  	nullable := field.Message != nil || (field.Oneof != nil && field.Oneof.Desc.IsSynthetic())
    85  	repeated := field.Desc.Cardinality() == protoreflect.Repeated
    86  	if repeated && !oneof {
    87  		g.P(`if len(x.`, fieldname, `) > 0 {`)
    88  	} else if nullable && !oneof {
    89  		g.P(`if x.`, fieldname, ` != nil {`)
    90  	}
    91  
    92  	packed := field.Desc.IsPacked()
    93  	wireType := generator.ProtoWireType(field.Desc.Kind())
    94  	fieldNumber := field.Desc.Number()
    95  	if packed {
    96  		wireType = protowire.BytesType
    97  	}
    98  	key := generator.KeySize(fieldNumber, wireType)
    99  	switch field.Desc.Kind() {
   100  	case protoreflect.Fixed64Kind, protoreflect.Sfixed64Kind:
   101  		if packed {
   102  			g.P(`n+=`, strconv.Itoa(key), `+`, runtimePackage.Ident("Sov"), `(uint64(len(x.`, fieldname, `)*8))`, `+len(x.`, fieldname, `)*8`)
   103  		} else if repeated {
   104  			g.P(`n+=`, strconv.Itoa(key+8), `*len(x.`, fieldname, `)`)
   105  		} else if proto3 && !nullable {
   106  			if !oneof {
   107  				g.P(`if x.`, fieldname, ` != 0 {`)
   108  			}
   109  			g.P(`n+=`, strconv.Itoa(key+8))
   110  			if !oneof {
   111  				g.P(`}`)
   112  			}
   113  		} else {
   114  			g.P(`n+=`, strconv.Itoa(key+8))
   115  		}
   116  	case protoreflect.DoubleKind:
   117  		if packed {
   118  			g.P(`n+=`, strconv.Itoa(key), `+`, runtimePackage.Ident("Sov"), `(uint64(len(x.`, fieldname, `)*8))`, `+len(x.`, fieldname, `)*8`)
   119  		} else if repeated {
   120  			g.P(`n+=`, strconv.Itoa(key+8), `*len(x.`, fieldname, `)`)
   121  		} else if proto3 && !nullable {
   122  			if !oneof {
   123  				g.P(`if x.`, fieldname, ` != 0 || `, mathPackage.Ident("Signbit"), `(x.`, fieldname, `) {`)
   124  			}
   125  			g.P(`n+=`, strconv.Itoa(key+8))
   126  			if !oneof {
   127  				g.P(`}`)
   128  			}
   129  		} else {
   130  			g.P(`n+=`, strconv.Itoa(key+8))
   131  		}
   132  	case protoreflect.Fixed32Kind, protoreflect.Sfixed32Kind:
   133  		if packed {
   134  			g.P(`n+=`, strconv.Itoa(key), `+`, runtimePackage.Ident("Sov"), `(uint64(len(x.`, fieldname, `)*4))`, `+len(x.`, fieldname, `)*4`)
   135  		} else if repeated {
   136  			g.P(`n+=`, strconv.Itoa(key+4), `*len(x.`, fieldname, `)`)
   137  		} else if proto3 && !nullable {
   138  			if !oneof {
   139  				g.P(`if x.`, fieldname, ` != 0 {`)
   140  			}
   141  			g.P(`n+=`, strconv.Itoa(key+4))
   142  			if !oneof {
   143  				g.P(`}`)
   144  			}
   145  		} else {
   146  			g.P(`n+=`, strconv.Itoa(key+4))
   147  		}
   148  	case protoreflect.FloatKind:
   149  		if packed {
   150  			g.P(`n+=`, strconv.Itoa(key), `+`, runtimePackage.Ident("Sov"), `(uint64(len(x.`, fieldname, `)*4))`, `+len(x.`, fieldname, `)*4`)
   151  		} else if repeated {
   152  			g.P(`n+=`, strconv.Itoa(key+4), `*len(x.`, fieldname, `)`)
   153  		} else if proto3 && !nullable {
   154  			if !oneof {
   155  				g.P(`if x.`, fieldname, ` != 0 || `, mathPackage.Ident("Signbit"), `(float64(x.`, fieldname, `)) {`)
   156  			}
   157  			g.P(`n+=`, strconv.Itoa(key+4))
   158  			if !oneof {
   159  				g.P(`}`)
   160  			}
   161  		} else {
   162  			g.P(`n+=`, strconv.Itoa(key+4))
   163  		}
   164  	case protoreflect.Int64Kind, protoreflect.Uint64Kind, protoreflect.Uint32Kind, protoreflect.EnumKind, protoreflect.Int32Kind:
   165  		if packed {
   166  			g.P(`l = 0`)
   167  			g.P(`for _, e := range x.`, fieldname, ` {`)
   168  			g.P(`l+=`, runtimePackage.Ident("Sov"), `(uint64(e))`)
   169  			g.P(`}`)
   170  			g.P(`n+=`, strconv.Itoa(key), `+`, runtimePackage.Ident("Sov"), `(uint64(l))+l`)
   171  		} else if repeated {
   172  			g.P(`for _, e := range x.`, fieldname, ` {`)
   173  			g.P(`n+=`, strconv.Itoa(key), `+`, runtimePackage.Ident("Sov"), `(uint64(e))`)
   174  			g.P(`}`)
   175  		} else if nullable {
   176  			g.P(`n+=`, strconv.Itoa(key), `+`, runtimePackage.Ident("Sov"), `(uint64(*x.`, fieldname, `))`)
   177  		} else if proto3 {
   178  			if !oneof {
   179  				g.P(`if x.`, fieldname, ` != 0 {`)
   180  			}
   181  			g.P(`n+=`, strconv.Itoa(key), `+`, runtimePackage.Ident("Sov"), `(uint64(x.`, fieldname, `))`)
   182  			if !oneof {
   183  				g.P(`}`)
   184  			}
   185  		} else {
   186  			g.P(`n+=`, strconv.Itoa(key), `+`, runtimePackage.Ident("Sov"), `(uint64(x.`, fieldname, `))`)
   187  		}
   188  	case protoreflect.BoolKind:
   189  		if packed {
   190  			g.P(`n+=`, strconv.Itoa(key), `+`, runtimePackage.Ident("Sov"), `(uint64(len(x.`, fieldname, `)))`, `+len(x.`, fieldname, `)*1`)
   191  		} else if repeated {
   192  			g.P(`n+=`, strconv.Itoa(key+1), `*len(x.`, fieldname, `)`)
   193  		} else if proto3 && !nullable {
   194  			if !oneof {
   195  				g.P(`if x.`, fieldname, ` {`)
   196  			}
   197  			g.P(`n+=`, strconv.Itoa(key+1))
   198  			if !oneof {
   199  				g.P(`}`)
   200  			}
   201  		} else {
   202  			g.P(`n+=`, strconv.Itoa(key+1))
   203  		}
   204  	case protoreflect.StringKind:
   205  		if repeated {
   206  			g.P(`for _, s := range x.`, fieldname, ` { `)
   207  			g.P(`l = len(s)`)
   208  			g.P(`n+=`, strconv.Itoa(key), `+l+`, runtimePackage.Ident("Sov"), `(uint64(l))`)
   209  			g.P(`}`)
   210  		} else if nullable {
   211  			g.P(`l=len(*x.`, fieldname, `)`)
   212  			g.P(`n+=`, strconv.Itoa(key), `+l+`, runtimePackage.Ident("Sov"), `(uint64(l))`)
   213  		} else if proto3 {
   214  			g.P(`l=len(x.`, fieldname, `)`)
   215  			if !oneof {
   216  				g.P(`if l > 0 {`)
   217  			}
   218  			g.P(`n+=`, strconv.Itoa(key), `+l+`, runtimePackage.Ident("Sov"), `(uint64(l))`)
   219  			if !oneof {
   220  				g.P(`}`)
   221  			}
   222  		} else {
   223  			g.P(`l=len(x.`, fieldname, `)`)
   224  			g.P(`n+=`, strconv.Itoa(key), `+l+`, runtimePackage.Ident("Sov"), `(uint64(l))`)
   225  		}
   226  	case protoreflect.GroupKind:
   227  		panic(fmt.Errorf("size does not support group %v", fieldname))
   228  	case protoreflect.MessageKind:
   229  		if field.Desc.IsMap() {
   230  			fieldKeySize := generator.KeySize(field.Desc.Number(), generator.ProtoWireType(field.Desc.Kind()))
   231  			goTypeK, _ := g.FieldGoType(field.Message.Fields[0])
   232  			goTypeV, ptr := g.FieldGoType(field.Message.Fields[1])
   233  			if ptr {
   234  				goTypeV = "*" + goTypeV
   235  			}
   236  			keyKeySize := generator.KeySize(1, generator.ProtoWireType(field.Message.Fields[0].Desc.Kind()))
   237  			valueKeySize := generator.KeySize(2, generator.ProtoWireType(field.Message.Fields[1].Desc.Kind()))
   238  
   239  			sum := []string{strconv.Itoa(keyKeySize)}
   240  			g.P("SiZeMaP := func(k ", goTypeK, ", v ", goTypeV, ") {")
   241  			switch field.Desc.MapKey().Kind() {
   242  			case protoreflect.DoubleKind, protoreflect.Fixed64Kind, protoreflect.Sfixed64Kind:
   243  				sum = append(sum, `8`)
   244  			case protoreflect.FloatKind, protoreflect.Fixed32Kind, protoreflect.Sfixed32Kind:
   245  				sum = append(sum, `4`)
   246  			case protoreflect.Int64Kind, protoreflect.Uint64Kind, protoreflect.Uint32Kind, protoreflect.EnumKind, protoreflect.Int32Kind:
   247  				sum = append(sum, fmt.Sprintf("%s%s", g.QualifiedGoIdent(runtimePackage.Ident("Sov")), `(uint64(k))`))
   248  			case protoreflect.BoolKind:
   249  				sum = append(sum, `1`)
   250  			case protoreflect.StringKind, protoreflect.BytesKind:
   251  				sum = append(sum, `len(k)`, fmt.Sprintf("%s%s", g.QualifiedGoIdent(runtimePackage.Ident("Sov")), `(uint64(len(k)))`))
   252  			case protoreflect.Sint32Kind, protoreflect.Sint64Kind:
   253  				sum = append(sum, fmt.Sprintf("%s%s", g.QualifiedGoIdent(runtimePackage.Ident("Soz")), `(uint64(k))`))
   254  			}
   255  
   256  			switch field.Desc.MapValue().Kind() {
   257  			case protoreflect.DoubleKind, protoreflect.Fixed64Kind, protoreflect.Sfixed64Kind:
   258  				sum = append(sum, strconv.Itoa(valueKeySize))
   259  				sum = append(sum, strconv.Itoa(8))
   260  			case protoreflect.FloatKind, protoreflect.Fixed32Kind, protoreflect.Sfixed32Kind:
   261  				sum = append(sum, strconv.Itoa(valueKeySize))
   262  				sum = append(sum, strconv.Itoa(4))
   263  			case protoreflect.Int64Kind, protoreflect.Uint64Kind, protoreflect.Uint32Kind, protoreflect.EnumKind, protoreflect.Int32Kind:
   264  				sum = append(sum, strconv.Itoa(valueKeySize))
   265  				sum = append(sum, fmt.Sprintf("%s%s", g.QualifiedGoIdent(runtimePackage.Ident("Sov")), `(uint64(v))`))
   266  			case protoreflect.BoolKind:
   267  				sum = append(sum, strconv.Itoa(valueKeySize))
   268  				sum = append(sum, `1`)
   269  			case protoreflect.StringKind:
   270  				sum = append(sum, strconv.Itoa(valueKeySize))
   271  				sum = append(sum, `len(v)`, fmt.Sprintf("%s%s", g.QualifiedGoIdent(runtimePackage.Ident("Sov")), `(uint64(len(v)))`))
   272  			case protoreflect.BytesKind:
   273  				g.P(`l = `, strconv.Itoa(valueKeySize), ` + len(v)+`, runtimePackage.Ident("Sov"), `(uint64(len(v)))`)
   274  				sum = append(sum, `l`)
   275  			case protoreflect.Sint32Kind, protoreflect.Sint64Kind:
   276  				sum = append(sum, strconv.Itoa(valueKeySize))
   277  				sum = append(sum, fmt.Sprintf("%s%s", g.QualifiedGoIdent(runtimePackage.Ident("Soz")), `(uint64(v))`))
   278  			case protoreflect.MessageKind:
   279  				g.P(`l := 0`)
   280  				g.P(`if v != nil {`)
   281  				g.messageSize("v", field.Message.Fields[1].Message)
   282  				g.P(`}`)
   283  				g.P(`l += `, strconv.Itoa(valueKeySize), `+`, runtimePackage.Ident("Sov"), `(uint64(l))`)
   284  				sum = append(sum, `l`)
   285  			}
   286  			g.P(`mapEntrySize := `, strings.Join(sum, "+"))
   287  			g.P(`n+=mapEntrySize+`, fieldKeySize, `+`, runtimePackage.Ident("Sov"), `(uint64(mapEntrySize))`)
   288  			g.P("}")
   289  			// first we have to sort the key
   290  			typ, ok := kindToGoType[field.Desc.MapKey().Kind()]
   291  			if !ok {
   292  				panic(fmt.Sprintf("pulsar does not support %s types as map keys", field.Desc.MapKey().Kind().String()))
   293  			}
   294  			g.P("if options.Deterministic {")
   295  			g.P("sortme := make([]", typ, ", 0, len(x.", field.GoName, "))")
   296  			g.P("for k := range x.", fieldname, " {")
   297  			g.P("sortme = append(sortme, k)")
   298  			g.P("}")
   299  			switch field.Desc.MapKey().Kind() {
   300  			case protoreflect.StringKind:
   301  				g.P(sortPkg.Ident("Strings"), "(sortme)")
   302  			default:
   303  				g.P(sortPkg.Ident("Slice"), "(sortme, func(i, j int) bool {")
   304  				switch field.Desc.MapKey().Kind() {
   305  				case protoreflect.BoolKind:
   306  					g.P("return !sortme[i] && sortme[j]")
   307  				default:
   308  					g.P("return sortme[i] < sortme[j]")
   309  				}
   310  				g.P("})")
   311  
   312  			}
   313  
   314  			g.P(`for _, k := range sortme {`)
   315  			g.P("v := x.", fieldname, "[k]")
   316  			g.P("SiZeMaP(k,v)")
   317  			g.P(`}`)
   318  			g.P("} else {")
   319  			g.P("for k,v := range x.", fieldname, " {")
   320  			g.P("SiZeMaP(k,v)")
   321  			g.P("}")
   322  			g.P("}")
   323  		} else if field.Desc.IsList() {
   324  			g.P(`for _, e := range x.`, fieldname, ` { `)
   325  			g.messageSize("e", field.Message)
   326  			g.P(`n+=`, strconv.Itoa(key), `+l+`, runtimePackage.Ident("Sov"), `(uint64(l))`)
   327  			g.P(`}`)
   328  		} else {
   329  			g.messageSize("x."+fieldname, field.Message)
   330  			g.P(`n+=`, strconv.Itoa(key), `+l+`, runtimePackage.Ident("Sov"), `(uint64(l))`)
   331  		}
   332  	case protoreflect.BytesKind:
   333  		if repeated {
   334  			g.P(`for _, b := range x.`, fieldname, ` { `)
   335  			g.P(`l = len(b)`)
   336  			g.P(`n+=`, strconv.Itoa(key), `+l+`, runtimePackage.Ident("Sov"), `(uint64(l))`)
   337  			g.P(`}`)
   338  		} else if proto3 {
   339  			g.P(`l=len(x.`, fieldname, `)`)
   340  			if !oneof {
   341  				g.P(`if l > 0 {`)
   342  			}
   343  			g.P(`n+=`, strconv.Itoa(key), `+l+`, runtimePackage.Ident("Sov"), `(uint64(l))`)
   344  			if !oneof {
   345  				g.P(`}`)
   346  			}
   347  		} else {
   348  			g.P(`l=len(x.`, fieldname, `)`)
   349  			g.P(`n+=`, strconv.Itoa(key), `+l+`, runtimePackage.Ident("Sov"), `(uint64(l))`)
   350  		}
   351  	case protoreflect.Sint32Kind, protoreflect.Sint64Kind:
   352  		if packed {
   353  			g.P(`l = 0`)
   354  			g.P(`for _, e := range x.`, fieldname, ` {`)
   355  			g.P(`l+=`, runtimePackage.Ident("Soz"), `(uint64(e))`)
   356  			g.P(`}`)
   357  			g.P(`n+=`, strconv.Itoa(key), `+`, runtimePackage.Ident("Sov"), `(uint64(l))+l`)
   358  		} else if repeated {
   359  			g.P(`for _, e := range x.`, fieldname, ` {`)
   360  			g.P(`n+=`, strconv.Itoa(key), `+`, runtimePackage.Ident("Soz"), `(uint64(e))`)
   361  			g.P(`}`)
   362  		} else if nullable {
   363  			g.P(`n+=`, strconv.Itoa(key), `+`, runtimePackage.Ident("Soz"), `(uint64(*x.`, fieldname, `))`)
   364  		} else if proto3 {
   365  			if !oneof {
   366  				g.P(`if x.`, fieldname, ` != 0 {`)
   367  			}
   368  			g.P(`n+=`, strconv.Itoa(key), `+`, runtimePackage.Ident("Soz"), `(uint64(x.`, fieldname, `))`)
   369  			g.P(`}`)
   370  		} else {
   371  			g.P(`n+=`, strconv.Itoa(key), `+`, runtimePackage.Ident("Soz"), `(uint64(x.`, fieldname, `))`)
   372  		}
   373  	default:
   374  		panic("not implemented")
   375  	}
   376  	if (repeated || nullable) && !oneof {
   377  		g.P(`}`)
   378  	}
   379  }
   380  
   381  func (g *fastGenerator) messageSize(varName string, message *protogen.Message) {
   382  	g.P(`l = options.Size(`, varName, `)`)
   383  }