github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/courier/swagger/gen/definition_scanner.go (about) 1 package gen 2 3 import ( 4 "fmt" 5 "go/types" 6 "reflect" 7 "regexp" 8 "strings" 9 10 "github.com/morlay/oas" 11 "github.com/sirupsen/logrus" 12 "golang.org/x/tools/go/loader" 13 14 "github.com/johnnyeven/libtools/codegen/loaderx" 15 ) 16 17 const ( 18 XNamed = `x-go-named` 19 XField = `x-go-name` 20 XTagJSON = `x-go-json` 21 XTagName = `x-tag-name` 22 XTagXML = `x-tag-xml` 23 XTagStyle = `x-tag-style` 24 XTagFmt = `x-tag-fmt` 25 XTagValidate = `x-go-validate` 26 XPointer = "x-pointer" 27 XEnumValues = `x-enum-values` 28 XEnumLabels = `x-enum-labels` 29 XEnumVals = `x-enum-vals` 30 ) 31 32 func NewDefinitionScanner(program *loader.Program) *DefinitionScanner { 33 return &DefinitionScanner{ 34 EnumScanner: NewEnumScanner(program), 35 program: program, 36 } 37 } 38 39 type DefinitionScanner struct { 40 EnumScanner *EnumScanner 41 program *loader.Program 42 definitions map[*types.TypeName]*oas.Schema 43 } 44 45 func (scanner *DefinitionScanner) BindSchemas(openapi *oas.OpenAPI) { 46 for typeName, schema := range scanner.definitions { 47 schema.AddExtension(XNamed, fmt.Sprintf("%s.%s", typeName.Pkg().Path(), typeName.Name())) 48 defKey := toDefID(typeName.Type().String()) 49 if _, exists := openapi.Components.Schemas[defKey]; exists { 50 logrus.Panicf("`%s` already used by %s", defKey, typeName.String()) 51 } else { 52 openapi.AddSchema(toDefID(typeName.Type().String()), schema) 53 } 54 } 55 return 56 } 57 58 func (scanner *DefinitionScanner) getSchemaByTypeString(typeString string) *oas.Schema { 59 pkgImportPath, _ := loaderx.GetPkgImportPathAndExpose(typeString) 60 pkgInfo := scanner.program.Package(loaderx.ResolvePkgImport(pkgImportPath)) 61 for _, def := range pkgInfo.Defs { 62 if typeName, ok := def.(*types.TypeName); ok { 63 if typeName.Type().String() == typeString { 64 return scanner.getSchemaByType(typeName.Type()) 65 } 66 } 67 } 68 return nil 69 } 70 71 func (scanner *DefinitionScanner) Def(typeName *types.TypeName) *oas.Schema { 72 if s, ok := scanner.definitions[typeName]; ok { 73 return s 74 } 75 76 if typeName.IsAlias() { 77 typeName = typeName.Type().(*types.Named).Obj() 78 } 79 80 doc := docOfTypeName(typeName.Type().(*types.Named).Obj(), scanner.program) 81 82 if doc, fmtName := ParseStrfmt(doc); fmtName != "" { 83 return scanner.addDef(typeName, oas.NewSchema(oas.TypeString, fmtName).WithDesc(doc)) 84 } 85 86 // todo 87 if typeName.Name() == "Time" { 88 return scanner.addDef(typeName, oas.DateTime().WithDesc(doc)) 89 } 90 91 doc, hasEnum := ParseEnum(doc) 92 if hasEnum { 93 enum := scanner.EnumScanner.Enum(typeName) 94 if len(enum) == 2 { 95 values := enum.Values() 96 if values[0] == "FALSE" && values[1] == "TRUE" { 97 return scanner.addDef(typeName, oas.Boolean()) 98 } 99 } 100 if enum == nil { 101 panic(fmt.Errorf("missing enum option but annotated by swagger:enum")) 102 } 103 return scanner.addDef(typeName, enum.ToSchema().WithDesc(doc)) 104 } 105 106 return scanner.addDef(typeName, scanner.getSchemaByType(typeName.Type().Underlying()).WithDesc(doc)) 107 } 108 109 func (scanner *DefinitionScanner) addDef(typeName *types.TypeName, schema *oas.Schema) *oas.Schema { 110 if scanner.definitions == nil { 111 scanner.definitions = map[*types.TypeName]*oas.Schema{} 112 } 113 scanner.definitions[typeName] = schema 114 return schema 115 } 116 117 func (scanner *DefinitionScanner) getSchemaByType(tpe types.Type) *oas.Schema { 118 switch tpe.(type) { 119 case *types.Interface: 120 return &oas.Schema{ 121 SchemaObject: oas.SchemaObject{ 122 Type: oas.TypeObject, 123 AdditionalProperties: &oas.SchemaOrBool{ 124 Allows: true, 125 }, 126 }, 127 } 128 case *types.Named: 129 named := tpe.(*types.Named) 130 if named.String() == "mime/multipart.FileHeader" { 131 return oas.Binary() 132 } 133 scanner.Def(named.Obj()) 134 return oas.RefSchema(fmt.Sprintf("#/components/schemas/%s", toDefID(named.String()))) 135 case *types.Basic: 136 typeName, format := getSchemaTypeFromBasicType(tpe.(*types.Basic).Name()) 137 if typeName != "" { 138 return oas.NewSchema(typeName, format) 139 } 140 case *types.Pointer: 141 count := 0 142 pointer := tpe.(*types.Pointer) 143 elem := pointer.Elem() 144 for pointer != nil { 145 elem = pointer.Elem() 146 pointer, _ = pointer.Elem().(*types.Pointer) 147 count++ 148 } 149 s := scanner.getSchemaByType(elem) 150 markPointer(s, count) 151 return s 152 case *types.Map: 153 keySchema := scanner.getSchemaByType(tpe.(*types.Map).Key()) 154 if keySchema != nil && len(keySchema.Type) > 0 && keySchema.Type != "string" { 155 panic(fmt.Errorf("only support map[string]interface{}")) 156 } 157 return oas.MapOf(scanner.getSchemaByType(tpe.(*types.Map).Elem())) 158 case *types.Slice: 159 return oas.ItemsOf(scanner.getSchemaByType(tpe.(*types.Slice).Elem())) 160 case *types.Array: 161 typArray := tpe.(*types.Array) 162 length := typArray.Len() 163 return oas.ItemsOf(scanner.getSchemaByType(typArray.Elem())).WithValidation(&oas.SchemaValidation{ 164 MaxItems: &length, 165 MinItems: &length, 166 }) 167 case *types.Struct: 168 var structType = tpe.(*types.Struct) 169 170 err := StructFieldUniqueChecker{}.Check(structType, false) 171 if err != nil { 172 panic(fmt.Errorf("type %s: %s", tpe, err)) 173 } 174 175 var structSchema = oas.ObjectOf(nil) 176 var schemas []*oas.Schema 177 178 for i := 0; i < structType.NumFields(); i++ { 179 field := structType.Field(i) 180 181 if !field.Exported() { 182 continue 183 } 184 185 structFieldType := field.Type() 186 structFieldTags := reflect.StructTag(structType.Tag(i)) 187 jsonTagValue := structFieldTags.Get("json") 188 189 name, flags := getTagNameAndFlags(jsonTagValue) 190 if name == "-" { 191 continue 192 } 193 194 if name == "" && field.Anonymous() { 195 s := scanner.getSchemaByType(structFieldType) 196 if s != nil { 197 schemas = append(schemas, s) 198 } 199 continue 200 } 201 202 if name == "" { 203 name = field.Name() 204 } 205 206 defaultValue, hasDefault := structFieldTags.Lookup("default") 207 validate, hasValidate := structFieldTags.Lookup("validate") 208 209 required := true 210 if hasOmitempty, ok := flags["omitempty"]; ok { 211 required = !hasOmitempty 212 } else { 213 // todo don't use non-default as required 214 required = !hasDefault 215 } 216 217 propSchema := scanner.getSchemaByType(structFieldType) 218 219 if flags != nil && flags["string"] { 220 propSchema.Type = oas.TypeString 221 } 222 223 if defaultValue != "" { 224 propSchema.Default = defaultValue 225 } 226 227 if hasValidate { 228 BindValidateFromValidateTagString(propSchema, validate) 229 } 230 231 propSchema = propSchema.WithDesc(docOfTypeName(field, scanner.program)) 232 propSchema.AddExtension(XField, field.Name()) 233 234 if nameValue, hasName := structFieldTags.Lookup("name"); hasName { 235 propSchema.AddExtension(XTagName, nameValue) 236 } 237 238 if styleValue, hasStyle := structFieldTags.Lookup("style"); hasStyle { 239 propSchema.AddExtension(XTagStyle, styleValue) 240 } 241 242 if fmtValue, hasFmt := structFieldTags.Lookup("fmt"); hasFmt { 243 propSchema.AddExtension(XTagFmt, fmtValue) 244 } 245 246 if xmlValue, hasXML := structFieldTags.Lookup("xml"); hasXML { 247 propSchema.AddExtension(XTagXML, xmlValue) 248 } 249 250 if jsonTagValue != "" { 251 propSchema.AddExtension(XTagJSON, jsonTagValue) 252 } 253 254 if propSchema.Ref != "" { 255 composedSchema := oas.AllOf( 256 propSchema, 257 &oas.Schema{ 258 SchemaObject: propSchema.SchemaObject, 259 }, 260 ) 261 composedSchema.SpecExtensions = propSchema.SpecExtensions 262 structSchema.SetProperty(name, composedSchema, required) 263 } else { 264 structSchema.SetProperty(name, propSchema, required) 265 } 266 267 } 268 269 if len(schemas) > 0 { 270 return oas.AllOf(append(schemas, structSchema)...) 271 } 272 return structSchema 273 } 274 return nil 275 } 276 277 type StructFieldUniqueChecker map[string]*types.Var 278 279 func (checker StructFieldUniqueChecker) Check(structType *types.Struct, anonymous bool) error { 280 for i := 0; i < structType.NumFields(); i++ { 281 field := structType.Field(i) 282 if !field.Exported() { 283 continue 284 } 285 if field.Anonymous() { 286 if named, ok := field.Type().(*types.Named); ok { 287 if st, ok := named.Underlying().(*types.Struct); ok { 288 if err := checker.Check(st, true); err != nil { 289 return err 290 } 291 } 292 } 293 continue 294 } 295 if anonymous { 296 if _, ok := checker[field.Name()]; ok { 297 return fmt.Errorf("%s.%s already defined in other anonymous field", structType.String(), field.Name()) 298 } 299 checker[field.Name()] = field 300 } 301 } 302 return nil 303 } 304 305 type VendorExtensible interface { 306 AddExtension(key string, value interface{}) 307 } 308 309 func markPointer(vendorExtensible VendorExtensible, count int) { 310 vendorExtensible.AddExtension(XPointer, count) 311 } 312 313 func toDefID(s string) string { 314 _, expose := loaderx.GetPkgImportPathAndExpose(s) 315 return expose 316 } 317 318 var ( 319 rxEnum = regexp.MustCompile(`swagger:enum`) 320 rxStrFmt = regexp.MustCompile(`swagger:strfmt\s+(\S+)([\s\S]+)?$`) 321 ) 322 323 func ParseEnum(doc string) (string, bool) { 324 if rxEnum.MatchString(doc) { 325 return strings.TrimSpace(strings.Replace(doc, "swagger:enum", "", -1)), true 326 } 327 return doc, false 328 } 329 330 func ParseStrfmt(doc string) (string, string) { 331 matched := rxStrFmt.FindAllStringSubmatch(doc, -1) 332 if len(matched) > 0 { 333 return strings.TrimSpace(matched[0][2]), matched[0][1] 334 } 335 return doc, "" 336 } 337 338 func getSchemaTypeFromBasicType(basicTypeName string) (tpe oas.Type, format string) { 339 switch basicTypeName { 340 case "bool": 341 return "boolean", "" 342 case "byte": 343 return "integer", "uint8" 344 case "error": 345 return "string", "" 346 case "float32": 347 return "number", "float" 348 case "float64": 349 return "number", "double" 350 case "int": 351 return "integer", "int64" 352 case "int8": 353 return "integer", "int8" 354 case "int16": 355 return "integer", "int16" 356 case "int32": 357 return "integer", "int32" 358 case "int64": 359 return "integer", "int64" 360 case "rune": 361 return "integer", "int32" 362 case "string": 363 return "string", "" 364 case "uint": 365 return "integer", "uint64" 366 case "uint16": 367 return "integer", "uint16" 368 case "uint32": 369 return "integer", "uint32" 370 case "uint64": 371 return "integer", "uint64" 372 case "uint8": 373 return "integer", "uint8" 374 case "uintptr": 375 return "integer", "uint64" 376 default: 377 panic(fmt.Errorf("unsupported type %q", basicTypeName)) 378 } 379 }