github.com/tickoalcantara12/micro/v3@v3.0.0-20221007104245-9d75b9bcbab9/cmd/protoc-gen-openapi/converter/message.go (about) 1 package converter 2 3 import ( 4 "fmt" 5 "strings" 6 7 "github.com/getkin/kin-openapi/openapi3" 8 "github.com/golang/protobuf/proto" 9 "github.com/golang/protobuf/protoc-gen-go/descriptor" 10 "github.com/tickoalcantara12/micro/v3/service/logger" 11 "google.golang.org/protobuf/compiler/protogen" 12 ) 13 14 const ( 15 openAPIFormatByte = "byte" 16 openAPIFormatDateTime = "date-time" 17 openAPIFormatDouble = "double" 18 openAPIFormatInt32 = "int32" 19 openAPIFormatInt64 = "int64" 20 openAPITypeArray = "array" 21 openAPITypeBoolean = "boolean" 22 openAPITypeNumber = "number" 23 openAPITypeObject = "object" 24 openAPITypeString = "string" 25 ) 26 27 var ( 28 globalPkg = &ProtoPackage{ 29 name: "", 30 parent: nil, 31 children: make(map[string]*ProtoPackage), 32 types: make(map[string]*descriptor.DescriptorProto), 33 } 34 35 wellKnownTypes = map[string]bool{ 36 "DoubleValue": true, 37 "FloatValue": true, 38 "Int64Value": true, 39 "UInt64Value": true, 40 "Int32Value": true, 41 "UInt32Value": true, 42 "BoolValue": true, 43 "StringValue": true, 44 "BytesValue": true, 45 "Value": true, 46 } 47 ) 48 49 func (c *Converter) registerType(pkgName *string, msg *descriptor.DescriptorProto) { 50 pkg := globalPkg 51 if pkgName != nil { 52 for _, node := range strings.Split(*pkgName, ".") { 53 if pkg == globalPkg && node == "" { 54 // Skips leading "." 55 continue 56 } 57 child, ok := pkg.children[node] 58 if !ok { 59 child = &ProtoPackage{ 60 name: pkg.name + "." + node, 61 parent: pkg, 62 children: make(map[string]*ProtoPackage), 63 types: make(map[string]*descriptor.DescriptorProto), 64 } 65 pkg.children[node] = child 66 } 67 pkg = child 68 } 69 } 70 pkg.types[msg.GetName()] = msg 71 } 72 73 func (c *Converter) relativelyLookupNestedType(desc *descriptor.DescriptorProto, name string) (*descriptor.DescriptorProto, bool) { 74 components := strings.Split(name, ".") 75 componentLoop: 76 for _, component := range components { 77 for _, nested := range desc.GetNestedType() { 78 if nested.GetName() == component { 79 desc = nested 80 continue componentLoop 81 } 82 } 83 logger.Infof("no such nested message (%s.%s)", component, desc.GetName()) 84 return nil, false 85 } 86 return desc, true 87 } 88 89 // @todo a bit of a copypaste from the function below, i did not know what to do 90 // with enums in the callsite of this function 91 func toTypeAndFormat(desc *descriptor.FieldDescriptorProto) (string, string, error) { 92 switch desc.GetType() { 93 case descriptor.FieldDescriptorProto_TYPE_DOUBLE, 94 descriptor.FieldDescriptorProto_TYPE_FLOAT: 95 return openAPITypeNumber, openAPIFormatDouble, nil 96 97 case descriptor.FieldDescriptorProto_TYPE_INT32, 98 descriptor.FieldDescriptorProto_TYPE_UINT32, 99 descriptor.FieldDescriptorProto_TYPE_FIXED32, 100 descriptor.FieldDescriptorProto_TYPE_SFIXED32, 101 descriptor.FieldDescriptorProto_TYPE_SINT32: 102 return openAPITypeNumber, openAPIFormatInt32, nil 103 104 case descriptor.FieldDescriptorProto_TYPE_INT64, 105 descriptor.FieldDescriptorProto_TYPE_UINT64, 106 descriptor.FieldDescriptorProto_TYPE_FIXED64, 107 descriptor.FieldDescriptorProto_TYPE_SFIXED64, 108 descriptor.FieldDescriptorProto_TYPE_SINT64: 109 return openAPITypeNumber, openAPIFormatInt64, nil 110 111 case descriptor.FieldDescriptorProto_TYPE_STRING: 112 return openAPITypeString, "", nil 113 114 case descriptor.FieldDescriptorProto_TYPE_BYTES: 115 return openAPITypeString, openAPIFormatByte, nil 116 117 case descriptor.FieldDescriptorProto_TYPE_ENUM: 118 return "string", "", nil 119 120 case descriptor.FieldDescriptorProto_TYPE_BOOL: 121 return openAPITypeBoolean, "", nil 122 123 case descriptor.FieldDescriptorProto_TYPE_GROUP, descriptor.FieldDescriptorProto_TYPE_MESSAGE: 124 switch desc.GetTypeName() { 125 case ".google.protobuf.Timestamp": 126 return openAPITypeString, openAPIFormatDateTime, nil 127 default: 128 return openAPITypeObject, "", nil 129 } 130 131 default: 132 return "", "", fmt.Errorf("unrecognized field type: %s", desc.GetType().String()) 133 } 134 } 135 136 // Convert a proto "field" (essentially a type-switch with some recursion): 137 func (c *Converter) convertField(curPkg *ProtoPackage, desc *descriptor.FieldDescriptorProto, msg *descriptor.DescriptorProto) (*openapi3.Schema, error) { 138 139 // Prepare a new jsonschema.Type for our eventual return value: 140 componentSchema := &openapi3.Schema{} 141 142 // Generate a description from src comments (if available) 143 if src := c.sourceInfo.GetField(desc); src != nil { 144 componentSchema.Description = formatDescription(src) 145 } 146 147 // Switch the types, and pick a JSONSchema equivalent: 148 switch desc.GetType() { 149 case descriptor.FieldDescriptorProto_TYPE_DOUBLE, 150 descriptor.FieldDescriptorProto_TYPE_FLOAT: 151 componentSchema.Type = openAPITypeNumber 152 componentSchema.Format = openAPIFormatDouble 153 154 case descriptor.FieldDescriptorProto_TYPE_INT32, 155 descriptor.FieldDescriptorProto_TYPE_UINT32, 156 descriptor.FieldDescriptorProto_TYPE_FIXED32, 157 descriptor.FieldDescriptorProto_TYPE_SFIXED32, 158 descriptor.FieldDescriptorProto_TYPE_SINT32: 159 componentSchema.Type = openAPITypeNumber 160 componentSchema.Format = openAPIFormatInt32 161 162 case descriptor.FieldDescriptorProto_TYPE_INT64, 163 descriptor.FieldDescriptorProto_TYPE_UINT64, 164 descriptor.FieldDescriptorProto_TYPE_FIXED64, 165 descriptor.FieldDescriptorProto_TYPE_SFIXED64, 166 descriptor.FieldDescriptorProto_TYPE_SINT64: 167 componentSchema.Type = openAPITypeNumber 168 componentSchema.Format = openAPIFormatInt64 169 170 case descriptor.FieldDescriptorProto_TYPE_STRING: 171 componentSchema.Type = openAPITypeString 172 173 case descriptor.FieldDescriptorProto_TYPE_BYTES: 174 componentSchema.Type = openAPITypeString 175 componentSchema.Format = openAPIFormatByte 176 177 case descriptor.FieldDescriptorProto_TYPE_ENUM: 178 componentSchema.Type = "string" 179 180 // Go through all the enums we have, see if we can match any to this field by name: 181 for _, enumDescriptor := range msg.GetEnumType() { 182 183 // Each one has several values: 184 for _, enumValue := range enumDescriptor.Value { 185 186 // Figure out the entire name of this field: 187 fullFieldName := fmt.Sprintf(".%v.%v", *msg.Name, *enumDescriptor.Name) 188 189 // If we find ENUM values for this field then put them into the JSONSchema list of allowed ENUM values: 190 if strings.HasSuffix(desc.GetTypeName(), fullFieldName) { 191 componentSchema.Enum = append(componentSchema.Enum, *enumValue.Name) 192 } 193 } 194 } 195 196 case descriptor.FieldDescriptorProto_TYPE_BOOL: 197 componentSchema.Type = openAPITypeBoolean 198 199 case descriptor.FieldDescriptorProto_TYPE_GROUP, descriptor.FieldDescriptorProto_TYPE_MESSAGE: 200 switch desc.GetTypeName() { 201 case ".google.protobuf.Timestamp": 202 componentSchema.Type = openAPITypeString 203 componentSchema.Format = openAPIFormatDateTime 204 default: 205 componentSchema.Type = openAPITypeObject 206 } 207 208 default: 209 return nil, fmt.Errorf("unrecognized field type: %s", desc.GetType().String()) 210 } 211 212 isList := false 213 var field *protogen.Field 214 if c.plug != nil && len(c.plug.Files) > 0 { 215 for _, file := range c.plug.Files { 216 for _, message := range file.Messages { 217 for _, f := range message.Fields { 218 parts := strings.Split(string(f.GoIdent.GoName), "_") 219 messageName := parts[0] 220 221 if messageName != *msg.Name { 222 continue 223 } 224 fieldName := parts[1] 225 if strings.ToLower(fieldName) == *desc.Name { 226 isList = f.Desc.IsList() 227 field = f 228 } 229 } 230 } 231 } 232 233 } 234 // Recurse array of primitive types: 235 if isList && componentSchema.Type != openAPITypeObject { 236 componentSchema.Items = &openapi3.SchemaRef{ 237 Value: &openapi3.Schema{ 238 Type: componentSchema.Type, 239 }, 240 } 241 componentSchema.Type = openAPITypeArray 242 return componentSchema, nil 243 } 244 245 // Recurse nested objects / arrays of objects (if necessary): 246 if componentSchema.Type == openAPITypeObject { 247 248 recordType, pkgName, ok := c.lookupType(curPkg, desc.GetTypeName()) 249 if !ok { 250 return nil, fmt.Errorf("no such message type named %s", desc.GetTypeName()) 251 } 252 253 // Recurse the recordType: 254 recursedComponentSchema, err := c.recursiveConvertMessageType(curPkg, recordType, pkgName) 255 if err != nil { 256 return nil, err 257 } 258 259 // Maps, arrays, and objects are structured in different ways: 260 switch { 261 case field != nil && field.Desc.Message().FullName() == "google.protobuf.Struct": 262 if !isList { 263 componentSchema.Type = openAPITypeObject 264 } else { 265 componentSchema.Items = &openapi3.SchemaRef{ 266 Value: &openapi3.Schema{ 267 Type: openAPITypeObject, 268 Properties: map[string]*openapi3.SchemaRef{}, 269 }, 270 } 271 272 componentSchema.Type = openAPITypeArray 273 } 274 // Arrays: 275 case isList: 276 componentSchema.Items = &openapi3.SchemaRef{ 277 Value: &openapi3.Schema{ 278 Type: openAPITypeObject, 279 Properties: recursedComponentSchema.Properties, 280 }, 281 } 282 283 componentSchema.Type = openAPITypeArray 284 // Maps: 285 case recordType.Options.GetMapEntry(): 286 logger.Tracef("Found a map (%s.%s)", *msg.Name, recordType.GetName()) 287 componentSchema.Type = openAPITypeObject 288 // fields of a map: key, value. we need the type of value here as key is always string 289 // see https://swagger.io/docs/specification/data-models/dictionaries/ 290 typ, format, err := toTypeAndFormat(recordType.Field[1]) 291 if err != nil { 292 return nil, err 293 } 294 componentSchema.AdditionalProperties = &openapi3.SchemaRef{ 295 Value: &openapi3.Schema{ 296 Type: typ, 297 Format: format, 298 }, 299 } 300 // Objects: 301 default: 302 componentSchema.Properties = recursedComponentSchema.Properties 303 // recursedComponentSchemaRef := fmt.Sprintf("#/components/schemas/%s", recursedComponentSchema.Title) 304 // componentSchema.Properties = openapi3.NewSchemaRef(recursedComponentSchemaRef, nil) 305 } 306 } 307 308 return componentSchema, nil 309 } 310 311 // Converts a proto "MESSAGE" into an OpenAPI schema: 312 func (c *Converter) convertMessageType(curPkg *ProtoPackage, msg *descriptor.DescriptorProto) (*openapi3.Schema, error) { 313 314 // main schema for the message 315 rootType, err := c.recursiveConvertMessageType(curPkg, msg, "") 316 if err != nil { 317 return nil, err 318 } 319 320 return rootType, nil 321 } 322 323 type nameAndCounter struct { 324 name string 325 counter int 326 } 327 328 func (c *Converter) recursiveConvertMessageType(curPkg *ProtoPackage, msg *descriptor.DescriptorProto, pkgName string) (*openapi3.Schema, error) { 329 if msg.Name != nil && wellKnownTypes[*msg.Name] && pkgName == ".google.protobuf" { 330 componentSchema := &openapi3.Schema{ 331 Title: msg.GetName(), 332 } 333 switch *msg.Name { 334 case "DoubleValue", "FloatValue": 335 componentSchema.Type = openAPITypeNumber 336 componentSchema.Format = openAPIFormatDouble 337 case "Int32Value", "UInt32Value": 338 componentSchema.Type = openAPITypeNumber 339 componentSchema.Format = openAPIFormatInt32 340 case "Int64Value", "UInt64Value": 341 componentSchema.Type = openAPITypeNumber 342 componentSchema.Format = openAPIFormatInt64 343 case "BoolValue": 344 componentSchema.Type = openAPITypeBoolean 345 case "StringValue": 346 componentSchema.Type = openAPITypeString 347 case "BytesValue": 348 componentSchema.Type = openAPITypeString 349 componentSchema.Format = openAPIFormatByte 350 case "Value": 351 componentSchema.Type = openAPITypeObject 352 } 353 return componentSchema, nil 354 } 355 356 // Prepare a new jsonschema: 357 componentSchema := &openapi3.Schema{ 358 Properties: make(map[string]*openapi3.SchemaRef), 359 Title: msg.GetName(), 360 Type: openAPITypeObject, 361 } 362 363 // Generate a description from src comments (if available) 364 if src := c.sourceInfo.GetMessage(msg); src != nil { 365 componentSchema.Description = formatDescription(src) 366 } 367 368 logger.Tracef("Converting message (%s)", proto.MarshalTextString(msg)) 369 370 // Recurse each field: 371 for _, fieldDesc := range msg.GetField() { 372 recursedComponentSchema, err := c.convertField(curPkg, fieldDesc, msg) 373 if err != nil { 374 logger.Errorf("Failed to convert field (%s.%s): %v", msg.GetName(), fieldDesc.GetName(), err) 375 return nil, err 376 } 377 logger.Tracef("Converted field: %s => %s", fieldDesc.GetName(), recursedComponentSchema.Type) 378 379 // Add it to the properties (by its JSON name): 380 componentSchema.Properties[fieldDesc.GetJsonName()] = openapi3.NewSchemaRef("", recursedComponentSchema) 381 } 382 383 return componentSchema, nil 384 } 385 386 func formatDescription(sl *descriptor.SourceCodeInfo_Location) string { 387 var lines []string 388 for _, str := range sl.GetLeadingDetachedComments() { 389 if s := strings.TrimSpace(str); s != "" { 390 lines = append(lines, s) 391 } 392 } 393 if s := strings.TrimSpace(sl.GetLeadingComments()); s != "" { 394 lines = append(lines, s) 395 } 396 if s := strings.TrimSpace(sl.GetTrailingComments()); s != "" { 397 lines = append(lines, s) 398 } 399 return strings.Join(lines, "\n\n") 400 }