github.com/cosmos/cosmos-proto@v1.0.0-beta.3/features/fastreflection/list.go (about) 1 package fastreflection 2 3 import ( 4 "fmt" 5 "github.com/cosmos/cosmos-proto/generator" 6 7 "google.golang.org/protobuf/compiler/protogen" 8 "google.golang.org/protobuf/reflect/protoreflect" 9 ) 10 11 type listGen struct { 12 *generator.GeneratedFile 13 field *protogen.Field 14 15 typeName string 16 } 17 18 func (g *listGen) generate() { 19 g.typeName = listTypeName(g.field) 20 21 g.genAssertions() 22 g.genType() 23 g.genLen() 24 g.genGet() 25 g.genSet() 26 g.genAppend() 27 g.genAppendMutable() 28 g.genTruncate() 29 g.genNewElement() 30 g.genIsValid() 31 } 32 33 // genAssertions generates protoreflect.List type assertions 34 func (g *listGen) genAssertions() { 35 // type assertion 36 g.P("var _ ", protoreflectPkg.Ident("List"), " = (*", g.typeName, ")(nil)") 37 } 38 39 // genType generates the list type 40 func (g *listGen) genType() { 41 g.P("type ", g.typeName, " struct {") 42 g.P("list *[]", getGoType(g.GeneratedFile, g.field)) 43 g.P("}") 44 g.P() 45 } 46 47 // genLen generates the implementation for protoreflect.List.Len 48 func (g *listGen) genLen() { 49 g.P("func (x *", g.typeName, ") Len() int {") 50 g.P("if x.list == nil {") 51 g.P("return 0") 52 g.P("}") 53 g.P("return len(*x.list)") 54 g.P("}") 55 g.P() 56 } 57 58 // genGet generates the implementation for protoreflect.List.Get 59 func (g *listGen) genGet() { 60 g.P("func (x *", g.typeName, ") Get(i int) ", protoreflectPkg.Ident("Value"), " {") 61 constructor := kindToValueConstructor(g.field.Desc.Kind()) 62 switch g.field.Desc.Kind() { 63 case protoreflect.MessageKind: 64 g.P("return ", constructor, "((*x.list)[i].ProtoReflect())") 65 case protoreflect.EnumKind: 66 g.P("return ", constructor, "((", protoreflectPkg.Ident("EnumNumber"), ")((*x.list)[i]))") 67 default: 68 g.P("return ", constructor, "((*x.list)[i])") 69 } 70 g.P("}") 71 g.P() 72 } 73 74 // genSet generates the implementation for protoreflect.List.Set 75 func (g *listGen) genSet() { 76 // Set 77 g.P("func (x *", g.typeName, ") Set(i int, value ", protoreflectPkg.Ident("Value"), ") {") 78 concreteValueName := genPrefValueToGoValue(g.GeneratedFile, g.field, "value", "concreteValue") 79 g.P("(*x.list)[i] = ", concreteValueName) 80 g.P("}") 81 g.P() 82 } 83 84 // genAppend generates the protoreflect.List.Append implementation 85 func (g *listGen) genAppend() { 86 g.P("func (x *", g.typeName, ") Append(value ", protoreflectPkg.Ident("Value"), ") {") 87 concreteValueName := genPrefValueToGoValue(g.GeneratedFile, g.field, "value", "concreteValue") 88 g.P("*x.list = append(*x.list, ", concreteValueName, ")") 89 g.P("}") 90 g.P() 91 } 92 93 // genAppendMutable generates the protoreflect.List.AppendMutable implementation 94 func (g *listGen) genAppendMutable() { 95 g.P("func (x *", g.typeName, ") AppendMutable() ", protoreflectPkg.Ident("Value"), " {") 96 switch g.field.Desc.Kind() { 97 case protoreflect.MessageKind: 98 g.P("v := new(", g.QualifiedGoIdent(g.field.Message.GoIdent), ")") 99 g.P("*x.list = append(*x.list, v)") 100 g.P("return ", protoreflectPkg.Ident("ValueOfMessage"), "(v.ProtoReflect())") 101 default: 102 panicMsg := fmt.Sprintf("AppendMutable can not be called on message %s at list field %s as it is not of Message kind", g.field.Parent.GoIdent.GoName, g.field.GoName) 103 g.P("panic(", fmtPkg.Ident("Errorf"), "(\"", panicMsg, "\"))") 104 } 105 g.P("}") 106 g.P() 107 } 108 109 // genTruncate generates the protoreflect.List.Truncate implementation 110 func (g *listGen) genTruncate() { 111 g.P("func (x *", g.typeName, ") Truncate(n int)", "{") 112 113 switch g.field.Desc.Kind() { 114 case protoreflect.MessageKind: // zero message kinds to avoid keeping data alive 115 g.P("for i := n; i < len(*x.list); i++ {") 116 g.P("(*x.list)[i] = nil") 117 g.P("}") 118 } 119 g.P("*x.list = (*x.list)[:n]") // truncate 120 g.P("}") 121 g.P() 122 } 123 124 // genNewElement generates the protoreflect.List.NewElement implementation 125 func (g *listGen) genNewElement() { 126 g.P("func (x *", g.typeName, ") NewElement() ", protoreflectPkg.Ident("Value"), "{") 127 switch g.field.Desc.Kind() { 128 case protoreflect.BytesKind: 129 g.P("var v []byte") 130 default: 131 zeroValue := zeroValueForField(g.GeneratedFile, g.field) 132 g.P("v := ", zeroValue) 133 } 134 switch g.field.Desc.Kind() { 135 case protoreflect.MessageKind: 136 g.P("return ", kindToValueConstructor(g.field.Desc.Kind()), "(v.ProtoReflect())") 137 case protoreflect.EnumKind: 138 g.P("return ", kindToValueConstructor(g.field.Desc.Kind()), "((", protoreflectPkg.Ident("EnumNumber"), ")(v))") 139 default: 140 g.P("return ", kindToValueConstructor(g.field.Desc.Kind()), "(v)") 141 } 142 g.P("}") 143 g.P() 144 } 145 146 func (g *listGen) genIsValid() { 147 g.P("func (x *", g.typeName, ") IsValid() bool {") 148 g.P("return x.list != nil") // if we generate this type it's always valid, it either comes from Mutable path, or Get path in which the list is valid. 149 g.P("}") 150 g.P() 151 } 152 153 func listTypeName(field *protogen.Field) string { 154 return fmt.Sprintf("_%s_%d_list", field.Parent.GoIdent.GoName, field.Desc.Number()) 155 } 156 157 func getGoType(g *generator.GeneratedFile, field *protogen.Field) (goType string) { 158 if field.Desc.IsWeak() { 159 return "struct{}" 160 } 161 162 switch field.Desc.Kind() { 163 case protoreflect.BoolKind: 164 goType = "bool" 165 case protoreflect.EnumKind: 166 goType = g.QualifiedGoIdent(field.Enum.GoIdent) 167 case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind: 168 goType = "int32" 169 case protoreflect.Uint32Kind, protoreflect.Fixed32Kind: 170 goType = "uint32" 171 case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind: 172 goType = "int64" 173 case protoreflect.Uint64Kind, protoreflect.Fixed64Kind: 174 goType = "uint64" 175 case protoreflect.FloatKind: 176 goType = "float32" 177 case protoreflect.DoubleKind: 178 goType = "float64" 179 case protoreflect.StringKind: 180 goType = "string" 181 case protoreflect.BytesKind: 182 goType = "[]byte" 183 case protoreflect.MessageKind, protoreflect.GroupKind: 184 goType = "*" + g.QualifiedGoIdent(field.Message.GoIdent) 185 } 186 return goType 187 } 188 189 func kindToValueConstructor(kind protoreflect.Kind) protogen.GoIdent { 190 switch kind { 191 case protoreflect.BoolKind: 192 return protoreflectPkg.Ident("ValueOfBool") 193 case protoreflect.EnumKind: 194 return protoreflectPkg.Ident("ValueOfEnum") 195 case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind: 196 return protoreflectPkg.Ident("ValueOfInt32") 197 case protoreflect.Uint32Kind, protoreflect.Fixed32Kind: 198 return protoreflectPkg.Ident("ValueOfUint32") 199 case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind: 200 return protoreflectPkg.Ident("ValueOfInt64") 201 case protoreflect.Uint64Kind, protoreflect.Fixed64Kind: 202 return protoreflectPkg.Ident("ValueOfUint64") 203 case protoreflect.FloatKind: 204 return protoreflectPkg.Ident("ValueOfFloat32") 205 case protoreflect.DoubleKind: 206 return protoreflectPkg.Ident("ValueOfFloat64") 207 case protoreflect.StringKind: 208 return protoreflectPkg.Ident("ValueOfString") 209 case protoreflect.BytesKind: 210 return protoreflectPkg.Ident("ValueOfBytes") 211 case protoreflect.MessageKind, protoreflect.GroupKind: 212 return protoreflectPkg.Ident("ValueOfMessage") 213 default: 214 panic("should not reach here") 215 } 216 } 217 218 // valueUnwrapper provides the function to call on value 219 // in order to get the concrete underlying type 220 func valueUnwrapper(kind protoreflect.Kind) string { 221 switch kind { 222 case protoreflect.BoolKind: 223 return "Bool" 224 case protoreflect.EnumKind: 225 return "Enum" 226 case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind: 227 return "Int" 228 case protoreflect.Uint32Kind, protoreflect.Fixed32Kind: 229 return "Uint" 230 case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind: 231 return "Int" 232 case protoreflect.Uint64Kind, protoreflect.Fixed64Kind: 233 return "Uint" 234 case protoreflect.FloatKind: 235 return "Float" 236 case protoreflect.DoubleKind: 237 return "Float" 238 case protoreflect.StringKind: 239 return "String" 240 case protoreflect.BytesKind: 241 return "Bytes" 242 case protoreflect.MessageKind, protoreflect.GroupKind: 243 return "Message" 244 default: 245 panic("should not reach here") 246 } 247 } 248 249 func genPrefValueToGoValue(g *generator.GeneratedFile, field *protogen.Field, inputName string, outputName string) string { 250 251 unwrapperFunc := valueUnwrapper(field.Desc.Kind()) 252 unwrapperVar := fmt.Sprintf("%sUnwrapped", inputName) 253 g.P(unwrapperVar, " := ", inputName, ".", unwrapperFunc, "()") 254 switch field.Desc.Kind() { 255 case protoreflect.MessageKind: 256 g.P(outputName, " := ", unwrapperVar, ".Interface().(*", g.QualifiedGoIdent(field.Message.GoIdent), ")") 257 case protoreflect.EnumKind: 258 g.P(outputName, " := (", g.QualifiedGoIdent(field.Enum.GoIdent), ")(", unwrapperVar, ")") 259 case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind: 260 g.P(outputName, " := (int32)(", unwrapperVar, ")") 261 case protoreflect.Uint32Kind, protoreflect.Fixed32Kind: 262 g.P(outputName, " := (uint32)(", unwrapperVar, ")") 263 case protoreflect.FloatKind: 264 g.P(outputName, " := (float32)(", unwrapperVar, ")") 265 default: 266 g.P(outputName, " := ", unwrapperVar) 267 } 268 269 return outputName 270 } 271 272 func zeroValueForField(g *generator.GeneratedFile, field *protogen.Field) string { 273 switch field.Desc.Kind() { 274 case protoreflect.BoolKind: 275 return "false" 276 case protoreflect.EnumKind: 277 return fmt.Sprintf("%d", field.Enum.Desc.Values().Get(0).Number()) 278 case protoreflect.Int32Kind, protoreflect.Sint32Kind, protoreflect.Sfixed32Kind: 279 return "int32(0)" 280 case protoreflect.Uint32Kind, protoreflect.Fixed32Kind: 281 return "uint32(0)" 282 case protoreflect.Int64Kind, protoreflect.Sint64Kind, protoreflect.Sfixed64Kind: 283 return "int64(0)" 284 case protoreflect.Uint64Kind, protoreflect.Fixed64Kind: 285 return "uint64(0)" 286 case protoreflect.FloatKind: 287 return "float32(0)" 288 case protoreflect.DoubleKind: 289 return "float64(0)" 290 case protoreflect.StringKind: 291 return "\"\"" 292 case protoreflect.BytesKind: 293 return "nil" 294 case protoreflect.MessageKind, protoreflect.GroupKind: 295 return fmt.Sprintf("new(%s)", g.QualifiedGoIdent(field.Message.GoIdent)) 296 default: 297 panic("should not reach") 298 } 299 }