github.com/profzone/eden-framework@v1.0.10/internal/generator/scanner/operator_scanner.go (about)

     1  package scanner
     2  
     3  import (
     4  	"fmt"
     5  	"github.com/go-courier/oas"
     6  	"github.com/profzone/eden-framework/pkg/courier/httpx"
     7  	"github.com/profzone/eden-framework/pkg/courier/status_error"
     8  	"github.com/profzone/eden-framework/pkg/courier/transport_http"
     9  	"github.com/profzone/eden-framework/pkg/packagex"
    10  	"github.com/profzone/eden-framework/pkg/reflectx"
    11  	"github.com/sirupsen/logrus"
    12  	"go/ast"
    13  	"go/types"
    14  	"net/http"
    15  	"reflect"
    16  	"runtime/debug"
    17  	"sort"
    18  	"strconv"
    19  	"strings"
    20  )
    21  
    22  type OperatorScanner struct {
    23  	*DefinitionScanner
    24  	*StatusErrScanner
    25  	pkg       *packagex.Package
    26  	operators map[*types.TypeName]*Operator
    27  }
    28  
    29  func NewOperatorScanner(pkg *packagex.Package) *OperatorScanner {
    30  	return &OperatorScanner{
    31  		pkg:               pkg,
    32  		DefinitionScanner: NewDefinitionScanner(pkg),
    33  		StatusErrScanner:  NewStatusErrScanner(pkg),
    34  	}
    35  }
    36  
    37  func (scanner *OperatorScanner) Operator(typeName *types.TypeName) *Operator {
    38  	if typeName == nil {
    39  		return nil
    40  	}
    41  
    42  	if operator, ok := scanner.operators[typeName]; ok {
    43  		return operator
    44  	}
    45  
    46  	logrus.Debugf("scanning Operator `%s.%s`", typeName.Pkg().Path(), typeName.Name())
    47  
    48  	defer func() {
    49  		if e := recover(); e != nil {
    50  			panic(fmt.Errorf("scan Operator `%s` failed, panic: %s; calltrace: %s", fullTypeName(typeName), fmt.Sprint(e), string(debug.Stack())))
    51  		}
    52  	}()
    53  
    54  	if typeStruct, ok := typeName.Type().Underlying().(*types.Struct); ok {
    55  		operator := &Operator{}
    56  
    57  		operator.Tag = scanner.tagFrom(typeName.Pkg().Path())
    58  
    59  		scanner.scanRouteMeta(operator, typeName)
    60  		scanner.scanParameterOrRequestBody(operator, typeStruct)
    61  		scanner.scanReturns(operator, typeName)
    62  
    63  		// cached scanned
    64  		if scanner.operators == nil {
    65  			scanner.operators = map[*types.TypeName]*Operator{}
    66  		}
    67  
    68  		scanner.operators[typeName] = operator
    69  
    70  		return operator
    71  	}
    72  
    73  	return nil
    74  }
    75  
    76  func (scanner *OperatorScanner) singleReturnOf(typeName *types.TypeName, name string) (string, bool) {
    77  	if typeName == nil {
    78  		return "", false
    79  	}
    80  
    81  	for _, typ := range []types.Type{
    82  		typeName.Type(),
    83  		types.NewPointer(typeName.Type()),
    84  	} {
    85  		method, ok := reflectx.FromTType(typ).MethodByName(name)
    86  		if ok {
    87  			results, n := scanner.pkg.FuncResultsOf(method.(*reflectx.TMethod).Func)
    88  			if n == 1 {
    89  				for _, v := range results[0] {
    90  					if v.Value != nil {
    91  						s, err := strconv.Unquote(v.Value.ExactString())
    92  						if err != nil {
    93  							panic(fmt.Errorf("%s: %s", err, v.Value))
    94  						}
    95  						return s, true
    96  					}
    97  				}
    98  			}
    99  		}
   100  	}
   101  
   102  	return "", false
   103  }
   104  
   105  func (scanner *OperatorScanner) tagFrom(pkgPath string) string {
   106  	tag := strings.TrimPrefix(pkgPath, scanner.pkg.PkgPath)
   107  	return strings.TrimPrefix(tag, "/")
   108  }
   109  
   110  func (scanner *OperatorScanner) scanRouteMeta(op *Operator, typeName *types.TypeName) {
   111  	typeStruct := typeName.Type().Underlying().(*types.Struct)
   112  
   113  	op.ID = typeName.Name()
   114  
   115  	for i := 0; i < typeStruct.NumFields(); i++ {
   116  		f := typeStruct.Field(i)
   117  		tags := reflect.StructTag(typeStruct.Tag(i))
   118  
   119  		if f.Anonymous() && strings.Contains(f.Type().String(), pkgImportPathHttpx+".Method") {
   120  			if path, ok := tags.Lookup("path"); ok {
   121  				vs := strings.Split(path, ",")
   122  				op.Path = vs[0]
   123  
   124  				if len(vs) > 0 {
   125  					for i := range vs {
   126  						switch vs[i] {
   127  						case "deprecated":
   128  							op.Deprecated = true
   129  							break
   130  						}
   131  					}
   132  				}
   133  			}
   134  
   135  			if basePath, ok := tags.Lookup("basePath"); ok {
   136  				op.BasePath = basePath
   137  			}
   138  
   139  			if summary, ok := tags.Lookup("summary"); ok {
   140  				op.Summary = summary
   141  			}
   142  
   143  			break
   144  		}
   145  	}
   146  
   147  	lines := scanner.pkg.CommentsOf(scanner.pkg.IdentOf(typeName))
   148  	comments := strings.Split(lines, "\n")
   149  
   150  	for i := range comments {
   151  		if strings.Index(comments[i], "@deprecated") != -1 {
   152  			op.Deprecated = true
   153  		}
   154  	}
   155  
   156  	if op.Summary == "" {
   157  		comments = filterMarkedLines(comments)
   158  
   159  		if comments[0] != "" {
   160  			op.Summary = comments[0]
   161  			if len(comments) > 1 {
   162  				op.Description = strings.Join(comments[1:], "\n")
   163  			}
   164  		}
   165  	}
   166  
   167  	if method, ok := scanner.singleReturnOf(typeName, "Method"); ok {
   168  		op.Method = method
   169  	}
   170  
   171  	if path, ok := scanner.singleReturnOf(typeName, "Path"); ok {
   172  		op.Path = path
   173  	}
   174  
   175  	if bathPath, ok := scanner.singleReturnOf(typeName, "BasePath"); ok {
   176  		op.BasePath = bathPath
   177  	}
   178  }
   179  
   180  func (scanner *OperatorScanner) scanReturns(op *Operator, typeName *types.TypeName) {
   181  	for _, typ := range []types.Type{
   182  		typeName.Type(),
   183  		types.NewPointer(typeName.Type()),
   184  	} {
   185  		method, ok := reflectx.FromTType(typ).MethodByName("Output")
   186  		if ok {
   187  			results, n := scanner.pkg.FuncResultsOf(method.(*reflectx.TMethod).Func)
   188  			if n == 2 {
   189  				for _, v := range results[0] {
   190  					if v.Type != nil {
   191  						if v.Type.String() != types.Typ[types.UntypedNil].String() {
   192  							if op.SuccessType != nil && op.SuccessType.String() != v.Type.String() {
   193  								logrus.Warnf(fmt.Sprintf("%s success result must be same struct, but got %v, already set %v", op.ID, v.Type, op.SuccessType))
   194  							}
   195  							op.SuccessType = v.Type
   196  							op.SuccessStatus, op.SuccessResponse = scanner.getResponse(v.Type, v.Expr)
   197  						}
   198  					}
   199  				}
   200  			}
   201  
   202  			if scanner.StatusErrScanner.StatusErrType != nil {
   203  				op.StatusErrors = scanner.StatusErrScanner.StatusErrorsInFunc(method.(*reflectx.TMethod).Func)
   204  				op.StatusErrorSchema = scanner.DefinitionScanner.GetSchemaByType(scanner.StatusErrScanner.StatusErrType)
   205  			}
   206  		}
   207  	}
   208  }
   209  
   210  func (scanner *OperatorScanner) firstValueOfFunc(named *types.Named, name string) (interface{}, bool) {
   211  	method, ok := reflectx.FromTType(types.NewPointer(named)).MethodByName(name)
   212  	if ok {
   213  		results, n := scanner.pkg.FuncResultsOf(method.(*reflectx.TMethod).Func)
   214  		if n == 1 {
   215  			for _, r := range results[0] {
   216  				if r.IsValue() {
   217  					if v := valueOf(r.Value); v != nil {
   218  						return v, true
   219  					}
   220  				}
   221  			}
   222  			return nil, true
   223  		}
   224  	}
   225  	return nil, false
   226  }
   227  
   228  func (scanner *OperatorScanner) getResponse(tpe types.Type, expr ast.Expr) (statusCode int, response *oas.Response) {
   229  	response = &oas.Response{}
   230  
   231  	if tpe.String() == "error" {
   232  		statusCode = http.StatusNoContent
   233  		return
   234  	}
   235  
   236  	var contentType string
   237  
   238  	if pointer, ok := tpe.(*types.Pointer); ok {
   239  		tpe = pointer.Elem()
   240  	}
   241  
   242  	if named, ok := tpe.(*types.Named); ok {
   243  		if v, ok := scanner.firstValueOfFunc(named, "ContentType"); ok {
   244  			if s, ok := v.(string); ok {
   245  				contentType = s
   246  			}
   247  			if contentType == "" {
   248  				contentType = "*"
   249  			}
   250  		}
   251  		if v, ok := scanner.firstValueOfFunc(named, "StatusCode"); ok {
   252  			if i, ok := v.(int64); ok {
   253  				statusCode = int(i)
   254  			}
   255  		}
   256  	}
   257  
   258  	if contentType == "" {
   259  		contentType = httpx.MIME_JSON
   260  	}
   261  
   262  	response.AddContent(contentType, oas.NewMediaTypeWithSchema(scanner.DefinitionScanner.GetSchemaByType(tpe)))
   263  
   264  	return
   265  }
   266  
   267  func (scanner *OperatorScanner) scanParameterOrRequestBody(op *Operator, typeStruct *types.Struct) {
   268  	reflectx.EachField(reflectx.FromTType(typeStruct), "name", func(field reflectx.StructField, fieldDisplayName string, omitempty bool) bool {
   269  		location, _ := tagValueAndFlagsByTagString(field.Tag().Get("in"))
   270  
   271  		if location == "" {
   272  			panic(fmt.Errorf("missing tag `in` for %s of %s", field.Name(), op.ID))
   273  		}
   274  
   275  		name, flags := tagValueAndFlagsByTagString(field.Tag().Get("name"))
   276  
   277  		schema := scanner.DefinitionScanner.propSchemaByField(
   278  			field.Name(),
   279  			field.Type().(*reflectx.TType).Type,
   280  			field.Tag(),
   281  			name,
   282  			flags,
   283  			scanner.pkg.CommentsOf(scanner.pkg.IdentOf(field.(*reflectx.TStructField).Var)),
   284  		)
   285  
   286  		//transformer, err := transform.TransformerMgrDefault.NewTransformer(nil, field.Type(), transform.TransformerOption{
   287  		//	MIME: field.Tag().Get("mime"),
   288  		//})
   289  
   290  		//if err != nil {
   291  		//	panic(err)
   292  		//}
   293  
   294  		switch location {
   295  		case "body":
   296  			reqBody := oas.NewRequestBody("", true)
   297  			// TODO
   298  			reqBody.AddContent("application/json", oas.NewMediaTypeWithSchema(schema))
   299  			op.SetRequestBody(reqBody)
   300  		case "query":
   301  			op.AddNonBodyParameter(oas.QueryParameter(fieldDisplayName, schema, !omitempty))
   302  		case "cookie":
   303  			op.AddNonBodyParameter(oas.CookieParameter(fieldDisplayName, schema, !omitempty))
   304  		case "header":
   305  			op.AddNonBodyParameter(oas.HeaderParameter(fieldDisplayName, schema, !omitempty))
   306  		case "path":
   307  			op.AddNonBodyParameter(oas.PathParameter(fieldDisplayName, schema))
   308  		}
   309  
   310  		return true
   311  	}, "in")
   312  }
   313  
   314  type Operator struct {
   315  	transport_http.RouteMeta
   316  
   317  	Tag         string
   318  	Description string
   319  
   320  	NonBodyParameters map[string]*oas.Parameter
   321  	RequestBody       *oas.RequestBody
   322  
   323  	StatusErrors      []*status_error.StatusError
   324  	StatusErrorSchema *oas.Schema
   325  
   326  	SuccessStatus   int
   327  	SuccessType     types.Type
   328  	SuccessResponse *oas.Response
   329  }
   330  
   331  func (operator *Operator) AddNonBodyParameter(parameter *oas.Parameter) {
   332  	if operator.NonBodyParameters == nil {
   333  		operator.NonBodyParameters = map[string]*oas.Parameter{}
   334  	}
   335  	operator.NonBodyParameters[parameter.Name] = parameter
   336  }
   337  
   338  func (operator *Operator) SetRequestBody(requestBody *oas.RequestBody) {
   339  	operator.RequestBody = requestBody
   340  }
   341  
   342  func (operator *Operator) BindOperation(method string, operation *oas.Operation, last bool) {
   343  	parameterNames := map[string]bool{}
   344  	for _, parameter := range operation.Parameters {
   345  		parameterNames[parameter.Name] = true
   346  	}
   347  
   348  	for _, parameter := range operator.NonBodyParameters {
   349  		if !parameterNames[parameter.Name] {
   350  			operation.Parameters = append(operation.Parameters, parameter)
   351  		}
   352  	}
   353  
   354  	if operator.RequestBody != nil {
   355  		operation.SetRequestBody(operator.RequestBody)
   356  	}
   357  
   358  	for _, statusError := range operator.StatusErrors {
   359  		statusErrorList := make([]string, 0)
   360  
   361  		if operation.Responses.Responses != nil {
   362  			if resp, ok := operation.Responses.Responses[statusError.Status()]; ok {
   363  				if resp.Extensions != nil {
   364  					if v, ok := resp.Extensions[XStatusErrs]; ok {
   365  						if list, ok := v.([]string); ok {
   366  							statusErrorList = append(statusErrorList, list...)
   367  						}
   368  					}
   369  				}
   370  			}
   371  		}
   372  
   373  		statusErrorList = append(statusErrorList, statusError.String())
   374  
   375  		sort.Strings(statusErrorList)
   376  
   377  		resp := oas.NewResponse("")
   378  		resp.AddExtension(XStatusErrs, statusErrorList)
   379  		resp.AddContent(httpx.MIME_JSON, oas.NewMediaTypeWithSchema(operator.StatusErrorSchema))
   380  		operation.AddResponse(statusError.Status(), resp)
   381  	}
   382  
   383  	if last {
   384  		operation.OperationId = operator.ID
   385  		operation.Deprecated = operator.Deprecated
   386  		operation.Summary = operator.Summary
   387  		operation.Description = operator.Description
   388  
   389  		if operator.Tag != "" {
   390  			operation.Tags = []string{operator.Tag}
   391  		}
   392  
   393  		if operator.SuccessType == nil {
   394  			operation.Responses.AddResponse(http.StatusNoContent, &oas.Response{})
   395  		} else {
   396  			status := operator.SuccessStatus
   397  			if status == 0 {
   398  				status = http.StatusOK
   399  				if method == http.MethodPost {
   400  					status = http.StatusCreated
   401  				}
   402  			}
   403  			if status >= http.StatusMultipleChoices && status < http.StatusBadRequest {
   404  				operator.SuccessResponse = oas.NewResponse(operator.SuccessResponse.Description)
   405  			}
   406  			operation.Responses.AddResponse(status, operator.SuccessResponse)
   407  		}
   408  	}
   409  
   410  	// sort all parameters by postion and name
   411  	if len(operation.Parameters) > 0 {
   412  		sort.Slice(operation.Parameters, func(i, j int) bool {
   413  			return positionOrders[operation.Parameters[i].In]+operation.Parameters[i].Name <
   414  				positionOrders[operation.Parameters[j].In]+operation.Parameters[j].Name
   415  		})
   416  	}
   417  }