github.com/artisanhe/tools@v1.0.1-0.20210607022958-19a8fef2eb04/courier/swagger/gen/definition_scanner.go (about) 1 package gen 2 3 import ( 4 "fmt" 5 "go/types" 6 "reflect" 7 "regexp" 8 "strings" 9 10 "github.com/go-courier/ptr" 11 12 "github.com/artisanhe/tools/codegen" 13 "github.com/go-courier/oas" 14 "github.com/sirupsen/logrus" 15 "golang.org/x/tools/go/loader" 16 17 "github.com/artisanhe/tools/codegen/loaderx" 18 ) 19 20 const ( 21 XNamed = `x-go-named` 22 XField = `x-go-name` 23 XTagJSON = `x-go-json` 24 XTagName = `x-tag-name` 25 XTagXML = `x-tag-xml` 26 XTagStyle = `x-tag-style` 27 XTagFmt = `x-tag-fmt` 28 XTagValidate = `x-go-validate` 29 XPointer = "x-pointer" 30 XEnumValues = `x-enum-values` 31 XEnumLabels = `x-enum-labels` 32 XEnumVals = `x-enum-vals` 33 ) 34 35 func NewDefinitionScanner(program *loader.Program) *DefinitionScanner { 36 return &DefinitionScanner{ 37 EnumScanner: NewEnumScanner(program), 38 program: program, 39 } 40 } 41 42 type DefinitionScanner struct { 43 EnumScanner *EnumScanner 44 program *loader.Program 45 definitions map[*types.TypeName]*oas.Schema 46 } 47 48 func (scanner *DefinitionScanner) BindSchemas(openapi *oas.OpenAPI) { 49 for typeName, schema := range scanner.definitions { 50 schema.AddExtension(XNamed, fmt.Sprintf("%s.%s", typeName.Pkg().Path(), typeName.Name())) 51 defKey := toDefID(typeName.Type().String()) 52 if _, exists := openapi.Components.Schemas[defKey]; exists { 53 logrus.Panicf("`%s` already used by %s", defKey, typeName.String()) 54 } else { 55 openapi.AddSchema(toDefID(typeName.Type().String()), schema) 56 } 57 } 58 return 59 } 60 61 func Package(program *loader.Program, path string) *loader.PackageInfo { 62 for _, info := range program.AllPackages { 63 pkgPath := info.Pkg.Path() 64 if path == pkgPath || path == codegen.DeVendor(pkgPath) || codegen.DeVendor(path) == codegen.DeVendor(pkgPath) { 65 return info 66 } 67 } 68 for _, info := range program.Created { 69 pkgPath := info.Pkg.Path() 70 if path == pkgPath || path == codegen.DeVendor(pkgPath) || codegen.DeVendor(path) == codegen.DeVendor(pkgPath) { 71 return info 72 } 73 } 74 return nil 75 } 76 77 func (scanner *DefinitionScanner) getSchemaByTypeString(typeString string) *oas.Schema { 78 pkgImportPath, _ := loaderx.GetPkgImportPathAndExpose(typeString) 79 pkgImportPath = loaderx.ResolvePkgImport(pkgImportPath) 80 pkgInfo := Package(scanner.program, pkgImportPath) 81 if pkgInfo == nil { 82 panic(fmt.Errorf("missing pkg %s of %s", pkgImportPath, typeString)) 83 } 84 85 for _, def := range pkgInfo.Defs { 86 if typeName, ok := def.(*types.TypeName); ok { 87 if typeName.Type().String() == typeString { 88 return scanner.getSchemaByType(typeName.Type()) 89 } 90 } 91 } 92 return nil 93 } 94 95 func (scanner *DefinitionScanner) Def(typeName *types.TypeName) *oas.Schema { 96 if s, ok := scanner.definitions[typeName]; ok { 97 return s 98 } 99 100 if typeName.IsAlias() { 101 typeName = typeName.Type().(*types.Named).Obj() 102 } 103 104 doc := docOfTypeName(typeName.Type().(*types.Named).Obj(), scanner.program) 105 106 if doc, fmtName := ParseStrfmt(doc); fmtName != "" { 107 return scanner.addDef(typeName, oas.NewSchema(oas.TypeString, fmtName).WithDesc(doc)) 108 } 109 110 // todo 111 if typeName.Name() == "Time" { 112 return scanner.addDef(typeName, oas.DateTime().WithDesc(doc)) 113 } 114 115 doc, hasEnum := ParseEnum(doc) 116 if hasEnum { 117 enum := scanner.EnumScanner.Enum(typeName) 118 if len(enum) == 2 { 119 values := enum.Values() 120 if values[0] == "FALSE" && values[1] == "TRUE" { 121 return scanner.addDef(typeName, oas.Boolean()) 122 } 123 } 124 if enum == nil { 125 panic(fmt.Errorf("missing enum option but annotated by swagger:enum")) 126 } 127 return scanner.addDef(typeName, enum.ToSchema().WithDesc(doc)) 128 } 129 130 return scanner.addDef(typeName, scanner.getSchemaByType(typeName.Type().Underlying()).WithDesc(doc)) 131 } 132 133 func (scanner *DefinitionScanner) addDef(typeName *types.TypeName, schema *oas.Schema) *oas.Schema { 134 if scanner.definitions == nil { 135 scanner.definitions = map[*types.TypeName]*oas.Schema{} 136 } 137 scanner.definitions[typeName] = schema 138 return schema 139 } 140 141 func (scanner *DefinitionScanner) getSchemaByType(tpe types.Type) *oas.Schema { 142 switch tpe.(type) { 143 case *types.Interface: 144 return &oas.Schema{ 145 SchemaObject: oas.SchemaObject{ 146 Type: oas.TypeObject, 147 AdditionalProperties: &oas.SchemaOrBool{ 148 Allows: true, 149 }, 150 }, 151 } 152 case *types.Named: 153 named := tpe.(*types.Named) 154 if named.String() == "mime/multipart.FileHeader" { 155 return oas.Binary() 156 } 157 scanner.Def(named.Obj()) 158 return oas.RefSchema(fmt.Sprintf("#/components/schemas/%s", toDefID(named.String()))) 159 case *types.Basic: 160 typeName, format := getSchemaTypeFromBasicType(tpe.(*types.Basic).Name()) 161 if typeName != "" { 162 return oas.NewSchema(typeName, format) 163 } 164 case *types.Pointer: 165 count := 0 166 pointer := tpe.(*types.Pointer) 167 elem := pointer.Elem() 168 for pointer != nil { 169 elem = pointer.Elem() 170 pointer, _ = pointer.Elem().(*types.Pointer) 171 count++ 172 } 173 s := scanner.getSchemaByType(elem) 174 markPointer(s, count) 175 return s 176 case *types.Map: 177 keySchema := scanner.getSchemaByType(tpe.(*types.Map).Key()) 178 if keySchema != nil && len(keySchema.Type) > 0 && keySchema.Type != "string" { 179 panic(fmt.Errorf("only support map[string]interface{}")) 180 } 181 return oas.MapOf(scanner.getSchemaByType(tpe.(*types.Map).Elem())) 182 case *types.Slice: 183 return oas.ItemsOf(scanner.getSchemaByType(tpe.(*types.Slice).Elem())) 184 case *types.Array: 185 typArray := tpe.(*types.Array) 186 length := typArray.Len() 187 return oas.ItemsOf(scanner.getSchemaByType(typArray.Elem())).WithValidation(&oas.SchemaValidation{ 188 MaxItems: ptr.Uint64(uint64(length)), 189 MinItems: ptr.Uint64(uint64(length)), 190 }) 191 case *types.Struct: 192 var structType = tpe.(*types.Struct) 193 194 err := StructFieldUniqueChecker{}.Check(structType, false) 195 if err != nil { 196 panic(fmt.Errorf("type %s: %s", tpe, err)) 197 } 198 199 var structSchema = oas.ObjectOf(nil) 200 var schemas []*oas.Schema 201 202 for i := 0; i < structType.NumFields(); i++ { 203 field := structType.Field(i) 204 205 if !field.Exported() { 206 continue 207 } 208 209 structFieldType := field.Type() 210 structFieldTags := reflect.StructTag(structType.Tag(i)) 211 jsonTagValue := structFieldTags.Get("json") 212 if jsonTagValue == "" { 213 jsonTagValue = structFieldTags.Get("name") 214 } 215 216 name, flags := getTagNameAndFlags(jsonTagValue) 217 if name == "-" { 218 continue 219 } 220 221 if name == "" && field.Anonymous() { 222 s := scanner.getSchemaByType(structFieldType) 223 if s != nil { 224 schemas = append(schemas, s) 225 } 226 continue 227 } 228 229 if name == "" { 230 name = field.Name() 231 } 232 233 defaultValue, hasDefault := structFieldTags.Lookup("default") 234 validate, hasValidate := structFieldTags.Lookup("validate") 235 236 required := true 237 if hasOmitempty, ok := flags["omitempty"]; ok { 238 required = !hasOmitempty 239 } else { 240 // todo don't use non-default as required 241 required = !hasDefault 242 } 243 244 propSchema := scanner.getSchemaByType(structFieldType) 245 246 if flags != nil && flags["string"] { 247 propSchema.Type = oas.TypeString 248 } 249 250 if defaultValue != "" { 251 propSchema.Default = defaultValue 252 } 253 254 if hasValidate { 255 BindValidateFromValidateTagString(propSchema, validate) 256 } 257 258 propSchema = propSchema.WithDesc(docOfTypeName(field, scanner.program)) 259 propSchema.AddExtension(XField, field.Name()) 260 261 if nameValue, hasName := structFieldTags.Lookup("name"); hasName { 262 propSchema.AddExtension(XTagName, nameValue) 263 } 264 265 if styleValue, hasStyle := structFieldTags.Lookup("style"); hasStyle { 266 propSchema.AddExtension(XTagStyle, styleValue) 267 } 268 269 if fmtValue, hasFmt := structFieldTags.Lookup("fmt"); hasFmt { 270 propSchema.AddExtension(XTagFmt, fmtValue) 271 } 272 273 if xmlValue, hasXML := structFieldTags.Lookup("xml"); hasXML { 274 propSchema.AddExtension(XTagXML, xmlValue) 275 } 276 277 if jsonTagValue != "" { 278 propSchema.AddExtension(XTagJSON, jsonTagValue) 279 } 280 281 if propSchema.Refer.RefString() != "" { 282 composedSchema := oas.AllOf( 283 propSchema, 284 &oas.Schema{ 285 SchemaObject: propSchema.SchemaObject, 286 }, 287 ) 288 composedSchema.SpecExtensions = propSchema.SpecExtensions 289 structSchema.SetProperty(name, composedSchema, required) 290 } else { 291 structSchema.SetProperty(name, propSchema, required) 292 } 293 294 } 295 296 if len(schemas) > 0 { 297 return oas.AllOf(append(schemas, structSchema)...) 298 } 299 return structSchema 300 } 301 return nil 302 } 303 304 type StructFieldUniqueChecker map[string]*types.Var 305 306 func (checker StructFieldUniqueChecker) Check(structType *types.Struct, anonymous bool) error { 307 for i := 0; i < structType.NumFields(); i++ { 308 field := structType.Field(i) 309 if !field.Exported() { 310 continue 311 } 312 if field.Anonymous() { 313 if named, ok := field.Type().(*types.Named); ok { 314 if st, ok := named.Underlying().(*types.Struct); ok { 315 if err := checker.Check(st, true); err != nil { 316 return err 317 } 318 } 319 } 320 continue 321 } 322 if anonymous { 323 if _, ok := checker[field.Name()]; ok { 324 return fmt.Errorf("%s.%s already defined in other anonymous field", structType.String(), field.Name()) 325 } 326 checker[field.Name()] = field 327 } 328 } 329 return nil 330 } 331 332 type VendorExtensible interface { 333 AddExtension(key string, value interface{}) 334 } 335 336 func markPointer(vendorExtensible VendorExtensible, count int) { 337 vendorExtensible.AddExtension(XPointer, count) 338 } 339 340 func toDefID(s string) string { 341 _, expose := loaderx.GetPkgImportPathAndExpose(s) 342 return expose 343 } 344 345 var ( 346 rxEnum = regexp.MustCompile(`swagger:enum`) 347 rxStrFmt = regexp.MustCompile(`swagger:strfmt\s+(\S+)([\s\S]+)?$`) 348 ) 349 350 func ParseEnum(doc string) (string, bool) { 351 if rxEnum.MatchString(doc) { 352 return strings.TrimSpace(strings.Replace(doc, "swagger:enum", "", -1)), true 353 } 354 return doc, false 355 } 356 357 func ParseStrfmt(doc string) (string, string) { 358 matched := rxStrFmt.FindAllStringSubmatch(doc, -1) 359 if len(matched) > 0 { 360 return strings.TrimSpace(matched[0][2]), matched[0][1] 361 } 362 return doc, "" 363 } 364 365 func getSchemaTypeFromBasicType(basicTypeName string) (tpe oas.Type, format string) { 366 switch basicTypeName { 367 case "bool": 368 return "boolean", "" 369 case "byte": 370 return "integer", "uint8" 371 case "error": 372 return "string", "" 373 case "float32": 374 return "number", "float" 375 case "float64": 376 return "number", "double" 377 case "int": 378 return "integer", "int64" 379 case "int8": 380 return "integer", "int8" 381 case "int16": 382 return "integer", "int16" 383 case "int32": 384 return "integer", "int32" 385 case "int64": 386 return "integer", "int64" 387 case "rune": 388 return "integer", "int32" 389 case "string": 390 return "string", "" 391 case "uint": 392 return "integer", "uint64" 393 case "uint16": 394 return "integer", "uint16" 395 case "uint32": 396 return "integer", "uint32" 397 case "uint64": 398 return "integer", "uint64" 399 case "uint8": 400 return "integer", "uint8" 401 case "uintptr": 402 return "integer", "uint64" 403 default: 404 panic(fmt.Errorf("unsupported type %q", basicTypeName)) 405 } 406 }