github.com/artisanhe/tools@v1.0.1-0.20210607022958-19a8fef2eb04/courier/swagger/gen/operator_scanner.go (about)

     1  package gen
     2  
     3  import (
     4  	"fmt"
     5  	"go/ast"
     6  	"go/build"
     7  	"go/types"
     8  	"net/http"
     9  	"os"
    10  	"path/filepath"
    11  	"reflect"
    12  	"runtime/debug"
    13  	"strings"
    14  
    15  	"github.com/go-courier/oas"
    16  	"github.com/sirupsen/logrus"
    17  	"golang.org/x/tools/go/loader"
    18  
    19  	"github.com/artisanhe/tools/codegen/loaderx"
    20  	"github.com/artisanhe/tools/courier/httpx"
    21  	"github.com/artisanhe/tools/courier/status_error"
    22  	"github.com/artisanhe/tools/courier/transport_http"
    23  	"github.com/artisanhe/tools/courier/transport_http/transform"
    24  )
    25  
    26  func FullNameOfType(tpe reflect.Type) string {
    27  	return fmt.Sprintf("%s.%s", tpe.PkgPath(), tpe.Name())
    28  }
    29  
    30  var TypeWebSocketListeners = FullNameOfType(reflect.TypeOf(transport_http.Listeners{}))
    31  var TypeWebSocketClient = FullNameOfType(reflect.TypeOf(transport_http.WSClient{}))
    32  
    33  func ConcatToOperation(method string, operators ...Operator) *oas.Operation {
    34  	operation := &oas.Operation{}
    35  	length := len(operators)
    36  	for idx, operator := range operators {
    37  		operator.BindOperation(method, operation, idx == length-1)
    38  	}
    39  	return operation
    40  }
    41  
    42  func NewOperatorScanner(program *loader.Program) *OperatorScanner {
    43  	return &OperatorScanner{
    44  		DefinitionScanner:  NewDefinitionScanner(program),
    45  		StatusErrorScanner: NewStatusErrorScanner(program),
    46  		program:            program,
    47  	}
    48  }
    49  
    50  type OperatorScanner struct {
    51  	*DefinitionScanner
    52  	*StatusErrorScanner
    53  	program   *loader.Program
    54  	operators map[*types.TypeName]Operator
    55  }
    56  
    57  func (scanner *OperatorScanner) Operator(typeName *types.TypeName) *Operator {
    58  	if typeName == nil {
    59  		return nil
    60  	}
    61  
    62  	if operator, ok := scanner.operators[typeName]; ok {
    63  		return &operator
    64  	}
    65  
    66  	defer func() {
    67  		if e := recover(); e != nil {
    68  			panic(fmt.Errorf("scan Operator `%v` failed, panic: %s; calltrace: %s", typeName, fmt.Sprint(e), string(debug.Stack())))
    69  		}
    70  	}()
    71  
    72  	if typeStruct, ok := typeName.Type().Underlying().(*types.Struct); ok {
    73  		operator := Operator{
    74  			ID:  typeName.Name(),
    75  			Tag: getTagNameByPkgPath(typeName.Pkg().Path()),
    76  		}
    77  
    78  		scanner.bindParameterOrRequestBody(&operator, typeStruct)
    79  		scanner.bindReturns(&operator, typeName)
    80  
    81  		if scanner.operators == nil {
    82  			scanner.operators = map[*types.TypeName]Operator{}
    83  		}
    84  
    85  		operator.Summary = docOfTypeName(typeName, scanner.program)
    86  
    87  		scanner.operators[typeName] = operator
    88  
    89  		return &operator
    90  	}
    91  
    92  	return nil
    93  }
    94  
    95  func getTagNameByPkgPath(pkgPath string) string {
    96  	cwd, _ := os.Getwd()
    97  	p, _ := build.Default.Import(pkgPath, "", build.FindOnly)
    98  	tag, _ := filepath.Rel(cwd, p.Dir)
    99  	return strings.Replace(tag, "routes/", "", 1)
   100  }
   101  
   102  func (scanner *OperatorScanner) bindWebSocketMessages(op *Operator, schema *oas.Schema, typeVar *types.Var) {
   103  	if strings.Contains(typeVar.Type().String(), TypeWebSocketClient) {
   104  		for pkg, pkgInfo := range scanner.program.AllPackages {
   105  			if pkg == typeVar.Pkg() {
   106  				for selectExpr := range pkgInfo.Selections {
   107  					if ident, ok := selectExpr.X.(*ast.Ident); ok {
   108  						if pkgInfo.ObjectOf(ident) == typeVar && "Send" == selectExpr.Sel.Name {
   109  							file := loaderx.FileOf(selectExpr, pkgInfo.Files...)
   110  							ast.Inspect(file, func(node ast.Node) bool {
   111  								switch node.(type) {
   112  								case *ast.CallExpr:
   113  									callExpr := node.(*ast.CallExpr)
   114  									if callExpr.Fun == selectExpr {
   115  										tpe := pkgInfo.TypeOf(callExpr.Args[0])
   116  										subSchema := scanner.getSchemaByType(tpe.(*types.Named))
   117  										op.AddWebSocketMessage(schema, subSchema)
   118  										return false
   119  									}
   120  								}
   121  								return true
   122  							})
   123  						}
   124  					}
   125  				}
   126  			}
   127  		}
   128  	}
   129  }
   130  
   131  func (scanner *OperatorScanner) bindWebSocketListeners(op *Operator, typeFunc *types.Func) {
   132  	scope := typeFunc.Scope()
   133  	for _, name := range scope.Names() {
   134  		n := scope.Lookup(name)
   135  		if strings.Contains(n.Type().String(), TypeWebSocketListeners) {
   136  			for pkg, pkgInfo := range scanner.program.AllPackages {
   137  				if pkg == n.Pkg() {
   138  					for selectExpr := range pkgInfo.Selections {
   139  						if ident, ok := selectExpr.X.(*ast.Ident); ok {
   140  							if pkgInfo.ObjectOf(ident) == n && "On" == selectExpr.Sel.Name {
   141  								file := loaderx.FileOf(selectExpr, pkgInfo.Files...)
   142  								ast.Inspect(file, func(node ast.Node) bool {
   143  									switch node.(type) {
   144  									case *ast.CallExpr:
   145  										callExpr := node.(*ast.CallExpr)
   146  										if callExpr.Fun == selectExpr {
   147  											tpe := pkgInfo.TypeOf(callExpr.Args[0])
   148  											schema := scanner.getSchemaByType(tpe.(*types.Named))
   149  											op.AddWebSocketMessage(schema)
   150  
   151  											params := pkgInfo.TypeOf(callExpr.Args[1]).(*types.Signature).Params()
   152  
   153  											for i := 0; i < params.Len(); i++ {
   154  												scanner.bindWebSocketMessages(op, schema, params.At(i))
   155  											}
   156  											return false
   157  										}
   158  									}
   159  									return true
   160  								})
   161  							}
   162  						}
   163  					}
   164  				}
   165  			}
   166  		}
   167  	}
   168  }
   169  
   170  func (scanner *OperatorScanner) bindReturns(op *Operator, typeName *types.TypeName) {
   171  	typeFunc := loaderx.MethodOf(typeName.Type().(*types.Named), "Output")
   172  
   173  	if typeFunc != nil {
   174  		metaData := ParseSuccessMetadata(docOfTypeName(typeFunc, scanner.program))
   175  
   176  		loaderx.ForEachFuncResult(scanner.program, typeFunc, func(resultTypeAndValues ...types.TypeAndValue) {
   177  			successType := resultTypeAndValues[0].Type
   178  
   179  			if strings.Contains(successType.String(), TypeWebSocketListeners) {
   180  				scanner.bindWebSocketListeners(op, typeFunc)
   181  				return
   182  			}
   183  
   184  			if successType.String() != types.Typ[types.UntypedNil].String() {
   185  				if op.SuccessType != nil && op.SuccessType.String() != successType.String() {
   186  					logrus.Warnf(fmt.Sprintf("%s success result must be same struct, but got %v, already set %v", op.ID, successType, op.SuccessType))
   187  				}
   188  				op.SuccessType = successType
   189  				op.SuccessStatus, op.SuccessResponse = scanner.getResponse(successType, metaData.Get("content-type"))
   190  			}
   191  
   192  			op.StatusErrors = scanner.StatusErrorScanner.StatusErrorsInFunc(typeFunc)
   193  			op.StatusErrorSchema = scanner.DefinitionScanner.getSchemaByTypeString(statusErrorTypeString)
   194  		})
   195  	}
   196  }
   197  
   198  func (scanner *OperatorScanner) getResponse(tpe types.Type, contentType string) (status int, response *oas.Response) {
   199  	response = &oas.Response{}
   200  
   201  	if tpe.String() == "error" {
   202  		status = http.StatusNoContent
   203  		return
   204  	}
   205  
   206  	if contentType == "" {
   207  		contentType = httpx.MIMEJSON
   208  	}
   209  
   210  	if pointer, ok := tpe.(*types.Pointer); ok {
   211  		tpe = pointer.Elem()
   212  	}
   213  
   214  	if named, ok := tpe.(*types.Named); ok {
   215  		{
   216  			typeFunc := loaderx.MethodOf(named, "ContentType")
   217  			if typeFunc != nil {
   218  				loaderx.ForEachFuncResult(scanner.program, typeFunc, func(resultTypeAndValues ...types.TypeAndValue) {
   219  					if resultTypeAndValues[0].IsValue() {
   220  						contentType = getConstVal(resultTypeAndValues[0].Value).(string)
   221  					}
   222  				})
   223  			}
   224  		}
   225  
   226  		{
   227  			typeFunc := loaderx.MethodOf(named, "Status")
   228  			if typeFunc != nil {
   229  				loaderx.ForEachFuncResult(scanner.program, typeFunc, func(resultTypeAndValues ...types.TypeAndValue) {
   230  					if resultTypeAndValues[0].IsValue() {
   231  						status = int(getConstVal(resultTypeAndValues[0].Value).(int64))
   232  					}
   233  				})
   234  			}
   235  		}
   236  	}
   237  
   238  	response.AddContent(contentType, oas.NewMediaTypeWithSchema(scanner.DefinitionScanner.getSchemaByType(tpe)))
   239  
   240  	return
   241  }
   242  
   243  func (scanner *OperatorScanner) bindParameterOrRequestBody(op *Operator, typeStruct *types.Struct) {
   244  	for i := 0; i < typeStruct.NumFields(); i++ {
   245  		var field = typeStruct.Field(i)
   246  
   247  		if !field.Exported() {
   248  			continue
   249  		}
   250  
   251  		var fieldType = field.Type()
   252  		var fieldName = field.Name()
   253  		var structFieldTags = reflect.StructTag(typeStruct.Tag(i))
   254  
   255  		location, locationFlags := getTagNameAndFlags(structFieldTags.Get("in"))
   256  
   257  		if location == "" {
   258  			if fieldName == "Body" {
   259  				location = "body"
   260  			}
   261  		}
   262  
   263  		if location == "context" {
   264  			continue
   265  		}
   266  
   267  		if field.Anonymous() {
   268  			if typeStruct, ok := fieldType.Underlying().(*types.Struct); ok {
   269  				scanner.bindParameterOrRequestBody(op, typeStruct)
   270  			}
   271  			continue
   272  		}
   273  
   274  		if location == "" {
   275  			panic(fmt.Errorf("missing tag `in` for %s of %s", fieldName, op.ID))
   276  		}
   277  
   278  		name, flags := getTagNameAndFlags(structFieldTags.Get("name"))
   279  		if name == "" {
   280  			name, flags = getTagNameAndFlags(structFieldTags.Get("json"))
   281  		}
   282  
   283  		var param *oas.Parameter
   284  
   285  		if location == "body" || location == "formData" {
   286  			op.SetRequestBody(scanner.getRequestBody(fieldType, location, locationFlags["multipart"]))
   287  			continue
   288  		}
   289  
   290  		if name == "" {
   291  			panic(fmt.Errorf("missing tag `name` or `json` for parameter %s of %s", fieldName, op.ID))
   292  		}
   293  
   294  		param = scanner.getNonBodyParameter(name, flags, location, structFieldTags, fieldType)
   295  
   296  		if param.Schema != nil && flags != nil && flags["string"] {
   297  			param.Schema.Type = oas.TypeString
   298  		}
   299  
   300  		if styleValue, hasStyle := structFieldTags.Lookup("style"); hasStyle {
   301  			param.AddExtension(XTagStyle, styleValue)
   302  		}
   303  
   304  		if fmtValue, hasFmt := structFieldTags.Lookup("fmt"); hasFmt {
   305  			param.AddExtension(XTagFmt, fmtValue)
   306  		}
   307  
   308  		param = param.WithDesc(docOfTypeName(field, scanner.program))
   309  		param.AddExtension(XField, field.Name())
   310  		op.AddNonBodyParameter(param)
   311  	}
   312  }
   313  
   314  func (scanner *OperatorScanner) getRequestBody(t types.Type, location string, isMultipart bool) *oas.RequestBody {
   315  	reqBody := oas.NewRequestBody("", true)
   316  	schema := scanner.DefinitionScanner.getSchemaByType(t)
   317  
   318  	contentType := httpx.MIMEJSON
   319  
   320  	if location == "formData" {
   321  		if isMultipart {
   322  			contentType = httpx.MIMEMultipartPOSTForm
   323  		} else {
   324  			contentType = httpx.MIMEPOSTForm
   325  		}
   326  	}
   327  
   328  	reqBody.Required = true
   329  	reqBody.AddContent(contentType, oas.NewMediaTypeWithSchema(schema))
   330  	return reqBody
   331  }
   332  
   333  func (scanner *OperatorScanner) getNonBodyParameter(name string, nameFlags transform.TagFlags, location string, tags reflect.StructTag, t types.Type) *oas.Parameter {
   334  	schema := scanner.DefinitionScanner.getSchemaByType(t)
   335  
   336  	defaultValue, hasDefault := tags.Lookup("default")
   337  	if hasDefault {
   338  		schema.Default = defaultValue
   339  	}
   340  
   341  	required := true
   342  	if hasOmitempty, ok := nameFlags["omitempty"]; ok {
   343  		required = !hasOmitempty
   344  	} else {
   345  		// todo don't use non-default as required
   346  		required = !hasDefault
   347  	}
   348  
   349  	validate, hasValidate := tags.Lookup("validate")
   350  	if hasValidate {
   351  		BindValidateFromValidateTagString(schema, validate)
   352  	}
   353  
   354  	if schema != nil && schema.Refer.RefString() != "" {
   355  		schema = oas.AllOf(
   356  			schema,
   357  			&oas.Schema{
   358  				SchemaObject:   schema.SchemaObject,
   359  				SpecExtensions: schema.SpecExtensions,
   360  			},
   361  		)
   362  	}
   363  
   364  	switch location {
   365  	case "query":
   366  		return oas.QueryParameter(name, schema, required)
   367  	case "cookie":
   368  		return oas.CookieParameter(name, schema, required)
   369  	case "header":
   370  		return oas.HeaderParameter(name, schema, required)
   371  	case "path":
   372  		return oas.PathParameter(name, schema)
   373  	}
   374  	return nil
   375  }
   376  
   377  type Operator struct {
   378  	ID                string
   379  	NonBodyParameters map[string]*oas.Parameter
   380  	RequestBody       *oas.RequestBody
   381  
   382  	StatusErrors      status_error.StatusErrorCodeMap
   383  	StatusErrorSchema *oas.Schema
   384  
   385  	Tag               string
   386  	Summary           string
   387  	SuccessType       types.Type
   388  	SuccessStatus     int
   389  	SuccessResponse   *oas.Response
   390  	WebSocketMessages map[*oas.Schema][]*oas.Schema
   391  }
   392  
   393  func (operator *Operator) AddWebSocketMessage(schema *oas.Schema, returns ...*oas.Schema) {
   394  	if operator.WebSocketMessages == nil {
   395  		operator.WebSocketMessages = map[*oas.Schema][]*oas.Schema{}
   396  	}
   397  	operator.WebSocketMessages[schema] = append(operator.WebSocketMessages[schema], returns...)
   398  }
   399  
   400  func (operator *Operator) AddNonBodyParameter(parameter *oas.Parameter) {
   401  	if operator.NonBodyParameters == nil {
   402  		operator.NonBodyParameters = map[string]*oas.Parameter{}
   403  	}
   404  	operator.NonBodyParameters[parameter.Name] = parameter
   405  }
   406  
   407  func (operator *Operator) SetRequestBody(requestBody *oas.RequestBody) {
   408  	operator.RequestBody = requestBody
   409  }
   410  
   411  func (operator *Operator) BindOperation(method string, operation *oas.Operation, last bool) {
   412  	if operator.WebSocketMessages != nil {
   413  		schema := oas.ObjectOf(nil)
   414  
   415  		for msgSchema, list := range operator.WebSocketMessages {
   416  			s := oas.ObjectOf(nil)
   417  
   418  			s.SetProperty(typeOfSchema(msgSchema), msgSchema, false)
   419  
   420  			if list != nil {
   421  				sub := oas.ObjectOf(nil)
   422  				for _, item := range list {
   423  					sub.SetProperty(typeOfSchema(item), item, false)
   424  				}
   425  				schema.SetProperty("out", sub, false)
   426  			}
   427  			schema.SetProperty("in", s, false)
   428  		}
   429  
   430  		requestBody := oas.NewRequestBody("WebSocket", true)
   431  		requestBody.AddContent(httpx.MIMEJSON, oas.NewMediaTypeWithSchema(schema))
   432  
   433  		operation.SetRequestBody(requestBody)
   434  		return
   435  	}
   436  
   437  	parameterNames := map[string]bool{}
   438  	for _, parameter := range operation.Parameters {
   439  		parameterNames[parameter.Name] = true
   440  	}
   441  
   442  	for _, parameter := range operator.NonBodyParameters {
   443  		if !parameterNames[parameter.Name] {
   444  			operation.Parameters = append(operation.Parameters, parameter)
   445  		}
   446  	}
   447  
   448  	if operator.RequestBody != nil {
   449  		operation.SetRequestBody(operator.RequestBody)
   450  	}
   451  
   452  	for code, statusError := range operator.StatusErrors {
   453  		resp := (*oas.Response)(nil)
   454  		if operation.Responses.Responses != nil {
   455  			resp = operation.Responses.Responses[statusError.Status()]
   456  		}
   457  		statusErrors := status_error.StatusErrorCodeMap{}
   458  		if resp != nil {
   459  			statusErrors = pickStatusErrorsFromDoc(resp.Description)
   460  		}
   461  		statusErrors[code] = statusError
   462  		resp = oas.NewResponse(statusErrors.String())
   463  		resp.AddContent(httpx.MIMEJSON, oas.NewMediaTypeWithSchema(operator.StatusErrorSchema))
   464  		operation.AddResponse(statusError.Status(), resp)
   465  	}
   466  
   467  	if last {
   468  		operation.OperationId = operator.ID
   469  		docs := strings.Split(operator.Summary, "\n")
   470  		if operator.Tag != "" {
   471  			operation.Tags = []string{operator.Tag}
   472  		}
   473  		operation.Summary = docs[0]
   474  		if len(docs) > 1 {
   475  			operation.Description = strings.Join(docs[1:], "\n")
   476  		}
   477  		if operator.SuccessType == nil {
   478  			operation.Responses.AddResponse(http.StatusNoContent, &oas.Response{})
   479  		} else {
   480  			status := operator.SuccessStatus
   481  			if status == 0 {
   482  				status = http.StatusOK
   483  				if method == http.MethodPost {
   484  					status = http.StatusCreated
   485  				}
   486  			}
   487  			if status >= http.StatusMultipleChoices && status < http.StatusBadRequest {
   488  				operator.SuccessResponse = oas.NewResponse(operator.SuccessResponse.Description)
   489  			}
   490  			operation.Responses.AddResponse(status, operator.SuccessResponse)
   491  		}
   492  	}
   493  }
   494  
   495  func typeOfSchema(schema *oas.Schema) string {
   496  	l := strings.Split(schema.Refer.RefString(), "/")
   497  	return l[len(l)-1]
   498  }