github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/courier/swagger/gen/operator_scanner.go (about) 1 package gen 2 3 import ( 4 "fmt" 5 "go/ast" 6 "go/build" 7 "go/types" 8 "net/http" 9 "os" 10 "path/filepath" 11 "reflect" 12 "runtime/debug" 13 "strings" 14 15 "github.com/johnnyeven/libtools/courier/transport_http/transform" 16 17 "github.com/morlay/oas" 18 "github.com/sirupsen/logrus" 19 "golang.org/x/tools/go/loader" 20 21 "github.com/johnnyeven/libtools/codegen/loaderx" 22 "github.com/johnnyeven/libtools/courier/httpx" 23 "github.com/johnnyeven/libtools/courier/status_error" 24 "github.com/johnnyeven/libtools/courier/transport_http" 25 ) 26 27 func FullNameOfType(tpe reflect.Type) string { 28 return fmt.Sprintf("%s.%s", tpe.PkgPath(), tpe.Name()) 29 } 30 31 var TypeWebSocketListeners = FullNameOfType(reflect.TypeOf(transport_http.Listeners{})) 32 var TypeWebSocketClient = FullNameOfType(reflect.TypeOf(transport_http.WSClient{})) 33 34 func ConcatToOperation(method string, operators ...Operator) *oas.Operation { 35 operation := &oas.Operation{} 36 length := len(operators) 37 for idx, operator := range operators { 38 operator.BindOperation(method, operation, idx == length-1) 39 } 40 return operation 41 } 42 43 func NewOperatorScanner(program *loader.Program) *OperatorScanner { 44 return &OperatorScanner{ 45 DefinitionScanner: NewDefinitionScanner(program), 46 StatusErrorScanner: NewStatusErrorScanner(program), 47 program: program, 48 } 49 } 50 51 type OperatorScanner struct { 52 *DefinitionScanner 53 *StatusErrorScanner 54 program *loader.Program 55 operators map[*types.TypeName]Operator 56 } 57 58 func (scanner *OperatorScanner) Operator(typeName *types.TypeName) *Operator { 59 if typeName == nil { 60 return nil 61 } 62 63 if operator, ok := scanner.operators[typeName]; ok { 64 return &operator 65 } 66 67 defer func() { 68 if e := recover(); e != nil { 69 logrus.Errorf("scan Operator `%v` failed, panic: %s; calltrace: %s", typeName, fmt.Sprint(e), string(debug.Stack())) 70 } 71 }() 72 73 if typeStruct, ok := typeName.Type().Underlying().(*types.Struct); ok { 74 operator := Operator{ 75 ID: typeName.Name(), 76 Tag: getTagNameByPkgPath(typeName.Pkg().Path()), 77 } 78 79 scanner.bindParameterOrRequestBody(&operator, typeStruct) 80 scanner.bindReturns(&operator, typeName) 81 82 if scanner.operators == nil { 83 scanner.operators = map[*types.TypeName]Operator{} 84 } 85 86 operator.Summary = docOfTypeName(typeName, scanner.program) 87 88 scanner.operators[typeName] = operator 89 90 return &operator 91 } 92 93 return nil 94 } 95 96 func getTagNameByPkgPath(pkgPath string) string { 97 cwd, _ := os.Getwd() 98 p, _ := build.Default.Import(pkgPath, "", build.FindOnly) 99 tag, _ := filepath.Rel(cwd, p.Dir) 100 i := strings.Index(tag, "routes/") 101 if i >= 0 { 102 tag = string([]byte(tag)[i:]) 103 } 104 return strings.Replace(tag, "routes/", "", 1) 105 } 106 107 func (scanner *OperatorScanner) bindWebSocketMessages(op *Operator, schema *oas.Schema, typeVar *types.Var) { 108 if strings.Contains(typeVar.Type().String(), TypeWebSocketClient) { 109 for pkg, pkgInfo := range scanner.program.AllPackages { 110 if pkg == typeVar.Pkg() { 111 for selectExpr := range pkgInfo.Selections { 112 if ident, ok := selectExpr.X.(*ast.Ident); ok { 113 if pkgInfo.ObjectOf(ident) == typeVar && "Send" == selectExpr.Sel.Name { 114 file := loaderx.FileOf(selectExpr, pkgInfo.Files...) 115 ast.Inspect(file, func(node ast.Node) bool { 116 switch node.(type) { 117 case *ast.CallExpr: 118 callExpr := node.(*ast.CallExpr) 119 if callExpr.Fun == selectExpr { 120 tpe := pkgInfo.TypeOf(callExpr.Args[0]) 121 subSchema := scanner.getSchemaByType(tpe.(*types.Named)) 122 op.AddWebSocketMessage(schema, subSchema) 123 return false 124 } 125 } 126 return true 127 }) 128 } 129 } 130 } 131 } 132 } 133 } 134 } 135 136 func (scanner *OperatorScanner) bindWebSocketListeners(op *Operator, typeFunc *types.Func) { 137 scope := typeFunc.Scope() 138 for _, name := range scope.Names() { 139 n := scope.Lookup(name) 140 if strings.Contains(n.Type().String(), TypeWebSocketListeners) { 141 for pkg, pkgInfo := range scanner.program.AllPackages { 142 if pkg == n.Pkg() { 143 for selectExpr := range pkgInfo.Selections { 144 if ident, ok := selectExpr.X.(*ast.Ident); ok { 145 if pkgInfo.ObjectOf(ident) == n && "On" == selectExpr.Sel.Name { 146 file := loaderx.FileOf(selectExpr, pkgInfo.Files...) 147 ast.Inspect(file, func(node ast.Node) bool { 148 switch node.(type) { 149 case *ast.CallExpr: 150 callExpr := node.(*ast.CallExpr) 151 if callExpr.Fun == selectExpr { 152 tpe := pkgInfo.TypeOf(callExpr.Args[0]) 153 schema := scanner.getSchemaByType(tpe.(*types.Named)) 154 op.AddWebSocketMessage(schema) 155 156 params := pkgInfo.TypeOf(callExpr.Args[1]).(*types.Signature).Params() 157 158 for i := 0; i < params.Len(); i++ { 159 scanner.bindWebSocketMessages(op, schema, params.At(i)) 160 } 161 return false 162 } 163 } 164 return true 165 }) 166 } 167 } 168 } 169 } 170 } 171 } 172 } 173 } 174 175 func (scanner *OperatorScanner) bindReturns(op *Operator, typeName *types.TypeName) { 176 typeFunc := loaderx.MethodOf(typeName.Type().(*types.Named), "Output") 177 178 if typeFunc != nil { 179 metaData := ParseSuccessMetadata(docOfTypeName(typeFunc, scanner.program)) 180 181 loaderx.ForEachFuncResult(scanner.program, typeFunc, func(resultTypeAndValues ...types.TypeAndValue) { 182 successType := resultTypeAndValues[0].Type 183 184 if strings.Contains(successType.String(), TypeWebSocketListeners) { 185 scanner.bindWebSocketListeners(op, typeFunc) 186 return 187 } 188 189 if successType.String() != types.Typ[types.UntypedNil].String() { 190 if op.SuccessType != nil && op.SuccessType.String() != successType.String() { 191 logrus.Warnf(fmt.Sprintf("%s success result must be same struct, but got %v, already set %v", op.ID, successType, op.SuccessType)) 192 } 193 op.SuccessType = successType 194 op.SuccessStatus, op.SuccessResponse = scanner.getResponse(successType, metaData.Get("content-type")) 195 } 196 197 op.StatusErrors = scanner.StatusErrorScanner.StatusErrorsInFunc(typeFunc) 198 op.StatusErrorSchema = scanner.DefinitionScanner.getSchemaByTypeString(statusErrorTypeString) 199 }) 200 } 201 } 202 203 func (scanner *OperatorScanner) getResponse(tpe types.Type, contentType string) (status int, response *oas.Response) { 204 response = &oas.Response{} 205 206 if tpe.String() == "error" { 207 status = http.StatusNoContent 208 return 209 } 210 211 if contentType == "" { 212 contentType = httpx.MIMEJSON 213 } 214 215 if pointer, ok := tpe.(*types.Pointer); ok { 216 tpe = pointer.Elem() 217 } 218 219 if named, ok := tpe.(*types.Named); ok { 220 { 221 typeFunc := loaderx.MethodOf(named, "ContentType") 222 if typeFunc != nil { 223 loaderx.ForEachFuncResult(scanner.program, typeFunc, func(resultTypeAndValues ...types.TypeAndValue) { 224 if resultTypeAndValues[0].IsValue() { 225 contentType = getConstVal(resultTypeAndValues[0].Value).(string) 226 } 227 }) 228 } 229 } 230 231 { 232 typeFunc := loaderx.MethodOf(named, "Status") 233 if typeFunc != nil { 234 loaderx.ForEachFuncResult(scanner.program, typeFunc, func(resultTypeAndValues ...types.TypeAndValue) { 235 if resultTypeAndValues[0].IsValue() { 236 status = int(getConstVal(resultTypeAndValues[0].Value).(int64)) 237 } 238 }) 239 } 240 } 241 } 242 243 response.AddContent(contentType, oas.NewMediaTypeWithSchema(scanner.DefinitionScanner.getSchemaByType(tpe))) 244 245 return 246 } 247 248 func (scanner *OperatorScanner) bindParameterOrRequestBody(op *Operator, typeStruct *types.Struct) { 249 for i := 0; i < typeStruct.NumFields(); i++ { 250 var field = typeStruct.Field(i) 251 252 if !field.Exported() { 253 continue 254 } 255 256 var fieldType = field.Type() 257 var fieldName = field.Name() 258 var structFieldTags = reflect.StructTag(typeStruct.Tag(i)) 259 260 location, locationFlags := getTagNameAndFlags(structFieldTags.Get("in")) 261 262 if location == "" { 263 if fieldName == "Body" { 264 location = "body" 265 } 266 } 267 268 if location == "context" { 269 continue 270 } 271 272 if field.Anonymous() { 273 if typeStruct, ok := fieldType.Underlying().(*types.Struct); ok { 274 scanner.bindParameterOrRequestBody(op, typeStruct) 275 } 276 continue 277 } 278 279 if location == "" { 280 panic(fmt.Errorf("missing tag `in` for %s of %s", fieldName, op.ID)) 281 } 282 283 name, flags := getTagNameAndFlags(structFieldTags.Get("name")) 284 if name == "" { 285 name, flags = getTagNameAndFlags(structFieldTags.Get("json")) 286 } 287 288 var param *oas.Parameter 289 290 if location == "body" || location == "formData" { 291 op.SetRequestBody(scanner.getRequestBody(fieldType, location, locationFlags["multipart"])) 292 continue 293 } 294 295 if name == "" { 296 panic(fmt.Errorf("missing tag `name` or `json` for parameter %s of %s", fieldName, op.ID)) 297 } 298 299 param = scanner.getNonBodyParameter(name, flags, location, structFieldTags, fieldType) 300 301 if param.Schema != nil && flags != nil && flags["string"] { 302 param.Schema.Type = oas.TypeString 303 } 304 305 if styleValue, hasStyle := structFieldTags.Lookup("style"); hasStyle { 306 param.AddExtension(XTagStyle, styleValue) 307 } 308 309 if fmtValue, hasFmt := structFieldTags.Lookup("fmt"); hasFmt { 310 param.AddExtension(XTagFmt, fmtValue) 311 } 312 313 param = param.WithDesc(docOfTypeName(field, scanner.program)) 314 param.AddExtension(XField, field.Name()) 315 op.AddNonBodyParameter(param) 316 } 317 } 318 319 func (scanner *OperatorScanner) getRequestBody(t types.Type, location string, isMultipart bool) *oas.RequestBody { 320 reqBody := oas.NewRequestBody("", true) 321 schema := scanner.DefinitionScanner.getSchemaByType(t) 322 323 contentType := httpx.MIMEJSON 324 325 if location == "formData" { 326 if isMultipart { 327 contentType = httpx.MIMEMultipartPOSTForm 328 } else { 329 contentType = httpx.MIMEPOSTForm 330 } 331 } 332 333 reqBody.Required = true 334 reqBody.AddContent(contentType, oas.NewMediaTypeWithSchema(schema)) 335 return reqBody 336 } 337 338 func (scanner *OperatorScanner) getNonBodyParameter(name string, nameFlags transform.TagFlags, location string, tags reflect.StructTag, t types.Type) *oas.Parameter { 339 schema := scanner.DefinitionScanner.getSchemaByType(t) 340 341 defaultValue, hasDefault := tags.Lookup("default") 342 if hasDefault { 343 schema.Default = defaultValue 344 } 345 346 required := true 347 if hasOmitempty, ok := nameFlags["omitempty"]; ok { 348 required = !hasOmitempty 349 } else { 350 // todo don't use non-default as required 351 required = !hasDefault 352 } 353 354 validate, hasValidate := tags.Lookup("validate") 355 if hasValidate { 356 BindValidateFromValidateTagString(schema, validate) 357 } 358 359 if schema != nil && schema.Ref != "" { 360 schema = oas.AllOf( 361 schema, 362 &oas.Schema{ 363 SchemaObject: schema.SchemaObject, 364 SpecExtensions: schema.SpecExtensions, 365 }, 366 ) 367 } 368 369 switch location { 370 case "query": 371 return oas.QueryParameter(name, schema, required) 372 case "cookie": 373 return oas.CookieParameter(name, schema, required) 374 case "header": 375 return oas.HeaderParameter(name, schema, required) 376 case "path": 377 return oas.PathParameter(name, schema) 378 } 379 return nil 380 } 381 382 type Operator struct { 383 ID string 384 NonBodyParameters map[string]*oas.Parameter 385 RequestBody *oas.RequestBody 386 387 StatusErrors status_error.StatusErrorCodeMap 388 StatusErrorSchema *oas.Schema 389 390 Tag string 391 Summary string 392 SuccessType types.Type 393 SuccessStatus int 394 SuccessResponse *oas.Response 395 WebSocketMessages map[*oas.Schema][]*oas.Schema 396 } 397 398 func (operator *Operator) AddWebSocketMessage(schema *oas.Schema, returns ...*oas.Schema) { 399 if operator.WebSocketMessages == nil { 400 operator.WebSocketMessages = map[*oas.Schema][]*oas.Schema{} 401 } 402 operator.WebSocketMessages[schema] = append(operator.WebSocketMessages[schema], returns...) 403 } 404 405 func (operator *Operator) AddNonBodyParameter(parameter *oas.Parameter) { 406 if operator.NonBodyParameters == nil { 407 operator.NonBodyParameters = map[string]*oas.Parameter{} 408 } 409 operator.NonBodyParameters[parameter.Name] = parameter 410 } 411 412 func (operator *Operator) SetRequestBody(requestBody *oas.RequestBody) { 413 operator.RequestBody = requestBody 414 } 415 416 func (operator *Operator) BindOperation(method string, operation *oas.Operation, last bool) { 417 if operator.WebSocketMessages != nil { 418 schema := oas.ObjectOf(nil) 419 420 for msgSchema, list := range operator.WebSocketMessages { 421 s := oas.ObjectOf(nil) 422 423 s.SetProperty(typeOfSchema(msgSchema), msgSchema, false) 424 425 if list != nil { 426 sub := oas.ObjectOf(nil) 427 for _, item := range list { 428 sub.SetProperty(typeOfSchema(item), item, false) 429 } 430 schema.SetProperty("out", sub, false) 431 } 432 schema.SetProperty("in", s, false) 433 } 434 435 requestBody := oas.NewRequestBody("WebSocket", true) 436 requestBody.AddContent(httpx.MIMEJSON, oas.NewMediaTypeWithSchema(schema)) 437 438 operation.SetRequestBody(requestBody) 439 return 440 } 441 442 parameterNames := map[string]bool{} 443 for _, parameter := range operation.Parameters { 444 parameterNames[parameter.Name] = true 445 } 446 447 for _, parameter := range operator.NonBodyParameters { 448 if !parameterNames[parameter.Name] { 449 operation.Parameters = append(operation.Parameters, parameter) 450 } 451 } 452 453 if operator.RequestBody != nil { 454 operation.SetRequestBody(operator.RequestBody) 455 } 456 457 for code, statusError := range operator.StatusErrors { 458 resp := (*oas.Response)(nil) 459 if operation.Responses.Responses != nil { 460 resp = operation.Responses.Responses[statusError.Status()] 461 } 462 statusErrors := status_error.StatusErrorCodeMap{} 463 if resp != nil { 464 statusErrors = pickStatusErrorsFromDoc(resp.Description) 465 } 466 statusErrors[code] = statusError 467 resp = oas.NewResponse(statusErrors.String()) 468 resp.AddContent(httpx.MIMEJSON, oas.NewMediaTypeWithSchema(operator.StatusErrorSchema)) 469 operation.AddResponse(statusError.Status(), resp) 470 } 471 472 if last { 473 operation.OperationId = operator.ID 474 docs := strings.Split(operator.Summary, "\n") 475 if operator.Tag != "" { 476 operation.Tags = []string{operator.Tag} 477 } 478 operation.Summary = docs[0] 479 if len(docs) > 1 { 480 operation.Description = strings.Join(docs[1:], "\n") 481 } 482 if operator.SuccessType == nil { 483 operation.Responses.AddResponse(http.StatusNoContent, &oas.Response{}) 484 } else { 485 status := operator.SuccessStatus 486 if status == 0 { 487 status = http.StatusOK 488 if method == http.MethodPost { 489 status = http.StatusCreated 490 } 491 } 492 if status >= http.StatusMultipleChoices && status < http.StatusBadRequest { 493 operator.SuccessResponse = oas.NewResponse(operator.SuccessResponse.Description) 494 } 495 operation.Responses.AddResponse(status, operator.SuccessResponse) 496 } 497 } 498 } 499 500 func typeOfSchema(schema *oas.Schema) string { 501 l := strings.Split(schema.Ref, "/") 502 return l[len(l)-1] 503 }