github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/courier/swagger/gen/operator_scanner.go (about)

     1  package gen
     2  
     3  import (
     4  	"fmt"
     5  	"go/ast"
     6  	"go/build"
     7  	"go/types"
     8  	"net/http"
     9  	"os"
    10  	"path/filepath"
    11  	"reflect"
    12  	"runtime/debug"
    13  	"strings"
    14  
    15  	"github.com/johnnyeven/libtools/courier/transport_http/transform"
    16  
    17  	"github.com/morlay/oas"
    18  	"github.com/sirupsen/logrus"
    19  	"golang.org/x/tools/go/loader"
    20  
    21  	"github.com/johnnyeven/libtools/codegen/loaderx"
    22  	"github.com/johnnyeven/libtools/courier/httpx"
    23  	"github.com/johnnyeven/libtools/courier/status_error"
    24  	"github.com/johnnyeven/libtools/courier/transport_http"
    25  )
    26  
    27  func FullNameOfType(tpe reflect.Type) string {
    28  	return fmt.Sprintf("%s.%s", tpe.PkgPath(), tpe.Name())
    29  }
    30  
    31  var TypeWebSocketListeners = FullNameOfType(reflect.TypeOf(transport_http.Listeners{}))
    32  var TypeWebSocketClient = FullNameOfType(reflect.TypeOf(transport_http.WSClient{}))
    33  
    34  func ConcatToOperation(method string, operators ...Operator) *oas.Operation {
    35  	operation := &oas.Operation{}
    36  	length := len(operators)
    37  	for idx, operator := range operators {
    38  		operator.BindOperation(method, operation, idx == length-1)
    39  	}
    40  	return operation
    41  }
    42  
    43  func NewOperatorScanner(program *loader.Program) *OperatorScanner {
    44  	return &OperatorScanner{
    45  		DefinitionScanner:  NewDefinitionScanner(program),
    46  		StatusErrorScanner: NewStatusErrorScanner(program),
    47  		program:            program,
    48  	}
    49  }
    50  
    51  type OperatorScanner struct {
    52  	*DefinitionScanner
    53  	*StatusErrorScanner
    54  	program   *loader.Program
    55  	operators map[*types.TypeName]Operator
    56  }
    57  
    58  func (scanner *OperatorScanner) Operator(typeName *types.TypeName) *Operator {
    59  	if typeName == nil {
    60  		return nil
    61  	}
    62  
    63  	if operator, ok := scanner.operators[typeName]; ok {
    64  		return &operator
    65  	}
    66  
    67  	defer func() {
    68  		if e := recover(); e != nil {
    69  			logrus.Errorf("scan Operator `%v` failed, panic: %s; calltrace: %s", typeName, fmt.Sprint(e), string(debug.Stack()))
    70  		}
    71  	}()
    72  
    73  	if typeStruct, ok := typeName.Type().Underlying().(*types.Struct); ok {
    74  		operator := Operator{
    75  			ID:  typeName.Name(),
    76  			Tag: getTagNameByPkgPath(typeName.Pkg().Path()),
    77  		}
    78  
    79  		scanner.bindParameterOrRequestBody(&operator, typeStruct)
    80  		scanner.bindReturns(&operator, typeName)
    81  
    82  		if scanner.operators == nil {
    83  			scanner.operators = map[*types.TypeName]Operator{}
    84  		}
    85  
    86  		operator.Summary = docOfTypeName(typeName, scanner.program)
    87  
    88  		scanner.operators[typeName] = operator
    89  
    90  		return &operator
    91  	}
    92  
    93  	return nil
    94  }
    95  
    96  func getTagNameByPkgPath(pkgPath string) string {
    97  	cwd, _ := os.Getwd()
    98  	p, _ := build.Default.Import(pkgPath, "", build.FindOnly)
    99  	tag, _ := filepath.Rel(cwd, p.Dir)
   100  	i := strings.Index(tag, "routes/")
   101  	if i >= 0 {
   102  		tag = string([]byte(tag)[i:])
   103  	}
   104  	return strings.Replace(tag, "routes/", "", 1)
   105  }
   106  
   107  func (scanner *OperatorScanner) bindWebSocketMessages(op *Operator, schema *oas.Schema, typeVar *types.Var) {
   108  	if strings.Contains(typeVar.Type().String(), TypeWebSocketClient) {
   109  		for pkg, pkgInfo := range scanner.program.AllPackages {
   110  			if pkg == typeVar.Pkg() {
   111  				for selectExpr := range pkgInfo.Selections {
   112  					if ident, ok := selectExpr.X.(*ast.Ident); ok {
   113  						if pkgInfo.ObjectOf(ident) == typeVar && "Send" == selectExpr.Sel.Name {
   114  							file := loaderx.FileOf(selectExpr, pkgInfo.Files...)
   115  							ast.Inspect(file, func(node ast.Node) bool {
   116  								switch node.(type) {
   117  								case *ast.CallExpr:
   118  									callExpr := node.(*ast.CallExpr)
   119  									if callExpr.Fun == selectExpr {
   120  										tpe := pkgInfo.TypeOf(callExpr.Args[0])
   121  										subSchema := scanner.getSchemaByType(tpe.(*types.Named))
   122  										op.AddWebSocketMessage(schema, subSchema)
   123  										return false
   124  									}
   125  								}
   126  								return true
   127  							})
   128  						}
   129  					}
   130  				}
   131  			}
   132  		}
   133  	}
   134  }
   135  
   136  func (scanner *OperatorScanner) bindWebSocketListeners(op *Operator, typeFunc *types.Func) {
   137  	scope := typeFunc.Scope()
   138  	for _, name := range scope.Names() {
   139  		n := scope.Lookup(name)
   140  		if strings.Contains(n.Type().String(), TypeWebSocketListeners) {
   141  			for pkg, pkgInfo := range scanner.program.AllPackages {
   142  				if pkg == n.Pkg() {
   143  					for selectExpr := range pkgInfo.Selections {
   144  						if ident, ok := selectExpr.X.(*ast.Ident); ok {
   145  							if pkgInfo.ObjectOf(ident) == n && "On" == selectExpr.Sel.Name {
   146  								file := loaderx.FileOf(selectExpr, pkgInfo.Files...)
   147  								ast.Inspect(file, func(node ast.Node) bool {
   148  									switch node.(type) {
   149  									case *ast.CallExpr:
   150  										callExpr := node.(*ast.CallExpr)
   151  										if callExpr.Fun == selectExpr {
   152  											tpe := pkgInfo.TypeOf(callExpr.Args[0])
   153  											schema := scanner.getSchemaByType(tpe.(*types.Named))
   154  											op.AddWebSocketMessage(schema)
   155  
   156  											params := pkgInfo.TypeOf(callExpr.Args[1]).(*types.Signature).Params()
   157  
   158  											for i := 0; i < params.Len(); i++ {
   159  												scanner.bindWebSocketMessages(op, schema, params.At(i))
   160  											}
   161  											return false
   162  										}
   163  									}
   164  									return true
   165  								})
   166  							}
   167  						}
   168  					}
   169  				}
   170  			}
   171  		}
   172  	}
   173  }
   174  
   175  func (scanner *OperatorScanner) bindReturns(op *Operator, typeName *types.TypeName) {
   176  	typeFunc := loaderx.MethodOf(typeName.Type().(*types.Named), "Output")
   177  
   178  	if typeFunc != nil {
   179  		metaData := ParseSuccessMetadata(docOfTypeName(typeFunc, scanner.program))
   180  
   181  		loaderx.ForEachFuncResult(scanner.program, typeFunc, func(resultTypeAndValues ...types.TypeAndValue) {
   182  			successType := resultTypeAndValues[0].Type
   183  
   184  			if strings.Contains(successType.String(), TypeWebSocketListeners) {
   185  				scanner.bindWebSocketListeners(op, typeFunc)
   186  				return
   187  			}
   188  
   189  			if successType.String() != types.Typ[types.UntypedNil].String() {
   190  				if op.SuccessType != nil && op.SuccessType.String() != successType.String() {
   191  					logrus.Warnf(fmt.Sprintf("%s success result must be same struct, but got %v, already set %v", op.ID, successType, op.SuccessType))
   192  				}
   193  				op.SuccessType = successType
   194  				op.SuccessStatus, op.SuccessResponse = scanner.getResponse(successType, metaData.Get("content-type"))
   195  			}
   196  
   197  			op.StatusErrors = scanner.StatusErrorScanner.StatusErrorsInFunc(typeFunc)
   198  			op.StatusErrorSchema = scanner.DefinitionScanner.getSchemaByTypeString(statusErrorTypeString)
   199  		})
   200  	}
   201  }
   202  
   203  func (scanner *OperatorScanner) getResponse(tpe types.Type, contentType string) (status int, response *oas.Response) {
   204  	response = &oas.Response{}
   205  
   206  	if tpe.String() == "error" {
   207  		status = http.StatusNoContent
   208  		return
   209  	}
   210  
   211  	if contentType == "" {
   212  		contentType = httpx.MIMEJSON
   213  	}
   214  
   215  	if pointer, ok := tpe.(*types.Pointer); ok {
   216  		tpe = pointer.Elem()
   217  	}
   218  
   219  	if named, ok := tpe.(*types.Named); ok {
   220  		{
   221  			typeFunc := loaderx.MethodOf(named, "ContentType")
   222  			if typeFunc != nil {
   223  				loaderx.ForEachFuncResult(scanner.program, typeFunc, func(resultTypeAndValues ...types.TypeAndValue) {
   224  					if resultTypeAndValues[0].IsValue() {
   225  						contentType = getConstVal(resultTypeAndValues[0].Value).(string)
   226  					}
   227  				})
   228  			}
   229  		}
   230  
   231  		{
   232  			typeFunc := loaderx.MethodOf(named, "Status")
   233  			if typeFunc != nil {
   234  				loaderx.ForEachFuncResult(scanner.program, typeFunc, func(resultTypeAndValues ...types.TypeAndValue) {
   235  					if resultTypeAndValues[0].IsValue() {
   236  						status = int(getConstVal(resultTypeAndValues[0].Value).(int64))
   237  					}
   238  				})
   239  			}
   240  		}
   241  	}
   242  
   243  	response.AddContent(contentType, oas.NewMediaTypeWithSchema(scanner.DefinitionScanner.getSchemaByType(tpe)))
   244  
   245  	return
   246  }
   247  
   248  func (scanner *OperatorScanner) bindParameterOrRequestBody(op *Operator, typeStruct *types.Struct) {
   249  	for i := 0; i < typeStruct.NumFields(); i++ {
   250  		var field = typeStruct.Field(i)
   251  
   252  		if !field.Exported() {
   253  			continue
   254  		}
   255  
   256  		var fieldType = field.Type()
   257  		var fieldName = field.Name()
   258  		var structFieldTags = reflect.StructTag(typeStruct.Tag(i))
   259  
   260  		location, locationFlags := getTagNameAndFlags(structFieldTags.Get("in"))
   261  
   262  		if location == "" {
   263  			if fieldName == "Body" {
   264  				location = "body"
   265  			}
   266  		}
   267  
   268  		if location == "context" {
   269  			continue
   270  		}
   271  
   272  		if field.Anonymous() {
   273  			if typeStruct, ok := fieldType.Underlying().(*types.Struct); ok {
   274  				scanner.bindParameterOrRequestBody(op, typeStruct)
   275  			}
   276  			continue
   277  		}
   278  
   279  		if location == "" {
   280  			panic(fmt.Errorf("missing tag `in` for %s of %s", fieldName, op.ID))
   281  		}
   282  
   283  		name, flags := getTagNameAndFlags(structFieldTags.Get("name"))
   284  		if name == "" {
   285  			name, flags = getTagNameAndFlags(structFieldTags.Get("json"))
   286  		}
   287  
   288  		var param *oas.Parameter
   289  
   290  		if location == "body" || location == "formData" {
   291  			op.SetRequestBody(scanner.getRequestBody(fieldType, location, locationFlags["multipart"]))
   292  			continue
   293  		}
   294  
   295  		if name == "" {
   296  			panic(fmt.Errorf("missing tag `name` or `json` for parameter %s of %s", fieldName, op.ID))
   297  		}
   298  
   299  		param = scanner.getNonBodyParameter(name, flags, location, structFieldTags, fieldType)
   300  
   301  		if param.Schema != nil && flags != nil && flags["string"] {
   302  			param.Schema.Type = oas.TypeString
   303  		}
   304  
   305  		if styleValue, hasStyle := structFieldTags.Lookup("style"); hasStyle {
   306  			param.AddExtension(XTagStyle, styleValue)
   307  		}
   308  
   309  		if fmtValue, hasFmt := structFieldTags.Lookup("fmt"); hasFmt {
   310  			param.AddExtension(XTagFmt, fmtValue)
   311  		}
   312  
   313  		param = param.WithDesc(docOfTypeName(field, scanner.program))
   314  		param.AddExtension(XField, field.Name())
   315  		op.AddNonBodyParameter(param)
   316  	}
   317  }
   318  
   319  func (scanner *OperatorScanner) getRequestBody(t types.Type, location string, isMultipart bool) *oas.RequestBody {
   320  	reqBody := oas.NewRequestBody("", true)
   321  	schema := scanner.DefinitionScanner.getSchemaByType(t)
   322  
   323  	contentType := httpx.MIMEJSON
   324  
   325  	if location == "formData" {
   326  		if isMultipart {
   327  			contentType = httpx.MIMEMultipartPOSTForm
   328  		} else {
   329  			contentType = httpx.MIMEPOSTForm
   330  		}
   331  	}
   332  
   333  	reqBody.Required = true
   334  	reqBody.AddContent(contentType, oas.NewMediaTypeWithSchema(schema))
   335  	return reqBody
   336  }
   337  
   338  func (scanner *OperatorScanner) getNonBodyParameter(name string, nameFlags transform.TagFlags, location string, tags reflect.StructTag, t types.Type) *oas.Parameter {
   339  	schema := scanner.DefinitionScanner.getSchemaByType(t)
   340  
   341  	defaultValue, hasDefault := tags.Lookup("default")
   342  	if hasDefault {
   343  		schema.Default = defaultValue
   344  	}
   345  
   346  	required := true
   347  	if hasOmitempty, ok := nameFlags["omitempty"]; ok {
   348  		required = !hasOmitempty
   349  	} else {
   350  		// todo don't use non-default as required
   351  		required = !hasDefault
   352  	}
   353  
   354  	validate, hasValidate := tags.Lookup("validate")
   355  	if hasValidate {
   356  		BindValidateFromValidateTagString(schema, validate)
   357  	}
   358  
   359  	if schema != nil && schema.Ref != "" {
   360  		schema = oas.AllOf(
   361  			schema,
   362  			&oas.Schema{
   363  				SchemaObject:   schema.SchemaObject,
   364  				SpecExtensions: schema.SpecExtensions,
   365  			},
   366  		)
   367  	}
   368  
   369  	switch location {
   370  	case "query":
   371  		return oas.QueryParameter(name, schema, required)
   372  	case "cookie":
   373  		return oas.CookieParameter(name, schema, required)
   374  	case "header":
   375  		return oas.HeaderParameter(name, schema, required)
   376  	case "path":
   377  		return oas.PathParameter(name, schema)
   378  	}
   379  	return nil
   380  }
   381  
   382  type Operator struct {
   383  	ID                string
   384  	NonBodyParameters map[string]*oas.Parameter
   385  	RequestBody       *oas.RequestBody
   386  
   387  	StatusErrors      status_error.StatusErrorCodeMap
   388  	StatusErrorSchema *oas.Schema
   389  
   390  	Tag               string
   391  	Summary           string
   392  	SuccessType       types.Type
   393  	SuccessStatus     int
   394  	SuccessResponse   *oas.Response
   395  	WebSocketMessages map[*oas.Schema][]*oas.Schema
   396  }
   397  
   398  func (operator *Operator) AddWebSocketMessage(schema *oas.Schema, returns ...*oas.Schema) {
   399  	if operator.WebSocketMessages == nil {
   400  		operator.WebSocketMessages = map[*oas.Schema][]*oas.Schema{}
   401  	}
   402  	operator.WebSocketMessages[schema] = append(operator.WebSocketMessages[schema], returns...)
   403  }
   404  
   405  func (operator *Operator) AddNonBodyParameter(parameter *oas.Parameter) {
   406  	if operator.NonBodyParameters == nil {
   407  		operator.NonBodyParameters = map[string]*oas.Parameter{}
   408  	}
   409  	operator.NonBodyParameters[parameter.Name] = parameter
   410  }
   411  
   412  func (operator *Operator) SetRequestBody(requestBody *oas.RequestBody) {
   413  	operator.RequestBody = requestBody
   414  }
   415  
   416  func (operator *Operator) BindOperation(method string, operation *oas.Operation, last bool) {
   417  	if operator.WebSocketMessages != nil {
   418  		schema := oas.ObjectOf(nil)
   419  
   420  		for msgSchema, list := range operator.WebSocketMessages {
   421  			s := oas.ObjectOf(nil)
   422  
   423  			s.SetProperty(typeOfSchema(msgSchema), msgSchema, false)
   424  
   425  			if list != nil {
   426  				sub := oas.ObjectOf(nil)
   427  				for _, item := range list {
   428  					sub.SetProperty(typeOfSchema(item), item, false)
   429  				}
   430  				schema.SetProperty("out", sub, false)
   431  			}
   432  			schema.SetProperty("in", s, false)
   433  		}
   434  
   435  		requestBody := oas.NewRequestBody("WebSocket", true)
   436  		requestBody.AddContent(httpx.MIMEJSON, oas.NewMediaTypeWithSchema(schema))
   437  
   438  		operation.SetRequestBody(requestBody)
   439  		return
   440  	}
   441  
   442  	parameterNames := map[string]bool{}
   443  	for _, parameter := range operation.Parameters {
   444  		parameterNames[parameter.Name] = true
   445  	}
   446  
   447  	for _, parameter := range operator.NonBodyParameters {
   448  		if !parameterNames[parameter.Name] {
   449  			operation.Parameters = append(operation.Parameters, parameter)
   450  		}
   451  	}
   452  
   453  	if operator.RequestBody != nil {
   454  		operation.SetRequestBody(operator.RequestBody)
   455  	}
   456  
   457  	for code, statusError := range operator.StatusErrors {
   458  		resp := (*oas.Response)(nil)
   459  		if operation.Responses.Responses != nil {
   460  			resp = operation.Responses.Responses[statusError.Status()]
   461  		}
   462  		statusErrors := status_error.StatusErrorCodeMap{}
   463  		if resp != nil {
   464  			statusErrors = pickStatusErrorsFromDoc(resp.Description)
   465  		}
   466  		statusErrors[code] = statusError
   467  		resp = oas.NewResponse(statusErrors.String())
   468  		resp.AddContent(httpx.MIMEJSON, oas.NewMediaTypeWithSchema(operator.StatusErrorSchema))
   469  		operation.AddResponse(statusError.Status(), resp)
   470  	}
   471  
   472  	if last {
   473  		operation.OperationId = operator.ID
   474  		docs := strings.Split(operator.Summary, "\n")
   475  		if operator.Tag != "" {
   476  			operation.Tags = []string{operator.Tag}
   477  		}
   478  		operation.Summary = docs[0]
   479  		if len(docs) > 1 {
   480  			operation.Description = strings.Join(docs[1:], "\n")
   481  		}
   482  		if operator.SuccessType == nil {
   483  			operation.Responses.AddResponse(http.StatusNoContent, &oas.Response{})
   484  		} else {
   485  			status := operator.SuccessStatus
   486  			if status == 0 {
   487  				status = http.StatusOK
   488  				if method == http.MethodPost {
   489  					status = http.StatusCreated
   490  				}
   491  			}
   492  			if status >= http.StatusMultipleChoices && status < http.StatusBadRequest {
   493  				operator.SuccessResponse = oas.NewResponse(operator.SuccessResponse.Description)
   494  			}
   495  			operation.Responses.AddResponse(status, operator.SuccessResponse)
   496  		}
   497  	}
   498  }
   499  
   500  func typeOfSchema(schema *oas.Schema) string {
   501  	l := strings.Split(schema.Ref, "/")
   502  	return l[len(l)-1]
   503  }