github.com/tickoalcantara12/micro/v3@v3.0.0-20221007104245-9d75b9bcbab9/cmd/protoc-gen-openapi/converter/message.go (about)

     1  package converter
     2  
     3  import (
     4  	"fmt"
     5  	"strings"
     6  
     7  	"github.com/getkin/kin-openapi/openapi3"
     8  	"github.com/golang/protobuf/proto"
     9  	"github.com/golang/protobuf/protoc-gen-go/descriptor"
    10  	"github.com/tickoalcantara12/micro/v3/service/logger"
    11  	"google.golang.org/protobuf/compiler/protogen"
    12  )
    13  
    14  const (
    15  	openAPIFormatByte     = "byte"
    16  	openAPIFormatDateTime = "date-time"
    17  	openAPIFormatDouble   = "double"
    18  	openAPIFormatInt32    = "int32"
    19  	openAPIFormatInt64    = "int64"
    20  	openAPITypeArray      = "array"
    21  	openAPITypeBoolean    = "boolean"
    22  	openAPITypeNumber     = "number"
    23  	openAPITypeObject     = "object"
    24  	openAPITypeString     = "string"
    25  )
    26  
    27  var (
    28  	globalPkg = &ProtoPackage{
    29  		name:     "",
    30  		parent:   nil,
    31  		children: make(map[string]*ProtoPackage),
    32  		types:    make(map[string]*descriptor.DescriptorProto),
    33  	}
    34  
    35  	wellKnownTypes = map[string]bool{
    36  		"DoubleValue": true,
    37  		"FloatValue":  true,
    38  		"Int64Value":  true,
    39  		"UInt64Value": true,
    40  		"Int32Value":  true,
    41  		"UInt32Value": true,
    42  		"BoolValue":   true,
    43  		"StringValue": true,
    44  		"BytesValue":  true,
    45  		"Value":       true,
    46  	}
    47  )
    48  
    49  func (c *Converter) registerType(pkgName *string, msg *descriptor.DescriptorProto) {
    50  	pkg := globalPkg
    51  	if pkgName != nil {
    52  		for _, node := range strings.Split(*pkgName, ".") {
    53  			if pkg == globalPkg && node == "" {
    54  				// Skips leading "."
    55  				continue
    56  			}
    57  			child, ok := pkg.children[node]
    58  			if !ok {
    59  				child = &ProtoPackage{
    60  					name:     pkg.name + "." + node,
    61  					parent:   pkg,
    62  					children: make(map[string]*ProtoPackage),
    63  					types:    make(map[string]*descriptor.DescriptorProto),
    64  				}
    65  				pkg.children[node] = child
    66  			}
    67  			pkg = child
    68  		}
    69  	}
    70  	pkg.types[msg.GetName()] = msg
    71  }
    72  
    73  func (c *Converter) relativelyLookupNestedType(desc *descriptor.DescriptorProto, name string) (*descriptor.DescriptorProto, bool) {
    74  	components := strings.Split(name, ".")
    75  componentLoop:
    76  	for _, component := range components {
    77  		for _, nested := range desc.GetNestedType() {
    78  			if nested.GetName() == component {
    79  				desc = nested
    80  				continue componentLoop
    81  			}
    82  		}
    83  		logger.Infof("no such nested message (%s.%s)", component, desc.GetName())
    84  		return nil, false
    85  	}
    86  	return desc, true
    87  }
    88  
    89  // @todo a bit of a copypaste from the function below, i did not know what to do
    90  // with enums in the callsite of this function
    91  func toTypeAndFormat(desc *descriptor.FieldDescriptorProto) (string, string, error) {
    92  	switch desc.GetType() {
    93  	case descriptor.FieldDescriptorProto_TYPE_DOUBLE,
    94  		descriptor.FieldDescriptorProto_TYPE_FLOAT:
    95  		return openAPITypeNumber, openAPIFormatDouble, nil
    96  
    97  	case descriptor.FieldDescriptorProto_TYPE_INT32,
    98  		descriptor.FieldDescriptorProto_TYPE_UINT32,
    99  		descriptor.FieldDescriptorProto_TYPE_FIXED32,
   100  		descriptor.FieldDescriptorProto_TYPE_SFIXED32,
   101  		descriptor.FieldDescriptorProto_TYPE_SINT32:
   102  		return openAPITypeNumber, openAPIFormatInt32, nil
   103  
   104  	case descriptor.FieldDescriptorProto_TYPE_INT64,
   105  		descriptor.FieldDescriptorProto_TYPE_UINT64,
   106  		descriptor.FieldDescriptorProto_TYPE_FIXED64,
   107  		descriptor.FieldDescriptorProto_TYPE_SFIXED64,
   108  		descriptor.FieldDescriptorProto_TYPE_SINT64:
   109  		return openAPITypeNumber, openAPIFormatInt64, nil
   110  
   111  	case descriptor.FieldDescriptorProto_TYPE_STRING:
   112  		return openAPITypeString, "", nil
   113  
   114  	case descriptor.FieldDescriptorProto_TYPE_BYTES:
   115  		return openAPITypeString, openAPIFormatByte, nil
   116  
   117  	case descriptor.FieldDescriptorProto_TYPE_ENUM:
   118  		return "string", "", nil
   119  
   120  	case descriptor.FieldDescriptorProto_TYPE_BOOL:
   121  		return openAPITypeBoolean, "", nil
   122  
   123  	case descriptor.FieldDescriptorProto_TYPE_GROUP, descriptor.FieldDescriptorProto_TYPE_MESSAGE:
   124  		switch desc.GetTypeName() {
   125  		case ".google.protobuf.Timestamp":
   126  			return openAPITypeString, openAPIFormatDateTime, nil
   127  		default:
   128  			return openAPITypeObject, "", nil
   129  		}
   130  
   131  	default:
   132  		return "", "", fmt.Errorf("unrecognized field type: %s", desc.GetType().String())
   133  	}
   134  }
   135  
   136  // Convert a proto "field" (essentially a type-switch with some recursion):
   137  func (c *Converter) convertField(curPkg *ProtoPackage, desc *descriptor.FieldDescriptorProto, msg *descriptor.DescriptorProto) (*openapi3.Schema, error) {
   138  
   139  	// Prepare a new jsonschema.Type for our eventual return value:
   140  	componentSchema := &openapi3.Schema{}
   141  
   142  	// Generate a description from src comments (if available)
   143  	if src := c.sourceInfo.GetField(desc); src != nil {
   144  		componentSchema.Description = formatDescription(src)
   145  	}
   146  
   147  	// Switch the types, and pick a JSONSchema equivalent:
   148  	switch desc.GetType() {
   149  	case descriptor.FieldDescriptorProto_TYPE_DOUBLE,
   150  		descriptor.FieldDescriptorProto_TYPE_FLOAT:
   151  		componentSchema.Type = openAPITypeNumber
   152  		componentSchema.Format = openAPIFormatDouble
   153  
   154  	case descriptor.FieldDescriptorProto_TYPE_INT32,
   155  		descriptor.FieldDescriptorProto_TYPE_UINT32,
   156  		descriptor.FieldDescriptorProto_TYPE_FIXED32,
   157  		descriptor.FieldDescriptorProto_TYPE_SFIXED32,
   158  		descriptor.FieldDescriptorProto_TYPE_SINT32:
   159  		componentSchema.Type = openAPITypeNumber
   160  		componentSchema.Format = openAPIFormatInt32
   161  
   162  	case descriptor.FieldDescriptorProto_TYPE_INT64,
   163  		descriptor.FieldDescriptorProto_TYPE_UINT64,
   164  		descriptor.FieldDescriptorProto_TYPE_FIXED64,
   165  		descriptor.FieldDescriptorProto_TYPE_SFIXED64,
   166  		descriptor.FieldDescriptorProto_TYPE_SINT64:
   167  		componentSchema.Type = openAPITypeNumber
   168  		componentSchema.Format = openAPIFormatInt64
   169  
   170  	case descriptor.FieldDescriptorProto_TYPE_STRING:
   171  		componentSchema.Type = openAPITypeString
   172  
   173  	case descriptor.FieldDescriptorProto_TYPE_BYTES:
   174  		componentSchema.Type = openAPITypeString
   175  		componentSchema.Format = openAPIFormatByte
   176  
   177  	case descriptor.FieldDescriptorProto_TYPE_ENUM:
   178  		componentSchema.Type = "string"
   179  
   180  		// Go through all the enums we have, see if we can match any to this field by name:
   181  		for _, enumDescriptor := range msg.GetEnumType() {
   182  
   183  			// Each one has several values:
   184  			for _, enumValue := range enumDescriptor.Value {
   185  
   186  				// Figure out the entire name of this field:
   187  				fullFieldName := fmt.Sprintf(".%v.%v", *msg.Name, *enumDescriptor.Name)
   188  
   189  				// If we find ENUM values for this field then put them into the JSONSchema list of allowed ENUM values:
   190  				if strings.HasSuffix(desc.GetTypeName(), fullFieldName) {
   191  					componentSchema.Enum = append(componentSchema.Enum, *enumValue.Name)
   192  				}
   193  			}
   194  		}
   195  
   196  	case descriptor.FieldDescriptorProto_TYPE_BOOL:
   197  		componentSchema.Type = openAPITypeBoolean
   198  
   199  	case descriptor.FieldDescriptorProto_TYPE_GROUP, descriptor.FieldDescriptorProto_TYPE_MESSAGE:
   200  		switch desc.GetTypeName() {
   201  		case ".google.protobuf.Timestamp":
   202  			componentSchema.Type = openAPITypeString
   203  			componentSchema.Format = openAPIFormatDateTime
   204  		default:
   205  			componentSchema.Type = openAPITypeObject
   206  		}
   207  
   208  	default:
   209  		return nil, fmt.Errorf("unrecognized field type: %s", desc.GetType().String())
   210  	}
   211  
   212  	isList := false
   213  	var field *protogen.Field
   214  	if c.plug != nil && len(c.plug.Files) > 0 {
   215  		for _, file := range c.plug.Files {
   216  			for _, message := range file.Messages {
   217  				for _, f := range message.Fields {
   218  					parts := strings.Split(string(f.GoIdent.GoName), "_")
   219  					messageName := parts[0]
   220  
   221  					if messageName != *msg.Name {
   222  						continue
   223  					}
   224  					fieldName := parts[1]
   225  					if strings.ToLower(fieldName) == *desc.Name {
   226  						isList = f.Desc.IsList()
   227  						field = f
   228  					}
   229  				}
   230  			}
   231  		}
   232  
   233  	}
   234  	// Recurse array of primitive types:
   235  	if isList && componentSchema.Type != openAPITypeObject {
   236  		componentSchema.Items = &openapi3.SchemaRef{
   237  			Value: &openapi3.Schema{
   238  				Type: componentSchema.Type,
   239  			},
   240  		}
   241  		componentSchema.Type = openAPITypeArray
   242  		return componentSchema, nil
   243  	}
   244  
   245  	// Recurse nested objects / arrays of objects (if necessary):
   246  	if componentSchema.Type == openAPITypeObject {
   247  
   248  		recordType, pkgName, ok := c.lookupType(curPkg, desc.GetTypeName())
   249  		if !ok {
   250  			return nil, fmt.Errorf("no such message type named %s", desc.GetTypeName())
   251  		}
   252  
   253  		// Recurse the recordType:
   254  		recursedComponentSchema, err := c.recursiveConvertMessageType(curPkg, recordType, pkgName)
   255  		if err != nil {
   256  			return nil, err
   257  		}
   258  
   259  		// Maps, arrays, and objects are structured in different ways:
   260  		switch {
   261  		case field != nil && field.Desc.Message().FullName() == "google.protobuf.Struct":
   262  			if !isList {
   263  				componentSchema.Type = openAPITypeObject
   264  			} else {
   265  				componentSchema.Items = &openapi3.SchemaRef{
   266  					Value: &openapi3.Schema{
   267  						Type:       openAPITypeObject,
   268  						Properties: map[string]*openapi3.SchemaRef{},
   269  					},
   270  				}
   271  
   272  				componentSchema.Type = openAPITypeArray
   273  			}
   274  		// Arrays:
   275  		case isList:
   276  			componentSchema.Items = &openapi3.SchemaRef{
   277  				Value: &openapi3.Schema{
   278  					Type:       openAPITypeObject,
   279  					Properties: recursedComponentSchema.Properties,
   280  				},
   281  			}
   282  
   283  			componentSchema.Type = openAPITypeArray
   284  		// Maps:
   285  		case recordType.Options.GetMapEntry():
   286  			logger.Tracef("Found a map (%s.%s)", *msg.Name, recordType.GetName())
   287  			componentSchema.Type = openAPITypeObject
   288  			// fields of a map: key, value. we need the type of value here as key is always string
   289  			// see https://swagger.io/docs/specification/data-models/dictionaries/
   290  			typ, format, err := toTypeAndFormat(recordType.Field[1])
   291  			if err != nil {
   292  				return nil, err
   293  			}
   294  			componentSchema.AdditionalProperties = &openapi3.SchemaRef{
   295  				Value: &openapi3.Schema{
   296  					Type:   typ,
   297  					Format: format,
   298  				},
   299  			}
   300  		// Objects:
   301  		default:
   302  			componentSchema.Properties = recursedComponentSchema.Properties
   303  			// recursedComponentSchemaRef := fmt.Sprintf("#/components/schemas/%s", recursedComponentSchema.Title)
   304  			// componentSchema.Properties = openapi3.NewSchemaRef(recursedComponentSchemaRef, nil)
   305  		}
   306  	}
   307  
   308  	return componentSchema, nil
   309  }
   310  
   311  // Converts a proto "MESSAGE" into an OpenAPI schema:
   312  func (c *Converter) convertMessageType(curPkg *ProtoPackage, msg *descriptor.DescriptorProto) (*openapi3.Schema, error) {
   313  
   314  	// main schema for the message
   315  	rootType, err := c.recursiveConvertMessageType(curPkg, msg, "")
   316  	if err != nil {
   317  		return nil, err
   318  	}
   319  
   320  	return rootType, nil
   321  }
   322  
   323  type nameAndCounter struct {
   324  	name    string
   325  	counter int
   326  }
   327  
   328  func (c *Converter) recursiveConvertMessageType(curPkg *ProtoPackage, msg *descriptor.DescriptorProto, pkgName string) (*openapi3.Schema, error) {
   329  	if msg.Name != nil && wellKnownTypes[*msg.Name] && pkgName == ".google.protobuf" {
   330  		componentSchema := &openapi3.Schema{
   331  			Title: msg.GetName(),
   332  		}
   333  		switch *msg.Name {
   334  		case "DoubleValue", "FloatValue":
   335  			componentSchema.Type = openAPITypeNumber
   336  			componentSchema.Format = openAPIFormatDouble
   337  		case "Int32Value", "UInt32Value":
   338  			componentSchema.Type = openAPITypeNumber
   339  			componentSchema.Format = openAPIFormatInt32
   340  		case "Int64Value", "UInt64Value":
   341  			componentSchema.Type = openAPITypeNumber
   342  			componentSchema.Format = openAPIFormatInt64
   343  		case "BoolValue":
   344  			componentSchema.Type = openAPITypeBoolean
   345  		case "StringValue":
   346  			componentSchema.Type = openAPITypeString
   347  		case "BytesValue":
   348  			componentSchema.Type = openAPITypeString
   349  			componentSchema.Format = openAPIFormatByte
   350  		case "Value":
   351  			componentSchema.Type = openAPITypeObject
   352  		}
   353  		return componentSchema, nil
   354  	}
   355  
   356  	// Prepare a new jsonschema:
   357  	componentSchema := &openapi3.Schema{
   358  		Properties: make(map[string]*openapi3.SchemaRef),
   359  		Title:      msg.GetName(),
   360  		Type:       openAPITypeObject,
   361  	}
   362  
   363  	// Generate a description from src comments (if available)
   364  	if src := c.sourceInfo.GetMessage(msg); src != nil {
   365  		componentSchema.Description = formatDescription(src)
   366  	}
   367  
   368  	logger.Tracef("Converting message (%s)", proto.MarshalTextString(msg))
   369  
   370  	// Recurse each field:
   371  	for _, fieldDesc := range msg.GetField() {
   372  		recursedComponentSchema, err := c.convertField(curPkg, fieldDesc, msg)
   373  		if err != nil {
   374  			logger.Errorf("Failed to convert field (%s.%s): %v", msg.GetName(), fieldDesc.GetName(), err)
   375  			return nil, err
   376  		}
   377  		logger.Tracef("Converted field: %s => %s", fieldDesc.GetName(), recursedComponentSchema.Type)
   378  
   379  		// Add it to the properties (by its JSON name):
   380  		componentSchema.Properties[fieldDesc.GetJsonName()] = openapi3.NewSchemaRef("", recursedComponentSchema)
   381  	}
   382  
   383  	return componentSchema, nil
   384  }
   385  
   386  func formatDescription(sl *descriptor.SourceCodeInfo_Location) string {
   387  	var lines []string
   388  	for _, str := range sl.GetLeadingDetachedComments() {
   389  		if s := strings.TrimSpace(str); s != "" {
   390  			lines = append(lines, s)
   391  		}
   392  	}
   393  	if s := strings.TrimSpace(sl.GetLeadingComments()); s != "" {
   394  		lines = append(lines, s)
   395  	}
   396  	if s := strings.TrimSpace(sl.GetTrailingComments()); s != "" {
   397  		lines = append(lines, s)
   398  	}
   399  	return strings.Join(lines, "\n\n")
   400  }