github.com/profzone/eden-framework@v1.0.10/internal/generator/openapi_generator.go (about)

     1  package generator
     2  
     3  import (
     4  	"encoding/json"
     5  	"fmt"
     6  	"github.com/go-courier/oas"
     7  	"github.com/profzone/eden-framework/internal/generator/scanner"
     8  	"github.com/profzone/eden-framework/pkg/packagex"
     9  	"github.com/sirupsen/logrus"
    10  	"go/ast"
    11  	"go/types"
    12  	"os"
    13  	"path"
    14  	"regexp"
    15  	"strings"
    16  )
    17  
    18  type OpenApiGenerator struct {
    19  	api           *oas.OpenAPI
    20  	pkg           *packagex.Package
    21  	routerScanner *scanner.RouterScanner
    22  }
    23  
    24  func NewOpenApiGenerator() *OpenApiGenerator {
    25  	return &OpenApiGenerator{
    26  		api: oas.NewOpenAPI(),
    27  	}
    28  }
    29  
    30  func (a *OpenApiGenerator) Load(cwd string) {
    31  	entryPath := path.Join(cwd, "cmd")
    32  	_, err := os.Stat(entryPath)
    33  	if err != nil {
    34  		if !os.IsExist(err) {
    35  			logrus.Panicf("entry path does not exist: %s", entryPath)
    36  		}
    37  	}
    38  	pkg, err := packagex.Load(entryPath)
    39  	if err != nil {
    40  		logrus.Panic(err)
    41  	}
    42  
    43  	a.pkg = pkg
    44  	a.routerScanner = scanner.NewRouterScanner(pkg)
    45  }
    46  
    47  func (a *OpenApiGenerator) Pick() {
    48  	defer func() {
    49  		a.routerScanner.OperatorScanner().BindSchemas(a.api)
    50  	}()
    51  
    52  	var routerVar = findRootRouter(a.pkg)
    53  	if routerVar == nil {
    54  		return
    55  	}
    56  
    57  	router := a.routerScanner.Router(routerVar)
    58  	routes := router.Routes()
    59  	operationIDs := map[string]*scanner.Route{}
    60  	for _, r := range routes {
    61  		method := r.Method()
    62  		operation := a.OperationByOperatorTypes(method, r.Operators...)
    63  		if _, exists := operationIDs[operation.OperationId]; exists {
    64  			panic(fmt.Errorf("operationID %s should be unique", operation.OperationId))
    65  		}
    66  		operationIDs[operation.OperationId] = r
    67  		a.api.AddOperation(oas.HttpMethod(strings.ToLower(method)), a.patchPath(r.Path(), operation), operation)
    68  	}
    69  }
    70  
    71  func (a *OpenApiGenerator) OperationByOperatorTypes(method string, operatorTypes ...*scanner.OperatorWithTypeName) *oas.Operation {
    72  	operation := &oas.Operation{}
    73  
    74  	length := len(operatorTypes)
    75  
    76  	for idx := range operatorTypes {
    77  		operatorTypes[idx].BindOperation(method, operation, idx == length-1)
    78  	}
    79  
    80  	return operation
    81  }
    82  
    83  var reHttpRouterPath = regexp.MustCompile("/:([^/]+)")
    84  
    85  func (a *OpenApiGenerator) patchPath(openapiPath string, operation *oas.Operation) string {
    86  	return reHttpRouterPath.ReplaceAllStringFunc(openapiPath, func(str string) string {
    87  		name := reHttpRouterPath.FindAllStringSubmatch(str, -1)[0][1]
    88  
    89  		var isParameterDefined = false
    90  
    91  		for _, parameter := range operation.Parameters {
    92  			if parameter.In == "path" && parameter.Name == name {
    93  				isParameterDefined = true
    94  			}
    95  		}
    96  
    97  		if isParameterDefined {
    98  			return "/{" + name + "}"
    99  		}
   100  
   101  		return "/0"
   102  	})
   103  }
   104  
   105  func (a *OpenApiGenerator) Output(outputPath string) Outputs {
   106  	data, err := json.MarshalIndent(a.api, "", "    ")
   107  	if err != nil {
   108  		logrus.Panic(err)
   109  	}
   110  	return Outputs{
   111  		path.Join(outputPath, "openapi.json"): string(data),
   112  	}
   113  }
   114  
   115  func runnerFunc(node ast.Node) (runner *ast.FuncDecl) {
   116  	switch n := node.(type) {
   117  	case *ast.CallExpr:
   118  		if len(n.Args) > 0 {
   119  			if selectorExpr, ok := n.Fun.(*ast.SelectorExpr); ok {
   120  				if selectorExpr.Sel.Name == "NewApplication" {
   121  					switch node := n.Args[0].(type) {
   122  					case *ast.SelectorExpr:
   123  						runner = node.Sel.Obj.Decl.(*ast.FuncDecl)
   124  					case *ast.Ident:
   125  						runner = node.Obj.Decl.(*ast.FuncDecl)
   126  					case *ast.FuncLit:
   127  						funcDec := &ast.FuncDecl{
   128  							Doc:  nil,
   129  							Recv: nil,
   130  							Name: nil,
   131  							Type: node.Type,
   132  							Body: node.Body,
   133  						}
   134  						runner = funcDec
   135  					}
   136  					return
   137  				}
   138  			}
   139  		}
   140  	}
   141  	return nil
   142  }
   143  
   144  func rootRouter(node ast.Node, p *packagex.Package) *types.Var {
   145  	switch n := node.(type) {
   146  	case *ast.CallExpr:
   147  		if len(n.Args) > 0 {
   148  			if selectorExpr, ok := n.Fun.(*ast.SelectorExpr); ok {
   149  				if selectorExpr.Sel.Name == "Serve" {
   150  					switch node := n.Args[0].(type) {
   151  					case *ast.SelectorExpr:
   152  						return p.TypesInfo.ObjectOf(node.Sel).(*types.Var)
   153  					case *ast.Ident:
   154  						return p.TypesInfo.ObjectOf(node).(*types.Var)
   155  					}
   156  				}
   157  			}
   158  		}
   159  	}
   160  	return nil
   161  }
   162  
   163  func findRootRouter(p *packagex.Package) (router *types.Var) {
   164  	for ident, def := range p.TypesInfo.Defs {
   165  		if typFunc, ok := def.(*types.Func); ok {
   166  			// 搜寻main函数
   167  			if typFunc.Name() != "main" {
   168  				continue
   169  			}
   170  
   171  			// 搜寻runner方法
   172  			var runner *ast.FuncDecl
   173  			ast.Inspect(ident.Obj.Decl.(*ast.FuncDecl), func(node ast.Node) bool {
   174  				runnerDecl := runnerFunc(node)
   175  				if runnerDecl != nil {
   176  					runner = runnerDecl
   177  					return false
   178  				}
   179  				return true
   180  			})
   181  
   182  			// 搜寻router入口
   183  			if runner != nil {
   184  				ast.Inspect(runner, func(node ast.Node) bool {
   185  					if routerVar := rootRouter(node, p); routerVar != nil {
   186  						router = routerVar
   187  						return false
   188  					}
   189  					return true
   190  				})
   191  			}
   192  			return
   193  		}
   194  	}
   195  	return
   196  }