github.com/profzone/eden-framework@v1.0.10/internal/generator/scanner/definition_scanner.go (about)

     1  package scanner
     2  
     3  import (
     4  	"fmt"
     5  	"github.com/profzone/eden-framework/pkg/packagex"
     6  	"github.com/profzone/eden-framework/pkg/reflectx"
     7  	str "github.com/profzone/eden-framework/pkg/strings"
     8  	"go/types"
     9  	"reflect"
    10  	"regexp"
    11  	"sort"
    12  	"strings"
    13  
    14  	"github.com/go-courier/oas"
    15  	"github.com/sirupsen/logrus"
    16  )
    17  
    18  func NewDefinitionScanner(pkg *packagex.Package) *DefinitionScanner {
    19  	return &DefinitionScanner{
    20  		enumScanner:       NewEnumScanner(pkg),
    21  		pkg:               pkg,
    22  		ioWriterInterface: packagex.NewPackage(pkg.Pkg("io")).TypeName("Writer").Type().Underlying().(*types.Interface),
    23  	}
    24  }
    25  
    26  type DefinitionScanner struct {
    27  	enumInterfaceType *types.Interface
    28  	pkg               *packagex.Package
    29  	enumScanner       *EnumScanner
    30  	definitions       map[*types.TypeName]*oas.Schema
    31  	schemas           map[string]*oas.Schema
    32  	ioWriterInterface *types.Interface
    33  }
    34  
    35  func addExtension(s *oas.Schema, key string, v interface{}) {
    36  	if s == nil {
    37  		return
    38  	}
    39  	if len(s.AllOf) > 0 {
    40  		s.AllOf[len(s.AllOf)-1].AddExtension(key, v)
    41  	} else {
    42  		s.AddExtension(key, v)
    43  	}
    44  }
    45  
    46  func setMetaFromDoc(s *oas.Schema, doc string) {
    47  	if s == nil {
    48  		return
    49  	}
    50  
    51  	lines := strings.Split(doc, "\n")
    52  
    53  	for i := range lines {
    54  		if strings.Index(lines[i], "@deprecated") != -1 {
    55  			s.Deprecated = true
    56  		}
    57  	}
    58  
    59  	description := dropMarkedLines(lines)
    60  
    61  	if len(s.AllOf) > 0 {
    62  		s.AllOf[len(s.AllOf)-1].Description = description
    63  	} else {
    64  		s.Description = description
    65  	}
    66  }
    67  
    68  func (scanner *DefinitionScanner) BindSchemas(openapi *oas.OpenAPI) {
    69  	openapi.Components.Schemas = scanner.schemas
    70  }
    71  
    72  func (scanner *DefinitionScanner) Def(typeName *types.TypeName) *oas.Schema {
    73  	if s, ok := scanner.definitions[typeName]; ok {
    74  		return s
    75  	}
    76  
    77  	logrus.Debugf("scanning Type `%s.%s`", typeName.Pkg().Path(), typeName.Name())
    78  
    79  	if typeName.IsAlias() {
    80  		typeName = typeName.Type().(*types.Named).Obj()
    81  	}
    82  
    83  	doc := scanner.pkg.CommentsOf(scanner.pkg.IdentOf(typeName.Type().(*types.Named).Obj()))
    84  
    85  	// register empty before scan
    86  	// to avoid cycle
    87  	scanner.setDef(typeName, &oas.Schema{})
    88  
    89  	if doc, fmtName := parseStrfmt(doc); fmtName != "" {
    90  		s := oas.NewSchema(oas.TypeString, fmtName)
    91  		setMetaFromDoc(s, doc)
    92  		return scanner.setDef(typeName, s)
    93  	}
    94  
    95  	if doc, typ := parseType(doc); typ != "" {
    96  		s := oas.NewSchema(oas.Type(typ), "")
    97  		setMetaFromDoc(s, doc)
    98  		return scanner.setDef(typeName, s)
    99  	}
   100  
   101  	if reflectx.FromTType(types.NewPointer(typeName.Type())).Implements(reflectx.FromTType(scanner.ioWriterInterface)) {
   102  		return scanner.setDef(typeName, oas.Binary())
   103  	}
   104  
   105  	if typeName.Pkg() != nil {
   106  		if typeName.Pkg().Path() == "time" && typeName.Name() == "Time" {
   107  			return scanner.setDef(typeName, oas.DateTime())
   108  		}
   109  	}
   110  
   111  	doc, hasEnum := ParseEnum(doc)
   112  	if hasEnum {
   113  		enumOptions := scanner.enumScanner.Enum(typeName)
   114  		if enumOptions == nil {
   115  			panic(fmt.Errorf("missing enum option but annotated by openapi:enum"))
   116  		}
   117  		s := oas.String()
   118  		for _, e := range enumOptions {
   119  			s.Enum = append(s.Enum, e.Value)
   120  		}
   121  		s.AddExtension(XEnumOptions, enumOptions)
   122  		return scanner.setDef(typeName, s)
   123  	}
   124  
   125  	s := scanner.GetSchemaByType(typeName.Type().Underlying())
   126  
   127  	setMetaFromDoc(s, doc)
   128  
   129  	return scanner.setDef(typeName, s)
   130  }
   131  
   132  func (scanner *DefinitionScanner) isInternal(typeName *types.TypeName) bool {
   133  	return strings.HasPrefix(typeName.Pkg().Path(), scanner.pkg.PkgPath)
   134  }
   135  
   136  func (scanner *DefinitionScanner) typeUniqueName(typeName *types.TypeName, isExist func(name string) bool) (string, bool) {
   137  	typePkgPath := typeName.Pkg().Path()
   138  	name := typeName.Name()
   139  
   140  	if scanner.isInternal(typeName) {
   141  		pathParts := strings.Split(typePkgPath, "/")
   142  		count := 1
   143  		for isExist(name) {
   144  			name = str.ToUpperCamelCase(pathParts[len(pathParts)-count]) + name
   145  			count++
   146  		}
   147  		return name, true
   148  	}
   149  
   150  	return str.ToUpperCamelCase(typePkgPath) + name, false
   151  }
   152  
   153  func (scanner *DefinitionScanner) reformatSchemas() {
   154  	typeNameList := make([]*types.TypeName, 0)
   155  
   156  	for typeName := range scanner.definitions {
   157  		v := typeName
   158  		typeNameList = append(typeNameList, v)
   159  	}
   160  
   161  	sort.Slice(typeNameList, func(i, j int) bool {
   162  		return scanner.isInternal(typeNameList[i]) && fullTypeName(typeNameList[i]) < fullTypeName(typeNameList[j])
   163  	})
   164  
   165  	schemas := map[string]*oas.Schema{}
   166  
   167  	for _, typeName := range typeNameList {
   168  		name, isInternal := scanner.typeUniqueName(typeName, func(name string) bool {
   169  			_, exists := schemas[name]
   170  			return exists
   171  		})
   172  
   173  		s := scanner.definitions[typeName]
   174  		addExtension(s, XID, name)
   175  		if !isInternal {
   176  			addExtension(s, XGoVendorType, fullTypeName(typeName))
   177  		}
   178  		schemas[name] = s
   179  	}
   180  
   181  	scanner.schemas = schemas
   182  }
   183  
   184  func (scanner *DefinitionScanner) setDef(typeName *types.TypeName, schema *oas.Schema) *oas.Schema {
   185  	if scanner.definitions == nil {
   186  		scanner.definitions = map[*types.TypeName]*oas.Schema{}
   187  	}
   188  	scanner.definitions[typeName] = schema
   189  	scanner.reformatSchemas()
   190  	return schema
   191  }
   192  
   193  func NewSchemaRefer(s *oas.Schema) *SchemaRefer {
   194  	return &SchemaRefer{
   195  		Schema: s,
   196  	}
   197  }
   198  
   199  type SchemaRefer struct {
   200  	*oas.Schema
   201  }
   202  
   203  func (r SchemaRefer) RefString() string {
   204  	s := r.Schema
   205  	if r.Schema.AllOf != nil {
   206  		s = r.AllOf[len(r.Schema.AllOf)-1]
   207  	}
   208  	return oas.NewComponentRefer("schemas", s.Extensions[XID].(string)).RefString()
   209  }
   210  
   211  func (scanner *DefinitionScanner) GetSchemaByType(typ types.Type) *oas.Schema {
   212  	switch t := typ.(type) {
   213  	case *types.Named:
   214  		if t.String() == "mime/multipart.FileHeader" {
   215  			return oas.Binary()
   216  		}
   217  		return oas.RefSchemaByRefer(NewSchemaRefer(scanner.Def(t.Obj())))
   218  	case *types.Interface:
   219  		return &oas.Schema{}
   220  	case *types.Basic:
   221  		typeName, format := getSchemaTypeFromBasicType(reflectx.FromTType(t).Kind().String())
   222  		if typeName != "" {
   223  			return oas.NewSchema(typeName, format)
   224  		}
   225  	case *types.Pointer:
   226  		count := 1
   227  		elem := t.Elem()
   228  
   229  		for {
   230  			if p, ok := elem.(*types.Pointer); ok {
   231  				elem = p.Elem()
   232  				count++
   233  			} else {
   234  				break
   235  			}
   236  		}
   237  
   238  		s := scanner.GetSchemaByType(elem)
   239  		markPointer(s, count)
   240  		return s
   241  	case *types.Map:
   242  		keySchema := scanner.GetSchemaByType(t.Key())
   243  		if keySchema != nil && len(keySchema.Type) > 0 && keySchema.Type != "string" {
   244  			panic(fmt.Errorf("only support map[string]interface{}"))
   245  		}
   246  		return oas.KeyValueOf(keySchema, scanner.GetSchemaByType(t.Elem()))
   247  	case *types.Slice:
   248  		return oas.ItemsOf(scanner.GetSchemaByType(t.Elem()))
   249  	case *types.Array:
   250  		length := uint64(t.Len())
   251  		s := oas.ItemsOf(scanner.GetSchemaByType(t.Elem()))
   252  		s.MaxItems = &length
   253  		s.MinItems = &length
   254  		return s
   255  	case *types.Struct:
   256  		err := (StructFieldUniqueChecker{}).Check(t, false)
   257  		if err != nil {
   258  			panic(fmt.Errorf("type %s: %s", typ, err))
   259  		}
   260  
   261  		structSchema := oas.ObjectOf(nil)
   262  		schemas := make([]*oas.Schema, 0)
   263  
   264  		for i := 0; i < t.NumFields(); i++ {
   265  			field := t.Field(i)
   266  
   267  			if !field.Exported() {
   268  				continue
   269  			}
   270  
   271  			structFieldType := field.Type()
   272  
   273  			tags := reflect.StructTag(t.Tag(i))
   274  
   275  			tagValueForName := tags.Get("json")
   276  			if tagValueForName == "" {
   277  				tagValueForName = tags.Get("name")
   278  			}
   279  
   280  			name, flags := tagValueAndFlagsByTagString(tagValueForName)
   281  			if name == "-" {
   282  				continue
   283  			}
   284  
   285  			if name == "" && field.Anonymous() {
   286  				if field.Type().String() == "bytes.Buffer" {
   287  					structSchema = oas.Binary()
   288  					break
   289  				}
   290  				s := scanner.GetSchemaByType(structFieldType)
   291  				if s != nil {
   292  					schemas = append(schemas, s)
   293  				}
   294  				continue
   295  			}
   296  
   297  			if name == "" {
   298  				name = field.Name()
   299  			}
   300  
   301  			required := true
   302  			if hasOmitempty, ok := flags["omitempty"]; ok {
   303  				required = !hasOmitempty
   304  			}
   305  
   306  			structSchema.SetProperty(
   307  				name,
   308  				scanner.propSchemaByField(field.Name(), structFieldType, tags, name, flags, scanner.pkg.CommentsOf(scanner.pkg.IdentOf(field))),
   309  				required,
   310  			)
   311  		}
   312  
   313  		if len(schemas) > 0 {
   314  			return oas.AllOf(append(schemas, structSchema)...)
   315  		}
   316  
   317  		return structSchema
   318  	}
   319  	return nil
   320  }
   321  
   322  func (scanner *DefinitionScanner) propSchemaByField(
   323  	fieldName string,
   324  	fieldType types.Type,
   325  	tags reflect.StructTag,
   326  	name string,
   327  	flags map[string]bool,
   328  	desc string,
   329  ) *oas.Schema {
   330  	propSchema := scanner.GetSchemaByType(fieldType)
   331  
   332  	refSchema := (*oas.Schema)(nil)
   333  
   334  	if propSchema.Refer != nil {
   335  		refSchema = propSchema
   336  		propSchema = &oas.Schema{}
   337  		propSchema.Extensions = refSchema.Extensions
   338  	}
   339  
   340  	defaultValue := tags.Get("default")
   341  	//validate, hasValidate := tags.Lookup("validate")
   342  
   343  	if flags != nil && flags["string"] {
   344  		propSchema.Type = oas.TypeString
   345  	}
   346  
   347  	if defaultValue != "" {
   348  		propSchema.Default = defaultValue
   349  	}
   350  
   351  	//if hasValidate {
   352  	//	if err := BindSchemaValidationByValidateBytes(propSchema, fieldType, []byte(validate)); err != nil {
   353  	//		panic(err)
   354  	//	}
   355  	//}
   356  
   357  	setMetaFromDoc(propSchema, desc)
   358  	propSchema.AddExtension(XGoFieldName, fieldName)
   359  
   360  	tagKeys := map[string]string{
   361  		"name":     XTagName,
   362  		"mime":     XTagMime,
   363  		"json":     XTagJSON,
   364  		"xml":      XTagXML,
   365  		"validate": XTagValidate,
   366  	}
   367  
   368  	for k, extKey := range tagKeys {
   369  		if v, ok := tags.Lookup(k); ok {
   370  			propSchema.AddExtension(extKey, v)
   371  		}
   372  	}
   373  
   374  	if refSchema != nil {
   375  		return oas.AllOf(
   376  			refSchema,
   377  			propSchema,
   378  		)
   379  	}
   380  
   381  	return propSchema
   382  }
   383  
   384  type StructFieldUniqueChecker map[string]*types.Var
   385  
   386  func (checker StructFieldUniqueChecker) Check(structType *types.Struct, anonymous bool) error {
   387  	for i := 0; i < structType.NumFields(); i++ {
   388  		field := structType.Field(i)
   389  		if !field.Exported() {
   390  			continue
   391  		}
   392  		if field.Anonymous() {
   393  			if named, ok := field.Type().(*types.Named); ok {
   394  				if st, ok := named.Underlying().(*types.Struct); ok {
   395  					if err := checker.Check(st, true); err != nil {
   396  						return err
   397  					}
   398  				}
   399  			}
   400  			continue
   401  		}
   402  		if anonymous {
   403  			if _, ok := checker[field.Name()]; ok {
   404  				return fmt.Errorf("%s.%s already defined in other anonymous field", structType.String(), field.Name())
   405  			}
   406  			checker[field.Name()] = field
   407  		}
   408  	}
   409  	return nil
   410  }
   411  
   412  type VendorExtensible interface {
   413  	AddExtension(key string, value interface{})
   414  }
   415  
   416  func markPointer(vendorExtensible VendorExtensible, count int) {
   417  	vendorExtensible.AddExtension(XGoStarLevel, count)
   418  }
   419  
   420  var (
   421  	reStrFmt = regexp.MustCompile(`open-?api:strfmt\s+(\S+)([\s\S]+)?$`)
   422  	reType   = regexp.MustCompile(`open-?api:type\s+(\S+)([\s\S]+)?$`)
   423  )
   424  
   425  func parseStrfmt(doc string) (string, string) {
   426  	matched := reStrFmt.FindAllStringSubmatch(doc, -1)
   427  	if len(matched) > 0 {
   428  		return strings.TrimSpace(matched[0][2]), matched[0][1]
   429  	}
   430  	return doc, ""
   431  }
   432  
   433  func parseType(doc string) (string, string) {
   434  	matched := reType.FindAllStringSubmatch(doc, -1)
   435  	if len(matched) > 0 {
   436  		return strings.TrimSpace(matched[0][2]), matched[0][1]
   437  	}
   438  	return doc, ""
   439  }
   440  
   441  var basicTypeToSchemaType = map[string][2]string{
   442  	"invalid": {"null", ""},
   443  
   444  	"bool":    {"boolean", ""},
   445  	"error":   {"string", "string"},
   446  	"float32": {"number", "float"},
   447  	"float64": {"number", "double"},
   448  
   449  	"int":   {"integer", "int32"},
   450  	"int8":  {"integer", "int8"},
   451  	"int16": {"integer", "int16"},
   452  	"int32": {"integer", "int32"},
   453  	"int64": {"integer", "int64"},
   454  
   455  	"rune": {"integer", "int32"},
   456  
   457  	"uint":   {"integer", "uint32"},
   458  	"uint8":  {"integer", "uint8"},
   459  	"uint16": {"integer", "uint16"},
   460  	"uint32": {"integer", "uint32"},
   461  	"uint64": {"integer", "uint64"},
   462  
   463  	"byte": {"integer", "uint8"},
   464  
   465  	"string": {"string", ""},
   466  }
   467  
   468  func getSchemaTypeFromBasicType(basicTypeName string) (typ oas.Type, format string) {
   469  	if schemaTypeAndFormat, ok := basicTypeToSchemaType[basicTypeName]; ok {
   470  		return oas.Type(schemaTypeAndFormat[0]), schemaTypeAndFormat[1]
   471  	}
   472  	panic(fmt.Errorf("unsupported type %q", basicTypeName))
   473  }