github.com/artisanhe/tools@v1.0.1-0.20210607022958-19a8fef2eb04/courier/swagger/gen/definition_scanner.go (about)

     1  package gen
     2  
     3  import (
     4  	"fmt"
     5  	"go/types"
     6  	"reflect"
     7  	"regexp"
     8  	"strings"
     9  
    10  	"github.com/go-courier/ptr"
    11  
    12  	"github.com/artisanhe/tools/codegen"
    13  	"github.com/go-courier/oas"
    14  	"github.com/sirupsen/logrus"
    15  	"golang.org/x/tools/go/loader"
    16  
    17  	"github.com/artisanhe/tools/codegen/loaderx"
    18  )
    19  
    20  const (
    21  	XNamed       = `x-go-named`
    22  	XField       = `x-go-name`
    23  	XTagJSON     = `x-go-json`
    24  	XTagName     = `x-tag-name`
    25  	XTagXML      = `x-tag-xml`
    26  	XTagStyle    = `x-tag-style`
    27  	XTagFmt      = `x-tag-fmt`
    28  	XTagValidate = `x-go-validate`
    29  	XPointer     = "x-pointer"
    30  	XEnumValues  = `x-enum-values`
    31  	XEnumLabels  = `x-enum-labels`
    32  	XEnumVals    = `x-enum-vals`
    33  )
    34  
    35  func NewDefinitionScanner(program *loader.Program) *DefinitionScanner {
    36  	return &DefinitionScanner{
    37  		EnumScanner: NewEnumScanner(program),
    38  		program:     program,
    39  	}
    40  }
    41  
    42  type DefinitionScanner struct {
    43  	EnumScanner *EnumScanner
    44  	program     *loader.Program
    45  	definitions map[*types.TypeName]*oas.Schema
    46  }
    47  
    48  func (scanner *DefinitionScanner) BindSchemas(openapi *oas.OpenAPI) {
    49  	for typeName, schema := range scanner.definitions {
    50  		schema.AddExtension(XNamed, fmt.Sprintf("%s.%s", typeName.Pkg().Path(), typeName.Name()))
    51  		defKey := toDefID(typeName.Type().String())
    52  		if _, exists := openapi.Components.Schemas[defKey]; exists {
    53  			logrus.Panicf("`%s` already used by %s", defKey, typeName.String())
    54  		} else {
    55  			openapi.AddSchema(toDefID(typeName.Type().String()), schema)
    56  		}
    57  	}
    58  	return
    59  }
    60  
    61  func Package(program *loader.Program, path string) *loader.PackageInfo {
    62  	for _, info := range program.AllPackages {
    63  		pkgPath := info.Pkg.Path()
    64  		if path == pkgPath || path == codegen.DeVendor(pkgPath) || codegen.DeVendor(path) == codegen.DeVendor(pkgPath) {
    65  			return info
    66  		}
    67  	}
    68  	for _, info := range program.Created {
    69  		pkgPath := info.Pkg.Path()
    70  		if path == pkgPath || path == codegen.DeVendor(pkgPath) || codegen.DeVendor(path) == codegen.DeVendor(pkgPath) {
    71  			return info
    72  		}
    73  	}
    74  	return nil
    75  }
    76  
    77  func (scanner *DefinitionScanner) getSchemaByTypeString(typeString string) *oas.Schema {
    78  	pkgImportPath, _ := loaderx.GetPkgImportPathAndExpose(typeString)
    79  	pkgImportPath = loaderx.ResolvePkgImport(pkgImportPath)
    80  	pkgInfo := Package(scanner.program, pkgImportPath)
    81  	if pkgInfo == nil {
    82  		panic(fmt.Errorf("missing pkg %s of %s", pkgImportPath, typeString))
    83  	}
    84  
    85  	for _, def := range pkgInfo.Defs {
    86  		if typeName, ok := def.(*types.TypeName); ok {
    87  			if typeName.Type().String() == typeString {
    88  				return scanner.getSchemaByType(typeName.Type())
    89  			}
    90  		}
    91  	}
    92  	return nil
    93  }
    94  
    95  func (scanner *DefinitionScanner) Def(typeName *types.TypeName) *oas.Schema {
    96  	if s, ok := scanner.definitions[typeName]; ok {
    97  		return s
    98  	}
    99  
   100  	if typeName.IsAlias() {
   101  		typeName = typeName.Type().(*types.Named).Obj()
   102  	}
   103  
   104  	doc := docOfTypeName(typeName.Type().(*types.Named).Obj(), scanner.program)
   105  
   106  	if doc, fmtName := ParseStrfmt(doc); fmtName != "" {
   107  		return scanner.addDef(typeName, oas.NewSchema(oas.TypeString, fmtName).WithDesc(doc))
   108  	}
   109  
   110  	// todo
   111  	if typeName.Name() == "Time" {
   112  		return scanner.addDef(typeName, oas.DateTime().WithDesc(doc))
   113  	}
   114  
   115  	doc, hasEnum := ParseEnum(doc)
   116  	if hasEnum {
   117  		enum := scanner.EnumScanner.Enum(typeName)
   118  		if len(enum) == 2 {
   119  			values := enum.Values()
   120  			if values[0] == "FALSE" && values[1] == "TRUE" {
   121  				return scanner.addDef(typeName, oas.Boolean())
   122  			}
   123  		}
   124  		if enum == nil {
   125  			panic(fmt.Errorf("missing enum option but annotated by swagger:enum"))
   126  		}
   127  		return scanner.addDef(typeName, enum.ToSchema().WithDesc(doc))
   128  	}
   129  
   130  	return scanner.addDef(typeName, scanner.getSchemaByType(typeName.Type().Underlying()).WithDesc(doc))
   131  }
   132  
   133  func (scanner *DefinitionScanner) addDef(typeName *types.TypeName, schema *oas.Schema) *oas.Schema {
   134  	if scanner.definitions == nil {
   135  		scanner.definitions = map[*types.TypeName]*oas.Schema{}
   136  	}
   137  	scanner.definitions[typeName] = schema
   138  	return schema
   139  }
   140  
   141  func (scanner *DefinitionScanner) getSchemaByType(tpe types.Type) *oas.Schema {
   142  	switch tpe.(type) {
   143  	case *types.Interface:
   144  		return &oas.Schema{
   145  			SchemaObject: oas.SchemaObject{
   146  				Type: oas.TypeObject,
   147  				AdditionalProperties: &oas.SchemaOrBool{
   148  					Allows: true,
   149  				},
   150  			},
   151  		}
   152  	case *types.Named:
   153  		named := tpe.(*types.Named)
   154  		if named.String() == "mime/multipart.FileHeader" {
   155  			return oas.Binary()
   156  		}
   157  		scanner.Def(named.Obj())
   158  		return oas.RefSchema(fmt.Sprintf("#/components/schemas/%s", toDefID(named.String())))
   159  	case *types.Basic:
   160  		typeName, format := getSchemaTypeFromBasicType(tpe.(*types.Basic).Name())
   161  		if typeName != "" {
   162  			return oas.NewSchema(typeName, format)
   163  		}
   164  	case *types.Pointer:
   165  		count := 0
   166  		pointer := tpe.(*types.Pointer)
   167  		elem := pointer.Elem()
   168  		for pointer != nil {
   169  			elem = pointer.Elem()
   170  			pointer, _ = pointer.Elem().(*types.Pointer)
   171  			count++
   172  		}
   173  		s := scanner.getSchemaByType(elem)
   174  		markPointer(s, count)
   175  		return s
   176  	case *types.Map:
   177  		keySchema := scanner.getSchemaByType(tpe.(*types.Map).Key())
   178  		if keySchema != nil && len(keySchema.Type) > 0 && keySchema.Type != "string" {
   179  			panic(fmt.Errorf("only support map[string]interface{}"))
   180  		}
   181  		return oas.MapOf(scanner.getSchemaByType(tpe.(*types.Map).Elem()))
   182  	case *types.Slice:
   183  		return oas.ItemsOf(scanner.getSchemaByType(tpe.(*types.Slice).Elem()))
   184  	case *types.Array:
   185  		typArray := tpe.(*types.Array)
   186  		length := typArray.Len()
   187  		return oas.ItemsOf(scanner.getSchemaByType(typArray.Elem())).WithValidation(&oas.SchemaValidation{
   188  			MaxItems: ptr.Uint64(uint64(length)),
   189  			MinItems: ptr.Uint64(uint64(length)),
   190  		})
   191  	case *types.Struct:
   192  		var structType = tpe.(*types.Struct)
   193  
   194  		err := StructFieldUniqueChecker{}.Check(structType, false)
   195  		if err != nil {
   196  			panic(fmt.Errorf("type %s: %s", tpe, err))
   197  		}
   198  
   199  		var structSchema = oas.ObjectOf(nil)
   200  		var schemas []*oas.Schema
   201  
   202  		for i := 0; i < structType.NumFields(); i++ {
   203  			field := structType.Field(i)
   204  
   205  			if !field.Exported() {
   206  				continue
   207  			}
   208  
   209  			structFieldType := field.Type()
   210  			structFieldTags := reflect.StructTag(structType.Tag(i))
   211  			jsonTagValue := structFieldTags.Get("json")
   212  			if jsonTagValue == "" {
   213  				jsonTagValue = structFieldTags.Get("name")
   214  			}
   215  
   216  			name, flags := getTagNameAndFlags(jsonTagValue)
   217  			if name == "-" {
   218  				continue
   219  			}
   220  
   221  			if name == "" && field.Anonymous() {
   222  				s := scanner.getSchemaByType(structFieldType)
   223  				if s != nil {
   224  					schemas = append(schemas, s)
   225  				}
   226  				continue
   227  			}
   228  
   229  			if name == "" {
   230  				name = field.Name()
   231  			}
   232  
   233  			defaultValue, hasDefault := structFieldTags.Lookup("default")
   234  			validate, hasValidate := structFieldTags.Lookup("validate")
   235  
   236  			required := true
   237  			if hasOmitempty, ok := flags["omitempty"]; ok {
   238  				required = !hasOmitempty
   239  			} else {
   240  				// todo don't use non-default as required
   241  				required = !hasDefault
   242  			}
   243  
   244  			propSchema := scanner.getSchemaByType(structFieldType)
   245  
   246  			if flags != nil && flags["string"] {
   247  				propSchema.Type = oas.TypeString
   248  			}
   249  
   250  			if defaultValue != "" {
   251  				propSchema.Default = defaultValue
   252  			}
   253  
   254  			if hasValidate {
   255  				BindValidateFromValidateTagString(propSchema, validate)
   256  			}
   257  
   258  			propSchema = propSchema.WithDesc(docOfTypeName(field, scanner.program))
   259  			propSchema.AddExtension(XField, field.Name())
   260  
   261  			if nameValue, hasName := structFieldTags.Lookup("name"); hasName {
   262  				propSchema.AddExtension(XTagName, nameValue)
   263  			}
   264  
   265  			if styleValue, hasStyle := structFieldTags.Lookup("style"); hasStyle {
   266  				propSchema.AddExtension(XTagStyle, styleValue)
   267  			}
   268  
   269  			if fmtValue, hasFmt := structFieldTags.Lookup("fmt"); hasFmt {
   270  				propSchema.AddExtension(XTagFmt, fmtValue)
   271  			}
   272  
   273  			if xmlValue, hasXML := structFieldTags.Lookup("xml"); hasXML {
   274  				propSchema.AddExtension(XTagXML, xmlValue)
   275  			}
   276  
   277  			if jsonTagValue != "" {
   278  				propSchema.AddExtension(XTagJSON, jsonTagValue)
   279  			}
   280  
   281  			if propSchema.Refer.RefString() != "" {
   282  				composedSchema := oas.AllOf(
   283  					propSchema,
   284  					&oas.Schema{
   285  						SchemaObject: propSchema.SchemaObject,
   286  					},
   287  				)
   288  				composedSchema.SpecExtensions = propSchema.SpecExtensions
   289  				structSchema.SetProperty(name, composedSchema, required)
   290  			} else {
   291  				structSchema.SetProperty(name, propSchema, required)
   292  			}
   293  
   294  		}
   295  
   296  		if len(schemas) > 0 {
   297  			return oas.AllOf(append(schemas, structSchema)...)
   298  		}
   299  		return structSchema
   300  	}
   301  	return nil
   302  }
   303  
   304  type StructFieldUniqueChecker map[string]*types.Var
   305  
   306  func (checker StructFieldUniqueChecker) Check(structType *types.Struct, anonymous bool) error {
   307  	for i := 0; i < structType.NumFields(); i++ {
   308  		field := structType.Field(i)
   309  		if !field.Exported() {
   310  			continue
   311  		}
   312  		if field.Anonymous() {
   313  			if named, ok := field.Type().(*types.Named); ok {
   314  				if st, ok := named.Underlying().(*types.Struct); ok {
   315  					if err := checker.Check(st, true); err != nil {
   316  						return err
   317  					}
   318  				}
   319  			}
   320  			continue
   321  		}
   322  		if anonymous {
   323  			if _, ok := checker[field.Name()]; ok {
   324  				return fmt.Errorf("%s.%s already defined in other anonymous field", structType.String(), field.Name())
   325  			}
   326  			checker[field.Name()] = field
   327  		}
   328  	}
   329  	return nil
   330  }
   331  
   332  type VendorExtensible interface {
   333  	AddExtension(key string, value interface{})
   334  }
   335  
   336  func markPointer(vendorExtensible VendorExtensible, count int) {
   337  	vendorExtensible.AddExtension(XPointer, count)
   338  }
   339  
   340  func toDefID(s string) string {
   341  	_, expose := loaderx.GetPkgImportPathAndExpose(s)
   342  	return expose
   343  }
   344  
   345  var (
   346  	rxEnum   = regexp.MustCompile(`swagger:enum`)
   347  	rxStrFmt = regexp.MustCompile(`swagger:strfmt\s+(\S+)([\s\S]+)?$`)
   348  )
   349  
   350  func ParseEnum(doc string) (string, bool) {
   351  	if rxEnum.MatchString(doc) {
   352  		return strings.TrimSpace(strings.Replace(doc, "swagger:enum", "", -1)), true
   353  	}
   354  	return doc, false
   355  }
   356  
   357  func ParseStrfmt(doc string) (string, string) {
   358  	matched := rxStrFmt.FindAllStringSubmatch(doc, -1)
   359  	if len(matched) > 0 {
   360  		return strings.TrimSpace(matched[0][2]), matched[0][1]
   361  	}
   362  	return doc, ""
   363  }
   364  
   365  func getSchemaTypeFromBasicType(basicTypeName string) (tpe oas.Type, format string) {
   366  	switch basicTypeName {
   367  	case "bool":
   368  		return "boolean", ""
   369  	case "byte":
   370  		return "integer", "uint8"
   371  	case "error":
   372  		return "string", ""
   373  	case "float32":
   374  		return "number", "float"
   375  	case "float64":
   376  		return "number", "double"
   377  	case "int":
   378  		return "integer", "int64"
   379  	case "int8":
   380  		return "integer", "int8"
   381  	case "int16":
   382  		return "integer", "int16"
   383  	case "int32":
   384  		return "integer", "int32"
   385  	case "int64":
   386  		return "integer", "int64"
   387  	case "rune":
   388  		return "integer", "int32"
   389  	case "string":
   390  		return "string", ""
   391  	case "uint":
   392  		return "integer", "uint64"
   393  	case "uint16":
   394  		return "integer", "uint16"
   395  	case "uint32":
   396  		return "integer", "uint32"
   397  	case "uint64":
   398  		return "integer", "uint64"
   399  	case "uint8":
   400  		return "integer", "uint8"
   401  	case "uintptr":
   402  		return "integer", "uint64"
   403  	default:
   404  		panic(fmt.Errorf("unsupported type %q", basicTypeName))
   405  	}
   406  }