github.com/artisanhe/tools@v1.0.1-0.20210607022958-19a8fef2eb04/courier/swagger/gen/operator_scanner.go (about) 1 package gen 2 3 import ( 4 "fmt" 5 "go/ast" 6 "go/build" 7 "go/types" 8 "net/http" 9 "os" 10 "path/filepath" 11 "reflect" 12 "runtime/debug" 13 "strings" 14 15 "github.com/go-courier/oas" 16 "github.com/sirupsen/logrus" 17 "golang.org/x/tools/go/loader" 18 19 "github.com/artisanhe/tools/codegen/loaderx" 20 "github.com/artisanhe/tools/courier/httpx" 21 "github.com/artisanhe/tools/courier/status_error" 22 "github.com/artisanhe/tools/courier/transport_http" 23 "github.com/artisanhe/tools/courier/transport_http/transform" 24 ) 25 26 func FullNameOfType(tpe reflect.Type) string { 27 return fmt.Sprintf("%s.%s", tpe.PkgPath(), tpe.Name()) 28 } 29 30 var TypeWebSocketListeners = FullNameOfType(reflect.TypeOf(transport_http.Listeners{})) 31 var TypeWebSocketClient = FullNameOfType(reflect.TypeOf(transport_http.WSClient{})) 32 33 func ConcatToOperation(method string, operators ...Operator) *oas.Operation { 34 operation := &oas.Operation{} 35 length := len(operators) 36 for idx, operator := range operators { 37 operator.BindOperation(method, operation, idx == length-1) 38 } 39 return operation 40 } 41 42 func NewOperatorScanner(program *loader.Program) *OperatorScanner { 43 return &OperatorScanner{ 44 DefinitionScanner: NewDefinitionScanner(program), 45 StatusErrorScanner: NewStatusErrorScanner(program), 46 program: program, 47 } 48 } 49 50 type OperatorScanner struct { 51 *DefinitionScanner 52 *StatusErrorScanner 53 program *loader.Program 54 operators map[*types.TypeName]Operator 55 } 56 57 func (scanner *OperatorScanner) Operator(typeName *types.TypeName) *Operator { 58 if typeName == nil { 59 return nil 60 } 61 62 if operator, ok := scanner.operators[typeName]; ok { 63 return &operator 64 } 65 66 defer func() { 67 if e := recover(); e != nil { 68 panic(fmt.Errorf("scan Operator `%v` failed, panic: %s; calltrace: %s", typeName, fmt.Sprint(e), string(debug.Stack()))) 69 } 70 }() 71 72 if typeStruct, ok := typeName.Type().Underlying().(*types.Struct); ok { 73 operator := Operator{ 74 ID: typeName.Name(), 75 Tag: getTagNameByPkgPath(typeName.Pkg().Path()), 76 } 77 78 scanner.bindParameterOrRequestBody(&operator, typeStruct) 79 scanner.bindReturns(&operator, typeName) 80 81 if scanner.operators == nil { 82 scanner.operators = map[*types.TypeName]Operator{} 83 } 84 85 operator.Summary = docOfTypeName(typeName, scanner.program) 86 87 scanner.operators[typeName] = operator 88 89 return &operator 90 } 91 92 return nil 93 } 94 95 func getTagNameByPkgPath(pkgPath string) string { 96 cwd, _ := os.Getwd() 97 p, _ := build.Default.Import(pkgPath, "", build.FindOnly) 98 tag, _ := filepath.Rel(cwd, p.Dir) 99 return strings.Replace(tag, "routes/", "", 1) 100 } 101 102 func (scanner *OperatorScanner) bindWebSocketMessages(op *Operator, schema *oas.Schema, typeVar *types.Var) { 103 if strings.Contains(typeVar.Type().String(), TypeWebSocketClient) { 104 for pkg, pkgInfo := range scanner.program.AllPackages { 105 if pkg == typeVar.Pkg() { 106 for selectExpr := range pkgInfo.Selections { 107 if ident, ok := selectExpr.X.(*ast.Ident); ok { 108 if pkgInfo.ObjectOf(ident) == typeVar && "Send" == selectExpr.Sel.Name { 109 file := loaderx.FileOf(selectExpr, pkgInfo.Files...) 110 ast.Inspect(file, func(node ast.Node) bool { 111 switch node.(type) { 112 case *ast.CallExpr: 113 callExpr := node.(*ast.CallExpr) 114 if callExpr.Fun == selectExpr { 115 tpe := pkgInfo.TypeOf(callExpr.Args[0]) 116 subSchema := scanner.getSchemaByType(tpe.(*types.Named)) 117 op.AddWebSocketMessage(schema, subSchema) 118 return false 119 } 120 } 121 return true 122 }) 123 } 124 } 125 } 126 } 127 } 128 } 129 } 130 131 func (scanner *OperatorScanner) bindWebSocketListeners(op *Operator, typeFunc *types.Func) { 132 scope := typeFunc.Scope() 133 for _, name := range scope.Names() { 134 n := scope.Lookup(name) 135 if strings.Contains(n.Type().String(), TypeWebSocketListeners) { 136 for pkg, pkgInfo := range scanner.program.AllPackages { 137 if pkg == n.Pkg() { 138 for selectExpr := range pkgInfo.Selections { 139 if ident, ok := selectExpr.X.(*ast.Ident); ok { 140 if pkgInfo.ObjectOf(ident) == n && "On" == selectExpr.Sel.Name { 141 file := loaderx.FileOf(selectExpr, pkgInfo.Files...) 142 ast.Inspect(file, func(node ast.Node) bool { 143 switch node.(type) { 144 case *ast.CallExpr: 145 callExpr := node.(*ast.CallExpr) 146 if callExpr.Fun == selectExpr { 147 tpe := pkgInfo.TypeOf(callExpr.Args[0]) 148 schema := scanner.getSchemaByType(tpe.(*types.Named)) 149 op.AddWebSocketMessage(schema) 150 151 params := pkgInfo.TypeOf(callExpr.Args[1]).(*types.Signature).Params() 152 153 for i := 0; i < params.Len(); i++ { 154 scanner.bindWebSocketMessages(op, schema, params.At(i)) 155 } 156 return false 157 } 158 } 159 return true 160 }) 161 } 162 } 163 } 164 } 165 } 166 } 167 } 168 } 169 170 func (scanner *OperatorScanner) bindReturns(op *Operator, typeName *types.TypeName) { 171 typeFunc := loaderx.MethodOf(typeName.Type().(*types.Named), "Output") 172 173 if typeFunc != nil { 174 metaData := ParseSuccessMetadata(docOfTypeName(typeFunc, scanner.program)) 175 176 loaderx.ForEachFuncResult(scanner.program, typeFunc, func(resultTypeAndValues ...types.TypeAndValue) { 177 successType := resultTypeAndValues[0].Type 178 179 if strings.Contains(successType.String(), TypeWebSocketListeners) { 180 scanner.bindWebSocketListeners(op, typeFunc) 181 return 182 } 183 184 if successType.String() != types.Typ[types.UntypedNil].String() { 185 if op.SuccessType != nil && op.SuccessType.String() != successType.String() { 186 logrus.Warnf(fmt.Sprintf("%s success result must be same struct, but got %v, already set %v", op.ID, successType, op.SuccessType)) 187 } 188 op.SuccessType = successType 189 op.SuccessStatus, op.SuccessResponse = scanner.getResponse(successType, metaData.Get("content-type")) 190 } 191 192 op.StatusErrors = scanner.StatusErrorScanner.StatusErrorsInFunc(typeFunc) 193 op.StatusErrorSchema = scanner.DefinitionScanner.getSchemaByTypeString(statusErrorTypeString) 194 }) 195 } 196 } 197 198 func (scanner *OperatorScanner) getResponse(tpe types.Type, contentType string) (status int, response *oas.Response) { 199 response = &oas.Response{} 200 201 if tpe.String() == "error" { 202 status = http.StatusNoContent 203 return 204 } 205 206 if contentType == "" { 207 contentType = httpx.MIMEJSON 208 } 209 210 if pointer, ok := tpe.(*types.Pointer); ok { 211 tpe = pointer.Elem() 212 } 213 214 if named, ok := tpe.(*types.Named); ok { 215 { 216 typeFunc := loaderx.MethodOf(named, "ContentType") 217 if typeFunc != nil { 218 loaderx.ForEachFuncResult(scanner.program, typeFunc, func(resultTypeAndValues ...types.TypeAndValue) { 219 if resultTypeAndValues[0].IsValue() { 220 contentType = getConstVal(resultTypeAndValues[0].Value).(string) 221 } 222 }) 223 } 224 } 225 226 { 227 typeFunc := loaderx.MethodOf(named, "Status") 228 if typeFunc != nil { 229 loaderx.ForEachFuncResult(scanner.program, typeFunc, func(resultTypeAndValues ...types.TypeAndValue) { 230 if resultTypeAndValues[0].IsValue() { 231 status = int(getConstVal(resultTypeAndValues[0].Value).(int64)) 232 } 233 }) 234 } 235 } 236 } 237 238 response.AddContent(contentType, oas.NewMediaTypeWithSchema(scanner.DefinitionScanner.getSchemaByType(tpe))) 239 240 return 241 } 242 243 func (scanner *OperatorScanner) bindParameterOrRequestBody(op *Operator, typeStruct *types.Struct) { 244 for i := 0; i < typeStruct.NumFields(); i++ { 245 var field = typeStruct.Field(i) 246 247 if !field.Exported() { 248 continue 249 } 250 251 var fieldType = field.Type() 252 var fieldName = field.Name() 253 var structFieldTags = reflect.StructTag(typeStruct.Tag(i)) 254 255 location, locationFlags := getTagNameAndFlags(structFieldTags.Get("in")) 256 257 if location == "" { 258 if fieldName == "Body" { 259 location = "body" 260 } 261 } 262 263 if location == "context" { 264 continue 265 } 266 267 if field.Anonymous() { 268 if typeStruct, ok := fieldType.Underlying().(*types.Struct); ok { 269 scanner.bindParameterOrRequestBody(op, typeStruct) 270 } 271 continue 272 } 273 274 if location == "" { 275 panic(fmt.Errorf("missing tag `in` for %s of %s", fieldName, op.ID)) 276 } 277 278 name, flags := getTagNameAndFlags(structFieldTags.Get("name")) 279 if name == "" { 280 name, flags = getTagNameAndFlags(structFieldTags.Get("json")) 281 } 282 283 var param *oas.Parameter 284 285 if location == "body" || location == "formData" { 286 op.SetRequestBody(scanner.getRequestBody(fieldType, location, locationFlags["multipart"])) 287 continue 288 } 289 290 if name == "" { 291 panic(fmt.Errorf("missing tag `name` or `json` for parameter %s of %s", fieldName, op.ID)) 292 } 293 294 param = scanner.getNonBodyParameter(name, flags, location, structFieldTags, fieldType) 295 296 if param.Schema != nil && flags != nil && flags["string"] { 297 param.Schema.Type = oas.TypeString 298 } 299 300 if styleValue, hasStyle := structFieldTags.Lookup("style"); hasStyle { 301 param.AddExtension(XTagStyle, styleValue) 302 } 303 304 if fmtValue, hasFmt := structFieldTags.Lookup("fmt"); hasFmt { 305 param.AddExtension(XTagFmt, fmtValue) 306 } 307 308 param = param.WithDesc(docOfTypeName(field, scanner.program)) 309 param.AddExtension(XField, field.Name()) 310 op.AddNonBodyParameter(param) 311 } 312 } 313 314 func (scanner *OperatorScanner) getRequestBody(t types.Type, location string, isMultipart bool) *oas.RequestBody { 315 reqBody := oas.NewRequestBody("", true) 316 schema := scanner.DefinitionScanner.getSchemaByType(t) 317 318 contentType := httpx.MIMEJSON 319 320 if location == "formData" { 321 if isMultipart { 322 contentType = httpx.MIMEMultipartPOSTForm 323 } else { 324 contentType = httpx.MIMEPOSTForm 325 } 326 } 327 328 reqBody.Required = true 329 reqBody.AddContent(contentType, oas.NewMediaTypeWithSchema(schema)) 330 return reqBody 331 } 332 333 func (scanner *OperatorScanner) getNonBodyParameter(name string, nameFlags transform.TagFlags, location string, tags reflect.StructTag, t types.Type) *oas.Parameter { 334 schema := scanner.DefinitionScanner.getSchemaByType(t) 335 336 defaultValue, hasDefault := tags.Lookup("default") 337 if hasDefault { 338 schema.Default = defaultValue 339 } 340 341 required := true 342 if hasOmitempty, ok := nameFlags["omitempty"]; ok { 343 required = !hasOmitempty 344 } else { 345 // todo don't use non-default as required 346 required = !hasDefault 347 } 348 349 validate, hasValidate := tags.Lookup("validate") 350 if hasValidate { 351 BindValidateFromValidateTagString(schema, validate) 352 } 353 354 if schema != nil && schema.Refer.RefString() != "" { 355 schema = oas.AllOf( 356 schema, 357 &oas.Schema{ 358 SchemaObject: schema.SchemaObject, 359 SpecExtensions: schema.SpecExtensions, 360 }, 361 ) 362 } 363 364 switch location { 365 case "query": 366 return oas.QueryParameter(name, schema, required) 367 case "cookie": 368 return oas.CookieParameter(name, schema, required) 369 case "header": 370 return oas.HeaderParameter(name, schema, required) 371 case "path": 372 return oas.PathParameter(name, schema) 373 } 374 return nil 375 } 376 377 type Operator struct { 378 ID string 379 NonBodyParameters map[string]*oas.Parameter 380 RequestBody *oas.RequestBody 381 382 StatusErrors status_error.StatusErrorCodeMap 383 StatusErrorSchema *oas.Schema 384 385 Tag string 386 Summary string 387 SuccessType types.Type 388 SuccessStatus int 389 SuccessResponse *oas.Response 390 WebSocketMessages map[*oas.Schema][]*oas.Schema 391 } 392 393 func (operator *Operator) AddWebSocketMessage(schema *oas.Schema, returns ...*oas.Schema) { 394 if operator.WebSocketMessages == nil { 395 operator.WebSocketMessages = map[*oas.Schema][]*oas.Schema{} 396 } 397 operator.WebSocketMessages[schema] = append(operator.WebSocketMessages[schema], returns...) 398 } 399 400 func (operator *Operator) AddNonBodyParameter(parameter *oas.Parameter) { 401 if operator.NonBodyParameters == nil { 402 operator.NonBodyParameters = map[string]*oas.Parameter{} 403 } 404 operator.NonBodyParameters[parameter.Name] = parameter 405 } 406 407 func (operator *Operator) SetRequestBody(requestBody *oas.RequestBody) { 408 operator.RequestBody = requestBody 409 } 410 411 func (operator *Operator) BindOperation(method string, operation *oas.Operation, last bool) { 412 if operator.WebSocketMessages != nil { 413 schema := oas.ObjectOf(nil) 414 415 for msgSchema, list := range operator.WebSocketMessages { 416 s := oas.ObjectOf(nil) 417 418 s.SetProperty(typeOfSchema(msgSchema), msgSchema, false) 419 420 if list != nil { 421 sub := oas.ObjectOf(nil) 422 for _, item := range list { 423 sub.SetProperty(typeOfSchema(item), item, false) 424 } 425 schema.SetProperty("out", sub, false) 426 } 427 schema.SetProperty("in", s, false) 428 } 429 430 requestBody := oas.NewRequestBody("WebSocket", true) 431 requestBody.AddContent(httpx.MIMEJSON, oas.NewMediaTypeWithSchema(schema)) 432 433 operation.SetRequestBody(requestBody) 434 return 435 } 436 437 parameterNames := map[string]bool{} 438 for _, parameter := range operation.Parameters { 439 parameterNames[parameter.Name] = true 440 } 441 442 for _, parameter := range operator.NonBodyParameters { 443 if !parameterNames[parameter.Name] { 444 operation.Parameters = append(operation.Parameters, parameter) 445 } 446 } 447 448 if operator.RequestBody != nil { 449 operation.SetRequestBody(operator.RequestBody) 450 } 451 452 for code, statusError := range operator.StatusErrors { 453 resp := (*oas.Response)(nil) 454 if operation.Responses.Responses != nil { 455 resp = operation.Responses.Responses[statusError.Status()] 456 } 457 statusErrors := status_error.StatusErrorCodeMap{} 458 if resp != nil { 459 statusErrors = pickStatusErrorsFromDoc(resp.Description) 460 } 461 statusErrors[code] = statusError 462 resp = oas.NewResponse(statusErrors.String()) 463 resp.AddContent(httpx.MIMEJSON, oas.NewMediaTypeWithSchema(operator.StatusErrorSchema)) 464 operation.AddResponse(statusError.Status(), resp) 465 } 466 467 if last { 468 operation.OperationId = operator.ID 469 docs := strings.Split(operator.Summary, "\n") 470 if operator.Tag != "" { 471 operation.Tags = []string{operator.Tag} 472 } 473 operation.Summary = docs[0] 474 if len(docs) > 1 { 475 operation.Description = strings.Join(docs[1:], "\n") 476 } 477 if operator.SuccessType == nil { 478 operation.Responses.AddResponse(http.StatusNoContent, &oas.Response{}) 479 } else { 480 status := operator.SuccessStatus 481 if status == 0 { 482 status = http.StatusOK 483 if method == http.MethodPost { 484 status = http.StatusCreated 485 } 486 } 487 if status >= http.StatusMultipleChoices && status < http.StatusBadRequest { 488 operator.SuccessResponse = oas.NewResponse(operator.SuccessResponse.Description) 489 } 490 operation.Responses.AddResponse(status, operator.SuccessResponse) 491 } 492 } 493 } 494 495 func typeOfSchema(schema *oas.Schema) string { 496 l := strings.Split(schema.Refer.RefString(), "/") 497 return l[len(l)-1] 498 }