github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/courier/swagger/gen/definition_scanner.go (about)

     1  package gen
     2  
     3  import (
     4  	"fmt"
     5  	"go/types"
     6  	"reflect"
     7  	"regexp"
     8  	"strings"
     9  
    10  	"github.com/morlay/oas"
    11  	"github.com/sirupsen/logrus"
    12  	"golang.org/x/tools/go/loader"
    13  
    14  	"github.com/johnnyeven/libtools/codegen/loaderx"
    15  )
    16  
    17  const (
    18  	XNamed       = `x-go-named`
    19  	XField       = `x-go-name`
    20  	XTagJSON     = `x-go-json`
    21  	XTagName     = `x-tag-name`
    22  	XTagXML      = `x-tag-xml`
    23  	XTagStyle    = `x-tag-style`
    24  	XTagFmt      = `x-tag-fmt`
    25  	XTagValidate = `x-go-validate`
    26  	XPointer     = "x-pointer"
    27  	XEnumValues  = `x-enum-values`
    28  	XEnumLabels  = `x-enum-labels`
    29  	XEnumVals    = `x-enum-vals`
    30  )
    31  
    32  func NewDefinitionScanner(program *loader.Program) *DefinitionScanner {
    33  	return &DefinitionScanner{
    34  		EnumScanner: NewEnumScanner(program),
    35  		program:     program,
    36  	}
    37  }
    38  
    39  type DefinitionScanner struct {
    40  	EnumScanner *EnumScanner
    41  	program     *loader.Program
    42  	definitions map[*types.TypeName]*oas.Schema
    43  }
    44  
    45  func (scanner *DefinitionScanner) BindSchemas(openapi *oas.OpenAPI) {
    46  	for typeName, schema := range scanner.definitions {
    47  		schema.AddExtension(XNamed, fmt.Sprintf("%s.%s", typeName.Pkg().Path(), typeName.Name()))
    48  		defKey := toDefID(typeName.Type().String())
    49  		if _, exists := openapi.Components.Schemas[defKey]; exists {
    50  			logrus.Panicf("`%s` already used by %s", defKey, typeName.String())
    51  		} else {
    52  			openapi.AddSchema(toDefID(typeName.Type().String()), schema)
    53  		}
    54  	}
    55  	return
    56  }
    57  
    58  func (scanner *DefinitionScanner) getSchemaByTypeString(typeString string) *oas.Schema {
    59  	pkgImportPath, _ := loaderx.GetPkgImportPathAndExpose(typeString)
    60  	pkgInfo := scanner.program.Package(loaderx.ResolvePkgImport(pkgImportPath))
    61  	for _, def := range pkgInfo.Defs {
    62  		if typeName, ok := def.(*types.TypeName); ok {
    63  			if typeName.Type().String() == typeString {
    64  				return scanner.getSchemaByType(typeName.Type())
    65  			}
    66  		}
    67  	}
    68  	return nil
    69  }
    70  
    71  func (scanner *DefinitionScanner) Def(typeName *types.TypeName) *oas.Schema {
    72  	if s, ok := scanner.definitions[typeName]; ok {
    73  		return s
    74  	}
    75  
    76  	if typeName.IsAlias() {
    77  		typeName = typeName.Type().(*types.Named).Obj()
    78  	}
    79  
    80  	doc := docOfTypeName(typeName.Type().(*types.Named).Obj(), scanner.program)
    81  
    82  	if doc, fmtName := ParseStrfmt(doc); fmtName != "" {
    83  		return scanner.addDef(typeName, oas.NewSchema(oas.TypeString, fmtName).WithDesc(doc))
    84  	}
    85  
    86  	// todo
    87  	if typeName.Name() == "Time" {
    88  		return scanner.addDef(typeName, oas.DateTime().WithDesc(doc))
    89  	}
    90  
    91  	doc, hasEnum := ParseEnum(doc)
    92  	if hasEnum {
    93  		enum := scanner.EnumScanner.Enum(typeName)
    94  		if len(enum) == 2 {
    95  			values := enum.Values()
    96  			if values[0] == "FALSE" && values[1] == "TRUE" {
    97  				return scanner.addDef(typeName, oas.Boolean())
    98  			}
    99  		}
   100  		if enum == nil {
   101  			panic(fmt.Errorf("missing enum option but annotated by swagger:enum"))
   102  		}
   103  		return scanner.addDef(typeName, enum.ToSchema().WithDesc(doc))
   104  	}
   105  
   106  	return scanner.addDef(typeName, scanner.getSchemaByType(typeName.Type().Underlying()).WithDesc(doc))
   107  }
   108  
   109  func (scanner *DefinitionScanner) addDef(typeName *types.TypeName, schema *oas.Schema) *oas.Schema {
   110  	if scanner.definitions == nil {
   111  		scanner.definitions = map[*types.TypeName]*oas.Schema{}
   112  	}
   113  	scanner.definitions[typeName] = schema
   114  	return schema
   115  }
   116  
   117  func (scanner *DefinitionScanner) getSchemaByType(tpe types.Type) *oas.Schema {
   118  	switch tpe.(type) {
   119  	case *types.Interface:
   120  		return &oas.Schema{
   121  			SchemaObject: oas.SchemaObject{
   122  				Type: oas.TypeObject,
   123  				AdditionalProperties: &oas.SchemaOrBool{
   124  					Allows: true,
   125  				},
   126  			},
   127  		}
   128  	case *types.Named:
   129  		named := tpe.(*types.Named)
   130  		if named.String() == "mime/multipart.FileHeader" {
   131  			return oas.Binary()
   132  		}
   133  		scanner.Def(named.Obj())
   134  		return oas.RefSchema(fmt.Sprintf("#/components/schemas/%s", toDefID(named.String())))
   135  	case *types.Basic:
   136  		typeName, format := getSchemaTypeFromBasicType(tpe.(*types.Basic).Name())
   137  		if typeName != "" {
   138  			return oas.NewSchema(typeName, format)
   139  		}
   140  	case *types.Pointer:
   141  		count := 0
   142  		pointer := tpe.(*types.Pointer)
   143  		elem := pointer.Elem()
   144  		for pointer != nil {
   145  			elem = pointer.Elem()
   146  			pointer, _ = pointer.Elem().(*types.Pointer)
   147  			count++
   148  		}
   149  		s := scanner.getSchemaByType(elem)
   150  		markPointer(s, count)
   151  		return s
   152  	case *types.Map:
   153  		keySchema := scanner.getSchemaByType(tpe.(*types.Map).Key())
   154  		if keySchema != nil && len(keySchema.Type) > 0 && keySchema.Type != "string" {
   155  			panic(fmt.Errorf("only support map[string]interface{}"))
   156  		}
   157  		return oas.MapOf(scanner.getSchemaByType(tpe.(*types.Map).Elem()))
   158  	case *types.Slice:
   159  		return oas.ItemsOf(scanner.getSchemaByType(tpe.(*types.Slice).Elem()))
   160  	case *types.Array:
   161  		typArray := tpe.(*types.Array)
   162  		length := typArray.Len()
   163  		return oas.ItemsOf(scanner.getSchemaByType(typArray.Elem())).WithValidation(&oas.SchemaValidation{
   164  			MaxItems: &length,
   165  			MinItems: &length,
   166  		})
   167  	case *types.Struct:
   168  		var structType = tpe.(*types.Struct)
   169  
   170  		err := StructFieldUniqueChecker{}.Check(structType, false)
   171  		if err != nil {
   172  			panic(fmt.Errorf("type %s: %s", tpe, err))
   173  		}
   174  
   175  		var structSchema = oas.ObjectOf(nil)
   176  		var schemas []*oas.Schema
   177  
   178  		for i := 0; i < structType.NumFields(); i++ {
   179  			field := structType.Field(i)
   180  
   181  			if !field.Exported() {
   182  				continue
   183  			}
   184  
   185  			structFieldType := field.Type()
   186  			structFieldTags := reflect.StructTag(structType.Tag(i))
   187  			jsonTagValue := structFieldTags.Get("json")
   188  
   189  			name, flags := getTagNameAndFlags(jsonTagValue)
   190  			if name == "-" {
   191  				continue
   192  			}
   193  
   194  			if name == "" && field.Anonymous() {
   195  				s := scanner.getSchemaByType(structFieldType)
   196  				if s != nil {
   197  					schemas = append(schemas, s)
   198  				}
   199  				continue
   200  			}
   201  
   202  			if name == "" {
   203  				name = field.Name()
   204  			}
   205  
   206  			defaultValue, hasDefault := structFieldTags.Lookup("default")
   207  			validate, hasValidate := structFieldTags.Lookup("validate")
   208  
   209  			required := true
   210  			if hasOmitempty, ok := flags["omitempty"]; ok {
   211  				required = !hasOmitempty
   212  			} else {
   213  				// todo don't use non-default as required
   214  				required = !hasDefault
   215  			}
   216  
   217  			propSchema := scanner.getSchemaByType(structFieldType)
   218  
   219  			if flags != nil && flags["string"] {
   220  				propSchema.Type = oas.TypeString
   221  			}
   222  
   223  			if defaultValue != "" {
   224  				propSchema.Default = defaultValue
   225  			}
   226  
   227  			if hasValidate {
   228  				BindValidateFromValidateTagString(propSchema, validate)
   229  			}
   230  
   231  			propSchema = propSchema.WithDesc(docOfTypeName(field, scanner.program))
   232  			propSchema.AddExtension(XField, field.Name())
   233  
   234  			if nameValue, hasName := structFieldTags.Lookup("name"); hasName {
   235  				propSchema.AddExtension(XTagName, nameValue)
   236  			}
   237  
   238  			if styleValue, hasStyle := structFieldTags.Lookup("style"); hasStyle {
   239  				propSchema.AddExtension(XTagStyle, styleValue)
   240  			}
   241  
   242  			if fmtValue, hasFmt := structFieldTags.Lookup("fmt"); hasFmt {
   243  				propSchema.AddExtension(XTagFmt, fmtValue)
   244  			}
   245  
   246  			if xmlValue, hasXML := structFieldTags.Lookup("xml"); hasXML {
   247  				propSchema.AddExtension(XTagXML, xmlValue)
   248  			}
   249  
   250  			if jsonTagValue != "" {
   251  				propSchema.AddExtension(XTagJSON, jsonTagValue)
   252  			}
   253  
   254  			if propSchema.Ref != "" {
   255  				composedSchema := oas.AllOf(
   256  					propSchema,
   257  					&oas.Schema{
   258  						SchemaObject: propSchema.SchemaObject,
   259  					},
   260  				)
   261  				composedSchema.SpecExtensions = propSchema.SpecExtensions
   262  				structSchema.SetProperty(name, composedSchema, required)
   263  			} else {
   264  				structSchema.SetProperty(name, propSchema, required)
   265  			}
   266  
   267  		}
   268  
   269  		if len(schemas) > 0 {
   270  			return oas.AllOf(append(schemas, structSchema)...)
   271  		}
   272  		return structSchema
   273  	}
   274  	return nil
   275  }
   276  
   277  type StructFieldUniqueChecker map[string]*types.Var
   278  
   279  func (checker StructFieldUniqueChecker) Check(structType *types.Struct, anonymous bool) error {
   280  	for i := 0; i < structType.NumFields(); i++ {
   281  		field := structType.Field(i)
   282  		if !field.Exported() {
   283  			continue
   284  		}
   285  		if field.Anonymous() {
   286  			if named, ok := field.Type().(*types.Named); ok {
   287  				if st, ok := named.Underlying().(*types.Struct); ok {
   288  					if err := checker.Check(st, true); err != nil {
   289  						return err
   290  					}
   291  				}
   292  			}
   293  			continue
   294  		}
   295  		if anonymous {
   296  			if _, ok := checker[field.Name()]; ok {
   297  				return fmt.Errorf("%s.%s already defined in other anonymous field", structType.String(), field.Name())
   298  			}
   299  			checker[field.Name()] = field
   300  		}
   301  	}
   302  	return nil
   303  }
   304  
   305  type VendorExtensible interface {
   306  	AddExtension(key string, value interface{})
   307  }
   308  
   309  func markPointer(vendorExtensible VendorExtensible, count int) {
   310  	vendorExtensible.AddExtension(XPointer, count)
   311  }
   312  
   313  func toDefID(s string) string {
   314  	_, expose := loaderx.GetPkgImportPathAndExpose(s)
   315  	return expose
   316  }
   317  
   318  var (
   319  	rxEnum   = regexp.MustCompile(`swagger:enum`)
   320  	rxStrFmt = regexp.MustCompile(`swagger:strfmt\s+(\S+)([\s\S]+)?$`)
   321  )
   322  
   323  func ParseEnum(doc string) (string, bool) {
   324  	if rxEnum.MatchString(doc) {
   325  		return strings.TrimSpace(strings.Replace(doc, "swagger:enum", "", -1)), true
   326  	}
   327  	return doc, false
   328  }
   329  
   330  func ParseStrfmt(doc string) (string, string) {
   331  	matched := rxStrFmt.FindAllStringSubmatch(doc, -1)
   332  	if len(matched) > 0 {
   333  		return strings.TrimSpace(matched[0][2]), matched[0][1]
   334  	}
   335  	return doc, ""
   336  }
   337  
   338  func getSchemaTypeFromBasicType(basicTypeName string) (tpe oas.Type, format string) {
   339  	switch basicTypeName {
   340  	case "bool":
   341  		return "boolean", ""
   342  	case "byte":
   343  		return "integer", "uint8"
   344  	case "error":
   345  		return "string", ""
   346  	case "float32":
   347  		return "number", "float"
   348  	case "float64":
   349  		return "number", "double"
   350  	case "int":
   351  		return "integer", "int64"
   352  	case "int8":
   353  		return "integer", "int8"
   354  	case "int16":
   355  		return "integer", "int16"
   356  	case "int32":
   357  		return "integer", "int32"
   358  	case "int64":
   359  		return "integer", "int64"
   360  	case "rune":
   361  		return "integer", "int32"
   362  	case "string":
   363  		return "string", ""
   364  	case "uint":
   365  		return "integer", "uint64"
   366  	case "uint16":
   367  		return "integer", "uint16"
   368  	case "uint32":
   369  		return "integer", "uint32"
   370  	case "uint64":
   371  		return "integer", "uint64"
   372  	case "uint8":
   373  		return "integer", "uint8"
   374  	case "uintptr":
   375  		return "integer", "uint64"
   376  	default:
   377  		panic(fmt.Errorf("unsupported type %q", basicTypeName))
   378  	}
   379  }